­

[源碼解析] PyTorch 分佈式(17) — 結合DDP和分佈式 RPC 框架

[源碼解析] PyTorch 分佈式(17) — 結合DDP和分佈式 RPC 框架

0x00 摘要

在前面的文章之中,我們已經學習了PyTorch 分佈式的基本模塊,接下來我們通過幾篇文章來看看如何把這些模塊應用到實踐之中,順便把PyTorch分佈式邏輯整體梳理一下。本文介紹如何把DDP和RPC framework結合起來。

本文以 COMBINING DISTRIBUTED DATAPARALLEL WITH DISTRIBUTED RPC FRAMEWORK 的翻譯為基礎,加入了自己的理解。

PyTorch分佈式其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[源碼解析]深度學習利器之自動微分(3) — 示例解讀

[源碼解析]PyTorch如何實現前向傳播(1) — 基礎類(上)

[源碼解析]PyTorch如何實現前向傳播(2) — 基礎類(下)

[源碼解析] PyTorch如何實現前向傳播(3) — 具體實現

[源碼解析] Pytorch 如何實現後向傳播 (1)—- 調用引擎

[源碼解析] Pytorch 如何實現後向傳播 (2)—- 引擎靜態結構

[源碼解析] Pytorch 如何實現後向傳播 (3)—- 引擎動態邏輯

[源碼解析] PyTorch 如何實現後向傳播 (4)—- 具體算法

[源碼解析] PyTorch 分佈式(1)——歷史和概述

[源碼解析] PyTorch 分佈式(2) —– DataParallel(上)

[源碼解析] PyTorch 分佈式(3) —– DataParallel(下)

[源碼解析] PyTorch 分佈式(4)——分佈式應用基礎概念

[源碼解析] PyTorch分佈式(5) —— DistributedDataParallel 總述&如何使用

[源碼解析] PyTorch分佈式(6) —DistributedDataParallel — 初始化&store

[源碼解析] PyTorch 分佈式(7) —– DistributedDataParallel 之進程組

[源碼解析] PyTorch 分佈式(8) ——– DistributedDataParallel之論文篇

[源碼解析] PyTorch 分佈式(9) —– DistributedDataParallel 之初始化

[源碼解析] PyTorch 分佈式(10)——DistributedDataParallel 之 Reducer靜態架構

[源碼解析] PyTorch 分佈式(11) —– DistributedDataParallel 之 構建Reducer和Join操作

[源碼解析] PyTorch 分佈式(12) —– DistributedDataParallel 之 前向傳播

[源碼解析] PyTorch 分佈式(13) —– DistributedDataParallel 之 反向傳播

[源碼解析] PyTorch 分佈式 Autograd (1) —- 設計

[源碼解析] PyTorch 分佈式 Autograd (2) —- RPC基礎

[源碼解析] PyTorch 分佈式 Autograd (3) —- 上下文相關

[源碼解析] PyTorch 分佈式 Autograd (4) —- 如何切入引擎

[源碼解析] PyTorch 分佈式 Autograd (5) —- 引擎(上)

[源碼解析] PyTorch 分佈式 Autograd (6) —- 引擎(下)

[源碼解析] PyTorch分佈式優化器(1)—-基石篇

[源碼解析] PyTorch分佈式優化器(2)—-數據並行優化器

[源碼解析] PyTorch分佈式優化器(3)—- 模型並行

[源碼解析] PyTorch 分佈式(14) –使用 Distributed Autograd 和 Distributed Optimizer

[源碼解析] PyTorch 分佈式(15) — 使用分佈式 RPC 框架實現參數服務器

[源碼解析] PyTorch 分佈式(16) — 使用異步執行實現批處理 RPC

註:本文沒有完全按照原文順序進行翻譯,而是按照自己理解的思路重新組織了文章。

0x00 綜述

本教程使用一個簡單的示例來演示如何將 DistributedDataParallel (DDP) 與分佈式 RPC 框架 相結合,將分佈式數據並行性與分佈式模型並行性相結合,以訓練一個簡單的模型。該示例的源代碼可以在這裡找到。

前面的教程 入門分佈式數據並行入門分佈式RPC框架 分別描述了如何執行分佈式數據並行和分佈式模型平行訓練。儘管如此,您可能希望在多種訓練範式中結合這兩種技術。例如:

  1. 如果我們有一個包含稀疏部分(大型嵌入表)和密集部分(FC 層)的模型,我們可能希望將嵌入表放在參數服務器上,並使用DistributedDataParallel在多個trainer之間複製 FC 層。分佈式RPC框架 就可被用於在參數服務器上執行嵌入查找。
  2. PipeDream論文中所述啟用混合併行性。我們可以使用分佈式 RPC 框架 將模型的各個階段跨多個worker 進行流水線化,並使用DistributedDataParallel 對每個階段進行數據並行(如果需要)。

在本教程中,我們將介紹上述案例 1。我們的設置中共有 4 個 worker,如下所示:

  • 1 個Master,負責在參數服務器上創建嵌入表(nn.EmbeddingBag)。master 還負責驅動兩個trainer上的訓練循環。
  • 1 個Parameter Server,它將嵌入表保存在內存中,並響應來自 Master 和 Trainer 的 RPC 請求。
  • 2 個trainer,它存儲一個 FC 層 (nn.Linear),其使用DistributedDataParallel 進行數據並行。trainer還負責執行前向傳播、後向傳播和優化器步驟。

整個訓練過程執行如下:

  1. Master 創建一個RemoteModule ,在參數服務器上保存一個嵌入表。
  2. Master 在trainer上啟動訓練循環,並將遠程模塊(remote module)傳播給trainer。
  3. Trainer 創建一個HybridModel,其首先使用 master 提供的遠程模塊執行嵌入查找(embedding lookup),然後執行封裝在 DDP 中的 FC 層。
  4. Trainer 執行模型的前向傳播,並使用Distributed Autograd 對損失執行後向傳播。
  5. 作為反向傳播的一部分,首先計算 FC 層的梯度,並通過 DDP 中的 allreduce 同步到所有trainer。
  6. 接下來,分佈式 Autograd 將梯度傳播到參數服務器,在那裡更新嵌入表的梯度。
  7. 最後,分佈式優化器被用於更新所有參數。

注意:如果您將 DDP 和 RPC 結合使用,則應始終使用Distributed Autograd進行反向傳播。

0x01 啟動

我們看看系統如何啟動。首先,在進行訓練之前,需要設置所有worker。我們創建了 4 個進程,其中 rank 0 和 rank 1 是我們的trainer,rank 2是master,rank 3是參數服務器。

初始化邏輯如下:

  • 我們使用 TCP init_method 在所有 4 個 worker 上初始化 RPC 框架。
  • 對於 Master,代碼做了如下操作:
    • 完成 RPC 初始化後,master 創建一個遠程模塊RemoteModule,該模塊指向一個在參數服務器上保存的EmbeddingBag層。
    • 然後 master 遍歷每個trainer,並通過使用rpc_async調用_run_trainer 在每個trainer之上啟動訓練循環。
    • 最後,master 在退出之前等待所有訓練完成。
  • Trainers做了如下操作:
    • Trainers 首先使用 init_process_group為DDP初始化一個world_size = 2(對於兩個trainer)的ProcessGroup
    • 接下來,Trainers 使用 TCP init_method 初始化 RPC 框架。注意RPC初始化和ProcessGroup初始化的端口是不同的。這是為了避免兩個框架的初始化之間的端口衝突。
    • 初始化完成後,trainer只需等待來自 master的_run_trainer RPC。
  • 參數服務器只是初始化 RPC 框架並等待來自trainer和master的 RPC。

具體代碼如下:

def run_worker(rank, world_size):
    r"""
    A wrapper function that initializes RPC, calls the function, and shuts down
    RPC.
    """

    # We need to use different port numbers in TCP init_method for init_rpc and
    # init_process_group to avoid port conflicts.
    rpc_backend_options = TensorPipeRpcBackendOptions()
    rpc_backend_options.init_method = "tcp://localhost:29501"

    # Rank 2 is master, 3 is ps and 0 and 1 are trainers.
    if rank == 2: # Master代碼
        rpc.init_rpc(
            "master",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )

        remote_emb_module = RemoteModule( # 指向一個在參數服務器上保存的EmbeddingBag層
            "ps",
            torch.nn.EmbeddingBag,
            args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
            kwargs={"mode": "sum"},
        )

        # Run the training loop on trainers.
        futs = []
        for trainer_rank in [0, 1]:
            trainer_name = "trainer{}".format(trainer_rank)
            fut = rpc.rpc_async( # 啟動 trainer循環
                trainer_name, _run_trainer, args=(remote_emb_module, trainer_rank)
            )
            futs.append(fut)

        # Wait for all training to finish.
        for fut in futs:
            fut.wait()
    elif rank <= 1:
        # Initialize process group for Distributed DataParallel on trainers.
        dist.init_process_group(
            backend="gloo", rank=rank, world_size=2, init_method="tcp://localhost:29500"
        )

        # Initialize RPC.
        trainer_name = "trainer{}".format(rank)
        rpc.init_rpc(
            trainer_name,
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )

        # 只需等待來自 master的 _run_trainer RPC
        # Trainer just waits for RPCs from master.
    else:
        rpc.init_rpc( # 參數服務器
            "ps",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )
        # parameter server do nothing
        pass # 啥也不幹,只是等待來自trainer和master的 RPC

    # block until all rpcs finish
    rpc.shutdown()


if __name__ == "__main__":
    # 2 trainers, 1 parameter server, 1 master.
    world_size = 4
    mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)

目前邏輯如下,我們後續會繼續拓展:

                               torch.multiprocessing.spawn
                                          +
                                          |
                                          |
              +----------------------------------------------------------------+----------------------------------+
              |                           |                                    |                                  |
              |                           |                                    |                                  |
              v                           v                                    v                                  v
+-------------+-------------+  +----------+---------------+ +------------------+------------------+ +-------------+--------+
|trainer 0         rank = 0 |  |trainer 1     rank = 1    | | master                     rank = 2 | |ps          rank = 3  |
|                           |  |                          | |                                     | |                      |
|                           |  |                          | |   rpc.init_rpc                      | |     rpc.init_rpc     |
|                           |  |                          | |                                     | |                      |
|   dist.init_process_group |  |  dist.init_process_group | |   remote_emb_module =  RemoteModule | |                      |
|                           |  |                          | |                                     | |                      |
|                           |  |                          | |                                     | |                      |
|   rpc.init_rpc            |  |  rpc.init_rpc            | |   fut = rpc.rpc_async(_run_trainer) | |                      |
|                           |  |                          | |                                     | |                      |
|                           |  |                          | |                                     | |                      |
+---------------------------+  +--------------------------+ +-------------------------------------+ +----------------------+

手機如下:

0x03 支撐系統

支撐系統主要指的就是 _RemoteModule,其作用是在異地建立一個模型,具體代碼在:torch/distributed/nn/api/remote_module.py。

3.1 功能

RemoteModule實例只能在RPC初始化之後創建,它可以在指定的遠程節點上創建用戶指定的模塊,其行為類似於常規的nn.Module方法,但不同之處是 RemoteModule 在遠程節點上執行forward方法。RemoteModule 負責autograd recording,以確保向後傳播可以將梯度傳播回相應的遠程模塊。

RemoteModule 可以使用RPC framework <//pytorch.org/docs/stable/rpc.html> 在處理器之間共享,且不會產生複製實際模塊的任何開銷,這相當於使用一個~torch.distributed.rpc.RRef指向遠程模塊。

3.2 使用

3.2.1 混合模型

要創建混合模型,通常應該在遠程模塊之外創建本地模塊,而不是作為任何遠程模塊的子模塊。如果遠程模塊放置在cuda設備上,那麼任何輸入CPU張量將自動移動到同一cuda設備之上。混合模型例子如下:

            >>> class HybridModel(nn.Module):
            >>>     def __init__(self):
            >>>         nn.Module.__init__(self)
            >>>         self.remote_embedding = RemoteModule(...) # 在遠端創建嵌入層
            >>>         self.local_linear = nn.Linear(...)

3.2.2 使用

使用例子如下,需要在兩個不同進程上運行如下代碼,例子之中,RemoteModule 創建時候,傳入了一個”worker1/cpu”參數,意思是在 worker1 的 cpu 設備上運行這個RemoteModule。具體格式是: <workername> / <device>,其中 <device> 是torch.device類型。

    Example::
        >>> # On worker 0:
        >>> import torch
        >>> import torch.distributed.rpc as rpc
        >>> from torch import nn, Tensor
        >>> from torch.distributed.nn.api.remote_module import RemoteModule
        >>>
        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
        >>> remote_linear_module = RemoteModule(
        >>>     "worker1/cpu", nn.Linear, args=(20, 30),
        >>> )
        >>> input = torch.randn(128, 20)
        >>> ret_fut = remote_linear_module.forward_async(input)
        >>> ret = ret_fut.wait()
        >>> rpc.shutdown()

        >>> # On worker 1:
        >>> import torch
        >>> import torch.distributed.rpc as rpc
        >>>
        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
        >>> rpc.shutdown()

3.3 定義

_RemoteModule定義如下,具體初始化邏輯是:

  • (1). 準備參數。
  • (2). 設置運行的遠端worker和遠端設備。
  • (3). 如果設置了_module_interface_cls
    • (3.1) 使用 _module_interface_cls 來在遠端構建模塊。_
    • (3.2) 在本地構建函數代理生成器。
    • (3.3) 等待創建完成。
    • (3.4) 在本地構建句柄。
  • (4) 沒有設置_module_interface_cls。
    • (4.1) 在本地構建函數代理生成器。
    • (4.2) 在遠端創建模塊。
  • (5). 在本地創建遠端函數代理。
class _RemoteModule(nn.Module):
    def __init__(
        self,
        remote_device: str,
        module_cls: nn.Module,
        args: Tuple = None,
        kwargs: Dict[str, Any] = None,
        _module_interface_cls: Any = None,
    ):
        """
        Args:
            remote_device (str): Device on the destination worker where we'd like to place this module.
                The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
                E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
                In addition, the device field can be optional and the default value is "cpu".

        Returns:
            A remote module instance which wraps the :class:`~nn.Module` created by the
            user-provided ``module_cls``, it has a blocking ``forward`` method and an
            asynchronous ``forward_async`` method that returns a future of the ``forward`` call
            on the user-provided module on the remote side.
        """
        super().__init__()

        # NOTE: if a new attribute is added to this class, also need to add it
        # to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` for pickling/unpickling.

        # Default arguments preperation.
        # 1. 準備參數
        args = args if args is not None else ()
        kwargs = kwargs if kwargs is not None else {}

        # 2. 設置運行的遠端worker和遠端設備
        self.on, self.device = _parse_remote_device(remote_device)
        agent = rpc._get_current_rpc_agent()
        # If the device map of the remote worker is set,
        # then enable moving any input CPU tensors to the same cuda device.
        self.is_device_map_set = bool(
            agent._get_device_map(agent.get_worker_info(self.on))
        )
        # ``enable_moving_cpu_tensors_to_cuda`` is less strict than ``is_device_map_set``:
        # If ``enable_moving_cpu_tensors_to_cuda`` is true, but the device map is not set,
        # then any CPU tensors can still be moved to a cuda device to run forward,
        # but the output must be moved back to CPU before being sent over the wire.
        enable_moving_cpu_tensors_to_cuda = torch.device(self.device).type == "cuda"

        # 3. 如果設置了_module_interface_cls
        if _module_interface_cls is not None:
            # Users reply on this field to know if this generated RemoteModule is TorchScript-able.
            self.is_scriptable = True

            # 3.1 使用 _module_interface_cls 來在遠端構建模塊
            # Instantiate template on remote side.
            fut = rpc.rpc_async(
                self.on,
                _instantiate_template,
                (_module_interface_cls, enable_moving_cpu_tensors_to_cuda),
            )

            # 3.2 在本地構建函數代理生成器
            # Instantiate template on local side.
            generated_module = (
                instantiator.instantiate_scriptable_remote_module_template(
                    _module_interface_cls, enable_moving_cpu_tensors_to_cuda
                )
            )
            self.generated_methods = generated_module._generated_methods

            # 3.3 等待創建完成
            # Create the module on the remote side.
            fut.wait()  # Ensure remote_module_cls is available on remote side.

            # 3.4 在本地構建句柄
            self.module_rref = rpc.rpc_sync(
                self.on,
                _create_module_with_interface,
                (module_cls, args, kwargs, self.device, _module_interface_cls),
            )
        else: # 4 沒有設置_module_interface_cls
            self.is_scriptable = False
            # 4.1 在本地構建函數代理生成器
            self.generated_methods = (
                _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
            )
            # 4.2 在遠端創建模塊
            # Create the module on the remote side.
            self.module_rref = rpc.remote(
                self.on,
                _create_module,
                (module_cls, args, kwargs, self.device),
            )

        # Install generated methods.
        # 5. 在本地創建遠端函數代理
        for method in self.generated_methods:
            method_name = method.__name__
            method = torch.jit.export(method)
            setattr(self, method_name, types.MethodType(method, self))

3.4 主要函數

其主要函數如下:

  • rpc.rpc_sync 返回指向遠程模塊參數的~torch.distributed.rpc.RRef列表。通常可以與~torch.distributed.optim.DistributedOptimizer結合使用。

  • get_module_rref 返回一個指向遠程模塊的~torch.distributed.rpc.RRef(RRef[nn.Module])類。

def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]:
    """
    Returns a list of :class:`~torch.distributed.rpc.RRef` pointing to the
    remote module's parameters. This can typically be used in conjuction
    with :class:`~torch.distributed.optim.DistributedOptimizer`.

    Args:
        recurse (bool): if True, then returns parameters of the remote
            module and all submodules of the remote module. Otherwise,
            returns only parameters that are direct members of the
            remote module.

    Returns:
        A list of :class:`~torch.distributed.rpc.RRef` (``List[RRef[nn.Parameter]]``)
        to remote module's parameters.
    """
    return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse))

def get_module_rref(self) -> rpc.RRef[nn.Module]:
    """
    Returns an :class:`~torch.distributed.rpc.RRef` (``RRef[nn.Module]``)
    pointing to the remote module.
    """
    return self.module_rref

於是邏輯圖轉換如下,在上圖基礎之上多了一個remote_emb_module,其在ps之上創建了一個RemoteModule

                                torch.multiprocessing.spawn
                                           +
                                           |
                                           |
               +----------------------------------------------------------------+----------------------------------+
               |                           |                                    |                                  |
               |                           |                                    |                                  |
               v                           v                                    v                                  v
+--------------+-------------+ +-----------+--------------+ +-------------------+-----------------+  +-------------+--------+
|trainer 0          rank = 0 | |trainer 1     rank = 1    | | master                     rank = 2 |  |ps          rank = 3  |
|                            | |                          | |                                     |  |                      |
|                            | |                          | |     rpc.init_rpc                    |  |     rpc.init_rpc     |
|                            | |                          | |                                     |  |                      |
|    dist.init_process_group | |  dist.init_process_group | |   remote_emb_module +----------------------> RemoteModule     |
|                            | |                          | |                                     |  |                      |
|                            | |                          | |                                     |  |                      |
|    rpc.init_rpc            | |  rpc.init_rpc            | |   fut = rpc.rpc_async(_run_trainer) |  |                      |
|                            | |                          | |                                     |  |                      |
|                            | |                          | |                                     |  |                      |
|                            | |                          | |                                     |  |                      |
+----------------------------+ +--------------------------+ +-------------------------------------+  +----------------------+

手機如下:

0x04 HybridModel

在討論 Trainer 的細節之前,讓我們先介紹一下 Trainer使用的HybridModel。該模型由稀疏部分和稠密部分組成。

  • 稠密部分是一個nn.Linear,使用DistributedDataParallel在所有trainer中複製,即 在 DDP 內包裝了一個 nn.Linear層。

  • 稀疏部分是一個遠程模塊 (remote_emb_module) ,它持有一個在參數服務器上的nn.EmbeddingBag。即,此遠程模塊可以獲取參數服務器上嵌入表的遠程引用。

該模型的前向方法非常簡單。它使用 RemoteModule 在參數服務器上執行嵌入查找forward ,並將其輸出傳播到 FC 層,這裡的 FC 使用了DDP

class HybridModel(torch.nn.Module):
    r"""
    The model consists of a sparse part and a dense part.
    1) The dense part is an nn.Linear module that is replicated across all trainers using DistributedDataParallel.
    2) The sparse part is a Remote Module that holds an nn.EmbeddingBag on the parameter server.
    This remote model can get a Remote Reference to the embedding table on the parameter server.
    """

    def __init__(self, remote_emb_module, device):
        super(HybridModel, self).__init__()
        self.remote_emb_module = remote_emb_module
        self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
        self.device = device

    def forward(self, indices, offsets):
        emb_lookup = self.remote_emb_module.forward(indices, offsets)
        return self.fc(emb_lookup.cuda(self.device))

邏輯拓展如下,兩個trainer 之上也建立了remote_emb_module,指向了ps之上的RemoteModule

                                         torch.multiprocessing.spawn
                                                    +
                                                    |
                                                    |
            +-----------------------------------------------------------------------------------+----------------------------------+
            |                                       |                                           |                                  |
            |                                       |                                           |                                  |
            v                                       v                                           v                                  v
+-----------+-------------+ +-----------------------+-------------------+ +---------------------+---------------+    +-------------+--------+
|trainer 0       rank = 0 | | trainer 1                        rank = 1 | | master                     rank = 2 |    |ps          rank = 3  |
|                         | |                                           | |                                     |    |                      |
|                         | |                                           | |   rpc.init_rpc                      |    |     rpc.init_rpc     |
| dist.init_process_group | | dist.init_process_group                   | |                                     |    |                      |
|                         | |                                           | |   remote_emb_module +------------------------> RemoteModule     |
| rpc.init_rpc            | | rpc.init_rpc                              | |                                     |    |         ^     ^      |
|                         | |                                           | |                                     |    |         |     |      |
|                         | |                                           | |   fut = rpc.rpc_async(_run_trainer) |    |         |     |      |
|                         | |                                           | |                                     |    |         |     |      |
| +---------------------+ | |            +---------------------------+  | |                                     |    |         |     |      |
| | HybridModel         | | |            |HybridModel                |  | |                                     |    |         |     |      |
| |                     | | |            |                           |  | +-------------------------------------+    +----------------------+
| |                     | | |            |                           |  |                                                      |     |
| |   fc = DDP(Linear)  | | |            |      fc = DDP(Linear())   |  |                                                      |     |
| |                     | | |            |                           |  |                                                      |     |
| |   remote_emb_module | | |            |      remote_emb_module+-------------------------------------------------------------+     |
| |             +       | | |            |                           |  |                                                            |
| +---------------------+ | |            +---------------------------+  |                                                            |
|               |         | |                                           |                                                            |
+-------------------------+ +-------------------------------------------+                                                            |
                |                                                                                                                    |
                +--------------------------------------------------------------------------------------------------------------------+

手機如下:

0x05 訓練

5.1 初始化

之前初始化時候,我們漏過了trainer的初始化,這裡我們分析一下。

我們先看看 Trainer 上的設置。

  • 首先,trainer使用遠程模塊(remote module)和自己的rank 來創建上面提到的 HybridModel,遠程模塊持有參數服務器上的嵌入表。
  • 其次,我們需要得到一個RRef 列表,該列表指向我們想要使用DistributedOptimizer優化的所有參數。
    • 要從參數服務器嵌入表之中拿到這些參數,我們可以調用 RemoteModule 的remote_parameters,它會遍歷嵌入表的所有參數並返回一個 RRef 列表。trainer通過 RPC 在參數服務器上調用此方法來得到所需參數的 RRef 列表。
    • 由於 DistributedOptimizer 始終持有一個需要優化參數的 RRef 列表,因此我們需要為 FC 層的局部參數創建 RRef。這是通過遍歷model.fc.parameters()來完成的,其將為每個參數創建一個 RRef 並將其附加到從remote_parameters()返回的列表中。
    • 請注意,我們不能使用model.parameters(),因為它會遞歸調用model.remote_emb_module.parameters(),而RemoteModule不支持這種操作。
  • 最後,我們使用所有 RRef 創建我們的 DistributedOptimizer 並定義一個 CrossEntropyLoss 函數。
def _run_trainer(remote_emb_module, rank):
    r"""
    Each trainer runs a forward pass which involves an embedding lookup on the
    parameter server and running nn.Linear locally. During the backward pass,
    DDP is responsible for aggregating the gradients for the dense part
    (nn.Linear) and distributed autograd ensures gradients updates are
    propagated to the parameter server.
    """

    # Setup the model.
    model = HybridModel(remote_emb_module, rank)

    # Retrieve all model parameters as rrefs for DistributedOptimizer.

    # Retrieve parameters for embedding table.
    model_parameter_rrefs = model.remote_emb_module.remote_parameters()

    # model.fc.parameters() only includes local parameters.
    # NOTE: Cannot call model.parameters() here,
    # because this will call remote_emb_module.parameters(),
    # which supports remote_parameters() but not parameters().
    for param in model.fc.parameters(): 
        model_parameter_rrefs.append(RRef(param)) # 這裡添加了需要分佈式優化的 DDP 的參數

    # Setup distributed optimizer
    opt = DistributedOptimizer(
        optim.SGD,
        model_parameter_rrefs, # dense參數和sparse參數一起分佈式優化
        lr=0.05,
    )

    criterion = torch.nn.CrossEntropyLoss()

我們邏輯拓展如下,這裡省略了 trainer 0 指向 參數服務器的箭頭,與上圖相比,增加了 DistributedOptimizer。

                                            torch.multiprocessing.spawn
                                                       +
                                                       |
                                                       |
               +-----------------------------------------------------------------------------------+----------------------------------+
               |                                       |                                           |                                  |
               |                                       |                                           |                                  |
               v                                       v                                           v                                  v
+--------------+-------------+ +-----------------------+-------------------+ +---------------------+---------------+  +---------------+-------------+
|trainer 0          rank = 0 | | trainer 1                        rank = 1 | | master                     rank = 2 |  |  ps                rank = 3 |
|                            | |                                           | |                                     |  |                             |
|                            | |                                           | |                                     |  |      rpc.init_rpc           |
| dist.init_process_group    | | dist.init_process_group                   | |   rpc.init_rpc                      |  |                             |
|                            | |                                           | |                                     |  |    +----------------------+ |
| rpc.init_rpc               | | rpc.init_rpc                              | |                            1        |  |    | RemoteModule         | |
|                            | |                                           | |   remote_emb_module +---------------------> |                      | |
| +------------------------+ | | +---------------------------------------+ | |                                     |  |    |                      | |
| | _run_trainer           | | | | _run_trainer                          | | |                                     |  |    |  remote_parameters() | |
| |                        | | | |                                       | | |   fut = rpc.rpc_async(_run_trainer) |  |    |                      | |
| |                        | | | |   output = model(indices, offsets)    | | |                                     |  |    |                      | |
| |                        | | | |   dist_autograd.backward              | | |                                     |  |    +------+--------+------+ |
| |                        | | | |   opt.step                            | | |                                     |  |           ^        ^        |
| |                        | | | |                                       | | |                                     |  |           |        |        |
| | +-------------------+  | | | |                                       | | +-------------------------------------+  +-----------------------------+
| | | HybridModel       |  | | | |  +-----------------------------+      | |                                                      |        |
| | |                   |  | | | |  | HybridModel                 |      | |                                                      |        |
| | | fc = DDP(Linear)  |  | | | |  |                             |      | |                                                      |        |
| | | remote_emb_module |  | | | |  |  fc = DDP(Linear().cuda()   |      | |                                                      |        |
| | |                   |  | | | |  |  remote_emb_module+------------------------------------------------------------------------->        |
| | +-------------------+  | | | |  |                             |      | |                             2                                 |
| |                        | | | |  +-----------------------------+      | |                                                               |
| | +--------------------+ | | | |  +-----------------------------+      | |                                                               |
| | |DistributedOptimizer| | | | |  |DistributedOptimizer         |      | |                                                               |
| | +--------------------+ | | | |  |                             +------------------------------------------------------------------------>
| |                        | | | |  +-----------------------------+      | |                              3
| +------------------------+ | | +---------------------------------------+ |
+----------------------------+ +-------------------------------------------+


手機如下:

5.2 訓練循環

現在我們介紹在每個trainer上運行的主訓練循環。這裡 get_next_batch只是一個輔助函數,用於生成隨機輸入和訓練目標。我們為多個epoch和每個batch運行該訓練循環:

  1. 為Distributed Autograd.設置Distributed Autograd Context
  2. 運行模型的前向傳播並拿到其輸出。
  3. 使用損失函數根據我們的輸出和target來計算損失。
  4. 使用 Distributed Autograd 對損失執行分佈式反向傳播。
  5. 最後,運行分佈式優化器step 來優化所有參數。
    def get_next_batch(rank):
        for _ in range(10):
            num_indices = random.randint(20, 50)
            indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS)

            # Generate offsets.
            offsets = []
            start = 0
            batch_size = 0
            while start < num_indices:
                offsets.append(start)
                start += random.randint(1, 10)
                batch_size += 1

            offsets_tensor = torch.LongTensor(offsets)
            target = torch.LongTensor(batch_size).random_(8).cuda(rank)
            yield indices, offsets_tensor, target

    # Train for 100 epochs
    for epoch in range(100):
        # create distributed autograd context
        for indices, offsets, target in get_next_batch(rank):
            with dist_autograd.context() as context_id:
                output = model(indices, offsets)
                loss = criterion(output, target)

                # Run distributed backward pass
                dist_autograd.backward(context_id, [loss])

                # Tun distributed optimizer
                opt.step(context_id)

                # Not necessary to zero grads as each iteration creates a different
                # distributed autograd context which hosts different grads
        print("Training done for epoch {}".format(epoch))

因為篇幅所限,我們只是把上面的trainer再細化如下圖:

  1. 初始化時候,調用 dist.init_process_group 來初始化 DistributedDataParallel,調用 rpc.init_rpc 來初始化 RPC。
  2. HybridModel 之中,fc 是DistributedDataParallel方式,remote_emb_module 是參數服務器上的 RemoteModule。
  3. DistributedOptimizer 之中,對於 HybridModel 的 fc 和 remote_emb_module 都會進行分佈式優化。
  4. _run_trainer 之中,使用 model(indices, offsets) 進行前向傳播,其中會調用到 HybridModel.forward。
  5. HybridModel.forward 之中則對embedding 和 fc 進行操作。
    1. embedding 是利用RPC 和 參數服務器。
    2. fc 是利用 DistributedDataParallel。
    3. 將嵌入表放在參數服務器上,並使用DistributedDataParallel 在多個trainer之間複製 FC 層。

這些序號與下圖中數字對應。

+---------------------------------------------------------------------+
| trainer 1                                                 rank = 1  |
|                +-----------------------------------+                |
|                |    dist.init_process_group      1 |                |
|                |                                   |                |
|                |    rpc.init_rpc                   |                |
|                |                                   |                |
|                +-----------------------------------+                |
| +-----------------------------------------------------------------+ |
| | _run_trainer                                                    | |
| |                                                                 | |
| |     output = model(indices, offsets)                            | |
| |     dist_autograd.backward      +                               | |
| |     opt.step                    |                               | |
| |  +-----------------------------------------------------------+  | |
| |  | HybridModel                  |                          2 |  | |
| |  |                              |                            |  | |
| |  |    fc = DDP(Linear().cuda()  |                            |  | |
| |  |                              |4                           |  | |
| |  |    remote_emb_module         |                            |  | |
| |  |                              |                            |  | |
| |  |                              v                            |  | |
| |  |   +--------------------------+--------------------------+ |  | |
| |  |   |forward                                              | |  | |
| |  |   |  emb_lookup = remote_emb_module.forward()           | |  | |
| |  |   |                  +                                  | |  | |
| |  |   |                  |  5                               | |  | |
| |  |   |                  |                                  | |  | |
| |  |   |                  v                                  | |  | |
| |  |   |  fc(emb_lookup.cuda(device)                         | |  | |
| |  |   |                                                     | |  | |
| |  |   +-----------------------------------------------------+ |  | |
| |  +-----------------------------------------------------------+  | |
| |  +-----------------------------------------------------------+  | |
| |  | DistributedOptimizer                                    3 |  | |
| |  |                                                           |  | |
| |  |         HybridModel.remote_emb_module.remote_parameters() |  | |
| |  |                                                           |  | |
| |  |         HybridModel.fc.parameters()                       |  | |
| |  |                                                           |  | |
| |  +-----------------------------------------------------------+  | |
| +-----------------------------------------------------------------+ |
+---------------------------------------------------------------------+

手機如下:

注,可以在此處找到整個示例的源代碼。

0x06 比對

我們已經看了三篇PyTorch官方樣例,裏面對參數服務器的實現各有不同。對於本文來說,又加入了一個master作為協調者來統一各個worker。

總的來說,在PyTorch 之中,因為有了 RPC 機制,所以PyTorch 的參數服務器實現比 ps-lite, paracel 更佳靈活機動:

  • 首先參數服務器目前可以放在 GPU 之中。
  • 其次,可以在參數服務器只放置參數,也可以運行優化代碼,甚至可以在參數服務之上啟動控制trainer。
  • 具體優化器根據實際需要,可以是普通優化器,也可以是DistributedOptimizer。
  • 訓練代碼從用戶編寫角度看則完全是運行在本地。

0xFF 參考

COMBINING DISTRIBUTED DATAPARALLEL WITH DISTRIBUTED RPC FRAMEWORK