多次回歸分析及推導
- 2019 年 10 月 3 日
- 筆記
多次回歸分析
在線性回歸分析的時候,我用了一條直線去擬合年齡和工資的數據,結果不是太貼合的。我們嘗試先用多次方程組來擬合數據。
我們先把數據讀出出來。
import tensorflow as tf import pandas as pd import numpy as np from matplotlib import pyplot as plt
unrate = pd.read_csv('SD.csv') unrate = unrate.sort_values('Year') print(unrate)
Year Salary 0 1.0 39451 30 1.1 40343 1 1.2 46313 31 1.3 47605 2 1.4 37839 .. ... ... 85 12.0 106247 86 12.5 117634 87 12.6 113300 88 13.3 123056 89 13.5 122537 [90 rows x 2 columns]
這次我們用一個二次方程來擬合一下這些數據。
方程我們定義為如下:
[ hat(y_i)=W_1*x_i^2 + W_2*x_i+b]
那麼這樣的話,我們就有三個參數 W_1, W_2, b。我們先給這三個參數一個初始數值。
w_1 = 1000 w_2 =1000 b = 1000 print(w_1) print(w_2) print(b) y_pred = w_1* np.power(unrate['Year'],2) + w_2* unrate['Year'] + b plt.scatter(unrate['Year'],unrate['Salary']) plt.plot(unrate['Year'],y_pred) plt.show()
1000 1000 1000
我們如果按照上述的模型,求出預測值(hat{y}),我們需要一個函數來評估這個值的好壞。
[loss=sum_{i=0}^{n} (y_i -hat{y}_i)^2]
這個函數和一次的一樣,沒有任何變化。接下來,我們需要求出這個函數的導函數。
[frac{dl}{dw_1} = frac{dl}{dhat{y}}*frac{dhat{y}}{dw_1} =-2sum_{i=0}^{n}(y_i-hat{y}_i)*x_i^2 ]
[ frac{dl}{dw_2} = frac{dl}{dhat{y}}*frac{dhat{y}}{dw_2}=-2sum_{i=0}^{n}(y_i-hat{y}_i)*x_i ]
[ frac{dl}{db}=frac{dl}{dhat{y}}*frac{dhat{y}}{db}=-2sum_{i=0}^{n}(y_i-hat{y}_i) ]
我們來把上述的函數代碼化
def train(w_1,w_2, b): learning_rate = 0.000001 y_pred = w_1* np.power(unrate['Year'],2) + w_2* unrate['Year'] + b dw_1 = -2*np.sum( np.transpose(unrate['Salary'] - y_pred)*np.power(unrate['Year'],2)) dw_2 = -2*np.sum( np.transpose(unrate['Salary'] - y_pred)*unrate['Year']) db = -2*np.sum((unrate['Salary'] - y_pred)) temp_w_1 = w_1 - learning_rate * dw_1 temp_w_2 = w_2 - learning_rate * dw_2 temp_b = b - learning_rate * db w_1 = temp_w_1 w_2= temp_w_2 b = temp_b return w_1,w_2,b
我們來運行下測試下效果:
for i in range(10000): w_1, w_2, b = train(w_1,w_2,b) print(w_1) print(w_2) print(b) y_pred = w_1 * np.power(unrate['Year'],2) + w_2 * unrate['Year'] + b loss = np.power((y_pred-unrate['Salary']),2).sum() plt.scatter(unrate['Year'],unrate['Salary']) plt.plot(unrate['Year'],y_pred)
-695.3117280326662 17380.592541992835 8744.131370136933 8487947406.30475
上面就是我們擬合出來的效果。
我們可以看出來,比我們之前一次的擬合的數據要好很多。