[源碼解析] PyTorch 分散式 Autograd (1) —- 設計
- 2021 年 11 月 29 日
- 筆記
- 001_機器學習, 006_深度學習, 011_分散式機器學習
[源碼解析] PyTorch 分散式 Autograd (1) —- 設計
0x00 摘要
本文以幾篇PyTorch官方文檔為基礎來了解分散式 autograd 的設計和內部結構,在翻譯時並沒有逐字翻譯,其中加入了自己的部分理解。分散式 autograd 後續文章的分析也會基於本文進行。
PyTorch分散式其他文章如下:
[源碼解析]PyTorch如何實現前向傳播(1) — 基礎類(上)
[源碼解析]PyTorch如何實現前向傳播(2) — 基礎類(下)
[源碼解析] PyTorch如何實現前向傳播(3) — 具體實現
[源碼解析] Pytorch 如何實現後向傳播 (1)—- 調用引擎
[源碼解析] Pytorch 如何實現後向傳播 (2)—- 引擎靜態結構
[源碼解析] Pytorch 如何實現後向傳播 (3)—- 引擎動態邏輯
[源碼解析] PyTorch 如何實現後向傳播 (4)—- 具體演算法
[源碼解析] PyTorch 分散式(2) —– DataParallel(上)
[源碼解析] PyTorch 分散式(3) —– DataParallel(下)
[源碼解析] PyTorch 分散式(4)——分散式應用基礎概念
[源碼解析] PyTorch分散式(5) —— DistributedDataParallel 總述&如何使用
[源碼解析] PyTorch分散式(6) —DistributedDataParallel — 初始化&store
[源碼解析] PyTorch 分散式(7) —– DistributedDataParallel 之進程組
[源碼解析] PyTorch 分散式(8) ——– DistributedDataParallel之論文篇
[源碼解析] PyTorch 分散式(9) —– DistributedDataParallel 之初始化
[源碼解析] PyTorch 分散式(10)——DistributedDataParallel 之 Reducer靜態架構
[源碼解析] PyTorch 分散式(11) —– DistributedDataParallel 之 構建Reducer和Join操作
[源碼解析] PyTorch 分散式(12) —– DistributedDataParallel 之 前向傳播
[源碼解析] PyTorch 分散式(13) —– DistributedDataParallel 之 反向傳播
0x01 分散式RPC框架
本文主要以 //pytorch.org/docs/master/rpc/distributed_autograd.html 為基準,但是原文檔要求用戶熟悉 Autograd 機制和分散式 RPC 框架,因為我們已經分析過 Autograd 機制,所以我們先研究一下 分散式 RPC 框架。
1.1 RPC 框架
RPC(Remote Procedure Call)是一種設計或者技術思想,而不是協議或者規範。
對於 RPC 最簡單的理解就是一個節點請求另外一個節點所提供的服務,但是對於用戶程式碼來說需要維護一個”本地調用”的感覺,即,對於遠程函數調用需要像調用本地的函數一樣,遠程服務或者程式碼看起來像運行在本地。
RPC 需要解決幾個問題:
- 如何通訊:即如何在調用者和服務提供者之間建立連接。
- 如何定址:即調用者如何找到服務提供者,怎麼知道其中有什麼服務。
- 如何發送參數:調用者發起遠程調用時候,方法的參數需要通過 TCP 等協議傳輸到伺服器,參數如何序列化?
- 如何接受參數:服務提供者收到參數之後如何反序列化,如何調用。
- 如何返回:服務提供者調用本地提供的服務之後,如何把返回值發送給調用者。
1.2 PyTorch RPC 四大支柱
以下翻譯自官方文檔 //pytorch.org/docs/master/rpc.html。
分散式 RPC 框架通過一組原語提供了多機模型訓練機制以允許遠程通訊,以及一個更高級別的 API 來自動區分拆分到多台機器上的模型。分散式 RPC 框架使遠程運行函數變得容易,支援引用遠程對象而無需複製真實數據,並提供 autograd 和優化器 API 以透明地向後運行和跨 RPC 邊界更新參數。這些功能可以分為四組 API。
- **遠程過程調用 (RPC) ** 支援使用給定的參數在指定的worker上運行函數並獲取返回值或創建對返回值的引用。有三個主要的 RPC API:
rpc_sync()
(同步)、rpc_async()
(非同步)和remote()
(非同步並返回對遠程返回值的引用)。如果用戶程式碼在沒有返回值的情況下無法繼續,請使用同步 API。否則,使用非同步 API 獲取 Future,並在調用者需要返回值時等待 Future。remote()
API 在需要遠程創建某些內容但從不需要將其獲取給調用者時很有用。想像一下driver進程設置參數伺服器和訓練器的情況。Driver 可以在參數伺服器上創建嵌入表,然後與訓練器共享嵌入表的引用,但其本身永遠不會在本地使用嵌入表。在這種情況下,rpc_sync()
和rpc_async()
已不再適用,因為他們總是意味著立即或在將來把返回值發給調用者。 - 遠程引用 (RRef)用作指向本地或遠程對象的分散式共享指針。它可以與其他 worker 共享,並且引用計數將被透明處理。每個 RRef 只有一個所有者,並且對象只存在於該所有者之中。持有 RRef 的非所有者worker 可以通過明確請求從所有者那裡獲取對象的副本。當 worker 需要訪問某個數據對象,但它本身既不是對象的創建者
remote()
函數的調用者也不是對象的所有者時,這很有用。分散式優化器就是此類用例的一個示例。 - Distributed Autograd將所有參與前向傳播 worker的本地 autograd 引擎縫合在一起,並在後向傳播期間自動聯繫他們以計算梯度。在進行前向傳遞如果需要跨越多台機器時,這尤其有用,例如分散式模型並行訓練、參數伺服器訓練等。 有了這個特性,用戶程式碼不再需要擔心如何跨 RPC 邊界發送梯度和應該以什麼順序啟動本地 autograd 引擎,如果前向傳遞中有嵌套和相互依賴的 RPC 調用,這可能會變得非常複雜。
- 分布優化器的構造需要一個
Optimizer()
(例如,SGD()
,Adagrad()
等)和一個RRefs的參數列表。即,在每個不同的Ref所有者之上創建一個Optimizer()
實例,然後運行step()
相應更新參數。當用戶進行分散式前向和後向傳播時,參數和梯度將分散在多個 worker 中,因此需要對每個相關 worker 進行優化。Distributed Optimizer 將所有這些本地優化器合而為一,並提供了簡潔的構造函數和step()
API。
1.3 RRef
下面我們以 //pytorch.org/docs/master/rpc/rref.html 為基準來學習遠程引用協議的基本概念和部分設計細節。
RRef 是遠程參考(Remote REFerence)的縮寫。 它是位於本地或遠程工作worker上對象的引用,並且透明地在內部進行引用計數。 從概念上講,它可以被視為一個分散式共享指針。 應用程式可以調用 remote()
創建 一個RRef。 每個 RRef 都被 remote()
的調用者(即所有者)所擁有,並且可以由多個用戶使用。 所有者存儲實際數據,並跟蹤全局參考計數。 每個 RRef 可以由全局RRefId
唯一標識,該全局RRefId
在創建時由 remote()
調用者分配。
在所有者worker中,只有一個OwnerRRef
實例包含真實數據,而在用戶worker之中,可以根據需要包含任意數量的UserRRefs
,UserRRef
不保存數據。當使用 RRP 時,所有者將使用全局唯一的RRefId來獲取唯一的OwnerRRef實例。 在 rpc_sync()
, rpc_async()
或 remote()
調用中,所有者創建一個UserRRef
,並將其用作參數或返回值。所有者將被通知並且相應更新參考計數。 如果全局沒有UserRRef
實例,並且所有者上也沒有對OwnerRRef
的引用,則OwnerRRef
及其數據將被刪除。
1.3.1 假設條件
RRef 協議的設計基於以下假設。
- 瞬態網路故障(Transient Network Failures):RRef 設計旨在通過重試消息來應對瞬態網路故障。 RRef不能處理節點崩潰或永久性網路分區,當這些事件發生時,應用程式應該關閉所有worker,還原到先前的checkpoint,然後恢復訓練。
- 非冪等 UDF (Non-idempotent UDFs):我們假設提供給
rpc_sync()
,rpc_async()
或remote()
的用戶函數(UDF)不是冪等的,因此無法重試。 但是,內部 RRef 控制消息是冪等且消息失敗時可重試。 - 消息傳遞無序(Out of Order Message Delivery):我們不會對一對節點之間的消息傳遞順序做假設,因為發送者和接收者都使用多個執行緒,所以無法保證首先處理哪個消息。
接下來我們只是大致講解如何使用,具體大家可以參閱 //pytorch.org/docs/master/rpc.html#distributed-rpc-framework。
1.3.2 同步調用
如下是同步調用API,該方法在 worker to
之上執行一個阻塞 RPC 調用來運行func
。RPC 消息的發送和接收與 Python 程式碼的執行並行。此方法是執行緒安全的。
torch.distributed.rpc.rpc_sync( to , func , args = None , kwargs = None , timeout = - 1.0 )
具體參數如下:
- to – 目標worker的name/rank/WorkerInfo。
- func (callable) – 一個可調用函數,例如 Python callables、內置運算符(例如add())和帶注釋的 TorchScript 函數。
- args –
func
調用的參數元組。 - kwargs –
func
調用關鍵字參數的字典。 - timeout – 用於此 RPC 的超時時間(以秒為單位)
返回值就是使用args
and kwargs
運行 func
的結果。
樣例:
確保 MASTER_ADDR
and MASTER_PORT
已經在兩個worker之上設置。
export MASTER_ADDR=localhost
export MASTER_PORT=5678
然後在兩個不同的進程中運行以下程式碼
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
1.3.2 非同步調用
如下是非同步調用API,該方法在 worker to
之上執行一個非阻塞 RPC 調用來運行func
。RPC 消息的發送和接收與 Python 程式碼的執行並行。此方法是執行緒安全的。該方法立刻返回一個可以被等待的Future
。
torch.distributed.rpc.rpc_async(to, func, args=None, kwargs=None, timeout=- 1.0)
具體參數如下:
- to – 目標worker的name/rank/
WorkerInfo
。 - func (callable) – 一個可調用函數,例如 Python callables、內置運算符(例如add())和帶注釋的 TorchScript 函數。
- args –
func
調用的參數元組。 - kwargs – 是
func
調用關鍵字參數的字典。 - timeout – 用於此 RPC 的超時時間(以秒為單位)
返回一個可等待的Future
對象。完成後,可以從 對象中檢索出func
的返回值。
樣例:
確保 MASTER_ADDR
and MASTER_PORT
已經在兩個worker之上設置。
>>> export MASTER_ADDR=localhost
>>> export MASTER_PORT=5678
然後在兩個不同的進程中運行以下程式碼
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
>>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
>>> result = fut1.wait() + fut2.wait()
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
0x02 示例
我們接下來以 //pytorch.org/docs/master/rpc/distributed_autograd.html 為基礎進行學習。
假設您有兩個節點和一個跨兩個節點分區的非常簡單的模型。這可以使用torch.distributed.rpc
如下實現。
分散式 autograd 背後的主要動機是在這種分散式模型上運行反向傳播loss
,我們已經計算並記錄了所有需要梯度的張量的梯度。
import torch
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
0x03 前向傳播期間的 Autograd 記錄
PyTorch 在前向傳播期間構建 autograd 圖,該圖用於執行後向傳播。有關更多詳細資訊,請參閱 autograd 如何編碼歷史記錄。
對於分散式 autograd,我們需要在前向傳播期間跟蹤所有 RPC,以確保正確執行後向傳播。為此,當執行 RPC 時候,我們把 send
和recv
functions 附加到autograd圖之上。
- 該
send
函數附加到 RPC 的發起源節點之上,其輸出邊指向 RPC 輸入張量的 autograd 函數。在向後傳播期間,send
函數的輸入是從目標接收的,是對應recv
函數的輸出。 - 該
recv
函數附加到 RPC 的接受目標節點之上,其輸入從某些運算符得到,這些運算符使用輸入張量在RPC接受目標上執行。在後向傳播期間,recv
函數的輸出梯度將被發送到源節點之上,並且作為send
方法的輸入。 - 每
send-recv
對被分配一個全局唯一的autograd_message_id
以唯一地標識該send-recv
對。這對於在向後傳播期間查找遠程節點上的相應函數很有用。 - 對於RRef,每當我們調用
torch.distributed.rpc.RRef.to_here()
時,我們都為涉及的張量添加了一個適當的send-recv
對。
例如,這就是我們上面示例的 autograd 圖的樣子(為簡單起見,t5.sum() 被排除在外)。
我們可以看到,send方法在前向傳播中是發送者,但是在反向傳播之中就是接受者。
0x04 分散式 Autograd 上下文
每個使用分散式 autograd 的前向和後向傳播都被分配了一個唯一的torch.distributed.autograd.context
,並且這個上下文具有一個全局唯一的autograd_context_id
。如果有需要,在每個節點上都會創建上下文。
上下文的作用如下:
- 運行分散式反向傳播的多個節點可能會在同一個張量上累積梯度並且存儲在張量的
.grad
之上。在我們運行優化器之前,張量的.grad
可能累積了來自各種分散式反向傳播的梯度。這類似於把torch.autograd.backward()
在本地進行多次調用。為了提供一種把每個反向傳播梯度分離開的方法,在每個反向傳播過程里,梯度將被累積在torch.distributed.autograd.context
之中。 - 在前向傳播期間,我們在上下文中存儲每個 autograd 傳播的
send
和recv
函數。這確保我們在 autograd 圖中保存對適當節點的引用以使其保持活動狀態。除此之外,這也使得在向後傳播期間很容易查找到對應的send
和recv
函數。 - 一般來說,我們也使用這個上下文來存儲每個分散式 autograd 傳播的一些元數據。
從用戶的角度來看,autograd 上下文設置如下:
import torch.distributed.autograd as dist_autograd
with dist_autograd.context() as context_id:
loss = model.forward()
dist_autograd.backward(context_id, loss)
需要注意的是,模型的前向傳播必須在分散式autograd上下文管理器中調用,因為需要一個有效的上下文來確保:所有的send
和recv
方法被存儲起來,並且在所有參與節點之上執行後向傳播。
0x05 分散式反向傳播
在本節中,我們將概述在分散式反向傳播期間準確計算依賴關係所遇到的挑戰,並且也講述幾種如何執行分散式反向傳播的演算法(演算法內部有權衡)。
5.1 計算依賴關係
首先,考慮在單台機器上運行以下程式碼
import torch
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = a + b
e = b * c
d.sum.().backward()
下圖就是上面程式碼對應的 autograd 圖。
作為反向傳播的一部分,autograd 引擎執行的第一步是計算 autograd 圖中每個節點的依賴項數量。這有助於 autograd 引擎知道圖中的節點何時準備好了可以執行。括弧內為數字add(1)
和mul(0)
表示依賴關係的數量。如您所見,這意味著在向後傳播期間,add
節點需要 1 個輸入,mul
節點不需要任何輸入(換句話說,不需要執行)。本地 autograd 引擎通過從根節點(在本例中是d
)遍歷圖來計算這些依賴關係。
實際上,Autograd 圖中的某些節點可能不會在向後傳播中執行。這一事實對分散式 autograd 提出了挑戰。考慮這段使用 RPC 的程式碼。
import torch
import torch.distributed.rpc as rpc
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = rpc.rpc_sync("worker1", torch.add, args=(a, b))
e = rpc.rpc_sync("worker1", torch.mul, args=(b, c))
loss = d.sum()
上面程式碼的關聯 autograd 圖將是:
計算此分散式 autograd 圖的依賴項更具挑戰性,並且需要一些開銷(在計算或網路通訊方面)。
對於性能敏感的應用,我們可以通過假設每個send
和recv
函數都是反向傳播的有效成分來避免大量開銷(大多數應用不會執行未使用的 RPC)。這簡化了分散式 autograd 演算法並且效率更高,但代價是應用程式需要了解這些限制。這種演算法稱為FAST模式演算法,下面詳細介紹。
在一般情況下, 作為向後傳播的一部分,可能不需要每個send
和recv
函數都是有效的。為了解決這個問題,我們提出了一種SMART 模式演算法,此演算法將在後面的部分中描述。請注意,目前僅實現了FAST模式演算法。
5.2 FAST模式演算法
該演算法的關鍵假設是:當我們運行反向傳播時,每個send
函數的依賴為 1。換句話說,我們假設我們會從另一個節點通過 RPC 接收梯度。
演算法如下:
- 我們從具有反向傳播根的worker開始(所有根都必須是本地的)。
- 查找當前Distributed Autograd Context 的所有
send
函數 。 - 從提供的根和我們檢索到的所有
send
函數開始,我們在本地計算依賴項 。 - 計算依賴項後,使用提供的根來啟動本地 autograd 引擎。
- 當 autograd 引擎執行該
recv
函數時,該recv
函數通過 RPC 將輸入梯度發送到適當的worker。每個recv
函數都知道目標 worker id,因為它被記錄為前向傳播的一部分。通過autograd_context_id
和autograd_message_id
該recv
函數被發送到遠程主機。 - 當遠程主機收到這個請求時,我們使用
autograd_context_id
和autograd_message_id
來查找適當的send
函數。 - 如果這是worker第一次收到對給定
autograd_context_id
的請求,它將按照上面的第 1-3 點所述在本地計算依賴項。 - 然後將在第6點接受到的
send
方法插入隊列,以便在該worker的本地 autograd 引擎上執行。 - 最後,我們不是在 Tensor的
.grad
之上累積梯度,而是在每個Distributed Autograd Context之上分別累積梯度 。梯度存儲在Dict[Tensor, Tensor]
之中 ,Dict[Tensor, Tensor]
基本上是從 Tensor 到其關聯梯度的映射,並且可以使用 get_gradients() API檢索該映射 。
例如,分散式 autograd 的完整程式碼如下:
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
# Run the backward pass.
dist_autograd.backward(context_id, [loss])
# Retrieve the gradients from the context.
dist_autograd.get_gradients(context_id)
具有依賴關係的分散式 autograd 圖如下(為簡單起見,t5.sum() 被排除在外):
應用於上述示例的FAST 模式演算法如下:
- 在
Worker 0
上,我們從根loss
和send1
開始計算依賴關係。 結果,send1
對Worker 0
的依賴數為 1,mul
對Worker 0
的依賴數為 1。 - 現在,我們在
Worker 0
上啟動本地 autograd 引擎。 我們首先執行mul
函數,將其輸出作為t4
的梯度,累積存儲在 autograd 上下文中。 然後,我們執行recv2
,它將這些梯度發送到Worker 1
。 - 由於這是
Worker 1
第一次知道有關此反向傳播的資訊,因此它將進行依賴關係計算,並且相應地標記send2
,add
和recv1
的依賴性。 - 接下來,在
Worker 1
的本地autograd
引擎上將send2
插入隊列,該引擎將依次執行add
和recv1
。 - 當執行
recv1
時,它將梯度發送到Worker 0
。 - 由於
Worker 0
已經計算了此向後傳播的依賴性,因此它僅僅在本地將send1
插入隊列並且執行。 - 最後,
t1
,t2
和t4
的梯度會累積在分散式 Autograd 上下文中。
5.3 SMART模式演算法
該演算法的全部細節仍在研究中,但對於總體思路,您可以參考RFC中的分散式 Autograd 演算法智慧模式部分 。
0x06 分散式優化器
該DistributedOptimizer
操作如下:
- 獲取要優化的遠程參數(
RRef
)列表。這些參數也可以是包含在本地RRef
的本地參數。 - 將一個
Optimizer
類作為本地優化器,該優化器將在所有不同的RRef
擁有者之上運行。 - 分散式優化器在每個工作節點上創建一個本地
Optimizer
實例,並且對於每一個Optimizer
保存一個RRef
。 - 當調用
torch.distributed.optim.DistributedOptimizer.step()
時,分散式優化器使用 RPC 在適當的遠程工作者上遠程執行所有本地優化器。必須為torch.distributed.optim.DistributedOptimizer.step()
提供一個分散式autogradcontext_id
。 本地優化器使用context_id
在相應上下文中存儲梯度。 - 如果多個並發分散式優化器正在更新一個 worker 上的同一批參數,這些更新將通過鎖來進行序列操作。
0x07 簡單的端到端示例
綜上所述,以下是一個使用分散式 autograd 和分散式優化器的簡單端到端示例。如果將程式碼放入名為「dist_autograd_simple.py」的文件中,則可以使用以下命令運行 :MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py
import torch
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer
def random_tensor():
return torch.rand((3, 3), requires_grad=True)
def _run_process(rank, dst_rank, world_size):
name = "worker{}".format(rank)
dst_name = "worker{}".format(dst_rank)
# Initialize RPC.
rpc.init_rpc(
name=name,
rank=rank,
world_size=world_size
)
# Use a distributed autograd context.
with dist_autograd.context() as context_id:
# Forward pass (create references on remote nodes).
rref1 = rpc.remote(dst_name, random_tensor)
rref2 = rpc.remote(dst_name, random_tensor)
loss = rref1.to_here() + rref2.to_here()
# Backward pass (run distributed autograd).
dist_autograd.backward(context_id, [loss.sum()])
# Build DistributedOptimizer.
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
# Run the distributed optimizer step.
dist_optim.step(context_id)
def run_process(rank, world_size):
dst_rank = (rank + 1) % world_size
_run_process(rank, dst_rank, world_size)
rpc.shutdown()
if __name__ == '__main__':
# Run world_size workers
world_size = 2
mp.spawn(run_process, args=(world_size,), nprocs=world_size)
0xFF 參考
//pytorch.org/docs/master/rpc/distributed_autograd.html#distributed-autograd-design
//pytorch.org/docs/master/rpc.html#distributed-autograd-framework
//pytorch.org/docs/master/rpc/rref.html
//pytorch.org/docs/master/rpc.html#distributed-rpc-framework