從零搭建Pytorch模型教程(一)數據讀取
前言
本文介紹了classdataset的幾個要點,由哪些部分組成,每個部分需要完成哪些事情,如何進行數據增強,如何實現自己設計的數據增強。然後,介紹了分散式訓練的數據載入方式,數據讀取的整個流程,當面對超大數據集時,記憶體不足的改進思路。
本文延續了以往的寫作態度和風格,即便是自己知道的內容,也仍然在寫之前看了很多的文章來保證內容的正確性和全面性,因此寫得極累,耗費時間較長。若有讀者看完後覺得有所幫助,文末可以讚賞一點。
文末掃描二維碼關注公眾號CV技術指南 ,專註於電腦視覺的技術總結、最新技術跟蹤、經典論文解讀,招聘資訊發布。
(零) 概述
浮躁是人性的一個典型的弱點,很多人總擅長看別人分享的現成程式碼解讀的文章,看起來學會了好多東西,實際上仍然不具備自己從零搭建一個pipeline的能力。
在公眾號(CV技術指南)的交流群里(群內交流氛圍不錯,有需要的請關注公眾號加群),常有不少人問到一些問題,根據這些問題明顯能看出是對pipeline不了解,卻已經在搞項目或論文了,很難想像如果基本的pipeline都不懂,如何分析程式碼問題所在?如何分析結果不正常的可能原因?遇到問題如何改?
Pytorch在這幾年逐漸成為了學術上的主流框架,其具有簡單易懂的特點。網上有很多pytorch的教程,如果是一個已經懂的人去看這些教程,確實pipeline的要素都寫到了,感覺這教程挺不錯的。但實際上更多地像是寫給自己看的一個筆記,記錄了pipeline要寫哪些東西,卻沒有介紹要怎麼寫,為什麼這麼寫,剛入門的小白看的時候容易雲里霧裡。
鑒於此,本教程嘗試對於pytorch搭建一個完整pipeline寫一個比較明確且易懂的說明。
本教程將介紹以下內容:
-
準備數據,自定義classdataset,分散式訓練的數據載入方式,載入超大數據集的改進思路。
-
搭建模型與模型初始化。
-
編寫訓練過程,包括載入預訓練模型、設置優化器、設置損失函數等。
-
可視化並保存訓練過程。
-
編寫推理函數。
(一)數據讀取
classdataset的定義
先來看一個完整的classdataset
classdataset的幾個要點:
-
classdataset類繼承torch.utils.data.dataset。
-
classdataset的作用是將任意格式的數據,通過讀取、預處理或數據增強後以tensor的形式輸出。其中任意格式的數據可能是以文件夾名作為類別的形式、或以txt文件存儲圖片地址的形式、或影片、或十幾幀影像作為一份樣本的形式。而輸出則指的是經過處理後的一個batch的tensor格式數據和對應標籤。
-
classdataset主要有三個函數要完成:__init__函數、__getitem__ 函數和__len__函數。
__init__函數
init函數主要是完成兩個靜態變數的賦值。一個是用於存儲所有數據路徑的變數,變數的每個元素即為一份訓練樣本,(註:如果一份樣本是十幾幀影像,則變數每個元素存儲的是這十幾幀影像的路徑),可以命名為self.filenames。一個是用於存儲與數據路徑變數一一對應的標籤變數,可以命名為self.labels。
假如數據集的格式如下:
可通過per_classes = os.listdir(data_path) 獲得所有類別的文件夾,在此處per_classes的每個元素即為對應的數據標籤,通過for遍歷per_classes即可獲得每個類的標籤,將其轉換成int的tensor形式即可。在for下獲得每個類下每張圖片的路徑,通過self.join獲得每份樣本的路徑,通過append添加到self.filenames中。
__getitem__ 函數
getitem 函數主要是根據索引返回對應的數據。這個索引是在訓練前通過dataloader切片獲得的,這裡先不管。它的參數默認是index,即每次傳回在init函數中獲得的所有樣本中索引對應的數據和標籤。因此,可通過下面兩行程式碼找到對應的數據和標籤。
獲得數據後,進行數據預處理。數據預處理主要通過 torchvision.transforms 來完成,這裡面已經包含了常用的預處理、數據增強方式。其完整使用方式在官網有詳細介紹://pytorch.org/vision/stable/transforms.html
上面這裡介紹了最常用的幾種,主要就是resize,隨機裁剪,翻轉,歸一化等。
最後通過transforms.Compose(transform_train_list)來執行。
除了這些已經有的數據增強方式外,在《
下面以隨機擦除作為例子。
如上所示,自己寫一個類RandomErasing,繼承object,在call函數里完成你的操作。在transform_train_list里添加上RandomErasing的定義即可。
transform_train_list = [ transforms.Resize((self.opt.h, self.opt.w), interpolation=3), transforms.Pad(self.opt.pad, padding_mode='edge'), transforms.RandomCrop((self.opt.h, self.opt.w)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) RandomErasing(probability=self.opt.erasing_p, mean=[0.0, 0.0, 0.0]) #添加到這裡 ]
__len__函數
len函數主要就是返回數據長度,即樣本的總數量。前面介紹了self.filenames的每個元素即為每份樣本的路徑,因此,self.filename的長度就是樣本的數量。通過return len(self.filenames)即可返回數據長度。
驗證classdataset
分散式訓練的數據載入方式
前面介紹的是單卡的數據載入,實際上分散式也是這樣,但為了高速高效讀取,每張卡上也會保存所有數據的資訊,即self.filenames和self.labels的資訊。只是在DistributedSampler 中會給每張卡分配互不交叉的索引,然後由torch.utils.data.DataLoader來載入。
dataset = My_Dataset(data_folder=data_folder) sampler = DistributedSampler(dataset) if is_distributed else None loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
數據讀取的完整流程
結合上面這段程式碼,在這裡,我們介紹以下讀取數據的整個流程。
-
首先定義一個classdataset,在初始化函數里獲得所有數據的資訊。
-
classdataset中實現getitem函數,通過索引來獲取對應的數據,然後對數據進行預處理和數據增強。
-
在模型訓練前,初始化classdataset,通過Dataloader來載入數據,其載入方式是通過Dataloader中分配的索引,調用getitem函數來獲取。
關於索引的分配,在單卡上,可通過設置shuffle=True來隨機生成索引順序;在多機多卡的分散式訓練上,shuffle操作通過DistributedSampler來完成,因此shuffle與sampler只能有一個,另一個必須為None。
超大數據集的載入思路
問題所在
再回顧一下上面這個流程,前面提到所有數據資訊在classdataset初始化部分都會保存在變數中,因此當面對超大數據集時,會出現記憶體不足的情況。
思路
將切片獲取索引的步驟放到classdataset初始化的位置,此時每張卡都是保存不同的數據子集。通過這種方式,可以將記憶體用量減少到原來的world_size倍(world_size指卡的數量)。
參考程式碼
class RankDataset(Dataset): ''' 實際流程 獲取rank和world_size 資訊 -> 獲取dataset長度 -> 根據dataset長度產生隨機indices -> 給不同的rank 分配indices -> 根據這些indices產生metas ''' def __init__(self, meta_file, world_size, rank, seed): super(RankDataset, self).__init__() random.seed(seed) np.random.seed(seed) self.world_size = world_size self.rank = rank self.metas = self.parse(meta_file) def parse(self, meta_file): dataset_size = self.get_dataset_size(meta_file) # 獲取metafile的行數 local_rank_index = self.get_local_index(dataset_size, self.rank, self.world_size) # 根據world size和rank,獲取當前epoch,當前rank需要訓練的index。 self.metas = self.read_file(meta_file, local_rank_index) def __getitem__(self, idx): return self.metas[idx] def __len__(self): return len(self.metas) ##train for epoch_num in range(epoch_num): dataset = RankDataset("/path/to/meta", world_size, rank, seed=epoch_num) sampler = RandomSampler(datset) dataloader = DataLoader( dataset=dataset, batch_size=32, shuffle=False, num_workers=4, sampler=sampler)
這一節參考鏈接://zhuanlan.zhihu.com/p/357809861
總結
本篇文章介紹了數據讀取的完整流程,如何自定義classdataset,如何進行數據增強,自己設計的數據增強如何寫,分散式訓練是如何載入數據的,超大數據集的數據載入改進思路。
相信讀完本文的讀者對數據讀取有了比較清晰的認識,下一篇將介紹搭建模型與模型初始化。
關注公眾號可加電腦視覺交流群
歡迎關注公眾號 CV技術指南 ,專註於電腦視覺的技術總結、最新技術跟蹤、經典論文解讀。
在公眾號中回復關鍵字 「入門指南「可獲取電腦視覺入門所有必備資料。
其它文章