Pytorch數據集讀入——Dataset類,實現數據集打亂Shuffle

  • 2019 年 11 月 7 日
  • 筆記

在進行相關平台的練習過程中,由於要自己導入數據集,而導入方法在市面上五花八門,各種庫都可以應用,在這個過程中我準備嘗試torchvision的庫dataset
torchvision.datasets.ImageFolder
簡單應用起來非常簡單,用torchvision.datasets.ImageFolder實現圖片的導入,在隨後訓練過程中用Datalodar處理後可按批次取出訓練集

class ImageFolder(root, transform=None, target_transform=None, loader=default_loader, is_valid_file=None)
ImageFolder有這麼幾個參數,其中root指的是數據所在的文件夾,其中該文件夾的存儲方式應為
root/labels/xxx.jpg
即根據自身分類標籤存儲在對應標籤名的文件夾內
ImageFolder在讀入的過程中會自行加好標籤,最後形成一對對的數據
另外比較常用的就是transform,表示對於傳入圖片的預處理,如剪裁,顏色選擇等等
比如

transform_t = transforms.Compose([      transforms.Resize([64, 64]),      transforms.Grayscale(num_output_channels=1),      transforms.ToTensor()]      )

具體參數可以上網查看
在之後用DataLodar處理後雖然的確有Shuffle的參數,但是卻只是在一個小批次內進行打亂,原本是按照類別存儲的,這樣的話會導致很嚴重的過擬合,為了避免這個,我決定常識改寫一下Dataset的類(主要是看起來Dataset看起來改寫比較順手…ImageFolder還沒有看源碼並沒要對此下手)
但是Dataset需要讀入一個個的訓練數據的位置,怎麼辦呢?我就先寫了一個小腳本,生成一個txt文件來存儲所有數據的名稱(相對路徑),同時在這一步就進行打亂操作【一眼看下去甚至會發現init的classnum參數完全沒用上(捂臉

import os  import numpy as np  '''  self.target     順序存儲數據集  self.DataFile   存儲根目錄  self.s          存儲所有數據  self.label      存儲所有標籤及其對應的值  '''  class create_list():      def __init__(self,root,classnum=2):          self.target=open("./Data.txt",'w')          self.DataFile=root          self.s=[]          self.label={}          self.datanum=0        def create(self):          files=os.listdir(self.DataFile)          for labels in files:              tempdata=os.listdir(self.DataFile+"/"+labels)              self.label[labels]=len(self.label)              for img in tempdata:                  self.datanum+=1                  self.target.write(self.DataFile+"/"+labels+"/"+img+" "+labels+"n")                  self.s.append([self.DataFile+"/"+labels+"/"+img,labels])        def detail(self):          #查看數據數量以及標籤對應          print(self.datanum)          print(self.label)        def get_all(self):          #查看所有數據          print(self.s)        def get_root(self):          #獲得根目錄          return self.DataFile        def shuffle(self):          #獲得打亂的存儲txt          shuffle_file=open("./Shuffle_Data.txt",'w')          temp=self.s          np.random.shuffle(temp)          for i in temp:              shuffle_file.write(i[0]+" "+str(i[1])+"n")          return self.DataFile+"/Shuffle_Data.txt"        def label_id(self,label):          #獲得該標籤對應的值          return self.label[label]

數據集的存儲方式上的要求跟之前的ImageFolder一樣
最終會生成一個這樣的txt文件
image
數據集來源於某x光胸片判斷…
而Shuffle操作就是為了生成打亂後的txt文件,我寫的比較簡單粗暴…先將就看吧,生成後大概就是這個樣子
image
至少真正的做到打亂數據了
完成這個以後,就可以用此來幫助DataLodar了
接下來的程式碼或許比較辣眼睛…但是事實證明是有用的,但是可能Python技巧不太熟練所以就會顯得很生澀…
我重現的Dataset類:

from PIL import Image  import torch    class cDataset(torch.utils.data.Dataset):      def __init__(self, datatxt, root="", transform=None, target_transform=None, LabelDic=None):          super(cDataset,self).__init__()          files = open(root + "/" + datatxt, 'r')          self.img=[]          for i in files:              i = i.rstrip()              temp = i.split()              if LabelDic!=None:                  self.img.append((temp[0],LabelDic[temp[1]]))              else:                  self.img.append((temp[0],temp[0]))            self.transform = transform          self.target_transform = target_transform        def __getitem__(self, index):          files, label = self.img[index]          img = Image.open(files).convert('RGB')          if self.transform is not None:              img = self.transform(img)          return img,label        def __len__(self):          return len(self.img)

其實直接看就能大概看明白,主要也就是要實現類裡面的幾個方法

class cDataset(torch.utils.data.Dataset):      def __init__():      def __getitem__(self, index):      def __len__(self):

其中getitm類似一次次的取出數據,len就是返回數據集數目
其中init的參數我做了稍許調整,由於我之前的txt內標籤是字元串,而為了能讓對應生成的tag是所要求的,可以傳入一個字典,如:
LabelDic={"NORMAL":0,"PNEUMONIA":1}
這樣就可以在之後轉化為數字的標籤,onehot或者怎麼怎麼樣了,,,