(PyTorch)50行程式碼實現對抗生成網路(GAN)

  • 2019 年 10 月 4 日
  • 筆記

2014年,蒙特利爾大學(University of Montreal)的伊恩•古德費洛(Ian Goodfellow)和他的同事發表了一篇令人震驚的論文,向全世界介紹了GANs,即生成式對抗網路。通過計算圖和博弈論的創新結合,他們表明,如果有足夠的建模能力,兩個相互競爭的模型將能夠通過普通的反向傳播進行協同訓練。

這些模型扮演兩個截然不同的角色(字面意思是對抗性的)給定一些真實的數據集R, G是生成器,試圖創建看起來像真實數據的假數據,而D是鑒別器,從真實數據集或G中獲取數據並標記差異。古德費羅的比喻(也是一個很好的比喻)是G就像一組偽造者試圖將真實的繪畫與他們的作品相匹配,而D則是一組偵探試圖分辨兩者不同之處。(除了在這種情況下,偽造者G永遠看不到原始數據——只有D的判斷。他們就像盲人偽造者一樣。)

在理想的情況下,隨著時間的推移,D和G都會變得更好,直到G本質上成為真品的「偽造大師」,而D則不知所措,「無法區分這兩種分布」。

在實踐中,古德費洛所展示的是,G將能夠在原始數據集上執行一種無監督學習的形式,以一種(可能)低得多的維度的方式來表示數據。正如楊立昆(Yann LeCun)所言,無監督學習是真正人工智慧的「蛋糕」。

這個強大的技術似乎需要大量的程式碼才能開始,對嗎?不。使用PyTorch,我們實際上可以用50行程式碼創建一個非常簡單的GAN。實際上只有5個組成部分需要考慮:

  • R:原始的、真實的數據集
  • I:作為熵源進入生成器的隨機雜訊
  • G:試圖複製/模擬原始數據集的生成器
  • D:鑒別器,用來區分G和R的輸出
  • 在實際的「訓練」循環中,我們教G欺騙D, D提防G。

1.)R:在我們的例子中,我們將從最簡單的R-鐘形曲線開始。此函數接受平均值和標準偏差,並返回一個函數,該函數使用這些參數從高斯函數中提供正確形狀的樣本數據。在我們的示例程式碼中,我們將使用平均值4.0和標準偏差1.25。

2.)I:生成器的輸入也是隨機的,但是為了讓我們的工作更困難一點,我們用均勻分布而不是正態分布。這意味著我們的模型G不能簡單地移動/縮放輸入來複制R,而是必須以非線性的方式重塑數據。

3.)G:生成器是一個標準的前饋圖——兩個隱層,三個線性映射。我們用的是雙曲正切激活函數,因為我們太老派了。G將從I中得到均勻分布的數據樣本以某種方式模擬R的正態分布樣本而不需要看到R。

4.)D:鑒別器程式碼與G的生成器程式碼非常相似;一個包含兩個隱層和三個線性映射的前饋圖。這裡的激活函數是一個S形,沒什麼特別的。它將從R或G中獲取樣本,並輸出一個介於0和1之間的標量,解釋為「假的」和「真實的」。換句話說,這是神經網路所能得到的最脆弱的東西。

5.)最後,訓練循環在兩種模式之間交替進行:第一種模式是真實數據的訓練D,另一種模式是虛假數據的訓練D,具有準確的標籤(可以將其視為警察學院);然後用不準確的標籤訓練G去愚弄D(這更像是《十一羅漢》中的準備蒙太奇)。這是一場正義與邪惡之間的戰爭。

即使您以前沒有見過PyTorch,您也可能知道發生了什麼。在第一個(綠色)部分中,我們將這兩種類型的數據都推入D,並對D的猜測與實際標籤應用可微標準。這種推動是『forward』」的一步;然後我們顯式調用『backward()』來計算梯度,然後使用梯度在d_optimizer step()調用中更新D的參數。G在這裡使用,但沒有經過訓練。

然後在最後(紅色)部分中,我們對G執行同樣的操作——注意,我們還通過D運行G的輸出(我們實際上是在給偽造者一個偵探來練習),但是在這一步我們沒有優化或更改D。我們不想讓D偵探知道錯誤的標籤。因此,我們只調用g_optimizer.step()。

還有……就這些。還有一些其他的樣板程式碼,但是特定於GAN的東西只是這5個組件,沒有其他的。

在D和G之間跳了幾千輪這種被禁止的舞蹈之後,我們得到了什麼?鑒別器D很快就會變好(而G則慢慢上升),但一旦它達到一定的能力水平,G就會有一個值得尊敬的對手並開始改進,真正改善了。

超過5000個訓練回合,每回合訓練D 20次,G 20次,G輸出的平均值超過4.0,但隨後回到一個相當穩定、正確的範圍(左)。同樣,標準偏差最初下降的方向是錯誤的,但隨後上升到期望的1.25範圍(右),與R匹配。

好。所以基本的統計數據最終與R相匹配。那麼更高的時刻呢?分布的形狀看起來對嗎?畢竟,均值為4.0,標準差為1.25的分布是均勻的,但這和R並不匹配。我們來看看G的最終分布:

還不賴。右尾比左尾稍粗,但是歪斜和峰度,我們可以說,是原始高斯函數的再現。

G幾乎完美地恢復了原始分布R,而D則蜷縮在角落裡,喃喃自語,無法分辨事實與虛構。這正是我們想要的行為(參見Goodfellow中的圖1)。少於50行程式碼。

現在,警告一句: GANs可能很挑剔和脆弱。當他們進入一種奇怪的狀態時,他們通常不會不經過一點勸說就出來。運行我的示例程式碼10次(每次超過5000輪),顯示了以下10個發行版:

10次運行中有8次的最終分布非常好——類似於高斯分布,均值為4,標準差在正確的範圍內。但是兩次運行不是—在一次運行(運行5)中,有一個凹分布,平均值在6.0左右,在最後一次運行(運行10)中,在-11處有一個狹窄的峰值!當您開始在幾乎所有的上下文中應用GANs時,您將會看到這種現象——GANs並不像一般的監督學習工作流那樣穩定。但當它們發揮作用時,它們看起來是非常神奇的。

Goodfellow將繼續發表關於GANs的許多其他論文,包括曾經的gem,其中描述了一些實際的改進,包括這裡採用的小批量識別方法。這是他發布的一個2小時的教程。對於TensorFlow用戶,這裡有一篇來自Aylien在GANs上的類似文章。