TensorFlow2.0(12):模型保存與序列化

  • 2019 年 12 月 30 日
  • 筆記

TensorFlow2.0(1):基本數據結構——張量

TensorFlow2.0(2):數學運算

TensorFlow2.0(3):張量排序、最大最小值

TensorFlow2.0(4):填充與複製

TensorFlow2.0(5):張量限幅

TensorFlow2.0(6):利用data模組進行數據預處理

TensorFlow2.0(7):4種常用的激活函數

TensorFlow2.0(8):誤差計算:損失函數總結

TensorFlow2.0(9):神器級可視化工具TensorBoard

TensorFlow2.0(10):載入自定義圖片數據集到Dataset

TensorFlow2.0(11):tf.keras建模三部曲

模型訓練好之後,我們就要想辦法將其持久化保存下來,不然關機或者程式退出後模型就不復存在了。本文介紹兩種持久化保存模型的方法:

在介紹這兩種方法之前,我們得先創建並訓練好一個模型,還是以mnist手寫數字識別數據集訓練模型為例:

import tensorflow as tf  from tensorflow import keras  from tensorflow.keras import layers, optimizers, Sequential
model = Sequential([  # 創建模型      layers.Dense(256, activation=tf.nn.relu),      layers.Dense(128, activation=tf.nn.relu),      layers.Dense(64, activation=tf.nn.relu),      layers.Dense(32, activation=tf.nn.relu),      layers.Dense(10)      ]  )  (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()  x_train = x_train.reshape(60000, 784).astype('float32') / 255  x_test = x_test.reshape(10000, 784).astype('float32') / 255    model.compile(loss='sparse_categorical_crossentropy',                optimizer=keras.optimizers.RMSprop())  history = model.fit(x_train, y_train,  # 進行簡單的1次迭代訓練                      batch_size=64,                      epochs=1)
Train on 60000 samples  60000/60000 [==============================] - 3s 46us/sample - loss: 2.3700  

方法一:model.save()

通過模型自帶的save()方法可以將模型保存到一個指定文件中,保存的內容包括:

  • 模型的結構
  • 模型的權重參數
  • 通過compile()方法配置的模型訓練參數
  • 優化器及其狀態
model.save('mymodels/mnist.h5')

使用save()方法保存後,在mymodels目錄下就會有一個mnist.h5文件。需要使用模型時,通過keras.models.load_model()方法從文件中再次載入即可。

new_model = keras.models.load_model('mymodels/mnist.h5')
WARNING:tensorflow:Sequential models without an `input_shape` passed to the first layer cannot reload their optimizer state. As a result, your model isstarting with a freshly initialized optimizer.  

新載入出來的new_model在結構、功能、參數各方面與model是一樣的。

通過save()方法,也可以將模型保存為SavedModel 格式。SavedModel格式是TensorFlow所特有的一種序列化文件格式,其他程式語言實現的TensorFlow中同樣支援:

model.save('mymodels/mnist_model', save_format='tf')  # 將模型保存為SaveModel格式
WARNING:tensorflow:From /home/chb/anaconda3/envs/study_python/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.  Instructions for updating:  If using Keras pass *_constraint arguments to layers.  INFO:tensorflow:Assets written to: mymodels/mnist_model/assets  
new_model = keras.models.load_model('mymodels/mnist_model')  # 載入模型
print(keras.models.__dir__())
['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__path__', '__file__', '__cached__', '__builtins__', '_sys', 'Sequential', 'Model', 'clone_model', 'model_from_config', 'model_from_json', 'model_from_yaml', 'load_model', 'save_model']  

方法二:model.save_weights()

save()方法會保留模型的所有資訊,但有時候,我們僅對部分資訊感興趣,例如僅對模型的權重參數感興趣,那麼就可以通過save_weights()方法進行保存。

model.save_weights('mymodels/mnits_weights')  # 保存模型權重資訊
new_model = Sequential([  # 創建新的模型      layers.Dense(256, activation=tf.nn.relu),      layers.Dense(128, activation=tf.nn.relu),      layers.Dense(64, activation=tf.nn.relu),      layers.Dense(32, activation=tf.nn.relu),      layers.Dense(10)      ]  )  new_model.compile(loss='sparse_categorical_crossentropy',                optimizer=keras.optimizers.RMSprop())  new_model.load_weights('mymodels/mnits_weights')  # 將保存好的權重資訊載入的新的模型中
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f49c42b87d0>  

註:本系列所有部落格將持續更新並發布在github上,您可以通過github下載本系列所有文章筆記文件。

https://github.com/ChenHuabin321/tensorflow2_tutorials

作者部落格:

https://www.cnblogs.com/chenhuabin