多任務學習中的數據分布問題(一)

今天這個專題源於我在做分散式多任務學習實驗時在選取數據集的時候的疑惑,以下我們討論多任務學習中(尤其是在分散式的環境下)如何選擇數據集和定義任務。

多任務學習最初的定義是:”多任務學習是一種歸納遷移機制,基本目標是提高泛化性能。多任務學習通過相關任務訓練訊號中的領域特定資訊來提高泛化能力,利用共享表示採用並行訓練的方法學習多個任務”。然而其具體實現手段卻有許多(如基於神經網路的和不基於神經網路的,這也是容易讓人糊塗的地方),但是不管如何,其關鍵點——共享表示是核心。

1.經典(非神經網路的)多任務學習

經典(非神經網路的)多任務學習我們已經在博文《分散式多任務學習論文閱讀(一):多任務學習速覽 》中詳細討論,此處不再贅述。在這種模式中給定\(t\)個學習任務\(\{\mathcal{T}_t\}_{t=1}^T\),每個任務各對應一個數據集\(\mathcal{D}_t = {\{(\bm{x}_{ti}, y_{ti})}_{i=1}^{m_t}\}\)(其中\(\bm{x_{ti}} \in \mathbb{R}^{d}\)\(y_{ti} \in \mathbb{R}\)),然後根據根據\(T\)個任務的訓練集學習\(T\)個函數\(\{f_t(\bm{x})\}_{t=1}^{T}\)。在這種模式下,每個任務的模型假設(比如都是線性函數)都常常是相同,導致每個任務的模型(權重)不同的原因歸根結底在於每個任務的數據集不同(每個任務的損失函數默認相同,但其實可同可不同)。 此模式優化的目標函數可以寫作:

\[\begin{aligned}
\underset{\textbf{W}}{\min}
& \sum_{t=1}^{T}\mathbb{E}_{(\bm{x_{ti}, y_{ti})\sim \mathcal{D}_t}}[L(y_{ti}, f(\bm{x}_{ti}; \bm{w}_t))]+ \lambda g(\textbf{W})\\

=& \sum_{t=1}^{T} [\frac{1}{m_t}\sum_{i=1}^{m_t}L(y_{ti}, f(\bm{x}_{ti}; \bm{w}_t))]+\lambda g(\textbf{W})\\
\end{aligned}
\tag{2}
\]

(此處\(\textbf{W}=(\bm{w}_1,\bm{w}_2,…,\bm{w}_T)\)為所有任務參數構成的矩陣,\(g(\textbf{W})\)編碼了任務的相關性)

我們下列所討論的分散式多任務學習,採用的數據分布假設也大多來自這種情況。

2. 聯邦學習和經典多任務學習中的數據分布對比

論文[1][2]在聯邦學習的情景下引入了多任務學習,從這篇兩篇論文我們可以看到聯邦學習和多任務學習的關聯和差異。
在標準的聯邦學習中,我們需要實現訓練一個「元模型」,然後再分發在各任務節點上微調。每個節點任務不共享數據,但是可以共享參數,以此聯合訓練出各一個全局的模型。也就是說,聯邦學習下每個節點的任務是一樣的,但是由於數據不獨立同分布,每個模型訓練出的局部模型差異會很大,就會使得構建一個全局的、通用的模型難度很大。比如同樣一個下一個單詞預測的任務,同樣給定”I love eating,”,但對於下一個單詞每個client會給出不同的答案。

[1][2]論文都提出一個思想,如果我們不求訓練出一個全局的模型,使每個節點訓練各不相同的模型這樣一種訓練方式,這被冠名為聯邦多任務學習了。論文[1][2]都保持了經典多任務學習的假設,不過有些許區別。論文[1]中每個任務的訓練數據分布和損失函數都不同。但是論文[2]中假定每個任務不同之處只有訓練數據的分布。

3.基於神經網路的多任務學習中的數據分布

基於神經網路的多任務學習(也就是大多數在CV、NLP)中使用的那種,分類和定義其實非常會亂,下面我們來看其中的一些常見方式。

3.1 同樣的輸入數據,不同的loss

大多數基於神經網路的多任務學習採用的方式是各任務基於同樣的輸入數據(或者可以看做將不同任務的數據混在一起使用),用不同的loss定義不同任務的。

如CV中使用的深度關係多任務學習模型:

CV多任務學習

NLP中的Joint learning:

NLP多任務學習

推薦系統中的用戶序列多任務模型:

NLP多任務學習

3.1 不同的輸入數據,不同的loss

我們也可以保持共享表示層這一關鍵特性不變,但是每個任務有不同的輸入數據和不同的loss,如下圖所示:
NLP多任務學習
在這種架構中,Input x表示不同任務的輸入數據,綠色部分表示不同任務之間共享的層,紫色表示每個任務特定的層,Task x表示不同任務對應的損失函數層。在多任務深度網路中,低層次語義資訊的共享有助於減少計算量,同時共享表示層可以使得幾個有共性的任務更好的結合相關性資訊,任務特定層則可以單獨建模任務特定的資訊,實現共享資訊和任務特定資訊的統一。

(注意,在深度網路中,多任務的語義資訊還可以從不同的層次輸出,例如GoogLeNet中的兩個輔助損失層。另外一個例子比如衣服影像檢索系統,顏色這類的資訊可以從較淺層的時候就進行輸出判斷,而衣服的樣式風格這類的資訊,更接近高層語義,需要從更高的層次進行輸出,這裡的輸出指的是每個任務對應的損失層的前一層。)

3.2 不同的輸入數據,相同的loss

我們想一下,每個任務對應不同的輸入數據,相同的loss的情況。比如我們同一個影像分類網路和交叉熵損失,但一個任務的數據集是男人和女人,一個任務數據集是人和狗,我們將這兩個數據集進行聯合學習,這是否算是多任務學習?如果是,是否能同時提升人-人分類器的精度和人-狗分類器的精度?(如下圖所示)
NLP多任務學習

第一個問題,按照經典多任務學習的分類,這種應該是算的,因為每個任務的數據集不同,直接導致了學得的模型不同,又由於有共享表示這一關鍵特性,也可以算是多任務學習。至於第二個問題,我覺得是可以的,因為這兩個任務雖然數據集不同,但是是互相關聯的,比如人的話可能會檢測頭髮,狗的話可能會檢測耳朵,但是都有一個檢測局部特徵的相似性在裡面。

參考文獻

  • [1] Smith V, Chiang C K, Sanjabi M, et al. Federated multi-task learning[J]. Advances in Neural Information Processing Systems, 2017.
  • [2] Marfoq O, Neglia G, Bellet A, et al. Federated multi-task learning under a mixture of distributions[J]. Advances in Neural Information Processing Systems, 2021, 34.