­

使用MATLAB快速搭建神經網路實現分類任務(模式識別)

  • 2020 年 5 月 4 日
  • AI

使用神經網路能執行幾種典型的任務:聚類、擬合、分類(模式識別)以及時間序列預測。

其中分類任務可以說是最常應用的場景之一,在之前的文章里也使用了分類任務作為案例對神經網路進行了入門講解。

時常遇到想要使用神經網路快速地實現分類的同學。

今天就講一講怎麼用MATLAB快速地完成吧。

1.準備數據

這裡使用MNIST數據集作為案例。

MNIST是一個很有名的手寫數字識別數據集。對於每張照片,都是以一個28*28的矩陣存儲的,數據「展平」之後是一個長度為784的一維數據。

MNIST數據集數字0~9示例

數據分為四組:

TRAIN_images:訓練集輸入數據,維度為60000*784。其中60000代表訓練集共有60000組數據(批次數bench)。

TRAIN_labels:訓練集的標籤數據,維度為60000*10,其中的10代表的數字0~9共有10種類型。比如數字0使用[1 0 0 0 0 0 0 0 0 0]代表,數字1使用[0 1 0 0 0 0 0 0 0 0]代表,以此類推。

TEST_images:測試集輸入數據,維度為10000*784

TEST_labels:訓練集的標籤數據,維度為10000*10

其中訓練集是用於訓練神經網路,測試集用於驗證神經網路分類的優劣度(正確率)。

需要注意的是,有的時候標籤值不是以矩陣的形式表示的,而是以0~9這樣的類別數字表示的。這時候需要對標籤類型進行轉換,部落客寫了這樣一個函數:

function label = class2label(class)
% 將class轉為label
% 例如將[3,2,1,1]轉換為[0,0,1;0,1,0;1,0,0;1,0,0]
% 輸入:
% class 為類別數據,一維數據
% 輸出:
% label 為多維矩陣,一行為一個標籤,共n行,即n個標籤
% 示例:label = class2label([3,2,1,1])

同樣從矩陣形式轉變為數字類別形式也有,這兩個函數時常會用到:

function class = label2class(label)
% 將label轉為class
% 例如將[0,0,1;0,1,0;1,0,0;1,0,0]轉換為[3,2,1,1]
% 輸入:
% label 為多維矩陣,一行為一個標籤,共n行,即n個標籤(一定要注意輸入的label的維度正確)
% 輸出:
% class 為分類結果,行向量,一維數據
% 示例:class = label2class([0,0,1;0,1,0;1,0,0;1,0,0])

2.初始化神經網路

patternnet是MATLAB內置的用於對目標類別進行分類的神經網路。使用patternnet時的標籤值必須是矩陣形式的。該函數的使用方法比較簡單,如下圖所示:

% 2.初始化神經網路
hiddenSizes = 120;                   %隱藏層數
net = patternnet(hiddenSizes);       %初始化模式識別神經網路
view(net)                            %查看神經網路結構

此時神經網路的結構如下圖,可以看出隱藏層的激活函數默認的是tansig,該激活函數是可以更換的。

模式識別神經網路結構

3.訓練神經網路

MATLAB使用train函數訓練淺層神經網路。訓練神經網路的指令也可以說是很簡單了,不過需要注意的是訓練數據的行列方向。

訓練集輸入數據是R*Q的矩陣,R為特徵維度,Q為批次數,訓練集標籤數據是U*Q的矩陣,U為標籤種類數,Q為批次數。如果不是這樣的方向則需要對數據進行轉置。

如下述語句對數據就進行了轉置。

% 3.訓練神經網路
net = train(net,TRAIN_images',TRAIN_labels'); %訓練網路

4.使用測試集進行分類

這步中使用了測試集進行了分類測試,並與真實的分類值進行對比,得到分類的正確率,程式碼如下:

% 4.使用測試集進行分類
testLen = 10000;                            %測試集長度
val = sim(net,TEST_images');                %計算測試集分類結果
classes = vec2ind(val);                     %將分類結果轉換為class
r = sum(classes == label2class(TEST_labels'))/(testLen);  %計算正確率
disp(['模式識別的正確率為',num2str(r)])                   %列印結果

分類的結果為:

96.69%算是差強人意的正確率,通過優化還可以將正確率進一步提升。例如在第二部分的程式碼後加入這樣兩行程式碼:

net.layers{1}.transferFcn = 'logsig';  %將激活函數改為sigmod
net.trainFcn = 'traincgf';             %將訓練函數換為traincgf

關於激活函數和訓練函數更多的說明可以看這裡:神經網路「分類」工具使用手冊

此時分類結果為:

正確率有所提升,同學們可以再優化隱藏層數、激活函數、訓練函數以及net的更多屬性獲取更好的分類效果。

5.封裝

上述流程可以封裝為一個函數文件,通過該函數文件可以快速實現分類任務的神經網路訓練以及分類性能測試。

而你需要做的只是把數據準備好即可。

封裝好的函數的說明如下:

function [net,r] = fastPatternnet(dTrain,dTrainLabel,dTest,dTestLabel,hiddenSizes,auto,set)
% 快速模式識別(分類)神經網路,可以自主設定訓練集比例,並得到測試集分類正確率
% 輸入:
% dTrain:神經網路輸入的訓練集,R*Q的矩陣,R為特徵維度,Q為批次數,輸入該變數時一定要注意行列方向是否正確
% dTrainLabel:神經網路的標籤值,U*Q的矩陣,U為標籤種類數,Q為批次數,輸入該變數時一定要注意行列方向是否正確
% dTest:神經網路輸入的測試集,R*Q的矩陣,R為特徵維度,Q為批次數,輸入該變數時一定要注意行列方向是否正確
% dTestLabel:神經網路輸入的測試集,R*Q的矩陣,R為特徵維度,Q為批次數,輸入該變數時一定要注意行列方向是否正確
% hiddenSizes:神經網路隱藏層數
% auto:是否進行自動糾錯,'on'為是,否則為否。開啟自動糾錯後會智慧調整訓練、測試集的行列方向。
% set:網路的額外設置,具體設置見:神經網路「分類」工具使用手冊 | 工具箱文檔
% 輸出:
% net:訓練完成的神經網路
% r:使用測試集得到的模式識別正確率,1對應100%正確

獲取更多資訊歡迎關注我的公眾號「括弧的城堡」,公眾號里可能還會有更多有趣的東西分享。

後續可能還會補充聚類、擬合和時間序列預測的MATLAB快速實現方法~

分類無處不在