當前位置:
首頁 > 新聞 > textgenrnn:只需幾行代碼即可訓練文本生成網路

textgenrnn:只需幾行代碼即可訓練文本生成網路

本文是一個 GitHub 項目,介紹了 textgenrnn,一個基於 Keras/TensorFlow 的 Python 3 模塊。只需幾行代碼即可訓練文本生成網路。

項目地址:https://github.com/minimaxir/textgenrnn?reddit=1

通過簡簡單單的幾行代碼,使用預訓練神經網路生成文本,或者在任意文本數據集上訓練你自己的任意規模和複雜度的文本生成神經網路。

textgenrnn 是一個基於 Keras/TensorFlow 的 Python 3 模塊,用於創建 char-rnn,具有許多很酷炫的特性:

  • 它是一個使用注意力權重(attention-weighting)和跳躍嵌入(skip-embedding)等先進技術的現代神經網路架構,用於加速訓練並提升模型質量。

  • 能夠在字元層級和詞層級上進行訓練和預測。

  • 能夠設置 RNN 的大小、層數,以及是否使用雙向 RNN。

  • 能夠對任何通用的輸入文本文件進行訓練。

  • 能夠在 GPU 上訓練模型,然後在 CPU 上使用這些模型。

  • 在 GPU 上訓練時能夠使用強大的 CuDNN 實現 RNN,這比標準的 LSTM 實現大大加速了訓練時間。

  • 能夠使用語境標籤訓練模型,能夠更快地學習並在某些情況下產生更好的結果。

你可以使用
textgenrnn,並且在該 Colaboratory
Notebook(https://drive.google.com/file/d/1mMKGnVxirJnqDViH7BDJxFqWrsXlPSoK/view?usp=sharing)中免費使用
GPU 訓練任意文本文件。

示例

from textgenrnn import textgenrnn
textgen = textgenrnn()
textgen.generate()

[Spoiler] Anyone else find
this post and their person that was a little more than I really like
the Star Wars in the fire or health and posting a personal house of the
2016 Letter for the game in a report of my backyard.

該模型可以很容易地在新的文本上進行訓練,甚至可以在僅僅輸入一次數據之後生成合適的文本。

textgen.train_from_file("hacker-news-2000.txt", num_epochs=1)
textgen.generate()

Project State Project Firefox

這個模型的權重比較小(占磁碟上
2 MB 的空間),它們可以很容易地被保存並載入到新的 textgenrnn
實例中。因此,你可以使用經過數百次數據輸入訓練的模型。(實際上,textgenrnn
的學習能力過於強大了,以至於你必須大大提高溫度(Temperature)來得到有創造性的輸出。)

textgen_2 = textgenrnn("/weights/hacker_news.hdf5")
textgen_2.generate(3, temperature=1.0)

Why we got money 「regular alter」

Urburg to Firefox acquires Nelf Multi Shamn

Kubernetes by Google』s Bern

您還可以訓練一個支持詞級別嵌入和雙向 RNN 層的新模型。

使用方法

textgenrnn 可以通過 pip 從 pypi(https://pypi.python.org/pypi/textgenrnn)中安裝:

pip3 install textgenrnn

  • 你可以在該 Jupyter Notebook(https://github.com/minimaxir/textgenrnn/blob/master/docs/textgenrnn-demo.ipynb)中查看常見的功能和配置選項的演示案例。

  • /datasets 包含用於訓練 textgenrnn 的 Hacker News 和 Reddit data 示例數據集。

  • /weights 包含在上述的數據集上進一步預訓練的模型,它可以被載入到 textgenrnn 中。

  • /output 包含從上述預訓練模型中生成文本的示例。

神經網路架構及實現

textgenrnn 基於 Andrej Karpathy 的 char-rnn 項目(https://github.com/karpathy/char-rnn),並且融入了一些最新的優化,如處理非常小的文本序列的能力。

textgenrnn:只需幾行代碼即可訓練文本生成網路

本文涉及到的預訓練模型遵循
DeepMoji
的神經網路架構(https://github.com/bfelbo/DeepMoji/blob/master/deepmoji/model_def.py)的啟發。對於默認的模型,textgenrnn
接受最多 40 個字元的輸入,它將每個字元轉換為 100 維的字元嵌入向量,並將這些向量輸入到一個包含 128
個神經元的長短期記憶(LSTM)循環層中。接著,這些輸出被傳輸至另一個包含 128 個神經元的 LSTM
中。以上所有三層都被輸入到一個注意力層中,用來給最重要的時序特徵賦權,並且將它們取平均(由於嵌入層和第一個 LSTM
層是通過跳躍連接與注意力層相連的,因此模型的更新可以更容易地向後傳播並且防止梯度消失)。該輸出被映射到最多 394
個不同字元的概率分布上,這些字元是序列中的下一個字元,包括大寫字母、小寫字母、標點符號和表情。(如果在新的數據集上訓練一個新模型,可以配置所有上面提到的數值參數。)

textgenrnn:只需幾行代碼即可訓練文本生成網路

或者,如果可以獲得每個文本文檔的語境標籤,則可以在語境模式下訓練模型。在這種模式下,模型會學習給定語境的文本,這樣循環層就會學習到非語境化的語言。前面提到的只包含文本的路徑可以藉助非語境化層提升性能;總之,這比單純使用文本訓練的模型訓練速度更快,且具備更好的定量和定性的模型性能。

軟體包包含的模型權重是基於(通過 BigQuery)在 Reddit 上提交的成千上萬的文本文檔訓練的,它們來自各種各樣的 subreddit 板塊。此外,該網路還採用了上文提到的非語境方法,從而提高訓練的性能,同時減少作者的偏見。

當使用
textgenrnn
在新的文本數據集上對模型進行微調時,所有的層都會被重新訓練。然而,由於原始的預訓練網路最初具備魯棒性強得多的「知識」,新的 textgenrnn
最終能夠訓練地更快、更準確,並且可以學習原始數據集中未出現的新關係。(例如:預訓練的字元嵌入包含所有可能的現代互聯網語法類型中的字元語境。)

此外,重新訓練是通過基於動量的優化器和線性衰減的學習率實現的,這兩種方法都可以防止梯度爆炸,並且大大降低模型在長時間訓練後發散的可能性。

注意事項

即使使用經過嚴格訓練的神經網路,你也不能每次都能得到高質量的文本。這就是使用神經網路文本生成的博文(http://aiweirdness.com/post/170685749687/candy-heart-messages-written-by-a-neural-network)或推文(https://twitter.com/botnikstudios/status/955870327652970496)通常生成大量文本,然後挑選出最好的那些再進行編輯的主要原因。

不同的數據集得到的結果差異很大。因為預訓練的神經網路相對來說較小,因此它不能像上述博客展示的
RNN 那樣存儲大量的數據。為了獲得最佳結果,請使用至少包含 2000-5000
個文檔的數據集。如果數據集較小,你需要在調用訓練方法和/或從頭開始訓練一個新模型時,通過調高 num_epochs
參數來對模型進行更長時間的訓練。即便如此,目前也沒有一個判斷模型」好壞」的啟發式方法。

你並不一定需要用 GPU 重新訓練 textgenrnn,但是在 CPU 上訓練花費的時間較長。如果你使用 GPU 訓練,我建議你增加 batch_size 參數,獲得更好的硬體利用率。

未來計劃

  • 更多正式文檔;

  • 一個使用 tensorflow.js 的基於 web 的實現(由於網路規模小,效果特別好);

  • 一種將注意力層輸出可視化的方法,以查看神經網路是如何「學習」的;

  • 有監督的文本生成模式:允許模型顯示 top n 選項,並且由用戶選擇生成的下一個字元/單詞(https://fivethirtyeight.com/features/some-like-it-bot/);

  • 一個允許將模型架構用於聊天機器人對話的模式(也許可以作為單獨的項目發布);

  • 對語境進行更深入的探索(語境位置 + 允許多個語境標籤);

  • 一個更大的預訓練網路,它能容納更長的字元序列和對語言的更深入理解,生成更好的語句;

  • 層次化的作用於詞級別模型的 softmax 激活函數(Keras 對此有很好的支持);

  • 在 Volta/TPU 上進行超高速訓練的 FP16 浮點運算(Keras 對此有很好的支持)。

使用 textgenrnn 的項目

Tweet Generator:訓練一個為任意數量的 Twitter 用戶生成推文而優化的神經網路。

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 機器之心 的精彩文章:

自監督對抗哈希SSAH:當前最佳的跨模態檢索框架
學界 | 極端圖像壓縮的生成對抗網路,可生成低碼率的高質量圖像

TAG:機器之心 |