當前位置:
首頁 > 知識 > Pytorch 中如何處理 RNN 輸入變長序列 padding

Pytorch 中如何處理 RNN 輸入變長序列 padding


本文作者憶臻,原載於知乎專欄 —— 機器學習亂髮與自然語言處理。

一、為什麼 RNN 需要處理變長輸入

假設我們有情感分析的例子,對每句話進行一個感情級別的分類,主體流程大概是下圖所示:

思路比較簡單,但是當我們進行 batch 個訓練數據一起計算的時候,我們會遇到多個訓練樣例長度不同的情況,這樣我們就會很自然的進行 padding,將短句子 padding 為跟最長的句子一樣。

比如向下圖這樣:

但是這會有一個問題,什麼問題呢?比如上圖,句子 「Yes」 只有一個單詞,但是 padding 了 5 的 pad 符號,這樣會導致 LSTM 對它的表示通過了非常多無用的字元,這樣得到的句子表示就會有誤差,更直觀的如下圖:

那麼我們正確的做法應該是怎麼樣呢?

這就引出 pytorch 中 RNN 需要處理變長輸入的需求了。在上面這個例子,我們想要得到的表示僅僅是 LSTM 過完單詞 "Yes" 之後的表示,而不是通過了多個無用的 「Pad」 得到的表示:如下圖:

二、pytorch 中 RNN 如何處理變長 padding

這裡的 pack,理解成壓緊比較好。 將一個 填充過的變長序列 壓緊。(填充時候,會有冗餘,所以壓緊一下)

輸入的形狀可以是 (T×B×*)。T 是最長序列長度,B 是 batch size,* 代表任意維度 (可以是 0)。如果 batch_first=True 的話,那麼相應的 input size 就是 (B×T×*)。

Variable 中保存的序列,應該按序列長度的長短排序,長的在前,短的在後(特別注意需要進行排序)。即 input[:,0] 代表的是最長的序列,input[:, B-1] 保存的是最短的序列。

參數說明:

input (Variable) – 變長序列 被填充後的 batch

lengths (list[int]) – Variable 中 每個序列的長度。(知道了每個序列的長度,才能知道每個序列處理到多長停止)

batch_first (bool, optional) – 如果是 True,input 的形狀應該是 B*T*size。

返回值:

一個 PackedSequence 對象。一個 PackedSequence 表示如下所示:

具體代碼如下:

此時,返回的 h_last 和 c_last 就是剔除 padding 字元後的 hidden state 和 cell state,都是 Variable 類型的。代表的意思如下(各個句子的表示,lstm 只會作用到它實際長度的句子,而不是通過無用的 padding 字元,下圖用紅色的打鉤來表示):

但是返回的 output 是 PackedSequence 類型的,可以使用:

將 encoderoutputs 在轉換為 Variable 類型,得到的_代表各個句子的長度。


三、總結

參考:

pytorch 對可變長度序列的處理http://suo.im/2c3XOA

pytorch RNN 變長輸入 paddinghttp://suo.im/2xkED8

限時拼團

3 大模塊,30 個課時

高校數學系教授帶班

100% 學員好評

與 100 + 同學一起夯實數學基礎,走穩機器學習入門第一步!

新人福利

關注 AI 研習社(okweiwu),回復1領取

【超過 1000G 神經網路 / AI / 大數據,教程,論文】

PyTorch 合輯


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 AI研習社 的精彩文章:

第 14 彈:斯坦福Serena Yeung教你深度增強學習

TAG:AI研習社 |