【技術分享】迭代再加權最小二乘
- 2020 年 2 月 14 日
- 筆記
本文原作者:尹迪,經授權後發佈。
1 原理
迭代再加權最小二乘(IRLS
)用於解決特定的最優化問題,這個最優化問題的目標函數如下所示:
$$arg min_{beta} sum_{i=1}^{n}|y_{i} – f_{i}(beta)|^{p}$$
這個目標函數可以通過迭代的方法求解。在每次迭代中,解決一個帶權最小二乘問題,形式如下:
$$beta ^{t+1} = argmin_{beta} sum_{i=1}^{n} w_{i}(beta^{(t)}))|y_{i} – f_{i}(beta)|^{2} = (X^{T}W^{(t)}X)^{-1}X^{T}W^{(t)}y$$
在這個公式中,$W^{(t)}$是權重對角矩陣,它的所有元素都初始化為1。每次迭代中,通過下面的公式更新。
$$W_{i}^{(t)} = |y_{i} – X_{i}beta^{(t)}|^{p-2}$$
2 源碼分析
在spark ml
中,迭代再加權最小二乘主要解決廣義線性回歸問題。下面看看實現代碼。
2.1 更新權重
// Update offsets and weights using reweightFunc val newInstances = instances.map { instance => val (newOffset, newWeight) = reweightFunc(instance, oldModel) Instance(newOffset, newWeight, instance.features) }
這裡使用reweightFunc
方法更新權重。具體的實現在廣義線性回歸的實現中。
/** * The reweight function used to update offsets and weights * at each iteration of [[IterativelyReweightedLeastSquares]]. */ val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = { (instance: Instance, model: WeightedLeastSquaresModel) => { val eta = model.predict(instance.features) val mu = fitted(eta) val offset = eta + (instance.label - mu) * link.deriv(mu) val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) (offset, weight) } } def fitted(eta: Double): Double = family.project(link.unlink(eta))
這裡的model.predict
利用帶權最小二乘模型預測樣本的取值,然後調用fitted
方法計算均值函數$mu$。offset
表示 更新後的標籤值,weight
表示更新後的權重。關於鏈接函數的相關計算可以參考廣義線性回歸的分析。
有一點需要說明的是,這段代碼中標籤和權重的更新並沒有參照上面的原理或者說我理解有誤。
2.2 訓練新的模型
// 使用更新過的樣本訓練新的模型 model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, standardizeFeatures = false, standardizeLabel = false).fit(newInstances) // 檢查是否收斂 val oldCoefficients = oldModel.coefficients val coefficients = model.coefficients BLAS.axpy(-1.0, coefficients, oldCoefficients) val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) => math.max(math.abs(x), math.abs(y)) } val maxTol = math.max(maxTolOfCoefficients, math.abs(oldModel.intercept - model.intercept)) if (maxTol < tol) { converged = true }
訓練完新的模型後,重複2.1步,直到參數收斂或者到達迭代的最大次數。
3 參考文獻
【1】Iteratively reweighted least squares