可視化PLA
- 2019 年 10 月 6 日
- 筆記

可視化PLA
0.說在前面
1.實現
2.作者的話
0.說在前面
之前Perceptron Learning Algorithm這篇文章詳細講了感知機PLA算法。
前兩天買了本統計學習方法,今天早上看了兩章,其中第二章就是這個PLA,跟李老師的課程講的基本一致,本節主要通過python實現這個感知機算法,並通過matlibplot可視化圖形,以及終端打印出下圖結果!

書上2.1圖
1.實現
原理請參考網上教程,或者我在前言的文章,再或者統計學習方法書上算法。
【導包】
分別用於矩陣,表格數據打印,數據可視化。
import numpy as np from prettytable import PrettyTable from matplotlib import pyplot as plt
【初始化】
# 原始數據 data = [[3, 3], [4, 3], [1, 1]] # shape=(3,2) X = np.array(data) print(X) # shape=(3,1) y = np.array([1, 1, -1]) # 設a=1,b=0,w為shape=(2,1) a=1 # 初始化為0 w=np.zeros((2,1)) print(w) b=0 # 設定循環 flag = True length = len(X) print(length) j = 1 # 誤分類列表 errorpoint_list=[0] # 權重列表 w_list=[0] # 偏值列表 b_list=[0] # 函數表達式列表 wb_list=[0]
【算法實現】
while flag: count = 0 print("第" + str(j) + "次糾正") for i in range(length): # 取出x的坐標點,shape=(1,2)x(2,1)=(1,1) # w*x+b運算 wb = int(np.dot(X[i,:], w) + b) # 尋找錯誤點 if wb * y[i] <= 0: w += (y[i]*a*X[i,:]).reshape(w.shape) b += a*y[i] count += 1 print("x"+str(i+1)+"為誤分類點") errorpoint_list.append((i+1)) print(w) w_list.append((int(w[0][0]),int(w[1][0]))) print(b) b_list.append(b) wb_function = str(int(w[0][0]))+"*x1+"+str(int(w[1][0]))+"*x2+("+str((b))+")" print(wb_function) wb_list.append(wb_function) break if count == 0: flag = False j+=1 # 最後被break掉的數據添加到各自列表 errorpoint_list.append(0) w_list.append((int(w[0][0]),int(w[1][0]))) b_list.append(b) wb_function = str(int(w[0][0]))+"*x1+"+str(int(w[1][0]))+"*x2+("+str((b))+")" wb_list.append(wb_function)
【可視化表】
# 可視化表2.1 pt = PrettyTable() pt.add_column("迭代次數",np.linspace(0,8,9,dtype=int)) pt.add_column("誤分類點",errorpoint_list) pt.add_column("w",w_list) pt.add_column("b",b_list) pt.add_column("w*x+b",wb_list) print(pt)

可視化表2.1圖
【最終結果可視化】
# 可視化 x = np.linspace(0, 7, 200) # 最終的函數表達式為w[0][0]*x+w[1][0]*y=0,推導後就是下面的式子 y = (-b - w[0][0] * x) / w[1][0] plt.plot(x, y, color='r') plt.scatter(X[:2, 0], X[:2, 1], color='blue', marker='o', label='Positive') plt.scatter(X[2:, 0], X[2:, 1], color='red', marker='x', label='Negative') plt.xlabel('x') plt.ylabel('y') plt.legend() plt.title('PLA') plt.savefig('pla.png', dpi=75) plt.show()

可視化結果圖
2.作者的話
最後,您如果覺得本公眾號對您有幫助,歡迎您多多支持,轉發,謝謝! 更多內容,請關注本公眾號機器學習系列!