【GAN優化】什麼是模式崩潰,以及如何從優化目標上解決這個問題

  • 2019 年 10 月 8 日
  • 筆記

今天講述的內容是GAN中的模式崩潰問題,之前的文章有提到這個問題,在接下來的兩三期內,將和大家一起討論有關模式崩潰的解決方法。

本期將會首先介紹什麼是模式崩潰,然後給出兩種通過修改GAN目標函數的解決方法,而下一期將從網路結構和mini-batch判別器的角度出發討論模式崩潰的解決方法。

本期不會涉及什麼數學知識,示意圖將會最大限度的說明本質問題,如有問題歡迎底部留言。

作者 | 小米粥

編輯 | 言有三

1. 模式崩潰(mode collapse)

GAN,屬於無監督生成模型中的一類。既然是生成模型,我們最起碼應該關注兩點:生成器生成樣本的品質和樣本多樣性。

根據數據的流形分布定律,自然界中同一類別的高維數據,往往集中在某個低維流形附近,所以生成器最理想的情況是:將輸入的雜訊都映射到訓練數據所在的流形上,並且與訓練數據的概率分布對應。舉一個簡單的例子,對於某一個訓練數據集,其中樣本的概率分布為一個簡單的一維高斯混合分布,包含兩個峰:

理想狀態下,生成器應該生成的樣本應該如下所示(綠色標記),生成樣本的位置幾乎都在兩個峰下,且分布符合訓練集的概率分布:

但是,上述情況在實際中是不可能出現的,我們實際中大多時候生成的樣本品質是比較差的,例如:

上圖裡生成器生成了一些品質比較差的樣本(標記為紅色),訓練集中本不包含紅色樣本,生成器應該著力生成綠色樣本而不應該產生紅色樣本,這屬於生成品質問題(比如使用貓的影像訓練GAN,最終GAN生成了一些狗?四不像?之類的照片)。而模式崩潰問題是針對於生成樣本的多樣性,即生成的樣本大量重複類似,例如下圖中,

生成的樣本全部聚集在左邊的峰下,這時雖然生成樣本的品質比較高,但是生成器完全沒有捕捉到右邊的峰的模式。(如果使用多種貓的影像訓練GAN,最終GAN只能產生逼真的英短,而無法產生其他品種)。

關於GAN模式崩潰問題的緩解方式有很多,我們接下來關注兩種修改目標函數的解決方案。

2. unrolled GAN

首先需要說明:其實,生成器在某一時刻單純地將樣本都聚集到某幾個高概率的峰下並不是我們討厭模式崩潰的根本原因,如果生成器能「及時發現問題」,自動調整權值,將生成樣本分散到整個訓練數據的流形上,則能自動跳出當前的模式崩潰狀態,並且理論上生成器確實「具備」該項能力(因為GoodFellow證明了GAN會實現最優解)。

但是實際情況是:對於生成器的不斷訓練並未使其學會提高生成樣本的多樣性,生成器只是在不斷將樣本從一個峰轉移聚集到另一個峰下。這樣的過程「沒完沒了」,無法跳出模式崩潰的循環。無論你在何時終止訓練,都面臨著模式崩潰,只是在不同時刻,生成樣本所聚集的峰不同罷了。

不過,這種情況的發生有一定的必然性,我們先使用原始形式GAN對這個過程進行示意描述,其目標函數為:

真實數據集的概率分布還是如第一部分所示,生成器生成樣本的概率分布如下:

我們先更新判別器:

假設判別器達到了最優狀態,則其表達式應為:

對應的,D(x)的影像為:

可以看出,這時判別器會立刻「懷疑」x=-3附近樣本點的真實性,接下來更新生成器:

此時的生成器將會非常「無可奈何」,為了使得目標函數f最小,最好的方法便是將樣本聚集到x=3附近,即:

再更新判別器,同上述過程,判別器會立刻「懷疑」x=3附近樣本點的真實性……這樣的糟糕結果會不斷循環下去。

對此,unrolled GAN認為:正是因為生成器缺乏「先見之明」,導致了無法跳出模式崩潰的困境,生成器每次更新參數時,只考慮在當前生成器和判別器的狀態下可以獲得的最優解,生成器並不知道當前選擇的最優解從長遠來看並不是最優解。

我們通過一定的改進,來賦予生成器「先見之明」。具體說來,判別器的目標函數仍然為:

參數更新方式為採用梯度下降方式連續更新K次,如下:

而生成器的優化目標修改為:

即生成器在更新時,不僅僅考慮當前生成器的狀態,還會額外考慮以當前狀態為起始點,判別器更新K次後的狀態,綜合兩個資訊做出最優解。其梯度的變化為:

其中,第一項就是非常熟悉的標準GAN形式的計算得到的梯度,而第二項便是考慮K次更新後判別器的狀態而產生的附加項。

我們現在再看剛才的問題,unrolled GAN會跳出模式崩潰的循環。同樣的初始狀態,

生成器在進行下一步更新時,面對以下兩種可能性(左邊是之前提到過的模式崩潰狀態,右邊是比較理想的樣本生成狀態):

經計算,選擇右邊會比選擇左邊產生更小的目標函數值,故實際中,生成器進行梯度更新將會趨向於右邊的狀態從而跳出模式崩潰。可以看出,生成器跳出模式崩潰的核心原因就是更新參數時不僅考慮當下狀態,而且額外考慮了K步判別器的反應,從而避免了短視行為,當然需要說明,這樣做是明顯增加了計算量的。

3. DRAGAN

GAN的參數優化問題並不是一個凸優化問題,存在許多局部納什均衡狀態。即使GAN進入某個納什均衡狀態,損失函數表現為收斂,其仍舊可產生模式崩潰,我們認為此時參數進入一個壞的局部均衡點。

通過實踐,發現當GAN出現模式崩潰問題時,通常伴隨著這樣的表現:當判別器在訓練樣本附近更新參數時,其梯度值非常大,故DRAGAN的解決方法是:對判別器,在訓練樣本附近施加梯度懲罰項:

這種方式試圖在訓練樣本附近構建線性函數,因為線性函數為凸函數具有全局最優解。需要額外說明,DRAGAN的形式與WGAN-GP頗為相似,只是WGAN-GP是在全樣本空間施加梯度懲罰,而DRAGAN只在訓練樣本附近施加梯度懲罰。

[1] Kodali N , Abernethy J , Hays J , et al. On Convergence and Stability of GANs[J]. 2017.

[2] Metz L , Poole B , Pfau D , et al. Unrolled Generative Adversarial Networks[J]. 2016.

總結

這篇文章首先講了GAN的模式崩潰問題,並用一個簡單的例子做了過程示意,接著重點描述了unrolled GAN的思想,並同樣進行了過程示意描述,最後又比較簡單地描述了另一種方案:DRAGAN。下一期,我們將從GAN結構方面去考慮模式崩潰問題。

下期預告:解決模式崩潰的GAN結構