神器:多卡同步的Batch Normalization

作者簡介

CW,廣東深圳人,畢業於中山大學(SYSU)數據科學與電腦學院,畢業後就業於騰訊電腦系統有限公司技術工程與事業群(TEG)從事Devops工作,期間在AI LAB實習過,實操過道路交通元素與醫療病例影像分割、影片實時人臉檢測與表情識別、OCR等項目。

目前也有在一些自媒體平台上參與外包項目的研發工作,項目專註於CV領域(傳統影像處理與深度學習方向均有)。

Foreword

使用多GPU卡訓練的情況下Batch Normalization(BN)可能會帶來很多問題,目前在很多深度學習框架如 Caffe、MXNet、TensorFlow 和 PyTorch 等,所實現的 BN 都是非同步的(unsynchronized),即歸一化操作是基於每個GPU上的數據獨立進行的。**

本文會為大家解析 BN 的多卡同步版本,這裡簡稱 SyncBN,首先解釋為何需要進行同步,接著為大家揭曉需要同步哪些資訊,最後結合基於 Pytorch 實現的程式碼解析實現過程中的關鍵部分。

Outline

i Why Synchronize BN:為何在多卡訓練的情況下需要對BN進行同步?

ii What is Synchronized BN:什麼是同步的BN,具體同步哪些東西?

iii How to implement:如何實現多卡同步的BN?

  1. 2次同步 vs 1次同步;

  2. 介紹torch.nn.DataParallel的前向回饋;

  3. 重載torch.nn.DataParallel.replicate方法;

  4. SyncBN 的同步註冊機制;

  5. SyncBN 的前向回饋

一·、Why Synchronize BN:

為何在多卡訓練的情況下需要對BN進行同步?

對於視覺分類和目標檢測等這類任務,batch size 通常較大,因此在訓練時使用 BN 沒太大必要進行多卡同步,同步反而會由於GPU之間的通訊而導致訓練速度減慢;

然而,對於語義分割等這類稠密估計問題而言,解析度高通常會得到更好的效果,這就需要消耗更多的GPU記憶體,因此其 batch size 通常較小,那麼每張卡計算得到的統計量可能與整體數據樣本具有較大差異,這時候使用 BN 就有一定必要性進行多卡同步了。

多卡情況下的BN(非同步)

這裡再提一點,如果使用pytorch的torch.nn.DataParallel,由於數據被可使用的GPU卡分割(通常是均分),因此每張卡上 BN 層的batch size(批次大小)實際為,下文也以torch.nn.DataParallel為背景進行說明。

二、What is Synchronized BN:

什麼是同步的BN,具體同步哪些東西?

由開篇至今,CW 一直提到「同步」這兩個字眼,那麼到底是什麼是同步的BN,具體同步的是什麼東西呢?

同步是發生在各個GPU之間的,需要同步的東西必然是它們互不相同的東西,那到底是什麼呢?或許你會說是它們拿到的數據,嗯,沒錯,但肯定不能把數據同步成一樣的了,不然這就和單卡訓練沒差別了,浪費了多張卡的資源…

現在,聰明的你肯定已經知道了,需要同步的是每張卡上計算的統計量,即 BN 層用到的(均值)和(方差),這樣子每張卡對其拿到的數據進行歸一化後的效果才能與單卡情況下對一個 batch 的數據歸一化後的效果相當。

因此,同步的 BN,指的就是每張卡上對應的 BN 層,分別計算出相應的統計量,接著基於每張卡的計算結果計算出統一的 ,然後相互進行同步,最後它們使用的都是同樣的

三、How to implement:

如何實現多卡同步的BN?

  1. 2次同步vs 1次同步

我們已經知道,在前向回饋過程中各卡需要同步均值和方差,從而計算出全局的統計量,或許大家第一時間想到的方式是先同步各卡的均值,計算出全局的均值,然後同步給各卡,接著各卡同步計算方差…這種方式當然沒錯,但是需要進行2次同步,而同步是需要消耗資源並且影響模型訓練速度的,那麼,是否能夠僅用1次同步呢?

全局的均值很容易通過同步計算得出,因此我們來看看方差的計算:

方差的計算,其中m為各GPU卡拿到的數據批次大小()。

由上可知,每張卡計算出,然後進行同步求和,即可計算出全局的方差。同時,全局的均值可通過各卡的同步求和得到,這樣,僅通過1次同步,便可完成全局均值及方差的計算。

1次同步完成全局統計量的計算

  1. 介紹nn.DataParallel的前向回饋

熟悉 pytorch 的朋友們應該知道,在進行GPU多卡訓練的場景中,通常會使用nn.DataParallel來包裝網路模型,它會將模型在每張卡上面都複製一份,從而實現並行訓練。這裡我自定義了一個類繼承nn.DataParallel,用它來包裝SyncBN,並且重載了nn.DataParallel的部分操作,因此需要先簡單說明下nn.DataParallel的前向回饋涉及到的一些操作。

nn.DataParallel的使用,其中DEV_IDS是可用的各GPU卡的id,模型會被複制到這些id對應的各個GPU上,DEV是主卡,最終反向傳播的梯度會被匯聚到主卡統一計算。

先來看看nn.DataParallel的前向回饋方法的源碼:

nn.DataParallel.forward

其中,主要涉及調用了以下4個方法:

(1) scatter:將輸入數據及參數均分到每張卡上;

(2) replicate:將模型在每張卡上複製一份(注意,卡上必須有scatter分割的數據存在!);

(3) parallel_apply:每張卡並行計算結果,這裡會調用被包裝的具體模型的前向回饋操作(在我們這裡就是會調用 SyncBN 的前向回饋方法);

(4) gather:將每張卡的計算結果統一匯聚到主卡。

注意,我們的關鍵在於重載replicate方法,原生的該方法只是將模型在每張卡上複製一份,並且沒有建立起聯繫,而我們的** SyncBN 是需要進行同步的,因此需要重載該方法,讓各張卡上的SyncBN ****通過某種數據結構和同步機制建立起聯繫。

  1. 重載nn.DataParallel.replicate方法

在這裡,可以設計一個繼承nn.DataParallel的子類DataParallelWithCallBack,重載了replicate方法,子類的該方法先是調用父類的replicate方法,然後調用一個自定義的回調函數(這也是之所以命名為DataParallelWithCallBack的原因),該回調函數用於將各卡對應的** SyncBN ****層關聯起來,使得它們可以通過某種數據結構進行通訊。

子類重載的replicate方法

自定義的回調函數,將各卡對應的Syn-BN層進行關聯,其中DataParallelContext是一個自定義類,其中沒有定義實質性的東西,作為一個上下文數據結構,實例化這個類的對象主要用於將各個卡上對應的Syn-BN層進行關聯;_sync_replicas是在Syn-BN中定義的方法,在該方法中其餘子卡上的Syn-BN層會向主卡進行註冊,使得主卡能夠通過某種數據結構和各卡進行通訊。

  1. Syn-BN的同步註冊機制

由上可知,我們需要在 SyncBN 中實現一個用於同步的註冊方法,SyncBN 中還需要設置一個用於管理同步的對象(下圖中的 _sync_master),這個對象有一個註冊方法,可將子卡註冊到其主卡。

在 SyncBN 的方法中,若是主卡,則將上下文管理器的 sync_master 屬性設置為這個管理同步的對象(_sync_master);否則,則調用上下文對象的同步管理對象的註冊方法,將該卡向其主卡進行註冊。

Syn-BN的同步註冊機制

主卡進行同步管理的類中註冊子卡的方法

主卡進行同步管理的類

子卡進行同步操作的類

  1. Syn-BN的前向回饋

如果你認真看完了以上部分,相信這部分你也知道大致是怎樣一個流程了。

首先,每張卡上的 SyncBN 各自計算出 mini-batch 的和以及平方和,然後主卡上的 SyncBN 收集來自各個子卡的計算結果,從而計算出全局的均值和方差,接著發放回各個子卡,最後各子卡的 SyncBN 收到來自主卡返回的計算結果各自進行歸一化(和縮放平移)操作。當然,主卡上的 SyncBN 計算出全局統計量後就可以進行它的歸一化(和縮放平移)操作了。

Syn-BN前向回饋(主卡)

Syn-BN前向回饋(子卡)

最後

在同步過程中,還涉及執行緒和條件對象的使用,這裡就不展開敘述了,感興趣的朋友可以到SyncBN源碼鏈接://github.com/chrisway613/Synchronized-BatchNormalization。另外,在資訊同步這部分,還可以設計其它方式進行優化,如果你有更好的意見,還請積極回饋,CW熱烈歡迎!

深藍學院 發起了一個讀者討論大家有什麼想法,歡迎和讀者溝通呀~