當前位置:
首頁 > 知識 > 生成式對抗網路如何快速理解?

生成式對抗網路如何快速理解?

導讀:生成式對抗網路(GAN)是一個最新的研究領域,主要用在圖像技術方面的圖像生成和自然語言方面的生成式對話內容。簡單說:就是機器可以根據需要生成新的圖像和對話內容。

讓我們假設這樣一種情景:你的鄰居正在舉辦一場非常酷的聚會,你非常想去參加。但有要參加聚會的話,你需要一張特價票,而這個票早就已經賣完了。

而對於這次聚會的組織者來說,為了讓聚會能夠成功舉辦,他們僱傭了一個合格的安全機構。主要目標就是不允許任何人破壞這次的聚會。為了做到這一點,他們在會場入口處安置了很多警衛,檢查每個人所持門票的真實性。

考慮到你沒有任何武術上的天賦,而你又特別想去參加聚會,那麼唯一的辦法就是用一張非常有說服力的假票來騙他們。但是這個計劃存在一個很大的bug——你從來沒有真正看到過這張門票到底是什麼樣的。所以,在這種情況下,如果你僅是根據自己的創造力設計了一張門票,那麼在第一次嘗試期間就想要騙過警衛幾乎是不可能的。除此之外,除非你有一個很好的關於此次聚會的門票的複印件,否則你最好不要把你的臉展露出來。

為了幫助解決問題,你決定打電話給你的朋友Bob為你做這個工作。Bob的任務非常簡單。他會試圖用你的假通行證進入聚會。如果他被拒絕了,他將返回,然後告訴你一些有關真正的門票應該是什麼樣的建議。

基於這個反饋,你可以製作一張全新版本的門票,然後將其交給Bob,再去檢票處嘗試一下。這個過程不斷重複,直到你能夠設計一個完美的門票「複製品」。

這是一個必須去的派對。而下面這張照片,其實是我其實從一個假票據生成器器網站上拿到的。

對於上面這個小故事,拋開裡面的假想成分,這幾乎就是生成對抗網路(GAN)的工作方式。

目前,生成對抗網路的大部分應用都是在計算機視覺領域。其中一些應用包括訓練半監督分類器,以及從低解析度圖像中生成高解析度圖像。

本篇文章對GAN進行了一些介紹,並對圖像生成問題進行了實際實踐。你可以在你的筆記本電腦上進行演示。

生成對抗網路(Generative Adversarial Networks)

生成對抗網路框架

GAN是由Goodfellow等人於2014年設計的生成模型。在GAN設置中,兩個由神經網路進行表示的可微函數被鎖定在一個遊戲中。這兩個參與者(生成器和鑒別器)在這個框架中要扮演不同的角色。

生成器試圖生成來自某種概率分布的數據。即你想重新生成一張聚會的門票。

鑒別器就像一個法官。它可以決定輸入是來自生成器還是來自真正的訓練集。這就像是聚會中的安保設置,比將你的假票和這正的門票進行比較,以找到你的設計中存在的缺陷。

我們將一個4層卷積網路用於生成器和鑒別器,進行批量正則化。對該模型進行訓練以生成SVHN和MNIST圖像。以上是訓練期間SVHN(上)和MNIST(下)生成器樣本

總而言之,遊戲如下:

?生成器試圖最大化鑒別器將其輸入錯認為正確的的概率。

?鑒別器引導生成器生成更逼真的圖像。

在完美的平衡狀態中,生成器將捕獲通用的訓練數據分布。結果,鑒別器總是不確定其輸入是否是真實的。

摘自DCGAN論文。生成器網路在這裡實現。注意:完全連接層和池化層的不存在

在DCGAN論文中,作者描述了一些深度學習技術的組合,它們是訓練GAN的關鍵。這些技術包括:(i)所有的卷積網路;(ii)批量正則化(BN)。

第一個強調的重點是帶步幅的卷積(strided convolutions),而不是池化層:增加和減少特徵的空間維度;第二個是,對特徵向量進行正則化以使其在所有層中具有零均值和單位方差。這有助於穩定學習和處理權重不佳的初始化問題。

言歸正傳,在這裡闡述一下實施細節,以及GAN的相關知識。我們提出了深度卷積生成對抗網路(DCGAN)的實現。我們的實現使用的是Tensorflow並遵循DCGAN論文中描述的一些實踐方法。

生成器

該網路有4個卷積層,所有的位於BN(輸出層除外)和校正線性單元(ReLU)激活之後。

它將隨機向量z(從正態分布中抽取)作為輸入。將z重塑為4D形狀之後,將其饋送到啟動一系列上採樣層的生成器中。

每個上採樣層都代表一個步幅為2的轉置卷積(Transpose convolution)運算。轉置卷積與常規卷積類似。

一般來說,常規卷積從寬且淺的層延展為更窄、更深的層。轉移卷積走另一條路。他們從深而窄的層次走向更寬更淺。

轉置卷積運算的步幅定義了輸出層的大小。在「相同」的填充和步幅為2時,輸出特徵的大小將是輸入層的兩倍。

發生這種情況的原因是,每次我們移動輸入層中的一個像素時,我們都會將輸出層上的卷積內核移動兩個像素。換句話說,輸入圖像中的每個像素都用於在輸出圖像中繪製一個正方形。

將一個3x3的內核在一個步幅為2的2x2輸入上進行轉置,就相當於將一個3x3的內核在一個步幅為2的5x5輸入上進行卷積運算。對於二者,均不使用填充「有效」

簡而言之,生成器開始於這個非常深但很窄的輸入向量開始。在每次轉置卷積之後,z變得更寬、更淺。所有的轉置卷積都使用5x5內核的大小,且深度從512減少到3——代表RGB彩色圖像。

def transpose_conv2d(x, output_space):

kernel_size=5, strides=2, padding="same",

kernel_initializer=tf.random_normal_initializer(mean=0.0,

stddev=0.02))

最後一層通過雙曲正切(tanh)函數輸出一個32x32x3的張量——值在-1和1之間進行壓縮。

這個最終的輸出形狀是由訓練圖像的大小來定義的。在這種情況下,如果是用於SVHN的訓練,生成器生成32x32x3的圖像。但是,如果是用於MNIST的訓練,則會生成28x28的灰度圖像。

最後,請注意,在將輸入向量z饋送到生成器之前,我們需要將其縮放到-1到1的區間。這是遵循使用tanh函數的選擇。

def generator(z, output_dim, reuse=False, alpha=0.2, training=True):

"""

Defines the generator network

:param z: input random vector z

:param output_dim: output dimension of the network

:param reuse: Indicates whether or not the existing model variables should be used or recreated

:param alpha: scalar for lrelu activation function

:param training: Boolean for controlling the batch normalization statistics

:return: model"s output

"""

with tf.variable_scope("generator", reuse=reuse):

fc1 = dense(z, 4*4*512)

# Reshape it to start the convolutional stack

fc1 = tf.reshape(fc1, (-1, 4, 4, 512))

fc1 = batch_norm(fc1, training=training)

t_conv1 = transpose_conv2d(fc1, 256)

t_conv1 = batch_norm(t_conv1, training=training)

t_conv2 = transpose_conv2d(t_conv1, 128)

t_conv2 = batch_norm(t_conv2, training=training)

logits = transpose_conv2d(t_conv2, output_dim)

out = tf.tanh(logits)

return out

鑒別器

鑒別器也是一個包含BN(除了其輸入層之外)和leaky ReLU激活的4層CNN。許多激活函數都可以在這種基礎GAN體系結構中進行良好的運算。但是leaky ReLUs有著非常廣泛的應用,因為它們可以幫助梯度在結構中更輕易地流動。

常規的RELU函數通過將負值截斷為0來工作。這樣做的效果是阻止梯度流通過網路。leaky ReLU允許一個小負值通過,而非要求函數為0。也就是說,函數用來計算特徵與小因素之間的最大值。

def lrelu(x, alpha=0.2):

# non-linear activation function

return tf.maximum(alpha * x, x)

leaky ReLU表示了一種解決崩潰邊緣ReLU問題的嘗試。這種情況發生在神經元陷於某一特定情況下,此時ReLU單元對於任何輸入都輸出0。對於這些情況,梯度完全關閉以通過網路迴流。

這對於GAN來說尤為重要,因為生成器必須學習的唯一方法是接受來自鑒別器的梯度。

(上)ReLU,(下)leaky ReLU激活函數。 請注意,當x為負值時, leaky ReLU允許有一個小的斜率

這個鑒別器首先接收一個32x32x3的圖像張量。與生成器相反的是,鑒別器執行一系列步幅為2的卷積。每一種方法都是通過將特徵向量的空間維度縮小一半,從而使學習過濾器的數量加倍。

最後,鑒別器需要輸出概率。為此,我們在最後的邏輯(logits)上使用Logistic Sigmoid激活函數。

def discriminator(x, reuse=False, alpha=0.2, training=True):

"""

Defines the discriminator network

:param x: input for network

:param reuse: Indicates whether or not the existing model variables should be used or recreated

:param alpha: scalar for lrelu activation function

:param training: Boolean for controlling the batch normalization statistics

:return: A tuple of (sigmoid probabilities, logits)

"""

with tf.variable_scope("discriminator", reuse=reuse):

# Input layer is 32x32x?

conv1 = conv2d(x, 64)

conv1 = lrelu(conv1, alpha)

conv2 = conv2d(conv1, 128)

conv2 = batch_norm(conv2, training=training)

conv2 = lrelu(conv2, alpha)

conv3 = conv2d(conv2, 256)

conv3 = batch_norm(conv3, training=training)

conv3 = lrelu(conv3, alpha)

# Flatten it

flat = tf.reshape(conv3, (-1, 4*4*256))

logits = dense(flat, 1)

out = tf.sigmoid(logits)

return out, logits

需要注意的是,在這個框架中,鑒別器充當一個常規的二進位分類器。一半的時間從訓練集接收圖像,另一半時間從生成器接收圖像。

回到我們的故事中,為了複製聚會的票,你唯一的信息來源是朋友Bob的反饋。換言之,Bob在每次嘗試期間向你提供的反饋的質量對於完成工作至關重要。

同樣的,每次鑒別器注意到真實圖像和虛假圖像之間的差異時,都會向生成器發送一個信號。該信號是從鑒別器向生成器反向流動的梯度。通過接收它,生成器能夠調整其參數以接近真實的數據分布。

這就是鑒別器的重要性所在。實際上,生成器將要儘可能好地產生數據,因為鑒別器正在不斷地縮小真實和虛假數據的差距。

損失

現在,讓我們來描述這一結構中最棘手的部分——損失。首先,我們知道鑒別器收集來自訓練集和生成器的圖像。

我們希望鑒別器能夠區分真實和虛假的圖像。我們每次通過鑒別器運行一個小批量(mini-batch)的時候,都會得到邏輯(logits)。這些是來自模型的未縮放值(unscaled values)。然而,我們可以將鑒別器接收的小批量(mini-batches)分成兩種類型。第一種類型只由來自訓練集的真實圖像組成,第二種類型只包含由生成器生成的假圖像。

def model_loss(input_real, input_z, output_dim, alpha=0.2, smooth=0.1):

"""

Get the loss for the discriminator and generator

:param input_real: Images from the real dataset

:param input_z: random vector z

:param out_channel_dim: The number of channels in the output image

:param smooth: label smothing scalar

:return: A tuple of (discriminator loss, generator loss)

"""

g_model = generator(input_z, output_dim, alpha=alpha)

d_model_real, d_logits_real = discriminator(input_real, alpha=alpha)

d_model_fake, d_logits_fake = discriminator(g_model, reuse=True, alpha=alpha)

# for the real images, we want them to be classified as positives,

# so we want their labels to be all ones.

# notice here we use label smoothing for helping the discriminator to generalize better.

# Label smoothing works by avoiding the classifier to make extreme predictions when extrapolating.

d_loss_real = tf.reduce_mean(

# for the fake images produced by the generator, we want the discriminator to clissify them as false images,

# so we set their labels to be all zeros.

d_loss_fake = tf.reduce_mean(

# since the generator wants the discriminator to output 1s for its images, it uses the discriminator logits for the

# fake images and assign labels of 1s to them.

g_loss = tf.reduce_mean(

d_loss = d_loss_real + d_loss_fake

return d_loss, g_loss

由於兩個網路同時進行訓練,因此GAN需要兩個優化器。它們分別用於最小化鑒別器和發生器的損失函數。

我們希望鑒別器輸出真實圖像的概率接近於1,輸出假圖像的概率接近於0。要做到這一點,鑒別器需要兩部分損失。因此,鑒別器的總損失是這兩部分損失之和。其中一部分損失用於將真實圖像的概率最大化,另一部分損失用於將假圖像的概率最小化。

比較真實(左)和生成的(右)SVHN樣本圖像。雖然有些圖像看起來很模糊,且有些圖像很難識別,但值得注意的是,數據分布是由模型捕獲的

在訓練開始的時候,會出現兩個有趣的情況。首先,生成器不清楚如何創建與訓練集中圖像相似的圖像。其次,鑒別器不清楚如何將接收到的圖像分為真、假兩類。

結果,鑒別器接收兩種類型截然不同的批量(batches)。一個由訓練集的真實圖像組成,另一個包含含有雜訊的信號。隨著訓練的不斷進行,生成器輸出的圖像更加接近於訓練集中的圖像。這種情況是由生成器學習組成訓練集圖像的數據分布而造成的。

與此同時,鑒別器開始真正善於將樣本分類為真或假。結果,這兩種小批量(mini-batch)在結構上開始相互類似。因此,鑒別器無法識別出真實或虛假的圖像。

對於損失,我們認為,使用具有Adam演算法的vanilla交叉熵(vanilla cross-entropy)作為優化器是一個不錯的選擇。

比較real(左)和Generate(右)MNIST示例圖像。由於MNIST圖像具有更簡單的數據結構,因此與SVHN相比,該模型能夠生成更真實的樣本

目前,GANs是機器學習中最熱門的學科之一。這些模型具有解開無監督學習方法(unsupervised learning methods)的潛力,並且可以將ML拓展到新領域。

自從GANs誕生以來,研究人員開發了許多用於訓練GANs的技術。在改進過後的GANs訓練技術中,作者描述了圖像生成(image generation)和半監督學習(semi-supervised learning)的最新技術。


喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 AI講堂 的精彩文章:

機器學習必知的8大神經網路架構和原理

TAG:AI講堂 |