當前位置:
首頁 > 知識 > 在 PyTorch 上跑 GAN 只需要 50 行代碼,不試試?

在 PyTorch 上跑 GAN 只需要 50 行代碼,不試試?

「TensorFlow & 神經網路演算法高級應用班」 要開課啦!


從初級到高級,理論 + 實戰,一站式深度了解 TensorFlow!


本課程面向深度學習開發者,講授如何利用 TensorFlow 解決圖像識別、文本分析等具體問題。課程跨度為 10 周,將從 TensorFlow 的原理與基礎實戰技巧開始,一步步教授學員如何在 TensorFlow 上搭建 CNN、自編碼、RNN、GAN 等模型,並最終掌握一整套基於 TensorFlow 做深度學習開發的專業技能。


兩名授課老師佟達、白髮川身為 ThoughtWorks 的資深技術專家,具有豐富的大數據平台搭建、深度學習系統開發項目經驗。


「TensorFlow & 神經網路演算法高級應用班」 要開課啦!

從初級到高級,理論 + 實戰,一站式深度了解 TensorFlow!


本課程面向深度學習開發者,講授如何利用 TensorFlow 解決圖像識別、文本分析等具體問題。課程跨度為 10 周,將從 TensorFlow 的原理與基礎實戰技巧開始,一步步教授學員如何在 TensorFlow 上搭建 CNN、自編碼、RNN、GAN 等模型,並最終掌握一整套基於 TensorFlow 做深度學習開發的專業技能。


兩名授課老師佟達、白髮川身為 ThoughtWorks 的資深技術專家,具有豐富的大數據平台搭建、深度學習系統開發項目經驗。

在 PyTorch 上跑 GAN 只需要 50 行代碼,不試試?



編者按:本文作者為前谷歌高級工程師、AI 初創公司 Wavefront 創始人兼 CTO Dev Nag,介紹了他是如何用不到五十行代碼,在 PyTorch 平台上完成對 GAN 的訓練。AI 研習社編譯整理。

在 PyTorch 上跑 GAN 只需要 50 行代碼,不試試?



Dev Nag


什麼是 GAN?

在進入技術層面之前,為照顧新入門的開發者,我們先來介紹下什麼是 GAN。


2014 年,Ian Goodfellow 和他在蒙特利爾大學的同事發表了一篇震撼學界的論文。沒錯,我說的就是《Generative Adversarial Nets》,這標誌著生成對抗網路(GAN)的誕生,而這是通過對計算圖和博弈論的創新性結合。他們的研究展示,給定充分的建模能力,兩個博弈模型能夠通過簡單的反向傳播(backpropagation)來協同訓練。


這兩個模型的角色定位十分鮮明。給定真實數據集 R,G 是生成器(generator),它的任務是生成能以假亂真的假數據;而 D 是判別器 (discriminator),它從真實數據集或者 G 那裡獲取數據, 然後做出判別真假的標記。Ian Goodfellow 的比喻是,G 就像一個贗品作坊,想要讓做出來的東西儘可能接近真品,矇混過關。而 D 就是文物鑒定專家,要能區分出真品和高仿(但在這個例子中,造假者 G 看不到原始數據,而只有 D 的鑒定結果——前者是在盲幹)。

在 PyTorch 上跑 GAN 只需要 50 行代碼,不試試?



理想情況下,D 和 G 都會隨著不斷訓練,做得越來越好——直到 G 基本上成為了一個「贗品製造大師」,而 D 因無法正確區分兩種數據分布輸給 G。


實踐中,Ian Goodfellow 展示的這項技術在本質上是:G 能夠對原始數據集進行一種無監督學習,找到以更低維度的方式(lower-dimensional manner)來表示數據的某種方法。而無監督學習之所以重要,就好像 AI 研習社反覆引用的 Yann LeCun 的那句話:「無監督學習是蛋糕的糕體」。這句話中的蛋糕,指的是無數學者、開發者苦苦追尋的「真正的 AI」。

在 PyTorch 上跑 GAN 只需要 50 行代碼,不試試?



上圖是 Yann LeCun 對 GAN 的讚揚,意為「GAN 是機器學習過去 10 年發展中最有意思的想法。」

用 PyTorch 訓練 GAN


Dev Nag:在表面上,GAN 這門如此強大、複雜的技術,看起來需要編寫天量的代碼來執行,但事實未必如此。我們使用 PyTorch,能夠在 50 行代碼以內創建出簡單的 GAN 模型。這之中,其實只有五個部分需要考慮:


R:原始、真實數據集


I:作為熵的一項來源,進入生成器的隨機噪音


G:生成器,試圖模仿原始數據


D:判別器,試圖區別 G 的生成數據和 R


我們教 G 糊弄 D、教 D 當心 G 的「訓練」環。


1.) R:在我們的例子里,從最簡單的 R 著手——貝爾曲線(bell curve)。它把平均數(mean)和標準差(standard deviation)作為輸入,然後輸出能提供樣本數據正確圖形(從 Gaussian 用這些參數獲得 )的函數。在我們的代碼例子中,我們使用 4 的平均數和 1.25 的標準差。


2.) I:生成器的輸入是隨機的,為提高點難度,我們使用均勻分布(uniform distribution )而非標準分布。這意味著,我們的 Model G 不能簡單地改變輸入(放大/縮小、平移)來複制 R,而需要用非線性的方式來改造數據。


3.) G: 該生成器是個標準的前饋圖(feedforward graph)——兩層隱層,三個線性映射(linear maps)。我們使用了 ELU (exponential linear unit)。G 將從 I 獲得平均分布的數據樣本,然後找到某種方式來模仿 R 中標準分布的樣本。

在 PyTorch 上跑 GAN 只需要 50 行代碼,不試試?



4.) D: 判別器的代碼和 G 的生成器代碼很接近。一個有兩層隱層和三個線性映射的前饋圖。它會從 R 或 G 那裡獲得樣本,然後輸出 0 或 1 的判別值,對應反例和正例。這幾乎是神經網路的最弱版本了。

在 PyTorch 上跑 GAN 只需要 50 行代碼,不試試?



5.) 最後,訓練環在兩個模式中變幻:第一步,用被準確標記的真實數據 vs. 假數據訓練 D;隨後,訓練 G 來騙過 D,這裡是用的不準確標記。道友們,這是正邪之間的較量。

在 PyTorch 上跑 GAN 只需要 50 行代碼,不試試?



即便你從沒接觸過 PyTorch,大概也能明白髮生了什麼。在第一部分(綠色),我們讓兩種類型的數據經過 D,並對 D 的猜測 vs. 真實標記執行不同的評判標準。這是 「forward」 那一步;隨後我們需要 「backward()」 來計算梯度,然後把這用來在 d_optimizer step() 中更新 D 的參數。這裡,G 被使用但尚未被訓練。


在最後的部分(紅色),我們對 G 執行同樣的操作——注意我們要讓 G 的輸出穿過 D (這其實是送給造假者一個鑒定專家來練手)。但在這一步,我們並不優化、或者改變 D。我們不想讓鑒定者 D 學習到錯誤的標記。因此,我們只執行 g_optimizer.step()。

這就完成了。據 AI 研習社了解,還有一些其他的樣板代碼,但是對於 GAN 來說只需要這五個部分,沒有其他的了。


在 D 和 G 之間幾千輪交手之後,我們會得到什麼?判別器 D 會快速改進,而 G 的進展要緩慢許多。但當模型達到一定性能之後,G 才有了個配得上的對手,並開始提升,巨幅提升。


兩萬輪訓練之後,G 的輸入平均值超過 4,但會返回到相當平穩、合理的範圍(左圖)。同樣的,標準差一開始在錯誤的方向降低,但隨後攀升至理想中的 1.25 區間(右圖),達到 R 的層次。

在 PyTorch 上跑 GAN 只需要 50 行代碼,不試試?



所以,基礎數據最終會與 R 吻合。那麼,那些比 R 更高的時候呢?數據分布的形狀看起來合理嗎?畢竟,你一定可以得到有 4.0 的平均值和 1.25 標準差值的均勻分布,但那不會真的符合 R。我們一起來看看 G 生成的最終分布。

在 PyTorch 上跑 GAN 只需要 50 行代碼,不試試?



結果是不錯的。左側的尾巴比右側長一些,但偏離程度和峰值與原始 Gaussian 十分相近。G 接近完美地再現了原始分布 R——D 落於下風,無法分辨真相和假相。而這就是我們想要得到的結果——使用不到 50 行代碼。


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 唯物 的精彩文章:

PyTorch 特輯!網紅 5 分鐘帶你入門 PyTorch
英偉達深度學習學院現場教你如何實操深度學習,作為 AI 開發者的你不來看看?
CNN+TensorFlow 就能教機器人作曲!
這些關於TensorFlow問題的解答,你不能錯過

TAG:唯物 |

您可能感興趣

Jeff Dean推薦:用TPU跑Julia程序,只需不到1000行代碼
5行代碼秀碾壓,比Keras還好用的fastai來了,嘗鮮PyTorch 1.0必備伴侶
將 30 萬行代碼從 Flow 遷移到 TypeScript 是一種怎樣的體驗?
Pandas on Ray:僅需改動一行代碼,即可讓Pandas加速四倍
Oracle開源GraphPipe:幾行代碼讓你在TensorFlow部署PyTorch模型
資源 | Pandas on Ray:僅需改動一行代碼,即可讓Pandas加速四倍
PyTorch代碼調試利器:自動print每行代碼的Tensor信息
利用PHPstorm進行代碼review
GitHub發布GitHub Actions平台,可直接運行代碼
iOS被曝新漏洞:15行代碼讓iPhone崩潰
iOS被爆新漏洞:只需15行代碼就可以讓你的iPhone崩潰重啟!
Reddit熱議:只要2行代碼,免費開源ML管理工具TRAINS
Linux 將不再支持舊 CPU 架構,可節省 50 萬行代碼
黑進iPhone讓手機崩潰重啟,只需15行代碼:iOS漏洞你可知?
Python10行代碼就可以搞定的-對象檢測,你值得擁有
一行代碼讓 Python 的運行速度提高100倍
一行代碼切換TensorFlow與PyTorch,模型訓練也能用倆框架
TP-Link TL-WR740N存在嚴重漏洞,允許攻擊者在設備上遠程執行代碼
想要千行代碼搞定Transformer?這份高效的PaddlePaddle官方實現請收下
批量下載網頁圖片,python只需23行代碼