GAN對抗網路入門教程
- 2019 年 10 月 5 日
- 筆記
譯:A Beginner's Guide to Generative Adversarial Networks (GANs) https://skymind.ai/wiki/generative-adversarial-network-gan
1 GAN簡介
生成對抗網路(英語:Generative Adversarial Network,簡稱GAN)是非監督式學習的一種方法,通過讓兩個神經網路相互博弈的方式進行學習。該方法由伊恩·古德費洛等人於2014年提出。生成對抗網路由一個生成網路與一個判別網路組成。生成網路從潛在空間(latent space)中隨機取樣作為輸入,其輸出結果需要盡量模仿訓練集中的真實樣本。判別網路的輸入則為真實樣本或生成網路的輸出,其目的是將生成網路的輸出從真實樣本中儘可能分辨出來。而生成網路則要儘可能地欺騙判別網路。兩個網路相互對抗、不斷調整參數,最終目的是使判別網路無法判斷生成網路的輸出結果是否真實。
生成對抗網路常用於生成以假亂真的圖片。此外,該方法還被用於生成影片、三維物體模型等。
雖然生成對抗網路原先是為了無監督學習提出的,它也被證明對半監督學習、完全監督學習 、強化學習是有用的。
image
2 生成與判別演算法
要理解GAN,你應該知道生成演算法是如何工作的,但是在理解生成演算法之前,將它們與判別演算法進行對比可以加深理解。我們先看下什麼事判別演算法?
判別演算法試圖對輸入數據進行分類; 也就是說,給定數據實例的特徵,它們預測該數據所屬的標籤或類別。
例如,給定電子郵件中的所有單詞(數據實例),判別演算法可以預測該消息是spam(垃圾郵件)還是not_spam(非垃圾郵件)。 其中spam
是標籤之一,從電子郵件收集的單詞包是構成輸入數據的特徵。 當以數學方式表達此問題時,標籤稱為y,並且要素稱為x。公式p(y|x)用於表示「給定x條件下y發生的概率」,在這種情況下,它將轉換為「在給定郵件所包含的字詞情況下,電子郵件是垃圾郵件的概率」。
因此,判別演算法是將特徵映射到標籤,而生成演算法恰恰在做相反的事情。生成演算法試圖預測給定某個標籤下的特徵,而不是預測給定某些特徵的標籤。
生成演算法試圖回答的問題是:假設這封電子郵件是垃圾郵件,特徵的分布或者概率是怎麼樣的? 雖然判別模型關注y和x之間的關係,但是生成模型關心「你如何得到x。」生成演算法是為了計算出(x | y),給出y條件下x發生的概率,或者說給出標籤時,特徵的概率。 (也就是說,生成演算法也可以用作分類器。恰好它們不是對輸入數據進行分類。)
下面兩句話將判別與生成區分開來:
- 判別模型學習了類之間的界限
- 生成模型模擬各個類的分布
3 GANs原理
GAN的基本原理其實非常簡單,這裡以生成圖片為例進行說明。假設我們有兩個網路,G(Generator)和D(Discriminator)。正如它的名字所暗示的那樣,它們的功能分別是:一個神經網路,稱為生成器,生成新的數據實例,而另一個神經網路,判別器,評估它們的真實性; 即判別器決定它所評測的每個數據實例是否屬於實際訓練數據集。
G是一個生成圖片的網路,它接收一個隨機的雜訊z,通過這個雜訊生成圖片,記做G(z)。 D是一個判別網路,判別一張圖片是不是「真實的」。它的輸入參數是x,x代表一張圖片,輸出D(x)代表x為真實圖片的概率,如果為1,就代表100%是真實的圖片,而輸出為0,就代表不可能是真實的圖片。 在訓練過程中,生成網路G的目標就是盡量生成真實的圖片去欺騙判別網路D。而D的目標就是盡量把G生成的圖片和真實的圖片分別開來。這樣,G和D構成了一個動態的「博弈過程」。
最後博弈的結果是什麼?在最理想的狀態下,G可以生成足以「以假亂真」的圖片G(z)。對於D來說,它難以判定G生成的圖片究竟是不是真實的,因此D(G(z)) = 0.5。
reference:https://zhuanlan.zhihu.com/p/24767059
以下是GAN大致步驟:
- 生成器接收隨機數並返回影像。
- 將生成的影像與從真實數據集中獲取的影像流一起饋送到判別器中。
- 判別器接收真實和假影像並返回概率,0到1之間的數字,1表示真實性的預測,0表示假。
image
您可以將GAN視為詐騙者和警察在貓與老鼠遊戲中的反對,其中詐騙者正在學習傳遞虛假資訊,並且警察正在學習如何檢測它們。 兩者都是動態的; 也就是說,警察也在接受培訓,每一方都在不斷升級中學習對方的方法。
對於MNIST數據集,判別器網路是標準卷積網路,可以對饋送給它的影像進行分類,二項分類器將影像標記為真實或偽造。 在某種意義上,生成器是反卷積網路:當標準卷積分類器採用影像並對其進行下取樣以產生概率時,生成器採用隨機雜訊矢量並將其上取樣到影像。 第一個通過下取樣技術(如maxpooling)丟棄數據,第二個生成新數據。
image
4 GANs, Autoencoders and VAEs
下面對生成性對抗網路與其他神經網路(例如自動編碼器和變分自動編碼器)進行比較。
自動編碼器將輸入數據編碼為矢量。它們創建原始數據的隱藏或壓縮表示,在減少維數方面很有用; 也就是說,用作隱藏表示的向量將原始數據壓縮為較少數量的突出維度。 自動編碼器可以與所謂的解碼器配對,允許您根據其隱藏的表示重建輸入數據,就像使用受限制的Boltzmann機器一樣。
image
變分自動編碼器是生成演算法,其為編碼輸入數據添加額外約束,即隱藏表示被標準化。 變分自動編碼器能夠像自動編碼器一樣壓縮數據並像GAN一樣合成數據。 然而GAN可以更精細、細粒度的生成數據,VAE生成的影像往往更加模糊。 Deeplearning4j的例子包括自動編碼器和變分自動編碼器。(https://github.com/deeplearning4j/dl4j-examples/tree/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/unsupervised)
5 Keras 實現GAN
https://github.com/eriklindernoren/Keras-GAN
from __future__ import print_function, division from keras.datasets import mnist from keras.layers import Input, Dense, Reshape, Flatten, Dropout from keras.layers import BatchNormalization, Activation, ZeroPadding2D from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import Adam import matplotlib.pyplot as plt import sys import numpy as np
Using TensorFlow backend.
class GAN(): def __init__(self): self.img_rows = 28 self.img_cols = 28 self.channels = 1 self.img_shape = (self.img_rows, self.img_cols, self.channels) optimizer = Adam(0.0002, 0.5) # Build and compile the discriminator self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) # Build and compile the generator self.generator = self.build_generator() self.generator.compile(loss='binary_crossentropy', optimizer=optimizer) # The generator takes noise as input and generated imgs z = Input(shape=(100,)) img = self.generator(z) # For the combined model we will only train the generator self.discriminator.trainable = False # The valid takes generated images as input and determines validity valid = self.discriminator(img) # The combined model (stacked generator and discriminator) takes # noise as input => generates images => determines validity self.combined = Model(z, valid) self.combined.compile(loss='binary_crossentropy', optimizer=optimizer) def build_generator(self): noise_shape = (100,) model = Sequential() model.add(Dense(256, input_shape=noise_shape)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(np.prod(self.img_shape), activation='tanh')) model.add(Reshape(self.img_shape)) model.summary() noise = Input(shape=noise_shape) img = model(noise) return Model(noise, img) def build_discriminator(self): img_shape = (self.img_rows, self.img_cols, self.channels) model = Sequential() model.add(Flatten(input_shape=img_shape)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(256)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1, activation='sigmoid')) model.summary() img = Input(shape=img_shape) validity = model(img) return Model(img, validity) def train(self, epochs, batch_size=128, save_interval=50): # Load the dataset (X_train, _), (_, _) = mnist.load_data() # Rescale -1 to 1 X_train = (X_train.astype(np.float32) - 127.5) / 127.5 X_train = np.expand_dims(X_train, axis=3) half_batch = int(batch_size / 2) for epoch in range(epochs): # --------------------- # Train Discriminator # --------------------- # Select a random half batch of images idx = np.random.randint(0, X_train.shape[0], half_batch) imgs = X_train[idx] noise = np.random.normal(0, 1, (half_batch, 100)) # Generate a half batch of new images gen_imgs = self.generator.predict(noise) # Train the discriminator d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1))) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1))) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- noise = np.random.normal(0, 1, (batch_size, 100)) # The generator wants the discriminator to label the generated samples # as valid (ones) valid_y = np.array([1] * batch_size) # Train the generator g_loss = self.combined.train_on_batch(noise, valid_y) # Plot the progress if epoch%1000==0: print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) # If at save interval => save generated image samples if epoch % save_interval == 0: self.save_imgs(epoch) def save_imgs(self, epoch): r, c = 5, 5 noise = np.random.normal(0, 1, (r * c, 100)) gen_imgs = self.generator.predict(noise) # Rescale images 0 - 1 gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') axs[i,j].axis('off') cnt += 1 fig.savefig("data/gan/images/mnist_%d.png" % epoch) plt.close() if __name__ == '__main__': gan = GAN() gan.train(epochs=30000, batch_size=32, save_interval=200)
WARNING:tensorflow:From D:ProgramDataAnaconda3libsite-packageskerasbackendtensorflow_backend.py:66: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead. WARNING:tensorflow:From D:ProgramDataAnaconda3libsite-packageskerasbackendtensorflow_backend.py:541: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead. WARNING:tensorflow:From D:ProgramDataAnaconda3libsite-packageskerasbackendtensorflow_backend.py:4432: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead. Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= flatten_1 (Flatten) (None, 784) 0 _________________________________________________________________ dense_1 (Dense) (None, 512) 401920 _________________________________________________________________ leaky_re_lu_1 (LeakyReLU) (None, 512) 0 _________________________________________________________________ dense_2 (Dense) (None, 256) 131328 _________________________________________________________________ leaky_re_lu_2 (LeakyReLU) (None, 256) 0 _________________________________________________________________ dense_3 (Dense) (None, 1) 257 ================================================================= Total params: 533,505 Trainable params: 533,505 Non-trainable params: 0 _________________________________________________________________ WARNING:tensorflow:From D:ProgramDataAnaconda3libsite-packageskerasoptimizers.py:793: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead. WARNING:tensorflow:From D:ProgramDataAnaconda3libsite-packageskerasbackendtensorflow_backend.py:3657: The name tf.log is deprecated. Please use tf.math.log instead. WARNING:tensorflow:From D:ProgramDataAnaconda3libsite-packagestensorflowpythonopsnn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where WARNING:tensorflow:From D:ProgramDataAnaconda3libsite-packageskerasbackendtensorflow_backend.py:148: The name tf.placeholder_with_default is deprecated. Please use tf.compat.v1.placeholder_with_default instead. Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_4 (Dense) (None, 256) 25856 _________________________________________________________________ leaky_re_lu_3 (LeakyReLU) (None, 256) 0 _________________________________________________________________ batch_normalization_1 (Batch (None, 256) 1024 _________________________________________________________________ dense_5 (Dense) (None, 512) 131584 _________________________________________________________________ leaky_re_lu_4 (LeakyReLU) (None, 512) 0 _________________________________________________________________ batch_normalization_2 (Batch (None, 512) 2048 _________________________________________________________________ dense_6 (Dense) (None, 1024) 525312 _________________________________________________________________ leaky_re_lu_5 (LeakyReLU) (None, 1024) 0 _________________________________________________________________ batch_normalization_3 (Batch (None, 1024) 4096 _________________________________________________________________ dense_7 (Dense) (None, 784) 803600 _________________________________________________________________ reshape_1 (Reshape) (None, 28, 28, 1) 0 ================================================================= Total params: 1,493,520 Trainable params: 1,489,936 Non-trainable params: 3,584 _________________________________________________________________ D:ProgramDataAnaconda3libsite-packageskerasenginetraining.py:493: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ? 'Discrepancy between trainable weights and collected trainable' 0 [D loss: 0.735185, acc.: 46.88%] [G loss: 0.829077] D:ProgramDataAnaconda3libsite-packageskerasenginetraining.py:493: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ? 'Discrepancy between trainable weights and collected trainable' 1000 [D loss: 0.590758, acc.: 71.88%] [G loss: 0.793450] 2000 [D loss: 0.587990, acc.: 62.50%] [G loss: 0.956186] 3000 [D loss: 0.644352, acc.: 59.38%] [G loss: 0.914777] 4000 [D loss: 0.673936, acc.: 62.50%] [G loss: 0.971460] 5000 [D loss: 0.759974, acc.: 53.12%] [G loss: 0.904706] 6000 [D loss: 0.555306, acc.: 81.25%] [G loss: 0.835633] 7000 [D loss: 0.674409, acc.: 62.50%] [G loss: 0.823623] 8000 [D loss: 0.672854, acc.: 53.12%] [G loss: 0.863680] 9000 [D loss: 0.743683, acc.: 46.88%] [G loss: 0.868321] 10000 [D loss: 0.635190, acc.: 59.38%] [G loss: 0.854181] 11000 [D loss: 0.700397, acc.: 56.25%] [G loss: 0.778778] 12000 [D loss: 0.741978, acc.: 46.88%] [G loss: 0.813542] 13000 [D loss: 0.760614, acc.: 46.88%] [G loss: 0.833507] 14000 [D loss: 0.671199, acc.: 68.75%] [G loss: 0.853395] 15000 [D loss: 0.676217, acc.: 62.50%] [G loss: 0.920993] 16000 [D loss: 0.593898, acc.: 68.75%] [G loss: 0.889001] 17000 [D loss: 0.724363, acc.: 50.00%] [G loss: 0.893431] 18000 [D loss: 0.779740, acc.: 43.75%] [G loss: 0.853765] 19000 [D loss: 0.642237, acc.: 59.38%] [G loss: 0.830348] 20000 [D loss: 0.587237, acc.: 62.50%] [G loss: 0.876839] 21000 [D loss: 0.645381, acc.: 62.50%] [G loss: 0.827465] 22000 [D loss: 0.723597, acc.: 46.88%] [G loss: 0.862281] 23000 [D loss: 0.671319, acc.: 65.62%] [G loss: 0.903444] 24000 [D loss: 0.684801, acc.: 62.50%] [G loss: 0.807403] 25000 [D loss: 0.737355, acc.: 43.75%] [G loss: 0.813877] 26000 [D loss: 0.606201, acc.: 68.75%] [G loss: 0.802509] 27000 [D loss: 0.711020, acc.: 56.25%] [G loss: 0.894887] 28000 [D loss: 0.641023, acc.: 56.25%] [G loss: 0.856079] 29000 [D loss: 0.696889, acc.: 46.88%] [G loss: 0.728626]
可以看到D的判別準確率最終在46%-56%之間,也就是說G網路生成的圖片已經真假難分
6 參考資料
- GAN學習指南:從原理入門到製作生成Demo https://zhuanlan.zhihu.com/p/24767059
- A Beginner's Guide to Generative Adversarial Networks (GANs) https://skymind.ai/wiki/generative-adversarial-network-gan