卷積神經網路處理影像識別(三)

  • 2019 年 11 月 25 日
  • 筆記

本篇接著上一篇來介紹卷積神經網路的訓練(即反向傳播)和應用。

訓練神經網路和保存訓練結果的程式碼如下:

import  tensorflow as tf  import tensorflow.examples.tutorials.mnist.input_data as input_data  import os  import numpy as np  import CNN_MNIST_inference    MODEL_SAVE_PATH ="E:/Python36/my tensorflow/CNN/model_path/"  MODEL_NAME = "MNIST_CNNmodel.ckpt"  print(os.path.join(MODEL_SAVE_PATH, MODEL_NAME))  BATCH_SIZE  =100  LEARNING_RATE_BASE = 0.8  LEARNING_RATE_DECAY = 0.99  REGULARIZATION_RATE = 0.0001  MOVING_AVERAGE_DECAY = 0.99  TRAINING_STEPS = 20000    def train(mnist):      '''training'''      x = tf.placeholder(tf.float32,                         [None,                          CNN_MNIST_inference.IMAGE_HEIGHT,                          CNN_MNIST_inference.IMAGE_WIDTH,                          CNN_MNIST_inference.NUM_CHANNELS], name='x-input')      y_ = tf.placeholder(tf.float32, [None, CNN_MNIST_inference.OUTPUT_NODE], name = 'y-input')      #I2 正則      regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)      y = CNN_MNIST_inference.inference(x, True, regularizer, None, reuse = False)      global_step = tf.Variable(0, trainable = False)      #平均移動      variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)      variables_averages_op = variable_averages.apply(tf.trainable_variables()) # moving average applied      average_y = CNN_MNIST_inference.inference(x, True, regularizer,variable_averages, reuse = True)        # loss      cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = y, labels = tf.argmax(y_, 1))      cross_entropy_mean = tf.reduce_mean(cross_entropy)      tf.add_to_collection('losses', cross_entropy_mean)      loss = tf.add_n(tf.get_collection('losses'))      #loss = cross_entropy_mean        #learning rate with decay      learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step,mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY, staircase = True)      #learning_rate = 0.01      train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step = global_step)      train_op = tf.group(train_step, variables_averages_op)      correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))      accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))        saver = tf.train.Saver() #初始化持久類        with tf.Session() as sess:          tf.global_variables_initializer().run() # 真正變數初始化            validation_set  = np.reshape(mnist.validation.images,                                       [-1,                                        CNN_MNIST_inference.IMAGE_HEIGHT,                                        CNN_MNIST_inference.IMAGE_WIDTH,                                        CNN_MNIST_inference.NUM_CHANNELS])          validate_feed  = {x: validation_set, y_ : mnist.validation.labels} #驗證集            test_set  = np.reshape(mnist.test.images,                                       [-1,                                        CNN_MNIST_inference.IMAGE_HEIGHT,                                        CNN_MNIST_inference.IMAGE_WIDTH,                                        CNN_MNIST_inference.NUM_CHANNELS])          test_feed        = {x: test_set, y_ : mnist.test.labels} #測試集(訓練集)            steps = [] # only for plot          accs = [] # only for plot          losses = [] # only for plot          for i in range(TRAINING_STEPS):              xs, ys = mnist.train.next_batch(BATCH_SIZE)              xs = np.reshape(xs,                              [BATCH_SIZE,                               CNN_MNIST_inference.IMAGE_HEIGHT,                               CNN_MNIST_inference.IMAGE_WIDTH,                               CNN_MNIST_inference.NUM_CHANNELS])                _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict = {x : xs, y_: ys})                #print(i,loss_value)              if i % 25  == 0:                  validate_acc = sess.run(accuracy, feed_dict = validate_feed) #驗證集 準確度                  steps.append(step); accs.append(validate_acc*100); losses.append(loss_value) # only for plot                  print("After %d training steps, validation dataset accuracy after this batch is %g%%, test dataset loss on this batch is %g"%(step, validate_acc*100,loss_value))                  saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step =global_step)            saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step =global_step)          test_acc = sess.run(accuracy, feed_dict = test_feed)          print("After %d training steps, test accuracy using average model is %g%%"%                (TRAINING_STEPS, test_acc*100))          writer = tf.summary.FileWriter("E://TensorBoard//test",sess.graph)            saver.save(sess, r"E:Python36my tensorflowckpt filesmode_mnist.ckpt")      #only for plot      from matplotlib import pyplot as plt      import matplotlib.ticker as mtick      plt.subplot(211)      plt.plot(steps, losses,color="red")      plt.scatter(steps, losses,s=20,color="red")      plt.xlabel("訓練的步數(Batch數)"); plt.ylabel("訓練batch上的Loss(含L2正則Loss)")      plt.subplot(212)      plt.plot(steps, accs,color="green")      plt.scatter(steps, accs,s=20,color="green")      yticks = mtick.FormatStrFormatter("%.3f%%")      plt.gca().yaxis.set_major_formatter(yticks)      plt.xlabel("step"); plt.ylabel("驗證集上的預測準確率")      plt.show()    def main(argv = None):      mnist = input_data.read_data_sets(r"E:Python36my tensorflowMNIST_data",one_hot =True)      train(mnist)    if __name__ == "__main__":      tf.app.run() #調用main()

下面是測試Batch的總Loss和驗證集上的準確率的收斂趨勢圖。由於我的電腦性能不好,所以我大幅度削減了待訓練參數個數。儘管如此,2000輪訓練之後,在驗證集上5000個圖片的預測正確率已達98.3%。如若不削減參數,準確率可達99.4%。

下面的程式碼是利用訓練好的卷積神經網路模型來評估它在驗證集上的準確率(可以在正式訓練時不評估從而節省訓練時間),以及用它用來識別單張圖片。

import  tensorflow as tf  import tensorflow.examples.tutorials.mnist.input_data as input_data  import os  import numpy as np  import CNN_MNIST_inference  import CNN_MNIST_train  import matplotlib.pyplot as plt    def evaluate(mnist):   #評估驗證集的預測準確度      with tf.Graph().as_default() as g:          x = tf.placeholder(tf.float32,                             [None,                              CNN_MNIST_inference.IMAGE_HEIGHT,                              CNN_MNIST_inference.IMAGE_WIDTH,                              CNN_MNIST_inference.NUM_CHANNELS], name='x-input')          y_ = tf.placeholder(tf.float32, [None, CNN_MNIST_inference.OUTPUT_NODE], name = 'y-input')          validation_set  = np.reshape(mnist.validation.images,                                       [-1,                                        CNN_MNIST_inference.IMAGE_HEIGHT,                                        CNN_MNIST_inference.IMAGE_WIDTH,                                        CNN_MNIST_inference.NUM_CHANNELS])          validate_feed  = {x: validation_set, y_ : mnist.validation.labels} #驗證集            y = CNN_MNIST_inference.inference(x, False, None, None, reuse = False)          correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))          accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))              #平均移動          variable_averages = tf.train.ExponentialMovingAverage(CNN_MNIST_train.MOVING_AVERAGE_DECAY)          variables_to_restore = variable_averages.variables_to_restore()          saver = tf.train.Saver(variables_to_restore)          with tf.Session() as sess:              #print(CNN_MNIST_train.MODEL_SAVE_PATH)              #找到目錄中最新的模型文件              ckpt = tf.train.get_checkpoint_state(CNN_MNIST_train.MODEL_SAVE_PATH)              #print(ckpt)              if ckpt and ckpt.model_checkpoint_path:                  #載入模型                  saver.restore(sess, ckpt.model_checkpoint_path)                  #模型的迭代輪數                  global_step = ckpt.model_checkpoint_path.split('/')[-1].split("-")[-1]                  accuary_score = sess.run(accuracy, feed_dict =validate_feed)                  print("After %s training steps, validation accuary = %g" %(global_step, accuary_score)) #global_step是str              else:                  print('No checkpoint file found')                  return     #把所有輸入數據input_data、聲明的常量放進with tf.Graph().as_default(): 裡面就行了,就可以統一到同一個graph了,  #不然input_data是放到系統默認創建的Graph,跟你又重新with tf.Graph().as_default():不是同一個Graph()就會報錯  def recognize(input_x):      g = tf.get_default_graph() # 因為 input_x 默認的圖中,所以可把下面的計算也默認的圖中      with g.as_default():          y = CNN_MNIST_inference.inference(input_x, False, None, None, reuse = False)          saver = tf.train.Saver()          with tf.Session() as sess:              #找到目錄中最新的模型文件              ckpt = tf.train.get_checkpoint_state(CNN_MNIST_train.MODEL_SAVE_PATH)              if ckpt and ckpt.model_checkpoint_path:                  #載入模型                  saver.restore(sess, ckpt.model_checkpoint_path)                  predicted_label = tf.argmax(y, 1)                  print("predicted_label: ", sess.run(predicted_label)[0])              else:                  print('No checkpoint file found')                  return    def plotImage(path):#僅用於繪製待識別的圖片      image_rawdata = tf.gfile.FastGFile(path,"rb").read()      img_data = tf.image.decode_jpeg(image_rawdata)      if img_data.dtype != tf.float32:          img_data = tf.image.convert_image_dtype(img_data, dtype = tf.float32)      with tf.Session() as sess:          image_data = img_data.eval() # return a numpy array#需要運行在會話中      image_data_shaped1 = image_data.reshape(image_data.shape[0],image_data.shape[1])#numpy array      #print(image_data_shaped1)      plt.imshow(image_data_shaped1,cmap='gray')      plt.show()    def main(argv=None):      mnist = input_data.read_data_sets(r"E:Python36my tensorflowMNIST_data",one_hot =True)      evaluate(mnist) #評估在驗證集上的預測準確度      #輸入      image_path = r"E:Python36MNIST picturetest50.jpg"      image_rawdata = tf.gfile.FastGFile(image_path,"rb").read()      img_data0 = tf.image.decode_jpeg(image_rawdata)      if img_data0.dtype != tf.float32:          img_data = tf.image.convert_image_dtype(img_data0, dtype = tf.float32)        #根據神經網路的要求轉換圖片數據的shape!      input_x =  tf.reshape(img_data, [1,                                      CNN_MNIST_inference.IMAGE_HEIGHT,                                      CNN_MNIST_inference.IMAGE_WIDTH,                                      CNN_MNIST_inference.NUM_CHANNELS])      plotImage(image_path)      recognize(input_x)    if __name__ =="__main__":      #tf.app.run() #調用main()      main()#