GAN生成對抗網路代碼解析
上一次我們寫了生成對抗網路的工作原理,俗話說,學而不編則罔,編而不學則殆,跑起來才能加深對演算法的理解。
在跑之前,首先需要先裝上python(廢話),tensorflow (https://www.tensorflow.org/install/),裝完之後試一下
看能不能出現「Hello, TensorFlow!」成功了就可以正式進入GAN了!
首先第一步是導入數據包,說明此前你應該已經預先安裝好了numpy和matplotlib,此外,我們訓練所用的資料庫MNIST (http://yann.lecun.com/exdb/mnist/)也能通過調用一個tensorflow的函數read_data_sets直接導入(不用手動下載)。
這一步如果成功的話,運行會出現:
就是MNIST被提取出來了,這4個文件依次是,訓練圖、訓練標籤、測試圖、測試標籤。
如果你想看看MNIST裡面的數據圖長啥樣,可以用next_batch來調用一個batch用PyPlot來看看:
由於隨機調用,所以每次運行會出現的數字都不同:
接下來我們就可以開始構建判別器Discriminator和生成器Generator了。
在這裡有一個小tip要注意就是最開頭那段if語句,因為小姐姐沒加它的時候出現報錯,意思是參數不能重複使用,可是怎麼可能,這些權重參數下面都要反覆用來計算,所以需要把你要重複使用的參數都包括進去才行繼續往下跑(生成器同此)。
生成器就可以看作是反卷積的過程,判別器輸入2維或者3維的像素矩陣,輸出一個概率,生成器則是反過來——將一個多維的噪音向量輸出為一張28*28的像素圖(但其實是28*28*1,因為數字只有灰度1一個維度,但是tensorflow通常能處理3通道的RGB像素圖),可以看到最後一層加了tf.sigmoid激活函數,它的作用是將灰度轉化成黑或者白來輸出圖像。
生成器構造完了,我們來看看未經訓練之前的噪音圖像什麼樣子的:
看一下輸出,噪音真的就是噪音本尊:
好,終於要開始訓練了,論文告訴我們說,判別器和生成器各自有自己的損失函數,我們同時訓練它們,使得生成器能生成更像正樣本的圖片,並使得判別器能更準確地判斷出正負樣本圖片。
然後構造優化器,這裡使用的是Adam梯度下降,同時讓判別器對真實圖和噪音圖分別訓練,以方便分別調整步長:
最後,GAN初版不好訓練(常常崩),可以用tensorboard同時監控損失函數和訓練圖的變化,甚至還能畫出神經網路的拓撲結構:
使用方法就是先在終端激活tensorflow
然後運行命令
再然後打開網頁
http://localhost:6000
就能看到了。
好,萬事俱備,只欠訓練。友情提示,訓練要很久(GPU至少3小時,CPU。。。30個小時吧),而且容易崩(GAN本身的結構決定的),祝福你成功。
開始訓練的界面是這樣子的:
Tensorboard裡面也能看到損失在逐步下降:
最後的最後,其實可以下載訓練好的權重來測試 (https://github.com/jonbruner/generative-adversarial-networks/blob/master/pretrained-model/pretrained_gan.ckpt),把它下載放在你的本地文件夾里,然後運行:
就能看到美麗的小數字們啦,啦啦啦啦。


TAG:夏至又一年 |