看完立刻理解GAN!初學者也沒關係
雷鋒網按:本文原作者天雨粟,原文載於作者的知乎專欄——機器不學習,雷鋒網經授權發布。
前言
GAN 從 2014 年誕生以來發展的是相當火熱,比較著名的 GAN 的應用有 Pix2Pix、CycleGAN 等。本篇文章主要是讓初學者通過代碼了解 GAN 的結構和運作機制,對理論細節不做過多介紹。我們還是採用 MNIST 手寫數據集(不得不說這個數據集對於新手來說非常好用)來作為我們的訓練數據,我們將構建一個簡單的 GAN 來進行手寫數字圖像的生成。
認識 GAN
GAN 主要包括了兩個部分,即生成器 generator 與判別器 discriminator。生成器主要用來學習真實圖像分布從而讓自身生成的圖像更加真實,以騙過判別器。判別器則需要對接收的圖片進行真假判別。在整個過程中,生成器努力地讓生成的圖像更加真實,而判別器則努力地去識別出圖像的真假,這個過程相當於一個二人博弈,隨著時間的推移,生成器和判別器在不斷地進行對抗,最終兩個網路達到了一個動態均衡:生成器生成的圖像接近於真實圖像分布,而判別器識別不出真假圖像,對於給定圖像的預測為真的概率基本接近 0.5(相當於隨機猜測類別)。
對於 GAN 更加直觀的理解可以用一個例子來說明:造假幣的團伙相當於生成器,他們想通過偽造金錢來騙過銀行,使得假幣能夠正常交易,而銀行相當於判別器,需要判斷進來的錢是真錢還是假幣。因此假幣團伙的目的是要造出銀行識別不出的假幣而騙過銀行,銀行則是要想辦法準確地識別出假幣。
因此,我們可以將上面的內容進行一個總結。給定真 = 1,假 = 0,那麼有:
對於給定的真實圖片(real image),判別器要為其打上標籤 1;
對於給定的生成圖片(fake image),判別器要為其打上標籤 0;
對於生成器傳給辨別器的生成圖片,生成器希望辨別器打上標籤 1。
有了上面的直觀理解,下面就讓我們來實現一個 GAN 來生成手寫數據吧!還有一些細節會在代碼部分進行介紹。
說明
TensorFlow 1.0
Python 3
Jupyter Notebook
GitHub 地址:NELSONZHAO/zhihu
建議將代碼 pull 下來,有部分代碼實現沒有寫在文章中。
代碼部分
數據載入與查看
數據我們使用 TensorFlow 中給定的 MNIST 數據介面。
在構建模型之前,我們首先來看一下我們需要完成的任務:
Inputs
generator
discriminator
定義參數
loss & optimizer
訓練模型
顯示結果
輸入 inputs
輸入函數主要來定義真實圖片與生成圖片兩個 tensor。
定義生成器
我們的生成器結構如下:
我們使用了一個採用 Leaky ReLU 作為激活函數的隱層,並在輸出層加入 tanh 激活函數。
下面是生成器的代碼。注意在定義生成器和判別器時,我們要指定變數的 scope,這是因為 GAN 中實際上包含生成器與辨別器兩個網路,在後面進行訓練時是分開訓練的,因此我們要把 scope 定義好,方便訓練時候指定變數。
在這個網路中,我們使用了一個隱層,並加入 dropout 防止過擬合。通過輸入雜訊圖片,generator 輸出一個與真實圖片一樣大小的圖像。
在這裡我們的隱層激活函數採用的是 Leaky ReLU(中文不知道咋翻譯),這個函數在 ReLU 函數基礎上改變了左半邊的定義。
圖片來自維基百科。Andrej Karpathy 在 CS231n 中也提到有模型通過這個函數取得了不錯的效果。
由於 TensorFlow 中沒有這個函數的實現,在這裡我們通過函數定義實現了 Leaky ReLU,其中 alpha 是一個很小的數。在輸出層我們使用 tanh 函數,這是因為 tanh 在這裡相比 sigmoid 的結果會更好一點(在這裡要注意,由於生成器的生成圖片像素限制在了 (-1, 1) 的取值之間,而 MNIST 數據集的像素區間為 [0, 1],所以在訓練時要對 MNIST 的輸入做處理,具體見訓練部分的代碼)。到此,我們構建好了生成器,它通過接收一個雜訊圖片輸出一個與真實圖片一樣 size 的圖像。
定義判別器
判別器的結構如下:
判別器接收一張圖片,並判斷它的真假,同樣隱層使用了 Leaky ReLU,輸出層為 1 個結點,輸出為 1 的概率。代碼如下:
在這裡,我們需要注意真實圖片與生成圖片是共享判別器的參數的,因此在這裡我們留了 reuse 介面來方便我們後面調用。
定義參數
img_size 是我們真實圖片的 size=32*32=784。
smooth 是進行 Label Smoothing Regularization 的參數,在後面會介紹。
構建網路
接下來我們來構建我們的網路,並獲得生成器與判別器返回的變數。
我們分別獲得了生成器與判別器的 logits 和 outputs。注意真實圖片與生成圖片是共享參數的,因此在判別器輸入生成圖片時,需要 reuse 參數。
定義 Loss 和 Optimizer
有了上面的 logits,我們就可以定義我們的 loss 和 Optimizer。在這之前,我們再來回顧一下生成器和判別器各自的目的是什麼:
對於給定的真實圖片,辨別器要為其打上標籤 1;
對於給定的生成圖片,辨別器要為其打上標籤 0;
對於生成器傳給辨別器的生成圖片,生成器希望辨別器打上標籤 1。
我們來把上面這三句話轉換成代碼:
d_loss_real 對應著真實圖片的 loss,它儘可能讓判別器的輸出接近於 1。在這裡,我們使用了單邊的 Label Smoothing Regularization,它是一種防止過擬合的方式,在傳統的分類中,我們的目標非 0 即 1,從直覺上來理解的話,這樣的目標不夠 soft,會導致訓練出的模型對於自己的預測結果過於自信。因此我們加入一個平滑值來讓判別器的泛化效果更好。
d_loss_fake 對應著生成圖片的 loss,它儘可能地讓判別器輸出為 0。
d_loss_real 與 d_loss_fake 加起來就是整個判別器的損失。
而在生成器端,它希望讓判別器對自己生成的圖片儘可能輸出為 1,相當於它在於判別器進行對抗。
下面我們定義了優化函數,由於 GAN 中包含了生成器和判別器兩個網路,因此需要分開進行優化,這也是我們在之前定義 variable_scope 的原因。
訓練模型
由於訓練部分代碼太長,我在這裡就不貼出來了,請前往我的 GitHub 下載代碼。在訓練部分,我們記錄了部分圖像的生成過程,並記錄了訓練數據的 loss 變化。
我們將整個訓練過程的 loss 變化繪製出來:
從圖中可以看出來,最終的判別器總體 loss 在 1 左右波動,而 real loss 和 fake loss 幾乎在一條水平線上波動,這說明判別器最終對於真假圖像已經沒有判別能力,而是進行隨機判斷。
查看過程結果
我們在整個訓練過程中記錄了 25 個樣本在不同階段的 samples 圖像,以序列化的方式進行了保存,我們的將 samples 載入進來。samples 的 size=epochs x 2 x n_samples x 784,我們的迭代次數為 300 輪,25 個樣本,因此,samples 的 size=300 x 2 x 25 x 784。我們將最後一輪的生成結果列印出來:
這就是我們的 GAN 通過學習真實圖片的分布後生成的圖像結果。
那麼有同學可能會問了,我們如果想要看這 300 輪中生成圖像的變化是什麼樣該怎麼辦呢?因為我們已經有了 samples,存儲了每一輪迭代的結果,我們可以挑選幾次迭代,把對應的圖像打出來:
這裡我挑選了第 0, 5, 10, 20, 40, 60, 80, 100, 150, 250 輪的迭代效果圖,在這個圖中,我們可以看到最開始的時候只有中間是白色,背景黑色塊中存在著很多雜訊。隨著迭代次數的不斷增加,生成器製造 「假圖」 的能力也越來越強,它逐漸學得了真實圖片的分布,最明顯的一點就是圖片區分出了黑色背景和白色字元的界限。
生成新的圖片
如果我們想重新生成新的圖片呢?此時我們只需要將我們之前保存好的模型文件載入進來就可以啦。
總結
整篇文章基於 MNIST 數據集構造了一個簡單的 GAN 模型,相信小夥伴看完代碼會對 GAN 有一個初步的了解。從最終的模型結果來看,生成的圖像能夠將背景與數字區分開,黑色塊雜訊逐漸消失,但從顯示結果來看還是有很多模糊區域的。
對於這裡的圖片處理,相信很多小夥伴會想到卷積神經網路,那麼後面我們還會將生成器和判別器改為卷積神經網路來構造深度卷積 GAN,它對於圖片的生成會取得更好的效果。
如果覺得不錯,請給 GitHub 點個 Star 吧~
※專訪閱面科技童志軍:FDDB、LFW雙奪冠的人臉識別技術
※重磅講座預告:黃鐵軍、陳雲霽等專家齊聚CCF ADL,分享類腦計算與深度學習處理器
※報告解讀:醫療AI應用將貢獻年增長率的40%,總市場將達100億美元
※「異鬼Ⅱ」Bootkit木馬藏身甜椒刷機軟體 騰訊電腦管家精準攔截
TAG:雷鋒網 |
※娃娃表示RNG對換線理解太可怕,RNG已經真正理解了拳頭的意圖
※用你的常識就可以理解EOS
※解說米勒談RNG輸給FW: RNG不能總指望UZI! 版本理解也不行!
※你不當CEO,不會理解有些CEO為什麼要自殺
※RNG贏比賽Uzi失誤卻登熱搜,解說表示難以理解,網友認為是BUG!
※VAR加入世界盃,C羅:我們應該理解它
※談AI隱私泄露問題:理解AI,AI「理解」人類同樣重要
※TCP 三次握手原理,你真的理解嗎?
※淺談我對OKR的理解
※ctDNA可靠性又遭質疑?一文助你真正理解ASCO&CAP聯合綜述
※也許,這樣理解HTTPS更容易
※JSONP淺顯理解
※「一哭就抱」會寵壞孩子?看完最新研究,才知道自己完全理解錯了
※7個目前科學無法合理解釋的未解之謎,你知道嗎?
※「一哭就抱」會寵壞孩子嗎?看完最新研究,才知道自己完全理解錯了!
※「你得理解我」「理解你,理解你,那誰來理解我?」
※為何KTV包廂里有獨立廁所,很多人不理解,看完後知道其中內幕
※LOL夏季賽:RNG誰都打不過,怎麼理解?看看他們的比賽才知道
※看完今日的分享,你或許就理解一些谷歌傳道AI的方法論
※難以理解的韓國文化TOP3,你猜對了嗎?