深度學習入門筆記系列 ( 八 ) ——基於 tensorflow 的手寫數字的識別(進階)
- 2019 年 11 月 12 日
- 筆記

基於 tensorflow 的手寫數字的識別(進階)
本系列將分為 8 篇 。本次為第 8 篇 ,基於 tensorflow ,利用卷積神經網路 CNN 進行手寫數字識別 。
1.引言
關於 mnist 數據集的介紹和卷積神經網路的筆記在本系列文章中已有過介紹 ,有需要可見下述兩篇文章 。本系列第 5 篇曾實現利用最簡單的 BP 神經網路進行手寫數字識別 。本系列第 6 篇簡單介紹了下卷積神經網路的知識 。
基於 tensorflow 的手寫數字識別
卷積神經網路(CNN)學習筆記
2.設計的 CNN 結構
本系列第 4 講講過實戰可以大致分為 "三步走"
- 定義神經網路的結構和前向傳播的輸出結果
- 定義損失函數以及選擇反向傳播優化的演算法
- 生成會話(tf.Session) 並在訓練數據上反覆運行反向傳播優化演算法
這裡也一樣 ,當然首先是設計我們針對此實戰的卷積神經網路 ,設計一個最簡單的如下手繪 (還是那句話 ,字醜人帥 ,拒絕反駁)

上圖得到兩次卷積池化結果後 ,將結果展平為 1 維向量 ,即1 *(7*7*64),再連接到十個節點的輸出層 。
3.手動幹起來 !
首先 ,需要讀取 MNIST 數據集 ,利用 TF 框架自帶類進行下載讀取 。

接下來就是根據之前的 「三步走」 進行實踐 。實現上述的網路結構 ,並依舊選擇二次代價函數和梯度下降法 。
首先 ,定義兩個函數 ,用於初始化參數 。再定義兩個函數實現卷積核池化(只是便於模組化 ,提高可讀性)。

根據上述手繪結構圖進行編程實現該結構 。

這裡有一個 dropout 操作 ,目的是訓練過程中使一部分神經元參數不變 ,即不參與訓練 ,相當於簡化結構 ,減少過擬合 。

再在會話 Session 中執行 ,並保存好模型參數 。

測試結果(小詹在按時付費的某伺服器跑的結果)如下圖 :
