0702-電腦視覺工具包torchvision

0702-電腦視覺工具包torchvision

pytorch完整教程目錄://www.cnblogs.com/nickchen121/p/14662511.html

一、torchvision 概述

電腦視覺是深度學習中最重要的一類應用,為了方便研究者使用,torch 專門開發了一個視覺工具包 torchvision,這個包獨立於 torch,需要使用 pip install torchvision 進行安裝。

之前的我們已經使用過它的部分功能,在這裡我們在做一個系統的介紹,它主要包含以下三個功能:

  • models:提供深度學習中各種經典網路的網路結構以及訓練好的模型,包括 Alex-Net、VGG 系列、ResNet 系列、Inception 系列等
  • datasets:提供常用的數據集載入,設計上都是集成 torch.utils.data.Dataset,主要包括 MNIST、CIFAR10/100、ImageNet、COCO 等
  • transforms:提供常用的數據預處理操作,主要包括對 Tensor 以及 PIL Image 對象的操作

二、通過 torchvision 載入模型

from torchvision import models
from torch import nn

# 載入預訓練好的模型,如果不存在會下載
# 預訓練好的模型保存在 ~/.torch/modes/ 下面
resnet34 = models.resnet34(pretrained=True, num_classes=1000)

# 修改最後的全連接層為 10 分類問題(默認是 ImageNet 上的 1000 分類)
resnet34.fc = nn.Linear(512, 10)

三、通過 torchvision 載入並處理數據集

from torchvision import datasets
from torchvision import transforms as T
# 指定數據集路徑為 data,如果數據集不存在則進行下載
# 通過 train=False 獲取測試集

normalize = T.Normalize(mean=[0.4, 0.4, 0.4], std=[0.2, 0.2, 0.2])
transform = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),  # 把圖片轉成 Tensor,歸一化至 [0,1]
    T.Lambda(lambda x: x.repeat(3, 1, 1)),  # 把圖片轉為 3 通道的
    normalize,
])

dataset = datasets.MNIST('data/',
                         download=True,
                         train=False,
                         transform=transform)

Transforms 中涵蓋了大部分對 Tensor 和 PIL Image 的常用處理,這個轉換通常分為兩步:

  1. 第一步:構建轉換操作,例如 transf = transforms.Normalize(mean=x, std=y)
  2. 第二步:執行轉換操作,例如 otuput = transf(inp)
import torch as t

# 構建隨機雜訊,圖片如下圖所示
to_pil = T.ToPILImage()
to_pil(t.rand(3, 64, 64))

四、通過 torchvision 拼接並保存圖片

torchvision 還提供了兩個常用的函數:

  1. make_grid,它能把多張圖片拼接在一個網格中
  2. save_img,它能把 Tensor 保存成圖片
len(dataset)
10000
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, shuffle=True, batch_size=16)
from torchvision.utils import make_grid, save_image
dataiter = iter(dataloader)
dataiter
img = make_grid(next(dataiter)[0], 4)  # 拼接成 4*4 網格圖片,並且會轉成 3 通道,如下圖所示
to_img = T.ToPILImage()
to_img(img)

save_image(img, 'a.png')
from PIL import Image
Image.open('a.png')