神經網路中,設計loss function有哪些技巧?
- 2019 年 10 月 29 日
- 筆記
☞機器學習、深度學習、python全棧開發乾貨
來源:深度學習與自然語言處理
神經網路中,設計loss function有哪些技巧?
本文綜合了幾位大咖的觀點而成。 作者:Alan Huang https://www.zhihu.com/question/268105631/answer/335246543
對於 gradient balancing問題,劉詩昆同學回答得挺不錯。我這邊再額外補充一些。
multi-task learning 中, tasks之間彼此的相容性對結果也會有一些影響。當兩個任務矛盾的時候, 往往結果會比單任務還要差不少。
Multi-task learning 還需要解決的是Gradient domination的問題。這個問題產生的原因是不同任務的loss的梯度相差過大, 導致梯度小的loss在訓練過程中被梯度大的loss所帶走。 題主所說的問題1和2都是指這個問題。
如果一開始就給不同的Loss進行加權, 讓它們有相近的梯度, 是不是就能訓練的好呢? 結果往往不是這樣的。不同的loss, 他們的梯度在訓練過程中變化情況也是不一樣的;而且不同的loss, 在梯度值相同的時候, 它們在task上的表現也是不同的。在訓練開始的時候,雖然balance了, 但是隨著訓練過程的進行, 中間又發生gradient domination了。 所以要想解決這個問題, 還是要合適地對不同loss做合適的均衡。
實踐中應該要如何調整呢?其實很簡單:
假設我們有兩個task, 用A和B表示。 假設網路設計足夠好, 容量足夠大, 而且兩個任務本身具有相關性,能夠訓得足夠好。
如果A和B單獨訓練, 他們在收斂的時候的梯度大小分別記為 Grad_a, Grad_b, 那麼我們只需要在兩個任務一起訓練的時候, 分別用各自梯度的倒數(1/Grad_a, 1/Grad_b)對兩個任務做平衡, 然後統一乘一個scalar就可以了。(根據單任務的收斂時候的loss梯度去確定multi-task訓練中不同任務的權重。)
因為loss的梯度在訓練中通常會變小(這裡用通常是因為一般mean square error等loss是這樣, 其他有的Loss並不是。), 如果我們確定這個網路在multi-task訓練的時候能夠達到原來的效果, 我們就只需要把平衡點設在兩個任務都足夠好的時候。這樣網路在訓練過程中, 就自然能夠達到那個平衡點, 即使一開始的時候會有gradient domination出現。
作者:劉詩昆 https://www.zhihu.com/question/268105631/answer/333738561
題主這個問題是 multi-task learning 里相當重要的一個核心問題。我正好在做相關的工作,很多細節將在論文投稿後再更新此答案。此外,我對 re-identfication 相關研究不熟所以無法回答第三個問題請見諒,望其他研究者補充。
理解多任務學習: Understanding Multi-task learning
Multi-task learning 核心的問題通常是可簡單分為兩類:
- How to share: 這裡主要涉及到基於 multi-task learning 的網路設計。
- Share how much: 如何平衡多任務的相關性使得每個任務都能有比 single-task training 取得更好的結果。
題主的問題主要落在第二類,儘管這兩個問題通常同時出現也互相關聯。對於 multi-task learning 更加粗略的介紹以及和 transfer learning 的關係請參看我之前的回答:劉詩昆:什麼是遷移學習 (Transfer Learning)?這個領域歷史發展前景如何?其中同樣包括了 task weighting 的一些討論,以下再做更加細節的補充。
網路設計和梯度平衡的關係: The Relationship Between Network Design and Gradient Balancing
無論是網路設計還是平衡梯度傳播,我們的目標永遠是讓網路更好的學習到 transferable, generalisable feature representation 以此來緩解 over-fitting。為了鼓勵多任務里多分享各自的 training signal 來學泛化能力更好的 feature,之前絕大部分研究工作的重點在網路設計上。直到去年才有陸續一兩篇文章開始討論 multi-task learning 里的 gradient balancing 問題。
再經過大量實驗後,我得出的結論是,一個好的 gradient balancing method 可以繼續有效增加網路的泛化能力,但是在網路設計本身的提高強度面前,這點增加不足一提。更加直白的表達是:
Gradient balancing method 一定需要建立在網路設計足夠好的基礎上,不然光憑平衡梯度並不會對網路泛化能力有著顯著的改變。
梯度統治: Gradient Domination
在 multi-task learning 里又可根據 training data 的類別再次分為兩類:
- one-to-many (single visual domain): 輸入一個數據,輸出多個標籤。通常是基於 image-to-image 的 dense prediction。一個簡單的例子,輸入一張圖片,輸出 semantic segmentation + depth estimation。
- many-to-many (multi visual domain):輸入多個數據,輸入各自標籤。比如如何同時訓練好多個圖片分類任務。
由於不同任務之間會有較大的差異,平衡梯度的目標是為了減緩任務本身的由於 variance, scale, complexity 不同而導致的差異。
在訓練 multi-task 網路時候則會因為任務複雜度的差異出現一個現象,我把他稱之為: Gradient Domination, 通常發生在 many-to-many 的任務訓練中。因為圖片分類可以因為圖片類別和本身數據數量而出現巨大差異。而基於 single visual domain 的 multi-task learning 則不容易出現這個問題因為數據集是固定的。
最極端的例子:MNIST + ImageNet 對於這種極端差異的多任務訓練基本可以看成基於 MNIST initialisation 的網路對於 ImageNet 的 finetune。所以這種情況的建議就是:優先訓練複雜度高的數據集,收斂之後再訓練複雜度低的數據集。當然這種情況下,多任務學習也沒有太大必要了。
對於一些差別比較大但是還是可接受範圍的比如:SVHN + CIFAR100。這種情況的 gradient balancing 就會出現一定的效果但也取決於你輸入數據的方式。輸入數據的通常方法,例如在這篇文章里:Incremental Learning Through Deep Adaptation 就是通過一個 dataset switch 來決定更新哪一個數據集的參數。對於這種方法,起始 learning rate 調的低,網路本身就會有一個較好的下降速率。
動態加權梯度傳播: Adaptive Weighting Scheme
即使光對優化網路調參並不能給多任務學習有著本質的改變。在考慮最 straightforward 的 loss:

我們的目標是學習好一個
能夠根據訓練效果動態變化使得平衡網路的梯度傳播。
這個問題目前只有兩篇文章做出了相關成果,
- Weight Uncertainty: 這個是通過 Gaussian approximation 的方式直接對修改了 loss 的方式,並同時以梯度傳播的方式來更新裡面的兩個參數。實際實驗效果也還不錯,在我復現的結果來看能有顯著的提升但是比較依賴並敏感於一個合適的的 learning rate 的設置。
- GradNorm : 是通過網路本身 back-propagation 的梯度大小進行 renormalisation。這篇文章寫的比較草率並被最近的 ICLR 2018 拒絕收錄了。個人期待他的更新作品能對方法本身有著更細節的描述。
- Dynamic Weight Average: 我對於 GradNorm 一個更加簡約且有效的改進,細節將會被補充。
一些總結
平衡梯度問題最近一年才剛剛開始吸引併產出部分深入研究的工作,這個方向對於理解 multi-task learning 來說至關重要,也可以引導我們去更加高效且條理化的訓練多任務網路。但在之前,更重要的事情是理解泛化能力本身,個人覺得 multi-task learning 的核心目標不在於訓練多個任務並得到超越單任務學習的性能,而是通過理解 multi-task learning 學習的過程重新思考並加深理解深度學習里 generlisation 的真正意義和價值。
作者:張小磊 https://www.zhihu.com/question/268105631/answer/333601828
一直認為設計或者改造loss function是機器學習領域的精髓,好的損失函數定義可以既能夠反映模型的訓練誤差,也能夠一定程度反映模型泛化誤差,可以很好的指導參數向著模型最優的道路進發。接下來關於設計損失函數提一些自己的看法:
1、設計損失函數之前應該明確自己的具體任務(分類、回歸或者排序等等),因為任務不同,具體的損失定義也會有所區別。對於分類問題,分類錯誤產生誤差;對於排序問題,樣本的偏序錯誤才產生誤差等。
2、設計損失函數應該以評價指標為導向,因為你的損失函數需要你的評價指標來評判,因此應該做到對號入座,回歸問題用均方誤差來衡量,那麼損失函數應為平方損失;二分類問題用準確率來衡量,那麼損失函數應為交叉熵損失,等等。
3、設計損失函數應該明確模型的真實誤差和模型複雜度(有種說法是,經驗誤差最小化和結構誤差最小化),既要保證損失函數能夠很好的反映訓練誤差,又要保證模型不至於過度繁瑣(過擬合的風險),也就是奧卡姆剃刀原理,如無必要,勿增實體。
4、設計損失函數時我們應該善於變通、善於借鑒、善於遷移。以2017年WWW上的Collaborative metric learning為例,該文將SVM的hinge loss引入到了metric learning裡邊,使得越相近的類里的越近,不相近的類距離越遠,同時會有一個最大邊界來處理分類錯誤的點(軟間隔),最後將該損失函數又引入到了推薦系統中的協同過濾演算法(CF)中。可以看出對於自己的研究領域,我們可以借鑒經典的損失函數來為我所用,以此來提升該領域的性能。
當然,以上說的更多的是普適思路,適用於傳統機器學習,相信對於深度學習同樣有借鑒意義。至於對於深度學習其他的技巧,應該還需要考慮深度學習模型獨有的一些問題,比如模型相對複雜以至於極易過擬合的風險,以及涉及參數眾多需要簡化調參等。