【技術分享】pytorch的FINETUNING實踐(resnet18 cifar10)

  • 2019 年 12 月 1 日
  • 筆記

本文主要是用pytorch訓練resnet18模型,對cifar10進行分類,然後將cifar10的數據進行調整,加載已訓練好的模型,在原有模型上FINETUNING 對調整的數據進行分類, 可參考pytorch官網教程

resnet18模型

pytorch的resnet18模型引用:https://github.com/kuangliu/pytorch-cifar

模型詳情可參考github裏面的models/resnet.py, 這裡不做詳細的說明,readme描述準確率可達到93.02%,但我本地測試迭代200次沒有達到這個數字,本地200次迭代準確率為87.40%。

導入需要的包

import os    import numpy as np  import torch.backends.cudnn as cudnn  import torch.optim as optim  import torchvision  import torchvision.transforms as transforms    from models import *  from utils import progress_bar

設置隨機種子,讓結果可復現

這裡嘗試了比較久,在cpu上運行,只需要設置torch.manual_seed(SEED)即可穩定復現結果,但在GPU上始終不行,總存在randomness的問題,後來在友人的幫助下,查了官方的資料,終於解決了這個問題,感謝。其中tensorflow在GPU似乎做不到結果可穩定復現,如果有知道的同學,還請不吝指導~

SEED = 0  torch.manual_seed(SEED)  torch.cuda.manual_seed(SEED)  torch.backends.cudnn.deterministic = True  torch.backends.cudnn.benchmark = False  np.random.seed(SEED)

設置是運行在cpu上還是gpu上

根據是否有gpu可用選擇運行的設備,注意驅動的安裝,版本的兼容性,驅動也折磨了我很久。。由於我運行在docker中,下載的驅動版本不一致,導致一直檢測不到gpu

device = 'cuda' if torch.cuda.is_available() else 'cpu'  best_acc = 0  start_epoch = 0

數據加載及預處理

數據存放在py文件同級目錄下的data文件夾下,如果數據不存在,download設置的為True,會自動從pytorch上進行下載,這裡對數據進行不同的轉換,增加數據多樣性。

transform_train = transforms.Compose([      transforms.RandomCrop(32, padding=4),      transforms.RandomHorizontalFlip(),      transforms.ToTensor(),      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),  ])    transform_test = transforms.Compose([      transforms.ToTensor(),      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),  ])    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)  trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)  testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

對數據集進行調整

原來cifar數據集包含10個類別

['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

需要實踐FINETUNING,所以對數據集進行了改造,由10類改為2類,分別為動物和運輸工具。馬算不算交通工具呢?^.^

clz_idx = trainset.class_to_idx  clz_to_idx = {'animal': 0, 'transport': 1}  clz = ['animal', 'transport']  animal_name = ["bird", "cat", "deer", "dog", "frog", "horse"]  animal = [clz_idx[x] for x in animal_name]    trainset.targets = [0 if x in animal else 1 for x in trainset.targets]  trainset.class_to_idx = clz_to_idx  trainset.classes = clz  testset.targets = [0 if x in animal else 1 for x in testset.targets]  testset.class_to_idx = clz_to_idx  testset.classes = clz

加載預訓練的模型

模型存放在checkpoint目錄下,模型的訓練是上述的Resnet18, 注意如果是gpu訓練,尤其關注一下if中代碼的順序。

  • 將net裝換為DataParallel,用以並行訓練,因為原Resnet18在gpu上訓練使用了DataParallel,所以這裡也要進行封裝,會包一層module
  • FINETUNING:將最後一層的10類輸出,改為2類輸出。注意gpu中的寫法,net.module.linear
  • net = net.to(device) 修改了模型之後,要將模型推送到gpu上,這步不能提前,會出現參數不在GPU上的錯誤
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'  checkpoint = torch.load('./checkpoint/ckpt.pth')    net = ResNet18()  if device == 'cuda':      net = torch.nn.DataParallel(net)      net.load_state_dict(checkpoint['net'])      net.module.linear = nn.Linear(net.module.linear.in_features, 2)  else:      net.load_state_dict(checkpoint['net'])      net.linear = nn.Linear(net.linear.in_features, 2)    net = net.to(device)

指定不需要調整的層數

指定前40層的參數固定,不需要再學習

for idx, (name, param) in enumerate(net.named_parameters()):      if idx > 40:  # count of layers is 62          param.requires_grad = False        if param.requires_grad == True:          print("t", idx, name)

loss函數和優化算法

criterion = nn.CrossEntropyLoss()  optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

訓練函數 及 測試函數

參考Resnet18中的main.py, 在測試的時候,保存訓練的結果,用以後續繼續訓練,區分文件夾保存, 同時只有在精度提高的基礎上進行保存

def train(epoch):      print('nEpoch: %d' % epoch)      net.train()      train_loss = 0      correct = 0      total = 0      for batch_idx, (inputs, targets) in enumerate(trainloader):          inputs, targets = inputs.to(device), targets.to(device)          optimizer.zero_grad()          outputs = net(inputs)          loss = criterion(outputs, targets)          loss.backward()          optimizer.step()          train_loss += loss.item()          _, predicted = outputs.max(1)          total += targets.size(0)          correct += predicted.eq(targets).sum().item()          # print('%d/%d, [Loss: %.03f | Acc: %.3f%% (%d/%d)]'          #       % (batch_idx+1, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total))          progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'                       % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))    best_acc = 0  def test(epoch):      global best_acc      net.eval()      test_loss = 0      correct = 0      total = 0      with torch.no_grad():          for batch_idx, (inputs, targets) in enumerate(testloader):              inputs, targets = inputs.to(device), targets.to(device)              outputs = net(inputs)              loss = criterion(outputs, targets)                test_loss += loss.item()              _, predicted = outputs.max(1)              total += targets.size(0)              correct += predicted.eq(targets).sum().item()                progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'                           % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))        # Save checkpoint.      acc = 100. * correct / total      if acc > best_acc:          print('Saving..')          state = {              'net': net.state_dict(),              'acc': acc,              'epoch': epoch,          }          if not os.path.isdir('checkpoint_ft'):              os.mkdir('checkpoint_ft')          torch.save(state, './checkpoint_ft/ckpt.pth')          best_acc = acc

開始訓練

由於在已經訓練好的模型的基礎上訓練,這裡的迭代次數不用太多即可以達到較高的準確率

for epoch in range(start_epoch, start_epoch + 20):      train(epoch)      test(epoch)

結果展示

Epoch: 0   [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.520 | Acc: 88.662% (44331/50000)   [================================================================>]  Step: 21ms | Tot 100/100 | Loss: 0.449 | Acc: 95.090% (9509/10000)  Saving..    Epoch: 1   [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.430 | Acc: 95.342% (47671/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.411 | Acc: 95.590% (9559/10000)  Saving..    Epoch: 2   [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.394 | Acc: 95.816% (47908/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.373 | Acc: 96.110% (9611/10000)  Saving..    Epoch: 3   [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.376 | Acc: 96.002% (48001/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.386 | Acc: 94.560% (9456/10000)    Epoch: 4   [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.368 | Acc: 96.160% (48080/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.365 | Acc: 96.350% (9635/10000)  Saving..    Epoch: 5   [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.362 | Acc: 96.214% (48107/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.381 | Acc: 93.430% (9343/10000)    Epoch: 6   [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.360 | Acc: 96.070% (48035/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.362 | Acc: 95.400% (9540/10000)    Epoch: 7   [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.358 | Acc: 96.062% (48031/50000)   [================================================================>]  Step: 21ms | Tot 100/100 | Loss: 0.400 | Acc: 90.730% (9073/10000)    Epoch: 8   [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.356 | Acc: 96.214% (48107/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.362 | Acc: 96.280% (9628/10000)    Epoch: 9   [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.353 | Acc: 96.242% (48121/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.376 | Acc: 94.590% (9459/10000)    Epoch: 10   [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.352 | Acc: 96.348% (48174/50000)   [================================================================>]  Step: 21ms | Tot 100/100 | Loss: 0.384 | Acc: 93.080% (9308/10000)    Epoch: 11   [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.351 | Acc: 96.236% (48118/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.356 | Acc: 95.480% (9548/10000)    Epoch: 12   [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.350 | Acc: 96.348% (48174/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.383 | Acc: 93.170% (9317/10000)      Epoch: 13   [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.348 | Acc: 96.358% (48179/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.373 | Acc: 93.330% (9333/10000)    Epoch: 14   [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.347 | Acc: 96.446% (48223/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.391 | Acc: 91.670% (9167/10000)    Epoch: 15   [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.346 | Acc: 96.324% (48162/50000)   [================================================================>]  Step: 21ms | Tot 100/100 | Loss: 0.347 | Acc: 95.880% (9588/10000)    Epoch: 16   [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.344 | Acc: 96.488% (48244/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.343 | Acc: 95.980% (9598/10000)    Epoch: 17   [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.344 | Acc: 96.416% (48208/50000)   [================================================================>]  Step: 21ms | Tot 100/100 | Loss: 0.344 | Acc: 95.890% (9589/10000)    Epoch: 18   [================================================================>]  Step: 54ms | Tot: 3 391/391  Loss: 0.344 | Acc: 96.370% (48185/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.354 | Acc: 95.060% (9506/10000)    Epoch: 19   [================================================================>]  Step: 53ms | Tot: 3 391/391  Loss: 0.344 | Acc: 96.338% (48169/50000)   [================================================================>]  Step: 20ms | Tot 100/100 | Loss: 0.399 | Acc: 89.760% (8976/10000)

在已有準確率為87.4%的Resnet18模型上進行FINETUNING二分類,第一次迭代準確率就能達到95.09%,收斂速度還是很快的,分類效果也不錯。

最終20次迭代測試集最高為96.11%。

最後

pytorch構建模型比較簡單,代碼看起來也很清晰,文檔支持的比較全面。