在K40小破卡訓練50層BERT Large的寶藏trick
- 2020 年 4 月 22 日
- AI
前言
雖然TPU的顯示記憶體令人羨慕,但是由於眾所周知的原因,絕大部分人還是很難日常化使用的。NVIDIA 又一直在擠牙膏,至今單卡的最大顯示記憶體也僅僅到32G(參考V100、DGX-2)。然而,訓練一個24層的BERT Large模型的時候,如果sequence length開滿512,那麼batch size僅僅開到8(有時候能到10)就把這寥寥32G的顯示記憶體打滿了。如果想訓練一個48層乃至100層的BERT Large,那完全是土豪們的遊戲了,需要瘋狂的模型並行+分散式多機訓練。
但!是!萬能的小夕前不久在Daxiang Dong大佬的安利下,發現了 @陳天奇 大佬2016年的一篇寶藏paper!

傳送門://arxiv.org/pdf/1604.06174.pdf
簡單的劃一下重點:
這篇paper用時間換空間的思想,在前向時只保存部分中間節點,在反向時重新計算沒保存的部分。論文通過這種機制,在每個batch只多計算一次前向的情況下,把n層網路的佔用顯示記憶體優化到了 ( √)O(n)。在極端情況下,仍可用 ( )O(nlogn)的計算時間換取到 ( )O(logn)的顯示記憶體佔用。在論文的實驗中,他們成功將將1000層的殘差網路從48G優化到了7G。且,這種方法同樣可以直接應用於RNN結構中。
看完摘要,瞬間感覺在小破卡上訓練BERT Large有救了!!!
此外,來快速過一遍paper中最重要的三點結論:
- 梯度計算等價,理論上沒有精度損失

2. 可以節省4倍+的顯示記憶體開銷

3. 訓練速度僅僅會被拖慢30%

不過論文發表在2016年,當時還沒有BERT,不過Baidu Paddle團隊補了一個BERT的實驗結果,發現在BERT上面只用22.5%的訓練速度損失就能換來5倍+的顯示記憶體開銷節省!相關實驗在本文末尾,不著急,接下來我們先一起分析一下在訓練階段時顯示記憶體為什麼容易不足。
感謝Baidu Paddle團隊提供本節圖文素材和測試數據
訓練階段顯示記憶體為何不足
深度學習中,網路的一次訓練包含前向計算、後向計算和優化三個步驟。

在這個過程中,前向計算會輸出大量的隱層變數Tensor,當模型層數加深時,Tensor數量可達成千上萬個。如Bert Large模型,單個Tensor可達到1GB,這些Tensor在顯示記憶體中累積,顯示記憶體很快就爆掉了
下圖是Bert Large模型在一次訓練過程中的顯示記憶體使用情況,可以明顯看到在前向計算過程中,顯示記憶體累積趨勢是一個陡峭的上升直線。而在反向計算過程中,這些隱層Tensor又會很快地被消耗掉,又是一個陡峭的下降曲線,顯示記憶體直接降到低位。

那麼問題來了,為什麼不直接刪除這些前向計算的Tensor呢?
答案很簡單,因為這些隱層的Tensor在反向的時會被用到(手動狗頭
來個簡單的證明。
假設前向計算中有一個矩陣乘法計算:
Y = W × X
對W求梯度:

很容易發現,對W求梯度的公式里有X,而X就是那個巨能吃顯示記憶體的隱層Tensor!
那我們是否可以暫時扔掉這些隱層Tensor,在反向計算時再把它們重新生成出來呢?當然可以,這正是上面這篇paper的思想。
重計算
顧名思義,”重計算”就是讓每個訓練迭代過程做兩次前向計算,看起來有點奇怪,實際上卻非常有效!對於剛剛那個吃顯示記憶體的Bert Large,支援重電腦制後,顯示記憶體佔用直接從175GB降低到20GB,陡峭的顯示記憶體上升直線變成了緩慢增長的Z形曲線,如下圖所示。

核心思想是將前向計算分割成多個段,將每個段的起始Tensor作為這個段的檢查點(checkpoints)。前向計算時,除了檢查點以外的其他隱層Tensor佔有的顯示記憶體可以及時釋放。反向計算用到這些隱層Tensor時,從前一個檢查點開始,重新進行這個段的前向計算,就可以重新獲得隱層Tensor。
重電腦制有點像玩單機遊戲。每過一個關卡就會保存一個檢查點,而隱層Tensor就相當於遊戲中任何一個時刻的影像。普通的訓練方式是打通一遍遊戲,並且將遊戲中所有時刻的影像保存下來;而重電腦制的思路是先把遊戲通關,保存檢查點,後面當收到某一時刻影像的請求時,再重打一遍這一關卡就可以了。

如下圖,舉一個簡單的例子,添加重電腦制前,前向計算中需要存儲的隱層是4個紅點;添加重電腦制後,需要存儲的隱層變為2個藍點, 從而節省了這部分記憶體。

雖然時間也是寶貴的,但重計算方法的性價比很高。在論文的實驗中,作者用30%的計算時間換取了4倍的記憶體空間。並且重計算只是重複了一次前向的過程,理論上精度沒有任何損失。
這麼寶藏的演算法當然也少不了開源實現。
開源實現
調研了一波,似乎TF沒有原生支援,但是生態里有OpenAI的第三方實現;pytorch和paddlepaddle中都有原生API支援
- Pytorch:
torch.utils.checkpoint
- PaddlePaddle:
optimizer.RecomputeOptimizer
不過pytorch的文檔比較略,也沒有提供更細緻的示例和相關數據,有興趣的小夥伴自行試一下。paddle框架中提供了詳細到哭的文檔,甚至還有一個現成的BERT+重計算的例子,以及非常詳細的實驗測試結果。這裡直接貼過來(真香系列
Paddle中實現顯示記憶體重計算大體分為三步:
- 定義一個經典的優化器,如SGD優化器;
- 在外麵包一層重計算優化器;
- 設置檢查點。
以MLP為例,只需要增加兩行程式碼就可以進入重計算模式
import paddle.fluid as fluid
# 定義MLP
def mlp(input_x, input_y, hid_dim=128, label_dim=2):
print(input_x)
fc_1 = fluid.layers.fc(input=input_x, size=hid_dim)
prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
sum_cost = fluid.layers.reduce_mean(cost)
return sum_cost, fc_1, prediction
input_x = fluid.layers.data(name="x", shape=[32], dtype='float32')
input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')
cost, fc_1, pred = mlp(input_x, input_y)
# 定義RecomputeOptimizer
sgd = fluid.optimizer.SGD(learning_rate=0.01)
recompute_optimizer = fluid.optimizer.RecomputeOptimizer(sgd)
# 設置checkpoints
recompute_optimizer._set_checkpoints([fc_1, pred])
# 運行優化演算法
recompute_optimizer.minimize(cost)
該示例github鏈接:
//github.com/PaddlePaddle/examples/blob/master/community_examples/recompute/demo.py
此外,官方還給出了一個BERT中做重計算的示例
github鏈接:
//github.com/PaddlePaddle/Fleet/tree/develop/examples/recompute/bert
BERT實驗結論(劃重點
根據上面paddle官方提供的BERT示例和實驗結果,得出以下幾個結論
結論一
在32GB顯示記憶體的Tesla V100顯示卡上應用重電腦制,可以訓練更大、更深的深度學習模型。當num_tokens為4096(batch size=32,seqlen=128)時,可以訓練100層的Bert網路。

從Github的實驗結果也可以看出,顯示記憶體上的收益比速度的損失要大很多:

在batch_size上提升了5倍,速度只降低了約1/5,且精度沒有損失。
結論二
模型訓練的batch size最大可提升為原來的5倍+,且只有少量的速度損失。
重電腦制在Bert Large這一模型上收益最大,最大batch size從93提升到562!而在VGG-16這種比較淺的模型上,重電腦制的收益則比較小。這充分符合重電腦制的設計理念:為了訓練更大、更深的模型。
結論三
在古董顯示卡Tesla K40顯示卡(12G顯示記憶體)上,訓練BERT Large時batch size可以開到130

最後,希望本文可以幫助大家在小破卡上盡情訓練BERT Large~