當前位置:
首頁 > 最新 > 我是這樣學習 GAN 的——開發者自述

我是這樣學習 GAN 的——開發者自述

為了

每篇文章的文末都有一個小話題

歡迎大家參與討論

有任何想說的都可以在評論區交流~

AI 研習社按:本文作者馬少楠,原載於作者知乎專欄,雷鋒網 AI 研習社經授權發布。

Generative Adversarial Network,就是大家耳熟能詳的 GAN,由 Ian Goodfellow 首先提出,在這兩年更是深度學習中最熱門的東西,彷彿什麼東西都能由 GAN 做出來。我最近剛入門 GAN,看了些資料,做一些筆記。

1.Generation

什麼是生成(generation)?就是模型通過學習一些數據,然後生成類似的數據。讓機器看一些動物圖片,然後自己來產生動物的圖片,這就是生成。

以前就有很多可以用來生成的技術了,比如 auto-encoder(自編碼器),結構如下圖:

你訓練一個 encoder,把 input 轉換成 code,然後訓練一個 decoder,把 code 轉換成一個 image,然後計算得到的 image 和 input 之間的 MSE(mean square error),訓練完這個 model 之後,取出後半部分 NN Decoder,輸入一個隨機的 code,就能 generate 一個 image。

但是 auto-encoder 生成 image 的效果,當然看著很彆扭啦,一眼就能看出真假。所以後來還提出了比如VAE這樣的生成模型,我對此也不是很了解,在這就不細說。

上述的這些生成模型,其實有一個非常嚴重的弊端。比如 VAE,它生成的 image 是希望和 input 越相似越好,但是 model 是如何來衡量這個相似呢?model 會計算一個 loss,採用的大多是 MSE,即每一個像素上的均方差。loss 小真的表示相似嘛?

比如這兩張圖,第一張,我們認為是好的生成圖片,第二張是差的生成圖片,但是對於上述的 model 來說,這兩張圖片計算出來的 loss 是一樣大的,所以會認為是一樣好的圖片。

這就是上述生成模型的弊端,用來衡量生成圖片好壞的標準並不能很好的完成想要實現的目的。於是就有了下面要講的 GAN。

2.GAN

大名鼎鼎的 GAN 是如何生成圖片的呢?首先大家都知道 GAN 有兩個網路,一個是 generator,一個是 discriminator,從二人零和博弈中受啟發,通過兩個網路互相對抗來達到最好的生成效果。流程如下:

主要流程類似上面這個圖。首先,有一個一代的 generator,它能生成一些很差的圖片,然後有一個一代的 discriminator,它能準確的把生成的圖片,和真實的圖片分類,簡而言之,這個 discriminator 就是一個二分類器,對生成的圖片輸出 0,對真實的圖片輸出 1。

接著,開始訓練出二代的 generator,它能生成稍好一點的圖片,能夠讓一代的 discriminator 認為這些生成的圖片是真實的圖片。然後會訓練出一個二代的 discriminator,它能準確的識別出真實的圖片,和二代 generator 生成的圖片。以此類推,會有三代,四代。。。n 代的 generator 和 discriminator,最後 discriminator 無法分辨生成的圖片和真實圖片,這個網路就擬合了。

這就是 GAN,運行過程就是這麼的簡單。這就結束了嘛?顯然沒有,下面還要介紹一下 GAN 的原理。

3.原理

首先我們知道真實圖片集的分布 Pdata(x),x 是一個真實圖片,可以想像成一個向量,這個向量集合的分布就是 Pdata。我們需要生成一些也在這個分布內的圖片,如果直接就是這個分布的話,怕是做不到的。

我們現在有的 generator 生成的分布可以假設為 PG(x;θ),這是一個由 θ 控制的分布,θ 是這個分布的參數(如果是高斯混合模型,那麼 θ 就是每個高斯分布的平均值和方差)

假設我們在真實分布中取出一些數據,,我們想要計算一個似然 PG(xi; θ)。

對於這些數據,在生成模型中的似然就是

我們想要最大化這個似然,等價於讓 generator 生成那些真實圖片的概率最大。這就變成了一個最大似然估計的問題了,我們需要找到一個 θ* 來最大化這個似然。

尋找一個 θ* 來最大化這個似然,等價於最大化 log 似然。因為此時這 m 個數據,是從真實分布中取的,所以也就約等於,真實分布中的所有 x 在 PG分布中的 log 似然的期望。

真實分布中的所有 x 的期望,等價於求概率積分,所以可以轉化成積分運算,因為減號後面的項和 θ 無關,所以添上之後還是等價的。然後提出共有的項,括弧內的反轉,max 變 min,就可以轉化為 KL divergence 的形式了,KL divergence 描述的是兩個概率分布之間的差異。

所以最大化似然,讓 generator 最大概率的生成真實圖片,也就是要找一個 θ 讓 PG更接近於 Pdata。

那如何來找這個最合理的 θ 呢?我們可以假設 PG(x; θ) 是一個神經網路。

首先隨機一個向量 z,通過 G(z)=x 這個網路,生成圖片 x,那麼我們如何比較兩個分布是否相似呢?只要我們取一組 sample z,這組 z 符合一個分布,那麼通過網路就可以生成另一個分布 PG,然後來比較與真實分布 Pdata。

大家都知道,神經網路只要有非線性激活函數,就可以去擬合任意的函數,那麼分布也是一樣,所以可以用一直正態分布,或者高斯分布,取樣去訓練一個神經網路,學習到一個很複雜的分布。

如何來找到更接近的分布,這就是 GAN 的貢獻了。先給出 GAN 的公式:

這個式子的好處在於,固定 G,max V(G,D) 就表示 PG和 Pdata之間的差異,然後要找一個最好的 G,讓這個最大值最小,也就是兩個分布之間的差異最小。

表面上看這個的意思是,D 要讓這個式子儘可能的大,也就是對於 x 是真實分布中,D(x) 要接近與 1,對於 x 來自於生成的分布,D(x) 要接近於 0,然後 G 要讓式子儘可能的小,讓來自於生成分布中的 x,D(x) 儘可能的接近 1。

現在我們先固定 G,來求解最優的 D:

對於一個給定的 x,得到最優的 D 如上圖,範圍在 (0,1) 內,把最優的 D 帶入

可以得到:

JS divergence 是 KL divergence 的對稱平滑版本,表示了兩個分布之間的差異,這個推導就表明了上面所說的,固定 G。

表示兩個分布之間的差異,最小值是 -2log2,最大值為 0。

現在我們需要找個 G,來最小化

觀察上式,當 PG(x)=Pdata(x) 時,G 是最優的。

4.訓練

有了上面推導的基礎之後,我們就可以開始訓練 GAN 了。結合我們開頭說的,兩個網路交替訓練,我們可以在起初有一個 G和 D,先訓練 D找到 :

然後固定 D開始訓練 G, 訓練的過程都可以使用 gradient descent,以此類推,訓練 D1,G1,D2,G2,...

但是這裡有個問題就是,你可能在 D* 的位置取到了:

然後更新 G為 G1,可能

了,但是並不保證會出現一個新的點 D1* 使得

這樣更新 G 就沒達到它原來應該要的效果,如下圖所示:

避免上述情況的方法就是更新 G 的時候,不要更新 G 太多。

知道了網路的訓練順序,我們還需要設定兩個 loss function,一個是 D 的 loss,一個是 G 的 loss。下面是整個 GAN 的訓練具體步驟:

上述步驟在機器學習和深度學習中也是非常常見,易於理解。

5.存在的問題

但是上面 G 的 loss function 還是有一點小問題,下圖是兩個函數的圖像:

log(1-D(x)) 是我們計算時 G 的 loss function,但是我們發現,在 D(x) 接近於 0 的時候,這個函數十分平滑,梯度非常的小。這就會導致,在訓練的初期,G 想要騙過 D,變化十分的緩慢,而上面的函數,趨勢和下面的是一樣的,都是遞減的。但是它的優勢是在 D(x) 接近 0 的時候,梯度很大,有利於訓練,在 D(x) 越來越大之後,梯度減小,這也很符合實際,在初期應該訓練速度更快,到後期速度減慢。

所以我們把 G 的 loss function 修改為

這樣可以提高訓練的速度。

還有一個問題,在其他 paper 中提出,就是經過實驗發現,經過許多次訓練,loss 一直都是平的,也就是

JS divergence 一直都是 log2,PG和 Pdata完全沒有交集,但是實際上兩個分布是有交集的,造成這個的原因是因為,我們無法真正計算期望和積分,只能使用 sample 的方法,如果訓練的過擬合了,D 還是能夠完全把兩部分的點分開,如下圖:

對於這個問題,我們是否應該讓 D 變得弱一點,減弱它的分類能力,但是從理論上講,為了讓它能夠有效的區分真假圖片,我們又希望它能夠 powerful,所以這裡就產生了矛盾。

還有可能的原因是,雖然兩個分布都是高維的,但是兩個分布都十分的窄,可能交集相當小,這樣也會導致 JS divergence 算出來 =log2,約等於沒有交集。

解決的一些方法,有添加雜訊,讓兩個分布變得更寬,可能可以增大它們的交集,這樣 JS divergence 就可以計算,但是隨著時間變化,雜訊需要逐漸變小。

還有一個問題叫 Mode Collapse,如下圖:

這個圖的意思是,data 的分布是一個雙峰的,但是學習到的生成分布卻只有單峰,我們可以看到模型學到的數據,但是卻不知道它沒有學到的分布。

造成這個情況的原因是,KL divergence 里的兩個分布寫反了

這個圖很清楚的顯示了,如果是第一個 KL divergence 的寫法,為了防止出現無窮大,所以有 Pdata出現的地方都必須要有 PG覆蓋,就不會出現 Mode Collapse。

6.參考

這是對 GAN 入門學習做的一些筆記和理解,後來太懶了,不想打公式了,主要是參考了李宏毅老師的視頻:

http://t.cn/RKXQOV0

---------------------------

來聊聊吧

喜歡GAN這種演算法嗎?

歡迎在評論區分享


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 唯物 的精彩文章:

一文詳解卷積神經網路的演變歷程!
深度學習自動編碼器還能用於數據生成?這篇文章告訴你答案
只要130 行代碼即可生成二維樣本,心動了嗎?
如何用 Caffe 生成對抗樣本?這篇文章告訴你一個更高效的演算法

TAG:唯物 |

您可能感興趣

別人家的孩子已經開始用iPad學習AR開發啦
如果你是一位APP開發者,你一定懂我在說什麼
想成為一名Web開發者?你應該學習Node.js而不是PHP
新時代遊戲腳本外掛?EA開發自我學習AI打戰地1
這恐怕是WWDC蘋果開發者大會多年來最軟的一次了
Adobe專業開發者眼中未來的VR是什麼樣的? | VR網
想成為一名Web開發者?或許應該學習Node.js而不是PHP
F8開發者大會:用戶還是開發者,這是一個問題
學習UE4開發,為什麼要學習C+?
Mock API是如何在開發中發光發熱的?
索尼推薦,PS VR開發者分享遊戲開發應該做和不該做的事情
谷歌將開放地圖,遊戲開發者們可以開發下一款 Pokémon Go 了
優秀的機器學習開發者都是這樣做的!
GitHub為開發者推出機器人學習實驗室
IKEA開發的這些「未來食物」,你有沒有膽量嘗試一下?
這款不用看PC設備的VR已經向開發者發貨
你在做軟體開發?據說有個AIDO 很好用……
MIT:我們用深度學習開發了一個能識別人類情緒的機器人
為 TV 開發的 App,你說要運行在手機上?
GDC:VR遊戲開發熱情下滑 Vive仍最受VR開發者喜愛