CIFAR10数据集实战-LeNet5神经网络(下)

下面开始加入test部分

先写入test部分代码

for x, label in cifar_test:      x, label = x.to(device), label.to(device)        logits = model(x)      pred = logits.armax(dim=1)      # 用argmax选出可能性最大的值的索引

为进行比对

定义正确率

写入对比

total_correct += torch.eq(pred, label).float().sum().item()  # torch.eq函数用于对比,同时要转为numpy数据  total_num += x.size(0)

再定义正确率并输出

acc = total_correct / total_num  print('acc:', acc)

可以加入模式切换

Model.train()和model.eval()

最终main.py文件为

import torch  from torchvision import datasets  # 引入pytorch、datasets工具包  from torchvision import transforms  # 引入数据变换工具包  from torch.utils.data import DataLoader  # 多线程数据读取  from LeNet5 import LeNet5  import torch.nn as nn    import torch.optim as optim  def main():        batchsz=32      # 这个batch_size数值不宜太大也不宜过小        cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([          transforms.Resize((32, 32)),          # .Compose相当于一个数据转换的集合          # 进行数据转换,首先将图片统一为32*32          transforms.ToTensor()          # 将数据转化到Tensor中        ]), download=True)      # 直接在datasets中导入CIFAR10数据集,放在"cifar"文件夹中        cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)      # 按照其要求,这里的参数需要有batch_size,      # 在该部分代码前面定义batch_size      # 再使数据加载的随机化            cifar_test = datasets.CIFAR10('cifar', train=False, transform=transforms.Compose([          transforms.Resize((32, 32)),          transforms.ToTensor()      ]), download=True)        cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)          x, label = iter(cifar_train).next()      # 通过.iter方法输出一个数据进行查看      # print('s.shape:', x.shape, 'label.shape:', label.shape)      # 输出shape进行查看              device = torch.device('cuda')      model = LeNet5().to(device)      criteon = nn.CrossEntropyLoss().to(device)      optimizer = optim.Adam(model.parameters(), lr=1e-3)      print(model)        model.train()      for epoch in range(1000):            for batchidx, (x, label) in enumerate(cifar_train):              # batchidx代表了有多少个batch,              x, label = x.to(device), label.to(device)                logits = model(x)              loss = criteon(logits, label)              optimizer.zero_grad()              loss.backward()              optimizer.step()              # print(epoch, loss.item())            model.eval()          total_correct = 0          total_num = 0            for x, label in cifar_test:              x, label = x.to(device), label.to(device)                logits = model(x)              pred = logits.argmax(dim=1)              # 用argmax选出可能性最大的值的索引              # 进行比对              total_correct += torch.eq(pred, label).float().sum().item()              # torch.eq函数用于对比,同时要转为numpy数据              total_num += x.size(0)          acc = total_correct / total_num          print('acc:', acc)

输出为

可以看出正确率在逐渐上升