tensorflow 三种模型:ckpt、pb、pb-savemodel
- 2020 年 8 月 8 日
- 筆記
- TensorFlow
1、CKPT
目录结构
checkpoint:
model.ckpt-1000.index
model.ckpt-1000.data-00000-of-00001
model.ckpt-1000.meta
特点:
首先这种模型文件是依赖 TensorFlow 的,只能在其框架下使用;
数据和图是分开的
这种在训练的时候用的比较多。
代码:就省略了
2、pb模型-只有模型
这种方式只保存了模型的图结构,可以保留隐私的公布到网上。
感觉一些水的论文会用这种方式。
代码:
thanks://www.jianshu.com/p/9221fbf52c55
·
1 import os 2 import tensorflow as tf 3 from tensorflow.python.saved_model import builder as saved_model_builder 4 from tensorflow.python.saved_model import (signature_constants, signature_def_utils, tag_constants, utils) 5 6 class model(): 7 def __init__(self): 8 self.a = tf.placeholder(tf.float32, [None]) 9 self.w = tf.Variable(tf.constant(2.0, shape=[1]), name="w") 10 b = tf.Variable(tf.constant(0.5, shape=[1]), name="b") 11 self.y = self.a * self.w + b 12 13 #模型保存为ckpt 14 def save_model(): 15 graph1 = tf.Graph() 16 with graph1.as_default(): 17 m = model() 18 with tf.Session(graph=graph1) as session: 19 session.run(tf.global_variables_initializer()) 20 update = tf.assign(m.w, [10]) 21 session.run(update) 22 predict_y = session.run(m.y,feed_dict={m.a:[3.0]}) 23 print(predict_y) 24 25 saver = tf.train.Saver() 26 saver.save(session,"model_pb/model.ckpt") 27 28 29 #保存为pb模型 30 def export_model(session, m): 31 32 33 #只需要修改这一段,定义输入输出,其他保持默认即可 34 model_signature = signature_def_utils.build_signature_def( 35 inputs={"input": utils.build_tensor_info(m.a)}, 36 outputs={ 37 "output": utils.build_tensor_info(m.y)}, 38 39 method_name=signature_constants.PREDICT_METHOD_NAME) 40 41 export_path = "pb_model/1" 42 if os.path.exists(export_path): 43 os.system("rm -rf "+ export_path) 44 print("Export the model to {}".format(export_path)) 45 46 try: 47 legacy_init_op = tf.group( 48 tf.tables_initializer(), name='legacy_init_op') 49 builder = saved_model_builder.SavedModelBuilder(export_path) 50 builder.add_meta_graph_and_variables( 51 session, [tag_constants.SERVING], 52 clear_devices=True, 53 signature_def_map={ 54 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 55 model_signature, 56 }, 57 legacy_init_op=legacy_init_op) 58 59 builder.save() 60 except Exception as e: 61 print("Fail to export saved model, exception: {}".format(e)) 62 63 #加载pb模型 64 def load_pb(): 65 session = tf.Session(graph=tf.Graph()) 66 model_file_path = "pb_model/1" 67 meta_graph = tf.saved_model.loader.load(session, [tf.saved_model.tag_constants.SERVING], model_file_path) 68 69 model_graph_signature = list(meta_graph.signature_def.items())[0][1] 70 output_tensor_names = [] 71 output_op_names = [] 72 for output_item in model_graph_signature.outputs.items(): 73 output_op_name = output_item[0] 74 output_op_names.append(output_op_name) 75 output_tensor_name = output_item[1].name 76 output_tensor_names.append(output_tensor_name) 77 print("load model finish!") 78 sentences = {} 79 # 测试pb模型 80 for test_x in [[1],[2],[3],[4],[5]]: 81 sentences["input"] = test_x 82 feed_dict_map = {} 83 for input_item in model_graph_signature.inputs.items(): 84 input_op_name = input_item[0] 85 input_tensor_name = input_item[1].name 86 feed_dict_map[input_tensor_name] = sentences[input_op_name] 87 predict_y = session.run(output_tensor_names, feed_dict=feed_dict_map) 88 print("predict pb y:",predict_y) 89 90 if __name__ == "__main__": 91 92 save_model() 93 94 graph2 = tf.Graph() 95 with graph2.as_default(): 96 m = model() 97 saver = tf.train.Saver() 98 with tf.Session(graph=graph2) as session: 99 saver.restore(session, "model_pb/model.ckpt") #加载ckpt模型 100 export_model(session, m) 101 102 load_pb()
3、pb模型-Saved model
这是一种简单格式pb模型保存方式
目录结构
└── 1
···├── saved_model.pb
···└── variables
·········├── variables.data-00000-of-00001
·········└── variables.index
特点:
对于训练好的模型,我们都是用来进行使用的,也就是进行inference。
这个时候就不模型变化了。这种方式就将变量的权重变成了一个常亮。
这样方式模型会变小
在一些嵌入式吗,用C或者C++的系统中,我们也是常用.pb格式的。
代码:
thanks to //www.jianshu.com/p/9221fbf52c55
1 import os 2 import tensorflow as tf 3 from tensorflow.python.saved_model import builder as saved_model_builder 4 from tensorflow.python.saved_model import (signature_constants, signature_def_utils, tag_constants, utils) 5 6 class model(): 7 def __init__(self): 8 self.a = tf.placeholder(tf.float32, [None]) 9 self.w = tf.Variable(tf.constant(2.0, shape=[1]), name="w") 10 b = tf.Variable(tf.constant(0.5, shape=[1]), name="b") 11 self.y = self.a * self.w + b 12 13 #模型保存为ckpt 14 def save_model(): 15 graph1 = tf.Graph() 16 with graph1.as_default(): 17 m = model() 18 with tf.Session(graph=graph1) as session: 19 session.run(tf.global_variables_initializer()) 20 update = tf.assign(m.w, [10]) 21 session.run(update) 22 predict_y = session.run(m.y,feed_dict={m.a:[3.0]}) 23 print(predict_y) 24 25 saver = tf.train.Saver() 26 saver.save(session,"model_pb/model.ckpt") 27 28 29 #保存为pb模型 30 def export_model(session, m): 31 32 33 #只需要修改这一段,定义输入输出,其他保持默认即可 34 model_signature = signature_def_utils.build_signature_def( 35 inputs={"input": utils.build_tensor_info(m.a)}, 36 outputs={ 37 "output": utils.build_tensor_info(m.y)}, 38 39 method_name=signature_constants.PREDICT_METHOD_NAME) 40 41 export_path = "pb_model/1" 42 if os.path.exists(export_path): 43 os.system("rm -rf "+ export_path) 44 print("Export the model to {}".format(export_path)) 45 46 try: 47 legacy_init_op = tf.group( 48 tf.tables_initializer(), name='legacy_init_op') 49 builder = saved_model_builder.SavedModelBuilder(export_path) 50 builder.add_meta_graph_and_variables( 51 session, [tag_constants.SERVING], 52 clear_devices=True, 53 signature_def_map={ 54 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 55 model_signature, 56 }, 57 legacy_init_op=legacy_init_op) 58 59 builder.save() 60 except Exception as e: 61 print("Fail to export saved model, exception: {}".format(e)) 62 63 #加载pb模型 64 def load_pb(): 65 session = tf.Session(graph=tf.Graph()) 66 model_file_path = "pb_model/1" 67 meta_graph = tf.saved_model.loader.load(session, [tf.saved_model.tag_constants.SERVING], model_file_path) 68 69 model_graph_signature = list(meta_graph.signature_def.items())[0][1] 70 output_tensor_names = [] 71 output_op_names = [] 72 for output_item in model_graph_signature.outputs.items(): 73 output_op_name = output_item[0] 74 output_op_names.append(output_op_name) 75 output_tensor_name = output_item[1].name 76 output_tensor_names.append(output_tensor_name) 77 print("load model finish!") 78 sentences = {} 79 # 测试pb模型 80 for test_x in [[1],[2],[3],[4],[5]]: 81 sentences["input"] = test_x 82 feed_dict_map = {} 83 for input_item in model_graph_signature.inputs.items(): 84 input_op_name = input_item[0] 85 input_tensor_name = input_item[1].name 86 feed_dict_map[input_tensor_name] = sentences[input_op_name] 87 predict_y = session.run(output_tensor_names, feed_dict=feed_dict_map) 88 print("predict pb y:",predict_y) 89 90 if __name__ == "__main__": 91 92 save_model() 93 94 graph2 = tf.Graph() 95 with graph2.as_default(): 96 m = model() 97 saver = tf.train.Saver() 98 with tf.Session(graph=graph2) as session: 99 saver.restore(session, "model_pb/model.ckpt") #加载ckpt模型 100 export_model(session, m) 101 102 load_pb()