PyTorch | 保存和加载模型教程
- 2019 年 10 月 4 日
- 筆記
原题 | SAVING AND LOADING MODELS
作者 | Matthew Inkawhich
原文 | https://pytorch.org/tutorials/beginner/saving_loading_models.html
译者 | kbsc13("算法猿的成长"公众号作者)
声明 | 翻译是出于交流学习的目的,欢迎转载,但请保留本文出于,请勿用作商业或者非法用途
简介
本文主要介绍如何加载和保存 PyTorch 的模型。这里主要有三个核心函数:
torch.save
:把序列化的对象保存到硬盘。它利用了 Python 的pickle
来实现序列化。模型、张量以及字典都可以用该函数进行保存;torch.load
:采用pickle
将反序列化的对象从存储中加载进来。torch.nn.Module.load_state_dict
:采用一个反序列化的state_dict
加载一个模型的参数字典。
本文主要内容如下:
- 什么是状态字典(state_dict)?
- 预测时加载和保存模型
- 加载和保存一个通用的检查点(Checkpoint)
- 在同一个文件保存多个模型
- 采用另一个模型的参数来预热模型(Warmstaring Model)
- 不同设备下保存和加载模型
1. 什么是状态字典(state_dict)
PyTorch 中,一个模型(torch.nn.Module
)的可学习参数(也就是权重和偏置值)是包含在模型参数(model.parameters()
)中的,一个状态字典就是一个简单的 Python 的字典,其键值对是每个网络层和其对应的参数张量。模型的状态字典只包含带有可学习参数的网络层(比如卷积层、全连接层等)和注册的缓存(batchnorm
的 running_mean
)。优化器对象(torch.optim
)同样也是有一个状态字典,包含的优化器状态的信息以及使用的超参数。
由于状态字典也是 Python 的字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都很容易实现。
下面是一个简单的使用例子,例子来自:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py
# Define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # Initialize model model = TheModelClass() # Initialize optimizer optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Print model's state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor, "t", model.state_dict()[param_tensor].size()) # Print optimizer's state_dict print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "t", optimizer.state_dict()[var_name])
上述代码先是简单定义一个 5 层的 CNN,然后分别打印模型的参数和优化器参数。
输出结果:
Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10]) Optimizer's state_dict: state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
2. 预测时加载和保存模型
加载/保存状态字典(推荐做法)
保存的代码:
torch.save(model.state_dict(), PATH)
加载的代码:
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。采用 torch.save()
来保存模型的状态字典的做法可以更方便加载模型,这也是推荐这种做法的原因。
通常会用 .pt
或者 .pth
后缀来保存模型。
记住
- 在进行预测之前,必须调用
model.eval()
方法来将dropout
和batch normalization
层设置为验证模型。否则,只会生成前后不一致的预测结果。 load_state_dict()
方法必须传入一个字典对象,而不是对象的保存路径,也就是说必须先反序列化字典对象,然后再调用该方法,也是例子中先采用torch.load()
,而不是直接model.load_state_dict(PATH)
加载/保存整个模型
保存:
torch.save(model, PATH)
加载:
# Model class must be defined somewhere model = torch.load(PATH) model.eval()
保存和加载模型都是采用非常直观的语法并且都只需要几行代码即可实现。这种实现保存模型的做法将是采用 Python 的 pickle
模块来保存整个模型,这种做法的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是 pickle
并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors
后采用都可能出现错误。
3. 加载和保存一个通用的检查点(Checkpoint)
保存的示例代码:
torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, ... }, PATH)
加载的示例代码:
model = TheModelClass(*args, **kwargs) optimizer = TheOptimizerClass(*args, **kwargs) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # - or - model.train()
当保存一个通用的检查点(checkpoint)时,无论是用于继续训练还是预测,都需要保存更多的信息,不仅仅是 state_dict
,比如说优化器的 state_dict
也是非常重要的,它包含了用于模型训练时需要更新的参数和缓存信息,还可以保存的信息包括 epoch
,即中断训练的批次,最后一次的训练 loss,额外的 torch.nn.Embedding
层等等。
上述保存代码就是介绍了如何保存这么多种信息,通过用一个字典来进行组织,然后继续调用 torch.save
方法,一般保存的文件后缀名是 .tar
。
加载代码也如上述代码所示,首先需要初始化模型和优化器,然后加载模型时分别调用 torch.load
加载对应的 state_dict
。然后通过不同的键来获取对应的数值。
加载完后,根据后续步骤,调用 model.eval()
用于预测,model.train()
用于恢复训练。
4. 在同一个文件保存多个模型
保存模型的示例代码:
torch.save({ 'modelA_state_dict': modelA.state_dict(), 'modelB_state_dict': modelB.state_dict(), 'optimizerA_state_dict': optimizerA.state_dict(), 'optimizerB_state_dict': optimizerB.state_dict(), ... }, PATH)
加载模型的示例代码:
modelA = TheModelAClass(*args, **kwargs) modelB = TheModelBClass(*args, **kwargs) optimizerA = TheOptimizerAClass(*args, **kwargs) optimizerB = TheOptimizerBClass(*args, **kwargs) checkpoint = torch.load(PATH) modelA.load_state_dict(checkpoint['modelA_state_dict']) modelB.load_state_dict(checkpoint['modelB_state_dict']) optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']) optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']) modelA.eval() modelB.eval() # - or - modelA.train() modelB.train()
当我们希望保存的是一个包含多个网络模型 torch.nn.Modules
的时候,比如 GAN、一个序列化模型,或者多个模型融合,实现的方法其实和保存一个通用的检查点的做法是一样的,同样采用一个字典来保持模型的 state_dict
和对应优化器的 state_dict
。除此之外,还可以继续保存其他相同的信息。
加载模型的示例代码如上述所示,和加载一个通用的检查点也是一样的,同样需要先初始化对应的模型和优化器。同样,保存的模型文件通常是以 .tar
作为后缀名。
5. 采用另一个模型的参数来预热模型(Warmstaring Model)
保存模型的示例代码:
torch.save(modelA.state_dict(), PATH)
加载模型的示例代码:
modelB = TheModelBClass(*args, **kwargs) modelB.load_state_dict(torch.load(PATH), strict=False)
在之前迁移学习教程中也介绍了可以通过预训练模型来微调,加快模型训练速度和提高模型的精度。
这种做法通常是加载预训练模型的部分网络参数作为模型的初始化参数,然后可以加快模型的收敛速度。
加载预训练模型的代码如上述所示,其中设置参数 strict=False
表示忽略不匹配的网络层参数,因为通常我们都不会完全采用和预训练模型完全一样的网络,通常输出层的参数就会不一样。
当然,如果希望加载参数名不一样的参数,可以通过修改加载的模型对应的参数名字,这样参数名字匹配了就可以成功加载。
6. 不同设备下保存和加载模型
在GPU上保存模型,在 CPU 上加载模型
保存模型的示例代码:
torch.save(model.state_dict(), PATH)
加载模型的示例代码:
device = torch.device('cpu') model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location=device))
在 CPU 上加载在 GPU 上训练的模型,必须在调用 torch.load()
的时候,设置参数 map_location
,指定采用的设备是 torch.device('cpu')
,这个做法会将张量都重新映射到 CPU 上。
在GPU上保存模型,在 GPU 上加载模型
保存模型的示例代码:
torch.save(model.state_dict(), PATH)
加载模型的示例代码:
device = torch.device('cuda') model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH) model.to(device) # Make sure to call input = input.to(device) on any input tensors that you feed to the model
在 GPU 上训练和加载模型,调用 torch.load()
加载模型后,还需要采用 model.to(torch.device('cuda'))
,将模型调用到 GPU 上,并且后续输入的张量都需要确保是在 GPU 上使用的,即也需要采用 my_tensor.to(device)
。
在CPU上保存,在GPU上加载模型
保存模型的示例代码:
torch.save(model.state_dict(), PATH)
加载模型的示例代码:
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want model.to(device) # Make sure to call input = input.to(device) on any input tensors that you feed to the model
这次是 CPU 上训练模型,但在 GPU 上加载模型使用,那么就需要通过参数 map_location
指定设备。然后继续记得调用 model.to(torch.device('cuda'))
。
保存 torch.nn.DataParallel 模型
保存模型的示例代码:
torch.save(model.module.state_dict(), PATH)
torch.nn.DataParallel
是用于实现多 GPU 并行的操作,保存模型的时候,是采用 model.module.state_dict()
。
加载模型的代码也是一样的,采用 torch.load()
,并可以放到指定的 GPU 显卡上。
完整的代码:
https://github.com/pytorch/tutorials/blob/master/beginner_source/saving_loading_models.py