[源碼解析] PyTorch分佈式(6) ——– DistributedDataParallel — 初始化&store
- 2021 年 11 月 18 日
- 筆記
- 001_機器學習, 006_深度學習, 011_分佈式機器學習
[源碼解析] PyTorch分佈式(6) —DistributedDataParallel — 初始化&store
0x00 摘要
本文是 PyTorch 分佈式系列的第六篇, 介紹 DistributedDataParallel 所依賴的初始化方法和Store這兩個概念。
本系列其他文章如下:
[源碼解析]PyTorch如何實現前向傳播(1) — 基礎類(上)
[源碼解析]PyTorch如何實現前向傳播(2) — 基礎類(下)
[源碼解析] PyTorch如何實現前向傳播(3) — 具體實現
[源碼解析] Pytorch 如何實現後向傳播 (1)—- 調用引擎
[源碼解析] Pytorch 如何實現後向傳播 (2)—- 引擎靜態結構
[源碼解析] Pytorch 如何實現後向傳播 (3)—- 引擎動態邏輯
[源碼解析] PyTorch 如何實現後向傳播 (4)—- 具體算法
[源碼解析] PyTorch 分佈式(2) —– DataParallel(上)
[源碼解析] PyTorch 分佈式(3) —– DataParallel(下)
[源碼解析] PyTorch 分佈式(4)——分佈式應用基礎概念
[源碼解析] PyTorch分佈式(5) —— DistributedDataParallel 總述&如何使用
0x01 回顧
1.1 基本概念
關於分佈式通信,PyTorch 提供的幾個概念是:進程組,後端,初始化,Store。
- 進程組 :DDP是真正的分佈式訓練,可以使用多台機器來組成一次並行運算的任務。為了能夠讓 DDP 的各個worker之間通信,PyTorch 設置了進程組這個概念。
- 後端 :後端這個概念是一個邏輯上的概念。本質上後端是一種IPC通信機制。
- 初始化 : 雖然有了後端和進程組的概念,但是如何讓 worker 在建立進程組之前發現彼此? 這就需要一種初始化方法來告訴大家傳遞一個信息:如何聯繫到其它機器上的進程。
- Store : 可以認為是分佈式鍵值存儲,利用這個存儲就可以在組中的進程之間共享信息以及初始化分佈式包 (通過顯式創建存儲來作為
init_method
的替代)。
1.2 初始化進程組
在調用任何 DDP 其他方法之前,需要使用torch.distributed.init_process_group()
進行初始化。該方法會初始化默認分佈式進程組和分佈式包。此方法會阻塞,直到所有進程都加入,函數定義如下:
init_process_group ( backend ,
init_method = None ,
timeout = default_pg_timeout ,
world_size =- 1 ,
rank =- 1 ,
store = None ,
group_name = '' ,
pg_options = None )
初始化進程組有兩種主要方法:
- 明確指定 store,rank 和 world_size。
- 指定 init_method(一個 URL 字符串),它指示在哪裡/如何發現對等點。
如果兩者都沒有指定,init_method
則假定為「env://」。因此大家可以看到,store 和 init_method 是互斥的。
init_process_group 的參數具體如下:
- 後端 – 要使用的後端。有效值包括
mpi
,gloo
,和nccl
。該字段應作為小寫字符串(例如"gloo"
)給出,也可以通過Backend
屬性(例如Backend.GLOO
)訪問 。如果在nccl
後端每台機器上使用多個進程,則每個進程必須對其使用的每個 GPU 具有獨佔訪問權限,因為在進程之間共享 GPU 可能會導致死鎖。 - init_method – 指定如何初始化進程組的 URL。如果未指定
init_method
或store
指定,則默認為「env://」 。與store
互斥。 - world_size – 參與作業的進程數。如果
store
指定,則 world_size 為必需。 - rank – 當前進程的等級(它應該是一個介於 0 和
world_size
-1之間的數字)。如果store
指定,則 rank 為必需。 - store – 所有 worker 都可以訪問的鍵/值存儲,用於交換連接/地址信息。與
init_method
互斥。 - timeout – 針對進程組執行的操作超時。默認值等於 30 分鐘。這適用於
gloo
後端。對於nccl
,這僅在環境變量NCCL_BLOCKING_WAIT
或NCCL_ASYNC_ERROR_HANDLING
設置為 1 時 適用。 - group_name – 組名。
- pg_options ( Process Group Options , optional ) – 進程組選項,指定在構建特定進程組期間需要傳入哪些附加選項。
0x02 初始化
2.1 初始化方法
目前DDP模塊支持三種初始化方式:
- Environment variable initialization
- Shared file-system initialization:init_method=‘file:///mnt/nfs/sharedfile’
- TCP initialization :init_method=‘tcp://10.1.1.20:23456’
環境變量
此方法將從環境變量中讀取配置,是允許完全自定義獲取信息的方式。通過在所有機器上設置以下四個環境變量,所有進程都可以正常連接到master(就是 rank 0 進程)以獲取其他進程的信息,並最終與它們握手。
MASTER_PORT
:rank 0 進程的機器上的端口。MASTER_ADDR
:rank 0 進程的機器上的 IP 地址。WORLD_SIZE
: 進程總數,因此master知道要等待多少worker。RANK
: 每個進程的rank,所以進程會知道自己是否是master。
共享文件系統
共享文件系統要求所有進程都可以訪問共享文件系統,並將通過共享文件協調它們。這意味着每個進程都將打開文件,寫入其信息,並等待每個進程都這樣做。之後,所有所需的信息都將可供所有流程使用。為了避免競爭條件,文件系統必須通過fcntl支持鎖定 。
dist.init_process_group(
init_method='file:///mnt/nfs/sharedfile',
rank=args.rank,
world_size=4)
TCP
TCP 初始化方式是通過提供rank 0進程的IP和端口來實現的,在這裡,所有worker都可以連接到等級為 0 的進程並交換有關如何相互聯繫的信息。
dist.init_process_group(
init_method='tcp://10.1.1.20:23456',
rank=args.rank,
world_size=4)
2.2 init_method VS store
我們很好奇,為什麼要有 init_method 和 store 這兩個參數?
通過看 init_process_group 代碼我們可以發現以下規律。
-
當 MPI 時候, init_method 沒有用處。
-
在非 MPI 後端時候,如果沒有 store 參數,則使用 init_method 構建一個store。
所以最終還是落到了 store 之上,store才是其作用的實體。
if store is None:
rendezvous_iterator = rendezvous(
init_method, rank, world_size, timeout=timeout
)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
init_process_group 代碼如下:
def init_process_group(backend,
init_method=None,
timeout=default_pg_timeout,
world_size=-1,
rank=-1,
store=None,
group_name='',
pg_options=None):
global _pg_group_ranks
global _backend
global _default_pg_init_method
if store is not None:
assert world_size > 0, 'world_size must be positive if using store'
assert rank >= 0, 'rank must be non-negative if using store'
elif init_method is None:
init_method = "env://"
backend = Backend(backend)
if backend == Backend.MPI:
default_pg = _new_process_group_helper(
-1,
-1,
[],
Backend.MPI,
None,
group_name=group_name,
timeout=timeout)
_update_default_pg(default_pg)
else:
# backward compatible API
if store is None:
# 如果沒有store,還是要用init_method構建一個store。
rendezvous_iterator = rendezvous(
init_method, rank, world_size, timeout=timeout
)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
default_pg = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
pg_options=pg_options,
group_name=group_name,
timeout=timeout)
_update_default_pg(default_pg)
_pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore[attr-defined, index]
_backend = _pg_map[GroupMember.WORLD][0] # type: ignore[index]
_default_pg_init_method = init_method
# 省略
2.3 rendezvous
上面代碼之中提到了 rendezvous,我們就來看看這個概念。
在我們可以運行集合算法之前,參與的進程需要找到彼此並交換信息才能夠進行通信。我們稱這個過程為rendezvous。rendezvous過程的結果是一個三元組,其中包含一個共享鍵/值存儲(store),進程的等級(rank)和參與進程的總數。如果內置的rendezvous方法都不適用於您的執行環境,那麼您可以選擇註冊自己的rendezvous處理程序。在調用rendezvous
函數時,選擇一個唯一的名稱並使用URL方案來標識它。
rendezvous 方法就是依據參數,選擇不同的handler來處理。
def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
# Append node-specific arguments.
result = urlparse(url)
if rank != -1 or world_size != -1:
query_dict: Dict[str, Union[int, str]] = dict(
# mypy doesn't allow dict() to accept List of values (#257)
pair.split("=") for pair in filter(None, result.query.split("&")) # type: ignore[arg-type, misc]
)
if rank != -1:
query_dict["rank"] = rank
if world_size != -1:
query_dict["world_size"] = world_size
result = result._replace(
query="{}".format("&".join(["{}={}".format(k, v) for k, v in query_dict.items()]))
)
url = urlunparse(result)
return _rendezvous_handlers[result.scheme](url, **kwargs)
handler 如下,你會發現,其實 handler 就是對應了初始化的三種方法:
register_rendezvous_handler("tcp", _tcp_rendezvous_handler)
register_rendezvous_handler("env", _env_rendezvous_handler)
register_rendezvous_handler("file", _file_rendezvous_handler)
2.4 小結
從目前分析結果來看,我們得到了如下結論:
- init_method 最終還是落到了 store 之上,store才是起作用的實體。
- 參與的進程需要找到彼此並交換信息才能夠進行通信。這個過程被稱為rendezvous。
0x03 Store
我們給出一個正式的概念。Store 是分佈式包(distributed package)所提供的分佈式鍵值存儲,所有的 workers 都會訪問這個存儲以共享信息以及初始化分佈式包 。用戶可以通過顯式創建存儲來作為init_method
的替代。目前有 3 種鍵值存儲:TCPStore
, FileStore
,和HashStore
。
我們接着上節繼續看 handler 概念。
3.1 _rendezvous_handlers
在 PyTorch 定義了一個全局變量 _rendezvous_handlers,用來保存如何返回 store 的方法,可以認為是工廠方法。
_rendezvous_handlers = {}
具體註冊方式是:
register_rendezvous_handler("tcp", _tcp_rendezvous_handler)
register_rendezvous_handler("env", _env_rendezvous_handler)
register_rendezvous_handler("file", _file_rendezvous_handler)
註冊代碼如下,就是往全局變量之中插入handler。
def register_rendezvous_handler(scheme, handler):
"""Registers a new rendezvous handler.
Args:
scheme (str): URL scheme to identify your rendezvous handler.
handler (function): Handler that is invoked when the
`rendezvous()` function is called with a URL that uses
the corresponding scheme. It must be a generator function
that yields the triplet.
"""
global _rendezvous_handlers
if scheme in _rendezvous_handlers:
raise RuntimeError(
"Rendezvous handler for {}:// already registered".format(scheme)
)
_rendezvous_handlers[scheme] = handler
3.2 handlers
如果仔細看 handlers 的代碼,就會發現其就是返回了不同的 store,比如 _tcp_rendezvous_handler具體就是使用各種信息建立 TCPStore,然後返回。
以下代碼均刪除非關鍵代碼。
3.2.1 _file_rendezvous_handler
這裡返回了FileStore。
def _file_rendezvous_handler(url: str, **kwargs):
result = urlparse(url)
path = result.path
query: Dict[str, str]
# mypy doesn't allow dict() to accept List of values (#257)
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
rank = int(query["rank"])
world_size = int(query["world_size"])
store = FileStore(path, world_size)
yield (store, rank, world_size)
# If this configuration is invalidated, there is nothing we can do about it
raise RuntimeError("Unable to perform rerendezvous using file:// method")
3.2.2 _tcp_rendezvous_handler
這裡返回了 TCPStore。
def _tcp_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs):
result = urlparse(url)
query: Dict[str, Union[int, str]]
# mypy doesn't allow dict() to accept List of values (#257)
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
rank = int(query["rank"])
world_size = int(query["world_size"])
start_daemon = rank == 0
assert result.hostname is not None
store = TCPStore(result.hostname, result.port, world_size, start_daemon, timeout)
yield (store, rank, world_size)
# If this configuration is invalidated, there is nothing we can do about it
raise RuntimeError("Unable to perform rerendezvous using tcp:// method")
3.2.3 _env_rendezvous_handler
居然也返回了 TCPStore,但是其會從環境變量中提取需要的信息。
def _env_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs):
result = urlparse(url)
query: Dict[str, Union[int, str]]
query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
rank: Optional[Union[str, int]]
world_size: Optional[Union[str, int]]
master_port: Optional[Union[str, int]]
if "rank" in query:
rank = int(query["rank"])
else:
rank = int(_get_env_or_raise("RANK"))
if "world_size" in query:
world_size = int(query["world_size"])
else:
world_size = int(_get_env_or_raise("WORLD_SIZE"))
master_addr = _get_env_or_raise("MASTER_ADDR")
master_port = int(_get_env_or_raise("MASTER_PORT"))
use_torchelastic_store = os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None)
if use_torchelastic_store == str(True):
worker_process_prefix = "/worker"
# When TORCHELASTIC_USE_AGENT_STORE is set up, the worker process is assumed
# to be invoked by the torchelastic agent. Torchelastic agent creates a tcp daemon thread
# on the GROUP_RANK=0, as a result all user worker processes should create store with: daemon=False
tcp_store = TCPStore(master_addr, master_port, world_size, False, timeout)
yield (PrefixStore(worker_process_prefix, tcp_store), rank, world_size)
else:
# Start the TCP store daemon on the rank 0
start_daemon = rank == 0
store = TCPStore(master_addr, master_port, world_size, start_daemon, timeout)
yield (store, rank, world_size)
# If this configuration is invalidated, there is nothing we can do about it
raise RuntimeError("Unable to perform rerendezvous using env:// method")
3.3 使用
3.3.1 使用 handler
如何使用 handler?在 init_process_group 之中有:
rendezvous_iterator = rendezvous(
init_method, rank, world_size, timeout=timeout
)
store, rank, world_size = next(rendezvous_iterator)
rendezvous 具體就是依據 init_method 來選擇一個 _rendezvous_handler,然後 _rendezvous_handler 返回了 store。
def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
# Append node-specific arguments.
result = urlparse(url)
if rank != -1 or world_size != -1:
query_dict: Dict[str, Union[int, str]] = dict(
# mypy doesn't allow dict() to accept List of values (#257)
pair.split("=") for pair in filter(None, result.query.split("&")) # type: ignore[arg-type, misc]
)
if rank != -1:
query_dict["rank"] = rank
if world_size != -1:
query_dict["world_size"] = world_size
result = result._replace(
query="{}".format("&".join(["{}={}".format(k, v) for k, v in query_dict.items()]))
)
url = urlunparse(result)
return _rendezvous_handlers[result.scheme](url, **kwargs)
3.3.2 使用 Store
我們繼續看如何使用 store。在 init_process_group 代碼之中,接下來就使用了 store 來初始化進程組。
default_pg = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
pg_options=pg_options,
group_name=group_name,
timeout=timeout)
_update_default_pg(default_pg)
3.3.2.1 _new_process_group_helper
為了接着看 _new_process_group_helper,我們首先看看幾個全局變量。以下幾個變量 ProcessGroup 信息做了全局存儲,比如 _pg_map[pg] = (Backend.NCCL, store)。
# Cached process groups
# For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
# For MPI pg, it is a map from ProcessGroup to (Backend, None)
_pg_map: Dict[ProcessGroup, Tuple[str, Optional[Store]]] = {}
# Process group's names, map from ProcessGroup to str
_pg_names: Dict[ProcessGroup, str] = {}
# Process group's global rank to local rank mapping
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
_new_process_group_helper
之中得到了 store 參數之後,據此生成了一個 prefix_store,然後再根據這個 pre_store 來生成了 ProcessGroupGloo。_new_process_group_helper
代碼具體如下:
def _new_process_group_helper(world_size,
rank,
group_ranks,
backend,
store,
pg_options=None,
group_name=None,
timeout=default_pg_timeout):
"""
Create a new distributed process group.
This function must be called by ALL processes in the global group, even if
the calling process is not part of the newly created group. In that case,
this function returns GroupMember.NON_GROUP_MEMBER.
This function is called with ``group_ranks == []`` for the default group.
"""
global _pg_map
global _group_count
global _pg_names
if not group_name:
group_name = str(_group_count)
_group_count += 1
# The list of group ranks is empty if we're creating the default group.
is_default_group = (len(group_ranks) == 0)
backend = Backend(backend)
pg: Union[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL]
if backend == Backend.MPI: # 沒有使用store
pg = ProcessGroupMPI.create(group_ranks)
if not pg:
return GroupMember.NON_GROUP_MEMBER
_pg_map[pg] = (Backend.MPI, None)
_pg_names[pg] = group_name
else:
# 這裡會使用store
# If this is a subgroup (which means group_ranks is specified),
# we check if the current process is a member of the new group.
if not is_default_group:
global_rank = _get_default_group().rank()
if global_rank not in group_ranks:
return GroupMember.NON_GROUP_MEMBER
# Use the group name as prefix in the default store, such that
# a single store can be reused by multiple groups.
prefix_store = PrefixStore(group_name, store) # 構建了 PrefixStore
if backend == Backend.GLOO:
pg = ProcessGroupGloo(
prefix_store, # 使用PrefixStore構建進程組
rank,
world_size,
timeout=timeout)
_pg_map[pg] = (Backend.GLOO, store)
_pg_names[pg] = group_name
elif backend == Backend.NCCL:
if pg_options is not None:
assert isinstance(pg_options, ProcessGroupNCCL.Options), \
"Expected pg_options argument to be of type ProcessGroupNCCL.Options"
else:
# default pg_options for NCCL
pg_options = ProcessGroupNCCL.Options()
pg_options.is_high_priority_stream = False
pg_options._timeout = timeout
pg = ProcessGroupNCCL(
prefix_store, # 使用PrefixStore構建進程組
rank,
world_size,
pg_options)
_pg_map[pg] = (Backend.NCCL, store)
_pg_names[pg] = group_name
else:
pg = getattr(Backend, backend.upper())(
prefix_store,
rank,
world_size,
timeout)
_pg_map[pg] = (backend, store)
_pg_names[pg] = group_name
return pg
3.3.2.2 ProcessGroupGloo
在 ProcessGroupGloo 之中有具體使用,比如在PrefixStore之上生成了一個GlooStore,利用 PrefixStore 建立網絡等等。
ProcessGroupGloo::ProcessGroupGloo(
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
c10::intrusive_ptr<Options> options)
: ProcessGroup(rank, size),
store_(new GlooStore(store)), // 在PrefixStore之上生成了一個GlooStore
options_(options),
stop_(false),
collectiveCounter_(0) {
auto& devices = options->devices;
contexts_.reserve(options->devices.size());
for (size_t i = 0; i < options->devices.size(); i++) {
auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_);
// 又生成了一個PrefixStore
auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_);
context->setTimeout(options->timeout);
// 利用 PrefixStore 建立網絡
context->connectFullMesh(store, options->devices[i]);
contexts_.push_back(std::move(context));
}
// Every worker thread stores the AsyncWork object it's currently
// working on in the workInProgress_ vector. It must have size equal
// to the number of workers such that they can simply index into it
// using the worker index they are started with.
workInProgress_.resize(options->threads);
threads_.resize(options->threads);
for (size_t i = 0; i < threads_.size(); i++) {
threads_[i] = std::thread(&ProcessGroupGloo::runLoop, this, i);
}
}
在下面代碼之中,也有對store_的使用,比如等待,存取。
void ProcessGroupGloo::setSequenceNumberForGroup() {
if (rank_ == 0) {
// Create and broadcast sequence number
auto seq = 1 + rand();
sequenceNum_ = c10d::SequenceNum(seq);
std::vector<char> values = c10d::toVec<char>(seq, kBytes);
store_->set(kSeqNumStoreKey, values); // 存value
} else {
// Read rank 0's sequence number from store.
sequenceNum_ = c10d::SequenceNum();
store_->wait({kSeqNumStoreKey}, options_->timeout); // 等待
std::vector<char> values = store_->get(kSeqNumStoreKey); // 取value
uint64_t num = c10d::fromVec<char>(values);
sequenceNum_->set(num);
}
}
3.4 小結
從目前分析結果來看,我們拓展結論如下:
- init_method 最終還是落到了 store 之上,store才是起作用的實體。
- 參與的進程需要找到彼此並交換信息才能夠進行通信。這個過程被稱為rendezvous。
- rendezvous 其實就是返回了某一種store 以供後續通信使用。
- 在進程組之中,會使用 store 來構建通信,等待,存取等。
我們接下來選擇 TCPStore進行相信分析。
0x04 TCPStore
TCPStore 是基於 TCP 的分佈式鍵值存儲實現。服務器存儲/保存數據,而存儲客戶端可以通過 TCP 連接到服務器存儲並執行諸如set()
插入鍵值對、get()
檢索鍵值對等操作。系統中應該有一個初始化完畢的TCPStore存儲服務器,因為存儲客戶端將等待這個存儲服務以建立連接。
TCPStore 的參數如下:
- host_name ( str ) – 主機名或 IP 地址。存儲服務器在其上運行。
- port ( int ) – 存儲服務器在這個端口上偵聽傳入請求。
- world_size ( int , optional ) – 用戶總數。
- world_size = 客戶端數 + 1,1 代表服務器。
- 默認值為 -1(負值表示不固定的用戶數)。
- is_master ( bool , optional ) – 初始化存儲服務器時為真,初始化存儲客戶端時為假。默認值為假。
- timeout ( timedelta , optional ) – store在初始化期間,以及get()和 wait()方法使用的超時時間。默認為 timedelta(seconds=300)。
- wait_for_worker ( bool , optional ) – 是否等待所有worker與存儲服務器連接。這僅在 world_size 為固定值時適用。默認值為真。
使用例子如下:
import torch.distributed as dist
from datetime import timedelta
# Run on process 1 (server)
server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30))
# Run on process 2 (client)
client_store = dist.TCPStore("127.0.0.1", 1234, 2, False)
# Use any of the store methods from either the client or server after initialization
server_store.set("first_key", "first_value")
client_store.get("first_key")
或者
>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> # Using TCPStore as an example, other store types can also be used
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> # This will throw an exception after 10 seconds
>>> store.wait(["bad_key"], timedelta(seconds=10))
從例子上看,就是簡單的 server,client 或者說 master, worker 的關係,我們接下來仔細分析。
4.1 TCPStore in python
在 Python 世界之中,就是簡單的設定了 host 和 port。
class TCPStore(Store):
def __init__(self, host_name, port, world_size=-1, is_master=False, timeout=None, *args, **kwargs): # real signature unknown; NOTE: unreliably restored from __doc__
pass
host = property(lambda self: object(), lambda self, v: None, lambda self: None) # default
"""Gets the hostname on which the store listens for requests."""
port = property(lambda self: object(), lambda self, v: None, lambda self: None) # default
"""Gets the port number on which the store listens for requests."""
我們需要深入到 C++ 世界看看。
4.2 TCPStore in CPP
4.2.1 API接口
首先,C++之中的 TCPStore 可以認為是一個API接口,其定義如下:
class TCPStore : public Store {
public:
explicit TCPStore(
const std::string& masterAddr,
PortType masterPort,
c10::optional<int> numWorkers = c10::nullopt_t(-1),
bool isServer = false,
const std::chrono::milliseconds& timeout = kDefaultTimeout,
bool waitWorkers = true);
virtual ~TCPStore();
void set(const std::string& key, const std::vector<uint8_t>& value) override;
std::vector<uint8_t> compareSet(
const std::string& key,
const std::vector<uint8_t>& expectedValue,
const std::vector<uint8_t>& desiredValue) override;
std::vector<uint8_t> get(const std::string& key) override;
int64_t add(const std::string& key, int64_t value) override;
bool deleteKey(const std::string& key) override;
// NOTE: calling other TCPStore APIs inside the callback is NOT threadsafe
// watchKey() is a blocking operation. It will register the socket on
// TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon. It will
// return once it has verified the callback is registered on both background
// threads. Only one thread can call watchKey() at a time.
void watchKey(const std::string& key, WatchKeyCallback callback) override;
bool check(const std::vector<std::string>& keys) override;
int64_t getNumKeys() override;
void wait(const std::vector<std::string>& keys) override;
void wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) override;
// Waits for all workers to join.
void waitForWorkers();
// Returns the hostname used by the TCPStore.
const std::string& getHost() const noexcept;
// Returns the port used by the TCPStore.
PortType getPort() const noexcept;
private:
int64_t addHelper_(const std::string& key, int64_t value);
std::vector<uint8_t> getHelper_(const std::string& key);
void waitHelper_(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout);
std::mutex watchKeyMutex_;
bool isServer_;
int storeSocket_ = -1; //
int listenSocket_ = -1; //
int masterListenSocket_ = -1; // master 在這裡監聽
std::string tcpStoreAddr_;
PortType tcpStorePort_;
c10::optional<int> numWorkers_;
const std::string initKey_;
const std::string regularPrefix_;
std::unique_ptr<TCPStoreMasterDaemon> tcpStoreMasterDaemon_ = nullptr;
std::unique_ptr<TCPStoreWorkerDaemon> tcpStoreWorkerDaemon_ = nullptr;
};
4.2.2 socket用處
其成員變量之中最主要的是三個socket,或者說他們是 store 的精華(難點)所在。
int storeSocket_ = -1; //
int listenSocket_ = -1; //
int masterListenSocket_ = -1; // master 在這裡監聽
4.2.2.1 業務分工
具體解釋如下(後面還會結合代碼繼續分析):
- masterListenSocket_ 是 listen 在 masterPort 之上。
tcpStoreMasterDaemon_
本身是一個master,就是為整個 TCPStore提供服務的 server。tcpStoreMasterDaemon_
使用tcputil::addPollfd(fds, storeListenSocket_, POLLIN)
來監聽masterListenSocket_
。- key-value 就是std::unordered_map<std::string, std::vector<uint8_t>> tcpStore。
storeSocket_
在 tcpStoreWorkerDaemon_ 之上,其連接到masterListenSocket_
:masterPort
之上。storeSocket_
的作用是封裝面對 master port 的操作,用戶只管 set,get 等操作,不用知道 master port。- set(key, data) 的作用就是通過
storeSocket_
向master 發送一個設置key : value 的請求。 tcpStoreMasterDaemon_
監聽到socket變化,就開始相應。tcpStoreMasterDaemon_
內部把 key : value 添加到std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_
之上。
- listenSocket_ 在 tcpStoreWorkerDaemon_ 之上,也連接到
masterListenSocket_
:masterPort
之上。下面有一個解耦,如注釋所述,It will register the socket on TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon
。- listenSocket_ 封裝了對 watchKey 的處理。Store Client 使用
watchKey(const std::string& key, WatchKeyCallback callback)
請求註冊,即:- Worker 請求註冊。使用
tcpStoreWorkerDaemon_->setCallback(regKey, callback)
來為tcpStoreWorkerDaemon_
的std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_
之上添加一個 callback。 - Worker 發送請求。通過 listenSocket_ 給 master 發消息 (key, WATCH_KEY),告訴master,如果 key 的 value 有變化,就調用這個 callback。
- Worker 請求註冊。使用
- Master 執行註冊。Master 接到 WATCH_KEY 消息之後進行註冊,調用 watchHandler,使用 watchedSockets_[key].push_back(socket) 來配置,告訴自己,如果這個 key 有變化,就給這個 socket 發消息。
- Master通知Worker。在 TCPStoreMasterDaemon::setHandler 之中,如果設置了新 value 之後,調用 sendKeyUpdatesToClients,其會遍歷 watchedSockets_[key],如果有 socket,就給 socket 發送消息變化通知。
- Worker執行callback。所以如果 key 有變化,就在 tcpStoreWorkerDaemon_ 之中調用了這個 callback。
- listenSocket_ 封裝了對 watchKey 的處理。Store Client 使用
4.2.2.2 Set 例子
我們首先看看 Set 的例子如下,就是 Worker 通過 socket 來在 Master 之上設置 value。
+
+----------------------------------------------------------------------+ | +----------------------------------------------+
| TCPStore Master | | | TCPStore Worker |
| | | | |
| | | | |
| | | | |
| +------------------------------------------------------------+ | | | |
| | TcpStoreMasterDaemon_ MasterPort| | | | |
| | | | | | |
| | TCPStore.masterListenSocket_ | | | | +---------------------------------+ |
| | | | | | | set(key, value) | |
| | | | | | | | |
| | tcpStore_[key] = value <------------------------------------------------+ | storeSocket_ | |
| | | | | | | | |
| | | | | | +---------------------------------+ |
| | | | | | |
| +------------------------------------------------------------+ | | | |
| | | | |
+----------------------------------------------------------------------+ | +----------------------------------------------+
+
手機如下:
4.2.2.3 Set 和 watchKey 結合
Set 和 watchKey 結合起來的示意圖如下(worker請求註冊,具體執行回調;master執行註冊,通知worker執行回調):
- Worker 請求註冊。Store Client 使用
watchKey(const std::string& key, WatchKeyCallback callback)
就是使用tcpStoreWorkerDaemon_->setCallback(regKey, callback)
來為tcpStoreWorkerDaemon_
的std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_
之上添加一個callback。 - Worker 發送請求。Worker 通過 listenSocket_ 給 master 發消息 (key, WATCH_KEY),告訴master,如果 key 的 value 有變化,就調用這個 callback。
- Master 執行註冊。Master 接到 WATCH_KEY 消息之後,調用 watchHandler,使用 watchedSockets_[key].push_back(socket) 來配置,告訴自己,如果這個 key 有變化,就給這個 socket 發消息。
- 下面我們假設 Store Client(這裡假設是同一個worker設置,實際上可能是不同worker)設置了一個 value。
- Master通知Worker。Master 在 TCPStoreMasterDaemon::setHandler 之中,如果設置了新 value 之後,調用 sendKeyUpdatesToClients,其會遍歷 watchedSockets_[key],如果有 socket,就給 socket 發送消息變化通知。
- Worker執行callback。如果 key 有變化,就在 tcpStoreWorkerDaemon_ 之中調用了這個 callback。
+----------------------------------------------------------------------+ + +------------------------------------------------------------------------+
| TCPStore Master | | | TCPStore Worker |
| | | | |
| +------------------------------------------------------------+ | | | |
| | TcpStoreMasterDaemon_ MasterPort| | | | +---------------------------------+ |
| | | | | | | | |
| | 2 | | | | | watchKey(key, callback) +----------------------+ |
| | TCPStore.masterListenSocket_ <----------------------------------+ | | | |
| | + | | | | | listenSocket_ | | |
| | | 3 | | | | | | 1 | |
| | v | | | | | | | |
| | watchedSockets_[key] = socket | | | | +---------------------------------+ | |
| | | | | | | |
| | +-------------------------------------------------+ | | | | | |
| | | | | | | | | |
| | | setHandler | | | | | +----------------------------------------------------------------+ |
| | | | | | | | | TCPStoreWorkerDaemon | | |
| | | | | | | | | v | |
| | | tcpStore_[key] = newData | | | | | | unordered_map<string, WatchKeyCallback> keyToCallbacks_ | |
| | | + | | | | | | | |
| | | | | | | | | | TCPStore.listenSocket_ | |
| | | | | | | | | | | |
| | | v | | | | | | +----------------------------------------------------------+ | |
| | | sendKeyUpdatesToClients | | | | | | | run | | |
| | | + | | 5 | | | | | | | |
| | | | | +---------------------->+ 6 | | |
| | | | | | | | | | | | callbackHandler +-----> keyToCallbacks_(callback) | | |
| | | v | | | | | | | | | | |
| | | | | | | | | | +----------------------------------------------------------+ | |
| | | for (int socket : watchedSockets_[key]){ | | | | | | +----------------------------------------------------------------+ |
| | | tcputil::sendString(socket, key, true) +-----+ | | | | |
| | | } | | | | | |
| | | | | | | | +------------------------+ |
| | | | | 4 | | | | set(key, newData) | |
| | | | <-----------------------+ | | |
| | +-------------------------------------------------+ | | | | | | |
| | | | | | +------------------------+ |
| +------------------------------------------------------------+ | | | |
| | | | |
+----------------------------------------------------------------------+ + +------------------------------------------------------------------------+
手機如下:
4.2.3 功能函數
TCPStore 提供了若干功能函數。
void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
std::string regKey = regularPrefix_ + key;
tcputil::sendValue<QueryType>(storeSocket_, QueryType::SET);
tcputil::sendString(storeSocket_, regKey, true);
tcputil::sendVector<uint8_t>(storeSocket_, data);
}
std::vector<uint8_t> TCPStore::get(const std::string& key) {
std::string regKey = regularPrefix_ + key;
return getHelper_(regKey);
}
int64_t TCPStore::add(const std::string& key, int64_t value) {
std::string regKey = regularPrefix_ + key;
return addHelper_(regKey, value);
}
int64_t TCPStore::addHelper_(const std::string& key, int64_t value) {
tcputil::sendValue<QueryType>(storeSocket_, QueryType::ADD);
tcputil::sendString(storeSocket_, key, true);
tcputil::sendValue<int64_t>(storeSocket_, value);
return tcputil::recvValue<int64_t>(storeSocket_);
}
這些功能函數是調用如下基礎函數來發送接收。
// this is only for convenience when sending rvalues
template <typename T>
void sendValue(int socket, const T& value, bool moreData = false) {
sendBytes<T>(socket, &value, 1, moreData);
}
template <typename T>
T recvValue(int socket) {
T value;
recvBytes<T>(socket, &value, 1);
return value;
}
4.2.4 構建函數
我們從構建函數可以看到:
- 對於存儲服務器角色,主要就是啟動了
tcpStoreMasterDaemon_
,注意在啟動了 daemon 之後,server 就進入了等待worker狀態,不會啟動接下來代碼中的 tcpStoreWorkerDaemon_。 - 對於存儲客戶端,則啟動了 tcpStoreWorkerDaemon_。
// TCPStore class methods
TCPStore::TCPStore(
const std::string& masterAddr,
PortType masterPort,
c10::optional<int> numWorkers,
bool isServer,
const std::chrono::milliseconds& timeout,
bool waitWorkers)
: Store(timeout),
isServer_(isServer),
tcpStoreAddr_(masterAddr),
tcpStorePort_(masterPort),
numWorkers_(numWorkers),
initKey_("init/"),
regularPrefix_("/") {
tcputil::socketInitialize();
if (isServer_) { // 如果設置了是server,就在masterPort上監聽
// Opening up the listening socket
std::tie(masterListenSocket_, tcpStorePort_) = tcputil::listen(masterPort);
}
try {
if (isServer_) { // 如果設置了是server,就啟動 tcpStoreMasterDaemon_
// Now start the daemon
tcpStoreMasterDaemon_ =
std::make_unique<TCPStoreMasterDaemon>(masterListenSocket_);
}
// Connect to the daemon
// worker 會與 master port 建立聯繫
storeSocket_ = tcputil::connect(
tcpStoreAddr_, tcpStorePort_, /* wait= */ true, timeout_);
if (numWorkers.value_or(-1) >= 0 && waitWorkers) {
waitForWorkers(); // server 等待 worker
}
// socket to handle requests from server,因為 master 也會給 worker 發消息
listenSocket_ = tcputil::connect(
tcpStoreAddr_, tcpStorePort_, /* wait= */ true, timeout_);
// 啟動 worker daemon
tcpStoreWorkerDaemon_ =
std::make_unique<TCPStoreWorkerDaemon>(listenSocket_);
} catch (const std::exception&) {
if (isServer_) {
tcpStoreMasterDaemon_ = nullptr;
tcputil::closeSocket(masterListenSocket_);
}
tcpStoreWorkerDaemon_ = nullptr;
if (listenSocket_ != -1) {
tcputil::closeSocket(listenSocket_);
}
if (storeSocket_ != -1) {
tcputil::closeSocket(storeSocket_);
}
throw;
}
}
server 會使用如下函數來等待 worker.
void TCPStore::waitForWorkers() {
addHelper_(initKey_, 1);
// Let server block until all workers have completed, this ensures that
// the server daemon thread is always running until the very end
if (isServer_) {
const auto start = std::chrono::steady_clock::now();
while (true) {
std::vector<uint8_t> value = getHelper_(initKey_);
auto buf = reinterpret_cast<const char*>(value.data());
auto len = value.size();
int numWorkersCompleted = std::stoi(std::string(buf, len));
if (numWorkersCompleted >= numWorkers_.value_or(-1)) {
break;
}
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - start);
if (timeout_ != kNoTimeout && elapsed > timeout_) {
break;
}
/* sleep override */
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
}
4.2.5 TCPStoreWorkerDaemon
這個 daemon 進程只是用來處理 watchKey。
// Separate thread that is launched on all instances (including master)
// Right now only handles callbacks registered from watchKey()
class TCPStoreWorkerDaemon : public BackgroundThread {
public:
explicit TCPStoreWorkerDaemon(int listenSocket);
// Set the callback to run key change
void setCallback(std::string key, WatchKeyCallback cb);
void waitForCallbackRegistration() {
// Block until callback has been registered successfully
std::unique_lock<std::mutex> callbackRegistrationLock(
callbackRegistrationMutex_);
callbackRegisteredCV_.wait(
callbackRegistrationLock, [&] { return callbackRegisteredData_; });
// Reset payload for next callback
callbackRegisteredData_ = false;
}
void setCallbackRegistered() {
callbackRegisteredData_ = true;
callbackRegisteredCV_.notify_one();
}
private:
void run();
void callbackHandler(int socket);
// List of callbacks map each watched key
std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_;
std::mutex keyToCallbacksMutex_;
std::mutex callbackRegistrationMutex_;
std::condition_variable callbackRegisteredCV_;
bool callbackRegisteredData_ = false;
};
其構建函數只是建立一個線程。
// TCPStoreListener class methods
TCPStoreWorkerDaemon::TCPStoreWorkerDaemon(int listenSocket)
: BackgroundThread(listenSocket) {
daemonThread_ = std::thread(&TCPStoreWorkerDaemon::run, this);
}
4.2.5.1 watchKey
Client Store 使用watchKey(const std::string& key, WatchKeyCallback callback)
的作用是往master註冊監聽key:
- Worker 請求註冊。使用
tcpStoreWorkerDaemon_->setCallback(regKey, callback)
來為tcpStoreWorkerDaemon_
的std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_
之上添加一個 callback。 - Worker 發送請求。通過 listenSocket_ 給 master 發消息 (key, WATCH_KEY),告訴master,如果 key 的 value 有變化,就調用這個 callback。
- 然後使用 waitForCallbackRegistration 等待註冊完成。
void TCPStore::watchKey(const std::string& key, WatchKeyCallback callback) {
// Only allow one thread to perform watchKey() at a time
const std::lock_guard<std::mutex> watchKeyLock(watchKeyMutex_);
// Register callback with TCPStoreMasterDaemon to call TCPStoreWorkerDaemon on
// key change
std::string regKey = regularPrefix_ + key;
tcpStoreWorkerDaemon_->setCallback(regKey, callback);
tcputil::sendValue<QueryType>(listenSocket_, QueryType::WATCH_KEY);
tcputil::sendString(listenSocket_, regKey);
// Block until callback has been registered successfully
tcpStoreWorkerDaemon_->waitForCallbackRegistration();
}
4.2.5.2 運行
其運行分為 windows 和 其他系統,但是主要就是收到了業務key,然後進行相關業務處理。
- Master 執行註冊。Master 接到 WATCH_KEY 消息之後,調用 watchHandler,使用 watchedSockets_[key].push_back(socket) 來配置,告訴自己,如果這個 key 有變化,就給這個 socket 發消息。
- Master通知Worker。在 TCPStoreMasterDaemon::setHandler 之中,如果設置了新 value 之後,調用 sendKeyUpdatesToClients,其會遍歷 watchedSockets_[key],如果有 socket,就給 socket 發送消息變化通知。
- Worker執行callback。所以如果 key 有變化,就在 tcpStoreWorkerDaemon_ 之中調用了這個 callback。
#ifdef _WIN32
void TCPStoreWorkerDaemon::run() { // 這裡是windows系統
std::vector<struct pollfd> fds;
tcputil::addPollfd(fds, storeListenSocket_, POLLIN);
while (true) {
// Check control and exit early if triggered
int res;
SYSCHECK_ERR_RETURN_NEG1(
res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
if (res == 0) {
auto rvPoll = WaitForSingleObject(ghStopEvent_, 0);
if (rvPoll != WAIT_TIMEOUT) {
break;
}
continue;
}
// if connection is closed gracefully by master, peeked data will return 0
char data;
int ret = recv(fds[0].fd, &data, 1, MSG_PEEK);
if (ret == 0) {
auto rvData = WaitForSingleObject(ghStopEvent_, 0);
if (rvData != WAIT_TIMEOUT) {
break;
}
continue;
}
// valid request, perform callback logic
callbackHandler(fds[0].fd); // 業務處理
}
}
#else
void TCPStoreWorkerDaemon::run() {
std::vector<struct pollfd> fds;
tcputil::addPollfd(fds, controlPipeFd_[0], POLLHUP);
tcputil::addPollfd(fds, storeListenSocket_, POLLIN);
while (true) {
SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
// Check control and exit early if triggered
// The pipe receives an event which tells us to shutdown the listener thread
if (fds[0].revents != 0) {
// Will be POLLUP when the pipe is closed
if (fds[0].revents ^ POLLHUP) {
throw std::system_error(
ECONNABORTED,
std::system_category(),
"Unexpected poll revent on the control pipe's reading fd: " +
std::to_string(fds[0].revents));
}
break;
}
// if connection is closed gracefully by master, peeked data will return 0
char data;
int ret = recv(fds[1].fd, &data, 1, MSG_PEEK);
if (ret == 0) {
continue;
}
// valid request, perform callback logic
callbackHandler(fds[1].fd); // 業務處理
}
}
#endif
4.2.6 TCPStoreMasterDaemon
這裡的 std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_; 是真實的 kv。
所以,TCPStoreMasterDaemon 就是負責對 kv 的操作,比如存取。
// Separate thread that is only launched on master
class TCPStoreMasterDaemon : public BackgroundThread {
public:
explicit TCPStoreMasterDaemon(int storeListenSocket);
private:
void run();
void queryFds(std::vector<struct pollfd>& fds);
void query(int socket);
// The master runs on a single thread so only
// one handler can be executed at a time
void setHandler(int socket);
void compareSetHandler(int socket);
void addHandler(int socket);
void getHandler(int socket) const;
void checkHandler(int socket) const;
void getNumKeysHandler(int socket) const;
void deleteHandler(int socket);
void waitHandler(int socket);
void watchHandler(int socket);
bool checkKeys(const std::vector<std::string>& keys) const;
// Helper function to alerts waiting workers, used in setHandler, getHandler
void wakeupWaitingClients(const std::string& key);
// Helper function used when the key is changed
// used in setHandler, addHandler, getHandler, deleteHandler
void sendKeyUpdatesToClients(
const std::string& key,
const enum WatchResponseType& type,
std::vector<uint8_t>& oldData,
std::vector<uint8_t>& newData);
std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_;
// From key -> the list of sockets waiting on the key
std::unordered_map<std::string, std::vector<int>> waitingSockets_;
// From socket -> number of keys awaited
std::unordered_map<int, size_t> keysAwaited_;
// From key -> the list of sockets watching the key
std::unordered_map<std::string, std::vector<int>> watchedSockets_;
};
4.2.6.1 運行
TCPStoreMasterDaemon 就是等待在 socket 之上,即 masterListenSocket_ 是 listen 在 masterPort 之上。
tcpStoreMasterDaemon_
使用tcputil::addPollfd(fds, storeListenSocket_, POLLIN)
來監聽masterListenSocket_
。- tcpStoreMasterDaemon_本身成為一個master,就是為整個 TCPStore提供服務的 server。
- key-value 就是std::unordered_map<std::string, std::vector<uint8_t>> tcpStore。
#ifdef _WIN32
void TCPStoreMasterDaemon::run() {
std::vector<struct pollfd> fds;
tcputil::addPollfd(fds, storeListenSocket_, POLLIN);
// receive the queries
bool finished = false;
while (!finished) {
for (size_t i = 0; i < sockets_.size(); i++) {
fds[i].revents = 0;
}
int res;
SYSCHECK_ERR_RETURN_NEG1(
res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
if (res == 0) {
auto rv = WaitForSingleObject(ghStopEvent_, 0);
if (rv != WAIT_TIMEOUT) {
finished = true;
break;
}
continue;
}
// TCPStore's listening socket has an event and it should now be able to
// accept new connections.
if (fds[0].revents != 0) { // 收到了消息
if (!(fds[0].revents & POLLIN)) {
throw std::system_error(
ECONNABORTED,
std::system_category(),
"Unexpected poll revent on the master's listening socket: " +
std::to_string(fds[0].revents));
}
int sockFd = std::get<0>(tcputil::accept(storeListenSocket_));
sockets_.push_back(sockFd);
tcputil::addPollfd(fds, sockFd, POLLIN);
}
queryFds(fds); // 業務處理
}
}
#else
void TCPStoreMasterDaemon::run() {
std::vector<struct pollfd> fds;
tcputil::addPollfd(fds, storeListenSocket_, POLLIN);
// Push the read end of the pipe to signal the stopping of the daemon run
tcputil::addPollfd(fds, controlPipeFd_[0], POLLHUP);
// receive the queries
bool finished = false;
while (!finished) {
for (size_t i = 0; i < sockets_.size(); i++) {
fds[i].revents = 0;
}
SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
// TCPStore's listening socket has an event and it should now be able to
// accept new connections.
if (fds[0].revents != 0) {
if (fds[0].revents ^ POLLIN) {
throw std::system_error(
ECONNABORTED,
std::system_category(),
"Unexpected poll revent on the master's listening socket: " +
std::to_string(fds[0].revents));
}
int sockFd = std::get<0>(tcputil::accept(storeListenSocket_));
sockets_.push_back(sockFd);
tcputil::addPollfd(fds, sockFd, POLLIN);
}
// The pipe receives an event which tells us to shutdown the daemon
if (fds[1].revents != 0) { // 收到了消息
// Will be POLLUP when the pipe is closed
if (fds[1].revents ^ POLLHUP) {
throw std::system_error(
ECONNABORTED,
std::system_category(),
"Unexpected poll revent on the control pipe's reading fd: " +
std::to_string(fds[1].revents));
}
finished = true;
break;
}
queryFds(fds); // 業務處理
}
}
#endif
4.2.6.2 調用業務
queryFds 會根據 socket 監聽結果而調用不同業務。
void TCPStoreMasterDaemon::queryFds(std::vector<struct pollfd>& fds) {
// Skipping the fds[0] and fds[1],
// fds[0] is master's listening socket
// fds[1] is control pipe's reading fd, it is not for Windows platform
for (size_t fdIdx = CONNECT_SOCKET_OFFSET; fdIdx < fds.size(); ++fdIdx) {
if (fds[fdIdx].revents == 0) {
continue;
}
// Now query the socket that has the event
try {
query(fds[fdIdx].fd); // 處理業務
} catch (...) {
tcputil::closeSocket(fds[fdIdx].fd);
// Remove all the tracking state of the close FD
for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
if (*vecIt == fds[fdIdx].fd) {
vecIt = it->second.erase(vecIt);
} else {
++vecIt;
}
}
if (it->second.size() == 0) {
it = waitingSockets_.erase(it);
} else {
++it;
}
}
for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) {
if (it->first == fds[fdIdx].fd) {
it = keysAwaited_.erase(it);
} else {
++it;
}
}
fds.erase(fds.begin() + fdIdx);
sockets_.erase(sockets_.begin() + fdIdx - CONNECT_SOCKET_OFFSET);
--fdIdx;
continue;
}
}
}
4.2.6.4 處理業務
從 socket 之中讀取消息,依據消息內容來進行相關業務處理。
// query communicates with the worker. The format
// of the query is as follows:
// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
// or, in the case of wait
// type of query | number of args | size of arg1 | arg1 | ...
void TCPStoreMasterDaemon::query(int socket) {
QueryType qt;
tcputil::recvBytes<QueryType>(socket, &qt, 1);
if (qt == QueryType::SET) {
setHandler(socket);
} else if (qt == QueryType::COMPARE_SET) {
compareSetHandler(socket);
} else if (qt == QueryType::ADD) {
addHandler(socket);
} else if (qt == QueryType::GET) {
getHandler(socket);
} else if (qt == QueryType::CHECK) {
checkHandler(socket);
} else if (qt == QueryType::WAIT) {
waitHandler(socket);
} else if (qt == QueryType::GETNUMKEYS) {
getNumKeysHandler(socket);
} else if (qt == QueryType::DELETE_KEY) {
deleteHandler(socket);
} else if (qt == QueryType::WATCH_KEY) {
watchHandler(socket);
} else {
throw std::runtime_error("Unexpected query type");
}
}
添加
此處是處理添加 value 的業務。
void TCPStoreMasterDaemon::setHandler(int socket) {
std::string key = tcputil::recvString(socket);
std::vector<uint8_t> newData = tcputil::recvVector<uint8_t>(socket);
std::vector<uint8_t> oldData;
bool newKey = true;
auto it = tcpStore_.find(key);
if (it != tcpStore_.end()) {
oldData = it->second;
newKey = false;
}
tcpStore_[key] = newData;
// On "set", wake up all clients that have been waiting
wakeupWaitingClients(key);
// Send key update to all watching clients
newKey ? sendKeyUpdatesToClients(
key, WatchResponseType::KEY_CREATED, oldData, newData)
: sendKeyUpdatesToClients(
key, WatchResponseType::KEY_UPDATED, oldData, newData);
}
獲取
出處處理獲取 value 的業務。
void TCPStoreMasterDaemon::getHandler(int socket) const {
std::string key = tcputil::recvString(socket);
auto data = tcpStore_.at(key);
tcputil::sendVector<uint8_t>(socket, data);
}
watchKey
此處添加了想要監控的 key。
對於WATCH_KEY,給對應的key添加了一個socket,作為以後發送通知的對象。
void TCPStoreMasterDaemon::watchHandler(int socket) {
std::string key = tcputil::recvString(socket);
// Record the socket to respond to when the key is updated
watchedSockets_[key].push_back(socket);
// Send update to TCPStoreWorkerDaemon on client
tcputil::sendValue<WatchResponseType>(
socket, WatchResponseType::KEY_CALLBACK_REGISTERED);
}
通知
如果key 有變化,就通知客戶端。
void TCPStoreMasterDaemon::sendKeyUpdatesToClients(
const std::string& key,
const enum WatchResponseType& type,
std::vector<uint8_t>& oldData,
std::vector<uint8_t>& newData) {
for (int socket : watchedSockets_[key]) {
tcputil::sendValue<WatchResponseType>(socket, type);
tcputil::sendString(socket, key, true);
tcputil::sendVector<uint8_t>(socket, oldData);
tcputil::sendVector<uint8_t>(socket, newData);
}
}
4.2.7 總結
我們總結圖例如下:
- Master 之中使用MasterPort 進行監聽請求。
- 關於存取value。
- Worker 之中,storeSocket_ 被用來存儲/獲取value,對應下圖 數字 1。
- 在 Master 之中對應了 tcpStore_。
- 關於監控。
- Worker 之中,listenSocket_ 被用來通知 Master 我需要監聽這個 key,對應下圖 數字 2。同時 worker 內部給這個 key 設置了 callback,對應了下圖 數字 3。
- 監聽在 Master 之中對應了
watchedSockets_[key] = socket_
。 - Master 之中,如果設置 value 時候,發現是一個被監控的 key,就通知 watchedSockets_[key],對應了下圖 數字 4。
- Worker 之中會進行相關業務調用。
+
+----------------------------------------------------------------------+ | +------------------------------------------------------------------------+
| TCPStore Master | | | TCPStore Worker |
| | | | |
| storeSocket_ | | | |
| | | | |
| +------------------------------------------------------------+ | | | |
| | TcpStoreMasterDaemon_ MasterPort| | | | 1 +---------------------------------+ |
| | | <--------------+ | set(key, value) | |
| | unordered_map<string, vector<uint8_t> > tcpStore_+---+ | | | | | | |
| | | | | | | | storeSocket_ | |
| | TCPStore.masterListenSocket_ | | | | | | | |
| | | | | | | +---------------------------------+ |
| | +-----------------------------------------------+ | | | | | |
| | | run | | | | | | 2 +---------------------------------+ |
| | | | | | <--------------+ | | |
| | | queryFds query | | | | | | | watchKey(key, callback) +-------------------------------+ |
| | | | | | | | | | | 3 | |
| | | setHandler getHandler | | | | | | | listenSocket_ | | |
| | | | | | | | | | | | |
| | +-----------------------------------------------+ | | | | | | | | |
| | | | | | | +---------------------------------+ | |
| +------------------------------------------------------------+ | | | | |
| | | | | | |
| | | | | | |
| | | | | +----------------------------------------------------------------+ |
| | | | | | TCPStoreWorkerDaemon | | |
| | | | | | | | |
| | | | | | unordered_map<string, WatchKeyCallback> keyToCallbacks_ | | |
| | | | | | | | |
| | | | | | TCPStore.listenSocket_ +----+ | |
| | | | | | | | |
| | | | | | +----------------------------------------------------------+ | |
| | | | | | | run | | | |
| | 4 | | | | | | | | |
| +--------------------->+ v | | |
| | | | | | callbackHandler +-----> keyToCallbacks_(callback) | | |
| | | | | | | | |
| | | | | +----------------------------------------------------------+ | |
| | | | +----------------------------------------------------------------+ |
+----------------------------------------------------------------------+ + +------------------------------------------------------------------------+
手機如下:
至此,我們梳理了初始化方法和Store這兩個概念,最終其實是Store這個概念在初始化過程中起了作用。我們也通過TCPStore 的分析知道了一個Store應該具備的功能,比如設置KV,監控某個key的變等等,正是這些功能才可以讓若干進程彼此知道對方的存在。
下一篇我們介紹進程組的概念,敬請期待。