[源碼解析] PyTorch 分散式(16) — 使用非同步執行實現批處理 RPC

[源碼解析] PyTorch 分散式(16) — 使用非同步執行實現批處理 RPC

0x00 摘要

在前面的文章之中,我們已經學習了PyTorch 分散式的基本模組,接下來我們通過幾篇文章來看看如何把這些模組應用到實踐之中,順便把PyTorch分散式邏輯整體梳理一下。本文介紹如何使用非同步執行操作來實現批處理 RPC,大家可以學習到PyTorch對參數伺服器一個新的實現方式。

本文以IMPLEMENTING BATCH RPC PROCESSING USING ASYNCHRONOUS EXECUTIONS的翻譯為基礎,加入了自己的理解。

PyTorch分散式其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[源碼解析]深度學習利器之自動微分(3) — 示例解讀

[源碼解析]PyTorch如何實現前向傳播(1) — 基礎類(上)

[源碼解析]PyTorch如何實現前向傳播(2) — 基礎類(下)

[源碼解析] PyTorch如何實現前向傳播(3) — 具體實現

[源碼解析] Pytorch 如何實現後向傳播 (1)—- 調用引擎

[源碼解析] Pytorch 如何實現後向傳播 (2)—- 引擎靜態結構

[源碼解析] Pytorch 如何實現後向傳播 (3)—- 引擎動態邏輯

[源碼解析] PyTorch 如何實現後向傳播 (4)—- 具體演算法

[源碼解析] PyTorch 分散式(1)——歷史和概述

[源碼解析] PyTorch 分散式(2) —– DataParallel(上)

[源碼解析] PyTorch 分散式(3) —– DataParallel(下)

[源碼解析] PyTorch 分散式(4)——分散式應用基礎概念

[源碼解析] PyTorch分散式(5) —— DistributedDataParallel 總述&如何使用

[源碼解析] PyTorch分散式(6) —DistributedDataParallel — 初始化&store

[源碼解析] PyTorch 分散式(7) —– DistributedDataParallel 之進程組

[源碼解析] PyTorch 分散式(8) ——– DistributedDataParallel之論文篇

[源碼解析] PyTorch 分散式(9) —– DistributedDataParallel 之初始化

[源碼解析] PyTorch 分散式(10)——DistributedDataParallel 之 Reducer靜態架構

[源碼解析] PyTorch 分散式(11) —– DistributedDataParallel 之 構建Reducer和Join操作

[源碼解析] PyTorch 分散式(12) —– DistributedDataParallel 之 前向傳播

[源碼解析] PyTorch 分散式(13) —– DistributedDataParallel 之 反向傳播

[源碼解析] PyTorch 分散式 Autograd (1) —- 設計

[源碼解析] PyTorch 分散式 Autograd (2) —- RPC基礎

[源碼解析] PyTorch 分散式 Autograd (3) —- 上下文相關

[源碼解析] PyTorch 分散式 Autograd (4) —- 如何切入引擎

[源碼解析] PyTorch 分散式 Autograd (5) —- 引擎(上)

[源碼解析] PyTorch 分散式 Autograd (6) —- 引擎(下)

[源碼解析] PyTorch分散式優化器(1)—-基石篇

[源碼解析] PyTorch分散式優化器(2)—-數據並行優化器

[源碼解析] PyTorch分散式優化器(3)—- 模型並行

[源碼解析] PyTorch 分散式(14) –使用 Distributed Autograd 和 Distributed Optimizer

[源碼解析] PyTorch 分散式(15) — 使用分散式 RPC 框架實現參數伺服器

註:本文沒有完全按照原文順序進行翻譯,而是按照自己理解的思路重新組織了文章。

0x01 前言

1.1 先決條件

本文的先決條件如下:

本教程演示了如何使用@rpc.functions.async_execution 裝飾器構建批處理 RPC 應用程式,這有助於通過減少被阻塞的 RPC 執行緒的數量,並且在被調用方整合 CUDA 操作來加快訓練速度。這與使用 TorchServer 進行批量推理的想法相同。Batch RPC 有助於將動作整合到較少的 CUDA 操作中,從而攤銷開銷。

注意:本教程需要 PyTorch v1.6.0 或更高版本。

1.2 基礎知識

之前的教程已經展示了使用torch.distributed.rpc構建分散式訓練應用程式的步驟,但他們沒有詳細說明在處理 RPC 請求時被調用方會發生什麼。從 PyTorch v1.5 開始,針對每個 RPC 請求,被調用者都會啟動一個執行緒來執行該請求中的函數,該執行緒會阻塞直到該函數返回。這適用於許多用例,但有一個問題:如果用戶函數在 IO 上阻塞,例如使用嵌套的 RPC 調用或訊號(例如等待不同的 RPC 請求來解除阻塞),則被調用者上的 RPC 執行緒將不得不空閑等待,直到 IO 完成或訊號(signal)事件發生。因此,RPC 被調用者使用的執行緒可能會使用比實際需要更多。造成這個問題的原因是RPC把用戶函數當成黑盒,對函數中發生的事情知之甚少。為了讓用戶函數能夠讓出和釋放 RPC 執行緒,需要向 RPC 系統提供更多的提示。

從 v1.6.0 開始,PyTorch 通過引入兩個新概念來解決這個問題:

  • torch.futures.Future 封裝了一個非同步執行,同時也支援安裝回調函數。
  • @rpc.functions.async_execution 裝飾器,它允許應用程式告訴被調用者,本目標函數將返回一個future,並且可以在執行過程中多次暫停和yield。

使用這兩個工具,應用程式程式碼可以將用戶函數分解為多個較小的函數,將它們鏈接在一起作為Future 對象的回調方法,並返回包含最終結果的 Future給調用者。在被調用方,在獲取Future對象時,它也會安裝後續的 RPC 響應處理作為回調方法,這些回調會在最終結果準備好時被觸發。這樣,被調用者不再需要阻塞一個執行緒,只是等待最終返回值準備好就行。 簡單的例子請參考@rpc.functions.async_execution的API文檔 。

除了減少被調用者的空閑執行緒數量外,這些工具還使批處理 RPC 處理更容易、更快。本教程演示了如何使用@rpc.functions.async_execution 裝飾器構建分散式批量更新參數伺服器和批量處理強化學習應用程式 。

註:我們不考慮強化學習的領域,那樣會影響我們的思路,牽扯精力

1.3 程式碼

因為原文主要是強化學習程式碼講解,而我們只關注普通分散式批量更新參數伺服器,所以需要看原始程式碼。

程式碼位於 //github.com/pytorch/examples/blob/master/distributed/rpc/batch/parameter_server.py。先全部摘錄如下:

import os
import threading
from datetime import datetime

import torch
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn
from torch import optim

import torchvision


batch_size = 20
image_w = 64
image_h = 64
num_classes = 30
batch_update_size = 5
num_batches = 6

def timed_log(text):
    print(f"{datetime.now().strftime('%H:%M:%S')} {text}")

class BatchUpdateParameterServer(object):

    def __init__(self, batch_update_size=batch_update_size):
        self.model = torchvision.models.resnet50(num_classes=num_classes)
        self.lock = threading.Lock()
        self.future_model = torch.futures.Future()
        self.batch_update_size = batch_update_size
        self.curr_update_size = 0
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
        for p in self.model.parameters():
            p.grad = torch.zeros_like(p)

    def get_model(self):
        return self.model

    @staticmethod
    @rpc.functions.async_execution
    def update_and_fetch_model(ps_rref, grads):
        self = ps_rref.local_value()
        timed_log(f"PS got {self.curr_update_size}/{batch_update_size} updates")
        for p, g in zip(self.model.parameters(), grads):
            p.grad += g
        with self.lock:
            self.curr_update_size += 1
            fut = self.future_model

            if self.curr_update_size >= self.batch_update_size:
                for p in self.model.parameters():
                    p.grad /= self.batch_update_size
                self.curr_update_size = 0
                self.optimizer.step()
                self.optimizer.zero_grad()
                fut.set_result(self.model)
                timed_log("PS updated model")
                self.future_model = torch.futures.Future()

        return fut


class Trainer(object):

    def __init__(self, ps_rref):
        self.ps_rref = ps_rref
        self.loss_fn = nn.MSELoss()
        self.one_hot_indices = torch.LongTensor(batch_size) \
                                    .random_(0, num_classes) \
                                    .view(batch_size, 1)

    def get_next_batch(self):
        for _ in range(num_batches):
            inputs = torch.randn(batch_size, 3, image_w, image_h)
            labels = torch.zeros(batch_size, num_classes) \
                        .scatter_(1, self.one_hot_indices, 1)
            yield inputs.cuda(), labels.cuda()

    def train(self):
        name = rpc.get_worker_info().name
        m = self.ps_rref.rpc_sync().get_model().cuda()
        for inputs, labels in self.get_next_batch():
            timed_log(f"{name} processing one batch")
            self.loss_fn(m(inputs), labels).backward()
            timed_log(f"{name} reporting grads")
            m = rpc.rpc_sync(
                self.ps_rref.owner(),
                BatchUpdateParameterServer.update_and_fetch_model,
                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),
            ).cuda()
            timed_log(f"{name} got updated model")


def run_trainer(ps_rref):
    trainer = Trainer(ps_rref)
    trainer.train()


def run_ps(trainers):
    timed_log("Start training")
    ps_rref = rpc.RRef(BatchUpdateParameterServer())
    futs = []
    for trainer in trainers:
        futs.append(
            rpc.rpc_async(trainer, run_trainer, args=(ps_rref,))
        )

    torch.futures.wait_all(futs)
    timed_log("Finish training")


def run(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    options=rpc.TensorPipeRpcBackendOptions(
        num_worker_threads=16,
        rpc_timeout=0  # infinite timeout
     )
    if rank != 0:
        rpc.init_rpc(
            f"trainer{rank}",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=options
        )
        # trainer passively waiting for ps to kick off training iterations
    else:
        rpc.init_rpc(
            "ps",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=options
        )
        run_ps([f"trainer{r}" for r in range(1, world_size)])

    # block until all rpcs finish
    rpc.shutdown()


if __name__=="__main__":
    world_size = batch_update_size + 1
    mp.spawn(run, args=(world_size, ), nprocs=world_size, join=True)

0x02 啟動

我們首先看看如何啟動。

2.1 總體啟動

我們假設有一個master(rank 0),一個worker。Master 之上運行的是參數伺服器,worker 之上是訓練程式碼。

def run(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    options=rpc.TensorPipeRpcBackendOptions(
        num_worker_threads=16,
        rpc_timeout=0  # infinite timeout
     )
    if rank != 0:
        rpc.init_rpc( # 訓練程式碼
            f"trainer{rank}",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=options
        )
        # trainer passively waiting for ps to kick off training iterations
    else:
        rpc.init_rpc( # 參數伺服器
            "ps", 
            rank=rank,
            world_size=world_size,
            rpc_backend_options=options
        )
        run_ps([f"trainer{r}" for r in range(1, world_size)])

    # block until all rpcs finish
    rpc.shutdown()


if __name__=="__main__":
    world_size = batch_update_size + 1
    mp.spawn(run, args=(world_size, ), nprocs=world_size, join=True)

邏輯如下圖:

             torch.multiprocessing.spawn
                        +
                        |
                        |
           +------------+-------------------------------------------------
           |                                                             |
           |                                                             |
           v                                                             v
+----------+----------------------------------------------+ +------------+----------------+
| "ps"                                           rank = 0 | | f"trainer{rank}"   rank = 1 |
|                                                         | |                             |
|                                                         | |                             |
|                     rpc.init_rpc                        | |         rpc.init_rpc        |
|                                                         | |                             |
|                                                         | |                             |
|  run_ps([f"trainer{r}" for r in range(1, world_size)])  | |                             |
|                                                         | |                             |
|                                                         | |                             |
+---------------------------------------------------------+ +-----------------------------+

2.2 啟動參數伺服器

run_ps 啟動了參數伺服器和trainer。注意,這裡在參數伺服器之中啟動 trainer,即,master 不僅僅有一個參數伺服器,還負責通過 rpc 來驅動trainer上的訓練循環。

def run_ps(trainers):
    timed_log("Start training")
    ps_rref = rpc.RRef(BatchUpdateParameterServer())
    futs = []
    for trainer in trainers: # trainer 是字元串,比如"trainer1"
        futs.append(
            rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) # 運行run_trainer
        )

    torch.futures.wait_all(futs)
    timed_log("Finish training")
    
def run_trainer(ps_rref):
    trainer = Trainer(ps_rref)
    trainer.train() # 調用 Trainer 的方法   

具體拓展如下:

這裡沒有給出參數伺服器和trainer的邏輯,我們會在後續分析之後陸續給出。trainer 也只給出了一個。

0x03 參數伺服器

上面圖中沒有給出具體參數伺服器程式碼,我們接下來就分析一下。

這裡考慮具有一個參數伺服器 (PS) 和多個trainer的同步訓練應用程式。在這個應用中,PS 持有參數並等待所有訓練器報告梯度。在每次迭代中,它等待直到從所有訓練器接收梯度,然後一次性更新所有參數。

下面的程式碼顯示了 PS 類的實現。

  • PS初始化時候生成了常規SGB優化器,不是分散式優化器,而且優化器是在PS之上
  • update_and_fetch_model方法被 @rpc.functions.async_execution所裝飾,將由trainer調用。
  • 每次調用都會返回一個Future對象,該對象將被用來處理更新後的模型。
  • 大多數訓練器發起的調用只是累積梯度到 .grad成員變數 ,然後立即返回,並在 PS 上產生 RPC 執行緒。
  • 最後到達的訓練器將觸發優化器步驟並消耗所有先前上報的梯度。然後它使用更新後的模型來設置future_model,這是依靠通過Future對象來依次通知來自其他訓練者的先前請求,並將更新後的模型發送給所有訓練者。

具體程式碼如下:

batch_size = 20
image_w = 64
image_h = 64
num_classes = 30
batch_update_size = 5
num_batches = 6

def timed_log(text):
    print(f"{datetime.now().strftime('%H:%M:%S')} {text}")

class BatchUpdateParameterServer(object):

    def __init__(self, batch_update_size=batch_update_size):
        self.model = torchvision.models.resnet50(num_classes=num_classes)
        self.lock = threading.Lock()
        self.future_model = torch.futures.Future()
        self.batch_update_size = batch_update_size
        self.curr_update_size = 0
        # 重點:這裡是常規SGB優化器,不是分散式優化器
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
        for p in self.model.parameters():
            p.grad = torch.zeros_like(p)

    def get_model(self):
        return self.model

    @staticmethod
    @rpc.functions.async_execution # trainer會直接調用
    def update_and_fetch_model(ps_rref, grads):
        self = ps_rref.local_value()
        timed_log(f"PS got {self.curr_update_size}/{batch_update_size} updates")
        for p, g in zip(self.model.parameters(), grads): # 得到
            p.grad += g # 累積梯度
        with self.lock:
            self.curr_update_size += 1
            fut = self.future_model

            if self.curr_update_size >= self.batch_update_size:
                # 最後到達的訓練器將觸發優化器步驟並消耗所有先前上報的梯度。
                for p in self.model.parameters():
                    p.grad /= self.batch_update_size
                self.curr_update_size = 0
                self.optimizer.step() # 更新模型
                self.optimizer.zero_grad()
                fut.set_result(self.model) # 將更新後的模型發送給所有訓練者
                timed_log("PS updated model")
                self.future_model = torch.futures.Future() # 使用更新後的模型來設置future_model

        return fut # 該對象將被用來處理更新後的模型

邏輯拓展如下,這裡省略了參數伺服器生成trainer的步驟:

手機如下:

0x04 Trainer

對於訓練器,它們都使用來自 PS 的相同參數集進行初始化。在每次迭代中執行如下操作:

  • 每個訓練器首先運行前向和後向傳播以在本地生成梯度。
  • 然後,每個訓練器使用 RPC 向 PS 報告其梯度,並通過同一 RPC 請求的返回值取回更新後的參數。

在訓練器的實現中,目標函數是否被標記 @rpc.functions.async_execution是沒有區別的。訓練器只需使用 rpc_sync 調用update_and_fetch_model,其將阻塞訓練器,直到返回更新的模型。

可以看到,參數伺服器存儲模型,模型可以返回到trainer。

class Trainer(object):

    def __init__(self, ps_rref):
        self.ps_rref = ps_rref
        self.loss_fn = nn.MSELoss()
        self.one_hot_indices = torch.LongTensor(batch_size) \
                                    .random_(0, num_classes) \
                                    .view(batch_size, 1)

    def get_next_batch(self):
        for _ in range(num_batches):
            inputs = torch.randn(batch_size, 3, image_w, image_h)
            labels = torch.zeros(batch_size, num_classes) \
                        .scatter_(1, self.one_hot_indices, 1)
            yield inputs.cuda(), labels.cuda()

    def train(self):
        name = rpc.get_worker_info().name
        # 從參數伺服器獲取模型
        m = self.ps_rref.rpc_sync().get_model().cuda()
        for inputs, labels in self.get_next_batch():
            timed_log(f"{name} processing one batch")
            # 利用模型來前向傳播/反向傳播
            self.loss_fn(m(inputs), labels).backward()
            timed_log(f"{name} reporting grads")
            # 調用參數伺服器的函數來提交梯度
            m = rpc.rpc_sync( # rpc_sync 操作完成之後,m就是最新模型了
                self.ps_rref.owner(),
                BatchUpdateParameterServer.update_and_fetch_model,
                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),
            ).cuda()
            timed_log(f"{name} got updated model")

拓展邏輯如下:

  1. 參數伺服器的run_trainer 方法會直接調用 trainer.train() 方法來執行一步step。
  2. train 方法之中,會調用 self.ps_rref.rpc_sync().get_model().cuda() 從參數伺服器獲得模型,放到本地設備之上(圖上是雙向箭頭,表示這是一個get/return動作,需要把模型存儲在worker本地)。
  3. 調用 self.loss_fn(m(inputs), labels).backward() 來進行前向傳播/反向傳播。
  4. 調用參數伺服器的 update_and_fetch_model 函數來提交梯度,這裡使用了非同步RPC
  5. 參數伺服器的 update_and_fetch_model 之中,進行梯度累積,模型更新是通過PS之上常規SGD優化器完成,最後調用 fut.set_result(self.model) 來發布新模型給trainer。在trainer 之中,就是 m = rpc.rpc_sync(…) 這個賦值之後,m 是最新模型了。

0x05 對比

前文結尾,我們對比參數伺服器的經典實現 ps-lite 和 前兩篇實現的參數伺服器。

  • ps-lite 是類似傳統伺服器實現,有自己主動的業務循環,可以響應用戶的顯式請求,也有自己明確的邏輯,本地也有自己的KV存儲。
  • PyTorch 前兩篇官方文檔(本系列前兩篇文章)之中,參數伺服器則是另外一種思路:
    • 參數伺服器上沒有主動的循環,沒有KV存儲,沒有伺服器邏輯,而是可以直接存儲業務模型,ps 會把業務模型需要優化的參數返回給trainer 之上的 DistributedOptimizer。
    • 業務驅動由trainer完成:train loop程式碼在trainer 之中,DistributedOptimizer 在trainer 之中,DistributedOptimizer 負責進行分散式優化。
  • 本文又與上面不同,看起來更像是ps-lite,但是又糅合了RPC實現:
    • ps進程會啟動trainer的訓練循環
    • 每個迭代之中,trainer 會從參數伺服器獲取最新模型,前向操作/後向傳播都在trainer 完成。
    • trainer 會通過非同步RPC把梯度提交給參數伺服器。
    • 模型更新是通過PS之上常規SGD優化器完成
    • 模型更新之後通過非同步RPC把模型再次分發給trainer。

不得不說,官方這幾篇文章快把各種實現方式玩出花來了,大家可以依據自己業務特點來參考實現。

0xFF 參考

IMPLEMENTING BATCH RPC PROCESSING USING ASYNCHRONOUS EXECUTIONS