徒手擼演算法 | 線性回歸
- 2019 年 12 月 24 日
- 筆記
01
線性回歸模型

其中,x1, x2, …, xp是預測變數,x0恆為1;θ1, θ2, …, θp是預測變數係數(權重或參數),θ0是偏置項(截距),令θ=(θ0, θ1, θ2, …, θp)T.
02
損失函數
回歸問題常用的損失函數是平方損失

通常用所有樣本誤差的均值衡量模型預測的效果

最優參數是使平均損失最小的參數

對於線性回歸來說,可以求解出參數的解析解;但大多數機器學習模型沒有解析解,所以一般採用優化演算法通過多次迭代模型參數使損失函數降低來求其近似解,即為數值解。
03
優化演算法
在求數值解的優化演算法中,梯度下降法被廣泛使用。梯度下降法一般有三種,第一種是批量梯度下降,每一次迭代使用所有樣本來進行梯度更新,因為每次迭代使用所有樣本,所以批量梯度下降法每次迭代都更為準確地朝著極值方向移動,但也因為每次迭代使用的樣本最多,所以影響了訓練的速度;第二種是隨機梯度下降,每一次迭代使用一個樣本來進行梯度更新,所以訓練速度更快,但準確性下降了;第三種是小批量梯度下降,每一次迭代使用部分樣本來進行梯度更新,是上面兩種方法的折中。下面使用小批量梯度下降優化線性回歸參數。
小批量梯度下降法:
令每次迭代的小批量樣本為B,|B|為小批量樣本的大小,則小批量樣本上的平均損失為

參數θ的梯度,即偏導數為

其中,

則小批量梯度下降法下參數θ的迭代公式為

,其中η為學習率
即為

04
演算法實現
1)構造數據
使用真實係數θ_true構造數據集y=Xθ_true+ϵ,其中雜訊項ϵ服從均值為0、標準差為0.01的正態分布。
import numpy as np import random # 設置預測變數個數 num_features = 3 # 設置樣本個數 num_samples = 10000 # 創建預測變數,服從均值為0、標準差為1的正態分布 features = np.random.normal(scale=1, size=(num_samples, num_features)) # 加入恆為1的x0變數 inputs = np.concatenate((np.ones((num_samples, 1)), features), axis=1) # 設置真實參數 θ_true = np.array([[3.2, 2, -3.4, 1.7]]).T # 計算 y=Xθ_true labels = np.dot(inputs, θ_true) # 在y中加入雜訊ϵ, labels += np.random.normal(scale=0.01, size=labels.shape)
2)設置數據集上小批量樣本迭代函數
def batch_iter(X, y, batch_size): num_samples = len(X) indices = list(range(num_samples)) random.shuffle(indices) X_iter = [] y_iter = [] for i in range(0, num_samples, batch_size): j = np.array(indices[i: min(i + batch_size, num_samples)]) X_iter.append(X[j]) y_iter.append(y[j]) return zip(X_iter, y_iter)
3)定義線性回歸模型
def linreg(X, θ): return np.dot(X, θ)
4)初始化模型參數c
# 將係數初始化成均值為0、標準差為0.01的正態隨機數 θ = np.random.normal(scale=0.01, size=(num_features+1, 1))
5)訓練模型
# 設置學習率 η = 0.05 # 設置迭代次數,每一次迭代都會通過小批量方式遍歷所有樣本 num_epochs = 5 # 設置小批量樣本大小 batch_size = 100 for epoch in range(num_epochs): # 通過小批量樣本上的梯度更新參數,X和y分別是小批量樣本的變數和標籤 for X, y in batch_iter(inputs, labels, batch_size): # 計算y的估計值y_hat y_hat = linreg(X, θ) # 計算參數θ的梯度,這裡應用矩陣乘法實現並行運算 grad_θ = np.dot(X.T, y_hat-y)/(num_features+1) # 更新參數θ θ -= η*grad_θ
查看一下θ的估計值,和真實值θ_true=(3.2, 2, -3.4, 1.7)T已經非常接近了。
print(θ) [[ 3.19945356] [ 2.00027749] [-3.39981785] [ 1.69967308]]