CIFAR10数据集实战-ResNet网络构建(中)

  • 2020 年 2 月 24 日
  • 筆記

再定义一个ResNet网络

我们本次准备构建ResNet-18层结构

class ResNet(nn.Module):        def __init__(self):          super(ResNet, self).__init__()            self.conv1 = nn.Sequential(              nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),              nn.BatchNorm2d(64)          )          # 紧跟着要进行四次这样的单元          # 构建辅助函数,使[b, 64, h, w] => [b, 128, h, w]          self.blk1 = ResBlk(64, 128)          # 构建辅助函数,使[b, 128, h, w] = > [b, 256, h, w]          self.blk2 = ResBlk(128, 256)          # 构建辅助函数,使[b, 256, h, w] = > [b, 512, h, w]          self.blk3 = ResBlk(256, 512)          # 构建辅助函数,使[b, 512, h, w] = > [b, 1024, h, w]          self.blk4 = ResBlk(512, 1024)

接下来构建ResNet-18的forward函数

def forward(self, x):      x = F.relu(self.conv1(x))      # [b, 64, h, w] => [b, 1024, h, w]      x = self.blk1(x)      x = self.blk2(x)      x = self.blk3(x)      x = self.blk4(x)

由于我们要进行10分类问题,要将添加代码

self.outlayer = nn.Linear(1024, 10)

x = self.outlayer(x)  return x

为确定具体维度大小,我们先构建假数据

def main():      blk = ResBlk(64, 128)      tmp = torch.randn(2, 3, 32, 32)      out = blk(tmp)      print(out.shape)    if __name__ == "__main__":      main()

此时代码为

import torch  import torch.nn as nn  import torch.nn.functional as F    class ResBlk(nn.Module):      # 与上节一样,同样resnet的block单元,继承nn模块      def __init__(self, ch_in, ch_out):          super(ResBlk, self).__init__()          # 完成初始化            self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)          self.bn1 = nn.BatchNorm2d(ch_out)          # 进行正则化处理,以使train过程更快更稳定          self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)          self.bn2 = nn.BatchNorm2d(ch_out)            self.extra = nn.Sequential()            if ch_out != ch_in:              self.extra = nn.Sequential(                  nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),                  nn.BatchNorm2d(ch_out),              )            def forward(self, x):          # 这里输入的是[b, ch, h, w]          out = F.relu(self.bn1(self.conv1(x)))          out = F.relu(self.bn2(self.conv2(out)))              out = self.extra(x) + out          # 这便是element.wise add,实现了[b, ch_in, h, w] 和 [b, ch_out, h, w]两个的相加            return out      class ResNet(nn.Module):        def __init__(self):          super(ResNet, self).__init__()            self.conv1 = nn.Sequential(              nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),              nn.BatchNorm2d(64)          )          # 紧跟着要进行四次这样的单元          # 构建辅助函数,使[b, 64, h, w] => [b, 128, h, w]          self.blk1 = ResBlk(64, 128)          # 构建辅助函数,使[b, 128, h, w] = > [b, 256, h, w]          self.blk2 = ResBlk(128, 256)          # 构建辅助函数,使[b, 256, h, w] = > [b, 512, h, w]          self.blk3 = ResBlk(256, 512)          # 构建辅助函数,使[b, 512, h, w] = > [b, 1024, h, w]          self.blk4 = ResBlk(512, 1024)            self.outlayer = nn.Linear(1024, 10)        def forward(self, x):          x = F.relu(self.conv1(x))          # [b, 64, h, w] => [b, 1024, h, w]          x = self.blk1(x)          x = self.blk2(x)          x = self.blk3(x)          x = self.blk4(x)            x = self.outlayer(x)          return x    def main():      blk = ResBlk(64, 128)      tmp = torch.randn(2, 3, 32, 32)      out = blk(tmp)      print(out.shape)    if __name__ == "__main__":      main()