聯邦學習中的優化演算法
導引
聯邦學習做為一種特殊的分散式機器學習,仍然面臨著分散式機器學習中存在的問題,那就是設計分散式的優化演算法。
以分散式機器學習中常採用的client-server架構(同步)為例,我們常常會將各client節點計算好的局部梯度收集到server節點進行求和,然後再根據這個總梯度進行權重更新。
不過相比傳統的分散式機器學習,它需要關注系統異質性(system heterogeneity)、統計異質性(statistical heterogeneity)和數據隱私性(data privacy
)。系統異質性體現為昂貴的通訊代價和節點隨時可能宕掉的風險(容錯);統計異質性數據的不獨立同分布(Non-IID)和不平衡。由於以上限制,傳統分散式機器學習的優化演算法便不再適用,需要設計專用的聯邦學習優化演算法。
舉個例子,傳統分散式機器學習中也提出了許多降低通訊量的演算法,包括近似牛頓法[1][2][3]、小樣本平均[5]等,但這些演算法只考慮了數據IID的情況,不能照搬過來。演算法[4]沒有假設數據IID,但是不適用深度學習,因為神經網路很難求對偶問題。
目前已經針對聯邦學習提出了許多新的優化演算法。同時,同時除了中心(centralized)化優化演算法,針對聯邦學習的去中心化(decentralized)優化演算法也得到了廣泛研究。
FedAvg——旨在減少通訊的開山之作
在聯邦學習中,首先的不同便是通訊代價。我們希望每輪通訊能夠在client上完成更多的運算(client端一般是用戶手機等設備,充電的時候都可以計算),這也是聯邦學習開山論文[6]提出的FedAvg演算法的初衷。該演算法是聯邦學習領域最為基礎的梯度聚合方法。
相比傳統分散式機器學習方法在client節點只計算出梯度,FedAvg方法希望client節點能夠多做一些運算,得到比梯度更好的下降方向。由於這個下降方向比梯度更好,所以可以收斂更快。而收斂快了,那麼通訊次數自然就少了。這就是該演算法設計的基本想法。
該演算法的每輪通訊描述如下:
(1) 第\(k\)個client節點執行
- 從server接收全局模型參數\(w^t\)並令\(w_k=w^t\)。
- 執行\(E\)個局部epoch的SGD:
\[w_k = w_k – \eta \nabla \mathcal{l}(w_k; b)
\](此處將局部數據\(D_k\)劃分為多個\(b\))
- 將新的\(w_k\)發往server。
(2) server節點執行
-
從\(K\)個client接收\(w_1^{t+1}、w_2^{t+1},…w_K^{t+1}\)
-
按(加權)平均更新模型參數:
\]
其中\(t\)為第\(t\)輪迭代。可以看到相比傳統分時機器學習中每個client計算完梯度就發給server,FedAvg計算完梯度後會直接更新局部參數,同時重複該過程多次。而對於server,會對client傳來的參數進行加權平均。
注意,FedAvg還有一種變種寫法如下:
(1) 第\(k\)個client節點執行
- 從server接收全局模型參數\(w^t\)並令\(w_k = w^t\)。
- 執行\(E\)個局部epoch的SGD:
\[w_k = w_k – \eta \nabla \mathcal{l}(w_k; b)
\](此處將局部數據\(D_k\)劃分為多個\(b\))
- 將\(g_k = w_k-w^t\)發往server。
(2) server節點執行
-
從\(K\)個client接收\(g_1、g_2,…g_K\)
-
按(加權)平均更新模型參數:
\]
兩種寫法本質上等效的。
綜上所述,FedAvg演算法在通訊次數相同的情況下,自然會收斂更快。如果實驗對比FedAvg和傳統分散式機器學習的SGD,我們會得到這樣的結果:
不過這麼做是有代價的,當client節點的計算量(以epoch來衡量)相同,那麼FedAvg的收斂速度是不如傳統SGD的。
這是典型的以計算換通訊策略。 而聯邦學習中計算代價小,通訊代價大,因此FedAvg演算法很有用。該演算法的作者已證明,FedAvg能夠在Non-IID條件下收斂[7]。論文[8]以Gboard輸入法背景下的單次預測任務為例,從工程上證明了FedAvg演算法的優越性。
FedProx——關注掉隊者
FedProx[9]主要從系統異質性和統計異質性兩個方向入手來改良FedAvg演算法。不過,介於後來FedAvg演算法已被證明在Non-IID數據集上本來能收斂[7],該演算法的貢獻還是在於提供了一個收斂更快、效果更好的演算法。
我們知道,系統異質性下,FedAvg演算法要求的每個節點執行的\(E\)個epoch的局部迭代可能無法得到保證,因為節點隨時可能宕掉。FedProx作者還探究了統計異質性和系統異質性之間的相互作用,並認為系統異質性產生的掉隊者(stragglers) 以及掉隊者發往server的帶偏差的參數資訊會進一步增加統計異質性,最終影響收斂。因此,作者提出在client的優化目標函數中增加一個近端項,這樣可以使優化演算法更加穩定,最終使得FedProx在統計異質性下也收斂更快。
FedProx中server的操作和FedAvg相同,都是採用(加權)平均,但是其第\(k\)個client端不再是執行\(E\)輪的SGD,而是求解以下帶近端項的優化問題:
w^{t+1}_k &=\underset{w}{\text{argmin}}h(w,w^t_k)\\
& = \mathcal{l}(w, D_k) + \frac{\mu}{2}||w-w^t_k||^2
\end{aligned}
\]
其中\(\mathcal{l}(w, D_k)\)為客戶端原本的優化函數,\(\frac{\mu}{2}||w-w^t||^2\)為近端項。作者在論文中證明了近端項的添加能夠使FedProx更好地適用於統計異質和系統異質的環境。
我們可以認為FedAvg是FedProx中將\(\mu\)設置為0,求解器設置為SGD,Epoch設置為\(E\)(當然這樣就無法處理系統異質性)的特殊情況。
FedAvg+ ———用元學習做模型個性化
在傳統的聯邦學習中,每個client節點聯合聯合訓練出各一個全局的模型(在前文中即server節點的\(w_t\))。但是由於數據Non-IID,訓練出的全局模型很難對每個局部節點都適用,不夠「個性化」。
Jiang Y等人的這篇論文[10]首次採用了個性化聯邦學習的思路:不求訓練出一個全局的模型,而使每個節點訓練各不相同的模型。作者在論文中採用模型不可知的元學習(Model Agnostic Meta Learning, MAML) 思路。元學習在給定小樣本實例的條件下進行自適應,可以優化在異構任務上的表現,它由兩個步驟構成:「meta training」——訓練初始化模型/元模型和「meta testing」——使初始化模型在特定的任務下完成自適應。
作者認為傳統的FedAvg演算法[6]可以被解釋為一種元學習演算法。在此基礎上再進行仔細的微調(fine-tuning)能夠使全局模型少一些泛化性,但同時能夠更容易個性化。我們將全局的已訓練的模型稱為初始化模型(initial model),將局部的已訓練模型稱為個性化模型(personalized model)。論文沒有採用[11]中將訓練初始化模型和模型個性化的操作分離,作者認為這樣會陷入局部最優,作者提出的演算法包括以下3個連續的步驟:
(1) 運行傳統的FedAvg演算法得到初始化模型,其中採用更大的\(E\),並使用帶動量的SGD做為優化器。
(2) 採用FedAvg的變種演算法對初始化模型進行微調: 此時採用Adam做為優化器,且迭代不再是採用\(E\)個epoch,而是先從\(D_k\)中隨取樣\(M\)個數據集\(\{D_{k,m}\}\)(\(M\)一般較小),然後進行如下的\(M\)個迭代步:
\]
(3) 對client進行進行個性化操作,採用和訓練期間相同的優化器。
作者認為該演算法能夠得到更穩固的初始化模型,這樣對於一些clients只有有限的甚至沒有數據來做個性化的情況很有好處。
Clustered FL——多任務知識共享
聚類聯邦學習(CFL)[12]這篇論文針對數據Non-IID導致的局部最優,提出了一種新的聯邦學習個性化方法:聚類(多任務)聯邦學習。
CFL保持著個性化聯邦學習的基本假設:每個節點訓練各不相同的模型。但並沒有採用元學習中初始化模型+自適應的措施,而是借用多任務學習中的常見手段,即讓節點在訓練的過程中就進行知識共享(可以參見我的部落格《基於正則表示的多任務學習》),而無需另設一個初始化模型。更具體的,CFL採用的是聚類多任務學習(clustered multitask learning),在訓練的過程中將參數相似的節點劃分為同一個任務簇,同一個任務簇共享參數的變化量\(g\),以此既能達到完成知識共享和相似的節點相互促進的目的。
聚類聯邦學習演算法的每輪通訊描述如下:
(1) 第\(k\)個client節點執行
- 從server接收\(g_{c(k)}\)
- 另\(w_{old}=w_k=w_k + g_{c(k)}\)
- 執行\(E\)個局部epoch的SGD:
\[w_k = w_k – \eta \nabla \mathcal{l}(w_k; b)
\](此處將局部數據\(D_k\)劃分為多個\(b\))
- 將\(g_k = w_k-w_{old}\)發往server。
(2) server節點執行
-
從\(K\)個client接收\(g_1、g_2,…g_K\)
-
對每一個簇\(c\in \mathcal{C}\),計算簇內平均參數變化:
\]
- 根據不同節點參數變化量的餘弦距離\(\alpha_{i,j}=\frac{\langle g_i, g_k\rangle}{||g_i||||g_j||}\)重新劃分聚類簇。
CFL的簇劃分演算法採用的是不斷進行二分裂的方式,無需指定簇的數量做為先驗。該演算法最重要的貢獻就是簇間知識共享思想的引入(並不共享參數,而共享參數的變化量,注意和論文[15]中直接參數的平均進行區分)。
pFedMe—純優化視角的個性化
pFedMe[13]這篇論文繼續瞄準聯邦學習個性化,它的創新點是使用Moreau envelope(也稱Moreau-Yosida正則化)做為client的正則損失函數。該演算法比已有的許多演算法收斂速度更快。
這個方法的一大貢獻將個性化模型與全局模型同時進行優化求解,該方法按照與標準FedAvg相似的方法來更新全局模型(多了個一階指數平滑),不過會以更低的複雜度來對個性化模型進行優化。
該篇論文演算法的每輪通訊描述如下:
(1) 第\(k\)個client節點執行
-
從server接收全局模型參數\(w^t\)並令\(w_k = w^t\)。
-
執行\(R\)輪局部迭代:
\[w_k = w_k – \eta \mu(w_k – \hat{\theta}_k(w_k))
\]其中
\[ \hat{\theta}_k(w_k)= \underset{\theta_k \in \mathbb{R}^d}{\text{argmin}} \{ \mathcal{l}(\theta_k, D_k) + \frac{\mu}{2}||\theta_k – w_k||^2 \}
\] -
將新的\(w_k\)發往server。
(2) server節點執行
-
從\(K\)個client接收\(w_1^{t+1}、w_2^{t+1},…w_K^{t+1}\)
-
按與平均值的一次指數平滑更新模型參數:
\[w^{t+1} = (1-\beta)w^t + \beta \frac{1}{K} \sum_{k=1}^K w_k^{t+1}
\]
其中重點在於client每輪局部迭代中求解Moreau envelope的部分,即求解\(\hat{\theta}_k(w_k)\)的部分。這裡\(\theta_k\)表示第\(k\)個client的個性化模型,\(\mu\)參數用於控制全局模型參數\(w_k\)相對於個性化模型的強度。其中Moreau envelope部分可以採用任意迭代方法求解。
FedEM—混合分布假設與EM演算法
FedEM[14]這篇論文另闢蹊徑,沒有關注模型的個性化,而是考慮從優化演算法上去提高聯邦學習模型的精度,其中採用的手段有兩點,一點是基於client節點數據滿足混合分布的假設,使每個client節點訓練由\(M\)個子模型集成所得的模型;二點是針對混合分布的假設,採用EM演算法來做參數估計,提高了模型的整體精度。
該演算法中心化形式的每輪通訊描述如下:
(1) 第\(k\)個client節點執行
-
從server接收全局模型參數\(w^t\)並令\(w_k = w^t\)。
-
對每一個模型成分\(m\)(\(m=1,…, M\))以及每一個局部樣本\(i\)(\(i=1,…,n_t\))執行\(\text{E}\)步驟
\[q_k(z^{(i)}_k=m)\leftarrow \frac{\pi _{km}\cdot \text{exp}\left(-\mathcal{l}(h_{w_{km}}(x_k^{(i)}), y_k^{(i)})\right)}
{\sum_{m’=1}^M \pi _{km’}\cdot \text{exp}\left(-\mathcal{l}(h_{w_{km’}}(x_k^{(i)}), y_k^{(i)})\right)}
\]對每一個模型成分\(m\)執行\(\text{M}\)步驟
\[ \pi_{km} = \frac{\sum_{i=1}^{n_t} q_k(z^{(i)}_k=m)}{n_t}
\]對每一個模型成分\(m\)執行\(J\)輪局部迭代:
\[w_{km} = w_{km} – \eta_j\sum_{i\in \mathcal{I}}q_k(z^{(i)}_k=m)\cdot \nabla_{w_{km}}\mathcal{l}(h_{w_{km}}(x_k^{(i)}), y_k^{(i)})
\]\(\mathcal{I}\)為每輪迭代有放回地從\(1,2,…|D_k|\)中採的隨機樣本索引集合
-
將新的\(w_k\)發往server。
(2) server節點執行
-
從\(K\)個client接收\(w_1^{t+1}、w_2^{t+1},…w_K^{t+1}\)
-
對每一個模型成分\(m\)按(加權)平均更新模型參數:
\]
該演算法的中心化形式在許多數據集上精度都取得了SOTA的水平。
該演算法的去中心化形式的每輪通訊描述如下:
第\(k\)個client節點執行:
-
對每一個模型成分\(m\)(\(m=1,…, M\))以及每一個局部樣本\(i\)(\(i=1,…,n_t\))執行\(\text{E}\)步驟
\[q_k(z^{(i)}_k=m)\leftarrow \frac{\pi _{km}\cdot \text{exp}\left(-\mathcal{l}(h_{w_{km}}(x_k^{(i)}), y_k^{(i)})\right)}
{\sum_{m’=1}^M \pi _{km’}\cdot \text{exp}\left(-\mathcal{l}(h_{w_{km’}}(x_k^{(i)}), y_k^{(i)})\right)}
\]對每一個模型成分\(m\)執行\(\text{M}\)步驟
\[ \pi_{km} = \frac{\sum_{i=1}^{n_t} q_k(z^{(i)}_k=m)}{n_t}
\]對每一個模型成分\(m\)執行\(J\)輪局部迭代:
\[w_{km} = w_{km} – \eta_j\sum_{i\in \mathcal{I}}q_k(z^{(i)}_k=m)\cdot \nabla_{w_{km}}\mathcal{l}(h_{w_{km}}(x_k^{(i)}), y_k^{(i)})
\]\(\mathcal{I}\)為每輪迭代有放回地從\(1,2,…|D_k|\)中採的隨機樣本索引集合
-
將新的\(w_k\)發往其鄰居節點
-
從鄰居節點接收新的\(w_k\)
-
對每一個模型成分\(m\)按(加權)平均更新模型參數:
\]
(其中加權參數\(\lambda\)為隨機初始化)
引用
-
[1] Shamir O, Srebro N, Zhang T. Communication-efficient distributed optimization using an approximate newton-type method[C]//International conference on machine learning. PMLR, 2014: 1000-1008.
-
[2] Wang S, Roosta F, Xu P, et al. Giant: Globally improved approximate newton method for distributed optimization[J]. Advances in Neural Information Processing Systems, 2018, 31.
-
[3] Mahajan D, Agrawal N, Keerthi S S, et al. An efficient distributed learning algorithm based on effective local functional approximations[J]. arXiv preprint arXiv:1310.8418, 2013.
-
[4] Smith V, Forte S, Chenxin M, et al. CoCoA: A general framework for communication-efficient distributed optimization[J]. Journal of Machine Learning Research, 2018, 18: 230.
-
[5] Zhang Y, Duchi J, Wainwright M. Divide and conquer kernel ridge regression: A distributed algorithm with minimax optimal rates[J]. The Journal of Machine Learning Research, 2015, 16(1): 3299-3340.
-
[6] McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.
-
[7] Stich S U. Local SGD converges fast and communicates little[C]///International Conference on Learning Representations, 2018.
-
[8] Hard A, Rao K, Mathews R, et al. Federated learning for mobile keyboard prediction[J]. arXiv preprint arXiv:1811.03604, 2018.
-
[9] Tian Li, Anit Kumar Sahu, Manzil Zaheer, Maziar Sanjabi, Ameet Talwalkar, and Virginia
Smith. 「Federated Optimization in Heterogeneous Networks」. In: Third MLSys Conference.2020. -
[10] Jiang Y, Konečný J, Rush K, et al. Improving federated learning personalization via model agnostic meta learning[J]. arXiv preprint arXiv:1909.12488, 2019.Presented at NeurIPS FL workshop 2019.
-
[11] Sim K C, Zadrazil P, Beaufays F. An investigation into on-device personalization of end-to-end automatic speech recognition models[J]. In Interspeech, 2019.
-
[12] Sattler F, Müller K R, Samek W. Clustered federated learning: Model-agnostic distributed multitask optimization under privacy constraints[J]. IEEE transactions on neural networks and learning systems, 2020, 32(8): 3710-3722.
-
[13]
T Dinh C, Tran N, Nguyen J. Personalized federated learning with moreau envelopes[J]. Advances in Neural Information Processing Systems, 2020, 33: 21394-21405. -
[14] Marfoq O, Neglia G, Bellet A, et al. Federated multi-task learning under a mixture of distributions[J]. Advances in Neural Information Processing Systems, 2021, 34.
-
[15]
Liu B, Guo Y, Chen X. PFA: Privacy-preserving Federated Adaptation for Effective Model Personalization[C]//Proceedings of the Web Conference 2021. 2021: 923-934.