當前位置:
首頁 > 知識 > 用GAN來做圖像生成,這是最好的方法

用GAN來做圖像生成,這是最好的方法

本文原作者天雨粟,原文載於作者的知乎專欄——機器不學習。AI 研習社經授權發布。前言

對於圖像問題,卷積神經網路相比於簡單地全連接的神經網路更具優勢。

本文將繼續深入 GAN,通過融合卷積神經網路來對我們的 GAN 進行改進,實現一個深度卷積 GAN。

如果還沒有親手實踐過 GAN 的小夥伴可以先去學習一下上一篇專欄:生成對抗網路(GAN)之 MNIST 數據生成。

本次代碼在 NELSONZHAO/zhihu/dcgan,裡面包含了兩個文件:

dcgan_mnist:基於 MNIST 手寫數據集構造深度卷積 GAN 模型

dcgan_cifar:基於 CIFAR 數據集構造深度卷積 GAN 模型

本文主要以 MNIST 為例進行介紹,兩者在本質上沒有差別,只在細微的參數上有所調整。由於窮學生資源有限,沒有對模型增加迭代次數,也沒有構造更深的模型。並且也沒有選取像素很高的圖像,高像素非常消耗計算量。

本節只是一個拋磚引玉的作用,讓大家了解 DCGAN 的結構,如果有資源的小夥伴可以自己去嘗試其他更清晰的圖片以及更深的結構,相信會取得很不錯的結果。

工具

Python3

TensorFlow 1.0

Jupyter notebook

正文

整個正文部分將包括以下部分:

- 數據載入

- 模型輸入

- Generator

- Discriminator

- Loss

- Optimizer

- 訓練模型

- 可視化

數據載入

數據載入部分採用 TensorFlow 中的 input_data 介面來進行載入。關於載入細節在前面的文章中已經寫了很多次啦,相信看過我文章的小夥伴對 MNIST 載入也非常熟悉,這裡不再贅述。

模型輸入

在 GAN 中,我們的輸入包括兩部分,一個是真實圖片,它將直接輸入給 discriminator 來獲得一個判別結果;另一個是隨機雜訊,隨機雜訊將作為 generator 來生成圖片的材料,generator 再將生成圖片傳遞給 discriminator 獲得一個判別結果。

上面的函數定義了輸入圖片與雜訊圖片兩個 tensor。

Generator

生成器接收一個雜訊信號,基於該信號生成一個圖片輸入給判別器。在上一篇專欄文章生成對抗網路(GAN)之 MNIST 數據生成中,我們的生成器是一個全連接層的神經網路,而本節我們將生成器改造為包含卷積結構的網路,使其更加適合處理圖片輸入。整個生成器結構如下:

我們採用了 transposed convolution 將我們的雜訊圖片轉換為了一個與輸入圖片具有相同 shape 的生成圖像。我們來看一下具體的實現代碼:

上面的代碼是整個生成器的實現細節,裡面包含了一些 trick,我們來一步步地看一下。

首先我們通過一個全連接層將輸入的雜訊圖像轉換成了一個 1 x 4*4*512 的結構,再將其 reshape 成一個 [batch_size, 4, 4, 512] 的形狀,至此我們其實完成了第一步的轉換。

接下來我們使用了一個對加速收斂及提高卷積神經網路性能中非常有效的方法——加入 BN(batch normalization),它的思想是歸一化當前層輸入,使它們的均值為 0 和方差為 1,類似於我們歸一化網路輸入的方法。

它的好處在於可以加速收斂,並且加入 BN 的卷積神經網路受權重初始化影響非常小,具有非常好的穩定性,對於提升卷積性能有很好的效果。關於 batch normalization,我會在後面專欄中進行一個詳細的介紹。

完成 BN 後,我們使用 Leaky ReLU 作為激活函數,在上一篇專欄中我們已經提過這個函數,這裡不再贅述。最後加入 dropout 正則化。剩下的 transposed convolution 結構層與之類似,只不過在最後一層中,我們不採用 BN,直接採用 tanh 激活函數輸出生成的圖片。

在上面的 transposed convolution 中,很多小夥伴肯定會對每一層 size 的變化疑惑,在這裡來講一下在 TensorFlow 中如何來計算每一層 feature map 的 size。首先,在卷積神經網路中,假如我們使用一個 k x k 的 filter 對 m x m x d 的圖片進行卷積操作,strides 為 s,在 TensorFlow 中,當我們設置 padding= same 時,卷積以後的每一個 feature map 的 height 和 width 為

;當設置 padding= valid 時,每一個 feature map 的 height 和 width 為

。那麼反過來,如果我們想要進行 transposed convolution 操作,比如將 7 x 7 的形狀變為 14 x 14,那麼此時,我們可以設置 padding= same ,strides=2 即可,與 filter 的 size 沒有關係;而如果將 4 x 4 變為 7 x 7 的話,當設置 padding= valid 時,即

,此時 s=1,k=4 即可實現我們的目標。

上面的代碼中我也標註了每一步 shape 的變化。

Discriminator

Discriminator 接收一個圖片,輸出一個判別結果(概率)。其實 Discriminator 完全可以看做一個包含卷積神經網路的圖片二分類器。結構如下:

實現代碼如下:

上面代碼其實就是一個簡單的卷積神經網路圖像識別問題,最終返回 logits(用來計算 loss)與 outputs。這裡沒有加入池化層的原因在於圖片本身經過多層卷積以後已經非常小了,並且我們加入了 batch normalization 加速了訓練,並不需要通過 max pooling 來進行特徵提取加速訓練。

Loss Function

Loss 部分分別計算 Generator 的 loss 與 Discriminator 的 loss,和之前一樣,我們加入 label smoothing 防止過擬合,增強泛化能力。

Optimizer

GAN 中實際包含了兩個神經網路,因此對於這兩個神經網路要分開進行優化。代碼如下:

這裡的 Optimizer 和我們之前不同,由於我們使用了 TensorFlow 中的 batch normalization 函數,這個函數中有很多 trick 要注意。首先我們要知道,batch normalization 在訓練階段與非訓練階段的計算方式是有差別的,這也是為什麼我們在使用 batch normalization 過程中需要指定 training 這個參數。上面使用 tf.control_dependencies 是為了保證在訓練階段能夠一直更新 moving averages。具體參考 A Gentle Guide to Using Batch Normalization in Tensorflow - Rui Shu。

訓練

到此為止,我們就完成了深度卷積 GAN 的構造,接著我們可以對我們的 GAN 來進行訓練,並且定義一些輔助函數來可視化迭代的結果。代碼太長就不放上來了,可以直接去我的 GitHub 下載。

我這裡只設置了 5 輪 epochs,每隔 100 個 batch 列印一次結果,每一行代表同一個 epoch 下的 25 張圖:

我們可以看出僅僅經過了少部分的迭代就已經生成非常清晰的手寫數字,並且訓練速度是非常快的。

上面的圖是最後幾次迭代的結果。我們可以回顧一下上一篇的一個簡單的全連接層的 GAN,收斂速度明顯不如深度卷積 GAN。

總結

到此為止,我們學習了一個深度卷積 GAN,並且看到相比於之前簡單的 GAN 來說,深度卷積 GAN 的性能更加優秀。當然除了 MNST 數據集以外,小夥伴兒們還可以嘗試很多其他圖片,比如我們之前用到過的 CIFAR 數據集,我在這裡也實現了一個 CIFAR 數據集的圖片生成,我只選取了馬的圖片進行訓練:

剛開始訓練時:

訓練 50 個 epochs:

這裡我只設置了 50 次迭代,可以看到最後已經生成了非常明顯的馬的圖像,可見深度卷積 GAN 的優勢。

我的 GitHub:NELSONZHAO (Nelson Zhao)

上面包含了我的專欄中所有的代碼實現,歡迎 star,歡迎 fork。

文章還不夠?來看直播吧!

掃碼進入直播間

關注 AI 研習社(okweiwu),回復1領取

【超過 1000G 神經網路/AI/大數據、教程、論文!】

普林斯頓聯合Adobe 連聲音都能PS了

點擊展開全文

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 唯物 的精彩文章:

Google AI 實力打臉:你真的懂機器學習嘛?
在ADAS中運用多任務深度學習框架
別人在刷屏而我在讀書!吳恩達這本書是你踏入deeplearning.ai的必修課
前Twitter矽谷資深工程師解讀Yolo2和Yolo9000目標檢測系統
霧霾太重?深度神經網路教你如何圖像去霧

TAG:唯物 |

您可能感興趣

AI版「大家來找茬」上線,究竟誰是真人,誰是GAN生成的假臉?
斯坦福AI Lab:除了生成圖像,GAN還可以用來合成基因
刷新一次,生成一張逼真假臉:用英偉達StyleGAN做的網站,生出了靈異事件
定製人臉圖像沒那麼難!使用TL-GAN模型輕鬆變臉
用GAN自動生成法線貼圖,讓圖形設計更輕鬆
谷歌大腦打造「以一當十」的GAN:僅用10%標記數據,生成圖像卻更逼真
為什麼讓GAN一家獨大?Facebook提出非對抗式生成方法GLANN
用英偉達StyleGAN生成老婆吧,他生成了一百多隻明日香
深度生成模型如何工作?一起去看看VAE和GAN背後的功與名!
GaN是5G最好選擇 手機端應用現實嗎?
還在腦補畫面?這款GAN能把故事畫出來
學界 | 用GAN自動生成法線貼圖,讓圖形設計更輕鬆
以合成假臉、假畫聞名的GAN很成熟了?那這些問題呢?
以為GAN只能「炮製假圖」?它還有這7種另類用途
超越GAN!OpenAI提出可逆生成模型,AI合成超逼真人像
谷歌大腦提出MaskGAN,可更好地實現文本生成
見過火系的暴鯉龍嗎?這個項目利用CycleGAN生成不同屬性神奇寶貝
GaN和SiC這幾大變化不得不看?
殺戮都市裡GANTZ是怎麼來的,你真知道?
GAN生成「照片級」 emoji!有人將扎克伯格做成了表情包