PyTorch | 保存和加載模型教程
- 2019 年 10 月 4 日
- 筆記
原題 | SAVING AND LOADING MODELS
作者 | Matthew Inkawhich
原文 | https://pytorch.org/tutorials/beginner/saving_loading_models.html
譯者 | kbsc13("算法猿的成長"公眾號作者)
聲明 | 翻譯是出於交流學習的目的,歡迎轉載,但請保留本文出於,請勿用作商業或者非法用途
簡介
本文主要介紹如何加載和保存 PyTorch 的模型。這裡主要有三個核心函數:
torch.save
:把序列化的對象保存到硬盤。它利用了 Python 的pickle
來實現序列化。模型、張量以及字典都可以用該函數進行保存;torch.load
:採用pickle
將反序列化的對象從存儲中加載進來。torch.nn.Module.load_state_dict
:採用一個反序列化的state_dict
加載一個模型的參數字典。
本文主要內容如下:
- 什麼是狀態字典(state_dict)?
- 預測時加載和保存模型
- 加載和保存一個通用的檢查點(Checkpoint)
- 在同一個文件保存多個模型
- 採用另一個模型的參數來預熱模型(Warmstaring Model)
- 不同設備下保存和加載模型
1. 什麼是狀態字典(state_dict)
PyTorch 中,一個模型(torch.nn.Module
)的可學習參數(也就是權重和偏置值)是包含在模型參數(model.parameters()
)中的,一個狀態字典就是一個簡單的 Python 的字典,其鍵值對是每個網絡層和其對應的參數張量。模型的狀態字典只包含帶有可學習參數的網絡層(比如卷積層、全連接層等)和註冊的緩存(batchnorm
的 running_mean
)。優化器對象(torch.optim
)同樣也是有一個狀態字典,包含的優化器狀態的信息以及使用的超參數。
由於狀態字典也是 Python 的字典,因此對 PyTorch 模型和優化器的保存、更新、替換、恢復等操作都很容易實現。
下面是一個簡單的使用例子,例子來自:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py
# Define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # Initialize model model = TheModelClass() # Initialize optimizer optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Print model's state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor, "t", model.state_dict()[param_tensor].size()) # Print optimizer's state_dict print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "t", optimizer.state_dict()[var_name])
上述代碼先是簡單定義一個 5 層的 CNN,然後分別打印模型的參數和優化器參數。
輸出結果:
Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10]) Optimizer's state_dict: state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
2. 預測時加載和保存模型
加載/保存狀態字典(推薦做法)
保存的代碼:
torch.save(model.state_dict(), PATH)
加載的代碼:
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
當需要為預測保存一個模型的時候,只需要保存訓練模型的可學習參數即可。採用 torch.save()
來保存模型的狀態字典的做法可以更方便加載模型,這也是推薦這種做法的原因。
通常會用 .pt
或者 .pth
後綴來保存模型。
記住
- 在進行預測之前,必須調用
model.eval()
方法來將dropout
和batch normalization
層設置為驗證模型。否則,只會生成前後不一致的預測結果。 load_state_dict()
方法必須傳入一個字典對象,而不是對象的保存路徑,也就是說必須先反序列化字典對象,然後再調用該方法,也是例子中先採用torch.load()
,而不是直接model.load_state_dict(PATH)
加載/保存整個模型
保存:
torch.save(model, PATH)
加載:
# Model class must be defined somewhere model = torch.load(PATH) model.eval()
保存和加載模型都是採用非常直觀的語法並且都只需要幾行代碼即可實現。這種實現保存模型的做法將是採用 Python 的 pickle
模塊來保存整個模型,這種做法的缺點就是序列化後的數據是屬於特定的類和指定的字典結構,原因就是 pickle
並沒有保存模型類別,而是保存一個包含該類的文件路徑,因此,當在其他項目或者在 refactors
後採用都可能出現錯誤。
3. 加載和保存一個通用的檢查點(Checkpoint)
保存的示例代碼:
torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, ... }, PATH)
加載的示例代碼:
model = TheModelClass(*args, **kwargs) optimizer = TheOptimizerClass(*args, **kwargs) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # - or - model.train()
當保存一個通用的檢查點(checkpoint)時,無論是用於繼續訓練還是預測,都需要保存更多的信息,不僅僅是 state_dict
,比如說優化器的 state_dict
也是非常重要的,它包含了用於模型訓練時需要更新的參數和緩存信息,還可以保存的信息包括 epoch
,即中斷訓練的批次,最後一次的訓練 loss,額外的 torch.nn.Embedding
層等等。
上述保存代碼就是介紹了如何保存這麼多種信息,通過用一個字典來進行組織,然後繼續調用 torch.save
方法,一般保存的文件後綴名是 .tar
。
加載代碼也如上述代碼所示,首先需要初始化模型和優化器,然後加載模型時分別調用 torch.load
加載對應的 state_dict
。然後通過不同的鍵來獲取對應的數值。
加載完後,根據後續步驟,調用 model.eval()
用於預測,model.train()
用於恢復訓練。
4. 在同一個文件保存多個模型
保存模型的示例代碼:
torch.save({ 'modelA_state_dict': modelA.state_dict(), 'modelB_state_dict': modelB.state_dict(), 'optimizerA_state_dict': optimizerA.state_dict(), 'optimizerB_state_dict': optimizerB.state_dict(), ... }, PATH)
加載模型的示例代碼:
modelA = TheModelAClass(*args, **kwargs) modelB = TheModelBClass(*args, **kwargs) optimizerA = TheOptimizerAClass(*args, **kwargs) optimizerB = TheOptimizerBClass(*args, **kwargs) checkpoint = torch.load(PATH) modelA.load_state_dict(checkpoint['modelA_state_dict']) modelB.load_state_dict(checkpoint['modelB_state_dict']) optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']) optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']) modelA.eval() modelB.eval() # - or - modelA.train() modelB.train()
當我們希望保存的是一個包含多個網絡模型 torch.nn.Modules
的時候,比如 GAN、一個序列化模型,或者多個模型融合,實現的方法其實和保存一個通用的檢查點的做法是一樣的,同樣採用一個字典來保持模型的 state_dict
和對應優化器的 state_dict
。除此之外,還可以繼續保存其他相同的信息。
加載模型的示例代碼如上述所示,和加載一個通用的檢查點也是一樣的,同樣需要先初始化對應的模型和優化器。同樣,保存的模型文件通常是以 .tar
作為後綴名。
5. 採用另一個模型的參數來預熱模型(Warmstaring Model)
保存模型的示例代碼:
torch.save(modelA.state_dict(), PATH)
加載模型的示例代碼:
modelB = TheModelBClass(*args, **kwargs) modelB.load_state_dict(torch.load(PATH), strict=False)
在之前遷移學習教程中也介紹了可以通過預訓練模型來微調,加快模型訓練速度和提高模型的精度。
這種做法通常是加載預訓練模型的部分網絡參數作為模型的初始化參數,然後可以加快模型的收斂速度。
加載預訓練模型的代碼如上述所示,其中設置參數 strict=False
表示忽略不匹配的網絡層參數,因為通常我們都不會完全採用和預訓練模型完全一樣的網絡,通常輸出層的參數就會不一樣。
當然,如果希望加載參數名不一樣的參數,可以通過修改加載的模型對應的參數名字,這樣參數名字匹配了就可以成功加載。
6. 不同設備下保存和加載模型
在GPU上保存模型,在 CPU 上加載模型
保存模型的示例代碼:
torch.save(model.state_dict(), PATH)
加載模型的示例代碼:
device = torch.device('cpu') model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location=device))
在 CPU 上加載在 GPU 上訓練的模型,必須在調用 torch.load()
的時候,設置參數 map_location
,指定採用的設備是 torch.device('cpu')
,這個做法會將張量都重新映射到 CPU 上。
在GPU上保存模型,在 GPU 上加載模型
保存模型的示例代碼:
torch.save(model.state_dict(), PATH)
加載模型的示例代碼:
device = torch.device('cuda') model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH) model.to(device) # Make sure to call input = input.to(device) on any input tensors that you feed to the model
在 GPU 上訓練和加載模型,調用 torch.load()
加載模型後,還需要採用 model.to(torch.device('cuda'))
,將模型調用到 GPU 上,並且後續輸入的張量都需要確保是在 GPU 上使用的,即也需要採用 my_tensor.to(device)
。
在CPU上保存,在GPU上加載模型
保存模型的示例代碼:
torch.save(model.state_dict(), PATH)
加載模型的示例代碼:
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want model.to(device) # Make sure to call input = input.to(device) on any input tensors that you feed to the model
這次是 CPU 上訓練模型,但在 GPU 上加載模型使用,那麼就需要通過參數 map_location
指定設備。然後繼續記得調用 model.to(torch.device('cuda'))
。
保存 torch.nn.DataParallel 模型
保存模型的示例代碼:
torch.save(model.module.state_dict(), PATH)
torch.nn.DataParallel
是用於實現多 GPU 並行的操作,保存模型的時候,是採用 model.module.state_dict()
。
加載模型的代碼也是一樣的,採用 torch.load()
,並可以放到指定的 GPU 顯卡上。
完整的代碼:
https://github.com/pytorch/tutorials/blob/master/beginner_source/saving_loading_models.py