[PyTorch 學習筆記] 7.1 模型保存與加載
本章代碼:
- //github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py
- //github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_load.py
- //github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/checkpoint_resume.py
- //github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/save_checkpoint.py
這篇文章主要介紹了序列化與反序列化,以及 PyTorch 中的模型保存於加載的兩種方式,模型的斷點續訓練。
序列化與反序列化
模型在內存中是以對象的邏輯結構保存的,但是在硬盤中是以二進制流的方式保存的。
-
序列化是指將內存中的數據以二進制序列的方式保存到硬盤中。PyTorch 的模型保存就是序列化。
-
反序列化是指將硬盤中的二進制序列加載到內存中,得到模型的對象。PyTorch 的模型加載就是反序列化。
PyTorch 中的模型保存與加載
torch.save
torch.save(obj, f, pickle_module, pickle_protocol=2, _use_new_zipfile_serialization=False)
主要參數:
- obj:保存的對象,可以是模型。也可以是 dict。因為一般在保存模型時,不僅要保存模型,還需要保存優化器、此時對應的 epoch 等參數。這時就可以用 dict 包裝起來。
- f:輸出路徑
其中模型保存還有兩種方式:
保存整個 Module
這種方法比較耗時,保存的文件大
torch.savev(net, path)
只保存模型的參數
推薦這種方法,運行比較快,保存的文件比較小
state_sict = net.state_dict()
torch.savev(state_sict, path)
下面是保存 LeNet 的例子。在網絡初始化中,把權值都設置為 2020,然後保存模型。
import torch
import numpy as np
import torch.nn as nn
from common_tools import set_seed
class LeNet2(nn.Module):
def __init__(self, classes):
super(LeNet2, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x
def initialize(self):
for p in self.parameters():
p.data.fill_(2020)
net = LeNet2(classes=2019)
# "訓練"
print("訓練前: ", net.features[0].weight[0, ...])
net.initialize()
print("訓練後: ", net.features[0].weight[0, ...])
path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"
# 保存整個模型
torch.save(net, path_model)
# 保存模型參數
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)
運行完之後,文件夾中生成了“model.pkl和
model_state_dict.pkl`,分別保存了整個網絡和網絡的參數
torch.load
torch.load(f, map_location=None, pickle_module, **pickle_load_args)
主要參數:
- f:文件路徑
- map_location:指定存在 CPU 或者 GPU。
加載模型也有兩種方式
加載整個 Module
如果保存的時候,保存的是整個模型,那麼加載時就加載整個模型。這種方法不需要事先創建一個模型對象,也不用知道模型的結構,代碼如下:
path_model = "./model.pkl"
net_load = torch.load(path_model)
print(net_load)
輸出如下:
LeNet2(
(features): Sequential(
(0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(4): ReLU()
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=400, out_features=120, bias=True)
(1): ReLU()
(2): Linear(in_features=120, out_features=84, bias=True)
(3): ReLU()
(4): Linear(in_features=84, out_features=2019, bias=True)
)
)
只加載模型的參數
如果保存的時候,保存的是模型的參數,那麼加載時就參數。這種方法需要事先創建一個模型對象,再使用模型的load_state_dict()
方法把參數加載到模型中,代碼如下:
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)
net_new = LeNet2(classes=2019)
print("加載前: ", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load)
print("加載後: ", net_new.features[0].weight[0, ...])
模型的斷點續訓練
在訓練過程中,可能由於某種意外原因如斷點等導致訓練終止,這時需要重新開始訓練。斷點續練是在訓練過程中每隔一定次數的 epoch 就保存模型的參數和優化器的參數,這樣如果意外終止訓練了,下次就可以重新加載最新的模型參數和優化器的參數,在這個基礎上繼續訓練。
下面的代碼中,每隔 5 個 epoch 就保存一次,保存的是一個 dict,包括模型參數、優化器的參數、epoch。然後在 epoch 大於 5 時,就break
模擬訓練意外終止。關鍵代碼如下:
if (epoch+1) % checkpoint_interval == 0:
checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)
在 epoch 大於 5 時,就break
模擬訓練意外終止
if epoch > 5:
print("訓練意外中斷...")
break
斷點續訓練的恢復代碼如下:
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
scheduler.last_epoch = start_epoch
需要注意的是,還要設置scheduler.last_epoch
參數為保存的 epoch。模型訓練的起始 epoch 也要修改為保存的 epoch。
參考資料
如果你覺得這篇文章對你有幫助,不妨點個贊,讓我有更多動力寫出好文章。