分散式機器學習中的模型架構

在上一篇博文《分散式機器學習中的模型聚合》(鏈接://www.cnblogs.com/orion-orion/p/15635803.html)中,我們關注了在分散式機器學習中模型聚合(參數通訊)的問題,但是對每一個client具體的模型架構設計和參數優化方法還沒有討論。本篇文章我們關注具體模型結構設計和參數優化。
首先,在我follow的這篇篇論文[1]中(程式碼參見[2])不同的client有一個集成模型,而每一個集成模型由多個模型分量組成,可以表示為如下圖:
NLP多任務學習
接下來我們就自頂向下地分層次展示Client、Learners_ensemble和每個Learner的設計原理。

1. Client

Client是每一個Client任務節點的類設計,它提供與模型類相似的get_next_batch()方法和step方法供我們前面在部落格《分散式機器學習中的模型聚合》中講過的aggregator類調用。但是我們需要認識到,我們在操縱Client時,實際上就在操縱其Learners_ensemble,也就是在操縱所有的Learner模型分量。

它包含的方法核心如下:
NLP多任務學習

其具體程式碼實現如下:

class Client(object):
    r"""一個Client任務節點
    """
    def __init__(
            self,
            learners_ensemble,
            train_iterator,
            val_iterator,
            test_iterator,
            logger,
            local_steps,
            tune_locally=False
    ):
        # 本地的learners_ensemble模型
        self.learners_ensemble = learners_ensemble
        self.n_learners = len(self.learners_ensemble)
        self.tune_locally = tune_locally

        # 表示是否進行本地調整,我們先化繁為簡,略過這一功能
        if self.tune_locally:
            self.tuned_learners_ensemble = deepcopy(self.learners_ensemble)
        else:
            self.tuned_learners_ensemble = None

        # 表示是否為二分類問題
        self.binary_classification_flag = self.learners_ensemble.is_binary_classification

        # 需要保存train,val,test的DataLoader(因為每個Client對應一個不同的數據集)
        # 保存DataLoader的好處是只需要對象初始化時設置好DataLoader,後續step時便不用傳入數據
        # 這裡"iterator"其實都是torch.utils.data.DataLoader對象
        # 使用前需要使用iter(train_iterator)來轉換為迭代器(用for迭代的話默認轉型)
        self.train_iterator = train_iterator
        self.val_iterator = val_iterator
        self.test_iterator = test_iterator

        # 由train_iterator創造迭代器
        self.train_loader = iter(self.train_iterator)

        self.n_train_samples = len(self.train_iterator.dataset)
        self.n_test_samples = len(self.test_iterator.dataset)

        # 記錄每一個分量模型中每一個樣本的權重(0~1之間)
        self.samples_weights = torch.ones(self.n_learners, self.n_train_samples) / self.n_learners

        self.local_steps = local_steps

        self.counter = 0 # 記錄進行優化步驟step的次數
        self.logger = logger

    def get_next_batch(self):
        """
        帶異常判斷(安全)地從train_loader(由train_iterator)構建的迭代器中讀一個batch
        如果數據集已經讀至末尾,則循環讀取
        """
        try:
            batch = next(self.train_loader)
        except StopIteration:
            self.train_loader = iter(self.train_iterator)
            batch = next(self.train_loader)

        return batch

    def step(self, single_batch_flag=False, *args, **kwargs):
        """
        進行client的一個訓練step

        :參數 single_batch_flag: 若為true, client只使用一個batch進行更新
        :返回 clients_updates: ()
        """
        self.counter += 1 # 迭代步數+1

        self.update_sample_weights()
        self.update_learners_weights()

        # 最終的優化步落實到learners_ensemble上
        if single_batch_flag:
            batch = self.get_next_batch()
            # 若已設定了一次只使用一個batch,則從train_loader中讀一個batch
            client_updates = \
                self.learners_ensemble.fit_batch(
                    batch=batch,
                    weights=self.samples_weights
                )
        else:
            # 否則,將迭代器train_iterator傳入
            client_updates = \
                self.learners_ensemble.fit_epochs(
                    iterator=self.train_iterator,
                    n_epochs=self.local_steps,
                    weights=self.samples_weights
                )

        return client_updates

    
    def write_logs(self):
    r"""
        記錄train和test數據的loss和acc,後面控制台會列印輸出。
        注意,此處評估調用tuned_learners_ensemble中的evaluate_iterator()方法進行模型評估並記錄,
        evaluate_iterator()方法具體實現我們後面會介紹
    """
    def update_sample_weights(self):
        # 此方法用於更新每個樣本的權重,
        # 在MixtureClient任務類中重寫
        pass

    def update_learners_weights(self):
        # 此方法用於更新每個分量模型的權重,
        # 在MixtureClient任務類中重寫
        pass

注意,以上Client類還未對update_learners_weightsupdate_sample_weights這兩個方法進行定義。定義在如下的MixtureClient中:

class MixtureClient(Client):
    def update_sample_weights(self):
        all_losses = self.learners_ensemble.gather_losses(self.val_iterator)
        self.samples_weights = F.softmax((torch.log(self.learners_ensemble.learners_weights) - all_losses.T), dim=1).T

    def update_learners_weights(self):
        self.learners_ensemble.learners_weights = self.samples_weights.mean(dim=1)

2. Learners_ensemble

Learners_ensemble是多個分量模型的集成。在優化模型時需要分別對多個模型分量進行優化。在模型輸出時,採用多個分量模型加權平均的輸出方式。

\[\bm{y}_{t} = \sum_{m=1}^M w_{tm}h(\mathbf{X}_t; \bm{\theta}_{tm})
\]

此外,Learners_ensemble還提供evaluate_iterator()方法來完成對模型的評估(該方法得到的評估數值是所有模型分量的平均),供上層Client類調用。

它包含的方法核心如下:

NLP多任務學習

其具體程式碼實現如下:


class LearnersEnsemble(object):
    """
    由多個分量Learners集成的LearnersEnsemble.
    (是一個可迭代對象,重寫了_iter_,_getitem_,_len_方法)
    """
    def __init__(self, learners, learners_weights):
        self.learners = learners
        self.learners_weights = learners_weights

        # 假設所有learners的特徵維度一樣
        self.model_dim = self.learners[0].model_dim
        # 布爾標識是分類還是回歸任務
        self.is_binary_classification = self.learners[0].is_binary_classification
        # 默認所有learners的device和metric一樣
        self.device = self.learners[0].device
        self.metric = self.learners[0].metric

    def fit_batch(self, batch, weights):
        """
        使用一個batch更新各learner分量.
        :參數 batch: 元組 (x, y, indices)
        :參數 weights: tensor類型,每個樣本對應的權重(可為None)
        :返回 client_updates: np.array類型,大小為(n_learners, model_dim): 用于衡量ensemble中每個learner的新舊參數之間的差異
        """
        #記錄每一個learners的參數的每一個維度的更新量
        client_updates = torch.zeros(len(self.learners), self.model_dim)

        for learner_id, learner in enumerate(self.learners):
            old_params = learner.get_param_tensor()
            if weights is not None:
                learner.fit_batch(batch=batch, weights=weights[learner_id])
            else:
                learner.fit_batch(batch=batch, weights=None)

            params = learner.get_param_tensor()

            client_updates[learner_id] = (params - old_params)

        return client_updates.cpu().numpy()


    def fit_epochs(self, iterator, n_epochs, weights=None):
        """
        多次遍歷訓練集(即多個epochs)更新各learner分量.
        :參數 n_epochs: 使用訓練集的epochs輪數
        :參數 weights: tensor類型,每個樣本對應的權重(可為None)
        :返回 client_updates: np.array類型,大小為(n_learners, model_dim): 用于衡量ensemble中每個learner的新舊參數之間的差異
        """
        client_updates = torch.zeros(len(self.learners), self.model_dim)

        for learner_id, learner in enumerate(self.learners):
            old_params = learner.get_param_tensor()
            if weights is not None:
                learner.fit_epochs(iterator, n_epochs, weights=weights[learner_id])
            else:
                learner.fit_epochs(iterator, n_epochs, weights=None)
            params = learner.get_param_tensor()

            client_updates[learner_id] = (params - old_params)

        return client_updates.cpu().numpy()

    def evaluate_iterator(self, iterator):
        """
        用迭代器指向的數據評估learners.

        :參數 iterator: yields x, y, indices
        :返回: global_loss, global_acc(測試數據的)
        """
        if self.is_binary_classification:
            criterion = nn.BCELoss(reduction="none")
        else:
            criterion = nn.NLLLoss(reduction="none")

        for learner in self.learners:
            # 將各learner模型設置為evaluation模式
            learner.model.eval()

        global_loss = 0.
        global_metric = 0.
        n_samples = 0

        with torch.no_grad():
            for (x, y, _) in iterator:
                x = x.to(self.device).type(torch.float32)
                y = y.to(self.device)
                n_samples += y.size(0)

                y_pred = 0.
                for learner_id, learner in enumerate(self.learners):
                    # 注意一,這裡sigmoid和softmax寫在model類外,更具靈活性,
                    # 但一般我們仍然將其看做分類器h(x)的一部分
                    # 注意二,此處實質上採用各分類器輸出進行加權平均集成
                    if self.is_binary_classification:
                        y_pred += self.learners_weights[learner_id] * torch.sigmoid(learner.model(x))
                    else:
                        y_pred += self.learners_weights[learner_id] * F.softmax(learner.model(x), dim=1)

                y_pred = torch.clamp(y_pred, min=0., max=1.)

                if self.is_binary_classification:
                    y = y.type(torch.float32).unsqueeze(1)
                    global_loss += criterion(y_pred, y).sum().item()
                    y_pred = torch.logit(y_pred, eps=1e-10)
                else:
                    global_loss += criterion(torch.log(y_pred), y).sum().item()

                global_metric += self.metric(y_pred, y).item()

            return global_loss / n_samples, global_metric / n_samples

    def gather_losses(self, iterator):
        """
        彙集各learner模型關於迭代的所有樣本的losses
        :參數 iterator:
        :返回: tensor (n_learners, n_samples) ,各learner關於所迭代的數據集所有樣本的loss
        """
        n_samples = len(iterator.dataset)
        all_losses = torch.zeros(len(self.learners), n_samples)
        for learner_id, learner in enumerate(self.learners):
            all_losses[learner_id] = learner.gather_losses(iterator)

        return all_losses

    def free_memory(self):
        """
        釋放模型權重
        """
        for learner in self.learners:
            learner.free_memory()

    def free_gradients(self):
        """
        釋放模型梯度
        """
        for learner in self.learners:
            learner.free_gradients()

    # 以下三個方法說明LearnersEnsemble是個可迭代對象
    def __iter__(self):
        return LearnersEnsembleIterator(self)

    def __len__(self):
        return len(self.learners)

    def __getitem__(self, idx):
        return self.learners[idx]

3. Learner

Learner相當於在具體的諸如CNN、RNN等模型之上進行的一層包裝,實現了模型訓練的介面,其屬性Learner.model即具體的模型對象,來自類似與下列的模型類:

class CIFAR10CNN(nn.Module):
    def __init__(self, num_classes):
        super(CIFAR10CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.fc1 = nn.Linear(64 * 5 * 5, 2048)
        self.output = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.output(x)
        return x

它包含的方法核心如下:

NLP多任務學習

其具體程式碼實現如下:


class Learner:
    """
    負責訓練並評估一個(深度)學習器

    屬性
    ----------
    model (nn.Module): learner訓練的模型
    criterion (torch.nn.modules.loss): 訓練`model`所用的損失函數,這裡我們設置reduction="none",也就是默認一個batch的loss返回一個向量而不求和/平均
    metric (fn): 模型評價指標對應的函數, 輸入兩個向量輸出一標量
    device (str or torch.device):
    optimizer (torch.optim.Optimizer):
    lr_scheduler (torch.optim.lr_scheduler):
    is_binary_classification (bool): 是否將labels轉換為float, 如果使用 `BCELoss`
    (用於分類的交叉熵損失函數)這裡必須要設置為True

    方法
    ------
    optimizer_step: 進行一輪優化迭代, 需要梯度已經被計算完畢
    fit_batch: 對一個批量進行一輪優化迭代
    fit_epoch: 單次遍歷iterator中的得到的所有樣本,進行一系列批量的迭代
    fit_epochs: 多次遍歷將從iterator指向的訓練集
    gather_losses:收集iterator迭代器所有樣本的loss並拼接輸出
    get_param_tensor: 獲取獲取一個flattened後的`model`的參數
    free_memory: 釋放模型權重
    free_gradients: 釋放模型梯度
    """

    def __init__(
            self, model,
            criterion,
            metric,
            device,
            optimizer,
            lr_scheduler=None,
            is_binary_classification=False
    ):

        self.model = model.to(device)
        self.criterion = criterion.to(device)
        self.metric = metric
        self.device = device
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.is_binary_classification = is_binary_classification

        self.model_dim = int(self.get_param_tensor().shape[0])

    def optimizer_step(self):
        """
         執行一輪優化迭代,調用之前需要反向傳播先算好梯度(即已調用loss.backward())
        """
        self.optimizer.step()
        if self.lr_scheduler:
            self.lr_scheduler.step()

    def fit_batch(self, batch, weights=None):
        """
        基於來自`iterator`的一個batch的樣本執行一輪優化迭代
        :參數 batch (元組(x, y, indices)): 
        :參數 weights(tensor): 每個樣本的權重,可為none
        :返回: loss.detach(), metric.detach()(訓練數據)
        """
        self.model.train()

        x, y, indices = batch
        x = x.to(self.device).type(torch.float32)
        y = y.to(self.device)

        if self.is_binary_classification:
            y = y.type(torch.float32).unsqueeze(1)

        self.optimizer.zero_grad()

        y_pred = self.model(x)
        loss_vec = self.criterion(y_pred, y)
        metric = self.metric(y_pred, y) / len(y)

        if weights is not None:
            weights = weights.to(self.device)
            loss = (loss_vec.T @ weights[indices]) / loss_vec.size(0)
        else:
            loss = loss_vec.mean()

        loss.backward()

        self.optimizer.step()
        if self.lr_scheduler:
            self.lr_scheduler.step()

        return loss.detach(), metric.detach()

    def fit_epoch(self, iterator, weights=None):
        """
        將來自`iterator`的所有batches遍歷一次,進行優化迭代
        :參數 iterator(torch.utils.data.DataLoader):
        :參數 weights(torch.tensor): 存儲每個樣本權重的向量,可為None
        :return: loss.detach(), metric.detach() (訓練數據)
        """
        self.model.train()

        global_loss = 0.
        global_metric = 0.
        n_samples = 0

        for x, y, indices in iterator:
            x = x.to(self.device).type(torch.float32)
            y = y.to(self.device)

            n_samples += y.size(0)

            if self.is_binary_classification:
                y = y.type(torch.float32).unsqueeze(1)

            self.optimizer.zero_grad()

            y_pred = self.model(x)

            loss_vec = self.criterion(y_pred, y)
            if weights is not None:
                weights = weights.to(self.device)
                loss = (loss_vec.T @ weights[indices]) / loss_vec.size(0)
            else:
                loss = loss_vec.mean()
            loss.backward()

            self.optimizer.step()

            global_loss += loss.detach() * loss_vec.size(0)
            global_metric += self.metric(y_pred, y).detach()

        return global_loss / n_samples, global_metric / n_samples

    def gather_losses(self, iterator):
        """
        計算來自iterator的樣本中的所有losses並拼接為all_losses
        :參數 iterator(torch.utils.data.DataLoader): 
        :return: 所有來自iterator.dataset樣本的losses拼成的tensor
        """
        self.model.eval()
        n_samples = len(iterator.dataset)
        all_losses = torch.zeros(n_samples, device=self.device)

        with torch.no_grad():
            for (x, y, indices) in iterator:
                x = x.to(self.device).type(torch.float32)
                y = y.to(self.device)

                if self.is_binary_classification:
                    y = y.type(torch.float32).unsqueeze(1)

                y_pred = self.model(x)
                all_losses[indices] = self.criterion(y_pred, y).squeeze()

        return all_losses

    def fit_epochs(self, iterator, n_epochs, weights=None):
        """
        執行多個n_epochs的訓練
        :參數 iterator(torch.utils.data.DataLoader):
        :參數 n_epochs(int):
        :參數 weights: 每個樣本權重的向量,可為None
        :返回: None
        """
        for step in range(n_epochs):
            self.fit_epoch(iterator, weights)

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

    def get_param_tensor(self):
        """
        將所有模型參數做為一個flattened的一維張量輸出
        :返回: torch.tensor
        """
        param_list = []

        for param in self.model.parameters():
            param_list.append(param.data.view(-1, ))

        return torch.cat(param_list)

    def get_grad_tensor(self):
        """
        將 `model` 所有參數的梯度做為flattened的一維張量輸出
        :返回: torch.tensor
        """
        grad_list = []

        for param in self.model.parameters():
            if param.grad is not None:
                grad_list.append(param.grad.data.view(-1, ))

        return torch.cat(grad_list)

    def free_memory(self):
        """
        釋放模型權重
        """
        del self.optimizer
        del self.model

    def free_gradients(self):
        """
        釋放模型梯度
        """
        self.optimizer.zero_grad(set_to_none=True)

3. Client、Learners_ensemble、Learner的對比

三者的對比架構圖如下:
NLP多任務學習

其中,\(A\)方法指向\(B\)方法的箭頭代表在\(A\)方法中調用\(B\)方法。

我們可以看到,我們在上一篇博文《分散式機器學習中的模型聚合》(鏈接://www.cnblogs.com/orion-orion/p/15635803.html)中所調用的函數client.step()以及client.write_logs()下層其實還封裝著這麼多的實現。

需要指出的是,模型的梯度計算和參數更新最終是要落實到Learner類去完成,不過模型的評估我們直接在LearnersEnsemble類即可完成,而不需要在Learner類去單獨設計一個方法。

4. 模型測試

我們採用CIFAR10 數據集對論文提出的模型進行測試,可以看到測試效果不錯。預設迭代200各epoch,迭代了9各epoch我們就已經達到 Train Acc: 73.587%,Test Acc: 70.577% ,雖然和論文最終宣稱的78.1%尚差距,不過最終應該能達到該精度,可見論文聲稱的結果很大程度上還是靠譜的。

==> Clients initialization..
===> Building data iterators..

  0%|          | 0/80 [00:00<?, ?it/s]
  4%|▍         | 3/80 [00:00<00:03, 24.48it/s]
  9%|▉         | 7/80 [00:00<00:02, 30.10it/s]
 14%|█▍        | 11/80 [00:00<00:02, 25.71it/s]
 18%|█▊        | 14/80 [00:00<00:02, 23.93it/s]
 21%|██▏       | 17/80 [00:00<00:02, 24.12it/s]
 26%|██▋       | 21/80 [00:00<00:02, 25.36it/s]
 30%|███       | 24/80 [00:00<00:02, 25.34it/s]
 34%|███▍      | 27/80 [00:01<00:02, 24.98it/s]
 39%|███▉      | 31/80 [00:01<00:01, 28.59it/s]
 46%|████▋     | 37/80 [00:01<00:01, 33.12it/s]
 51%|█████▏    | 41/80 [00:01<00:01, 32.48it/s]
 56%|█████▋    | 45/80 [00:01<00:01, 24.59it/s]
 66%|██████▋   | 53/80 [00:01<00:00, 31.57it/s]
 72%|███████▎  | 58/80 [00:01<00:00, 33.56it/s]
 78%|███████▊  | 62/80 [00:02<00:00, 33.12it/s]
 86%|████████▋ | 69/80 [00:02<00:00, 37.24it/s]
 91%|█████████▏| 73/80 [00:02<00:00, 36.86it/s]
 98%|█████████▊| 78/80 [00:02<00:00, 35.90it/s]
100%|██████████| 80/80 [00:02<00:00, 31.41it/s]
===> Initializing clients..

  0%|          | 0/80 [00:00<?, ?it/s]
  1%|▏         | 1/80 [00:13<18:05, 13.75s/it]
  2%|▎         | 2/80 [00:13<07:26,  5.73s/it]
  4%|▍         | 3/80 [00:13<04:03,  3.16s/it]
  5%|▌         | 4/80 [00:14<02:28,  1.95s/it]
  6%|▋         | 5/80 [00:14<01:36,  1.28s/it]
  8%|▊         | 6/80 [00:14<01:05,  1.13it/s]
  9%|▉         | 7/80 [00:14<00:46,  1.56it/s]
 10%|█         | 8/80 [00:14<00:33,  2.12it/s]
 11%|█▏        | 9/80 [00:14<00:25,  2.81it/s]
 12%|█▎        | 10/80 [00:14<00:19,  3.61it/s]
 14%|█▍        | 11/80 [00:14<00:15,  4.43it/s]
 15%|█▌        | 12/80 [00:14<00:12,  5.25it/s]
 16%|█▋        | 13/80 [00:15<00:10,  6.13it/s]
 18%|█▊        | 14/80 [00:15<00:09,  6.86it/s]
 19%|█▉        | 15/80 [00:15<00:08,  7.42it/s]
 21%|██▏       | 17/80 [00:15<00:07,  8.42it/s]
 22%|██▎       | 18/80 [00:15<00:07,  8.73it/s]
 24%|██▍       | 19/80 [00:15<00:06,  9.03it/s]
 25%|██▌       | 20/80 [00:15<00:06,  9.22it/s]
 26%|██▋       | 21/80 [00:15<00:06,  9.40it/s]
 28%|██▊       | 22/80 [00:15<00:06,  9.23it/s]
 29%|██▉       | 23/80 [00:16<00:06,  9.33it/s]
 30%|███       | 24/80 [00:16<00:05,  9.37it/s]
 32%|███▎      | 26/80 [00:16<00:05,  9.61it/s]
 34%|███▍      | 27/80 [00:16<00:05,  9.62it/s]
 36%|███▋      | 29/80 [00:16<00:05,  8.51it/s]
 39%|███▉      | 31/80 [00:16<00:05,  9.02it/s]
 41%|████▏     | 33/80 [00:17<00:05,  9.25it/s]
 42%|████▎     | 34/80 [00:17<00:04,  9.34it/s]
 44%|████▍     | 35/80 [00:17<00:04,  9.45it/s]
 45%|████▌     | 36/80 [00:17<00:04,  9.55it/s]
 46%|████▋     | 37/80 [00:17<00:04,  8.99it/s]
 48%|████▊     | 38/80 [00:17<00:04,  8.48it/s]
 49%|████▉     | 39/80 [00:17<00:04,  8.29it/s]
 50%|█████     | 40/80 [00:17<00:04,  8.08it/s]
 51%|█████▏    | 41/80 [00:18<00:04,  7.99it/s]
 52%|█████▎    | 42/80 [00:18<00:04,  7.79it/s]
 54%|█████▍    | 43/80 [00:18<00:04,  8.01it/s]
 55%|█████▌    | 44/80 [00:18<00:04,  8.47it/s]
 56%|█████▋    | 45/80 [00:18<00:03,  8.84it/s]
 57%|█████▊    | 46/80 [00:18<00:03,  9.03it/s]
 59%|█████▉    | 47/80 [00:18<00:04,  7.95it/s]
 60%|██████    | 48/80 [00:18<00:04,  7.89it/s]
 61%|██████▏   | 49/80 [00:19<00:03,  7.77it/s]
 62%|██████▎   | 50/80 [00:19<00:04,  6.43it/s]
 64%|██████▍   | 51/80 [00:19<00:04,  6.74it/s]
 65%|██████▌   | 52/80 [00:19<00:03,  7.45it/s]
 66%|██████▋   | 53/80 [00:19<00:03,  8.04it/s]
 68%|██████▊   | 54/80 [00:19<00:03,  8.41it/s]
 70%|███████   | 56/80 [00:19<00:02,  9.00it/s]
 71%|███████▏  | 57/80 [00:20<00:02,  9.20it/s]
 72%|███████▎  | 58/80 [00:20<00:02,  9.38it/s]
 74%|███████▍  | 59/80 [00:20<00:02,  9.52it/s]
 75%|███████▌  | 60/80 [00:20<00:02,  9.64it/s]
 76%|███████▋  | 61/80 [00:20<00:01,  9.65it/s]
 78%|███████▊  | 62/80 [00:20<00:01,  9.57it/s]
 79%|███████▉  | 63/80 [00:20<00:01,  9.69it/s]
 80%|████████  | 64/80 [00:20<00:01,  9.75it/s]
 81%|████████▏ | 65/80 [00:20<00:01,  9.69it/s]
 82%|████████▎ | 66/80 [00:21<00:01,  8.49it/s]
 84%|████████▍ | 67/80 [00:21<00:01,  8.29it/s]
 85%|████████▌ | 68/80 [00:21<00:01,  8.08it/s]
 86%|████████▋ | 69/80 [00:21<00:01,  8.03it/s]
 88%|████████▊ | 70/80 [00:21<00:01,  7.96it/s]
 89%|████████▉ | 71/80 [00:21<00:01,  8.10it/s]
 91%|█████████▏| 73/80 [00:21<00:00,  7.17it/s]
 92%|█████████▎| 74/80 [00:22<00:00,  7.71it/s]
 94%|█████████▍| 75/80 [00:22<00:00,  8.13it/s]
 95%|█████████▌| 76/80 [00:22<00:00,  8.56it/s]
 98%|█████████▊| 78/80 [00:22<00:00,  9.11it/s]
100%|██████████| 80/80 [00:22<00:00,  9.43it/s]
100%|██████████| 80/80 [00:22<00:00,  3.52it/s]
==> Test Clients initialization..
===> Building data iterators..

0it [00:00, ?it/s]
0it [00:00, ?it/s]
===> Initializing clients..

0it [00:00, ?it/s]
0it [00:00, ?it/s]
++++++++++++++++++++++++++++++
Global..
Train Loss: 2.299 | Train Acc: 10.643% |Test Loss: 2.298 | Test Acc: 10.503% |
++++++++++++++++++++++++++++++++++++++++++++++++++
################################################################################
Training..

  0%|          | 0/200 [00:00<?, ?it/s]
  0%|          | 1/200 [01:08<3:48:37, 68.93s/it]
  1%|          | 2/200 [02:16<3:45:00, 68.18s/it]
  2%|▏         | 3/200 [03:23<3:41:16, 67.40s/it]
  2%|▏         | 4/200 [04:29<3:39:15, 67.12s/it]++++++++++++++++++++++++++++++
Global..
Train Loss: 1.003 | Train Acc: 65.321% |Test Loss: 1.036 | Test Acc: 63.872% |
++++++++++++++++++++++++++++++++++++++++++++++++++
################################################################################

  2%|▎         | 5/200 [05:56<4:00:55, 74.13s/it]
  3%|▎         | 6/200 [07:02<3:50:47, 71.38s/it]
  4%|▎         | 7/200 [08:08<3:43:55, 69.61s/it]
  4%|▍         | 8/200 [09:14<3:38:57, 68.43s/it]
  4%|▍         | 9/200 [10:20<3:35:23, 67.66s/it]++++++++++++++++++++++++++++++
Global..
Train Loss: 0.754 | Train Acc: 73.587% |Test Loss: 0.835 | Test Acc: 70.577% |
++++++++++++++++++++++++++++++++++++++++++++++++++

這裡附上論文中數據集和其採用模型的對應關係和論文中所聲稱的在以上各數據集中能達到的精度。
NLP多任務學習
NLP多任務學習

參考文獻

  • [1] Marfoq O, Neglia G, Bellet A, et al. Federated multi-task learning under a mixture of distributions[J]. Advances in Neural Information Processing Systems, 2021, 34.
  • [2] //github.com/omarfoq/FedEM