Pytorch系列:(二)數據載入
DataLoader
DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,
batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,
drop_last=False,timeout=0,work_init_fn=None)
常用參數說明:
-
dataset: Dataset類 ( 詳見下文數據集構建 ),可以自定義數據集或者讀取pytorch自帶數據集
-
batch_size: 每個batch載入多少個樣本, 默認1
-
shuffle: 是否順序讀取,True表示隨機打亂,默認False
-
sampler:定義從數據集中提取樣本的策略。如果指定,則忽略shuffle參數。
-
batch_sampler: 定義一個按照batch_size大小返回索引的取樣器。取樣器詳見下文Batch_Sampler
-
num_workers: 數據讀取進程數量, 默認0
-
collate_fn: 自定義一個函數,接收一個batch的數據,進行自定義處理,然後返回處理後這個batch的數據。例如改變數據類型:
def my_collate_fn(batch_data):
x_batch = []
y_batch = []
for x,y in batch_data:
x_batch.append(x.float())
y_batch.append(y.int())
return x_batch,y_batch
-
pin_memory:設置pin_memory=True,則意味著生成的Tensor數據最開始是屬於記憶體中的鎖頁記憶體,這樣將記憶體的Tensor轉義到GPU的顯示記憶體就會更快一些。默認為False.
主機中的記憶體,有兩種,一種是鎖頁,一種是不鎖頁。鎖頁記憶體存放的內容在任何情況下都不會與主機的虛擬記憶體 (硬碟)進行交換,而不鎖頁記憶體在主機記憶體不足時,數據會存放在虛擬記憶體中。注意顯示卡中的顯示記憶體全部都是鎖業記憶體。如果電腦記憶體充足的話,設置為True可以加快數據交換順序。
-
drop_last:默認False, 最後剩餘數據量不夠batch_size時候,是否丟棄。
-
timeout: 設置數據讀取的時間限制,超過限制時間還未完成數據讀取則報錯。數值必須大於等於0
數據集構建
自定義數據集
自定義數據集,需要繼承torch.utils.data.Dataset
,然後在__getitem__()
中,接受一個索引,返回一個樣本, 基本流程,首先在__init__()
載入數據以及做一些處理,在__getitem__()
中返回單個數據樣本,在__len__()
中,返回樣本數量
import torch
import torch.utils.data.dataset as Data
class MyDataset(Data.Dataset):
def __init__(self):
self.x = torch.randn((10,20))
self.y = torch.tensor([1 if i>5 else 0 for i in range(10)],
dtype=torch.long)
def __getitem__(self,idx):
return self.x[idx],self.y[idx]
def __len__(self):
return self.x.__len__()
torchvision數據集
pytorch自帶torchvision庫可以幫助我們方便快捷的讀取和載入數據
import torch
from torchvision import datasets, transforms
# 定義一個預處理方法
transform = transforms.Compose([transforms.ToTensor()])
# 載入一個自帶數據集
trainset = datasets.MNIST('/pytorch/MNIST_data/', download=True, train=True,
transform=transform)
TensorDataset
注意這裡的tensor必須是一維度的數據。
import torch.utils.data as Data
x = torch.tensor([1,2,3,4,5])
y = torch.tensor([0,0,0,1,1])
dataset = Data.TensorDataset(x,y)
從文件夾中載入數據集
如果想要載入自己的數據集可以這樣,用貓狗數據集舉例,根目錄下 ( "data/train" )
,分別放置兩個文件夾,dog和cat,這樣使用ImageFolder函數就可以自動的將貓狗照片自動的按照文件夾定義為貓狗兩個標籤
import torch
from torchvision import datasets, transforms
data_dir = "data/train"
transform = transforms.Compose([transforms.Resize(255),transforms.ToTensor()])
dataset = datasets.ImageFolder(data_dir, transform=transform)
數據集操作
數據拼接
連接不同的數據集以構成更大的新數據集。
class torch.utils.data.ConcatDataset( [datasets, … ] )
newDataset = torch.utils.data.ConcatDataset([dataset1,dataset2])
數據切分
方法一: class torch.utils.data.Subset(dataset, indices)
取指定一個索引序列對應的子數據集。
from torch.utils.data import Subset
train_set = Subset(dataset,[i for i in range(1,100)]
test_set = Subset(test0_ds,[i for i in range(100,150)]
方法二:torch.utils.data.random_split(dataset, lengths)
from torch.utils.data import random_split
train_set, test_set = random_split(dataset,[100,50])
取樣器
所有取樣器都在 torch.utils.data
中,取樣器會根據該有的策略返回一組索引,在DataLoader中設定了取樣器之後,會根據索引讀取相應的樣本, 不同取樣器生成的索引不一樣,從而實現不同的取樣目的。
Sampler
所有取樣器的基類,自定義取樣器的時候需要實現 __iter__()
函數
class Sampler(object):
"""
Base class for all Samplers.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
RandomSampler
RandomSampler,當DataLoader的shuffle
參數為True時,系統會自動調用這個取樣器,實現打亂數據。默認的是採用SequentialSampler,它會按順序一個一個進行取樣。
SequentialSampler
按順序取樣,當DataLoader的shuffle
參數為False時,使用的就是SequentialSampler。
SubsetRandomSampler
輸入一個列表,按照這個列表取樣。也可以通過這個取樣器來分割數據集。
BatchSampler
參數:sampler, batch_size, drop_last
每此返回batch_size
數量的取樣索引,通過設置sampler
參數來使用不同的取樣方法。
WeightedRandomSampler
參數:weights, num_samples, replacement
它會根據每個樣本的權重選取數據,在樣本比例不均衡的問題中,可用它來進行重取樣。通過weights
設定樣本權重,權重越大的樣本被選中的概率越大,待選取的樣本數目一般小於全部的樣本數目。num_samples
為返回索引的數量,replacement
表示是否是放回抽樣,如果為True,表示可以重複取樣,默認為True
自定義取樣器
集成Sampler類,然後實現__iter__()
方法,比如,下面實現一個SequentialSampler類
class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)