[源碼解析] 模型並行分散式訓練 Megatron (3) -模型並行實現 - ⎝⎛CodingNote.cc ⎞⎠
­

[源碼解析] 模型並行分散式訓練 Megatron (3) —模型並行實現

[源碼解析] 模型並行分散式訓練 Megatron (3) —模型並行實現

0x00 摘要

NVIDIA Megatron 是一個基於 PyTorch 的分散式訓練框架,用來訓練超大Transformer語言模型,其通過綜合應用了數據並行,Tensor並行和Pipeline並行來複現 GPT3,值得我們深入分析其背後機理。

本系列大概有6~7篇文章,通過論文和源碼和大家一起學習研究。本文將看看 Megatron 如何處理模型並行。

本系列其他文章為:

[源碼解析] 模型並行分散式訓練Megatron (1) — 論文 & 基礎

[源碼解析] 模型並行分散式訓練Megatron (2) — 整體架構

0x01 並行Transformer層

在論文篇之中,我們了解到,因為模型越來越大,其尺寸遠遠超過了處理器的記憶體限制,因此產生了諸如激活檢查點(activation checkpointing)這樣的記憶體管理技術。而模型並行則通過對模型進行各種分片來克服單個處理器記憶體限制,這樣模型權重和其關聯的優化器狀態就可以分散到多個設備之上。

ParallelTransformerLayer 就是對 Transformer 層的並行實現,所以我們接著分析。

1.1 初始化

ParallelTransformerLayer 初始化方法之中,建立了如下:

  • 生成一個LayerNorm處理輸入數據。
  • 生成並行Attention。
  • 生成處理attention輸出的LayerNorm。
  • 如果是decoder,則生成一個ParallelAttention。
  • 生成一個並行MLP。
class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

    Transformer layer takes input with size [b, s, h] and returns an
    output of the same size.
    """

    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
                 self_attn_mask_type=AttnMaskType.padding):
        args = get_args()

        super(ParallelTransformerLayer, self).__init__()
        self.layer_number = layer_number
        self.layer_type = layer_type

        self.apply_residual_connection_post_layernorm \
            = args.apply_residual_connection_post_layernorm

        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

        # Layernorm on the input data.
        self.input_layernorm = LayerNorm( # 生成一個LayerNorm處理輸入數據
            args.hidden_size,
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)

        # Self attention.
        self.self_attention = ParallelAttention( # 生成並行Attention
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion

        # Layernorm on the attention output
        self.post_attention_layernorm = LayerNorm( # 生成處理attention輸出的LayerNorm
            args.hidden_size,
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)

        if self.layer_type == LayerType.decoder: # 如果本層是decoder
            self.inter_attention = ParallelAttention( # 則生成一個ParallelAttention
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)

        # MLP
        self.mlp = ParallelMLP(init_method, # 生成一個並行MLP
                               output_layer_init_method)

對應就是:

1.2 前向傳播

其前向傳播方法如下,就是調用各種成員函數進行前向操作。

def forward(self, hidden_states, attention_mask,
            encoder_output=None, enc_dec_attn_mask=None,
            inference_params=None):
    # hidden_states: [b, s, h]

    # Layer norm at the beginning of the transformer layer.
    layernorm_output = self.input_layernorm(hidden_states) # 對輸入進行處理
    
    # Self attention.
    attention_output, attention_bias = \ # attention操作
        self.self_attention(
            layernorm_output,
            attention_mask,
            inference_params=inference_params)

    # Residual connection. 殘差連接
    if self.apply_residual_connection_post_layernorm:
        residual = layernorm_output #norm之後結果作為X
    else:
        residual = hidden_states # 原始輸入X

    # jit scripting for a nn.module (with dropout) is not
    # trigerring the fusion kernel. For now, we use two
    # different nn.functional routines to account for varying
    # dropout semantics during training and inference phases.
    if self.bias_dropout_fusion: # dropout操作
        if self.training:
            bias_dropout_add_func = bias_dropout_add_fused_train
        else:
            bias_dropout_add_func = bias_dropout_add_fused_inference
    else:
        bias_dropout_add_func = get_bias_dropout_add(self.training)

    # re-enable torch grad to enable fused optimization.
    with torch.enable_grad():
        layernorm_input = bias_dropout_add_func( # dropout操作
            attention_output,
            attention_bias.expand_as(residual),
            residual,
            self.hidden_dropout)

    # Layer norm post the self attention.
    layernorm_output = self.post_attention_layernorm(layernorm_input) # 處理attention輸出

    if self.layer_type == LayerType.decoder:
        attention_output, attention_bias = \
            self.inter_attention(layernorm_output,
                                 enc_dec_attn_mask,
                                 encoder_output=encoder_output)
        # residual connection
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        # re-enable torch grad to enable fused optimization.
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

        # Layer norm post the decoder attention
        layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

    # MLP.
    mlp_output, mlp_bias = self.mlp(layernorm_output) # MLP操作 

    # Second residual connection.
    if self.apply_residual_connection_post_layernorm: # 殘差操作
        residual = layernorm_output
    else:
        residual = layernorm_input

    # re-enable torch grad to enable fused optimization.
    with torch.enable_grad():
        output = bias_dropout_add_func( # dropout操作
            mlp_output,
            mlp_bias.expand_as(residual),
            residual,
            self.hidden_dropout)

    return output

0x02 並行MLP

ParallelTransformerLayer 裡面包含了 Attention 和 MLP,因為篇幅所限,我們這裡主要對MLP進行分析。對於 Attention 則簡單研究一下其行切分機制,畢竟我們想了解的是如何進行模型並行,而非深入理解Transformer。

Megatron的並行MLP包含了兩個線性層,第一個線性層實現了 hidden size 到 4 x hidden size 的轉換,第二個線性層實現了 4 x hidden size 回到 hidden size。具體 MLP 的邏輯如下:

圖:具有模型並行性的 MLP。f和g表示和通訊切塊相關的操作,其是共軛的。f 的前向傳播是一個identity運算符,而後向傳播是一個all-reduce,g 的前向傳播是 all-reduce,後向傳播是一個identity運算符。這裡的 f 來自 ColumnParallelLinear,g 來自 RowParallelLinear。即,MLP 就是把 ColumnParallelLinear 和 RowParallelLinear 結合起來。

於是,這裡焦點問題就是:如何把這兩種線性層切開到不同的GPU卡之上?參見前文,這裡採用了第二種方案,

另一個選項是沿列拆分A,得到 \(A=[A_1,A_2]\)。該分區允許GeLU非線性獨立應用於每個分區GEMM的輸出:

\[\begin{bmatrix}
Y_1& Y_2
\end{bmatrix}= \begin{bmatrix}
GeLU(XA_1),GeLU(XA_2)
\end{bmatrix}
\]

這個方法更好,因為它刪除了同步點,直接把兩個 GeLU 的輸出拼接在一起就行。因此,我們以這種列並行方式劃分第一個GEMM,並沿其行分割第二個GEMM,以便它直接獲取GeLU層的輸出,而不需要任何其他通訊(比如 all-reduce 就不需要了),如圖所示。

我們再深入分析一下為何選擇這個方案。

按照常規邏輯,MLP 的前向傳播應該分為兩個階段,分別對應了下面圖之中的兩行,

  • 第一行是把參數 A 按照列切分,然後把結果按照列拼接起來,得到的結果就是與不使用並行策略完全等價的結果。
  • 第二行是把激活 Y 按照列切分,參數B按照行切分做並行,最後把輸出做加法,得到 Z。

但是每個split會導致兩次額外的通訊(前向傳播和後向傳播各一次,下面只給出了前向傳播)。因為對於第二行來說,其輸入Y其實本質是 XA1,XA2並行的,所以為了降低通訊量,我們可以把數據通訊延後或者乾脆取消通訊,就是把第一行最後的 all_gather 和第二行最初的 split 省略掉,這其實就是數學上的傳遞性和結合律(局部和之和為全局和)。於是我們就得到了論文之中的第二種方案。

結合程式碼,就是:

  • ColumnParallelLinear 實現了 MLP 的前半部分或者考慮了這個線性層獨立使用的情況。
  • RowParallelLinear 實現了 MLP 的後半部分或者考慮了這個線性層獨立使用的情況。

j

2.1 命名規範

我們首先看看命名規範,後文使用如下:

  • h: hidden size
  • n: number of attention heads
  • p: number of model parallel partitions
  • np: n/p
  • hp: h/p
  • hn: h/n
  • b: batch size
  • s: sequence length
  • l: number of layers
  • Transformer 的輸入size是 [s, b, h],返回一個同樣size的張量,我們使用 hyperparameters 作為transformer 的超參數。

2.2 MLP 程式碼

2.2.1 初始化

megatron/model/transformer.py 之中有 ParallelMLP 定義如下:

  • 定義了一個 ColumnParallelLinear 用來進行第一個 H 到 4 H 的轉換。
  • 然後是一個 gelu。
  • 接著是 RowParallelLinear 用來進行 4H 到 H 的轉換回來。

dropout操作是在上面ParallelTransformerLayer的forward之中進行。

所以,MLP大致如圖,這裡A,B是各自的權重矩陣:

也就是對應論文之中這個圖形。

程式碼如下。

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

    def __init__(self, init_method, output_layer_init_method):
        super(ParallelMLP, self).__init__()
        args = get_args()

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear( # 列切分
            args.hidden_size,
            args.ffn_hidden_size,
            gather_output=False, # 這裡是false,採用第二種方案
            init_method=init_method,
            skip_bias_add=True)

        self.bias_gelu_fusion = args.bias_gelu_fusion # gelu
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear( # 行切分
            args.ffn_hidden_size,
            args.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            skip_bias_add=True)

2.2.2 前向操作

這裡分別調用了 ColumnParallelLinear 完成了 H 到 4H 的轉換,RowParallelLinear 完成了 4H 到 H 的轉換。

def forward(self, hidden_states):

    # [s, b, 4hp]
    intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) # 縱向切分

    if self.bias_gelu_fusion:
         intermediate_parallel = \
                 bias_gelu_impl(intermediate_parallel, bias_parallel)
    else:
        intermediate_parallel = \
            self.activation_func(intermediate_parallel + bias_parallel)

    # [s, b, h]
    output, output_bias = self.dense_4h_to_h(intermediate_parallel) # 橫向切分
    return output, output_bias

我們接下來分別介紹 ColumnParallelLinear 和 RowParallelLinear。ColumnParallelLinear 分別可以獨立使用或者作為 ParallelMLP 的前半段,RowParallelLinear 也可以獨立使用或者作為 ParallelMLP 的後半段。

0x03 ColumnParallelLinear

ColumnParallelLinear 就是按列進行切分,也就是縱刀流。注意,這裡說的是對權重進行列切分。就是:

\[Y = XA = X[A_1, A_2] = [XA_1, XA_2]
\]

具體切分如下:

3.1 定義

因為 Python 語言特性,這裡有用的只是注釋,從注釋中可以看出來,對於 $ Y = XA + b \(,A 被以如下方式進行並行化:\) A = [A_1, …, A_p] $

class ColumnParallelLinear(torch.nn.Module):
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias
        gather_output: If true, call all-gether on output and make Y avaiable
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
        skip_bias_add: This was added to enable performance optimations where bias
                       can be fused with other elementwise operations. we skip 
                       adding bias but instead return it.
    """

3.2 初始化

初始化程式碼之中主要是用切分的資訊來初始化權重。

def __init__(self, input_size, output_size, bias=True, gather_output=True,
             init_method=init.xavier_normal_, stride=1,
             keep_master_weight_for_test=False,
             skip_bias_add=False):
    super(ColumnParallelLinear, self).__init__()

    # Keep input parameters
    self.input_size = input_size
    self.output_size = output_size
    self.gather_output = gather_output
    # Divide the weight matrix along the last dimension.
    world_size = get_tensor_model_parallel_world_size() # 獲得本tensor並行組的world size
    self.output_size_per_partition = divide(output_size, world_size) # 獲得本子模型應輸出size
    self.skip_bias_add = skip_bias_add

    # Parameters.
    # Note: torch.nn.functional.linear performs XA^T + b and as a result
    # we allocate the transpose.
    # Initialize weight.
    args = get_args()
    if args.use_cpu_initialization:
        # 用切分的size初始化權重
        self.weight = Parameter(torch.empty(self.output_size_per_partition,
                                            self.input_size,
                                            dtype=args.params_dtype))
        self.master_weight = _initialize_affine_weight_cpu( # 初始化權重
            self.weight, self.output_size, self.input_size,
            self.output_size_per_partition, 0, init_method,
            stride=stride, return_master_weight=keep_master_weight_for_test)
    else:
        # 用切分的size初始化權重
        self.weight = Parameter(torch.empty(
            self.output_size_per_partition, self.input_size,
            device=torch.cuda.current_device(), dtype=args.params_dtype))
        _initialize_affine_weight_gpu(self.weight, init_method, # 初始化權重
                                      partition_dim=0, stride=stride)

    if bias:
        if args.use_cpu_initialization:
            # 用切分的size初始化權重
            self.bias = Parameter(torch.empty(
                self.output_size_per_partition, dtype=args.params_dtype))
        else:
            # 用切分的size初始化權重
            self.bias = Parameter(torch.empty(
                self.output_size_per_partition,
                device=torch.cuda.current_device(),
                dtype=args.params_dtype))
        set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
        # Always initialize bias to zero.
        with torch.no_grad():
            self.bias.zero_()
    else:
        self.register_parameter('bias', None)
    self.async_tensor_model_parallel_allreduce = (
            not args.no_async_tensor_model_parallel_allreduce and
            world_size > 1)

3.2.1 切分size

self.output_size_per_partition = divide(output_size, world_size) 這裡有一個分割size操作,得到每個子模型應該擁有的權重大小。

def ensure_divisibility(numerator, denominator):
    """Ensure that numerator is divisible by the denominator."""
    assert numerator % denominator == 0, '{} is not divisible by {}'.format(
        numerator, denominator)


def divide(numerator, denominator):
    """Ensure that numerator is divisible by the denominator and return
    the division value."""
    ensure_divisibility(numerator, denominator)
    return numerator // denominator

3.2.2 初始化權重

以下程式碼實現了初始化權重。

def _initialize_affine_weight_gpu(weight, init_method,
                                  partition_dim, stride=1):
    """Initialize affine weight for model parallel on GPU."""

    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)

    with get_cuda_rng_tracker().fork():
        init_method(weight)


def _initialize_affine_weight_cpu(weight, output_size, input_size,
                                  per_partition_size, partition_dim,
                                  init_method, stride=1,
                                  return_master_weight=False):
    """Initialize affine weight for model parallel.

    Build the master weight on all processes and scatter
    the relevant chunk."""

    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)

    # Initialize master weight
    master_weight = torch.empty(output_size, input_size,
                                dtype=torch.float,
                                requires_grad=False)
    init_method(master_weight)
    args = get_args()
    master_weight = master_weight.to(dtype=args.params_dtype)

    # Split and copy
    per_partition_per_stride_size = divide(per_partition_size, stride)
    weight_list = torch.split(master_weight, per_partition_per_stride_size,
                              dim=partition_dim)
    rank = get_tensor_model_parallel_rank()
    world_size = get_tensor_model_parallel_world_size()
    my_weight_list = weight_list[rank::world_size]

    with torch.no_grad():
        torch.cat(my_weight_list, dim=partition_dim, out=weight)
    if return_master_weight:
        return master_weight
    return None

3.3 邏輯梳理

為了更好的分析,我們引入下圖(來自參考1),這個圖對應了 ColumnParallelLinear 類的前向傳播和後向傳播過程這裡的 f 和 g 操作其實是從程式碼之中抽象出來的,可以理解為 f 是對輸入的處理,g 則是處理之後得到最終輸出。此處對應了論文中描述的粗體字:

Figure 3. Blocks of Transformer with Model Parallelism. f and g are conjugate. f is an identity operator in the forward pass and all reduce in the backward pass while g is an all reduce in the forward pass and identity in the backward pass.

圖片來自 GTC 2020: Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism

我們針對上圖,梳理一下邏輯。

3.3.1 前向傳播

我們一步一步細化。

首先,總體語義為:Y = XA + b。

其次,前向傳播時候的邏輯如下:

  • 輸入:這裡 A 沿著列做切分,X 是全部的輸入(每個GPU都擁有相同的X)。
  • 計算:經過計算之後,輸出的 \(Y_1, Y_2\) 也是按照列被切分過的。每個GPU只有自己對應的分區。
  • 輸出:\(Y_1, Y_2\) 只有合併在一起,才能得到最終輸出的 Y。

再次,我們使用operator來細化一下:

  • 輸入:因為每個GPU需要拿到一個完整的輸入 X,所以前向操作之中需要把X分發到每個GPU,這樣就使用了 Identity 操作
  • 計算:經過計算之後,輸出的 \(Y_1, Y_2\) 也是按照列被切分過的。每個GPU只有自己對應的分區。
  • 輸出:因為\(Y_1, Y_2\) 需要合併在一起,才能得到最終輸出的 Y。所以需要有一個 all-gather 操作來進行聚合,即得到 $ Y = [Y_1, Y_2]$。

我們把這些邏輯點在上圖上用紅色方框標示,輸入 X 先經過 f 來處理,輸出 Y 是 g 整合之後的結果。

3.3.2 後向傳播

我們接下來看看後向傳播,對於上圖來說,後向傳播是從上至下,梯度先經過 g,最後被 f 處理。

反向傳播的邏輯如下:

  • 目前得到了反向傳播上游傳過來的梯度 \(\frac{\partial L}{\partial Y}\),現在需要對其進行切分,保證每個GPU之上都有一份梯度 \(\frac{\partial L}{\partial Y_i}\)。操作是\(\frac{\partial L}{\partial Y_i}(split)\)
  • 每個GPU之上會進行關於X的梯度計算,於是每個GPU都有一份對X的梯度(但是其內容不一樣)。
  • 最後需要把各個 GPU 之上關於X的梯度進行相加,得到完整梯度,這就需要一個 all-reduce 操作。即 $\frac{\partial L}{\partial X} = \frac{\partial L}{\partial X} |_1 + \frac{\partial L}{\partial X} |_2 $

所以我們在圖上用藍色圓角矩形標示出來後向傳播對應的運算元。

3.4 程式碼實現

我們接下來結合程式碼來分析。

3.3.1 ColumnParallelLinear

ColumnParallelLinear 的 forward 程式碼之中,主要是實施了 f 和 g 的forward操作,同時把 f 和 g 的backward 操作搭建起來,具體如下:

  • 如果配置了非同步操作,則使用 ColumnParallelLinearWithAsyncAllreduce 完成 f 運算符的功能,這一個函數包括了identity 操作,矩陣乘法,搭建後向傳播操作。
  • 如果是同步操作,則:
    • 使用 copy_to_tensor_model_parallel_region 完成前向傳播 identity 操作,建立反向傳播all-reduce,就是圖中f的backward。identity 操作 就是把輸入 X 完整的拷貝到多個GPU之上,類似 X 通過 f 的前向操作,變成了 [X, X, …, X]。
    • 使用 linear 對 [X, X, …, X] 和 權重 A 完成矩陣乘法操作。
  • 如果gather_output為True,則在前向傳播時候把 \(Y_i\) 做all-gather,因為反向傳播時需要把完整梯度scatter到對應GPU之上,所以要搭建對於的split操作。MLP實現之中,此處設置為 False,這樣每個GPU輸出的是自己partition 的 4h/p,直接傳送給下一個線性層
def forward(self, input_):
    # 如果選擇忽略bias,就會設置為None,後續就不用處理了
    bias = self.bias if not self.skip_bias_add else None

    # 下面主要是圖中的 f 操作
    if self.async_tensor_model_parallel_allreduce:
        # 建立反向傳播時候的非同步all-reduce
        input_shape = input_.shape
        input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
        # Maxtrix multiply with asynchronouse all-reduce execution
        output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
                input_, self.weight, bias)
        output_parallel = output_parallel.view(
                input_shape[0], input_shape[1], output_parallel.shape[1])
    else:
        # Set up backprop all-reduce.、
        # 建立反向傳播all-reduce,就是圖中f的backward
        input_parallel = copy_to_tensor_model_parallel_region(input_) 

        # Matrix multiply.
        output_parallel = F.linear(input_parallel, self.weight, bias) # 矩陣乘法操作

    # 下面就是圖中的 g 操作    
    if self.gather_output: # 是否需要聚合操作
        # All-gather across the partitions.
        # 聚合輸出,就是圖中g的forward
        output = gather_from_tensor_model_parallel_region(output_parallel) #
    else:
        output = output_parallel
        
    output_bias = self.bias if self.skip_bias_add else None # 如果不忽略bias,還得傳出去
    return output, output_bias

3.3.2 f 操作

F 操作是對輸入進行初步處理,具體是:

  • 前向傳播時候直接拷貝。
  • 後向傳播做all-reduce。

3.3.2.1 同步操作

這裡我們主要分析 copy_to_tensor_model_parallel_region,其做了前向copy操作,同時構建了後向 all-reduce。

def copy_to_tensor_model_parallel_region(input_):
    return _CopyToModelParallelRegion.apply(input_)

我們還是需要看看 _CopyToModelParallelRegion。可以看到,其 forward 就是簡單的把輸入轉移到輸出,就是對應了前向複製identity。

class _CopyToModelParallelRegion(torch.autograd.Function):
    """Pass the input to the model parallel region."""

    @staticmethod
    def symbolic(graph, input_):
        return input_
    
    @staticmethod
    def forward(ctx, input_):
        return input_ # 簡單的把輸入轉移到輸出,就是對應了前向複製identity

    @staticmethod
    def backward(ctx, grad_output):
        return _reduce(grad_output) # 反向傳播時候,輸入是多個GPU上的梯度整體,通過all-reduce合併

對應的後向傳播就使用了All-reduce,反向傳播時候,輸入是多個GPU上的梯度整體,通過all-reduce合併。

def _reduce(input_):
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if get_tensor_model_parallel_world_size()==1:
        return input_

    # All-reduce.
    torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())

    return input_
3.3.2.2 非同步 All-Reduce

ColumnParallelLinearWithAsyncAllreduce 這裡把同步之中的乘法操作也放置進來。

class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
    """
    Column-parallel linear layer execution with asynchronous all-reduce
    execution in backprop.
    """
    @staticmethod
    def forward(ctx, input, weight, bias):
        ctx.save_for_backward(input, weight)
        ctx.use_bias = bias is not None
        output = torch.matmul(input, weight.t()) # 同步時候的乘法也在這裡了
        if bias is not None:
            output = output + bias
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors
        use_bias = ctx.use_bias
        grad_input = grad_output.matmul(weight)
        # Asyncronous all-reduce
        handle = torch.distributed.all_reduce( # 反向傳播操作
                grad_input, group=get_tensor_model_parallel_group(), async_op=True)
        # Delay the start of weight gradient computation shortly (3us) to have
        # all-reduce scheduled first and have GPU resources allocated
        _ = torch.empty(1, device=grad_output.device) + 1
        grad_weight = grad_output.t().matmul(input)
        grad_bias = grad_output.sum(dim=0) if use_bias else None
        handle.wait()
        return grad_input, grad_weight, grad_bias

3.3.3 g 操作

以下對應了圖之中的 g 操作。G操作是最終生成輸出Y,邏輯是:

  • 前向傳播時候做 all-gather;
  • 後向傳播需要執行 split,把梯度scatter到不同GPU之上。

def gather_from_tensor_model_parallel_region(input_):
    return _GatherFromModelParallelRegion.apply(input_)

具體程式碼如下:

class _GatherFromModelParallelRegion(torch.autograd.Function):
    """Gather the input from model parallel region and concatinate."""

    @staticmethod
    def symbolic(graph, input_):
        return _gather(input_)
    
    @staticmethod
    def forward(ctx, input_):
        return _gather(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _split(grad_output)

3.3.4 基礎函數

我們接下來看看上面用到的一些基礎函數。

3.3.4.1 gather

_gather 是沿著最後一個維度進行拼接。

def _gather(input_):
    """Gather tensors and concatinate along the last dimension."""

    world_size = get_tensor_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if world_size==1:
        return input_

    # Size and dimension.
    last_dim = input_.dim() - 1
    rank = get_tensor_model_parallel_rank() # 獲得本worker在tensor並行之中的rank

    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
    tensor_list[rank] = input_
    # 在本 tensor 進程組之間進行 all-gather操作
    torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())

    # Note: torch.cat already creates a contiguous tensor.
    output = torch.cat(tensor_list, dim=last_dim).contiguous()

    return output
3.3.4.2 split

_split 完成了張量切分操作。

def split_tensor_along_last_dim(tensor, num_partitions,
                                contiguous_split_chunks=False):
    """Split a tensor along its last dimension.
    Arguments:
        tensor: input tensor.
        num_partitions: number of partitions to split the tensor
        contiguous_split_chunks: If True, make each chunk contiguous
                                 in memory.
    """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1
    last_dim_size = divide(tensor.size()[last_dim], num_partitions) # 得到每個切分的size
    # Split.
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) # 對張量進行切分
    # Note: torch.split does not create contiguous tensors by default.
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tensor_list

def _split(input_):
    """Split the tensor along its last dimension and keep the
    corresponding slice."""

    world_size = get_tensor_model_parallel_world_size() # 獲取本tensor進程組的world size
    # Bypass the function if we are using only 1 GPU.
    if world_size==1:
        return input_

    # Split along last dimension.
    input_list = split_tensor_along_last_dim(input_, world_size)

    # Note: torch.split does not create contiguous tensors by default.
    rank = get_tensor_model_parallel_rank() # 獲取自己的rank
    output = input_list[rank].contiguous() # 獲取切分後,自己對應的rank

    return output

其中,get_tensor_model_parallel_rank 作用是獲取本進程在tensor並行組的rank。

def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())

0x04 RowParallelLinear

RowParallelLinear 這裡是按照行進行切分,就是橫刀流,注意這裡是對權重A實施行切分。比如公式為 Y = XA,X是輸入,A是權重,Y是輸出,行切分就是針對A的第一個維度進行切分,這裡 \(X_1\) 最後一個維度等於 \(A_1\) 第一個維度。

\[XA = \begin{bmatrix}X_1,X_2\end{bmatrix} \begin{bmatrix}A_1 \\ A_2\end{bmatrix} = X_1 A_1 + X_2 A_2 = Y_1 + Y_2 = Y
\]

具體如下:

4.1 定義

定義之中只有注釋有用,可以看出來如何切分。

class RowParallelLinear(torch.nn.Module):
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
        skip_bias_add: This was added to enable performance optimization where bias
                       can be fused with other elementwise operations. We skip
                       adding bias but instead return it.
    """

4.2 初始化

和列切分類似,初始化之中主要是獲取每個權重分區的大小,然後據此切分權重。

def __init__(self, input_size, output_size, bias=True,
             input_is_parallel=False,
             init_method=init.xavier_normal_, stride=1,
             keep_master_weight_for_test=False,
             skip_bias_add=False):
    super(RowParallelLinear, self).__init__()

    # Keep input parameters
    self.input_size = input_size
    self.output_size = output_size
    self.input_is_parallel = input_is_parallel
    # Divide the weight matrix along the last dimension.
    world_size = get_tensor_model_parallel_world_size()
    self.input_size_per_partition = divide(input_size, world_size) # 獲取每個權重分區的大小
    self.skip_bias_add = skip_bias_add

    # Parameters.
    # Note: torch.nn.functional.linear performs XA^T + b and as a result
    # we allocate the transpose.
    # Initialize weight.
    args = get_args()
    if args.use_cpu_initialization:
        self.weight = Parameter(torch.empty(self.output_size,
                                            self.input_size_per_partition,
                                            dtype=args.params_dtype))
        # 切分權重
        self.master_weight = _initialize_affine_weight_cpu(
            self.weight, self.output_size, self.input_size,
            self.input_size_per_partition, 1, init_method,
            stride=stride, return_master_weight=keep_master_weight_for_test)
    else:
        self.weight = Parameter(torch.empty(
            self.output_size, self.input_size_per_partition,
            device=torch.cuda.current_device(), dtype=args.params_dtype))
        # 切分權重
        _initialize_affine_weight_gpu(self.weight, init_method,
                                      partition_dim=1, stride=stride)
    if bias:
        if args.use_cpu_initialization:
            self.bias = Parameter(torch.empty(self.output_size,
                                              dtype=args.params_dtype))
        else:
            self.bias = Parameter(torch.empty(
                self.output_size, device=torch.cuda.current_device(),
                dtype=args.params_dtype))
        # Always initialize bias to zero.
        with torch.no_grad():
            self.bias.zero_()
    else:
        self.register_parameter('bias', None)

4.3 邏輯梳理

為了更好的分析,我們引入下圖(來自參考1),這個圖對應了 RowParallelLinear 類的前向傳播和後向傳播過程這裡的 f 和 g 操作其實是從程式碼之中抽象出來的,可以理解為 f 是對輸入的處理,g 則是處理之後得到最終輸出

我們針對上圖,梳理一下邏輯。

4.3.1 前向傳播

我們一步一步細化。

首先,總體語義為:Y = XA + b。

其次,前向傳播時候的邏輯如下:

  • 輸入:這裡 A 沿著行做切分,因為A的維度發生了變化,所以X也需要做相應變化,X就必須按照列做切分,這樣 X 每個分塊才能與A 每個分塊進行相乘。這裡如果輸入是已經split過的(input_is_parallel 為True),則就不需要再進行split。
  • 計算:計算就是 \(Y_1 = X_1 A_1\)\(Y_2 = X_2A_2\)。經過計算之後,輸出的 \(Y_1, Y_2\) 的shape就是最終 Y 的shape。每個GPU只有自己對應的分區。
  • 輸出:\(Y_1, Y_2\) 只有合併在一起,才能得到最終輸出的 Y。但是因為 \(Y_1, Y_2\) 形狀相同,都等於Y的形狀,所以只要簡單矩陣相加即可。

再次,我們使用operator來細化一下:

  • 輸入:需要對 X 進行縱向切分,這就是一個split操作,得到了 \([X_1, X_2]\),這兩個分區要分別放到兩個GPU之上。
  • 計算:經過計算之後,每個GPU只有自己對應的分區。
  • 輸出:因為\(Y_1, Y_2\) 需要合併在一起,才能得到最終輸出的 Y。這樣需要把 \(Y_1\)\(Y_2\) 相加(因為是兩個GPU,所以之間還有等待操作),這就是 all-reduce 操作

我們把這些邏輯點在上圖上用紅色方框標示,輸入 X 先經過 f 來處理,輸出 Y 是 g 整合之後的結果。

4.3.2 後向傳播

我們接下來看看後向傳播,對於上圖來說,後向傳播是從上至下,梯度先經過 g,最後被 f 處理。

反向傳播的邏輯如下:

  • 目前得到了反向傳播上游傳過來的梯度 \(\frac{\partial L}{\partial Y}\),因為 \(Y_1, Y_2\) 的形狀相同,所以直接把梯度 \(\frac{\partial L}{\partial Y}\)傳給每個GPU即可,操作是\(\frac{\partial L}{\partial Y_i} = \frac{\partial L}{\partial Y} (id entity)\)。這裡解釋一下,在前向傳播時候,XA 的結果需要 all-reduce,可以理解為 sum operator,所以反向傳播時候直接拷貝梯度即可。
  • 每個GPU之上會進行關於X的梯度計算,於是每個GPU都有一份對X的梯度(但是其內容不一樣)。
  • 最後需要把各個 GPU 之上關於X的梯度進行聚合,得到完整梯度,就是forward 之中 split 的反向操作,按照最後一列對梯度進行拼接,即all-gather操作。

所以我們在圖上用藍色圓角矩形標示出來後向傳播對應的運算元。

4.4 程式碼實現

我們接下來看看程式碼如何實現。

4.4.1 RowParallelLinear

RowParallelLinear 的 forward 程式碼之中,主要是實施了 f 和 g 的forward操作,同時把 f 和 g 的backward 操作搭建起來,具體如下:

def forward(self, input_):
    # 這裡,輸入的張量已經被分割到每個GPU,輸出張量是all-reduce之後的整體
    # Set up backprop all-reduce.
    if self.input_is_parallel:  # 是否已經是split的輸入
        # Transformer's MLP 到達這裡,因為已經split,所以直接就接了輸入,不會scatter
        input_parallel = input_
    else:
        # 獨立 row parallel 線性層到這裡,會進行前向切分和後向拼接
        input_parallel = scatter_to_tensor_model_parallel_region(input_)
        
    # Matrix multiply.
    # 比如 X_i 和 A_i 進行乘法操作
    output_parallel = F.linear(input_parallel, self.weight)
    
    # All-reduce across all the partitions.
    # 進行前向all-reduce操作,這樣每個GPU之上都是完整的最新結果,同時搭建了後向的identity操作。
    output_ = reduce_from_tensor_model_parallel_region(output_parallel)
    if not self.skip_bias_add:
        # 加上bias
        output = output_ + self.bias if self.bias is not None else output_
        output_bias = None
    else:
        output = output_
        output_bias = self.bias
    return output, output_bias

4.4.1 f 操作

scatter_to_tensor_model_parallel_region 對應了f操作,其作用是:

  • 前向切分split輸入,同時搭建後向的 all-gather 操作。
  • 後向操作進行 all-gather 操作。

程式碼為:

def scatter_to_tensor_model_parallel_region(input_):
    return _ScatterToModelParallelRegion.apply(input_)

具體 _ScatterToModelParallelRegion 完成了實際業務,具體 _split, _gather 操作在前面都介紹過。

class _ScatterToModelParallelRegion(torch.autograd.Function):
    """Split the input and keep only the corresponding chuck to the rank."""

    @staticmethod
    def symbolic(graph, input_):
        return _split(input_)

    @staticmethod
    def forward(ctx, input_):
        return _split(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _gather(grad_output)

4.4.2 g 操作

reduce_from_tensor_model_parallel_region 對應了 g 操作,作用是:

  • 前向操作是 all-reduce之後得到最終輸出.

  • 反向操作則直接拷貝操作。

程式碼為:

def reduce_from_tensor_model_parallel_region(input_):
    return _ReduceFromModelParallelRegion.apply(input_)

具體業務如下:

class _ReduceFromModelParallelRegion(torch.autograd.Function):
    """All-reduce the input from the model parallel region."""

    @staticmethod
    def symbolic(graph, input_):
        return _reduce(input_)
    
    @staticmethod
    def forward(ctx, input_):
        return _reduce(input_) # 前面有介紹

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output #就是indentity 操作,直接把輸入拷貝到兩個GPU之上

0x05 Embedding

我們接下來看看 embedding。為了讓記憶體做到均衡配置,對embedding也會按照vocab維度來做shard操作,最終把分區放到多個GPU之上。這樣每個卡上都有嵌入表的一部分。

class VocabParallelEmbedding(torch.nn.Module):
    """Embedding parallelized in the vocabulary dimension.

    This is mainly adapted from torch.nn.Embedding and all the default
    values are kept.
    Arguments:
        num_embeddings: vocabulary size.
        embedding_dim: size of hidden state.
        init_method: method to initialize weights.
    """

    def __init__(self, num_embeddings, embedding_dim,
                 init_method=init.xavier_normal_):
        super(VocabParallelEmbedding, self).__init__()
        # Keep the input dimensions.
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        # Set the detauls for compatibility.
        self.padding_idx = None
        self.max_norm = None
        self.norm_type = 2.
        self.scale_grad_by_freq = False
        self.sparse = False
        self._weight = None
        self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
        # Divide the weight matrix along the vocaburaly dimension.
        self.vocab_start_index, self.vocab_end_index = \ # 得到分區的起始,終止位置
            VocabUtility.vocab_range_from_global_vocab_size(
                self.num_embeddings, get_tensor_model_parallel_rank(),
                self.tensor_model_parallel_size)
        self.num_embeddings_per_partition = self.vocab_end_index - \ # 得到分區內嵌入數目
            self.vocab_start_index

        # Allocate weights and initialize.
        args = get_args()
        if args.use_cpu_initialization:
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                dtype=args.params_dtype))
            _initialize_affine_weight_cpu( # 對權重進行分區
                self.weight, self.num_embeddings, self.embedding_dim,
                self.num_embeddings_per_partition, 0, init_method)
        else:
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
            _initialize_affine_weight_gpu(self.weight, init_method, # 對權重進行分區
                                          partition_dim=0, stride=1)

因為每一個GPU只是獲得了總體嵌入的一部分,所以對於每個worker來說,可能有一個輸入找不到嵌入,因此需要對embedding最終輸出做一個 all-reduce操作,這樣可以得到完整embedding。

def forward(self, input_):
        if self.tensor_model_parallel_size > 1:
            # Build the mask.
        		# input_mask 意思是單詞不在本worker的 embedding 分區範圍內,所以設置為0
            input_mask = (input_ < self.vocab_start_index) | \
                         (input_ >= self.vocab_end_index)
            # Mask the input.
            masked_input = input_.clone() - self.vocab_start_index
            masked_input[input_mask] = 0
        else:
            masked_input = input_
            # Get the embeddings.
        output_parallel = F.embedding(masked_input, self.weight,
                                      self.padding_idx, self.max_norm,
                                      self.norm_type, self.scale_grad_by_freq,
                                      self.sparse)
        # Mask the output embedding.
        if self.tensor_model_parallel_size > 1:
            output_parallel[input_mask, :] = 0.0
        # Reduce across all the model parallel GPUs.
        output = reduce_from_tensor_model_parallel_region(output_parallel)
        return output

0x06 總結

6.1 MLP並行

我們總結一下MLP的並行實現,具體如下圖,其中邏輯如下:

  • 中間灰色的是論文中的概念圖。
  • 聯繫程式碼之後,我們可以知道,其是由一個 ColumnParallelLinear 接上一個 RowParallelLinear 完成的,我們把概念圖轉化為圖左側兩個方框。
  • ColumnParallelLinear 是對權重進行列切分,RowParallelLinear 是對權重進行行切分。
  • 其中 ColumnParallelLinear 的 \(Y_1, Y_2\) 沒有經過 all-gather 操作(就是略過了 g 操作),而是直接輸入到了 RowParallelLinear 之中,接到了RowParallelLinear 的 \(X_1, X_2\),即,RowParallelLinear 沒有 f 操作。
  • 概念圖之中的 f 就是ColumnParallelLinear 的 f,g 就是 RowParallelLinear 的 g。具體邏輯如圖上所示。

6.2 共軛函數

論文之中提到了共軛函數。

f and g are conjugate. f is an identity operator in the forward pass and all reduce in the backward pass while g is an all reduce in the forward pass and identity in the backward pass.

我們前面程式碼之中也有使用到,我們整理出來如下,其中兩兩互為共軛函數。

  • copy_to_tensor_model_parallel_region 是前向操作copy(identity),後向操作 all-reduce。
  • reduce_from_tensor_model_parallel_region 是前向操作 all-reduce,後向操作 copy(identity)。

其實,就是MLP之中的 f,g 操作,這兩個是共軛函數。

類似,gather_from_tensor_model_parallel_region 是前向操作 all-gather,後向操作 scatter,這和scatter_to_tensor_model_parallel_region 也是共軛函數。

這些函數程式碼具體如下:

def copy_to_tensor_model_parallel_region(input_):
    return _CopyToModelParallelRegion.apply(input_)


def reduce_from_tensor_model_parallel_region(input_):
    return _ReduceFromModelParallelRegion.apply(input_)


def scatter_to_tensor_model_parallel_region(input_):
    return _ScatterToModelParallelRegion.apply(input_)


def gather_from_tensor_model_parallel_region(input_):
    return _GatherFromModelParallelRegion.apply(input_)

至此,我們已經完成了對模型並行實現的分析,下一篇我們看看在源碼之中如何設定各種並行配置。

0xFF 參考

//developer.nvidia.com/gtc/2020/slides/s21496-megatron-lm-training-multi-billion-parameter-language-models-using-model-parallelism.pdf

[細讀經典]Megatron論文和程式碼詳細分析(2)

[細讀經典]Megatron論文和程式碼詳細分析(1)

Megatron-LM源碼閱讀(一)

Megatron-LM源碼閱讀(二)

megatron學習總結

GTC 2020: Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism

大規模訓練之 transformer 中的張量模型並行