當前位置:
首頁 > 知識 > 一文讀懂生成對抗網路GANs

一文讀懂生成對抗網路GANs

原文標題:AnIntuitive Introduction to Generative Adversarial Networks

作者:KeshavDhandhania、ArashDelijani

翻譯:申利彬

校對:和中華

本文約4000字,建議閱讀10分鐘

本文以圖像生成問題引出GAN模型,並介紹了GAN模型的數學原理和訓練過程,最後提供了豐富的GAN學習資料。

本文討論生成對抗網路,簡稱GANs。在生成任務或更廣泛的無監督學習中,GANs是為數不多在此領域表現較好的機器學習技術之一。特別是他們在圖像生成相關任務上擁有出色表現。深度學習領域先驅Yann LeCun,稱讚GAN是機器學習近十年來最好的想法。最重要的是,GAN相關核心概念很容易理解(事實上,讀完本文後你就可以對它有個清晰的認識)。

我們將GANs應用在圖像生成任務中,並以此來解釋GANs,下面是本文的概要:

簡單回顧深度學習

圖像生成問題

生成任務中的關鍵問題

生成對抗網路

挑戰

進一步閱讀

總結

簡單回顧深度學習

(前饋)神經網路示意圖,棕色為輸入層,黃色為隱藏層,紅色為輸出層

我們先簡單介紹一下深度學習。上圖是神經網路示意圖,它是由神經元組成,神經元之間通過邊相互連接,而且神經元按層排列,中間為隱藏層,輸入層和輸出層分別在左右兩側。神經元之間的連接邊都有權重,每一個神經元都會根據與其連接的神經元的輸入值加權求和,然後帶入非線性激活函數中計算,這類激活函數有Sigmoid和ReLU。例如,第一層隱藏層神經元對來自輸入層神經元的值進行加權求和,然後再應用ReLU函數。激活函數引入了非線性,它使神經網路可以模擬複雜的現象(多個線性層等價於一個線性層)。

給一個特定的輸入,我們依次計算每個神經元輸出的值(也稱為神經元的活性)。從左到右,我們利用前層計算好的值來逐層計算,最後得到輸出層的值。然後根據輸出層的值和期望值(目標值)定義一個損失函數,例如,均方誤差損失函數。

其中,x是輸入,h(x)是輸出,y是目標值,總和包含數據集中所有數據點。

在每步中,我們的目標是以合適的數值優化每條邊的權重,從而儘可能降低損失函數的大小。我們計算出梯度值,然後利用梯度具體優化每一個權重。當我們計算出損失函數值,就可以用反向傳播演算法計算梯度。反向傳播演算法的主要結果是:利用鏈式求導法則和後一層參數的梯度值來計算這層的梯度。然後,我們通過與各個梯度成比例的量(即梯度下降)來更新每個權重。

如果你想要進一步了解神經網路和反向傳播演算法的細節,我推薦你閱讀Nikhil Buduma寫的簡單學習深度學習(Deep Learning in aNutshell)

http://nikhilbuduma.com/2014/12/29/deep-learning-in-a-nutshell/

圖像生成問題

圖像生成問題上,我們希望機器學習模型可以生成圖像。為了訓練模型,我們得到了一個圖像數據集(比如從網路下載的1,000,000張圖片)。在測試的時候,模型可以生成圖像,這些圖像看起來像屬於訓練集,但實際上並不是訓練集中的圖像。也就是說,我們想生成新的圖像(與單純地記憶相反),但仍然希望它能捕獲訓練數據集中的模式,從而使新的圖像感覺與訓練數據集相似。

圖像生成問題:沒有輸入,所需的輸出是一個圖像

需要注意的一點是:在測試或預測階段,這個問題沒有輸入。每次「運行模型」時,我們希望它生成(輸出)一個新的圖像。這可以說輸入將從一個容易抽樣的分布(例如均勻分布或高斯分布)中隨機抽樣而來。

生成任務中的關鍵問題

生成任務中的關鍵問題是:什麼是一個好的損失函數?假如你有兩張機器學習模型生成的圖片,我們如何決定哪一個更好,好多少呢?

在以前的方法中,這個問題最常見的解決方案是計算輸出圖像和訓練集中最鄰近圖像的距離,其中使用一些預定義的距離度量標準來計算距離。例如,在語言翻譯任務中,我們通常有一個源語句和一個小的(約5個)目標句子集,也就是由不同翻譯人員提供的譯文。當模型生成一個譯文,我們把譯文與提供的目標句子比較,然後根據它距離哪個目標句子最近,分配一個相應的分數(特別是,我們是用BLEU分數,它是根據兩個句子之間有多少個n-grams匹配的距離度量標準)。但這是一種單句子翻譯方法,當目標是一個較大的文本時,同樣的方法會使損失函數的質量嚴重惡化。例如,我們的任務可能是生成給定文章的段落摘要,那麼這種惡化源於少量的樣本無法代表在所有可能的正確答案中觀察到的變化範圍。

生成對抗網路

GAN針對上面問題的回答是:用另外一個神經網路---記分神經網路(稱為判別器 Discriminator),它將評估生成神經網路輸出的圖像的真實性。這兩個神經網路具有相反的目標(對抗),生成網路的目標是生成一個看起來真實的假圖像,判別網路的目標是區分假圖像和真實圖像。

這將生成任務的設置類似於強化學習的雙人遊戲(如象棋,Atari games or 圍棋),在強化學習中我們有一個從零開始通過自我對抗不斷改進的機器學習模型 。象棋或者圍棋這些遊戲的對抗雙方總是對稱的(儘管並非總是如此),但對於GAN的設置,兩個網路的目標和角色是不相同的。一個網路產生假的樣本,而另一個網路區分真的和假的樣本。

生成對抗網路的示意圖,生成器網路標記為G,判別器網路標記為D

如上圖所示,是生成對抗網路示意圖。生成網路G和判別網路D在進行一場雙方極大極小博弈。首先,為了更好地理解這種對抗機制,需要注意到判別網路(D)的輸入可以是從訓練集中抽樣出的樣本,也可以是生成網路(G)的輸出,不過一般是50%來自訓練集,剩餘50%來自G。為了從G中生成樣本,我們從高斯分布中提取潛在的向量並輸入生成網路(G)。如果我們想生成200*200的灰度圖像,那麼生成網路(G)的輸出應該是200*200的向量。下面給出目標函數,它是判別網路(D)做預測的標準對數似然函數。

生成網路(G)是最小化目標函數,也就是減小對數似然函數或是說「迷惑」判別網路(D)。也就是說,無論何時從生成網路(G)輸出中抽取樣本作為判別網路(D)的輸入,都希望判別網路識別為真樣本。判別網路(D)是要最大化目標函數,也就是要增大對數似然函數或者說是把真實樣本和生成樣本區分開。換句話說,如果生成網路(G)在「迷惑」判別網路(D)上效果很好,也就會通過增大公式第二項中D(G(z))來最小化目標函數。另外,如果判別網路(D)能很好地工作,那麼在從訓練數據中選擇樣本的情況下,會通過第一項(因為D(x)很大)增大目標函數,也會通過第二項減小它(因為D(x)很小)。

如同平常的訓練過程一樣,使用隨機初始化和反向傳播,此外,我們需要單獨交替迭代更新生成器和判別器。下面是在特定問題上應用GANs的端到端的工作流程描述:

1. 決定GAN網路架構:G的架構是什麼?D的架構是什麼?

2. 訓練:一定數量的交替更新

更新D(固定G):一半樣本是真的,另一半是假的

更新G(固定D):生成所有樣本(注意,即使D保持不變,梯度流還是會經過D)

3. 人工檢查一些假樣本,如果質量很高(或者質量沒有提升)則停止,否則重複2。

當G和D都是前饋神經網路時,我們得到的結果如下(在MNIST數據集中訓練)

來自Goodfellow et. Al,從訓練集開始,最右邊一列(黃色框內)圖像與其緊鄰左邊一列的圖像最接近。其它所有圖像都是生成的樣本

關於G和D我們可以使用更複雜的架構,例如使用跳格卷積(strided convolutional)和adam優化器來代替隨機梯度下降。另外,還有其它一些方面的改進,例如優化架構,使用超參數和優化器(具體可參考論文)。改進後,我們得到了如下的結果:

卧室圖片,來自Alec Radford et. Al

挑戰

訓練GANs最關鍵的挑戰是有可能不收斂,有時這個問題也被稱為模式崩潰(mode collapse)。舉個例子,來簡單解釋這個問題。假設任務是生成數字圖像,就像MNIST數據集中的一樣。可能出現的問題(實踐中確實出現)是生成器G開始生成數字6,而不能生成其它數字。一旦D適應G的當前行為,為了最大限度地提高分類的準確性,它開始把所有的數字6歸為假,所有其它數字都是真實的(假設它不能分辨假的6和真實的6)。然後G又適應了D的當前行為,開始只生成數字8而不生成其它數字。然後D又適應,開始把數字8歸為假,其它的都是真。接著G又開始只生成3,如此循環下去。基本上,生成器G僅生成與訓練數據集的一個小的子集相似的圖像,而一旦識別器D開始把這個小的子集與其餘的區分開,生成器G又轉換到另外的子集,它們將一直簡單的來回震蕩。雖然這個問題沒有被完全解決,但還是有一些方法可以避免這個問題。這些方法涉及小批量特徵(minibatch features)和多次更新D的反向傳播。我們不再討論這些方法的細節,如果要了解更多信息,請查看下一節中的建議閱讀材料。

進一步閱讀

如果你想更深一步了解GANs,我建議你閱讀ICCV 2017 tutorials on GANs( https://sites.google.com/view/iccv-2017-gans/home),那裡有很多最新的教程,並且它們對GANs的不同方面各有側重。

我還想說一點關於條件GANs(Conditional GANs)的概念,條件GANs,是在輸入的條件下產生輸出。例如,任務可能是輸出與輸入描述相匹配的圖像。所以,當你輸入狗時,輸出的應該是狗的圖像。

下面是一些最近研究的成果(附論文鏈接)。

文本到圖像合成( 『Text to Image synthesis』)成果,作者Reed et. al

超解析度圖像(Image Super-resolution)成果,作者Ledig et. Al

圖像到圖像轉換(Image to Image translation)成果,作者Isola et. Al

生成高解析度名人相(Generating high resolution 『celebritylike』 images),作者Karras et. Al

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 數據派THU 的精彩文章:

8種方法用Python實現線性回歸,為你解析最高效選擇
機器學習在熱門微博推薦中的應用

TAG:數據派THU |