[源碼解析] PyTorch分佈式優化器(3)—- 模型並行

[源碼解析] PyTorch分佈式優化器(3)—- 模型並行

0x00 摘要

本系列介紹分佈式優化器,分為三篇文章,分別是基石篇,DP/DDP/Horovod 之中數據並行的優化器,PyTorch 分佈式優化器,按照深度遞進。本文介紹PyTorch 分佈式優化器和PipeDream之中的優化器,主要涉及模型並行(流水線並行)。

PyTorch分佈式其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[源碼解析]深度學習利器之自動微分(3) — 示例解讀

[源碼解析]PyTorch如何實現前向傳播(1) — 基礎類(上)

[源碼解析]PyTorch如何實現前向傳播(2) — 基礎類(下)

[源碼解析] PyTorch如何實現前向傳播(3) — 具體實現

[源碼解析] Pytorch 如何實現後向傳播 (1)—- 調用引擎

[源碼解析] Pytorch 如何實現後向傳播 (2)—- 引擎靜態結構

[源碼解析] Pytorch 如何實現後向傳播 (3)—- 引擎動態邏輯

[源碼解析] PyTorch 如何實現後向傳播 (4)—- 具體算法

[源碼解析] PyTorch 分佈式(1)——歷史和概述

[源碼解析] PyTorch 分佈式(2) —– DataParallel(上)

[源碼解析] PyTorch 分佈式(3) —– DataParallel(下)

[源碼解析] PyTorch 分佈式(4)——分佈式應用基礎概念

[源碼解析] PyTorch分佈式(5) —— DistributedDataParallel 總述&如何使用

[源碼解析] PyTorch分佈式(6) —DistributedDataParallel — 初始化&store

[源碼解析] PyTorch 分佈式(7) —– DistributedDataParallel 之進程組

[源碼解析] PyTorch 分佈式(8) ——– DistributedDataParallel之論文篇

[源碼解析] PyTorch 分佈式(9) —– DistributedDataParallel 之初始化

[源碼解析] PyTorch 分佈式(10)——DistributedDataParallel 之 Reducer靜態架構

[源碼解析] PyTorch 分佈式(11) —– DistributedDataParallel 之 構建Reducer和Join操作

[源碼解析] PyTorch 分佈式(12) —– DistributedDataParallel 之 前向傳播

[源碼解析] PyTorch 分佈式(13) —– DistributedDataParallel 之 反向傳播

[源碼解析] PyTorch 分佈式 Autograd (1) —- 設計

[源碼解析] PyTorch 分佈式 Autograd (2) —- RPC基礎

[源碼解析] PyTorch 分佈式 Autograd (3) —- 上下文相關

[源碼解析] PyTorch 分佈式 Autograd (4) —- 如何切入引擎

[源碼解析] PyTorch 分佈式 Autograd (5) —- 引擎(上)

[源碼解析] PyTorch 分佈式 Autograd (6) —- 引擎(下)

[源碼解析] PyTorch分佈式優化器(1)—-基石篇

[源碼解析] PyTorch分佈式優化器(2)—-數據並行優化器

為了更好的說明,本文代碼會依據具體情況來進行相應精簡。

0x01 前文回顧

之前無論是 DP, DDP,或者 Horovod,實質上的都是處理數據並行,比如 DDP 將相同的模型複製到所有 GPU,其中每個 GPU 使用輸入數據的不同分區。雖然它可以顯着加速訓練過程,但它不適用於模型太大而無法放入單個 GPU 的某些用例。於是人們引入了模型並行(model parallel)。

與此對應,優化器也需要做不同的修改以適應模型並行的需求。為了更好的分析,本文首先介紹單機模型並行,然後介紹PyTorch分佈式優化器。

0x02 單機模型

下面文字翻譯自 //pytorch.org/tutorials/intermediate/model_parallel_tutorial.html ,加入了一些自己的思考和理解。

模型並行被廣泛用於分佈式訓練。與DataParallel相比,模型並行將單個模型拆分到不同的 GPU 上,而不是在每個 GPU 上複製整個模型(具體來說,假設一個模型 m包含 10 層,當使用DataParallel,每個 GPU 將擁有這 10 層的全部副本,而當在兩個 GPU 上使用模型並行時,每個 GPU 可以託管 5 層)。

模型並行的高級思想是將模型的不同子網絡放置在不同的設備上,並相應地實現該forward方法以便跨設備移動中間輸出。由於單個設備上只有模型的一部分在運行,因此一組設備可以共同服務於一個更大的模型。

在這篇文章中,我們不會嘗試構建巨大的模型並將它們壓縮到有限數量的 GPU 中。相反,這篇文章側重於展示模型並行的想法。讀者可以將這些想法應用到實際應用中。

2.1 基本用法

讓我們從一個包含兩個線性層的玩具模型開始。要在兩個 GPU 上運行這個模型,只需將每個線性層放在不同的 GPU 上,並相應地移動輸入和中間輸出以匹配層設備。

import torch
import torch.nn as nn
import torch.optim as optim


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = torch.nn.Linear(10, 10).to('cuda:0')
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5).to('cuda:1')

    def forward(self, x):
        x = self.relu(self.net1(x.to('cuda:0')))
        return self.net2(x.to('cuda:1'))

ToyModel的代碼看起來與在單個 GPU 上的實現方式非常相似。只是修改了兩個部分:網絡構造部分和forward部分。

  • __init__方法使用了兩個to(device)語句用來在適當的設備上放置線性層,這樣就把整個網絡拆分成兩個部分,然後就可以分別運行在不同的GPU之上。
  • forward 方法使用了兩個to(device)語句用來在適當的設備上放置張量,這樣可以把一個layer的輸出結果通過tensor.to的語義拷貝到另一個layer所在的GPU上。

這是模型中唯一需要更改的地方。backward()torch.optim會可以應付這種情況,它們自動接管梯度,彷彿模型是一個GPU之上。在調用損失函數時,您只需要確保標籤與網絡的輸出在同一設備上。

model = ToyModel()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

optimizer.zero_grad()
outputs = model(torch.randn(20, 10))
labels = torch.randn(20, 5).to('cuda:1')
loss_fn(outputs, labels).backward()
optimizer.step()

這裡最重要的是 labels = torch.randn(20, 5).to(‘cuda:1’),這保證了標籤在 cuda:1’。

回憶一下之前forward的代碼:self.net2(x.to(‘cuda:1′))。這兩行代碼確保標籤與輸出在同一設備 cuda:1’ 上。

初始化之後如下:

+--------------------+                       +------------------------+
| cuda:0             |                       | cuda:1                 |
|                    |                       |                        |
|                    |                       |                        |
|                    |                       |                        |
|       net1(x)      |                       |        net2(x)         |
|                    |                       |                        |
|                    |                       |                        |
|                    |                       |                        |
+--------------------+                       +------------------------+

forward 操作和設定label之後如下,現在輸出和label都在GPU 1 之上:

               +--------------------+                       +------------------------+
               | cuda:0             |                       | cuda:1                 |
               |                    |                       |                        |
               |                    |                       |                        |
               |                    |                       |                        |
x.to('cuda:0')-------> net1(x)  +-------> x.to('cuda:1') +-------->  net2(x)         |
               |                    |                       |                        |
               |                    |                       |   labels.to('cuda:1')  |
               |                    |                       |                        |
               +--------------------+                       +------------------------+

2.2 將模型並行應用到現有模塊

還可以通過更改幾行代碼把一個現有的單 GPU 模塊轉換到在多個 GPU 上運行。下面的代碼展示了如何分解 torchvision.models.resnet50()到兩個 GPU之上。基本想法是繼承現有ResNet模塊,並在構建過程中將層拆分為兩個 GPU。然後,重載forward方法以便把兩個子網絡拼接起來,forward具體是通過相應地移動中間輸出來完成。

from torchvision.models.resnet import ResNet, Bottleneck

num_classes = 1000

class ModelParallelResNet50(ResNet):
    def __init__(self, *args, **kwargs):
        super(ModelParallelResNet50, self).__init__(
            Bottleneck, [3, 4, 6, 3], num_classes=num_classes, *args, **kwargs)

        self.seq1 = nn.Sequential(
            self.conv1,
            self.bn1,
            self.relu,
            self.maxpool,

            self.layer1,
            self.layer2
        ).to('cuda:0')

        self.seq2 = nn.Sequential(
            self.layer3,
            self.layer4,
            self.avgpool,
        ).to('cuda:1')

        self.fc.to('cuda:1')

    def forward(self, x):
        x = self.seq2(self.seq1(x).to('cuda:1'))
        return self.fc(x.view(x.size(0), -1))

上述實現解決了模型太大而無法放入單個 GPU 的情況下的問題。但是,您可能已經注意到,即使您的模型適合這種情況,它也許會比在單個 GPU 上運行要慢。這是因為,在任何時候,兩個 GPU 中只有一個在工作,而另一個坐在那裡什麼也不做。在 layer2layer3 之中需要把中間輸出從cuda:0拷貝到 cuda:1,這將進一步引起性能惡化。

讓我們運行一個實驗,以更從一個可以量化地角度來了解執行時間。在這個實驗中,我們通過運行隨機輸入和標籤來訓練ModelParallelResNet50和現有 torchvision.models.resnet50()。訓練後,模型不會產生任何有用的預測,但我們可以對執行時間有一個合理的了解。

import torchvision.models as models

num_batches = 3
batch_size = 120
image_w = 128
image_h = 128


def train(model):
    model.train(True)
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    one_hot_indices = torch.LongTensor(batch_size) \
                           .random_(0, num_classes) \
                           .view(batch_size, 1)

    for _ in range(num_batches):
        # generate random inputs and labels
        inputs = torch.randn(batch_size, 3, image_w, image_h)
        labels = torch.zeros(batch_size, num_classes) \
                      .scatter_(1, one_hot_indices, 1)

        # run forward pass
        optimizer.zero_grad()
        outputs = model(inputs.to('cuda:0'))

        # run backward pass
        labels = labels.to(outputs.device)
        loss_fn(outputs, labels).backward()
        optimizer.step()

上述train(model)方法使用nn.MSELoss用作損失函數,使用optim.SGD作為優化器。它模仿 128 X 128圖像的訓練,這些圖像被組織成 3 個批次,每批次包含 120 個圖像。然後,我們使用timeit來運行 train(model) 10 次,並且用標準差來繪製執行時間。

import matplotlib.pyplot as plt
plt.switch_backend('Agg')
import numpy as np
import timeit

num_repeat = 10

stmt = "train(model)"

setup = "model = ModelParallelResNet50()"
mp_run_times = timeit.repeat(
    stmt, setup, number=1, repeat=num_repeat, globals=globals())
mp_mean, mp_std = np.mean(mp_run_times), np.std(mp_run_times)

setup = "import torchvision.models as models;" + \
        "model = models.resnet50(num_classes=num_classes).to('cuda:0')"
rn_run_times = timeit.repeat(
    stmt, setup, number=1, repeat=num_repeat, globals=globals())
rn_mean, rn_std = np.mean(rn_run_times), np.std(rn_run_times)


def plot(means, stds, labels, fig_name):
    fig, ax = plt.subplots()
    ax.bar(np.arange(len(means)), means, yerr=stds,
           align='center', alpha=0.5, ecolor='red', capsize=10, width=0.6)
    ax.set_ylabel('ResNet50 Execution Time (Second)')
    ax.set_xticks(np.arange(len(means)))
    ax.set_xticklabels(labels)
    ax.yaxis.grid(True)
    plt.tight_layout()
    plt.savefig(fig_name)
    plt.close(fig)


plot([mp_mean, rn_mean],
     [mp_std, rn_std],
     ['Model Parallel', 'Single GPU'],
     'mp_vs_rn.png')

img

結果表明,模型並行需要的執行時間比但GPU實現需要的時間長 4.02/3.75-1=7%。所以我們可以得出結論,在 GPU 之間來回複製張量大約有 7% 的開銷。

2.3 問題與方案

2.3.1 目前狀況

我們總結一下目前狀況:

  • 雖然有多塊GPU,但是在整個執行過程中的每一個時刻,只有一個GPU在計算,其他GPU處於空閑狀態。
  • 另外還有中間計算結果在GPU之間的拷貝工作,這也使得性能惡化。

因此我們需要針對這兩個問題進行針對性處理:

  • 讓所有 GPU 都動起來。
  • 減少拷貝傳輸時間。

2.3.2 解決方案

兩個問題解決方案如下:

讓所有 GPU 都動起來的一種選擇是加入流水線機制:將每個批次進一步劃分,組成一個分割(split )管道,這樣當一個分割到達第二個子網絡時,可以將接下來的分割送入第一個子網絡。這樣,兩個連續的分割(split )就可以在兩個 GPU 上同時運行。

為什麼可以做到這一點?這是因為 CUDA 的異步並行執行邏輯。

  • CUDA 的一些操作是異步的,比如:核發射,設備間數據拷貝,主機和設備內拷貝小存儲塊等等。
  • 幾乎所有具有計算能力1.1及更高計算能力的CUDA設備都支持並發複製和核執行,即數據拷貝和數值計算可以並行。
  • 一些計算能力2.x的設備可並發執行多個內核。
  • 在一些計算能力2.x的設備上,兩個方向的拷貝可以並行(GPU到CPU,CPU到GPU)。

如何減少拷貝傳輸時間?這個可以使用一些硬件和軟件的結合來增加帶寬減少延遲,比如:

  • 硬件層面包括:單機內部的PCIe、NVlink、NVSwitch;多機之間的RDMA網絡(IB或RoCE)。
  • 軟件堆棧包括:GPUDirect的一系列技術:P2P(Peer-to-Peer),RDMA,Async,Storage等。

PyTorch使用了NCCL庫(基於CUDA計算)。

2.4 通過流水線輸入加速

在接下來的實驗中,我們進一步將每個”120 個圖像批次” 分成 “20 個圖像分割(split)”。由於 PyTorch 異步啟動 CUDA 操作,因此實現不需要產生多個線程來實現並發。

class PipelineParallelResNet50(ModelParallelResNet50):
    def __init__(self, split_size=20, *args, **kwargs):
        super(PipelineParallelResNet50, self).__init__(*args, **kwargs)
        self.split_size = split_size

    def forward(self, x):
        splits = iter(x.split(self.split_size, dim=0))
        s_next = next(splits)
        s_prev = self.seq1(s_next).to('cuda:1')
        ret = []

        for s_next in splits:
            # A. s_prev runs on cuda:1
            s_prev = self.seq2(s_prev)
            ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))

            # B. s_next runs on cuda:0, which can run concurrently with A
            s_prev = self.seq1(s_next).to('cuda:1')

        s_prev = self.seq2(s_prev)
        ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))

        return torch.cat(ret)


setup = "model = PipelineParallelResNet50()"
pp_run_times = timeit.repeat(
    stmt, setup, number=1, repeat=num_repeat, globals=globals())
pp_mean, pp_std = np.mean(pp_run_times), np.std(pp_run_times)

plot([mp_mean, rn_mean, pp_mean],
     [mp_std, rn_std, pp_std],
     ['Model Parallel', 'Single GPU', 'Pipelining Model Parallel'],
     'mp_vs_rn_vs_pp.png')

請注意,設備到設備張量複製操作會在源設備和目標設備上的當前流上進行同步。如果創建多個流,則必須確保複製操作正確同步。在完成複製操作之前寫入源張量或讀取/寫入目標張量可能會導致未定義的行為。上述實現僅在源設備和目標設備上使用默認流,因此沒有必要強制執行額外的同步操作。

img

實驗結果表明,把流水線輸入加入到 ResNet50 的模型並行之後,訓練過程加快了大約3.75/2.51-1=49%。雖然它離理想的 100% 加速還很遠。由於我們在流水線並行實現中引入了一個新參數split_sizes,因此尚不清楚此新參數如何影響整體訓練時間。直觀地說,使用小的split_size會導致許多微小的 CUDA 核啟動,而使用大split_size結果會導致在第一次和最後一次拆分期間產生相對較長的空閑時間。兩者都不是最優的。split_size這個特定實驗可能有一個最佳配置。讓我們嘗試通過使用幾個不同的split_size值運行實驗來找到它。

means = []
stds = []
split_sizes = [1, 3, 5, 8, 10, 12, 20, 40, 60]

for split_size in split_sizes:
    setup = "model = PipelineParallelResNet50(split_size=%d)" % split_size
    pp_run_times = timeit.repeat(
        stmt, setup, number=1, repeat=num_repeat, globals=globals())
    means.append(np.mean(pp_run_times))
    stds.append(np.std(pp_run_times))

fig, ax = plt.subplots()
ax.plot(split_sizes, means)
ax.errorbar(split_sizes, means, yerr=stds, ecolor='red', fmt='ro')
ax.set_ylabel('ResNet50 Execution Time (Second)')
ax.set_xlabel('Pipeline Split Size')
ax.set_xticks(split_sizes)
ax.yaxis.grid(True)
plt.tight_layout()
plt.savefig("split_size_tradeoff.png")
plt.close(fig)

img

結果表明,設置split_size為 12 實現了最快的訓練速度,從而導致3.75/2.43-1=54%加速。我們仍有機會進一步加快訓練進程。例如,目前所有cuda:0上的操作都放在其默認流上。這意味着下一個拆分的計算不能與上一個拆分的複製操作重疊。但是,由於 prev 和 next 拆分(split)是不同的張量,因此將一個張量的計算與另一個張量的拷貝重疊起來是沒有問題的。這種實現需要在兩個GPU上使用多個流,並且不同的子網結構需要不同的流管理策略。由於沒有一個適用於所有模型並行用例的通用的多流解決方案,我們不會在本教程中討論它。

這篇文章展示了幾個性能測量。在您自己的機器上運行相同的代碼時,您可能會看到不同的性能結果,因為結果取決於底層硬件和軟件。要為您的環境獲得最佳性能,正確的方法是首先生成結果曲線,並根據曲線來確定最佳分割大小,然後將該分割大小應用到管道輸入之上。

0x03 分佈式問題和方案

我們已經了解了單機之上的模型並行,接下來就要看模型跨越多個服務器的分佈式模型並行訓練。

3.1 思路

我們先設想一下如果自己實現分佈式優化器,應該如何處理。

假如模型分為三個部分,有三個主機可以訓練。

+----------------------------------------------------------------+
| Model                                                          |
|                                                                |
| +-----------------+  +------------------+  +-----------------+ |
| | Sub+model 1     |  | Sub+model 2      |  | Sub+model 3     | |
| |                 |  |                  |  |                 | |
| |                 |  |                  |  |                 | |
| +-----------------+  +------------------+  +-----------------+ |
|                                                                |
+----------------------------------------------------------------+

+-------------------+  +------------------+  +-----------------+
| Host 1            |  | Host 2           |  | Host 3          |
|                   |  |                  |  |                 |
|                   |  |                  |  |                 |
|                   |  |                  |  |                 |
|                   |  |                  |  |                 |
|                   |  |                  |  |                 |
+-------------------+  +------------------+  +-----------------+

我們會顯式的把這三個部分分別部署到三個主機之上,在三個主機之上都有一套自己的訓練代碼,每個訓練代碼之中都有自己的本地優化器負責優化本地子模型的參數。

+---------------------+         +---------------------+         +---------------------+
| Host 1              |         | Host 2              |         | Host 3              |
|                     |         |                     |         |                     |
| +-----------------+ |         | +-----------------+ |         | +-----------------+ |
| | Sub model 1     | |forward  | | Sub model 2     | |forward  | | Sub model 3     | |
| |                 | +-------> | |                 | +-------> | |                 | |
| |_parameters <--+ | |         | |_parameters <--+ | |         | |_parameters <--+ | |
| |               | | | <-------+ |               | | | <-------+ |               | | |
| |               | | | backward| |               | | | backward| |               | | |
| +-----------------+ |         | +-----------------+ |         | +-----------------+ |
|                 |   |         |                 |   |         |                 |   |
|                 |   |         |                 |   |         |                 |   |
| ------------------+ |         | +-----------------+ |         | +-----------------+ |
| |Optimizer 1    | | |         | | Optimizer 2   | | |         | | Optimizer 3   | | |
| |               | | |         | |               | | |         | |               | | |
| |    step() +---+ | |         | |    step() +---+ | |         | |     step()+---+ | |
| |                 | |         | |                 | |         | |                 | |
| +-----------------+ |         | +-----------------+ |         | +-----------------+ |
+---------------------+         +---------------------+         +---------------------+

但是這樣有幾個問題需要我們解決:

  • 如何劃分模型到不同機器上?如何把代碼分割到不同機器上?
  • 如何跨機器把前向傳播,後向傳播連接在一起?
  • 各個機器之間是同步運行還是異步運行?
  • 如果是同步,如何讓整個系統用同一個步驟運行?
  • 如何把這些優化器結合在一起?還是優化器各做各的,彼此沒有任何聯繫?
  • 如何儘力讓用戶少修改代碼?
  • 如何能讓開發者感覺就是開發本地版本代碼?

經過思考就會發現,這裏面錯綜複雜。如果我們自己基於 PyTorch 來實現,你會發現這可能最終結果是一個 PipeDream。於是我們看看 PyTorch 如何處理。

3.2 PyTorch 的思路

PyTorch 使用 RPC 來解決這些問題。

3.2.1 四大天王

前文我們提到了,PyTorch的分佈式框架使用了四大天王:

  • **遠程過程調用 (RPC) ** 使用給定的參數在指定的worker上運行函數並獲取返回值或創建對返回值的引用。有三個主要的 API: rpc_sync()(同步)、 rpc_async()(異步)和 remote()(異步並返回對遠程返回值的引用)。
    • 如果用戶代碼在沒有返回值的情況下無法繼續,請使用同步 API。
    • 否則,使用異步 API 獲取 Future,並在調用者需要返回值時等待 Future。
    • remote() API適用如下情況:需要在遠程創建某些內容但從不需要將其獲取給調用者。
  • 遠程引用 (RRef) 是指向本地或遠程對象的分佈式共享指針,就是本地或者跨機器的變量引用。
  • **Distributed Autograd **將所有參與前向傳播 worker的本地 autograd 引擎縫合在一起,並在後向傳播期間自動聯繫它們以計算梯度。在進行前向傳遞如果需要跨越多台機器時,這尤其有用,例如分佈式模型並行訓練、參數服務器訓練等。 有了這個特性,用戶代碼不再需要擔心如何跨 RPC 邊界發送梯度和應該以什麼順序啟動本地 autograd 引擎,如果前向傳遞中有嵌套和相互依賴的 RPC 調用,這可能會變得非常複雜。
  • 分佈優化器的構造需要一個 Optimizer()(例如,SGD()Adagrad()等)和一個RRefs的參數列表。即,在每個不同的Ref所有者之上創建一個 Optimizer()實例,然後運行step()相應更新參數。當用戶進行分佈式前向和後向傳播時,參數和梯度將分散在多個 worker 中,因此需要對每個相關 worker 進行優化。Distributed Optimizer 將所有這些本地優化器合而為一,並提供了簡潔的構造函數和step()API。

3.2.2 邏輯關係

我們使用官方圖示,可以看到 PyTorch 分佈式包的內部架構和邏輯關係。分佈式優化器基於另外三者之上。

我們會在後續結合代碼進行講解如何使用。

0x04 PyTorch 分佈式優化器

首先說明一下,為了清晰的分析,我們後續忽略所有 script 相關部分。

4.1 示例

DistributedOptimizer 的使用方法如下:

  1. 獲取要優化的遠程參數列表 (RRef)。 這些也可以是包裝在本地RRef中的本地參數。
  2. Optimizer 類作為本地優化器來運行所有的RRef owner。
  3. 分佈式優化器在每個 worker 節點上創建其本地優化器的實例,並持有這些本地優化器的 RRef。
  4. 當調用 torch.distributed.optim.DistributedOptimizer.step() 時,分佈式優化器使用 RPC 在適當的遠程 worker 上遠程執行所有本地優化器。torch.distributed.optim.DistributedOptimizer.step 必須獲得一個分佈式autograd context_id作為輸入,本地優化器將把梯度保存在相關的context之中。
  5. 如果多個並發的分佈式優化器同時更新工作器上的相同參數,則這些更新將通過鎖序列化。

看起來有點抽象,我們需要一步一步分析。

4.2 簡單的端到端示例

綜上所述,以下是使用分佈式 autograd 和分佈式優化器的簡單端到端示例。 如果將代碼放入名為「 dist_autograd_simple.py」的文件中,則可以使用命令MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py運行該代碼:

import multiprocessing as mp
import torch
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer

def random_tensor():
    return torch.rand((3, 3), requires_grad=True)

def _run_process(rank, dst_rank, world_size):
    name = "worker{}".format(rank)
    dst_name = "worker{}".format(dst_rank)

    # Initialize RPC.
    rpc.init_rpc(
        name=name,
        rank=rank,
        world_size=world_size
    )

    # Use a distributed autograd context.
    with dist_autograd.context() as context_id: # 本地優化器將把梯度保存在相關的context之中
        # Forward pass (create references on remote nodes).
        rref1 = rpc.remote(dst_name, random_tensor) # 在遠端創建一個 random_tensor
        rref2 = rpc.remote(dst_name, random_tensor) # 在遠端創建一個 random_tensor
        loss = rref1.to_here() + rref2.to_here() # 獲取要優化的遠程參數列表 (`RRef`)

        # Backward pass (run distributed autograd).
        dist_autograd.backward([loss.sum()])

        # Build DistributedOptimizer.
        dist_optim = DistributedOptimizer( # 分佈式優化器在每個 worker 節點上創建其本地Optimizer的實例,並將持有這些本地優化器的 RRef。
        optim.SGD,
        [rref1, rref2],
        lr=0.05,
        )

        # Run the distributed optimizer step.
        dist_optim.step()

def run_process(rank, dst_rank, world_size):
    _run_process(rank, dst_rank, world_size)
    rpc.shutdown()

processes = []

# Run world_size workers.
world_size = 2
for i in range(world_size):
    p = mp.Process(target=run_process, args=(i, (i + 1) % 2, world_size))
    p.start()
    processes.append(p)

for p in processes:
    p.join()

4.3 定義

DistributedOptimizer 得到了分散在 workers 之上參數的遠端引用,然後對於這些參數在本地運行優化器。

對於單個worker來說,如果它接受到來自相同或不同客戶端的~torch.distributed.optim.DistributedOptimizer.step的並發調用,則這些調用將會在這個worker之上串行進行,因為每個worker的優化器一次只能處理一組梯度。

DistributedOptimizer 的定義其實看不到啥東西,這是因為 Python 的語言特性,我們沒辦法在統一地方看到類的成員變量,但是有一個 functional_optim_map 值得我們關注。 這裡是把每個內置優化器又配置了一個對應的新優化器,比如 optim.Adagrad 對應的是 _FunctionalAdagrad,我們就選擇一個新優化器看看。

class DistributedOptimizer:
    """
    DistributedOptimizer takes remote references to parameters scattered
    across workers and applies the given optimizer locally for each parameter.

    This class uses :meth:`~torch.distributed.autograd.get_gradients` in order
    to retrieve the gradients for specific parameters.

    Concurrent calls to
    :meth:`~torch.distributed.optim.DistributedOptimizer.step`,
    either from the same or different clients, will
    be serialized on each worker -- as each worker's optimizer can only work
    on one set of gradients at a time. However, there is no guarantee that
    the full forward-backward-optimizer sequence will execute for one client
    at a time. This means that the gradients being applied may not correspond
    to the latest forward pass executed on a given worker. Also, there is no
    guaranteed ordering across workers.

    `DistributedOptimizer` creates the local optimizer with TorchScript enabled
    by default, so that optimizer updates are not blocked by the Python Global
    Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed
    Model Parallel). This feature is currently enabled for most optimizers. You
    can also follow `the recipe`__ in PyTorch tutorials to enable TorchScript support
    for your own custom optimizers.

    Args:
        optimizer_class (optim.Optimizer): the class of optimizer to
            instantiate on each worker.
        params_rref (list[RRef]): list of RRefs to local or remote parameters
            to optimize.
        args: arguments to pass to the optimizer constructor on each worker.
        kwargs: arguments to pass to the optimizer constructor on each worker.
        
    """
    
    # dict to map a user passed in optimizer_class to a functional
    # optimizer class if we have already defined inside the
    # distributed.optim package, this is so that we hide the
    # functional optimizer to user and still provide the same API.
    functional_optim_map = {
        optim.Adagrad: _FunctionalAdagrad,
        optim.Adam: _FunctionalAdam,
        optim.AdamW: _FunctionalAdamW,
        optim.SGD: _FunctionalSGD,
        optim.Adadelta: _FunctionalAdadelta,
        optim.RMSprop: _FunctionalRMSprop,
        optim.Rprop: _FunctionalRprop,
        optim.Adamax: _FunctionalAdamax,
    }        

4.3.1_FunctionalSGD

optim.SGD 對應的是 _FunctionalSGD。其代碼位於 torch/distributed/optim/functional_sgd.py。具體是定義一個與TorchScript兼容的函數式SGD優化器,PyTorch 將以函數的方式使用這些優化器。在更新參數時,PyTorch 不使用 param.grad,而是顯式地允許分佈式優化器將梯度傳遞給 step 函數。注意:此優化器應該僅由分佈式優化器內部使用,而不是向用戶公開。

import torch.optim._functional as F

# Define a TorchScript compatible Functional SGD Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
# we explicitly allow the distributed optimizer pass gradients to
# the `step` function. In this way, we could separate the gradients
# and parameters and allow multithreaded trainer to update the
# parameters without data traces on accumulating to the same .grad.
# NOTE: This should be only used by distributed optimizer internals
# and not meant to expose to the user.
@torch.jit.script
class _FunctionalSGD(object):
    def __init__(
        self,
        params: List[Tensor],
        lr: float = 1e-2,
        momentum: float = 0.0,
        dampening: float = 0.0,
        weight_decay: float = 0.0,
        nesterov: bool = False
    ):
        self.defaults = {
            "lr": lr,
            "momentum": momentum,
            "dampening": dampening,
            "weight_decay": weight_decay,
        }
        self.nesterov = nesterov
        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})

        # NOTE: we only have one param_group and don't allow user to add additional
        # param group as it's not a common use case.
        self.param_group = {"params": params}

    def step(self, gradients: List[Optional[Tensor]]):
        params = self.param_group['params']
        grads = []
        momentum_buffer_list: List[Optional[Tensor]] = []
        lr = self.defaults['lr']
        weight_decay = self.defaults['weight_decay']
        momentum = self.defaults['momentum']
        dampening = self.defaults['dampening']

        for param, gradient in zip(params, gradients):
            if gradient is not None:
                grads.append(gradient)

                if param not in self.state:
                    self.state[param] = {}

                state = self.state[param]
                if 'momentum_buffer' not in state:
                    momentum_buffer_list.append(None)
                else:
                    momentum_buffer_list.append(state['momentum_buffer'])

        with torch.no_grad():
            F.sgd(params,
                  grads,
                  momentum_buffer_list,
                  weight_decay=weight_decay,
                  momentum=momentum,
                  lr=lr,
                  dampening=dampening,
                  nesterov=self.nesterov)

        # update momentum_buffers in state
        for i, p in enumerate(params):
            state = self.state[p]
            momentum_buffer = momentum_buffer_list[i]
            if momentum_buffer is not None:
                state['momentum_buffer'] = momentum_buffer

4.4 初始化

4.4.1 初始化

這部分代碼主要對應了:分佈式優化器在每個 worker 節點上創建其本地Optimizer的實例,並將持有這些本地優化器的 RRef。具體結合我們之前示例代碼來看,params_rref 就是需要優化的參數列表,每個會對應一個優化器,就是 DistributedOptimizer 生成了所有節點上的優化器,以 rpc.RRef(_LocalOptimizer) 形式保存在 self.remote_optimizers 之中。

def __init__(self, optimizer_class, params_rref, *args, **kwargs):
    per_worker_params_rref = defaultdict(list)
    for param in params_rref: # 
        per_worker_params_rref[param.owner()].append(param) # [owner] = param

    # 拿到對應的本地優化器類    
    if optimizer_class in DistributedOptimizer.functional_optim_map and jit._state._enabled:
        optim_ctor = DistributedOptimizer.functional_optim_map.get(optimizer_class)
    else:
        optim_ctor = optimizer_class
    self.is_functional_optim = (optim_ctor != optimizer_class)

    if self.is_functional_optim:
        optimizer_new_func = _new_script_local_optimizer
    else:
        optimizer_new_func = _new_local_optimizer # 下面會介紹

    remote_optim_futs = []
    for worker, param_rrefs in per_worker_params_rref.items():
        remote_optim_rref_fut = rpc.rpc_async(
            worker, # 在 worker 之上生成其本地優化器
            optimizer_new_func, # rpc_async 會調用
            args=(optim_ctor, param_rrefs) + args,
            kwargs=kwargs,
        )
        remote_optim_futs.append(remote_optim_rref_fut)

    self.remote_optimizers = _wait_for_all(remote_optim_futs) # 本地保存的遠端各個節點上優化器

4.4.2 生成優化器 _LocalOptimizer

_new_local_optimizer 是生成了_LocalOptimizer

def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
    return rpc.RRef(
        _LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs))

_LocalOptimizer 是本地優化器,其運行在遠端worker節點之上,master 擁有這些優化器的代理。

class _LocalOptimizer(object):
    # Ideally we would only need to share a lock for instances of
    # _LocalOptimizer that deal with the same parameters. We are
    # making a simplifying assumption here that if there is more
    # than one instance of _LocalOptimizer per worker, they will
    # be optimizing the same parameters (e.g. each data parallel
    # trainer will create its own instance of _LocalOptimizer but
    # they will all optimize the same parameters on each worker)
    global_lock = Lock()

    def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
        self._local_params = [rref.local_value() for rref in local_params_rref]
        self.optim = optim_cls( # 優化器還是普通的優化器,因為優化器代碼還是之前的,只是優化的參數對象變成了異地節點參數
            self._local_params, # 用參數代理初始化
            *args,
            **kwargs)

    def step(self, autograd_ctx_id):
        # 獲取到分佈上下文裏面計算好的梯度
        all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)

        with _LocalOptimizer.global_lock:
            for param, grad in all_local_grads.items():
                param.grad = grad
            self.optim.step() # 參數優化

4.4.3 等待完成

用 _wait_for_all 等待異步完成。

def _wait_for_all(rpc_futs):
    # TODO: improve error propagation
    exception = None
    results = []
    for fut in rpc_futs:
        try:
            results.append(fut.wait())
        except Exception as e:
            results.append(e)
            exception = e
    if exception is not None:
        raise exception
    return results

對應的邏輯如下:

  • ref1, ref2 是遠端待優化的參數,都是 torch.rand((3, 3))。
  • optim_rref1,optim_rref2 分別是 Node 2,Node 3上本地優化器的 rref。
                                                      +----------------------------------+
+--------------------------------------------+        | Node 2                   worker 1|
| Node 1                              master |        |                                  |
|                                            |        |    +--------------------------+  |
|                                            |        |    | _LocalOptimizer          |  |
|  +---------------------------------+       |        |    |                          |  |
|  | DistributedOptimizer            |       |        |    |                          |  |
|  |                                 |       |        |    |   optim = _FunctionalSGD |  |
|  |                                 |       |        |    |                          |  |
|  |     remote_optimizers = [       |       |        |    |   _local_params = rref1  |  |
|  |                optim_rref1 +------------------------> |                     +    |  |
|  |                ,                |       |        |    |                     |    |  |
|  |                optim_rref2 +-------+    |        |    +--------------------------+  |
|  |                ]                |  |    |        |                          |       |
|  |                                 |  |    |        |                          v       |
|  |                                 |  |    |   +-------------->   torch.rand((3, 3))   |
|  |                                 |  |    |   |    |                                  |
|  +---------------------------------+  |    |   |    +----------------------------------+
|                                       |    |   |
|                                       |    |   |    +-----------------------------------+
|                                       |    |   |    | Node 3                   worker 2 |
|                                       |    |   |    |                                   |
|                                       |    |   |    |     +--------------------------+  |
|                                       |    |   |    |     | _LocalOptimizer          |  |
|                                       |    |   |    |     |                          |  |
|                                       +-----------------> |                          |  |
|                                            |   |    |     |   optim = _FunctionalSGD |  |
|                                            |   |    |     |                          |  |
|                             rref1 +------------+    |     |   _local_params = rref2  |  |
|                                            |        |     |                     +    |  |
|                                            |        |     |                     |    |  |
|                             rref2 +------------+    |     +--------------------------+  |
|                                            |   |    |                           |       |
|                                            |   |    |                           |       |
|                                            |   |    |                           v       |
|                                            |   +--------------->   torch.rand((3, 3))   |
|                                            |        |                                   |
+--------------------------------------------+        +-----------------------------------+

4.5 更新參數

DistributedOptimizer 在優化時候,會遍歷保存的優化器,逐一調用 _local_optimizer_step。

為什麼可以在Node 1 之上統一調用這些遠端優化器?因為最後更新所有參數完畢之後,才能調用下一輪前向傳播,所以可以統一調用然後等待都完成

def step(self, context_id):
    """
    Performs a single optimization step.

    This will call :meth:`torch.optim.Optimizer.step` on each worker
    containing parameters to be optimized, and will block until all workers
    return. The provided ``context_id`` will be used to retrieve the
    corresponding :class:`~torch.distributed.autograd.context` that
    contains the gradients that should be applied to the parameters.

    Args:
        context_id: the autograd context id for which we should run the
            optimizer step.
    """
    dist_autograd._is_valid_context(context_id)

    if self.is_functional_optim:
        optimizer_step_func = _script_local_optimizer_step
    else:
        optimizer_step_func = _local_optimizer_step # 

    rpc_futs = []
    for optimizer in self.remote_optimizers: # 遍歷 _LocalOptimizer
        rpc_futs.append(rpc.rpc_async( # 異步異地調用
            optimizer.owner(),
            optimizer_step_func, # 逐一調用
            args=(optimizer, context_id),
        ))
    _wait_for_all(rpc_futs)

4.5.1 本地優化

_local_optimizer_step 就是得到 _LocalOptimizer,然後調用其 step。

def _local_optimizer_step(local_optim_rref, autograd_ctx_id):
    local_optim = local_optim_rref.local_value()
    local_optim.step(autograd_ctx_id)

_LocalOptimizer 的 step 首先獲取分佈式梯度,然後用這個梯度進行參數優化。

class _LocalOptimizer(object):

    def step(self, autograd_ctx_id):
        # 獲取到分佈上下文裏面計算好的梯度
        all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)

        with _LocalOptimizer.global_lock:
            for param, grad in all_local_grads.items():
                param.grad = grad
            self.optim.step() # 參數優化

4.5.2 獲取分佈式梯度

get_gradients 的 Python 代碼其實沒有意義。

def get_gradients(context_id): # real signature unknown; restored from __doc__
    """
    get_gradients(context_id: int) -> Dict[Tensor, Tensor]
    
    Retrieves a map from Tensor to the appropriate gradient for that Tensor
    accumulated in the provided context corresponding to the given ``context_id``
    as part of the distributed autograd backward pass.
    
    Arguments:
        context_id(int): The autograd context id for which we should retrieve the
                         gradients.
    
    Returns:
        A map where the key is the Tensor and the value is the associated gradient
        for that Tensor.
    
    Example::
        >>> import torch.distributed.autograd as dist_autograd
        >>> with dist_autograd.context() as context_id:
        >>>     t1 = torch.rand((3, 3), requires_grad=True)
        >>>     t2 = torch.rand((3, 3), requires_grad=True)
        >>>     loss = t1 + t2
        >>>     dist_autograd.backward(context_id, [loss.sum()])
        >>>     grads = dist_autograd.get_gradients(context_id)
        >>>     print(grads[t1])
        >>>     print(grads[t2])
    """
    return {}

其對應 C++ 的位於 torch/csrc/jit/runtime/register_distributed_ops.cpp。是調用了上下文的函數。

// Implementations located in
// torch/csrc/jit/runtime/register_distributed_ops.cpp
TORCH_LIBRARY_IMPL(aten, CatchAll, m) {
  m.impl("get_gradients", [](int64_t context_id) {
    const auto& autogradContext =
        dist_autograd::DistAutogradContainer::getInstance().retrieveContext(
            context_id);
    return autogradContext->getGradients(); // 上下文
  });
}

C++世界的 getGradients 代碼如下:

const c10::Dict<torch::Tensor, torch::Tensor> DistAutogradContext::
    getGradients() const {
  std::lock_guard<std::mutex> guard(lock_);
  // block current streams before accessing gradients to make sure that
  // gradient computations are finished before use.
  for (auto& entry : gradReadyEvents_) {
    auto& event = entry.second;
    event.block(impl_.getStream(event.device()));
  }
  return accumulatedGrads_; // 分佈式梯度累積在這裡
}

在 torch/csrc/distributed/autograd/context/context.h之中有:

// DistAutogradContext which stores information for a single distributed
// autograd pass on a worker.
class TORCH_API DistAutogradContext {
  // Gradients accumulated in this context so far. The key is the variable on
  // which the gradient needs to be accumulated and the value is the gradient
  // that needs to be accumulated on that variable..
  c10::Dict<torch::Tensor, torch::Tensor> accumulatedGrads_;

所以我們邏輯拓展如下:

  1. DistributedOptimizer 調用 optim_rref1 和 optim_rref2 的 step 方法在遠端 worker 之上進行運行,優化。
  2. Worker 1 和 worker 2 之上的 _LocalOptimizer 分別獲得對本地 _local_params_ 進行優化。
  3. 優化結果在 _Node DistAutogradContext 之中的accumulatedGrads_累積。

這樣,整個模型的各個子模型就在各個 Node 之上以統一的步驟進行訓練/優化。

                                                   +--------------------------------------+
                                                   | Node 2                      worker 1 |
                                                   |                                      |
                                                   |    +--------------------------+      |
                                                   |    | DistAutogradContext      |      |
                                                   |    |                          |  3   |
                                                   |    |     accumulatedGrads_ <------+  |
+-----------------------------------------+        |    |                          |   |  |
| Node 1                           master |        |    +--------------------------+   |  |
|                                         |        |    +--------------------------+   |  |
| +-------------------------------+       |  +--------> | _LocalOptimizer          |   |  |
| | DistributedOptimizer          |       |  |     |    |                          |   |  |
| |                               |       |  |     |    |   optim = _FunctionalSGD |   |  |
| |                               |       |  |     |    |                          |   |  |
| |   remote_optimizers = [       |       |  |     |    |   _local_params = rref1  |   |  |
| |              optim_rref1 +---------------+     |    |                     +    |   |  |
| |              ,                |       |     +---------> step() +-------------------+  |
| |              optim_rref2 +-------+    |     |  |    |                     |    |      |
| |                               |  |    |     |  |    +--------------------------+      |
| |              ]           +----------------->+  |                        2 |           |
| |                          |    |  |    |        |                          v           |
| |                          |    |  |    |   +----------------> torch.rand((3, 3))       |
| |                        1 |    |  |    |   |    |                                      |
| |   step() {               |    |  |    |   |    +--------------------------------------+
| |                          |    |  |    |   |
| |     optim_rref1.step()+--+    |  |    |   |    +--------------------------------------+
| |                               |  |    |   |    | Node 3                      worker 2 |
| |     optim_rref2.step()+--+    |  |    |   |    |                                      |
| |                          |    |  |    |   |    |     +--------------------------+     |
| |   }                      |    |  |    |   |    |     | _LocalOptimizer          |     |
| |                          |    |  |    |   |    |     |                          |     |
| +-------------------------------+  +-----------------> |                          |     |
|                            |            |   |    |     |   optim = _FunctionalSGD |     |
|                            |            |   |    |     |                          |     |
|                          1 |            |   |    |     |   _local_params = rref2  |     |
|                            |            |   |    |     |                     +    |  3  |
|                            +-----------------------------> step() +------------------v  |
|                                         |   |    |     |                     |    |  |  |
|                         rref1 +-------------+    |     +--------------------------+  |  |
|                                         |        |                        2  |       |  |
|                                         |        |                           v       |  |
|                         rref2 +-------------------------------> torch.rand((3, 3))   |  |
|                                         |        |                                   |  |
+-----------------------------------------+        |     +--------------------------+  |  |
                                                   |     | DistAutogradContext      |  |  |
                                                   |     |                          |  |  |
                                                   |     |     accumulatedGrads_ <-----+  |
                                                   |     |                          |     |
                                                   |     +--------------------------+     |
                                                   +--------------------------------------+

0x05 PipeDream 優化器

最後,我們來看看 PipeDream,看看它是怎麼實現分佈式優化器的,我們探尋的思路是:

  • 因為PipeDream是在每個worker之上啟動全部代碼,所以每個本地優化器如何確定自己要優化的參數?
  • 優化時候如何更新參數?

5.1 如何確定優化參數

我們先提前說一下:

  • 每個node的module不同,所以每個優化器的待優化參數是本地module的參數。
  • 每個node優化自己負責的部分module。

我們需要從頭梳理。

5.1.1 main 方法

來到 runtime/translation/main_with_runtime.py。這裡首先構建一個 StageRuntime,然後用 StageRuntime 的參數來構建優化器。

def main():
    r = runtime.StageRuntime(
        model=model, distributed_backend=args.distributed_backend,
        fp16=args.fp16, loss_scale=args.loss_scale,
        training_tensor_shapes=training_tensor_shapes,
        eval_tensor_shapes=eval_tensor_shapes,
        training_tensor_dtypes=dtypes,
        inputs_module_destinations=inputs_module_destinations,
        target_tensor_names=target_tensor_names,
        configuration_maps=configuration_maps,
        master_addr=args.master_addr,
        rank=args.rank, local_rank=args.local_rank,
        num_ranks_in_server=args.num_ranks_in_server,
        verbose_freq=args.verbose_frequency,
        model_type=runtime.TRANSLATION,
        enable_recompute=args.recompute)
    
    if use_adam_optimizer:
        optimizer = adam.AdamWithWeightStashing(
            modules=r.modules(), master_parameters=r.master_parameters,
            model_parameters=r.model_parameters, loss_scale=args.loss_scale,
            num_versions=num_versions, lr=args.lr, betas=(0.9,0.999),
            weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency,
            macrobatch=args.macrobatch)
    else:
        optimizer = sgd.SGDWithWeightStashing(
            modules=r.modules(), master_parameters=r.master_parameters,
            model_parameters=r.model_parameters, loss_scale=args.loss_scale,
            num_versions=num_versions, lr=args.lr, momentum=args.momentum,
            weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency)    

5.1.2 構建runtime

StageRuntime 的 initialize 函數會構建 module,這裡通過本 node 的stage 來構建自己的 modules。

我們從前面文章中摘錄。

stage_to_module_map 就是設置 stage 到 modules 的關係,目的是為了得到本stage所對應的modules。

本stage(數值為 3)對應的是 index 為 3,4 的兩個 module,就是下面的 3 ,3.

module_to_stage_map = {list: 5} [0, 1, 2, 3, 3]

具體代碼是:

def initialize(self, model, inputs_module_destinations,
               configuration_maps, master_addr, rank,
               local_rank, num_ranks_in_server):
  
        if module_to_stage_map is None:
            self.modules_with_dependencies = ModulesWithDependencies(model)
        else:
            # 依據本stage來找到自己的modules。
            modules = stage_to_module_map[self.stage]
            self.modules_with_dependencies = ModulesWithDependencies(
                [model[module] for module in modules])
        
        # 確定哪些模型layers
        modules = self.modules_with_dependencies.modules()            

        # 拿到 master_parameters 和 model_parameters
        if self.fp16:
            self.master_parameters = []
            self.model_parameters = []
            for i in range(len(modules)):
                import apex.fp16_utils as fp16_utils
                module_parameters, module_master_parameters = \
                    fp16_utils.prep_param_lists(modules[i])
                self.master_parameters.extend(module_master_parameters)
                self.model_parameters.extend(module_parameters)
        else:
            self.master_parameters = list(self.parameters())
            self.model_parameters = None     
            
            

比如模型被分配到兩個node之上,每個node兩個layers,這裡 Node 2有一個DDP數據並行。

每個 Node 的模型參數就是不同的,Node 1 的待優化參數是 Layer 1,Layer 2 的參數;Node 2 的待優化參數是 Layer 3,Layer 4 的參數。

                                              Node 2
                                              +----------------------------------------+
                                              | Stage 2                   StageRuntime |
                                              |                                        |
Node 1                                        |           CommunicationHandler         |
+---------------------------------------+     |                                        |
| Stage 1        StageRuntime           |     |      +----------------------------+    |
|                                       |     |      | +------------------------+ |    |
|                                       |     |      | |Rank 2                  | |    |
|         CommunicationHandler          |     |      | |                        | |    |
|                                       |     |      | |                        | |    |
|      +-----------------------+        |     |      | |  Layer 3 +---> Layer 4 | |    |
|      |Rank 1                 |        |     |      | |                        | |    |
|      |                       |        |     | DDP  | |                        | |    |
|      | Layer 1 +---> Layer 2 |        +----------->+ +------------------------+ |    |
|      |                       |        |     |      | +------------------------+ |    |
|      |                       |        |     |      | |Rank 3                  | |    |
|      +-----------------------+        |     |      | |                        | |    |
|                                       |     |      | |                        | |    |
|   master_parameters = Parameters(     |     |      | |  Layer 3 +---> Layer 4 | |    |
|                   Layer 1, Layer 2)   |     |      | |                        | |    |
|                                       |     |      | |                        | |    |
|   model_parameters                    |     |      | +------------------------+ |    |
|                                       |     |      +----------------------------+    |
+---------------------------------------+     |                                        |
                                              |                                        |
                                              |  master_parameters = Parameters(       |
                                              |                      Layer 3, Layer 4) |
                                              |                                        |
                                              |                                        |
                                              |  model_parameters                      |
                                              |                                        |
                                              +----------------------------------------+

5.1.3 SGDWithWeightStashing

然後用 runtime 的 master_parameters 和 model_parameters 構建本地優化器 SGDWithWeightStashing。

OptimizerWithWeightStashing 是 SGDWithWeightStashing 的基類。

class SGDWithWeightStashing(OptimizerWithWeightStashing): # 基類
    """
    SGD optimizer with weight stashing.
    """
    def __init__(self, modules, master_parameters, model_parameters,
                 loss_scale, num_versions, lr=required, momentum=0,
                 dampening=0, weight_decay=0, nesterov=False, verbose_freq=0,
                 macrobatch=False):
        super(SGDWithWeightStashing, self).__init__(
            optim_name='SGD',
            modules=modules, master_parameters=master_parameters,
            model_parameters=model_parameters, loss_scale=loss_scale,
            num_versions=num_versions, lr=lr, momentum=momentum,
            dampening=dampening, weight_decay=weight_decay,
            nesterov=nesterov, verbose_freq=verbose_freq,
            macrobatch=macrobatch,
        )

基類 OptimizerWithWeightStashing 會生成一個原生優化器,賦值在 base_optimizer。

class OptimizerWithWeightStashing(torch.optim.Optimizer):
    """Wrapper class that adds weight stashing to a vanilla torch.optim.Optimizer.

    Arguments:
        - optim_name: the name of optimizer, required to create the corresponding
                      base_optimizer (torch.optim.{optim_name}).
        - optimizer_args: the keyword arguments passed to base_optimizer.
    """

    def __init__(self, optim_name, modules, master_parameters, model_parameters,
                 loss_scale, num_versions, verbose_freq=0, macrobatch=False,
                 **optimizer_args):
        self.modules = modules
        self.master_parameters = master_parameters
        self.model_parameters = model_parameters  # model_parameters is None if not fp16.
        self.loss_scale = loss_scale

        # Only need at most 2 versions if using macrobatching.
        if macrobatch:
            num_versions = min(2, num_versions)
        self.num_versions = num_versions
        
        # 生成一個原生優化器
        self.base_optimizer = getattr(torch.optim, optim_name)(
            master_parameters, **optimizer_args)
        self.latest_version = Version()
        self.current_version = Version()
        self.initialize_queue()
        self.verbose_freq = verbose_freq
        self.batch_counter = 0

        # If macrobatching, push and pop versions at the right rate.
        if macrobatch:
            self.update_interval = self.num_versions
        else:
            self.update_interval = 1

邏輯拓展如下,每個優化器使用自己 Node 的參數進行優化。

                                              +----------------------------------------+
                                              | Stage 2                   StageRuntime |
                                              |                                        |
                                              |           CommunicationHandler         |
+---------------------------------------+     |                                        |
| Stage 1        StageRuntime           |     |      +----------------------------+    |
|                                       |     |      | +------------------------+ |    |
|                                       |     |      | |Rank 2                  | |    |
|         CommunicationHandler          |     |      | |                        | |    |
|                                       |     |      | |                        | |    |
|      +-----------------------+        |     |      | |  Layer 3 +---> Layer 4 | |    |
|      |Rank 1                 |        |     |      | |                        | |    |
|      |                       |        |     | DDP  | |                        | |    |
|      | Layer 1 +---> Layer 2 |        +----------->+ +------------------------+ |    |
|      |                       |        |     |      | +------------------------+ |    |
|      |                       |        |     |      | |Rank 3                  | |    |
|      +-----------------------+        |     |      | |                        | |    |
|                                       |     |      | |                        | |    |
|   master_parameters = Parameters(     |     |      | |  Layer 3 +---> Layer 4 | |    |
|                   Layer 1, Layer 2)   |     |      | |                        | |    |
|                             +         |     |      | |                        | |    |
|   model_parameters          |         |     |      | +------------------------+ |    |
|                             |         |     |      +----------------------------+    |
|  +---------------------------------+  |     |                                        |
|  |SGDWithWeightStashing     |      |  |     |                                        |
|  |                          |      |  |     |  master_parameters = Parameters(       |
|  |   base_optimizer = SGB(  v      |  |     |                      Layer 3, Layer 4) |
|  |              master_parameters) |  |     |                               +        |
|  |                                 |  |     |  model_parameters             |        |
|  +---------------------------------+  |     |                               |        |
|                                       |     |  +----------------------------------+  |
+---------------------------------------+     |  |SGDWithWeightStashing       |     |  |
                                              |  |                            |     |  |
                                              |  |      base_optimizer = SGB( v     |  |
                                              |  |               master_parameters) |  |
                                              |  +----------------------------------+  |
                                              |                                        |
                                              +----------------------------------------+

5.2 優化

5.2.2 整體優化

整體是異步運行,也就是異步優化。

def train(train_loader, r, optimizer, epoch):

  	# 省略其他
    
    # start num_warmup_minibatches forward passes
    for i in range(num_warmup_minibatches):
        r.run_forward()

    for i in range(n - num_warmup_minibatches):
        # perform forward pass
        r.run_forward()

        # perform backward pass
        if args.fp16:
            r.zero_grad()
        else:
            optimizer.zero_grad()
        optimizer.load_old_params()

        r.run_backward()
        optimizer.load_new_params()
        optimizer.step()

    # finish remaining backward passes
    for i in range(num_warmup_minibatches):
        optimizer.zero_grad()
        optimizer.load_old_params()
        r.run_backward()
        optimizer.load_new_params()
        optimizer.step()

    # wait for all helper threads to complete
    r.wait()

5.2.2 優化器優化

優化直接使用 SGDWithWeightStashing 的 step 方法。其最後也是 class OptimizerWithWeightStashing(torch.optim.Optimizer) 的 step 方法。

def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
        closure (callable, optional): A closure that reevaluates the model
                                      and returns the loss.
    """
    # Update the gradient every `update_interval` steps.
    if self.batch_counter % self.update_interval != self.update_interval - 1:
        self.batch_counter += 1
        return None

    if self.model_parameters is not None:
        import apex.fp16_utils as fp16_utils
        fp16_utils.model_grads_to_master_grads(self.model_parameters,
                                               self.master_parameters)
        if self.loss_scale != 1.0:
            # 處理梯度
            for parameter in self.master_parameters:
                parameter.grad.data = parameter.grad.data / self.loss_scale

    for p in self.param_groups[0]['params']:
        if p.grad is not None: # 繼續處理累積的梯度
            p.grad.div_(self.update_interval)

    loss = self.base_optimizer.step() # 進行優化
    if self.model_parameters is not None:
        import apex.fp16_utils as fp16_utils
        fp16_utils.master_params_to_model_params(self.model_parameters,
                                                 self.master_parameters)
    self.latest_version = self.latest_version.incr()
    if self.num_versions > 1:
        self.buffered_state_dicts = self.queue[0][0]
        self.queue.append(self.get_params(clone=False))

    self.batch_counter += 1
    return loss

具體如下:

                                               Node 2
                                               +-----------------------------------------+
                                               | Stage 2                    StageRuntime |
                                               |                                         |
Node 1                                         |           CommunicationHandler          |
+-----------------------------------------+    |                                         |
| Stage 1                    StageRuntime |    |      +----------------------------+     |
|                                         |    |      | +------------------------+ |     |
|                                         |    |      | |Rank 2                  | |     |
|          CommunicationHandler           |    |      | |                        | |     |
|                                         |    |      | |                        | |     |
|       +-----------------------+         |    |      | |  Layer 3 +---> Layer 4 | |     |
|       |Rank 1                 |         |    |      | |                        | |     |
|       |                       |         |    | DDP  | |                        | |     |
|       | Layer 1 +---> Layer 2 |         +---------->+ +------------------------+ |     |
|       |                       |         |    |      | +------------------------+ |     |
|       |                       |         |    |      | |Rank 3                  | |     |
|       +-----------------------+         |    |      | |                        | |     |
|                                         |    |      | |                        | |     |
|  master_parameters = Parameters(        |    |      | |  Layer 3 +---> Layer 4 | |     |
|                  Layer 1, Layer 2)      |    |      | |                        | |     |
|                                +        |    |      | |                        | |     |
|  model_parameters              |        |    |      | +------------------------+ |     |
|                                |        |    |      +----------------------------+     |
|  step()                        |        |    |                                         |
|   +                            |        |    |                                         |
|   |                            |        |    |  master_parameters = Parameters(        |
|   |  +-------------------------------+  |    |                      Layer 3, Layer 4)  |
|   |  |SGDWithWeightStashing    |     |  |    |                                   +     |
|   |  |                         |     |  |    |  model_parameters                 |     |
|   |  |   base_optimizer = SGB( v     |  |    |                                   |     |
|   |  |            master_parameters) |  |    |  step()                           |     |
|   |  |                               |  |    |   +                               |     |
|   +----> base_optimizer.step()       |  |    |   |                               |     |
|      |                               |  |    |   |  +-------------------------------+  |
|      +-------------------------------+  |    |   |  |SGDWithWeightStashing       |  |  |
|                                         |    |   |  |                            |  |  |
+-----------------------------------------+    |   |  |      base_optimizer = SGB( v  |  |
                                               |   |  |            master_parameters) |  |
                                               |   |  |                               |  |
                                               |   +------>  base_optimizer.step()    |  |
                                               |      |                               |  |
                                               |      +-------------------------------+  |
                                               |                                         |
                                               +-----------------------------------------+

至此,分佈式優化器系列完成,在後續分析ZeRO時候,我們還會介紹 PyTorch ZeroRedundancyOptimizer,估計要等待幾周之後了。我們從下一篇開始,介紹 PyTorch 分佈式 的幾個官方文檔應用例子,以此來把 PyTorch 分佈式整個邏輯串聯起來看看在實際之中應該如何應用,敬請期待。

0xFF 參考

torch.optim.optimizer源碼閱讀和靈活使用

pytorch源碼閱讀(二)optimizer原理

pytorch 優化器(optim)不同參數組,不同學習率設置的操作

Pytorch——momentum動量

各種優化方法總結比較(sgd/momentum/Nesterov/adagrad/adadelta)

【優化器】優化器算法及PyTorch實現(一):永不磨滅的SGD

以optim.SGD為例介紹pytorch優化器

Pytorch學習筆記08—-優化器算法Optimizer詳解(SGD、Adam)

pytorch中使用torch.optim優化神經網絡以及優化器的選擇 – pytorch中文網

pytorch優化器詳解:SGD

聊聊GPU通信那些事

//developer.nvidia.com/gpudirect

//www.nvidia.cn/data-center/magnum-io/

//www.nvidia.cn/data-center/nvlink/