教程 | 基於LSTM實現手寫數字識別

  • 2019 年 10 月 7 日
  • 筆記

基於tensorflow,如何實現一個簡單的循環神經網絡,完成手寫數字識別,附完整演示代碼。

LSTM網絡構建

01

基於tensorflow實現簡單的LSTM網絡,完成mnist手寫數字數據集訓練與識別。這個其中最重要的構建一個LSTM網絡,tensorflow已經給我們提供相關的API, 我們只要使用相關API就可以輕鬆構建一個簡單的LSTM網絡。

首先定義輸入與目標標籤

# create RNN network  X = tf.placeholder(shape=[None, time_steps, num_features], dtype=tf.float32)  Y = tf.placeholder(shape=[None, 10], dtype=tf.float32)

其中

  • None: 表示batchsize的大小或者數目
  • time_steps: 網絡把輸出重新輸入的次數
  • num_features: 輸入矩陣/神經元

構建LSTM單元

lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)  outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

其中:

lstm_cell 表示 LSTM 的單元 num_hidden : 隱藏層節點數目 forget_bias: 遺忘門中要加上的增益偏置 outputs: 網絡輸出 states:狀態

這樣我們就構建好一個LSTM循環神經網絡了,它的執行過程是很魔幻的。簡直是神奇!以後再說。

代碼程序執行與輸出

02

完整的代碼演示分為如下幾個部分:

  • 加載數據集
  • 創建LSTM網絡
  • 訓練網絡
  • 執行測試
import tensorflow as tf  from tensorflow.contrib import rnn  import numpy as np      from tensorflow.examples.tutorials.mnist import input_data  print(tf.__version__)  mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)    num_hidden = 128  time_steps = 28  num_features = 28  num_classes = 10  batch_size = 128    # create RNN network  X = tf.placeholder(shape=[None, time_steps, num_features], dtype=tf.float32)  Y = tf.placeholder(shape=[None, 10], dtype=tf.float32)    # Define weights  weights = {      'out': tf.Variable(tf.random_normal([num_hidden, num_classes]))  }  biases = {      'out': tf.Variable(tf.random_normal([num_classes]))  }      def rnn_network(x, weights, biases):      x = tf.unstack(x, time_steps, 1)      lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)      outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)      return tf.matmul(outputs[-1], weights['out']) + biases['out']      # 輸入預測  logits = rnn_network(X, weights, biases)  prediction = tf.nn.softmax(logits)    # 定義損失函數與優化器  loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(      logits=logits, labels=Y))  optimizer = tf.train.AdamOptimizer()  train_op = optimizer.minimize(loss_op)    # 計算識別精度  correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))  accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))    # 開始訓練  with tf.Session() as sess:      sess.run(tf.global_variables_initializer())      for step in range(1, 5001):          batch_x, batch_y = mnist.train.next_batch(batch_size)          # Reshape data to get 28 seq of 28 elements          batch_x = batch_x.reshape((batch_size, time_steps, num_features))          # Run optimization op (backprop)          sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})          if step % 1000 == 0 or step == 1:              # Calculate batch loss and accuracy              loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x,                                                                   Y: batch_y})              print("Step " + str(step) + ", Loss= " +                     "{:.4f}".format(loss) + ", Training Accuracy= " +                     "{:.3f}".format(acc))        print("Optimization Finished!")        # 使用測試數據集測試訓練號的模型, 測試128張手寫數字圖像      test_len = 128      test_data = mnist.test.images[:test_len].reshape((-1, time_steps, num_features))      test_label = mnist.test.labels[:test_len]      print("Testing Accuracy:",           sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))

運行輸出如下: