BYOL:輕鬆進行自監督學習
譯者:AI研習社(季一帆)
雙語原文鏈接:Easy Self-Supervised Learning with BYOL
註:本文所有代碼可見Google Colab notebook,你可用Colab的免費GPU運行或改進。
自監督學習
在深度學習中,經常遇到的問題是沒有足夠的標記數據,而手工標記數據耗費大量時間且人工成本高昂。基於此,自我監督學習成為深度學習的研究熱點,旨在從未標記樣本中進行學習,以緩解數據標註困難的問題。子監督學習的目標很簡單,即訓練一個模型使得相似的樣本具有相似的表示,然而具體實現卻困難重重。經過谷歌這樣的諸多先驅者若干年的研究,子監督學習如今已取得一系列的進步與發展。
在BYOL之前,多數自我監督學習都可分為對比學習或生成學習,其中,生成學習一般GAN建模完整的數據分佈,計算成本較高,相比之下,對比學習方法就很少面臨這樣的問題。對此,BYOL的作者這樣說道:
通過對比方法,同一圖像不同視圖的表示更接近(正例),不同圖像視圖的表示相距較遠(負例),通過這樣的方式減少表示的生成成本。
為了實現對比方法,我們必須將每個樣本與其他許多負例樣本進行比較。然而這樣會使訓練很不穩定,同時會增大數據集的系統偏差。BYOL的作者顯然明白這點:
對比方法對圖像增強的方式非常敏感。例如,當消除圖像增強中的顏色失真時,SimCLR表現不佳。可能的原因是,同一圖像的不同裁切一般會共享顏色直方圖,而不同圖像的顏色直方圖是不同的。因此,在對比任務中,可以通過關注顏色直方圖,使用隨機裁切方式實現圖像增強,其結果表示幾乎無法保留顏色直方圖之外的信息。
不僅僅是顏色失真,其他類型的數據轉換也是如此。一般來說,對比訓練對數據的系統偏差較為敏感。在機器學習中,數據偏差是一個廣泛存在的問題(見facial recognition for women and minorities),這對對比方法來說影響更大。不過好在BYOL不依賴負採樣,從而很好的避免了該問題。
BYOL:Bootstrap Your Own Latent(發掘自身潛能)
BYOL的目標與對比學習相似,但一個很大的區別是,BYOL不關心不同樣本是否具有不同的表徵(即對比學習中的對比部分),僅僅使相似的樣品表徵類似。看上去似乎無關緊要,但這樣的設定會顯著改善模型訓練效率和泛化能力:
-
由於不需要負採樣,BLOY有更高的訓練效率。在訓練中,每次遍歷只需對每個樣本採樣一次,而無需關注負樣本。
-
BLOY模型對訓練數據的系統偏差不敏感,這意味着模型可以對未見樣本也有較好的適用性。
BYOL最小化樣本表徵和該樣本變換之後的表徵間的距離。其中,不同變換類型包括0:平移、旋轉、模糊、顏色反轉、顏色抖動、高斯噪聲等(我在此以圖像操作來舉例說明,但BYOL也可以處理其他數據類型)。至於是單一變換還是幾種不同類型的聯合變換,這取決於你自己,不過我一般會採用聯合變換。但有一點需要注意,如果你希望訓練的模型能夠應對某種變換,那麼用該變換處理訓練數據時必要的。
手把手教你編碼BYOL
首先是數據轉換增強的編碼。BYOL的作者定義了一組類似於SimCLR的特殊轉換:
import random from typing import Callable, Tuple from kornia import augmentation as aug class RandomApply(nn.Module): def forward(self, x: Tensor) -> Tensor: def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module: |
上述代碼通過Kornia實現數據轉換,這是一個基於 PyTorch 的可微分的計算機視覺開源庫。當然,你可以用其他開源庫實現數據轉換擴充,甚至是自己編寫。實際上,可微分性對BYOL而言並沒有那麼必要。
接下來,我們編寫編碼器模塊。該模塊負責從基本模型提取特徵,並將這些特徵投影到低維隱空間。具體的,我們通過wrapper類實現該模塊,這樣我們可以輕鬆將BYOL用於任何模型,無需將模型編碼到腳本。該類主要由兩部分組成:
特徵抽取,獲取模型最後一層的輸出。
映射,非線性層,將輸出映射到更低維空間。
特徵提取通過hooks實現(如果你不了解hooks,推薦閱讀我之前的介紹文章How to Use PyTorch Hooks)。除此之外,代碼其他部分很容易理解。
from typing import Union
def mlp(dim: int, projection_size: int = 256, hidden_size: int = 4096) -> nn.Module: class EncoderWrapper(nn.Module): self._projector = None @property def _hook(self, _, __, output): def _register_hook(self): layer.register_forward_hook(self._hook) def forward(self, x: Tensor) -> Tensor: |
BYOL包含兩個相同的編碼器網絡。第一個編碼器網絡的權重隨着每一訓練批次進行更新,而第二個網絡(稱為「目標」網絡)使用第一個編碼器權重均值進行更新。在訓練過程中,目標網絡接收原始批次訓練數據,而另一個編碼器則接收相應的轉換數據。兩個編碼器網絡會分別為相應數據生成低維表示。然後,我們使用多層感知器預測目標網絡的輸出,並最大化該預測與目標網絡輸出之間的相似性。
圖源:Bootstrap Your Own Latent, Figure 2
也許有人會想,我們不是應該直接比較數據轉換之前和之後的隱向量表徵嗎?為什麼還有設計多層感知機?假設沒有MLP層的話,網絡可以通過將權重降低到零方便的使所有圖像的表示相似化,可這樣模型並沒有學到任何有用的東西,而MLP層可以識別出數據轉換並預測目標隱向量。這樣避免了權重趨零,可以學習更恰當的數據表示!
訓練結束後,捨棄目標網絡編碼器,只保留一個編碼器,根據該編碼器,所有訓練數據可生成自洽表示。這正是BYOL能夠進行自監督學習的關鍵!因為學習到的表示具有自洽性,所以經不同的數據變換後幾乎保持不變。這樣,模型使得相似示例的表示更加接近!
接下來編寫BYOL的訓練代碼。我選擇使用Pythorch Lightning開源庫,該庫基於PyTorch,對深度學習項目非常友好,能夠進行多GPU培訓、實驗日誌記錄、模型斷點檢查和混合精度訓練等,甚至在cloud TPU上也支持基於該庫運行PyTorch模型!
from copy import deepcopy from itertools import chain from typing import Dict, List import pytorch_lightning as pl def normalized_mse(x: Tensor, y: Tensor) -> Tensor: class BYOL(pl.LightningModule): self.encoder(torch.zeros(2, 3, *image_size)) def forward(self, x: Tensor) -> Tensor: @property def update_target(self): # — Methods required for PyTorch Lightning only! — def configure_optimizers(self): def training_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]: pred1, pred2 = self.forward(x1), self.forward(x2) self.log(“train_loss”, loss.item()) @torch.no_grad() return {“loss”: loss} @torch.no_grad() |
上述代碼部分源自Pythorch Lightning提供的示例代碼。這段代碼你尤其需要關注的是training_step,在此函數實現模型的數據轉換、特徵投影和相似性損失計算等。
實例說明
下文我們將在STL10數據集上對BYOL進行實驗驗證。因為該數據集同時包含大量未標記的圖像以及標記的訓練和測試集,非常適合無監督和自監督學習實驗。STL10網站這樣描述該數據集:
STL-10數據集是一個用於研究無監督特徵學習、深度學習、自學習算法的圖像識別數據集。該數據集是對CIFAR-10數據集的改進,最明顯的便是,每個類的標記訓練數據比CIFAR-10中的要少,但在監督訓練之前,數據集提供大量的未標記樣本訓練模型學習圖像模型。因此,該數據集主要的挑戰是利用未標記的數據(與標記數據相似但分佈不同)來構建有用的先驗知識。
通過Torchvision可以很方便的加載STL10,因此無需擔心數據的下載和預處理。
from torchvision.datasets import STL10 from torchvision.transforms import ToTensor TRAIN_DATASET = STL10(root=”data”, split=”train”, download=True, transform=ToTensor()) |
同時,我們使用監督學習方法作為基準模型,以此衡量本文模型的準確性。基線模型也可通過Lightning模塊輕易實現:
class SupervisedLightningModule(pl.LightningModule): def __init__(self, model: nn.Module, **hparams): super().__init__() self.model = model def forward(self, x: Tensor) -> Tensor: def configure_optimizers(self): def training_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]: @torch.no_grad() @torch.no_grad() |
可以看到,使用Pythorch Lightning可以方便的構建並訓練模型。只需為訓練集和測試集創建DataLoader
對象,將其導入需要訓練的模型即可。本實驗中,epoch設置為25,學習率為1e-4。
from os import cpu_count
from torch.utils.data import DataLoader model = resnet18(pretrained=True) |
經訓練,僅通過一個非常小的模型ResNet18就取得約85%的準確率。但實際上,我們還可以做得更好!
接下來,我們使用BYOL對ResNet18模型進行預訓練。在這次實驗中,我選擇epoch為50,學習率依然是1e-4。註:該過程是本文代碼耗時最長的部分,在K80 GPU的標準Colab中大約需要45分鐘。
model = resnet18(pretrained=True) byol = BYOL(model, image_size=(96, 96)) trainer = pl.Trainer( max_epochs=50, gpus=-1, accumulate_grad_batches=2048 // 128, weights_summary=None, ) train_loader = DataLoader( TRAIN_UNLABELED_DATASET, batch_size=128, shuffle=True, drop_last=True, ) trainer.fit(byol, train_loader, val_loader) |
然後,我們使用新的ResNet18模型重新進行監督學習。(為徹底清除BYOL中的前向hook,我們實例化一個新模型,在該模型引入經過訓練的狀態字典。)
# Extract the state dictionary, initialize a new ResNet18 model, # and load the state dictionary into the new model. # # This ensures that we remove all hooks from the previous model, # which are automatically implemented by BYOL. state_dict = model.state_dict() model = resnet18() model.load_state_dict(state_dict) supervised = SupervisedLightningModule(model) |
通過這種方式,模型準確率提高了約2.5%,達到了87.7%!雖然該方法需要更多的代碼(大約300行)以及一些庫的支撐,但相比其他自監督方法仍顯得簡潔。作為對比,可以看下官方的SimCLR或SwAV是多麼複雜。而且,本文具有更快的訓練速度,即使是Colab的免費GPU,整個實驗也不到一個小時。
結論
本文要點總結如下。首先也是最重要的,BYOL是一種巧妙的自監督學習方法,可以利用未標記的數據來最大限度地提高模型性能。此外,由於所有ResNet模型都是使用ImageNet進行預訓練的,因此BYOL的性能優於預訓練的ResNet18。STL10是ImageNet的一個子集,所有圖像都從224×224像素縮小到96×96像素。雖然分辨率發生改變,我們希望自監督學習能避免這樣的影響,表現出較好性能,而僅僅依靠STL10的小規模訓練集是不夠的。
類似ResNet這樣的模型中,ML從業人員過於依賴預先訓練的權重。雖然這在一定情況下是很好的選擇,但不一定適合其他數據,哪怕在STL10這樣與ImageNet高度相似的數據中表現也不如人意。因此,我迫切希望將來在深度學習的研究中,自監督方法能夠獲得更多的關注與實踐應用。
參考資料
//arxiv.org/pdf/2006.07733.pdf
//arxiv.org/pdf/2006.10029v2.pdf
//github.com/lucidrains/byol-pytorch
//github.com/google-research/simclr
//cs.stanford.edu/~acoates/stl10/
AI研習社是AI學術青年和AI開發者技術交流的在線社區。我們與高校、學術機構和產業界合作,通過提供學習、實戰和求職服務,為AI學術青年和開發者的交流互助和職業發展打造一站式平台,致力成為中國最大的科技創新人才聚集地。
如果,你也是位熱愛分享的AI愛好者。歡迎與譯站一起,學習新知,分享成長。