[源碼解析] PyTorch 流水線並行實現 (5)–計算依賴

[源碼解析] PyTorch 流水線並行實現 (5)–計算依賴

0x00 摘要

前幾篇文章我們介紹了 PyTorch 流水線並行的基本知識,自動平衡機制和切分數據等,本文我們結合論文內容來看看如何實現流水線依賴,核心就是如何建立這些小批次之間的跨設備依賴關係

流水線並行其他文章鏈接如下:

[源碼解析] 深度學習流水線並行Gpipe(1)—流水線基本實現

[源碼解析] 深度學習流水線並行GPipe (2) —– 梯度累積

[源碼解析] 深度學習流水線並行 GPipe(3) —-重計算

[源碼解析] 深度學習流水線並行之PipeDream(1)— Profile階段

[源碼解析] 深度學習流水線並行 PipeDream(2)— 計算分區

[源碼解析] 深度學習流水線並行 PipeDream(3)— 轉換模型

[源碼解析] 深度學習流水線並行 PipeDream(4)— 運行時引擎

[源碼解析] 深度學習流水線並行 PipeDream(5)— 通信模塊

[源碼解析] 深度學習流水線並行 PipeDream(6)— 1F1B策略

[源碼解析] PyTorch 流水線並行實現 (1)–基礎知識

[源碼解析] PyTorch 流水線並行實現 (2)–如何劃分模型

[源碼解析] PyTorch 流水線並行實現 (3)–切分數據和運行時系統

[源碼解析] PyTorch 流水線並行實現 (4)–前向計算

本文圖片來自論文和github源碼。

0x01 前文回顧

為了更好的理解本文,我們首先看看前文之中的關鍵部分。

  • 原始流水線狀態如下:
    • 管道並行的策略是根據分區索引 j 分配任務,以便第 j 個分區完全位於第 j 個設備中。
    • 持有模型後期部分的設備必須等待,直到持有模型早期部分的設備計算結束。

img

  • 目標流水線狀態如下:

img

  • 目前問題

    • 如果分成若干個微批次,則需要強制要求 \(F_{i,j}\) 必須在 \(F_{i+1,j}\) 之前完成,以及 \(B{i,j}\) 必須在執行\(B{i-1,j}\) 之前完成
    • 後向傳播的計算圖是在前向傳播過程中動態構造的。PyTorch既不記錄正向計算圖,也不維護一個梯度磁帶(gradient tape),PyTorch的自動微分(autograd)引擎僅對計算圖進行反向傳播。這意味着自動加載引擎可能不會完全按照與正向過程相反的執行順序運行,除非由圖的結構強制執行
  • 目前難點

    • 如何在每個設備中以正確的順序發佈那些綁定到設備的任務,以避免由於Python解釋器未能提前請求而延遲在設備上(與CPU異步)執行任務。[這個前文已經介紹]
    • 如何建立這些小批次之間的跨設備依賴關係
  • 實現方案

    • 如何保證正確執行順序?torchgpipe引入了確定性時鐘周期(deterministic clock-cycle),它給出了任務的總體順序[這個前文已經介紹]
    • 如何保證計算圖中的動態顯式依賴關係?針對clock_cycles產生的每一個運行計劃:
      • 利用 fence 函數調用「fork」和「join」,以此在向後計算圖中動態創建顯式後向傳播依賴關係。
      • 利用 compute(schedule, skip_trackers, in_queues, out_queues) 進行計算。

因為前文已經介紹了執行順序方案,所以本文介紹如何計算依賴。

0x02 計算依賴

+-----------------------------------------------------------------------------------------+
|                                                                                         |
| Layer 1 +--->  Layer 2 +-----> Layer 3 +----->  Layer 4 +-----> Layer 5  +---> Layer 6  |
|                                                                                         |
+--------------------------+---------------------------+----------------------------------+
                                          +
                                          |
                                          |
                                          v
 +------------------------------------------------------------------------------------+
 | +--------------------+         +---------------------+      +--------------------+ |
 | |Partition 1         |         |Partition 2          |      |Partition 3         | |
 | |                    |         |                     |      |                    | |
 | |      Layer 1       |    +---------> Layer 4        |      |                    | |
 | |         +          |    |    |         +           |  +------->   Layer 6      | |
 | |         |          |    |    |         |           |  |   |                    | |
 | |         v          |    |    |         |           |  |   |                    | |
 | |      Layer 2       |    |    |         |           |  |   |                    | |
 | |         +          |    |    |         v           |  |   |                    | |
 | |         |          |    |    |      Layer 5 +---------+   |                    | |
 | |         v          |    |    |                     |      |                    | |
 | |      Layer 3  +---------+    |                     |      |                    | |
 | |                    |         |                     |      |                    | |
 | +---------+----------+         +---------+-----------+      +-----------+--------+ |
 |                                                                                    |
 +------------------------------------------------------------------------------------+

為什麼需要計算依賴?

  • 因為模型已經被分層,模型的不同部分拆開放到不同設備上,數據也分成微批次,所以本來模型內部是線性依賴關係,現在需要變成流水線依賴關係。因此原始計算圖不能滿足需求,因此需要有針對性的補充。就像上圖那樣,6個層被分成了三個partitions,這三個partitons 之間的依賴如何構建
  • 之前的線性依賴關係其實是在模型定義時候就基本確定了,現在則需要每次運行時候建立一個動態依賴關係。

所以針對流水線並行,torchgpipe需要自己補充一個本機跨設備偽分佈式依賴關係。torchgpipe 通過在前向計算圖和後向計算圖做各種調整來達到目的。計算圖就意味着各種依賴邏輯,依賴邏輯的補足就是依靠本節介紹的 Fork 和 Join 兩個函數完成的。

這裡最初有一個疑問,就是Torchgpipe怎麼在不使用 PyTorch RPC 和 p2p的情況下,構建出來一個異地反向計算圖。後來發現,原來是我想多了,因為Torchgpipe沒有考慮到這種情況,它針對都是在同一個主機之上的GPU,不涉及異地多機器計算。

Torchgpipe 本質上還是一個進程內運行多個線程進行計算,是 DP 的替代。比如源碼中就有對比如下:

### ResNet-101 Accuracy Benchmark

Batch size | torchgpipe | nn.DataParallel | Goyal et al.
---------- | ---------: | --------------: | -----------:
256        | 21.99±0.13 |      22.02±0.11 |   22.08±0.06
1K         | 22.24±0.19 |      22.04±0.24 |          N/A
4K         | 22.13±0.09 |             N/A |          N/A

再比如代碼中明確提到:

If you decide not to use checkpointing at all, :class:`nn.DataParallel
<torch.nn.DataParallel>` might be more efficient than GPipe.

0x03 反向傳播依賴

我們首先看看反向傳播依賴,這個是論文的重點。

2.1 解析

我們還是要回憶一下前面兩個圖例。

圖1

img

圖2

img

這裡需要完成兩種依賴:

  • 行間依賴,就是 batch 之間的依賴,就是設備內的依賴。從圖上看,就是藍色列內的 \(F_{1,1}\) 必須在 \(F_{2,1}\)之前完成,\(B_{2,1}\) 必須在\(B_{1,1}\) 之前完成。
  • 列間依賴,就是 partitions(設備) 之間的依賴。從圖上看,就是藍色 \(F_{1,1}\) 必須在黃色 \(F_{1,2}\)之前完成,即第一個設備必須在第二個設備之前完成,而且第一個設備的輸出是第二個設備的輸入。

假定我們依據確定性時鐘周期(deterministic clock-cycle)算法來運行一個前向傳播。即使前向傳播是按照在第j個設備上應該執行的順序來執行任務 \(F_{1,j},…,F_{m,j}\) ,得到的後向傳播結果計算圖看起來也更像圖1而非圖2,

從圖1上看,PyTorch 的 autograd 引擎不知道 \(B_{i+1,j}\) 必須在 \(B_{i,j}\) 之前運行,因此會打亂後向傳播的時間流。因此,虛擬依賴(圖2的虛線箭頭)必須在前向傳播中被顯式繪製出來。

我們再仔細分析一下圖2。圖2之中,每一行都表示一個 micro-batch 在訓練中的運行流,這個流的前向是由clock算法確定的。後向關係是由前向傳播中自動確定完成的

現在的問題是:一個 mini-batch 被分成了4個 micro-batch,分別在不同時鐘周期進入訓練。就是每一列。這一列由上到下的傳播也是由clock算法確定,但是反向傳播(由下自上)目前是不確定的。比如最後一列中,反向傳播的順序應是:\(B_{4,1},B_{3,1},B_{2,1},B_{1,1}\)。但是這個目前從前向傳播的結果來看,無法確定這個順序。

所以需要依靠本節介紹的 Fork 和 Join 兩個函數完成這個依賴關係。圖中斜線表示checkpoint之中需要先有一個重計算,然後才能由下往上走

因此,torchpipe定義兩個基礎函數,Fork 和 Join 來表達這種依賴關係:

  • Fork 是 auto grad 函數,其把一個張量 x 映射到 pair(x, \(\phi\)),這裡 \(\phi\) 是一個空張量。

  • Join 是 auto grad 函數,其把 pair(x, \(\phi\)) 映射到一個張量 x ,這裡 \(\phi\) 是一個空張量。

現在,\(F_{i+1,j}\) 對於 \(F_{i,j}\) 的依賴(其在後向傳播計算圖中被轉換為 \(B_{i,j}\) 到 $B_{i+1,j} $ 的依賴關係)可以被如下表示

所以,圖中這裡實線都是前向傳播時候構建的,虛線是由 fork & join 構建的。

原則上,表示虛擬依賴關係的張量可以是任意的。然而,torchgpipe選擇使用空張量,以消除由張量引起的任何不必要的計算,例如PyTorch中的梯度累積。

具體如下圖。就是使用 Fork 和 Join 的後向計算圖。圖中,不同顏色對應不同的設備。箭頭依據後向傳播圖的方向來繪製,這些聯繫是在前向傳播中被構建的。因此,\(F^{‘}_{i,j}\) 對於 \(B_{i+1,j}\) 的虛擬依賴通過 Fork 和 Join 被構建出來,用虛線表示。

2.2 基礎功能

2.2.1 Function

首先,我們要看看 torch.autograd.Function 的作用。

torch.autograd.Function類實際上是一個操作函數的基礎父類,這樣的操作函數必須具備兩個基本的過程,即前向的運算過程和反向的求導過程,

如果某些操作無法通過 PyTorch 已有的層或者是已有的方法實現不了,就需要實現一個新的方法對 PyTorch 進行拓展。當不使用自動求導機制,需要自定義求導規則的時候,就應該拓展torch.autograd.Function類。 由於pytorch不再提供自動求導機制,就要用戶自己定義實現前向傳播和反向傳播的計算過程,這就是 “Extending torch.autograd”。

我們接下來介紹Backward Dependency 的關鍵算法:Fork and Join。

2.2.2 Fork

Fork 是auto grad 函數,其把一個張量 x 映射到 pair(x, \(\phi\)),這裡 \(\phi\) 是一個空張量。Fork 方法就是拓展了torch.autograd.Function

def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
    """Branches out from an autograd lane of the given tensor."""
    if torch.is_grad_enabled() and input.requires_grad:
        input, phony = Fork.apply(input)
    else:
        phony = get_phony(input.device, requires_grad=False)

    return input, phony


class Fork(torch.autograd.Function):
    @staticmethod
    def forward(ctx: 'Fork', input: Tensor) -> Tuple[Tensor, Tensor]:  # type: ignore
        phony = get_phony(input.device, requires_grad=False)
        return input.detach(), phony.detach()

    @staticmethod
    def backward(ctx: 'Fork', grad_input: Tensor, grad_grad: Tensor) -> Tensor:  # type: ignore
        return grad_input

2.2.3 Join

Join 是auto grad 函數,其把 pair(x, \(\phi\)) 映射到一個張量 x ,這裡 \(\phi\) 是一個空張量。Join 方法也是拓展了torch.autograd.Function

def join(input: Tensor, phony: Tensor) -> Tensor:
    """Merges two autograd lanes."""
    if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
        input = Join.apply(input, phony)

    return input


class Join(torch.autograd.Function):
    @staticmethod
    def forward(ctx: 'Join', input: Tensor, phony: Tensor) -> Tensor:  # type: ignore
        return input.detach()

    @staticmethod
    def backward(ctx: 'Join', grad_input: Tensor) -> Tuple[Tensor, None]:  # type: ignore
        return grad_input, None

2.2.4 Phony

Phony是沒有空間的張量,因為它不需要任何梯度累積,所以可以在 autograd 圖中構建任意的依賴。

def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor:
    """Gets a phony. Phony is tensor without space. It is useful to make
    arbitrary dependency in a autograd graph because it doesn't require any
    gradient accumulation.

    .. note::

        Phonies for each device are cached. If an autograd function gets a phony
        internally, the phony must be detached to be returned. Otherwise, the
        autograd engine will mutate the cached phony in-place::

            class Phonify(torch.autograd.Function):
                @staticmethod
                def forward(ctx, input):
                    phony = get_phony(input.device, requires_grad=False)
                    return phony.detach()  # detach() is necessary.

    """
    key = (device, requires_grad)

    try:
        phony = _phonies[key]
    except KeyError:
        with use_stream(default_stream(device)):
            phony = torch.empty(0, device=device, requires_grad=requires_grad)

        _phonies[key] = phony

    return phony

2.2.5 detach

在代碼中,經常可以見到 detach 的使用,這個從注釋可以看出來,是為了解決 PyTorch 的一個bug。

    # A Python autograd function might fail with this error:
    #
    #   RuntimeError: Returning Variables sharing storage with other Variables
    #   that require grad is not supported in Python functions. Please submit a
    #   feature request if you hit this error.
    #
    # It doesn't look like an essential restriction. But it happens on the
    # current PyTorch version. To avoid it, we should detach the tensor before
    # returning by identity autograd functions, such as Wait, Fork, and Join.
    #

2.3 使用

在 Pipeline 之中我們可以看到具體的使用方法,fence 方法(省略部分代碼)利用 depend 來構建後向傳播的依賴關係,確保 batches[i-1] 在 batches[i] 之後完成。

    def fence(self,
              schedule: List[Tuple[int, int]],
              skip_trackers: List[SkipTrackerThroughPotals],
              ) -> None:
        """Copies micro-batches after computation for the previous
        micro-batches.
        """
        batches = self.batches
        copy_streams = self.copy_streams
        skip_layout = self.skip_layout

        for i, j in schedule:
            # Ensure that batches[i-1] is executed after batches[i] in
            # backpropagation by an explicit dependency.
            if i != 0:
                depend(batches[i-1], batches[i]) # 在這裡建立了後向傳播依賴關係
                
            next_stream = copy_streams[j][i]

            for prev_j, ns, name in skip_layout.copy_policy(j):
                prev_stream = copy_streams[prev_j][i]
                skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)

            if j != 0:
                prev_stream = copy_streams[j-1][i]
                copy(batches[i], prev_stream, next_stream)                

具體 depend 代碼如下:

def depend(fork_from: Batch, join_to: Batch) -> None:
    fork_from[0], phony = fork(fork_from[0])
    join_to[0] = join(join_to[0], phony)

我們結合示例代碼把傳入的參數賦值一下,重新把方法解釋如下,這樣大家就可以更好的理解。

def depend(batches[i-1]: Batch, batches[i]: Batch) -> None:
    batches[i-1][0], phony = fork(batches[i-1][0])
    batches[i][0] = join(batches[i][0], phony)

具體邏輯如下,通過 phony 完成了一個橋接,即在正向傳播之中,batches[i] 依賴 batches[i-1] 的執行結果

      +----------------+          +--------------+
      |                |          |              |
      |  batches[i-1]  |          |  batches[i]  |
      |                |          |              |
      +----------+-----+          +-----+--------+
                 |                      |
                 |                      |
                 |                      |
                 v                      v
+--------------------------------------------------------+
| depend         |                      |                |
|                |                      |                |
|                |                      |                |
|                v                      |                |
|        +-----------------------+      |                |
|        | fork  |               |      |                |
|        |       |    get_phony  |      |                |
|        |       |        +      |      |                |
|        |       |        |      |      |                |
|        |       |        |      |      |                |
|        +-----------------------+      |                |
|                |        |             |                |
|                |        |             |                |
|                |        |             |                |
|                v        v             |                |
|    +-----------+--+  +--+-----+       |                |
|    |              |  |        |       |                |
|    | batches[i-1] |  | phony  |       |                |
|    |              |  |        |       |                |
|    +--------------+  +--+-----+       |                |
|                         |             |                |
|                         |             |                |
|                         v             v                |
|                      +--+------------------+           |
|                      |Join            |    |           |
|                      |                |    |           |
|                      |                |    |           |
|                      |                v    |           |
|                      +---------------------+           |
|                                       |                |
|                                       |                |
|                                       |                |
|                                       v                |
|                                 +-----+------+         |
|                                 |            |         |
|                                 | batches[i] |         |
|                                 |            |         |
|                                 +------------+         |
|                                                        |
+--------------------------------------------------------+

我們把多個 batches 聯合起來看看,這樣就能看出來一個依賴鏈條。

                  +----------------------------------------------------------+
                  | depend                                                   |
                  |                                                          |
                  | +------------+                                           |
 +-------------   | |fork        |     +-----------+                         |
 |            |   | |            |     |           |                         |
 |batches[i]  +----------------------> | batches[i]|                         |
 |            |   | |            |     |           |                         |
 +-------------   | |            |     +-----------+                         |
                  | |            |             +-------+                     |
                  | |            +-----------> | Join  |                     |
                  | |            |             |       |                     |
                  | +------------+             |       |                     |
 +-------------   |                            |       |    +--------------+ |
 |            |   |                            |       |    |              | |
 |batches[i+1]+-------------------------------------------->+ batches[i+1] | |
 |            |   |                            |       |    |              | |
 +---------+---   |                            |       |    +--------------+ |
           |      |                            +-------+                     |
           |      |                                                          |
           |      +----------------------------------------------------------+
           |      +----------------------------------------------------------+
           |      | depend                                                   |
           |      |                                                          |
           |      | +-------------+                                          |
           |      | |fork         |     +------------+                       |
           |      | |             |     |            |                       |
           +--------------------------> |batches[i+1]|                       |
                  | |             |     |            |                       |
                  | |             |     +------------+                       |
                  | |             |           +-------+                      |
                  | |             +---------> |Join   |                      |
                  | +-------------+           |       |                      |
+------------+    |                           |       |     +-------------+  |
|            |    |                           |       |     |             |  |
|batches[i+2]+--------------------------------------------> | batches[i+2]|  |
|            |    |                           |       |     |             |  |
+----------+-+    |                           |       |     +-------------+  |
           |      |                           +-------+                      |
           |      |                                                          |
           |      +----------------------------------------------------------+
           |
           |      +-----------------------------------------------------------+
           |      | depend                                                    |
           |      |                                                           |
           +----------------------------->    ......                          |
                  |                                                           |
                  |                                                           |
                  +-----------------------------------------------------------+

這樣,上圖就是前向計算圖,於是在後向傳播之中,batches[i] 就 必須在 batches[i-1] 之前完成了

我們再結合論文的圖來看看。

本來示例代碼中是:

depend(batches[i-1], batches[i])

為了和論文中的圖對應,我們修改為:

depend(batches[i], batches[i+1])

depend 代碼也變化為:

def depend(batches[i]: Batch, batches[i+1]: Batch) -> None:
    batches[i][0], phony = fork(batches[i][0])
    batches[i+1][0] = join(batches[i+1][0], phony)

對應下圖,就是在後向傳播計算圖之中 batches[i+1] 通過一個join, 一個fork,排在了 batches[i] 前面,就是下面大箭頭所示,具體細化一下:

  • 從這個圖上,PyTorch 的 autograd 引擎不知道 \(B_{i+1,j}\) 必須在 \(B_{i,j}\) 之前運行,因此會打亂後向傳播的時間流。因此,虛擬依賴(前面圖的虛線箭頭)必須在前向傳播中被顯式繪製出來。

  • 圖上的實線箭頭依據後向傳播圖的方向來繪製,這些聯繫是在前向傳播中被構建的。就是說,對於 \({Batch}_i\) 來說,其反向傳播順序是固定的。就是上面一行內順序是固定的,下面一行內順序也是固定的

  • 但是,上下兩行之間的順序是不可知的,需要用虛線來保證,就是用 Join & Fork 來保證。

0x03 正向傳播依賴

我們回頭再來看正向依賴。因為正向傳播的部分目的就是完成反向傳播依賴,而目前反向傳播只完成了行之間的依賴,列之間的依賴沒有完成,我們現在補全

列之間的依賴就是設備之間的依賴,即前一個設備的輸出是後一個設備的輸入

3.1 分割模型

首先還是需要回顧下如何切分模型,從 split_module 可以看到,

GPipe 的 partitions 成員變量是 nn.ModuleList 類型。nn.ModuleList是一個容器,其儲存不同 module,並自動將每個 module 的 parameters 添加到網絡中。但是nn.ModuleList 並沒有定義一個網絡,而只是將不同的模塊儲存在一起,這些模塊之間並沒有什麼先後順序,網絡的執行順序是根據 forward 函數來決定的。

def split_module(module: nn.Sequential,
                 balance: Iterable[int],
                 devices: List[torch.device],
                 ) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:

    balance = list(balance)

    j = 0
    partitions = []
    layers: NamedModules = OrderedDict()

    for name, layer in module.named_children(): # 遍歷模型包含的層
        layers[name] = layer # 把新的層加入到數組中

        if len(layers) == balance[j]: # 如果數組大小等於balance[j],就是達到了device j應該包含的層數
            # Group buffered layers as a partition.
            partition = nn.Sequential(layers) # 把層數組組合成一個sequential module

            device = devices[j]
            partition.to(device) # 把層放置到相關設備之上

            partitions.append(partition) # 這個新module加入到分區數組中

            # Prepare for the next partition.
            layers.clear()
            j += 1 # 去下一個device看看

    partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
    del devices[j:]

    return partitions, balance, devices

隨之而來問題就是:partition內部可以用Sequential來進行一系列的前向操作,但是如何配置partitions 之間的執行順序?

+-----------------------------------------------------------------------------------------+
|                                                                                         |
| Layer 1 +--->  Layer 2 +-----> Layer 3 +----->  Layer 4 +-----> Layer 5  +---> Layer 6  |
|                                                                                         |
+-----------------------------------------+-----------------------------------------------+
                                          |
                                          |
                                          |
                                          v
+-----------------------------------------------------------------------------------------+
| +--------------------+           +---------------------+         +--------------------+ |
| |Partition 1         |           |Partition 2          |         |Partition 3         | |
| |                    |   ???     |                     |         |                    | |
| |      Layer 1       |     +----------> Layer 4        |   ???   |                    | |
| |         +          |     |     |         +           |     +------->   Layer 6      | |
| |         |          |     |     |         |           |     |   |                    | |
| |         v          |     |     |         |           |     |   |                    | |
| |      Layer 2       |     |     |         |           |     |   |                    | |
| |         +          |     |     |         v           |     |   |                    | |
| |         |          |     |     |      Layer 5 +------------+   |                    | |
| |         v          |     |     |                     |         |                    | |
| |      Layer 3  +----------+     |                     |         |                    | |
| |                    |           |                     |         |                    | |
| +--------------------+           +---------------------+         +--------------------+ |
|                                                                                         |
+-----------------------------------------------------------------------------------------+

3.2 建立依賴

我們還是從論文中入手。假定我們有一個神經網絡,其由一系列子網絡構成。我們假定這些子網絡是 \(f^1,…,f^n\),其參數分別是 \(\theta^1,…,\theta^n\),則整個網絡是:

參數是 \(\theta = (\theta^1,…,\theta^n)\),為了清楚起見,我們稱 \(f^j\) 表示 f 的第 j 個分區,並假設分區的參數是相互不相交的。

在訓練網絡時,基於梯度的方法(如隨機梯度下降法)需要在給定小批量訓練數據 x 和相應損失之後,計算網絡的輸出結果f(x)。以及損失相對於網絡參數 \(\theta\) 的梯度g。這兩個階段分別稱為向前傳播和向後傳播。

既然 f 由其 L 層 子模塊 (\(f^L, f^{L-1},…f^1\)) 順序組成,那麼前向傳播\(f(x)\) 可以通過如下方式計算:讓 \(x^0=x\)(就是輸入x),然後順序應用每一個 partition,即 \(x^j = f^j (x^{j-1})\),這裡 $ j = 1, …, L$。就是 \(f(x)\) 可以表示為 :

\[f(x) = f^L(f^{L-1}(f^{L-2}(… f^1(x))))
\]

於是我們知道了,前向傳播的順序是由 \(f(x) = f^L(f^{L-1}(f^{L-2}(… f^1(x))))\) 來確定的

我們可以針對代碼,進一步解析,看看如何實施partitions之間的順序依賴。

    def run(self) -> None:
        """Runs pipeline parallelism.

        It modifies the given batches in place.

        """
        batches = self.batches
        partitions = self.partitions
        devices = self.devices
        skip_layout = self.skip_layout

        m = len(batches)
        n = len(partitions)

        skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]

        with spawn_workers(devices) as (in_queues, out_queues):
            for schedule in clock_cycles(m, n): # 這裡使用,給出了執行序列計劃,後續按照這個來執行
                self.fence(schedule, skip_trackers)
                self.compute(schedule, skip_trackers, in_queues, out_queues)

解析的目標是 for schedule in clock_cycles(m, n) 這個 for 循環,其:

  • 針對clock_cycles產生的每一個運行計劃:
    • 利用 fence(schedule, skip_trackers) 構建後向傳播依賴關係。
    • 利用 compute(schedule, skip_trackers, in_queues, out_queues) 進行計算。

現在我們完成了兩步:

  1. 確定性時鐘周期算法給定了前向傳播的執行順序,我們只要按照 clock_cycles 方法提供的計劃一一運行即可
  2. fence 方法通過調用 join 和 fork,我們做到了在後向傳播之中,batches[i] 就 必須在 batches[i-1] 之前完成了,即 \(B_{i+1,j}\) 必須在 \(B_{i,j}\) 之前運行。

對於我們的圖來說,第二步就是完成了下圖的列依賴。

我們的問題是:怎麼通過這個 for 循環,做到 \(B_{i,{j+1}}\) 必須在 \(B_{i,j}\) 之前運行?,即怎麼安排反向傳播逐次運行?就是怎麼完成行內的依賴?

這就要通過 compute 的源碼進行分析。重點說明的是:

  • batches[i] 這裡是會變化的,比如 batches[0] 在經過 partitions[j] 的計算之後,會變成 batches[0][j]
  • 對於 compute 方法,關鍵就是在最底部的代碼 batches[i] = batch。就是把 第 j 個device 對 第 i 個 batch 的計算結果 賦值到 batches[i],賦值之後,batches[i]就是 batches[i][j],這樣,在下次計算時候,構建的就是 F[i, j+1], 下一次 fence 之中的 depend 操作,就是針對 batches[i, j+1]
  • 因此,在前向計算圖上,通過這個賦值操作, batches[i, j+1] 就依賴 batches[i, j],所以反向計算時候,batches[i, j + 1] 就必須在 batches[i, j] 之前完成
    def compute(self,
                schedule: List[Tuple[int, int]],
                skip_trackers: List[SkipTrackerThroughPotals],
                in_queues: List[InQueue],
                out_queues: List[OutQueue],
                ) -> None:
        """Runs tasks with synchronization to copy streams."""
        batches = self.batches
        partitions = self.partitions
        devices = self.devices
        n = len(partitions)
        streams = [current_stream(d) for d in devices]
  
        for i, j in schedule: # 針對 schedule 之中的每一對 i,j
            batch = batches[i]
            partition = partitions[j]

            # Synchronize with the copied input. ([1] in the diagram)

            # Determine whether checkpointing or not.

            if checkpoint:
							# 忽略
            else:
                def compute(batch: Batch = batch,
                            partition: nn.Sequential = partition,
                            skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                            ) -> Batch:
                    with use_skip_tracker(skip_tracker):
                        return batch.call(partition) # 前向計算,計算以 partition為單位計算,partition內部的層是順序計算,由 Sequential保證。

                task = Task(streams[j], compute=compute, finalize=None)
                del compute

            # Compute tasks in parallel. ([2] in the diagram)
            in_queues[j].put(task) # 讓 worker計算

        for i, j in schedule:
            ok, payload = out_queues[j].get() # 獲取 worker 的前向計算結果,就是 第 j 個device 對 第 i 個 batch 的計算結果

            task, batch = cast(Tuple[Task, Batch], payload)

            # The copy stream synchronizes to copy the output. ([3] in the
            # diagram)

            # Finalize tasks. If checkpointing is enabled, here the
            # recomputation is scheduled at backpropagation. ([4] in the
            # diagram)

            # 第 j 個device 對 第 i 個 batch 的計算 就是 F[i,j]

            batches[i] = batch # 這裡是關鍵,就是把 第 j 個device 對 第 i 個 batch 的計算結果 賦值到 batches[i],batches[i]就是 batches[i][j],在下次計算時候,構建的就是 F[i,j+1], 下一次 fence 之中的 depend 操作,就是針對 batches[i,j+1]

關於這個賦值操作,其對應的grad_fn 是 PermuteBackward,比如:

a = torch.tensor([2., 3.], requires_grad=True)
c = a
c.backward(gradient=external_grad)
print(c)

具體是:

c = {Tensor: 2} tensor([2., 3.], requires_grad=True)
  T = {Tensor: 2} tensor([2., 3.], grad_fn=<PermuteBackward>)

現在,我們把下圖進行升級。

                 +-------------------------------------------------------------------+
                 | depend                                                            |
                 |                                                                   |
                 | +---------------+                                                 |
                 | |fork           |                                                 |
+-------------   | |               |     +-----------+                               |
|            |   | |               |     |           |                               |
|batches[i]  +-------------------------> | batches[i]|                               |
|            |   | |               |     |           |                               |
+-------------   | |               |     +-----------+                               |
                 | |               |                                                 |
                 | |               |                                                 |
                 | |               |     +--------+    +-------+                     |
                 | |  get_phony +------> |        +--->+ Join  |                     |
                 | |               |     | phony  |    |       |                     |
                 | +---------------+     |        |    |       |                     |
                 |                       +--------+    |       |                     |
                 |                                     |       |                     |
+-------------   |                                     |       |    +--------------+ |
|            |   |                                     |       |    |              | |
|batches[i+1]+----------------------------------------------------->+ batches[i+1] | |
|            |   |                                     |       |    |              | |
+-------------   |                                     |       |    +--------------+ |
                 |                                     +-------+                     |
                 |                                                                   |
                 +-------------------------------------------------------------------+

我們進行橫向拓展,得到如下,即一個batch 被分成兩個小批次: batches[i],batches[i+1] ,它們在兩個設備 partitions[j],partitions[j + 1] 之上流水線,這樣行和列都有反向傳播的依賴。

                                 F[i,j]                                                                            F[i,j+1]

                    +------------------------------------------------+                            +-----------------------------------------------+
                    | partitions[j]                                  |                            |  partitions[j+1]                              |
                    |                                                |                            |                                               |
                    | +--------------------+   +------------------+  |                            | +-------------------+   +------------------+  |
                    | |fence               |   | compute          |  |                            | | fence             |   | compute          |  |
                    | |                    |   |                  |  |                            | |                   |   |                  |  |
+--------------+    | |  +--------------+  |   |  +------------+  |  |     +-----------------+    | |   +-------------+ |   |  +------------+  |  |       +-----------------+
|              |    | |  | depend       |  |   |  |forward     |  |  |     |                 |    | |   | depend      | |   |  |forward     |  |  |       |                 |
|  batches[i]  +---------------------------------------------------------> | batches[i][j]   +----------------------------------------------------------> | batches[i][j+1] |
|              |    | |  |              |  |   |  |            |  |  |     |                 |    | |   |             | |   |  |            |  |  |       |                 |
+--------------+    | |  |              |  |   |  |            |  |  |     +-----------------+    | |   |             | |   |  |            |  |  |       +-----------------+
                    | |  |              |  |   |  +------------+  |  |                            | |   |             | |   |  +------------+  |  |
                    | |  |              |  |   |                  |  |                            | |   |             | |   |                  |  |
+--------------+    | |  |              |  |   +------------------+  |     +-----------------+    | |   |             | |   +------------------+  |       +-------------------+
|              |    | |  |              |  |                         |     |                 |    | |   |             | |                         |       |                   |
|  batches[i+1]+---------------------------------------------------------> | batches[i+1][j] +----------------------------------------------------------> | batches[i+1][j+1] |
|              |    | |  |              |  |                         |     |                 |    | |   |             | |                         |       |                   |
+--------------+    | |  +--------------+  |                         |     +-----------------+    | |   +-------------+ |                         |       +-------------------+
                    | |                    |                         |                            | |                   |                         |
                    | +--------------------+                         |                            | +-------------------+                         |
                    +------------------------------------------------+                            +-----------------------------------------------+

手機如下:

0x04 總結

下圖 $ m = 4, n = 3$。即,模型被分成3個子網絡,小批次被分割成 4個微批次。F 和 B 的下標是 (m, n)。

img

如上圖,這裡需要完成兩種依賴:

  • 行間依賴,就是 batch 之間的依賴,就是設備內的依賴。從圖上看是虛線,就是 \(F_{1,1}\) 必須在 \(F_{2,1}\)之前完成,\(B_{2,1}\) 必須在\(B_{1,1}\) 之前完成。
  • 列間依賴,就是 partitions(設備) 之間的依賴。從圖上看是實線,就是 \(F_{1,1}\) 必須在 \(F_{1,2}\)之前完成,即第一個設備必須在第二個設備之前完成,而且第一個設備的輸出是第二個設備的輸入。

如上圖,我們需要完成行,列兩方面的依賴。

  • 行間依賴是用 Join & Fork 來保證,利用空張量完成了依賴關係的設定,確保 batches[i-1] 在 batches[i] 之後完成。
  • 列間依賴是通過 batches[i] = batch 完成,利用 PermuteBackward 來完成了設備之間的依賴。

至此,我們完成了執行順序和依賴關係的設定,下一篇我們介紹如何並行處理。

0xFF 參考

Markdown公式用法大全

markdown中公式編輯教程

//docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html#stream-sync-behavior

CUDA學習:基礎知識小結

CUDA隨筆之Stream的使用

NVIDIA解決方案架構師深度解析大規模參數語言模型Megatron-BERT

Accelerating Wide & Deep Recommender Inference on GPUs

HugeCTR: High-Performance Click-Through Rate Estimation Training

//discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548

//github.com/NVIDIA/apex/

//github.com/justheuristic/prefetch_generator

//pytorch.org/tutorials/intermediate/model_parallel_turotial.html

//pytorch.org/docs/stable/autograd.html

//pytorch.org/docs/notes/cuda.html

//zhuanlan.zhihu.com/p/61765561

//pytorch.apachen.org/docs/1.7/64.html

//zhidx.com/p/217999.html