當前位置:
首頁 > 最新 > Pointer-network理論及tensorflow實戰

Pointer-network理論及tensorflow實戰

數據下載地址:鏈接:https://pan.baidu.com/s/1nwJiu4T 密碼:6joq

本文代碼地址:https://github.com/princewen/tensorflow_practice/tree/master/myPtrNetwork


1、什麼是pointer-network

Pointer Networks 是發表在機器學習頂級會議NIPS 2015上的一篇文章,其作者分別來自Google Brain和UC Berkeley。

Pointer Networks 也是一種seq2seq模型。他在attention mechanism的基礎上做了改進,克服了seq2seq模型中「輸出嚴重依賴輸入」的問題。

什麼是「輸出嚴重依賴輸入」呢?

論文里舉了個例子,給定一些二維空間中[0,1]*[1,0]範圍內的點,求這些點的凸包(convex hull)。凸包是凸優化里的重要概念,含義如下圖所示,通俗來講,即找到幾個點能把所有點「包」起來。比如,模型的輸入是序列,輸出序列是凸包。到這裡,「輸出嚴重依賴輸入」的意思也就明了了,即輸出是從輸入序列中提取出來的。換個輸入,如,那麼輸出序列就是從裡面選出來。用論文中的語言來描述,即和的凸包,輸出分別依賴於輸入的長度,兩個問題求解的target class不一樣,一個是7,另一個是1000。

Pointer Network在求凸包上的效果如何呢?

從Accuracy一欄可以看到,Ptr-net明顯優於LSTM和LSTM+Attention。

為啥叫pointer network呢?

前面說到,對於凸包的求解,就是從輸入序列中選點的過程。選點的方法就叫pointer,他不像attetion mechanism將輸入信息通過encoder整合成context vector,而是將attention轉化為一個pointer,來選擇原來輸入序列中的元素。

與attention的區別:如果你也了解attention的原理,可以看看pointer是如何修改attention的?如果不了解,這一部分就可以跳過了。

首先搬出attention mechanism的公式,前兩個公式是整合encoder和decoder的隱式狀態,學出來encoder、decoder隱式狀態與當前輸出的權重關係a,然後根據權重關係a和隱式狀態e得到context vector用來預測下一個輸出。

Pointer Net沒有最後一個公式,即將權重關係a和隱式狀態整合為context vector,而是直接進行通過softmax,指向輸入序列選擇中最有可能是輸出的元素。

如果你對上面的理論還沒有理解的很到位,那麼我們通過代碼來進一步講解,相信你通過這段代碼,可以對Ptr的理論有一個更深入的認識。


2、pointer-network實現

這段代碼源自:https://github.com/devsisters/pointer-network-tensorflow

上面的代碼 實現比較複雜,連下載數據的過程都有,真的是十分費勁,我直接把數據下載好了,上傳到百度雲上了,大家可以自行下載(地址見文章開頭)。

代碼目錄如下:

config.py 定義了模型的配置

data_util.py 定義了數據處理過程

main.py 模型的主入口,定義了模型的訓練過程

model.py 定義了我們的pointer-network模型

我們這裡主要講解我們的數據處理和模型定義兩個文件

2.1 數據處理

好了,我們來看看我們的數據吧:

每行是一條數據,由於一條太長,所以分了三行顯示。輸入和target由output隔開,每個輸入的點由兩個坐標構成。

我們用下面的代碼讀入數據,這裡,我們把最後一個target的最後一個去掉了,我們認為我們正常的target的輸出序列不包含最後一個1,最後一個1作為結束標記在後面的代碼里會加入。

由於每條記錄的長度可能不同,因此,我們需要把所有數據的長度補成一樣的:

2.2 模型建立

在model.py文件中,我們定義了Model類以及兩個輔助的函數:

trainable_initial_state :建立可訓練的lstm初始狀態

index_matrix_to_pairs:這個主要是幫助我們使用gather_nd函數來選擇輸入的內容,該函數的一個簡單處理效果如下:

我們這裡重點講解Model類的_build_model函數,該函數用來建立一個pointer-network模型。

定義輸入

我們定義了四部分的輸入,分別是encoder的輸入及長度,decoder的預測序列及長度

輸入處理

我們要對輸入進行處理,將輸入轉換為embedding,embedding的長度和lstm的隱藏神經元個數相同。

在對輸入進行處理之後,輸入的形狀就變為[batch , max_enc_seq_length, hidden_dim]

Encoder

根據配置中的lstm層數,我們建立encoder,同時將我們處理好的輸入輸入到模型中,得到encoder的輸出以及encoder的最終狀態。:

在得到輸出之後,我們要給最前面的輸出添加一個開始的輸出,同時這個添加的開始的輸出還將作為encoder的最開始的輸入。看下面的圖片:

training decoder

與seq2seq不同的是,pointer-network的輸入並不是target序列的embedding,而是根據target序列的值選擇相應位置的encoder的輸出。

我們知道encoder的輸出長度在添加了開始輸出之後形狀為[batch ,max_enc_seq_length + 1]。現在假設我們拿第一條記錄進行訓練,第一條記錄的預測序列是[1,2,4],那麼decoder依次的輸入是

self.enc_outputs[0][0], self.enc_outputs[0][1],self.enc_outputs[0][2],self.enc_outputs[0][4],那麼如何根據target序列來選擇encoder的輸出呢,這裡就要用到我們剛剛定義的index_matrix_to_pairs函數以及gather_nd函數:

由於decoder的輸出變成了原先的target序列的長度+1,因此我們要在每個target後面補充一個結束標記,我們補充1作為結束標記:

同樣,我們建立一個多層的lstm網路:

對於decoder來說,這裡我們每次每個batch只輸入一個值,然後使用循環來實現整個decoder的過程:

可以看到,我們定義了兩個數組來保存輸出的序列,以及每次輸出的softmax值。這裡定義了一個choose_index函數,這個函數的作用即我們的pointer機制,即得到每個decoder輸出與encoder輸出按如下公式相互作用的softmax數組:

在論文中還提到一個詞叫做glimpse function,他首先將上面式子中的q進行了處理,公式如下:

glimpse function可以實現多層,當然我們代碼里只有一層:

可以看到,我們的attention函數高度還原了上面的式子,哈哈!

decoder predicting

對預測來說,我們不能實現選擇好用哪個encoder的輸出,必須根據上一輪的輸出來決定,所以與training的代碼不同的是,我們在每層循環里都是用index_matrix_to_pairs函數以及gather_nd函數來選擇下一時刻的輸出。

** 定義loss**

這裡定義的loss與seq2seq的loss相似:

定義訓練、驗證、預測函數

我們還定義了訓練、驗證、預測函數:

2.3 訓練及驗證

在main.py函數中,我們獲取數據,並進行訓練。這裡代碼還有待完善,因為沒有進行預測,嘻嘻!

實驗效果如下:


3、 參考文獻

1、神經網路之Pointer Net (Ptr-net) :https://zhuanlan.zhihu.com/p/30860157

2、https://github.com/devsisters/pointer-network-tensorflow

微信公眾號

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 小小挖掘機 的精彩文章:

實戰深度強化學習DQN-理論和實踐

TAG:小小挖掘機 |