分布式机器学习:模型平均MA与弹性平均EASGD(PySpark)

计算机科学一大定律:许多看似过时的东西可能过一段时间又会以新的形式再次回归。

1 模型平均方法(MA)

1.1 算法描述与实现

我们在博客《分布式机器学习:同步并行SGD算法的实现与复杂度分析(PySpark)》中介绍的SSGD算法由于通信比较频繁,在通信与计算比较大时(不同节点位于不同的地理位置),难以取得理想的加速效果。接下来我们介绍一种通信频率比较低的同步算法——模型平均方法(Model Average, MA)[1]。在MA算法中,每个工作节点会根据本地数据对本地模型进行多轮的迭代更新,直到本地模型收敛说本地迭代轮数超过一个预设的阈值,再进行一次全局的模型平均,并以此均值做为最新的全局模型继续训练,其具体流程如下:


MA算法按照通信间隔的不同,可分为下面两种情况:

  • 只在所有工作节点完成本地训练之后,做一次模型平均。这种情况下所需通信量极小,本地模型在迭代过程中没有任何交互,可以完全独立地完成并行计算,通信只在模型训练的最后发生一次。这类算法只在强凸情形下收敛率有保障,但对非凸问题不一定适用(如神经网络),因为本地模型可能落到了不同的局部凸子域,对参数的平均无法保证最终模型的性能。
  • 在本地完成一定轮数的迭代之后,就做一次模型平均,然后用这次平均的模型的结果做为接下来的训练起点,然后继续迭代,循环往复。相比只在最终做一次模型平均,中间的多次平均控制了各工作节点模型之间的差异,降低了它们落在局部凸子域的可能性,从而保证了最终的模型精度。这种方法被广发应用于很多实际的机器学习系统(如CNTK)中,此外该思想在联邦学习[2]中也应用广泛)。

该算法的PySpark实现如下:

from typing import Tuple
from sklearn.datasets import load_breast_cancer
import numpy as np
from pyspark.sql import SparkSession
from operator import add
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

n_slices = 4  # Number of Slices
n_iterations = 1500  # Number of iterations
eta = 0.1
mini_batch_fraction = 0.1 # the fraction of mini batch sample 
n_local_iterations = 1 # the number local epochs

def logistic_f(x, w):
    return 1 / (np.exp(-x.dot(w)) + 1 +1e-6)


def gradient(pt_w: Tuple):
    """ Compute linear regression gradient for a matrix of data points
    """
    idx, (point, w) = pt_w
    y = point[-1]    # point label
    x = point[:-1]   # point coordinate
    # For each point (x, y), compute gradient function, then sum these up
    return  (idx, (w, - (y - logistic_f(x, w)) * x))


def update_local_w(iter):
    iter = list(iter)
    idx, (w, _) = iter[0]
    g_mean = np.mean(np.array([ g for _, (_, g) in iter]), axis=0) 
    return  [(idx, w - eta * g_mean)]


def draw_acc_plot(accs, n_iterations):
    def ewma_smooth(accs, alpha=0.9):
        s_accs = np.zeros(n_iterations)
        for idx, acc in enumerate(accs):
            if idx == 0:
                s_accs[idx] = acc
            else:
                s_accs[idx] = alpha * s_accs[idx-1] + (1 - alpha) * acc
        return s_accs

    s_accs = ewma_smooth(accs, alpha=0.9)
    plt.plot(np.arange(1, n_iterations + 1), accs, color="C0", alpha=0.3)
    plt.plot(np.arange(1, n_iterations + 1), s_accs, color="C0")
    plt.title(label="Accuracy on test dataset")
    plt.xlabel("Round")
    plt.ylabel("Accuracy")
    plt.savefig("ma_acc_plot.png")


if __name__ == "__main__":

    X, y = load_breast_cancer(return_X_y=True)

    D = X.shape[1]

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=0, shuffle=True)
    n_train, n_test = X_train.shape[0], X_test.shape[0]

    spark = SparkSession\
        .builder\
        .appName("Model Average")\
        .getOrCreate()

    matrix = np.concatenate(
        [X_train, np.ones((n_train, 1)), y_train.reshape(-1, 1)], axis=1)

    points = spark.sparkContext.parallelize(matrix, n_slices).cache()
    points = points.mapPartitionsWithIndex(lambda idx, iter: [ (idx, arr) for arr in iter])

    ws = spark.sparkContext.parallelize(2 * np.random.ranf(size=(n_slices, D + 1)) - 1, n_slices).cache()
    ws = ws.mapPartitionsWithIndex(lambda idx, iter: [(idx, next(iter))])

    w = 2 * np.random.ranf(size=D + 1) - 1
    print("Initial w: " + str(w))
    
    accs = []
    for t in range(n_iterations):
        print("On iteration %d" % (t + 1))
        w_br = spark.sparkContext.broadcast(w)
        ws = ws.mapPartitions(lambda iter: [(iter[0][0], w_br.value)])
                            
        for local_t in range(n_local_iterations):
            ws = points.sample(False, mini_batch_fraction, 42 + t)\
                .join(ws, numPartitions=n_slices)\
                    .map(lambda pt_w: gradient(pt_w))\
                        .mapPartitions(update_local_w) 
            
        par_w_sum = ws.mapPartitions(lambda iter: [iter[0][1]]).treeAggregate(0.0, add, add)           
  
        w  = par_w_sum / n_slices 

        y_pred = logistic_f(np.concatenate(
            [X_test, np.ones((n_test, 1))], axis=1), w)
        pred_label = np.where(y_pred < 0.5, 0, 1)
        acc = accuracy_score(y_test, pred_label)
        accs.append(acc)
        print("iterations: %d, accuracy: %f" % (t, acc))

    print("Final w: %s " % w)
    print("Final acc: %f" % acc)

    spark.stop()

    draw_acc_plot(accs, n_iterations)

1.2 算法收敛表现

算法初始化权重如下:

Initial w: [-4.59895046e-01  4.81609930e-01 -2.98562178e-01  4.37876789e-02
 -9.12956525e-01  6.72295704e-01  6.02029280e-01 -4.01078397e-01
  9.08559315e-02 -1.07924749e-01  4.64202010e-01 -6.69343161e-01
 -7.98638952e-01  2.56715359e-01 -4.08737254e-01 -6.20120002e-01
 -8.59081121e-01  9.25086249e-01 -8.64084351e-01  6.18274961e-01
 -3.05928664e-01 -6.96321445e-01 -3.70347891e-01  8.45658259e-01
 -3.46329338e-01  9.75573025e-01 -2.37675425e-01  1.26656795e-01
 -6.79589868e-01  9.48379550e-01 -2.04796940e-04]

算法的终止权重和acc如下:

Final w: [ 3.56746113e+01  5.23773783e+01  2.10458066e+02  1.13411434e+02
  9.03684544e-01 -1.64116119e-01 -8.50221455e-01 -1.02747339e-01
  1.01249316e+00  5.86541201e-01 -7.95885004e-01  4.08388583e+00
 -3.72622203e+00 -1.22789161e+02  7.15869286e-01 -9.59608820e-01
  5.85750752e-01  9.52410298e-01 -4.74590169e-01  8.01080392e-01
  3.72102291e+01  6.54164219e+01  2.11443556e+02 -1.34266020e+02
  1.19299917e+00 -9.48551465e-01 -3.02582118e-01 -9.50371027e-01
  1.30280295e+00  5.02509358e-01  5.34567937e+00] 
Final acc: 0.923977

MA算法的在测试集上的ACC曲线如下:

我们可以尝试与SSGD算法的ACC曲线做对比(下列是SSGD算法的ACC曲线):

可以看到MA算法在将本地迭代轮数\(M\)设置为1,其它超参数和SSGD算法一样的情况下,二者收敛速率接近。

2 模型平均方法的改进——BMUF算法

2.1 算法描述与实现

在MA算法中,不论参数本地更新流程是什么策略,在聚合的时候都只是将来自各个工作节点的模型进行简单平均。如果把每次平均之间的本地更新称作一个数据块(block)的话,那么模型平均可以看做基于数据块的全局模型更新流程。我们知道,在单机优化算法中,常常会加入动量[3]以有效利用历史更新信息来减少随机梯度下降中梯度噪声的影响。类似地,我们也可以考虑在MA算法中对每次全局模型的更新引入动量的概念。一种称为块模型更新过滤(Block-wise Model Update Filtering, BMUF)[4]的算法基于数据块的动量思想对MA进行了改进,其有效性在相关文献中被验证。BMUF算法实际上是想利用全局的动量,使历史上本地迭代对全局模型更新的影响有一定的延续性,从而达到加速模型优化进程的作用。具体流程如下:

该算法的PySpark实现只需要在MA算法的基础上对参数聚合部分做如下修改即可:

mu = 0.5
zeta = 0.5
# weight update
delta_w = 2 * np.random.ranf(size=D + 1) - 1
    
for t in range(n_iterations):
    ...

    w_avg  = par_w_sum / n_slices 

    delta_w = mu * delta_w + zeta * (w_avg - w)
    w = w + delta_w

2.2 算法收敛表现

BMUF算法的终止权重和acc如下:

Final w: [ 3.55587119e+01  4.74476386e+01  2.05646360e+02  8.43207945e+01
 -6.24992160e-01  4.09293120e-01 -2.03903959e-01 -7.48743358e-01
  6.17507919e-01  1.37260407e-01  3.02077368e-01  2.29466375e+00
 -4.29517662e+00 -1.25214901e+02 -3.93164203e-01 -6.65440018e-01
 -9.50817683e-01  9.09075928e-01 -8.22611068e-01  6.22626019e-01
  3.78429147e+01  5.77291936e+01  2.04653067e+02 -1.17393706e+02
 -5.53476597e-03  2.76373535e-02 -1.98339171e+00 -3.34173754e-01
  4.17088085e-03  1.15206427e+00  4.68333285e+00] 
Final acc: 0.923977

BMUF算法的在测试集上的ACC曲线如下:

我们发现当通信间隔\(M\)设置为1时,BMUF算法的收敛速率要略快于SSGD算法和MA算法。由于利用了历史动量信息,其ACC曲线也要略为稳定一些。

3 弹性平均SGD算法(EASGD)

3.1 算法描述与实现

前面介绍的几种算法无论本地模型用什么方法更新,都会在某个时刻聚合出一个全局模型,并且用其替代本地模型。但这种处理方法对于深度学习这种有很多个局部极小点的优化问题而言,是否是最合适的选择呢?答案是不确定的。由于各个工作节点所使用的训练数据不同,本地模型训练的模型有所差别,各个工作节点实际上是在不同的搜索空间里寻找局部最优点,由于探索的方向不同,得到的模型有可能是大相径庭的(最极端的情况也就是联邦学习,不同节点间数据直接是Non-IID的)。简单的中心化聚合可能会抹杀各个工作节点自身探索的有益信息。

为了解决以上问题,研究人员提出了一种非完全一致的分布式机器学习算法,称为弹性平均SGD(简称EASGD)[5]。该方法不强求各个工作节点继承全局模型(也是后来联邦学习中个性化联邦学习的思想。如果我们定义\(w_k\)为第\(k\)个工作节点上的模型,\(\overline{w}\)为全局模型,则可将分布式优化描述为如下式子:

\[\underset{w^1, w^2, \cdots, w^k}{\min}
\sum_{k=1}^K\hat{l}_k(w_k) + \frac{\rho}{2}\lVert w_k – \overline{w}\rVert^2
\]

换言之,分布式优化有两个目标:

  • 使得各个工作节点本身的损失函数得到最小化。
  • 希望各个工作节点上的本地模型和全局模型之间的差距比较小。

按照这个优化目标,如果分别对\(w_k\)\(\overline{w}\)求导,就可以得到下列算法中的更新公式:

如果我们将EASGD与SSGD或者MA进行对比,可以看出EASGD在本地模型和服务器模型更新时都兼顾全局一致性和本地模型的独立性。具体而言,是指:

  • 当对本地模型进行更新时,在按照本地数据计算梯度的同时,也力求用全局模型来约束本地模型不要偏离太远。
  • 在对全局模型进行更新时,不直接把各个本地模型的平均值做为下一轮的全局模型,而是部分保留了历史上全局模型的参数信息。

这种弹性更新的方法,即可保持工作节点探索各自的探索方向,同时也不会让它们彼此相差太远(事实上,该思想也体现于ICML2021个性化联邦学习论文Ditto[6]中)实验表明,EASGD算法的精度和稳定性都有较好的表现。除了同步的设置,EASGD算法也有异步的版本,我们后面再进行介绍。

该算法的PySpark实现只需要在MA算法的基础上去掉用全局参数对本地参数的覆盖,并参数聚合部分和本地更新的部分修改即可:

 
beta = 0.5 # the parameter of history information
alpha = 0.1 # constraint coefficient


def update_local_w(iter, w):
    iter = list(iter)
    idx, (local_w, _) = iter[0]
    g_mean = np.mean(np.array([ g for _, (_, g) in iter]), axis=0) 
    return  [(idx, local_w - eta * g_mean - alpha * (local_w - w))]

...

if __name__ == "__main__":

    ...

    for t in range(n_iterations):
        print("On iteration %d" % (t + 1))
        w_br = spark.sparkContext.broadcast(w)
                            
        ws = points.sample(False, mini_batch_fraction, 42 + t)\
            .join(ws, numPartitions=n_slices)\
                .map(lambda pt_w: gradient(pt_w))\
                    .mapPartitions(lambda iter: update_local_w(iter, w=w_br.value)) 
            
        par_w_sum = ws.mapPartitions(lambda iter: [iter[0][1]]).treeAggregate(0.0, add, add)           
  
        w  = (1 - beta) * w + beta * par_w_sum / n_slices 

3.2 算法收敛表现

BMUF算法的终止权重和acc如下:

Final w: [ 4.75633325e+01  7.05407657e+01  2.79006876e+02  1.45465411e+02
  4.54467492e-01 -2.10142380e-01 -6.30421903e-01  4.53977048e-01
  1.01717057e-01 -2.14420411e-01 -2.94989128e-01  4.89426514e+00
 -3.05999725e+00 -1.62456459e+02  1.27772367e-01 -4.68403685e-02
 -8.63345165e-03  2.15800745e-01  5.77719463e-01 -1.83278567e-02
  5.01647916e+01  8.80774672e+01  2.79145194e+02 -1.81621547e+02
  2.14490664e-01 -8.83817758e-01 -1.43244912e+00 -5.96750910e-01
  1.04627441e+00  4.37109225e-01  6.04818129e+00] 
Final acc: 0.929825

BMUF算法的在测试集上的ACC曲线如下:

我们发现和BMUF算法类似,EASGD的算法收敛速率也要略快于SSGD算法和MA算法。而由于其弹性更新操作,其ACC曲线比上面介绍的所有算法都要稳定。

4 总结

上述介绍的都是分布式机器学习中常用的同步算法。MA相比SSGD,允许工作节点在本地进行多轮迭代(尤其适用于高通信计算比的情况),因而更加高效。但是MA算法通常会带来精度损失,实践中需要仔细调整参数设置,或者通过增加数据块粒度的动量来获取更好的效果。EASGD方法则不强求全局模型的一致性,而是为每个工作节点保持了独立的探索能力。

以上这些算法的共性是:所有的工作节点会以一定的频率进行全局同步。当工作节点的计算性能存在差异,或者某些工作节点无法正常工作(比如死机)时,分布式系统的整体运行效率不好,甚至无法完成训练任务。而这就需要异步的并行算法来解决了。

参考

  • [1]
    McDonald R, Hall K, Mann G. Distributed training strategies for the structured perceptron[C]//Human language technologies: The 2010 annual conference of the North American chapter of the association for computational linguistics. 2010: 456-464.

  • [2] McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.

  • [3] Sutskever I, Martens J, Dahl G, et al. On the importance of initialization and momentum in deep learning[C]//International conference on machine learning. PMLR, 2013: 1139-1147.

  • [4] Chen K, Huo Q. Scalable training of deep learning machines by incremental block training with intra-block parallel optimization and blockwise model-update filtering[C]//2016 ieee international conference on acoustics, speech and signal processing (icassp). IEEE, 2016: 5880-5884.

  • [5]
    Zhang S, Choromanska A E, LeCun Y. Deep learning with elastic averaging SGD[J]. Advances in neural information processing systems, 2015, 28.

  • [6] Li T, Hu S, Beirami A, et al. Ditto: Fair and robust federated learning through personalization[C]//International Conference on Machine Learning. PMLR, 2021: 6357-6368.

  • [7] 刘浩洋,户将等. 最优化:建模、算法与理论[M]. 高教出版社, 2020.

  • [8] 刘铁岩,陈薇等. 分布式机器学习:算法、理论与实践[M]. 机械工业出版社, 2018.

  • [9] Stanford CME 323: Distributed Algorithms and Optimization (Lecture 7)