CIFAR10数据集实战-LeNet5神经网络(下)
- 2020 年 1 月 2 日
- 筆記
下面开始加入test部分
先写入test部分代码
for x, label in cifar_test: x, label = x.to(device), label.to(device) logits = model(x) pred = logits.armax(dim=1) # 用argmax选出可能性最大的值的索引
为进行比对
定义正确率
写入对比
total_correct += torch.eq(pred, label).float().sum().item() # torch.eq函数用于对比,同时要转为numpy数据 total_num += x.size(0)
再定义正确率并输出
acc = total_correct / total_num print('acc:', acc)
可以加入模式切换
Model.train()和model.eval()
最终main.py文件为
import torch from torchvision import datasets # 引入pytorch、datasets工具包 from torchvision import transforms # 引入数据变换工具包 from torch.utils.data import DataLoader # 多线程数据读取 from LeNet5 import LeNet5 import torch.nn as nn import torch.optim as optim def main(): batchsz=32 # 这个batch_size数值不宜太大也不宜过小 cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([ transforms.Resize((32, 32)), # .Compose相当于一个数据转换的集合 # 进行数据转换,首先将图片统一为32*32 transforms.ToTensor() # 将数据转化到Tensor中 ]), download=True) # 直接在datasets中导入CIFAR10数据集,放在"cifar"文件夹中 cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True) # 按照其要求,这里的参数需要有batch_size, # 在该部分代码前面定义batch_size # 再使数据加载的随机化 cifar_test = datasets.CIFAR10('cifar', train=False, transform=transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor() ]), download=True) cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True) x, label = iter(cifar_train).next() # 通过.iter方法输出一个数据进行查看 # print('s.shape:', x.shape, 'label.shape:', label.shape) # 输出shape进行查看 device = torch.device('cuda') model = LeNet5().to(device) criteon = nn.CrossEntropyLoss().to(device) optimizer = optim.Adam(model.parameters(), lr=1e-3) print(model) model.train() for epoch in range(1000): for batchidx, (x, label) in enumerate(cifar_train): # batchidx代表了有多少个batch, x, label = x.to(device), label.to(device) logits = model(x) loss = criteon(logits, label) optimizer.zero_grad() loss.backward() optimizer.step() # print(epoch, loss.item()) model.eval() total_correct = 0 total_num = 0 for x, label in cifar_test: x, label = x.to(device), label.to(device) logits = model(x) pred = logits.argmax(dim=1) # 用argmax选出可能性最大的值的索引 # 进行比对 total_correct += torch.eq(pred, label).float().sum().item() # torch.eq函数用于对比,同时要转为numpy数据 total_num += x.size(0) acc = total_correct / total_num print('acc:', acc)
输出为

可以看出正确率在逐渐上升