論文筆記:Sequence Generative Adversarial Nets with Policy Gradient
1 問題描述
GAN在圖像生成領域取得了巨大的突破,而直接將GAN用於文本生成則存在以下困難:
1.GAN難以直接處理離散數據,與圖像生成中pixel的值連續不同,文本生成中需要生成離散的tokens序列。由於生成器G需要從判別器D獲取誤差進行訓練,因此G和D都需要完全可微。在圖像生成中,G的參數進行一點點的更新都會造成生成圖片的像素的改變,使得生成的數據更加「逼真」。而在文本生成中,G參數的微小變化無法在引起離散token的變化,使得D無法給出相應的信息。
2.一般來說,判別器D只可以對生成的完整序列打分,由於文本生成中一般通過token by token的方式生成,如何判斷生成的部分序列的質量也是一個困難。
本文章提出使用policy-gradient的方法,將判別器D對生成文本的打分作為reward,對生成器G進行更新,來解決離散問題。通過蒙特卡洛採樣的方法,對部分序列進行估值,來解決判斷生成的部分序列的質量問題。生成器使用LSTM-RNN language model,判別器使用Text-CNN。
2 模型
整個模型架構如下圖
對於判別器,對輸入的真實數據和生成數據進行二分類,它的損失函數如下:
對於生成器,先通過MLE進行預訓練,再利用policy-gradient進行更新。
policy-gradient的思想就是直接通過梯度上升的方法來調整策略(policy)以最大化reward。這裡policy-gradient即最大化下式:
上式可簡化為:
其中Q代表生成Y1:t-1序列時,生成yttoken的expected accumulative reward。即增強學習中在狀態 Y1:t-1下,執行動作yt的Q值。該Q值可以通過蒙特卡洛採樣的方式,利用生成器在該狀態下繼續採樣出N個完整句子,由判別器對這N個完整句子打分,取平均值,來估計該Q值,如下:
從(7)式我們可以看出,直觀來說,如果生成yt帶來的reward高,那麼我們就增大生成它的概率。
最終對於生成器G,我們通過最大化(7)式進行參數更新,如下:
整個演算法的流程為先通過真實數據對G進行MLE的預訓練,再從G中採樣出負樣本,通過負樣本和真實數據進行判別器D的預訓練。預訓練完畢後,通過對G和D進行迭代訓練,訓練G時,先通過生成一個batch的數據,再通過MC search的D對該batch數據每一時刻的reward進行計算,最終通過policy-gradient調整概率分布,整個流程如下:
3 實驗
作者在奧巴馬演講稿、詩歌生成、音樂生成等真實數據上進行了實驗,通過BLEU和人工評價的方式證明了該方法比單純的MLE效果要好。事實上BLEU並不能很好的衡量生成質量,而人工評價又耗時好力,且具有一定偏差和隨機性。作者創新性的設計了一種Synthetic Data Experiments。
該實驗先將一個生成模型LSTM作為oracle,該模型參數可以隨機初始化。該模型即代表目標數據分布,我們從該模型上採樣出10000個樣本作為真實數據,來訓練G,而將G生成的數據輸入oracle模型,通過計算平均的NLL值(negative log-likelihood)作為評價標準。這相當於我們通過訓練G,來擬合oracle模型的分布,而由於oracle分布已知,我們可以通過NLL計算出G對oracle分布的擬合程度,作為一種可信的自動化評價標準。作者在這個實驗上比較了隨機生成、PG-BLEU、scheduled sampling和MLE等模型,證明了SeqGan能夠取得很好的效果,實驗結果如下:
作者也通過該實驗測試了G和D迭代訓練時,每一輪迭代,兩個模型應該分別訓練的次數,結果如下:
發現當g-step=1,d-step=5,k=3時效果較好。
4.總結
本文章通過policy-gradient和MC search解決了GAN在文本生成上的兩大問題,並取得了較好的實驗結果,開啟了GAN生成文本的大門。然而方法仍然存在一些問題,比如訓練不穩定,當D訓練的很好時,傳給G的信號很弱(Reward vanishing),mode-collapse,判別器必須要等到生成完整句子才能給出打分,訓練計算量大等問題。我將在以後陸續介紹針對這些問題提出改進方法的文章。
TAG:文本智能 |