CIFAR10数据集实战-ResNet网络构建(中)
- 2020 年 2 月 24 日
- 筆記
再定义一个ResNet网络
我们本次准备构建ResNet-18层结构
class ResNet(nn.Module): def __init__(self): super(ResNet, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64) ) # 紧跟着要进行四次这样的单元 # 构建辅助函数,使[b, 64, h, w] => [b, 128, h, w] self.blk1 = ResBlk(64, 128) # 构建辅助函数,使[b, 128, h, w] = > [b, 256, h, w] self.blk2 = ResBlk(128, 256) # 构建辅助函数,使[b, 256, h, w] = > [b, 512, h, w] self.blk3 = ResBlk(256, 512) # 构建辅助函数,使[b, 512, h, w] = > [b, 1024, h, w] self.blk4 = ResBlk(512, 1024)
接下来构建ResNet-18的forward函数
def forward(self, x): x = F.relu(self.conv1(x)) # [b, 64, h, w] => [b, 1024, h, w] x = self.blk1(x) x = self.blk2(x) x = self.blk3(x) x = self.blk4(x)
由于我们要进行10分类问题,要将添加代码
self.outlayer = nn.Linear(1024, 10)
和
x = self.outlayer(x) return x
为确定具体维度大小,我们先构建假数据
def main(): blk = ResBlk(64, 128) tmp = torch.randn(2, 3, 32, 32) out = blk(tmp) print(out.shape) if __name__ == "__main__": main()
此时代码为
import torch import torch.nn as nn import torch.nn.functional as F class ResBlk(nn.Module): # 与上节一样,同样resnet的block单元,继承nn模块 def __init__(self, ch_in, ch_out): super(ResBlk, self).__init__() # 完成初始化 self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(ch_out) # 进行正则化处理,以使train过程更快更稳定 self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(ch_out) self.extra = nn.Sequential() if ch_out != ch_in: self.extra = nn.Sequential( nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1), nn.BatchNorm2d(ch_out), ) def forward(self, x): # 这里输入的是[b, ch, h, w] out = F.relu(self.bn1(self.conv1(x))) out = F.relu(self.bn2(self.conv2(out))) out = self.extra(x) + out # 这便是element.wise add,实现了[b, ch_in, h, w] 和 [b, ch_out, h, w]两个的相加 return out class ResNet(nn.Module): def __init__(self): super(ResNet, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64) ) # 紧跟着要进行四次这样的单元 # 构建辅助函数,使[b, 64, h, w] => [b, 128, h, w] self.blk1 = ResBlk(64, 128) # 构建辅助函数,使[b, 128, h, w] = > [b, 256, h, w] self.blk2 = ResBlk(128, 256) # 构建辅助函数,使[b, 256, h, w] = > [b, 512, h, w] self.blk3 = ResBlk(256, 512) # 构建辅助函数,使[b, 512, h, w] = > [b, 1024, h, w] self.blk4 = ResBlk(512, 1024) self.outlayer = nn.Linear(1024, 10) def forward(self, x): x = F.relu(self.conv1(x)) # [b, 64, h, w] => [b, 1024, h, w] x = self.blk1(x) x = self.blk2(x) x = self.blk3(x) x = self.blk4(x) x = self.outlayer(x) return x def main(): blk = ResBlk(64, 128) tmp = torch.randn(2, 3, 32, 32) out = blk(tmp) print(out.shape) if __name__ == "__main__": main()