當前位置:
首頁 > 最新 > GAN生成對抗網路代碼解析

GAN生成對抗網路代碼解析

上一次我們寫了生成對抗網路的工作原理,俗話說,學而不編則罔,編而不學則殆,跑起來才能加深對演算法的理解。

在跑之前,首先需要先裝上python(廢話),tensorflow (https://www.tensorflow.org/install/),裝完之後試一下

看能不能出現「Hello, TensorFlow!」成功了就可以正式進入GAN了!

首先第一步是導入數據包,說明此前你應該已經預先安裝好了numpy和matplotlib,此外,我們訓練所用的資料庫MNIST (http://yann.lecun.com/exdb/mnist/)也能通過調用一個tensorflow的函數read_data_sets直接導入(不用手動下載)。

這一步如果成功的話,運行會出現:

就是MNIST被提取出來了,這4個文件依次是,訓練圖、訓練標籤、測試圖、測試標籤。

如果你想看看MNIST裡面的數據圖長啥樣,可以用next_batch來調用一個batch用PyPlot來看看:

由於隨機調用,所以每次運行會出現的數字都不同:

接下來我們就可以開始構建判別器Discriminator和生成器Generator了。

在這裡有一個小tip要注意就是最開頭那段if語句,因為小姐姐沒加它的時候出現報錯,意思是參數不能重複使用,可是怎麼可能,這些權重參數下面都要反覆用來計算,所以需要把你要重複使用的參數都包括進去才行繼續往下跑(生成器同此)。

生成器就可以看作是反卷積的過程,判別器輸入2維或者3維的像素矩陣,輸出一個概率,生成器則是反過來——將一個多維的噪音向量輸出為一張28*28的像素圖(但其實是28*28*1,因為數字只有灰度1一個維度,但是tensorflow通常能處理3通道的RGB像素圖),可以看到最後一層加了tf.sigmoid激活函數,它的作用是將灰度轉化成黑或者白來輸出圖像。

生成器構造完了,我們來看看未經訓練之前的噪音圖像什麼樣子的:

看一下輸出,噪音真的就是噪音本尊:

好,終於要開始訓練了,論文告訴我們說,判別器和生成器各自有自己的損失函數,我們同時訓練它們,使得生成器能生成更像正樣本的圖片,並使得判別器能更準確地判斷出正負樣本圖片。

然後構造優化器,這裡使用的是Adam梯度下降,同時讓判別器對真實圖和噪音圖分別訓練,以方便分別調整步長:

最後,GAN初版不好訓練(常常崩),可以用tensorboard同時監控損失函數和訓練圖的變化,甚至還能畫出神經網路的拓撲結構:

使用方法就是先在終端激活tensorflow

然後運行命令

再然後打開網頁

http://localhost:6000

就能看到了。

好,萬事俱備,只欠訓練。友情提示,訓練要很久(GPU至少3小時,CPU。。。30個小時吧),而且容易崩(GAN本身的結構決定的),祝福你成功。

開始訓練的界面是這樣子的:

Tensorboard裡面也能看到損失在逐步下降:

最後的最後,其實可以下載訓練好的權重來測試 (https://github.com/jonbruner/generative-adversarial-networks/blob/master/pretrained-model/pretrained_gan.ckpt),把它下載放在你的本地文件夾里,然後運行:

就能看到美麗的小數字們啦,啦啦啦啦。


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 夏至又一年 的精彩文章:

TAG:夏至又一年 |