Pytorch Dataset和Dataloader 學習筆記(二)
- 2021 年 6 月 18 日
- 筆記
Pytorch Dataset & Dataloader
Pytorch框架下的工具包中,提供了數據處理的兩個重要介面,Dataset 和 Dataloader,能夠方便的使用和按批裝載自己的數據集。
-
數據的預處理,載入數據並轉化為tensor格式
-
使用Dataset構建自己的數據
-
使用Dataloader裝載數據
【數據】鏈接://pan.baidu.com/s/1gdWFuUakuslj-EKyfyQYLA
提取碼:10d4
複製這段內容後打開百度網盤手機App,操作更方便哦
數據的預處理與載入
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
## 1. 數據的處理,載入轉化為tensor
x_data = 'X.csv'
y_data = 'y.csv'
x = np.loadtxt(x_data, delimiter=' ', dtype=np.float32)
y = np.loadtxt(y_data, delimiter=' ', dtype=np.float32).reshape(-1, 1)
x = torch.from_numpy(x[:, :])
y = torch.from_numpy(y[:, :])
torch.utils.data.Dataset
Dataset抽象類,用於包裝構建自己的數據集,該類包括三個基本的方法:
- __init__ 進行數據的讀取操作
- __getitem__ 數據集需支援索引訪問
- __len__ 返回數據集的長度
## 2. 構建自己的數據集
class Mydataset(Dataset):
def __init__(self, train_data, label_data):
self.train = train_data
self.label = label_data
self.len = len(train_data)
def __getitem__(self, item):
return self.train[item], self.label[item]
def __len__(self):
return self.len
dataset = Mydataset(x, y)
samples = dataset.__len__()
print("總樣本數:",samples)
torch.utils.data.Dataloader
Dataloader抽象類,構建可迭代的數據集裝載器,從Dataset實例對象中按batch_size裝載數據以送入訓練。包含以下幾個參數:
- batch_size 批大小
- shuffle 裝載的batch是否亂序
- drop_last 不足batch大小的最後部分是否捨去
- num_workers 是否多進程讀取數據
## 3. 創建數據集裝載器
train_loader = DataLoader(dataset=dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=4)
測試
if __name__ == "__main__":
iteration = 0
for train_data, train_label in train_loader:
print("x: ", train_data, "\ny: ", train_label)
iteration += 1
### 這裡dataloader中drop_last為True,所以迭代次數應為 samples/batch_size = 6
print("每個epoch迭代次數:",iteration)
完整程式碼
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
## 1. 數據的處理,載入轉化為tensor
x_data = 'X.csv'
y_data = 'y.csv'
x = np.loadtxt(x_data, delimiter=' ', dtype=np.float32)
y = np.loadtxt(y_data, delimiter=' ', dtype=np.float32).reshape(-1, 1)
x = torch.from_numpy(x[:, :])
y = torch.from_numpy(y[:, :])
## 2. 構建自己的數據集
class Mydataset(Dataset):
def __init__(self, train_data, label_data):
self.train = train_data
self.label = label_data
self.len = len(train_data)
def __getitem__(self, item):
return self.train[item], self.label[item]
def __len__(self):
return self.len
dataset = Mydataset(x, y)
## 3. 創建數據集裝載器
train_loader = DataLoader(dataset=dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=4)
if __name__ == "__main__":
iteration = 0
samples = dataset.__len__()
print("總樣本數:", samples)
for train_data, train_label in train_loader:
print("x: ", train_data, "\ny: ", train_label)
iteration += 1
### 這裡dataloader中drop_last為True,所以迭代次數應為 samples/batch_size = 6
print("每個epoch迭代次數:",iteration)