當前位置:
首頁 > 新聞 > 要合作,不要對抗!無需預訓練超越經典演算法,上交大提出合作訓練式生成模型CoT

要合作,不要對抗!無需預訓練超越經典演算法,上交大提出合作訓練式生成模型CoT

作者:盧思迪 上海交通大學

【新智元導讀】上海交通大學APEX實驗室研究團隊提出合作訓練(Cooperative Training),通過交替訓練生成器(G)和調和器(M),無需任何預訓練即可穩定地降低當前分布與目標分布的JS散度,且在生成性能和預測性能上都超越了以往的演算法。對於離散序列建模任務來說,該演算法無需改動模型的網路結構,同時計算代價較理想,是一種普適的高效演算法。本文是論文第一作者盧思迪帶來的解讀。

論文地址:https://arxiv.org/pdf/1804.03782.pdf

GitHub:https://github.com/desire2020/Cooperative-Training

生成式模型是無監督學習這一領域的一個重要話題。對於連續數據(如圖片)的建模,自2014年生成式對抗網路(Generative Adversarial Network, GAN)發表以來,研究已取得了不少進展。然而,對於離散數據,特別是離散序列的建模與生成,針對這個問題的研究仍沒有產生足夠令人滿意的突破。

對於這一類數據建模問題,經典演算法如極大似然估計(Maximum Likelihood Estimation, MLE)很難稱得上是理想的演算法。在數據有限的情況下,它和生成式任務並不能完美地相適應。如下圖,MLE等價於優化單側KL散度KL(P||G):

由於KL散度不對稱,對於預測中的失誤,MLE這一目標函數能夠給出比較好的懲罰進而給予糾正;但是對於潛在的生成失誤,MLE並不能很好地起到作用。

針對這一問題,研究者們提出了序列生成式網路(Sequence Generative Adversarial Network, SeqGAN)。SeqGAN是這一領域針對MLE問題的早期嘗試之一,其使用強化學習來優化GAN的目標函數,即:

相比於經典演算法,SeqGAN在樣本生成的質量上有了一些改進。然而由於對抗網路固有的不穩定性,SeqGAN常常在預測式任務中表現不佳。此外,受限於策略梯度法這一基於策略的強化學習(Policy-based Reinforcement Learning)的能力,SeqGAN並不能單獨使用,需要使用MLE進行預訓練。

針對這個問題,上海交通大學APEX實驗室研究團隊提出合作訓練(Cooperative Training),通過交替訓練生成器(G)和調和器(M),無需任何預訓練即可穩定地降低當前分布與目標分布的JS散度,且在生成性能和預測性能上都超越了以往的演算法。對於離散序列建模任務來說,該演算法無需改動模型的網路結構,同時計算代價較理想,是一種普適的高效演算法。


一個支點,撬動分布

在圖片生成等任務里,GAN之所以能奏效,是因為其本質上優化的是當前分布與目標分布的Jensen-Shannon散度(JSD),即:

其中M=0.5P + 0.5G,是當前已習得分布G與目標分布P的一個均衡混合分布。從定義可以看出,JSD對於P和G是對稱的。也就是說,對於模型在生成式任務和判別式任務中的錯誤,這個衡量標準都可以均衡地反饋出來。如果能夠直接最小化JSD或者它的一個無偏差(unbiased)近似,那麼對於目標分布的擬合就是比較準確的。遺憾的是,直接對JSD本身進行優化是不可能的。原因是,我們只有對於自己當前分布的建模G,但是無法直接拿到目標分布P,進而構造準確的M是不可能的。但是,受到GAN的啟發,我們可以訓練一個模型去近似混合分布M,並且以它為支點來優化一個JSD的好的近似。

基於這一想法,研究者們提出了合作訓練(Cooperative Training, CoT)。如圖所示,在合作訓練的框架內,有兩個架構相同的模塊,稱為生成器(Generator, G)和調和器(Mediator, M)

每一次迭代,從G中采出一些樣本,再從訓練數據中隨機選出等量的樣本,把兩者混合,用來訓練M。由於這種情況下,我們只關心M對於給定樣本的似然度估計,因此在訓練M時,我們使用MLE就不會產生一般意義下的各種問題。在M得到訓練後,對於來自G的一組樣本s,用M給出的估計值M(s)來代替真實值M*(s),從而得到一個JSD的近似估計。在訓練G時,最小化這個近似估計,即可達到對於目標分布的趨近。


要合作,不要對抗

通過一些推導,我們可以給出這個演算法中兩個模塊各自的目標函數:

調和器(Mediator):

生成器(Generator):

其中π代表兩個模塊在給定前綴下對於下一步所有決策的概率估計。為了不使這篇介紹過於無趣,詳細的推導請參見原文。完整演算法流程如下:

可以看出,這個演算法的計算複雜度與MLE一致,兩者僅差一個常數倍數。

對於CoT來說,最終的優化問題可以寫成:

這是一個合作式目標(而非GAN中的非合作博弈目標)。通過推導我們可以知道,這個優化目標的一半和JSD的相反數趨勢一致,兩者的差值就是目標分布的熵!


實驗及更多討論

對於合成數據上的驗證性實驗,研究者使用了由SeqGAN提出,並在TexyGen(一個基準評測系統)中得以完善的數據,即合成數據圖靈測試(Synthetic Turing Test)。結果如表所示,公平起見,這一測試中所有的模型未使用任何正則化,且生成器架構完全相同。

注意到,即使是在反映預測式任務性能的NLL test(這本身是MLE的優化目標)這一指標上,CoT也超越了MLE,不僅僅是在收斂性能上優於MLE,即使訓練途中所探索到的最好局部最優(7.54)也好於MLE。而在生成質量的測試指標NLL oracle方面,從零開始訓練,無需任何MLE預熱的CoT達到了使用簡單生成器架構模型中最優水平。如果綜合考察生成質量和預測準確性,之前的模型在兩個指標之和的意義下相比MLE並沒有產生改進。而CoT不但有明顯改進,而且在兩個任務下的性能水平基本一致(均為8.1左右)。反觀MLE,則很不均衡(生成損失:9.43預測損失:7.67)。這更說明一個無偏的優化目標對於數據建模的有效性之重要。

我們注意到相比較G,由於M的訓練目標形式上更接近有監督學習,再加上在推薦設定中它的容量比G更大,它很容易過擬合,進而影響模型的表現。因此,在使用一些簡單的正則化技術,如Dropout之後,模型的表現更加令人滿意。在合成數據上,我們可以通過算出真正的JSD來說明這一點。如圖,使用了正則化後,我們可以發現我們的演算法達到了對於真正的JSD的持續、一致、較為穩定的優化。而且,由調和器提供的對JSD的趨勢估計也非常準確。

除此之外,作為一個通用的離散序列建模演算法,我們也進行了一些文本上的實驗。為了控制變數,我們使用這一領域前人工作大都評測過的一個較長文本數據集EMNLP 2017 WMT News Section。如表所示,在使用相同(或接近)的架構和細節設定的前提下,我們的演算法達到了最佳水平。


Nested CoT

我們注意到這個演算法還可以用於提高其自身的效果。具體來說,對於M我們也可以使用CoT來代替MLE對其進行訓練。由於CoT具有提高模型在預測任務中泛化性能的能力,這樣做可以使得模型更加穩定。然而,受限於篇幅和時間,我們沒有給出實踐上的驗證,但這一想法本身非常有趣。


總結

我們提出新的生成式模型訓練演算法合作訓練(Cooperative Training),用於優化當前已習得分布和目標分布的JS散度。該演算法無需預訓練,計算速度和MLE同等理想,且在所有離散序列建模任務(包括生成式和預測式)裡面超越了以往的演算法。我們希望能進一步地對這一演算法展開研究,並將其延拓至其他類型數據如圖片上,為生成式模型建立一個新的範式。我們也期待研究者能夠就CoT與GAN之間更深層次的聯繫展開研究,併產生一些有趣的結論。


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 新智元 的精彩文章:

AI 里程碑!機器翻譯系統提前 7 年達到人類專業翻譯水平!
機器人蝙蝠俠和蜘蛛俠:一個靠機器學習飛,一個折成輪子滾,動作逆天!

TAG:新智元 |