Pytorch-nn.Module
- 2019 年 12 月 5 日
- 筆記
本节介绍在pytorch中十分重要的“类”:nn.Module。
在实现自己设计的层结构功能时,必须要使用自己继承的类。
类的书写如下
import torch import torch.nn as nn import torch.nn.functional as F class MyLinear(nn.Module): # 先定义自己的类 def __init__(self, inp, outp): super(MyLinear, self).__init__() # 初始化自己定义的类 self.w = nn.Parameter(torch.randn(outp, inp)) self.b = nn.Parameter(torch.randn(outp)) def forward(self, x): # 定义前向 x = x @ self.w.t() + self.b return x
那么nn.Module类到底是什么?
(1)nn.Module在pytorch中是基本的复类,继承它后会很方便的使用nn.linear、nn.normalize等。
(2)还可以进行嵌套,便于书写树形结构
(3)nn.Module提供了很多已经编写好的功能,如Linear、ReLU、Sigmoid、Conv2d、ConvTransposed2d、Dropout等。
最主要的功能是书写代码方便
self.net = nn.Sequential( # .Sequential()相当于设定了一个容器, # 将需要进行forward的函数代入其中, # 但不用每一个步骤都写上, # 直接放在容器中,后面再定义一个forward代码即可 nn.Conv2d(1, 32, 5, 1, 1), nn.MaxPool2d(2, 2), ... )
使用nn.Module的第三个好处是可以对网络中的参数进行有效的管理
通过.parameters()即可很方便的对参数进行查看
net = nn.Sequential(nn.Linear(4, 2), nn.Linear(2, 2)) print(list(net.parameters()))[0].shape # 输出查看第0层的参数
也可用.named_parameters()来输出网络结构编好名字的参数
print(list(net.named_parameters()))[0].shape
后续再加上.item(),来对各种属性进行查看
print(list(net.named_parameters()))[0].item()
另外nn.Module还可以自己定义类的顺序。
也可以很方便的将所有的运算都转入到GPU上去。使用.device函数,
device = torch.device('cuda') net = Net() net.to(device)
还可以很方便的进行save和load,以防止突然发生的断点和系统崩溃的现象
net.load_state_dict(torch.load('ckpt.mdl')) torch.save(net.state_dict(), 'ckpt.mdl')
nn.Modele还可以很方便的切换状态
# 切换到train状态 net.train() # 切换到test