當前位置:
首頁 > 知識 > 手把手教你訓練 RNN

手把手教你訓練 RNN

本文為雷鋒字幕組編譯的技術博客,原標題 Step-by-step walkthrough of RNN Training - Part I,作者為 Eniola Alese。

翻譯 | 趙朋飛 程思婕 整理 | 凡江


RNN 前向傳播逐步演練

單個 RNN Cell 中的前向傳播演算法

在之前的文章中,我們介紹了 RNN 的基本結構並將其按時間序列展開成 Cells 循環鏈,稱為 RNN cells。下面,我們將揭示單個 RNN Cell 的內部結構和前向傳播計算過程。

將其過程分解成多個步驟:

第一步:cell 接受兩個輸入:x?t? 和 a?t-1?。

第二步:接下來,計算矩陣乘積 ?,W_xh 乘 x?t?,W_ah 乘 a?t-1?。然後,通過將以上兩項乘積相加,並加上偏置 b_h,計算得出 h?t?。

第三步:緊接著上面的步驟,將 h(t) 傳給激活函數(比如 tanh 或 relu)計算 a(t)。本例中激活函數使用 tanh 函數。

第四步: cell 輸出 a?t? 並將其傳給下一 cell 做進一步計算。

第五步:然後,計算 o?t?; 這是所有輸出可能取值的非標準化對數概率。方法是計算矩陣乘積?,W_ao 乘 a?t?,並與 b_o 相加。

第六步:最後,通過將 o?t? 傳輸給激活函數(例如 sigmoid 或 softmax),得到了一個實際輸出的標準化概率向量 ??t?。輸出的激活函數的選擇通常取決於期望的輸出類型(sigmoid 用於二元輸出,softmax 用於多類別輸出)。

前向傳播演算法

前向傳播演算法在整個 RNN 網路中運行以上步驟,而不僅僅在單個 RNN cell 中運行。從隱藏層狀態 a?0?的初始化開始,在所有時間序列 t = 1 to T 中共享權值和偏置向量 W_xh,W_ah, W_ao, b_h, b_o,在每個時間序列中重複上面的每一步。

例如,如果我們擁有一個 8 個序列的輸入 x?1?,x?2?,......x?8?,這個網路的前向傳播計算過程是步驟 1-6 在循環中重複 8 次。


RNN 的反向傳播是為了計算出關於損失函數的梯度值

單個 RNN 單元的反向傳播

RNN 中反向傳播的目的是計算出最終的損失值 L 分別對權值矩陣(W_xh,W_ah,W_ao)和偏置向量(b_h,b_o)的偏導數值。

推導出所需的導數值非常簡單,我們只需要利用鏈式法則(https://en.wikipedia.org/wiki/Chain_rule)就能計算出它們。

第一步:為了計算代價,需要先定義損失函數。一般根據具體手中的任務來選擇該損失函數。在這個例子里,對於多分類輸出問題,我們採用交叉熵損失函數 L?t?,其具體計算過程如下:

第二步:接下來我們開始往後計算損失函數 L?t? 對預測輸出值的激活值 ??t? 的偏導數值。因為在前向傳播過程中 softmax 函數以多分類的輸出值作為輸入,因此下面的偏導數值的計算分為兩種情況:分類 i 時和分類 k 時:

第三步:接著利用分類 i 時和分類 k 時的偏導數值,可以計算出損失函數 L?t? 對預測輸出值 o?t? 的偏導數值:

第四步:利用偏導數值及鏈式法則,計算出損失函數 L?t? 對輸出過程中的偏置向量 b_o 的偏導數值:

第五步:利用偏導數值及鏈式法則,計算出損失函數 L?t? 對隱層至輸出層中的權值矩陣 W_ao 的偏導數值:

第六步: 利用偏導數值、及鏈式法則,計算出損失函數 L?t? 對隱狀態的激活值 a?t? 的偏導數值:

第七步: 利用偏導數值及鏈式法則,計算出損失函數 L?t? 對隱狀態 h?t? 的偏導數值:

第八步: 利用偏導數值及鏈式法則,計算出損失函數 L?t? 對隱狀態的偏置向量 b_h 的偏導數值:

第九步:利用偏導數值及鏈式法則,計算出損失函數 L?t? 對輸入層至隱層中的偏置矩陣 W_xh 的偏導數值:

第十步:利用偏導數值及鏈式法則,計算出損失函數 L?t? 對輸入層至隱層中的偏置矩陣 W_ah 的偏導數值:

隨時間反向傳播(BPTT)

就像前文中提到的前向傳播過程一樣,將循環網路展開,BPTT 將沿此一直運行著上述步驟。

主要的區別在於我們必須將每個時間步 t 的偏導數值累加起來,從而更新權值和偏置,這是因為這些參數在前向傳播的過程中是被各個時間步所共享的。

總結

在本文的第一部分和第二部分中,我們了解了循環神經網路訓練過程中所涉及到的前向傳播和反向傳播。接下來,我們將著眼於 RNN 中所存在的梯度消失問題,並討論 LSTM 和 GRU 網路的進展。

博客原址:

Part I

https://medium.com/learn-love-ai/step-by-step-walkthrough-of-rnn-training-part-i-7aee5672dea3

Part II

https://medium.com/learn-love-ai/step-by-step-walkthrough-of-rnn-training-part-ii-7141084d274b

添加雷鋒字幕組微信號(leiphonefansub)為好友

從Python入門-如何成為AI工程師

BAT資深演算法工程師獨家研發課程

最貼近生活與工作的好玩實操項目

班級管理助學搭配專業的助教答疑

學以致用拿offer,學完即推薦就業

新人福利

關注 AI 研習社(okweiwu),回復1領取

【超過 1000G 神經網路 / AI / 大數據資料】

Pytorch 中如何處理 RNN 輸入變長序列 padding

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 AI研習社 的精彩文章:

狗狗視角看世界,用視覺數據預測狗的行為
來了!2018 MIT 6.S094 中文譯版「深度學習和自動駕駛課」今日上線

TAG:AI研習社 |