當前位置:
首頁 > 新聞 > 基於TensorFlow的簡單故事生成案例:帶你了解LSTM

基於TensorFlow的簡單故事生成案例:帶你了解LSTM

機器之心編譯

參與:Ellen Han、吳攀



在深度學習中,循環神經網路(RNN)是一系列善於從序列數據中學習的神經網路。由於對長期依賴問題的魯棒性,長短期記憶(LSTM)是一類已經有實際應用的循環神經網路。現在已有大量關於LSTM的文章和文獻,其中推薦如下兩篇:


Goodfellow et.al. 《深度學習》一書第十章:http://www.deeplearningbook.org/

Chris Olah:理解 LSTM:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

已存在大量優秀的庫可以幫助你基於LSTM構建機器學習應用。在GitHub中,谷歌的TensorFlow在此文成文時已有超過 50000 次星,表明了其在機器學習從業者中的流行度。

與此形成對比,相對缺乏的似乎是關於如何基於LSTM建立易於理解的TensorFlow應用的優秀文檔和示例,這也是本文嘗試解決的問題。

假設我們想用一個樣本短故事來訓練LSTM預測下一個單詞,伊索寓言:


long ago , the mice had a general council to consider what measures they could take to outwit their common enemy , the cat . some said this , and some said that but at last a young mouse got up and said he had a proposal to make , which he thought would meet the case . you will all agree , said he , that our chief danger consists in the sly and treacherous manner in which the enemy approaches us . now , if we could receive some signal of her approach , we could easily escape from her . i venture , therefore , to propose that a small bell be procured , and attached by a ribbon round the neck of the cat . by this means we should always know when she was about , and could easily retire while she was in the neighbourhood . this proposal met with general applause , until an old mouse got up and said that is all very well , but who is to bell the cat ? the mice looked at one another and nobody spoke . then the old mouse said it is easy to propose impossible remedies .

表1.取自伊索寓言的短故事,其中有112個不同的符號。單詞和標點符號都視作符號。

如果我們將文本中的3個符號以正確的序列輸入LSTM,以1個標記了的符號作為輸出,最終神經網路將學會正確地預測下一個符號(Figure1)。

基於TensorFlow的簡單故事生成案例:帶你了解LSTM

圖 1.有3個輸入和1個輸出的LSTM單元

嚴格說來,LSTM只能理解輸入的實數。一種將符號轉化為數字的方法是基於每個符號出現的頻率為其分配一個對應的整數。例如,上面的短文中有112個不同的符號。如列表2所示的函數建立了一個有如下條目 [ 「,」 : 0 ] [ 「the」 : 1 ], …, [ 「council」 : 37 ],…,[ 「spoke」 = 111 ]的詞典。而為了解碼LSTM的輸出,同時也生成了逆序字典。

build_dataset(words):

表 2.建立字典和逆序字典的函數

類似地,預測值也是一個唯一的整數值與逆序字典中預測符號的索引相對應。例如:如果預測值是37,預測符號便是「council」。

輸出的生成看起來似乎簡單,但實際上LSTM為下一個符號生成了一個含有112個元素的預測概率向量,並用softmax()函數歸一化。有著最高概率值的元素的索引便是逆序字典中預測符號的索引值(例如:一個 one-hot 向量)。圖2 給出了這個過程。

基於TensorFlow的簡單故事生成案例:帶你了解LSTM

圖2.每一個輸入符號被分配給它的獨一無二的整數值所替代。輸出是一個表明了預測符號在反向詞典中索引的 one-hot 向量。

LSTM模型是這個應用的核心部分。令人驚訝的是,它很易於用TensorFlow實現:


def RNN(x, weights, biases):

# reshape to [1, n_input]

x = tf.reshape(x, [-1, n_input])

# Generate a n_input-element sequence of inputs

# (eg. [had] [a] [general] → [20] [6] [33])

x = tf.split(x,n_input,1)

# 1-layer LSTM with n_hidden units.

rnn_cell = rnn.BasicLSTMCell(n_hidden)

# generate prediction

outputs, states = rnn.static_rnn(rnn_cell, x, dtype=tf.float32)

# there are n_input outputs but

# we only want the last output

return tf.matmul(outputs[-1], weights["out"]) + biases["out"]

表3.有512個LSTM 單元的網路模型

最難部分是以正確的格式和順序完成輸入。在這個例子中,LSTM的輸入是一個有3個整數的序列(例如:1x3 的整數向量)

網路的常量、權值和偏差設置如下:

vocab_size = len(dictionary)

表4.常量和訓練參數

訓練過程中的每一步,3個符號都在訓練數據中被檢索。然後3個符號轉化為整數以形成輸入向量。

symbols_in_keys = [ [dictionary[ str(training_data[i])]] for i in range(offset, offset+n_input) ]

表 5.將符號轉化為整數向量作為輸入

訓練標籤是一個位於3個輸入符號之後的 one-hot 向量

symbols_out_onehot = np.zeros([vocab_size], dtype=float)

表6.單向量作為標籤

在轉化為輸入詞典的格式後,進行如下的優化過程:

_, acc, loss, onehot_pred = session.run([optimizer, accuracy, cost, pred], feed_dict={x: symbols_in_keys, y: symbols_out_onehot})

表 7.訓練過程中的優化

精度和損失被累積以監測訓練過程。通常50,000次迭代足以達到可接受的精度要求。

...

表 8.一個訓練間隔的預測和精度數據示例(間隔1000步)

代價是標籤和softmax()預測之間的交叉熵,它被RMSProp以 0.001的學習率進行優化。在本文示例的情況中,RMSProp通常比Adam和SGD表現得更好。

pred = RNN(x, weights, biases)

表 9.損失和優化器

LSTM的精度可以通過增加層來改善。

rnn_cell = rnn.MultiRNNCell([rnn.BasicLSTMCell(n_hidden),rnn.BasicLSTMCell(n_hidden)])

Listing 10. 改善的LSTM

現在,到了有意思的部分。讓我們通過將預測得到的輸出作為輸入中的下一個符號輸入LSTM來生成一個故事吧。示例輸入是「had a general」,LSTM給出了正確的輸出預測「council」。然後「council」作為新的輸入「a general council」的一部分輸入神經網路得到下一個輸出「to」,如此循環下去。令人驚訝的是,LSTM創作出了一個有一定含義的故事。

had a general council to consider what measures they could take to outwit their common enemy , the cat . some said this , and some said that but at last a young mouse got

表11.截取了樣本故事生成的故事中的前32個預測值

如果我們輸入另一個序列(例如:「mouse」, 「mouse」, 「mouse」)但並不一定是這個故事中的序列,那麼會自動生成另一個故事。

mouse mouse mouse , neighbourhood and could receive a outwit always the neck of the cat . some said this , and some said that but at last a young mouse got up and said

表 12.並非來源於示例故事中的輸入序列

示例代碼可以在這裡找到:https://github.com/roatienza/Deep-Learning-Experiments/blob/master/Experiments/Tensorflow/RNN/rnn_words.py

示例文本的鏈接在這裡:https://github.com/roatienza/Deep-Learning-Experiments/blob/master/Experiments/Tensorflow/RNN/belling_the_cat.txt



小貼士:

  1. 用整數值編碼符號容易操作但會丟失單詞的意思。本文中將符號轉化為整數值是用來簡化關於用TensorFlow建立LSTM應用的討論的。更推薦採用Word2Vec將符號編碼為向量。

  2. 將輸出表達成單向量是效率較低的方式,尤其當我們有一個現實的單詞量大小時。牛津詞典有超過170,000個單詞,而上面的例子中只有112個單詞。再次聲明,本文中的示例只為了簡化討論。

  3. 這裡採用的代碼受到了Tensorflow-Examples的啟發:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/recurrent_network.py

  4. 本文例子中的輸入大小為3,看一看當採用其它大小的輸入時會發生什麼吧(例如:4,5或更多)。

  5. 每次運行代碼都可能生成不同的結果,LSTM的預測能力也會不同。這是由於精度依賴於初始參數的隨機設定。訓練次數越多(超過150,000次)精度也會相應提高。每次運行代碼,建立的詞典也會不同

  6. Tensorboard在調試中,尤其當檢查代碼是否正確地建立了圖時很有用。

  7. 試著用另一個故事測試LSTM,尤其是用另一種語言寫的故事。

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

經濟學人:亞馬遜帝國
伯克利論文提出實時機器學習:可解決實時性和靈活性等七大要求
新手教程:在新應用中實踐深度學習的最佳建議
英偉達Titan Xp之後,如何為深度學習挑選合適的GPU?

TAG:機器之心 |

您可能感興趣

APT案例分析:一個基於meterpreter和Windows代理的攻擊事件
Google Project Shield如何抵禦DDoS攻擊?這個案例可以初探端倪
跨國風險管理公司Baker Engineering and Risk Consultants案例
案例:Oracle報錯ASM磁碟組不存在或沒有mount
Danimer Scientific和PepsiCo於EFIB 2017呈現品牌合作案例研究 &nbsp
SocialBeta 本周 Top 6 營銷案例
案例分析 Toro Gastrobar
我是不會告訴你 Kanye 成為大學精神科臨床案例的;OVO x Jordan 8
乾貨 | TensorFlow的55個經典案例
houdini學習案例+CG獵人vip群
SocialBeta 本周 Top 5 海外營銷案例
Thought Works的10個案例教你打造團隊文化
案例分析 De Lemos
Lime的森林小鳥via——copic sketch24色套裝實戰案例
Hadoop-CERN案例研究
盤點Stella Mccartney獨創性、先鋒性、引導性的印花案例
Haomang量催案例 人生如夢
機器學習其實沒想的那麼難,谷歌TensorFlow應用案例解讀
這套97㎡的freestyle案例,完美解答覆式loft該怎麼裝!