Pytorch-nn.Module

  • 2019 年 12 月 5 日
  • 筆記

本节介绍在pytorch中十分重要的“类”:nn.Module。

在实现自己设计的层结构功能时,必须要使用自己继承的类。

类的书写如下

import torch  import torch.nn as nn  import torch.nn.functional as F      class MyLinear(nn.Module):      # 先定义自己的类      def __init__(self, inp, outp):          super(MyLinear, self).__init__()          # 初始化自己定义的类            self.w = nn.Parameter(torch.randn(outp, inp))          self.b = nn.Parameter(torch.randn(outp))        def forward(self, x):          # 定义前向          x = x @ self.w.t() + self.b          return x

那么nn.Module类到底是什么?

(1)nn.Module在pytorch中是基本的复类,继承它后会很方便的使用nn.linear、nn.normalize等。

(2)还可以进行嵌套,便于书写树形结构

(3)nn.Module提供了很多已经编写好的功能,如Linear、ReLU、Sigmoid、Conv2d、ConvTransposed2d、Dropout等。

最主要的功能是书写代码方便

self.net = nn.Sequential(      # .Sequential()相当于设定了一个容器,      # 将需要进行forward的函数代入其中,      # 但不用每一个步骤都写上,      # 直接放在容器中,后面再定义一个forward代码即可      nn.Conv2d(1, 32, 5, 1, 1),      nn.MaxPool2d(2, 2),      ...    )

使用nn.Module的第三个好处是可以对网络中的参数进行有效的管理

通过.parameters()即可很方便的对参数进行查看

net = nn.Sequential(nn.Linear(4, 2), nn.Linear(2, 2))  print(list(net.parameters()))[0].shape  # 输出查看第0层的参数

也可用.named_parameters()来输出网络结构编好名字的参数

print(list(net.named_parameters()))[0].shape

后续再加上.item(),来对各种属性进行查看

print(list(net.named_parameters()))[0].item()

另外nn.Module还可以自己定义类的顺序。

也可以很方便的将所有的运算都转入到GPU上去。使用.device函数,

device = torch.device('cuda')  net = Net()  net.to(device)

还可以很方便的进行save和load,以防止突然发生的断点和系统崩溃的现象

net.load_state_dict(torch.load('ckpt.mdl'))  torch.save(net.state_dict(), 'ckpt.mdl')

nn.Modele还可以很方便的切换状态

# 切换到train状态  net.train()  # 切换到test