教程 | 基于LSTM实现手写数字识别
- 2019 年 10 月 7 日
- 筆記
基于tensorflow,如何实现一个简单的循环神经网络,完成手写数字识别,附完整演示代码。
LSTM网络构建
01
基于tensorflow实现简单的LSTM网络,完成mnist手写数字数据集训练与识别。这个其中最重要的构建一个LSTM网络,tensorflow已经给我们提供相关的API, 我们只要使用相关API就可以轻松构建一个简单的LSTM网络。
首先定义输入与目标标签
# create RNN network X = tf.placeholder(shape=[None, time_steps, num_features], dtype=tf.float32) Y = tf.placeholder(shape=[None, 10], dtype=tf.float32)
其中
- None: 表示batchsize的大小或者数目
- time_steps: 网络把输出重新输入的次数
- num_features: 输入矩阵/神经元
构建LSTM单元
lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0) outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
其中:
lstm_cell 表示 LSTM 的单元 num_hidden : 隐藏层节点数目 forget_bias: 遗忘门中要加上的增益偏置 outputs: 网络输出 states:状态
这样我们就构建好一个LSTM循环神经网络了,它的执行过程是很魔幻的。简直是神奇!以后再说。
代码程序执行与输出
02
完整的代码演示分为如下几个部分:
- 加载数据集
- 创建LSTM网络
- 训练网络
- 执行测试
import tensorflow as tf from tensorflow.contrib import rnn import numpy as np from tensorflow.examples.tutorials.mnist import input_data print(tf.__version__) mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) num_hidden = 128 time_steps = 28 num_features = 28 num_classes = 10 batch_size = 128 # create RNN network X = tf.placeholder(shape=[None, time_steps, num_features], dtype=tf.float32) Y = tf.placeholder(shape=[None, 10], dtype=tf.float32) # Define weights weights = { 'out': tf.Variable(tf.random_normal([num_hidden, num_classes])) } biases = { 'out': tf.Variable(tf.random_normal([num_classes])) } def rnn_network(x, weights, biases): x = tf.unstack(x, time_steps, 1) lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0) outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32) return tf.matmul(outputs[-1], weights['out']) + biases['out'] # 输入预测 logits = rnn_network(X, weights, biases) prediction = tf.nn.softmax(logits) # 定义损失函数与优化器 loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( logits=logits, labels=Y)) optimizer = tf.train.AdamOptimizer() train_op = optimizer.minimize(loss_op) # 计算识别精度 correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # 开始训练 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for step in range(1, 5001): batch_x, batch_y = mnist.train.next_batch(batch_size) # Reshape data to get 28 seq of 28 elements batch_x = batch_x.reshape((batch_size, time_steps, num_features)) # Run optimization op (backprop) sess.run(train_op, feed_dict={X: batch_x, Y: batch_y}) if step % 1000 == 0 or step == 1: # Calculate batch loss and accuracy loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x, Y: batch_y}) print("Step " + str(step) + ", Loss= " + "{:.4f}".format(loss) + ", Training Accuracy= " + "{:.3f}".format(acc)) print("Optimization Finished!") # 使用测试数据集测试训练号的模型, 测试128张手写数字图像 test_len = 128 test_data = mnist.test.images[:test_len].reshape((-1, time_steps, num_features)) test_label = mnist.test.labels[:test_len] print("Testing Accuracy:", sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))
运行输出如下:


