tensoflow serving 實戰之GAN 識別門牌號的識別服務接口

  • 2019 年 11 月 27 日
  • 筆記

TensorFlow服務

TensorFlow服務,託管模型並提供遠程訪問。TensorFlow服務有一個很好的文檔的架構和有用的教程。不幸的是,這個有點難用,你需要做較大改動來為自己的模型提供服務。

安裝測試使用請參看  安裝並測試demo

目錄:

作為一個例子,採取了一個GAN模型的半監督學習

  • 街景房屋號碼數據集上訓練半監督學習的GAN模型
  • 使用GAN鑒別器來預測房屋號碼。作為輸出,有10個對應於從0到9的數字的預測信心分數。
  • 讓TensorFlow在Docker容器中服務我的模型
  • 創建客戶端以請求數字圖像的分數

您可以在我的GitHub信息庫中找到實現細節。

主要步驟是:

對於正在使用TensorFlow創建Deep Learning模型的任何人來說,前兩個步驟非常簡單,我不想在這裡詳細介紹。但是最後兩個步驟對我來說是相當新的,我花了一些時間來了解它的工作原理和所需要的。

TensorFlow服務。它是什麼?

TensorFlow服務實現運行機器學習模型的服務器,並提供對它們的遠程訪問。常見的任務是提供數據(例如圖像)的預測和分類。

幾個技術亮點:

  • 服務器實現GRPC接口,因此您無法從瀏覽器發出請求。相反,我們需要創建一個可以通過GRPC進行通信的客戶端
  • TensorFlow服務已經為存儲為Protobuf的模型提供了操作
  • 您可以創建自己的實例來處理以其他格式存儲的模型

所以我需要將我的模型導出到Protobuf。Protobuf協議緩衝區(或Protobuf)允許高效的數據序列化。它是一個軟件的開源軟件,已經開發出來了…,對,谷歌:-)

將模型導出為Protobuf

TensorFlow服務提供SavedModelBuild類,將模型保存為Protobuf。這裡描述很好。

我的GAN模型接受一個形狀[batch_num,width,height,channels]的圖像張量,其中批次數為1,用於投放(您只能預測一個圖像在時間),寬度和高度為32像素,圖像通道數為3必須對輸入圖像進行縮放,使每個像素在[-1,1]的範圍內,而不在[0,255]的範圍內。

從另一方面,服務模式必須接受JPEG圖像作為輸入,因此為了服務,我需要注入層以將JPEG轉換為所需的圖像張量。

首先,我實現了圖像轉換。這對我來說有點棘手。

serialized_tf_example = tf.placeholder(                              tf.string,name ='input_image')  feature_configs = {'image / encoded':tf.FixedLenFeature(                                           shape = [],                                           dtype = tf.string),}  tf_example = tf.parse_example(serialized_tf_example,feature_configs)  jpegs = tf_example ['image / encoded']  images = tf.map_fn(preprocess_image,jpegs,dtype = tf.float32)  image = tf.squeeze(images,[0])  #現在圖像形狀是(1,?,?,3)

基本上,您需要一個佔位符,用於串行輸入圖像,功能配置(字典名稱到功能),您可以列出預期輸入(在我的情況下為JPEG格式的圖像/編碼)和功能類型。然後,您解析序列化示例並從中提取JPEG。最後一步是將JPEG轉換為所需的圖像張量。請參閱我的GitHub的實現細節(preprocess_image方法)。

然後我可以使用該圖像張量作為我的GAN模型的輸入,創建會話對象並加載保存的檢查點。

......  net = GAN(images,z_size,learning_rate,drop_rate = 0。)  ......  saver = tf.train.Saver()  whth tf.Session() as sess:      #從上一個檢查點恢復模型      ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)          saver.restore(sess,ckpt.model_checkpoint_path)  ......

接下來的挑戰是,如何使用提供的SavedModelBuilder將還原的模型轉換為Protobuf。

builder = tf.saved_model.builder.SavedModelBuilder(export_path)

您必須使用輸入,輸出和方法名稱(例如分類或預測)創建所謂的簽名。TensorFlow提供了一個方法tf.saved_model.utils.build_tensor_info來創建張量信息。我用它來定義輸入和輸出(在我的情況下的分數)。

predict_tensor_inputs_info =       tf.saved_model.utils.build_tensor_info(jpegs)  predict_tensor_scores_info =       tf.saved_model.utils.build_tensor_info(net.discriminator_out)

現在我準備好創建簽名。

prediction_signature =(      tf.saved_model.signature_def_utils.build_signature_def(          inputs = {'images':predict_tensor_inputs_info},          outputs = {'scores':predict_tensor_scores_info},          method_name =               tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

「圖像」「分數」是預定義的名稱,您必須在輸入和輸出字典中使用它們。

教程 TensorFlow團隊中創建兩個簽名 – 一個用於分類,一個用於預測。我不想要任何分類結果,所以預測簽名對我來說足夠了。

最後一步 – 保存模型。

legacy_init_op = tf.group(tf.tables_initializer(),                            name ='legacy_init_op')  builder.add_meta_graph_and_variables(      sess,      [tf.saved_model.tag_constants.SERVING],      signature_def_map = {'predict_images':prediction_signature},      legacy_init_op = legacy_init_op)  builder.save()

這是非常簡單的,現在你的模型存儲為Protobuf。導出文件夾的結構應該是:

  • variables.data -xxx-of-yyyvariables.index的變量文件夾
  • saved_model.pb文件

工作的第一部分完成 – 模型成功導出為Protobuf。

把它放在一起

環境

我在以下環境中開發和測試:

  • GPU供電的PC(NVidia GeForce GTX 1060 6 GB)
  • Ubuntu 16.04
  • 蟒蛇 4.3.14
  • Python 3.5
  • TensorFlow 1.1,GPU構建。注意:我有TensorFlow 1.2的問題,所以我回到以前的版本

自己試試

以下是您需要執行的步驟,以便自己嘗試。

  • 克隆來源
cd ~  git clone <a class="markup--anchor markup--pre-anchor" href="https://github.com/Vetal1977/tf_serving_example.git" target="_blank" rel="nofollow noopener" data-href="https://github.com/Vetal1977/tf_serving_example.git">https://github.com/Vetal1977/tf_serving_example.git  </a>cd tf_serving_example
  • 訓練模型
python3 svnh_semi_supervised_model_train.py

下載date 約需5-10分鐘,並測試街景房屋號碼數據集和另一個測試集合 需要20分鐘訓練模型(在我的環境中)。

  • 檢查保存模型
ls ./checkpoints

您應該看到數據,索引和元數據文件。

  • 導出模型到Protobuf由TensorFlow提供服務
python3 svnh_semi_supervised_model_saved.py --checkpoint-dir =。/ checkpoints --output_dir =。/ gan-export --model-version = 1

應打印出以下內容

成功將GAN模型版本'1'導出到'./gan-export'

如果你輸入

ls ./gan-export/1

你應該得到變量文件夾和saved_model.pb文件。

如何測試接口?

啟動接口服務

tensorflow_model_server --port=9000 --model_name=gan --model_base_path=/home/abc/Desktop/tf_serving_example-master/gan-export

首先從二進制文件中恢復為圖像文件

python3 svnh_semi_supervised_model_save_test_images.py

可以從該目錄下看到svnh_test_images 隨機抽取64張的門派圖像

發起請求

python svnh_semi_supervised_client.py --server=localhost:9000 --image=./svnh_test_images/image_22.jpg

返回信息如下:

outputs {    key: "scores"    value {      dtype: DT_FLOAT      tensor_shape {        dim {          size: 1        }        dim {          size: 10        }      }      float_val: 9.66807189862e-10      float_val: 0.000227736207307      float_val: 0.980206489563      float_val: 8.44745736686e-05      float_val: 0.0005895147915      float_val: 4.84909996601e-08      float_val: 2.95252248179e-05      float_val: 0.0188583470881      float_val: 3.83840551876e-06      float_val: 9.54504270761e-12    }  }

其中的float_val 就是其softmax的數值,可以這麼理解,從上到下共計10行,代表該圖對其數值 0-9的分別預測信心度,越靠近1則信心越高,

我測試的圖是

 則float第三行 的數值是最大的。說明預測對了

訓練過程如下:

Raising pool_size_limit_ from 100 to 110  		Classifier train accuracy:  0.19  		Classifier test accuracy 0.136140135218  		Step time:  0.05168628692626953  		Epoch time:  22.38613224029541  Epoch 1  		Classifier train accuracy:  0.244  		Classifier test accuracy 0.274585125999  		Step time:  0.030593395233154297  		Epoch time:  21.693989753723145  Epoch 2  		Classifier train accuracy:  0.394  		Classifier test accuracy 0.445221266134  		Step time:  0.030382871627807617  		Epoch time:  21.80638337135315  Epoch 3  		Classifier train accuracy:  0.58  		Classifier test accuracy 0.533804548248  		Step time:  0.03650689125061035  		Epoch time:  21.941813707351685

原創文章,轉載請註明: 轉載自URl-team

本文鏈接地址: tensoflow serving 實戰之GAN 識別門牌號的識別服務接口

  1. 目標檢測筆記二:Object Detection API 小白實踐指南
  2. CNN結構模型一句話概述:從LeNet到ShuffleNet
  3. TensorFlow識別字母扭曲干擾型驗證碼-開放源碼與98%模型
  4. TensorFlow 資源大全–中文版
  5. image net 2012數據集以及中文標籤分享