當前位置:
首頁 > 知識 > 增強學習在image caption任務上的應用

增強學習在image caption任務上的應用

引言


第二十二期的PaperWeekly對Image Captioning進行了綜述。今天這篇文章中,我們會介紹一些近期的工作。(如果你對Image Captioning這個任務不熟悉的話,請移步二十二期PaperWeekly 第二十二期---Image Caption任務綜述)


Image Captioning的模型一般是encoder-decoder的模型。模型對$p(S|I)$進行建模,$S$是描述,$I$是圖片。模型的訓練目標是最大化log似然:$max_ hetasum_i log P(S_i|I_i, heta)$。

然而使用最大似然訓練有兩個問題:


1、雖然訓練時最大化後驗概率,但是在評估時使用的測度則為BLEU,METEOR,ROUGE,CIDER等。這裡有訓練loss和評估方法不統一的問題。而且log似然可以認為對每個單詞都給予一樣的權重,然而實際上有些單詞可能更重要一些(比如說一些表示內容的單詞)。


2、第二個問題為Exposure bias。訓練的時候,每個時刻的輸入都是來自於真實的caption。而生成的時候,每個時刻的輸入來自於前一時刻的輸出;所以一旦有一個單詞生成的不好,錯誤可能會接著傳遞,使得生成的越來越糟糕。


如何解決這兩個問題呢?很顯而易見的想法就是盡量使得訓練和評估時的情形一樣。我們可以在訓練的時候不優化log似然,而是直接最大化CIDER(或者BLEU,METEOR,ROUGE等)。並且,在訓練時也和測試時一樣使用前一時刻的輸入,而不是全使用ground truth輸入。

然而這有什麼難點呢?第一,CIDER或者這一些metric並不是可直接求導。(這就是為什麼在分類問題中,我們把0-1 error近似成log loss,hinge loss的原因)。其次從前一時刻輸出獲得後一時刻的輸入涉及到採樣操作,這也是不可微的。為了能夠解決這些不可微的問題,人們就想到了Reinforcement learning。


RL基本概念


RL中有一些比較重要的基本概念:狀態(state),行為(action),回報(reward)和決策(policy)。決策是一個狀態到動作的函數,一般是需要學習的東西。拿打遊戲的例子介紹RL最簡單。如果說是玩flappy bird,RL要學習的就是在什麼位置跳,能使得最後得到的分數越高。在這個例子里,最後的分數就是回報,位置就是狀態,跳或者不跳就是行為,而什麼時候跳就是學到的策略。


如果放在Image captioning中,狀態就是你看到的圖片和已生成的單詞,而動作就是下一個單詞生成什麼,回報就是CIDER等metric。


相關文獻

最近已經有很多工作將RL用在NLP相關的問題上。[1]第一次將REINFORCE演算法用在image caption和seq2seq問題上。[5]將使用了更先進的RL演算法 — Actor-critic — 來做machine translation上。[2,4]將[1]的演算法進行稍許改進(仍舊是REINFORCE演算法),使用在了image captioning上。[3]將REINFORCE用在序列生成GAN中,解決了之前序列生成器輸出為離散不可微的問題。[6]將RL用在自然對話系統中。這篇文章中我們主要介紹[1,2,4]。


RL演算法背景


這三篇文章使用的是REINFORCE演算法,屬於增強學習中Policy Gradient的一種。我們需要將deterministic的策略形式 $a=pi(s, heta)$轉化為概率形式,$p(a) = pi(a|s, heta)$。Policy Gradient就是對參數$ heta$求梯度的方法。


直觀的想,如果我們希望最後的決策能獲得更高的reward,最簡單的就是使得高reward的行為有高概率,低reward的行為有低概率。所以REINFORCE的更新目標為


$$max_{ heta} sum R(a,s)log pi(a|s, heta)$$

$R(s,a)$是回報函數。有了目標,我們可以通過隨機梯度下降來更新$ heta$來獲得更大的回報。


然而這個方法有一個問題,訓練時梯度的方差過大,導致訓練不穩定。我們可以思考一下,如果reward的值為100到120之間,現在的方法雖然能更大地提高reward為120的行為的概率,但是也還是會提升低reward的行為的概率。所以為了克服這個問題,又有了REINFORCE with baseline。


$$max_{ heta} sum (R(a,s) - b(s))log pi(a|s, heta)$$


$b(s)$在這裡就是baseline,目的是通過給回報一個基準來減少方差。假設還是100到120的回報,我們將baseline設為110,那麼只有100回報的行為就會被降低概率,而120回報的行為則會被提升概率。

三篇paper


第一篇是FAIR在ICLR2016發表的[1]。這篇文章是第一個將RL的演算法應用的離散序列生成的文章。文章中介紹了三種不同的方法,這裡我們只看最後一種演算法,Mixed Incremental Cross-Entropy Reinforce。


大體的想法就是用REINFORCE with baseline來希望直接優化BLEU4分數。具體訓練的時候,他們先用最大似然方法做預訓練,然後用REINFORCE finetune。在REINFORCE階段,生成器不再使用任何ground truth信息,而是直接從RNN模型隨機採樣,最後獲得採樣的序列的BLEU4的分數r作為reward來更新整個序列生成器。


這裡他們使用baseline在每個時刻是不同的;是每個RNN隱變數的一個線性函數。這個線性函數也會在訓練中更新。他們的系統最後能比一般的的cross extropy loss,和scheduled sampling等方法獲得更好的結果。


他們在github開源了基於torch的代碼,https://github.com/facebookresearch/MIXER


第二篇論文是今年CVPR的投稿。這篇文章在[1]的基礎上改變了baseline的選取。他們並沒有使用任何函數來對baseline進行建模,而是使用了greedy decoding的結果的回報作為baseline。他們聲稱這個baseline減小了梯度的variance。


這個baseline理解起來也很簡單:如果採樣得到句子沒有greedy decoding的結果好,那麼降低這句話的概率,如果比greedy decoding還要好,則提高它的概率。


這個方法的好處在於避免了訓練一個模型,並且這個baseline也極易獲得。有一個很有意思的現象是,一旦使用了這樣的訓練方法,beam search和greedy decoding的結果就幾乎一致了。


目前這篇文章的結果是COCO排行榜上第一名。他們使用CIDEr作為優化的reward,並且發現優化CIDEr能夠使所有其他metric如BLEU,ROUGE,METEOR都能提高。


他們的附錄中有一些captioning的結果。他們發現他們的模型在一些非尋常的圖片上表現很好,比如說有一張手心裡捧著一個長勁鹿的圖。


第三篇論文[4]也是這次CVPR的投稿。這篇文章則是在$R(a,s)$這一項動了手腳。


前兩篇都有一個共同特點,對所有時刻的單詞,他們的$R(a,s)$都是一樣的。然而這篇文章則給每個時刻的提供了不同的回報。


其實這個動機很好理解。比如說,定冠詞a,無論生成的句子質量如何,都很容易在句首出現。假設說在一次採樣中,a在句首,且最後的獲得回報減去baseline後為負,這時候a的概率也會因此被調低,但是實際上大多數情況a對最後結果的好壞並沒有影響。所以這篇文章採用了在每個時刻用$Q(w_)$來代替了原來一樣的$R$。


這個$Q$的定義為,


$Q heta(w) = mathbb}[R(w, w)]$


也就是說,當前時刻的回報,為固定了前t個單詞的期望回報。考慮a的例子,由於a作為句首生成的結果有好有壞,最後的Q值可能接近於baseline,所以a的概率也就不會被很大地更新。實際使用中,這個Q值可以通過rollout來估計:固定前t個詞後,隨機採樣K個序列,取他們的平均回報作為Q值。文中K為3。這篇文章中的baseline則跟[1]中類似。


從實驗結果上,第三篇並沒有第二篇好,但是很大一部分原因是因為使用的模型和特徵都比較老舊。


總結


將RL用在序列生成上似乎是現在新的潮流。但是現在使用的大多數的RL方法還比較簡單,比如本文中的REINFORCE演算法可追溯到上個世紀。RL本身也是一個很火熱的領域,所以可以預計會有更多的論文將二者有機地結合。


參考文獻


[1] Ranzato, Marc』Aurelio, Sumit Chopra, Michael Auli, and Wojciech Zaremba. 「Sequence level training with recurrent neural networks.」 arXiv preprint arXiv:1511.06732 (2015).


[2] Rennie, Steven J., Etienne Marcheret, Youssef Mroueh, Jarret Ross, and Vaibhava Goel. 「Self-critical Sequence Training for Image Captioning.」 arXiv preprint arXiv:1612.00563 (2016).


[3] Yu, Lantao, Weinan Zhang, Jun Wang, and Yong Yu. 「Seqgan: sequence generative adversarial nets with policy gradient.」 arXiv preprint arXiv:1609.05473 (2016).


[4] Liu, Siqi, Zhenhai Zhu, Ning Ye, Sergio Guadarrama, and Kevin Murphy. 「Optimization of image description metrics using policy gradient methods.」 arXiv preprint arXiv:1612.00370 (2016).


[5] Bahdanau, Dzmitry, Philemon Brakel, Kelvin Xu, Anirudh Goyal, Ryan Lowe, Joelle Pineau, Aaron Courville, and Yoshua Bengio. 「An actor-critic algorithm for sequence prediction.」 arXiv preprint arXiv:1607.07086 (2016).


[6] Li, Jiwei, Will Monroe, Alan Ritter, Michel Galley, Jianfeng Gao, and Dan Jurafsky. 「Deep reinforcement learning for dialogue generation.」 arXiv preprint arXiv:1606.01541 (2016).


作者


羅若天,TTIC博士生研究方向CV+NLP


github:https://github.com/ruotianluo

您的贊是小編持續努力的最大動力,動動手指贊一下吧!


本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 科研圈 的精彩文章:

罹患自閉症的女孩有著更「男性化」的大腦
年輕學者說:獨立誰不想?現實好殘酷
自然選擇可能讓人類越來越不傾向於接受更多的教育
科研圈一周精選招聘
上海交通大學電子信息與電氣工程學院儀器系「青年千人」宋傑課題組博士後招聘

TAG:科研圈 |

您可能感興趣

如何學習Javascript
Spring boot學習
Apple Watch如此領先 Android Wear仍需從中學習更多
java學習 JavaScript學習心得
學習清單:如何提高soft skills-how to negotiate
為什麼要學習React Native
Android學習Broadcast Receiver 注意事項
Android Framework 如何學習,如何從應用深入到Framework?
谷歌工程師:聊一聊深度學習的weight initialization
Windows Shellcode學習筆記——利用VirtualAlloc繞過DEP
資深演算法工程師眼中的深度學習:Ian Goodfellow 和Yoshua Bengio的「Deep Learning」讀書分享
Slow Down,向Gentlewoman學習穿衣哲學
Apache ZooKeeper進一步學習
遭遇痛風的髖 Hip Gout Arthritis文獻學習
Java學習之static關鍵字
Intel+Cloudera,用BigDL玩轉深度學習
Elasticsearch 5.4 beta 新功能:機器學習官方支持來了!
Play-with-Docker:線上免費學習Docker
基於 Python 和 Scikit-Learn 的機器學習介紹