徒手撸算法 | 线性回归
- 2019 年 12 月 24 日
- 筆記
01
线性回归模型

其中,x1, x2, …, xp是预测变量,x0恒为1;θ1, θ2, …, θp是预测变量系数(权重或参数),θ0是偏置项(截距),令θ=(θ0, θ1, θ2, …, θp)T.
02
损失函数
回归问题常用的损失函数是平方损失

通常用所有样本误差的均值衡量模型预测的效果

最优参数是使平均损失最小的参数

对于线性回归来说,可以求解出参数的解析解;但大多数机器学习模型没有解析解,所以一般采用优化算法通过多次迭代模型参数使损失函数降低来求其近似解,即为数值解。
03
优化算法
在求数值解的优化算法中,梯度下降法被广泛使用。梯度下降法一般有三种,第一种是批量梯度下降,每一次迭代使用所有样本来进行梯度更新,因为每次迭代使用所有样本,所以批量梯度下降法每次迭代都更为准确地朝着极值方向移动,但也因为每次迭代使用的样本最多,所以影响了训练的速度;第二种是随机梯度下降,每一次迭代使用一个样本来进行梯度更新,所以训练速度更快,但准确性下降了;第三种是小批量梯度下降,每一次迭代使用部分样本来进行梯度更新,是上面两种方法的折中。下面使用小批量梯度下降优化线性回归参数。
小批量梯度下降法:
令每次迭代的小批量样本为B,|B|为小批量样本的大小,则小批量样本上的平均损失为

参数θ的梯度,即偏导数为

其中,

则小批量梯度下降法下参数θ的迭代公式为

,其中η为学习率
即为

04
算法实现
1)构造数据
使用真实系数θ_true构造数据集y=Xθ_true+ϵ,其中噪声项ϵ服从均值为0、标准差为0.01的正态分布。
import numpy as np import random # 设置预测变量个数 num_features = 3 # 设置样本个数 num_samples = 10000 # 创建预测变量,服从均值为0、标准差为1的正态分布 features = np.random.normal(scale=1, size=(num_samples, num_features)) # 加入恒为1的x0变量 inputs = np.concatenate((np.ones((num_samples, 1)), features), axis=1) # 设置真实参数 θ_true = np.array([[3.2, 2, -3.4, 1.7]]).T # 计算 y=Xθ_true labels = np.dot(inputs, θ_true) # 在y中加入噪声ϵ, labels += np.random.normal(scale=0.01, size=labels.shape)
2)设置数据集上小批量样本迭代函数
def batch_iter(X, y, batch_size): num_samples = len(X) indices = list(range(num_samples)) random.shuffle(indices) X_iter = [] y_iter = [] for i in range(0, num_samples, batch_size): j = np.array(indices[i: min(i + batch_size, num_samples)]) X_iter.append(X[j]) y_iter.append(y[j]) return zip(X_iter, y_iter)
3)定义线性回归模型
def linreg(X, θ): return np.dot(X, θ)
4)初始化模型参数c
# 将系数初始化成均值为0、标准差为0.01的正态随机数 θ = np.random.normal(scale=0.01, size=(num_features+1, 1))
5)训练模型
# 设置学习率 η = 0.05 # 设置迭代次数,每一次迭代都会通过小批量方式遍历所有样本 num_epochs = 5 # 设置小批量样本大小 batch_size = 100 for epoch in range(num_epochs): # 通过小批量样本上的梯度更新参数,X和y分别是小批量样本的变量和标签 for X, y in batch_iter(inputs, labels, batch_size): # 计算y的估计值y_hat y_hat = linreg(X, θ) # 计算参数θ的梯度,这里应用矩阵乘法实现并行运算 grad_θ = np.dot(X.T, y_hat-y)/(num_features+1) # 更新参数θ θ -= η*grad_θ
查看一下θ的估计值,和真实值θ_true=(3.2, 2, -3.4, 1.7)T已经非常接近了。
print(θ) [[ 3.19945356] [ 2.00027749] [-3.39981785] [ 1.69967308]]