PyTorch的元學習庫:Torchmeta

Torchmeta是擴展和數據載入器的集合,用於在PyTorch中進行少量學習和元學習。Torchmeta在2019年全球PyTorch夏季黑客馬拉松上獲得了最佳表演獎。該庫是開源的,可以嘗試使用pip install torchmeta。

https://github.com/tristandeleu/pytorch-meta

什麼是元學習?

當無法訪問大量數據時會發生什麼?畢竟,與當前形式的深度學習不同,人類只有幾次演示就能非常快速,高效地學習執行新任務。僅通過少數幾個訓練示例就可以獲得性能良好的模型尤其具有挑戰性,因此需要一定水平的先驗知識才能解決任務,從而在學習新知識時有效地「領先」。例如,領域專家可以例如通過正則化或體系結構選擇將這種先驗知識明確引入模型中。

或者,可以從過去的經驗中獲得此先驗知識;這是元學習中採用的方法。在元學習中,利用從許多不同的元訓練任務中獲得的經驗,目的是提高在新的下游任務上的表現和學習效率(即,必要的訓練示例的數量)。

創建元學習的「健身房」

創建Torchmeta的動機是為了促進對不同數據集上的元學習演算法進行評估,並儘可能減少更改。它的設計靈感來自OpenAI Gym,它通過提供適用於多種環境的通用介面,使強化學習變得更加容易。Gym作為標準工具的採用,使大多數開源項目都可以不受環境選擇的影響,並且可以無縫測試多個環境。

同樣,Torchmeta在統一的介面下引入了數據載入器,以處理各種標準的幾次鏡頭分類和回歸問題。從1.3版開始,Torchmeta中提供以下數據集:

  • 快速回歸

-正弦波(Finn等,2017)

-諧波函數(Lacoste等,2018)

-正弦和直線(Finn等,2018)

  • 為數不多的鏡頭分類(圖片分類)

– Omniglot(湖等人,2015年,2019)

– (迷你ImageNet 。Vinyals等人在2016年,。拉維等人,2017年)

-分層-ImageNet(Ren等人,2018。)

-CIFAR-FS(Bertinetto等,2018)

-Fewshot-CIFAR100(Oreshkin等,2018)

-加州理工學院-UCSD鳥類(Hilliard等,2018,Chen等,2019)

-雙重和三重MNIST (2019年,星期日)

在Omniglot(左),Mini-ImageNet(中)和Caltech-UCSD Birds(右)上進行5次5拍學習問題的任務示例。

所有這些數據載入器都與PyTorch生態系統完全兼容,包括PyTorch DataLoader和torchvision軟體包。根據相應的數據集隨機生成一批任務,每個任務包含一個培訓和一個測試數據集-這是元學習中的常見做法。儘管可以完全控制數據載入器的定義方式,但Torchmeta還包括適用於最受歡迎基準的幫助程式功能,以及文獻中有用的默認值。

https://pytorch.org/docs/master/torchvision/

from torchmeta.datasets.helpers import omniglot  from torchmeta.utils.data import BatchMetaDataLoader    dataset = omniglot("data", ways=5, shots=5, test_shots=15, meta_train=True, download=True)  dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)    for batch in dataloader:      train_inputs, train_targets = batch["train"]      print('Train inputs shape: {0}'.format(train_inputs.shape))    # (16, 25, 1, 28, 28)      print('Train targets shape: {0}'.format(train_targets.shape))  # (16, 25)        test_inputs, test_targets = batch["test"]      print('Test inputs shape: {0}'.format(test_inputs.shape))      # (16, 75, 1, 28, 28)      print('Test targets shape: {0}'.format(test_targets.shape))    # (16, 75)

一個為5次5通道Omniglot數據集創建數據載入器的最小示例。數據載入器載入一批隨機生成的任務,並將所有樣本串聯到一個張量中。

元學習模組

除了數據載入器之外,Torchmeta還提供了PyTorch的擴展nn.Module,稱為MetaModule,以簡化某些元學習演算法的實現。這些元模組使可以選擇使用完整的計算圖手動指定模組的參數。例如,這允許通過更新參數進行反向傳播,這些參數是基於梯度的元學習方法的關鍵成分(Finn等人,2017 ; Finn,2018 ; Grant等人,2018 ; Lee等人等人,2019年 ; Raghu等人,2019年)和各種混合方法(Rusu等人,2018年 ; Zintgraf等人,2019年)。

MetaLinear元模組的插圖,nn.Linear的擴展。左:MetaLinear元模組的實例化。中:默認行為,等效於nn.Linear。右:具有額外參數的行為(此處為一步式漸變更新,Finn等人,2017年)。漸變表示為虛線箭頭。

默認情況下(即沒有額外的參數),元模組的行為與其在PyTorch中對應的模組相同。因此,創建與這些元學習方法兼容的模型對於Torchmeta來說非常自然,並且只需對現有PyTorch模型進行最小的更改即可。還可以將元模組與標準nn.Module實例進行交錯,以僅對模型的某些部分進行快速適應(Raghu等人,2019)。

import torch.nn as nn  from torchmeta.modules import (MetaModule, MetaSequential,                                 MetaConv2d, MetaLinear)  from torchmeta.modules.utils import get_subdict    class Model(MetaModule):      def __init__(self, in_channels, num_classes):          super(Model, self).__init__()          self.features = MetaSequential(MetaConv2d(in_channels, 64, 3),                                         nn.ReLU(),                                         nn.MaxPool2d(2))          self.classifier = MetaLinear(64, num_classes)        def forward(self, inputs, params=None):          features = self.features(inputs,                                   params=get_subdict(params, 'features'))          logits = self.classifier(features.view((inputs.size(0), -1)),                                   params=get_subdict(params, 'classifier'))          return logits

提高元學習研究的可重複性

由於缺乏文獻中所使用的某些數據集的標準,因此元學習的可重複性可能非常具有挑戰性,尤其是在數據載入方面。例如,雖然Vinyals等人介紹了Mini-ImageNet數據集。(2016),Ravi&Larochelle(2017)使用的分割現在已被社區廣泛接受為官方數據集。到目前為止,對於某些數據集(例如CUB),這種情況仍然存在。很難跟蹤應評估元學習演算法的「正確」版本。

藉助眾多速記學習和元學習數據集以及標準拆分,Torchmeta的目標是提供所有必要的工具,以使元學習演算法的開發和可重複性儘可能容易且可訪問。

結論

PyTorch中元學習的未來是光明的,並且最近發布了許多偉大的開源項目。特別要提到兩個,learn2learn提供了一些標準元學習演算法的實現,而更高版本的則是一個庫,可以對現有PyTorch模型進行高階優化。Torchmeta通過為各種數據集提供統一的介面以及一組簡化元學習演算法開發的工具,很好地補充了這些其他庫。這使得在不同基準上對這些方法的評估變得無縫,因此是在元學習中更好地再現研究的關鍵一步。

要了解有關Torchmeta的更多資訊,請查看項目存儲庫中可用的示例以及MAML的此實現,以更詳細地展示Torchmeta的所有功能。

https://github.com/tristandeleu/pytorch-meta/tree/master/examples

https://github.com/tristandeleu/pytorch-maml