當前位置:
首頁 > 知識 > 深度學習對話系統實戰篇-簡單 chatbot 代碼實現

深度學習對話系統實戰篇-簡單 chatbot 代碼實現

本文作者 劉沖,本文首發於知乎專欄【基於深度學習的對話系統】,AI 研習社獲其授權轉載。

本文的代碼都可以到我的 github 中下載:https://github.com/lc222/seq2seq_chatbot

前面幾篇文章我們已經介紹了 seq2seq 模型的理論知識,並且從 tensorflow 源碼層面解析了其實現原理,本篇文章我們會聚焦於如何調用 tf 提供的 seq2seq 的 API,實現一個簡單的 chatbot 對話系統。這裡先給出幾個參考的博客和代碼:

tensorflow 官網 API 指導(http://t.cn/R8MiZcR)

Chatbots with Seq2Seq Learn to build a chatbot using TensorFlow(http://t.cn/R8MiykP)

DeepQA(http://t.cn/R8MiVld)

Neural_Conversation_Models(http://t.cn/RtthjXn)

經過一番調查發現網上現在大部分 chatbot 的代碼都是基於 1.0 版本之前的 tf 實現的,而且都是從 tf 官方指導文檔 nmt 上進行遷移和改進,所以基本上大同小異,但是在實際使用過程中會發現一個問題,由於 tf 版本之間的兼容問題導致這些代碼在新版本的 tf 中無法正常運行,常見的幾個問題主要是:

關於上面第三個錯誤這裡多說幾句,因為確實困擾了我很久,基本上我在網上找到的每一份代碼都會有這個錯(DeepQA 除外)。首先來講一種最簡單的方法是將 tf 版本換成 1.0.0,這樣問題就解決了。

然後說下不想改 tf 版本的辦法,我在網上找了很久,自己也嘗試著去找 bug 所在,錯誤定位在 embedding_attention_seq2seq 函數中調用 deepcopy 函數,於是就有人嘗試著把 deepcopy 改成 copy,或者乾脆不進行 copy 直接讓 encoder 和 decoder 使用相同參數的 RNNcell,但這明顯是不正確的做法。我先想出了一種解決方案就是將 embedding_attention_seq2seq 的傳入參數中的 cell 改成兩個,分別是 encoder_cell 和 decoder_cell,然後這兩個 cell 分別使用下面代碼進行初始化:

這樣做不需要調用 deepcopy 函數對 cell 進行複製了,問題也就解決了,但是在模型構建的時候速度會比較慢,我猜測是因為需要構造兩份 RNN 模型,但是最後訓練的時候發現速度也很慢,就先放棄了這種做法。

然後我又分析了一下代碼,發現問題並不是單純的出現在 embedding_attention_seq2seq 這個函數,而是在調用 module_with_buckets 的時候會構建很多個不同 bucket 的 seq2seq 模型,這就導致了 embedding_attention_seq2seq 會被重複調用很多次,後來經過測試發現確實是這裡出現的問題,因為即便不使用 model_with_buckets 函數,我們自己為每個 bucket 構建模型時同樣也會報錯,但是如果只有一個 bucket 也就是只調用一次 embedding_attention_seq2seq 函數時就不會報錯,其具體的內部原理我現在還沒有搞清楚,就看兩個最簡單的例子:

所以先忽視原因,只看解決方案的話就是,不適用 buckets 構建模型,而是簡單的將所有序列都 padding 到統一長度,然後直接調用一次 embedding_attention_seq2seq 函數構建模型即可,這樣是不會抱錯的。(希望看到這的同學如果對這裡比較理解可以指點一二,或者互相探討一下)

最後我也是採用的這種方案,綜合了別人的代碼實現了一個 embedding+attention+beam_search 等多種功能的 seq2seq 模型,訓練一個基礎版本的 chatbot 對話機器人,tf 的版本是 1.4。寫這份代碼的目的一方面是為了讓自己對 tf 的 API 介面的使用方法更熟悉,另一方面是因為網上的一些代碼都很繁雜,想 DeepQA 這種,裡面會有很多個文件還實現了前端,然後各種封裝,顯得很複雜,不適合新手入門,所以就想寫一個跟 textcnn 相似風格的代碼,只包含四個文件,代碼讀起來也比較友好。接下來就讓我們看一下具體的代碼實現吧。最終的代碼我會放在 github 上

數據處理

這裡我們借用 [DeepQA](https://github.com/Conchylicultor/DeepQA#chatbot) 裡面數據處理部分的代碼,省去從原始本文文件構造對話的過程直接使用其生成的 dataset-cornell-length10-filter1-vocabSize40000.pkl 文件。有了該文件之後數據處理的代碼就精簡了很多,主要包括:

1. 讀取數據的函數 loadDataset()

2. 根據數據創建 batches 的函數 getBatches() 和 createBatch()

3. 預測時將用戶輸入的句子轉化成 batch 的函數 sentence2enco()

具體的代碼含義在注釋中都有詳細的介紹,這裡就不贅述了,見下面的代碼:


模型構建

有了數據之後看一下模型構建的代碼,其實主體代碼還是跟前面說到的 tf 官方指導文檔差不多,主要分為以下幾個功能模塊:

1. 一些變數的傳入和定義

2. OutputProjection 層和 sampled_softmax_loss 函數的定義

3. RNNCell 的定義和創建

4. 根據訓練或者測試調用相應的 embedding_attention_seq2seq 函數構建模型

5. step 函數定義,主要用於給定一個 batch 的數據,構造相應的 feed_dict 和 run_opt

代碼如下所示:

接下來我們主要說一下我做的主要工作,就是 beam_search 這部分,其原理想必大家看過前面的文章應該已經很清楚了,那麼如何編程實現呢,首先我們要考慮的是在哪裡進行 beam search,因為 beam search 是在預測時需要用到,代替 greedy 的一種搜索策略,所以第一種方案是在 tf 之外,用 python 實現,這樣做的缺點是 decode 速度會很慢。第二種方案是在 tf 內模型構建時進行,這樣做的好處是速度快但是比較麻煩。

在網上找了很久在 tensorflow 的一個 issue(http://t.cn/R8M6mDo) 裡面發現了一個方案,他的思路是修改 loop_function 函數,也就是之前根據上一時刻輸出得到下一時刻輸入的函數,在 loop function 裡面實現 top_k 取出概率最大的幾個序列,並把相應的路徑和單詞對應關係保存下來。但是存在一個問題就是一開始 decode 的時候傳入的是一句話,也就是 batch_size 為 1,但是經過 loop_function 之後返回的是 beam_size 句話,但是再將其傳入 RNNCell 的時候就會報錯,如何解決這個問題呢,想了很久決定直接從 decode 開始的時候就把輸入擴展為 beam_size 個,把 encoder 階段的輸出和 attention 向量都變成 beam_size 維的 tensor,就說把 decoder 階段的 RNN 輸入的 batch_size 當做為 beam_size。

但是這樣做仍然會出現一個問題,就是你會發現最後的輸出全部都相同,原因就在於 decoder 開始的時候樣本是 beam_szie 個完全相同的輸入,所以經過 loop_function 得到的 beam_size 個最大序列也是完全相同的,為了解決這個問題我們需要在第一次編碼的時候不取整體最大的前 beam_size 個序列,而是取第一個元素編碼結果的前 beam_size 個值作為結果。這部分代碼比較多就只貼出來 loop_function 的函數,有興趣的同學可以去看我 github 上面的代碼,就在 seq2seq 文件中。


模型訓練

其實模型訓練部分的代碼很簡單,就是每個 epoch 都對樣本進行 shuffle 然後分 batches,接下來將每個 batch 的數據分別傳入 model.step() 進行模型的訓練,這裡比較好的一點是,DeepQA 用的是 embedding_rnn_seq2seq 函數,訓練過程中 loss 經過 30 個人 epoch 大概可以降到 3 點多,但是我這裡改成了 embedding_attention_seq2seq 函數,最後 loss 可以降到 2.0 以下,可以說效果還是很顯著的,而且模型的訓練速度並沒有降低,仍然是 20 個小時左右就可以完成訓練。

貼上兩張圖看一下訓練的效果,這裡用的是 deepQA 的截圖,因為我的代碼訓練的時候忘了加 tensorboard 的東西:


模型預測

預測好模型之後,接下來需要做的就是對模型效果進行測試,這裡也比較簡單,主要是如何根據 beam_search 都所處的結果找到對應的句子進行輸出。代碼如下所示:

接下來我們看一下幾個例子,這裡 beam_size=5,並去掉了一些重複的回答:

NLP 工程師入門實踐班:基於深度學習的自然語言處理

三大模塊,五大應用,手把手快速入門 NLP

海外博士講師,豐富項目經驗

演算法 + 實踐,搭配典型行業應用

隨到隨學,專業社群,講師在線答疑

新人福利

關注 AI 研習社(okweiwu),回復1領取

【超過 1000G 神經網路 / AI / 大數據,教程,論文】

如何基於 rasa 搭建一個中文對話系統 (有源碼視頻)


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 AI研習社 的精彩文章:

微軟推出開源自動駕駛模擬平台 AirSim 教程,機器學習新手也能快速上手自動駕駛
50篇學術訪談實錄:一份557頁的年終答卷

TAG:AI研習社 |