GANs為何引爆機器學習?這是篇基於TensorFlow教程
GIF/1.7M
「機器人圈導覽」: 生成對抗網路無疑是機器學習領域近三年來最火爆的研究領域,相關論文層出不求,各種領域的應用層出不窮。那麼,GAN到底如何實踐?本文編譯自Medium,該文作者以一朵玫瑰花為例,詳細闡述了GAN的原理,以及基於谷歌TensorFlow的實現,文章略長,閱讀大約需要15分鐘。
想像有一天,我們可以利用一個神經網路觀看電影並製作自己的電影,或者聽歌和創作歌曲。神經網路將從它看到的內容中學習,而且你並不需要明確地告訴它,這種使神經網路學習的方式被稱為無監督學習。
實際上,以無監督的方式訓練的GAN(生成對抗網路)在過去三年中獲得了極大的關注,被認為是目前AI領域最熱門的話題之一。就像Facebook AI的主管Yann LeCun認為的那樣:
生成對抗網路是機器學習過去十年最有趣的想法。
GAN是理想的神經網路,它在看到某些圖像後生成新圖像。那麼,這可以用來做什麼?為什麼這很重要?
生成的卧室圖像
直到最近,神經網路(特別是卷積神經網路)只擅長分類任務,如在貓和狗、飛機和汽車之間進行分類。但現在,他們可以用來生成圖片的貓或狗(即使它們看起來很奇怪),這告訴我們他們已經學會記住特徵。
GAN的這種非凡的能力可以應用於許多驚人的應用程序,如:
?生成給定文本描述的圖像。
文本到圖像
點擊此處鏈接以了解更多信息:The major advancements in Deep Learning in 2016
?圖像到圖像的翻譯:
這可能是GAN最酷的應用。圖像到圖像翻譯可用於很多場景,比如從草圖生成逼真的圖像,將白天拍攝的圖像轉換為夜間圖像,甚至將黑白圖像轉換為彩色圖像。
查看此鏈接了解更多詳情:基於條件抗網路的圖像到圖像的翻譯
讓我們了解一下GAN有什麼能力讓所有人都對齊大肆吹捧。我們用一個簡單的GAN實例生成玫瑰圖像。
我們來看看GAN到底是什麼?
在我們開始構建GAN之前,我們可以了解它的工作原理。 生成對抗網路包含兩個神經網路,一個鑒別器和一個生成器。鑒別器是一個卷積神經網路(CNN)(不知道CNN是什麼?請看這個帖子),學習區分真實和假的圖像。真實的圖像是從資料庫中獲取的,而假的圖像來自生成器。
鑒別器
生成器的工作原理就像CNN反向運行,它將一個隨機數向量作為輸入,並在輸出端生成一個圖像。
生成器
稍後我們將介紹生成器和鑒別器的工作和實現,但現在我們通過一個非常有名的實例來解釋GAN(濫用生成對抗網路生成8位像素藝術)。
我們可以把生成器比作一個偽造者,而把鑒別器視作一個警察,他必須從兩枚貨幣中區分真假。在最開始的時候,我們要確保偽造貨幣的偽造者和警察都是同樣不擅長他們的工作的。因此,偽造者首先生成一些隨機的雜訊圖像。
偽造者產生的雜訊圖像
現在警察接受訓練來區分偽造者產生假的的圖像和真實的貨幣。
訓練警察
偽造者現在已經知道它的圖像已被歸類為「假」,而警察正在尋找貨幣所具有的一些獨特的特徵(如顏色和圖案)。偽造者現在在學習這些特徵,並生成具有這些特徵的圖像。
訓練偽造者
現在,警察再次區分出數據集中的出真正貨幣和來自偽造者新改進的圖像,並要求對它們進行分類,因此,該警察將會學到更多的關於真實圖像的特徵(如貨幣的表面特徵)。
用新的虛假圖像來訓練警察
而偽造者再次學習這些特徵,併產生更好看的假圖像。
再次訓練偽造者
偽造者和警察之間的這場拉鋸戰將一直持續,直到偽造者生成的圖像看起來與真實的圖像完全相同,而且警察將無法對其進行分類。
真假難辨
在Tensorflow上生成玫瑰花
我們只用tensorflow而不用其它(除了pillow)來構建一個簡單的DCGAN(深度卷積生成對抗式網路)。那麼,DCGAN是什麼呢?
DCGAN是普通GAN的一個修改版本,以解決普通GAN所涵蓋的一些難題,例如:使偽造的圖像視覺上看起來比較滿意,通過反覆輸出符合鑒別器正在尋找的但不在實際圖像附近的數據分布的圖像,在訓練過程中提高穩定性,從而使發生器不會在鑒別器中找到缺陷。
下圖就是我們正在嘗試去構建的鑒別器架構:
鑒別器架構
可以看出,它將圖像作為輸入並輸出一個logit(1為真類,0為偽類)。
接下來,我們用一個生成器架構,它由conv_transpose層組成,它們將一組隨機數作為輸入,並在輸出端生成一個圖像。
生成器架構
DCGAN可直接產生這篇論文中提到的變化:
?用分段卷積(鑒別器)和分數階卷積(生成器)替換任何合并層。
?在發生器和鑒別器中使用batchnorm。
?刪除完全連接的隱藏層以進行更深層次的體系結構。
?在除了使用Tanh的輸出之外的所有圖層中的生成器中使用ReLU激活函數。
?在所有層的鑒別器中使用LeakyReLU激活函數。
我們首先需要收集玫瑰圖像。一個簡單方法就是在Google上進行玫瑰圖像搜索,並使用諸如ImageSpark這樣的Chrome插件下載搜索結果中的所有圖像。
我們收集了67張圖片(當然是越多越好)並在這裡可用。在以下目錄中提取這些圖像:
/Dataset/Roses。
點擊鏈接獲取更多信息:GANs_N_Roses
既然我們已經有了圖像,下一步就是通過將它們重構為64 * 64,並將其縮放值設置為-1和1之間,以預處理這些圖像。
def load_dataset(path, data_set="birds", image_size=64):
"""
Loads the images from the specified path
:param path: string indicating the dataset path.
:param data_set: "birds" -> loads data from birds directory, "flowers" -> loads data from the flowers directory.
:param image_size: size of images in the returned array
:return: numpy array, shape : [number of images, image_size, image_size, 3]
"""
all_dirs = os.listdir(path)
image_dirs = [i for i in all_dirs if i.endswith(".jpg") or i.endswith(".jpeg") or i.endswith(".png")]
number_of_images = len(image_dirs)
images = []
print("{} images are being loaded...".format(data_set[:-1]))
for c, i in enumerate(image_dirs):
images.append(np.array(ImageOps.fit(Image.open(path + "/" + i),
(image_size, image_size), Image.ANTIALIAS))/127.5 - 1.)
.format(c + 1, number_of_images))
print("
")
images = np.reshape(images, [-1, image_size, image_size, 3])
return images.astype(np.float32)
首先,我們寫出可用於執行卷積、卷積轉置、緻密完全連接層和LeakyReLU激活(因為它在Tensorflow上不可用)的函數。
def conv2d(x, inputFeatures, outputFeatures, name):
with tf.variable_scope(name):
w = tf.get_variable("w", [5, 5, inputFeatures, outputFeatures],
initializer=tf.truncated_normal_initializer(stddev=0.02))
b = tf.get_variable("b", [outputFeatures], initializer=tf.constant_initializer(0.0))
return conv
實現卷積層的函數
我們使用get_variable()而不是通常的Variable(),在tensorflow上創建一個變數,以便以後在不同的函數調用之間共享權重和偏差。 查看這篇文章了解更多有關共享變數的信息。
def conv_transpose(x, outputShape, name):
with tf.variable_scope(name):
w = tf.get_variable("w", [5, 5, outputShape[-1], x.get_shape()[-1]],
initializer=tf.truncated_normal_initializer(stddev=0.02))
b = tf.get_variable("b", [outputShape[-1]], initializer=tf.constant_initializer(0.0))
return convt
實現卷積轉置的函數
# fully-conected layer
def dense(x, inputFeatures, outputFeatures, scope=None, with_w=False):
with tf.variable_scope(scope or "Linear"):
matrix = tf.get_variable("Matrix", [inputFeatures, outputFeatures], tf.float32,
tf.random_normal_initializer(stddev=0.02))
bias = tf.get_variable("bias", [outputFeatures], initializer=tf.constant_initializer(0.0))
if with_w:
return tf.matmul(x, matrix) + bias, matrix, bias
else:
return tf.matmul(x, matrix) + bias
實現緻密完全連接層的函數
def lrelu(x, leak=0.2, name="lrelu"):
with tf.variable_scope(name):
f1 = 0.5 * (1 + leak)
f2 = 0.5 * (1 - leak)
return f1 * x + f2 * abs(x)
Leaky ReLU
下一步是構建生成器和鑒別器。我們先從主角—生成器開始。我們需要構建的生成器架構如下所示:
我們又一次試圖實現的生成器架構
def generator(z, z_dim):
"""
Used to generate fake images to fool the discriminator.
:param z: The input random noise.
:param z_dim: The dimension of the input noise.
:return: Fake images -> [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3]
"""
gf_dim = 64
z2 = dense(z, z_dim, gf_dim * 8 * 4 * 4, scope="g_h0_lin")
center=True, scale=True, is_training=True, scope="g_bn1"))
center=True, scale=True, is_training=True, scope="g_bn2"))
center=True, scale=True, is_training=True, scope="g_bn3"))
center=True, scale=True, is_training=True, scope="g_bn4"))
h4 = conv_transpose(h3, [mc.BATCH_SIZE, 64, 64, 3], "g_h4")
generator()函數使用上圖中的體系架構構建一個生成器。諸如除去所有完全連接層,僅在發生器上使用ReLU以及使用批量歸一化,這些任務DCGAN要求已經達標。
類似地,鑒別器也可以很容易地構造成如下圖所示:
所需架構:
鑒別器架構
def discriminator(image, reuse=False):
"""
Used to distinguish between real and fake images.
:param image: Images feed to the discriminate.
:param reuse: Set this to True to allow the weights to be reused.
:return: A logits value.
"""
df_dim = 64
if reuse:
tf.get_variable_scope().reuse_variables()
h0 = lrelu(conv2d(image, 3, df_dim, name="d_h0_conv"))
h1 = lrelu(batch_norm(conv2d(h0, df_dim, df_dim * 2, name="d_h1_conv"),
center=True, scale=True, is_training=True, scope="d_bn1"))
h2 = lrelu(batch_norm(conv2d(h1, df_dim * 2, df_dim * 4, name="d_h2_conv"),
center=True, scale=True, is_training=True, scope="d_bn2"))
h3 = lrelu(batch_norm(conv2d(h2, df_dim * 4, df_dim * 8, name="d_h3_conv"),
center=True, scale=True, is_training=True, scope="d_bn3"))
h4 = dense(tf.reshape(h3, [-1, 4 * 4 * df_dim * 8]), 4 * 4 * df_dim * 8, 1, scope="d_h3_lin")
return h4
我們再次避免了密集的完全連接的層,使用了Leaky ReLU,並在Discriminator處進行了批處理。
下面到了有趣的部分,我們要訓練這些網路:
鑒別器和發生器的損耗函數如下所示:
鑒別器損耗函數
生成器損耗函數
G = generator(zin, z_dim) # G(z)
Dx = discriminator(images) # D(x)
Dg = discriminator(G, reuse=True) # D(G(x))
我們將隨機輸入傳遞給發生器,輸入阻抗為[BATCH_SIZE,Z_DIM],生成器現在應該在其輸出端給出BATCH_SIZE偽圖像數。生成器輸出的大小現在將為[BATCH_SIZE,IMAGE_SIZE,IMAGE_SIZE,3]。
D(x)是識別真實圖像或虛假圖像,進行訓練以便區分它們的鑒別器。為了在真實圖像上訓練鑒別器,我們將真實圖像批處理傳遞給D(x),並將目標設置為1。類似地,想要在來自生成器的假圖像上對其進行訓練的話,我們將使用D(g)將生成器的輸出連接到鑒別器的輸入上。D的損失是使用tensorflow的內置函數實現的:
dloss = d_loss_real + d_loss_fake
我們接下來需要訓練生成器,以便D(g)將輸出為1,即我們將修正鑒別器上的權重,並且僅在生成器權重上返回,以便鑒別器總是輸出為1。
因此,生成器的損耗函數為:
接下來我們將收集鑒別器和生成器的所有權重(以後需要對發生器或判別器進行訓練):
# Get the variables which need to be trained
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if "d_" in var.name]
g_vars = [var for var in t_vars if "g_" in var.name]
我們使用tensorflow的AdamOptimizer來學習權重。接下來我們將需要修改的權重傳遞給鑒別器和生成器的優化器。
with tf.variable_scope(tf.get_variable_scope(), reuse=False) as scope:
d_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(dloss, var_list=d_vars)
g_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(gloss, var_list=g_vars)
最後一步是運行會話並將所需的圖像批處理傳遞給優化器。我們將對這個模型進行30000次迭代訓練,並定期顯示鑒別器和發生器損耗。
with tf.Session() as sess:
tf.global_variables_initializer().run()
writer = tf.summary.FileWriter(logdir=logdir, graph=sess.graph)
if not load:
for idx in range(n_iter):
batch_images = next_batch(real_img, batch_size=batch_size)
for k in range(1):
sess.run([d_optim], feed_dict=)
sess.run([g_optim], feed_dict=)
print("[%4d/%4d] time: %4.4f, " % (idx, n_iter, time.time() - start_time))
if idx % 10 == 0:
# Display the loss and run tf summaries
summary = sess.run(summary_op, feed_dict=)
writer.add_summary(summary, global_step=idx)
d_loss = d_loss_fake.eval()
g_loss = gloss.eval()
print("
Discriminator loss:
Generator loss:
".format(d_loss, g_loss))
if idx % 1000 == 0:
# Save the model after every 1000 iternations
saver.save(sess, saved_models_path + "/train", global_step=idx)
為了簡化調整超參數,並在每次運行時保存結果,我們實現了form_results函數和mission_control.py文件。
該網路的所有超參數可以在mission_control.py文件中進行修改,之後運行的main.py文件將自動為每次運行創建文件夾,並保存資料庫文件和生成的圖像。
"""
Contains all the variables necessary to run gans_n_roses.py file.
"""
# Set LOAD to True to load a trained model or set it False to train a new one.
LOAD = False
# Dataset directories
DATASET_PATH = "./Dataset/Roses/"
DATASET_CHOSEN = "roses" # required by utils.py -> ["birds", "flowers", "black_birds"]
# Model hyperparameters
Z_DIM = 100 # The input noise vector dimension
BATCH_SIZE = 12
N_ITERATIONS = 30000
LEARNING_RATE = 0.0002
BETA_1 = 0.5
IMAGE_SIZE = 64 # Change the Generator model if the IMAGE_SIZE needs to be changed to a different value
我們可以通過打開tensorboard,在訓練期間的每次迭代中查鑒別器和生成器的損耗,並將其指向在每個運行文件夾下創建的Tensorboard目錄中。
訓練期間發生器損耗的變化
訓練期間鑒別器損耗的變化
從這些圖可以看出,在訓練階段,鑒別器和生成器損耗在不斷增加,表明生成器和鑒別器都試圖相互執行。
代碼還可以為每次運行保存生成的圖像,其中一些圖像如下所示:
在第0次迭代:
第100次迭代:
第1000次迭代:
圖像在第30000次迭代中被過度擬合:
訓練階段生成的圖像如圖所示:
這些圖像是有希望實現目標的,但是經過大約1000次迭代,可以看出,發生器只是從訓練數據集中再現圖像。我們可以使用較大的數據集,並進行較少數量的迭代訓練,以減少過度擬合。
GAN易於實現,但如果沒有正確的超參數和網路架構,是難以進行訓練的。我們寫這篇文章主要是為了幫助人們開始使用生成網路。
用GAN還可以做什麼呢?
基於條件對抗網路的圖像到圖像翻譯
Wasserstein GAN
GAN的Pytorch實現


※前沿探索 對大腦還一知半解,科學家就直接用它來「指導」機器學習了
※機器學習演算法在自動駕駛領域的應用大盤點!
※首屆北美計算機華人學者年會:伊利諾伊大學劉兵—終身機器學習
※基於機器學習的KPI自動化異常檢測系統
※晶元設計遇上機器學習,專家們都這麼看
TAG:機器學習 |
※FAIR開源Tensor Comprehensions,讓機器學習與數學運算高性能銜接
※FAIR開源Tensor Comprehensions,讓機器學習與數學運算高性能銜接
※如何使用 Android Things 和 TensorFlow 在物聯網上應用機器學習
※Learning Memory Access Patterns,資料庫+機器學習探索
※Google推出AI晶元Edge TPU,可在邊緣運行TensorFlow Lite機器學習模型
※Andrew Ng經典機器學習課程的Python實現2
※英特爾宣布Windows機器學習Movidius Myriad X VPU
※機器學習基石-Noise and Error
※Facebook發布Tensor Comprehensions:自動編譯高性能機器學習核心的C+庫
※機器學習基石-The Learning Problem
※機器學習技法-lecture5:Kernel Logistic Regression
※VR動作遊戲《Astro Bot Rescue Mission》與機器人一起戰鬥
※ClickHouse如何結合自家的GNDT演算法庫CatBoost來做機器學習
※德國Festo推出仿生機器人BionicWheelBot
※Windows Defender ATP機器學習和AMSI:發掘基於腳本的攻擊
※用AI 打造遊戲,Unity 機器學習 Agent——ml-agents
※Learn with Google AI:谷歌開放更多免費AI及機器學習在線資源
※用Scratch+IBM Watson實現機器學習
※Feature Tools:可自動構造機器學習特徵的Python庫
※亞馬遜AWS首席科學家Animashree Anandkumar:機器學習將引領未來革命