tensorflow保存與恢復模型
- 2019 年 11 月 20 日
- 筆記
ckpt模型與pb模型比較
- ckpt模型可以重新訓練,pb模型不可以(pb一般用於線上部署)
- ckpt模型可以指定保存最近的n個模型,pb不可以
保存ckpt模型
保存路徑必須帶.ckpt這個後綴名,不能是文件夾,否則無法保存meta文件
12345678 |
CKPT_PATH = './model.ckpt'vgg16_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg19')outputs_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='outputs') # max_to_keep是指在文件夾中保存幾個最近的模型 saver = tf.train.Saver(vgg16_variables + outputs_variables, max_to_keep=1)saver.save(sess, CKPT_PATH) |
---|
恢復ckpt模型
12345 |
ckpt = tf.train.get_checkpoint_state('ckpt') if ckpt: saver.restore(sess, ckpt.model_checkpoint_path) print('Restore from', ckpt.model_checkpoint_path) gstep = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] |
---|
保存pb模型
保存為pb模型時要指明對外暴露哪些介面
123456789 |
graph_def = tf.get_default_graph().as_graph_def() output_graph_def = graph_util.convert_variables_to_constants( sess, graph_def, ['inputs','labels','keep_prob','accuracy'] ) with tf.gfile.GFile('save.pb', 'wb') as fid: serialized_graph = output_graph_def.SerializeToString() fid.write(serialized_graph) |
---|
載入pb模型
pb 格式模型保存與恢復相比於前面的 .ckpt 格式而言要稍微麻煩一點,但使用更靈活,特別是模型恢復,因為它可以脫離會話(Session)而存在,便於部署。
載入步驟如下:
tf.Graph()
定義了一張新的計算圖,與上面的計算圖區分開ParseFromString
將保存的計算圖反序列化tf.import_graph_def
導入一張計算圖- 新建
Session
,獲取Tensor
- 使用模型進行預測
12345678910111213141516171819 |
model_graph = tf.Graph()with model_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile('save.pb', 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') with tf.Session(graph=model_graph) as sess: inputs = tf.get_default_graph().get_tensor_by_name('inputs:0') labels = tf.get_default_graph().get_tensor_by_name('labels:0') keep_prob = tf.get_default_graph().get_tensor_by_name('keep_prob:0') accuracy = tf.get_default_graph().get_tensor_by_name('accuracy:0') batch_xs, batch_ys = mnist.test.next_batch(100) batch_xs = batch_xs.reshape([-1, 28, 28, 1]) acc = sess.run(accuracy, feed_dict={inputs: batch_xs, labels: batch_ys, keep_prob:1.0}) print('After restore sess from pb file, accuracy is ', acc) |
---|