Torchmeta:PyTorch的元学习库

  • 2019 年 10 月 7 日
  • 筆記

作者 | sharmistha chatterjee

来源 | Medium

编辑 | 代码医生团队

介绍

元学习研究和开放源代码库提供了一种通过标准化基准和各种可用数据集对不同算法进行详细比较的方法,从而可以完全控制此评估的复杂性。但是,大多数在线可用的代码都有以下限制:

  • 数据管道通常特定于一个数据集,而对另一个数据集进行测试需要大量的返工。
  • 元学习中的基准测试由数据集组成,这给数据管道增加了一层复杂性。因此大多数元学习项目都实现了适合其方法的自己的特定数据加载组件。
  • 输入级别缺乏标准会导致围绕每种元学习算法的机制产生差异,从而使比较过程更具挑战性。

为了解决这个限制,Google AI引入了Torchmeta,这是一个基于PyTorch深度学习框架构建的库,可以对多个数据集的元学习算法进行无缝且一致的评估。为了解释Torchmeta,使用了一些初步的概念,例如DataLoader和BatchLoader,可以解释为:

DataLoader是一种通用实用程序,可用作应用程序数据获取层的一部分,以通过批处理和缓存在各种远程数据源(例如数据库或Web服务)上提供简化且一致的API。

批处理是DataLoader的主要功能。批处理加载函数接受键列表,并返回一个Promise,该Promise解析为值列表DataLoader合并在单个执行框架内发生的所有单个加载(一旦解决了包装承诺,即执行),然后是具有全部功能的批处理函数要求的钥匙。

  • Torchmeta具有以下功能。
  • Torchmeta通过少量的分类和回归为大多数标准基准提供了DataLoader,并提供了新的元数据集抽象。
  • 数据加载器与PyTorch的标准数据组件完全兼容,例如Dataset和DataLoader。
  • Torchmeta为所有可用的基准提供了相同的界面,从而使不同数据集之间的转换尽可能无缝。
  • Torchmeta还对PyTorch进行了一些扩展,以简化与元学习算法兼容的模型的开发,其中一些需要更高阶的区分。
  • 可用的基准有助于为开发新的元学习算法提供参考。
  • Torchmeta提供了一个框架,研究人员可以围绕该框架构建自己的元学习算法,而不是使数据管道适应其方法。
  • Torchmeta通过将元数据集与算法本身解耦来促进代码重用,从而提供了这一抽象层。

数次学习的数据加载器

快速学习很少能具有使用先验知识快速推广具有有限监督经验的新任务的能力。快速学习分为三类:

  • 数据使用先验知识来增强监督经验。
  • 该模型通过先验知识约束假设空间,
  • 算法使用先验知识来更改对假设空间中最佳假设参数的搜索。

Torchmeta在其库中具有以下内容。

  • 该库提供了与元学习文献中经典的几次快照分类和回归问题相对应的数据集。
  • 该界面旨在支持分类和回归的数据集之间的模块化,以简化对全套基准测试的评估过程。

为了平衡几次学习中固有的数据缺乏,元学习算法从称为元训练集的数据集D-meta = {D1,…,Dn}中获取一些先验知识。在几次学习中,每个元素Di仅包含几个输入/输出对(x,y),其中y取决于问题的性质。由于这些数据集可以包含过去执行的不同任务的示例。Torchmeta提供了一种解决方案,可以使用最少的问题特定组件来自动创建每个数据集Di。

极少回归

少有的回归问题中的大多数是通过不同功能的输入和输出之间的简单回归问题,其中每个功能对应一个任务。这些功能被参数化以允许任务之间的可变性,同时在各个任务之间保持不变的“主题”。例如,这些函数可以是形式为fi(x)= ai sin(x + bi)的正弦波,其中a和b在某些范围内变化。

在Torchmeta中,元训练集继承自名为MetaDataset的对象,每个数据集Di(i = 1,…,n,用户定义n)对应于该函数的特定参数选择,所有在元训练集创建时采样一次的参数。一旦知道了函数的参数,我们就可以通过在给定范围内对输入进行采样并将其提供给函数来创建数据集。

少拍分类

对于少有的分类问题,数据集Di的创建通常遵循两个步骤:

  • 前N个类别是从大量候选项中取样的(对应于“ N向分类”中的N)。
  • 在下一步中,每个班级选择k个示例(对应于“ k-shot学习”中的k个)。
  • 这是一个分为两步的过程,它是作为继承自MetaDataset的CombinationMetaDataset对象的一部分而提供的,它为用户提供了针对特定问题的大量类候选者的用户规范。
  • 为了促进元学习的可重复性,每个任务都与一个唯一的标识符(类标识符的N元组)相关联。选择任务后,对象将返回数据集Di以及来自相应类集中的所有示例。
  • Torchmeta还包括一些有用的功能,以增加诸如旋转图像之类的变体来增加班级候选人的数量。

下图展示了元学习器的作用,在元测试中,另一个不相交的任务集Tt〜p(T)(p(T)->任务T的分布)用于测试元学习者。每个Tt都作用于N个数据集,其中数据集= {D train Tt,D test Tt}。学习者从训练集D train Tt和测试集D test Tt上学习。Tt的平均损耗被视为元学习测试误差。

训练和测试数据集拆分

  • 在元学习中,每个数据集Di分为两部分:训练集(或支持集),用于使模型适应当前的任务;测试集(或查询集),用于评估和元优化。
  • 当任务保持不变时,这两个部分不会重叠,在训练和测试集中都没有任何示例。
  • Torchmeta在数据集上引入了一个称为Splitter的包装器,该包装器负责创建训练和测试数据集,以及可选地对数据进行混排。

为了实例化基于Mini Imagenet的5向1发分类问题的元训练集,使用:

数据集= torchmeta.datasets.MiniImagenet(“数据”,num_classes_per_task = 5,meta_train = True,下载= True)

数据集= torchmeta.transforms.ClassSplitter(数据集,num_train_per_class = 1,num_test_per_class = 15,shuffle = True)

除了元训练集之外,大多数基准测试还提供了元测试集,用于对元学习算法的总体评估(以及可能的元验证集)。创建MetaDataset对象时,可以使用meta_test = True(或meta_val = True)而不是meta_train = True来选择这些不同的元数据集。

元数据加载器

可以迭代一些镜头分类和回归问题中的元训练集对象,以生成PyTorch数据集对象,该对象包含在任何标准数据管道(与DataLoader组合)中。

元学习算法在批次任务上运行效果更好。与在PyTorch中将示例与DataLoader一起批处理的方式类似,Torchmeta公开了一个MetaDataLoader,该对象可以在迭代时产生大量任务。这样的元数据加载器能够输出一个大张量,其中包含批处理中来自不同任务的所有示例,如下所示:

数据集= torchmeta.datasets.helpers.miniimagenet(“数据”,镜头= 1,方式= 5,meta_train = True,下载= True)

数据加载器= torchmeta.utils.data.BatchMetaDataLoader(数据集,batch_size = 16)

元学习模块

下图显示了使用学习者的损失和错误信号进行元学习的顺序步骤。

元学习者的学习步骤:来源:

https : //arxiv.org/pdf/1904.05046.pdf

在元学习中,PyTorch中的模型是由称为模块的基本组件创建的,该基本组件等效于神经网络中包含该层的计算图及其参数的一层。这些模块将其参数视为其计算图的组成部分,足以训练带有反向传播的模型。

但是,一些元学习算法需要通过参数更新(例如梯度更新)进行反向传播,以进行元优化(或“外环”),因此涉及高阶微分。

因此,适应PyTorch中的现有模块至关重要,以便它们可以处理任意计算图来替代这些参数。因此,Torchmeta扩展了现有模块,并保留了提供新参数作为附加输入的选项。这些新对象称为MetaModule,它们的默认行为(即,未指定任何其他参数)等同于它们的PyTorch对应对象。否则,如果指定了额外的参数(例如,梯度下降的一步的结果),则MetaModule会将它们视为计算图的一部分,并且反向传播将按预期进行。

  • 上图描述了带有或不带有附加参数的线性模块(称为MetaLinear)的扩展如何工作,以及对梯度的影响。
  • 左图显示了元模块作为参数W和b的容器的实例,以及带有占位符的重量和偏差参数的计算图。
  • 中间的图显示了MetaLinear元模块的默认行为,其中的占位符用W&b替换,这等效于PyTorch的Linear模块。
  • 右图显示了如何使用完整的计算图填充这些占位符,就像一个梯度下降步骤。在后一种情况下,外循环更新中必需的外循环相对于W的坡度可以正确地一直流到参数W。

下面的代码演示了如何从Torchmeta的现有数据集中生成训练,验证和测试元数据集。

from torchmeta.datasets import Omniglot, MiniImagenet, CIFARFS, FC100, TieredImagenet, TCGA  from torchmeta.transforms import Categorical, ClassSplitter, Rotation  from torchvision.transforms import Compose, Resize, ToTensor  from torchmeta.utils.data import BatchMetaDataLoader    dataset = Omniglot("data",                     # Number of ways                     num_classes_per_task=5,                     # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)                     transform=Compose([Resize(28), ToTensor()]),                     # Transform the labels to integers (e.g. ("Glagolitic/character01", "Sanskrit/character14", ...) to (0, 1, ...))                     target_transform=Categorical(num_classes=5),                     # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)                     class_augmentations=[Rotation([90, 180, 270])],                     meta_train=True,                     download=True)    dataset = ClassSplitter(dataset, shuffle=True, num_train_per_class=5, num_test_per_class=15)  dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)  for batch in dataloader:      train_inputs, train_targets = batch["train"]      print('Train inputs shape: {0}'.format(train_inputs.shape))    # (16, 25, 1, 28, 28)      print('Train targets shape: {0}'.format(train_targets.shape))  # (16, 25)        test_inputs, test_targets = batch["test"]      print('Test inputs shape: {0}'.format(test_inputs.shape))      # (16, 75, 1, 28, 28)      print('Test targets shape: {0}'.format(test_targets.shape))    # (16, 75)

下图显示了下载后从Omnichlot和MiniImagenet从Torchmeta的数据集中生成的元学习数据集。

此处Omniglot数据集包含50个字母。将其分为30个字母的背景集和20个字母的评估集。在将背景大小调整为28×28张量后,应该使用背景集学习有关字符的一般知识(例如,特征学习,元学习)。此外,将标签传送到整数Glagolitic / character01”,“ Sanskrit / character14”,……)到(0,1,..,n)。

MiniImageNet包含60,000个84×84 RGB图像,每个类别600个图像。使用Torchmeta,可以生成HDF5格式的元学习数据集。

Torchmeta具有以HDF5格式下载数据集的功能,该功能允许:

  1. 要将包含HDF5文件的文件夹(包括子文件夹)用作数据源,
  2. 在数据集中维护一个简单的HDF5组层次结构,
  3. 启用延迟数据加载(即应DataLoader的请求),以便允许使用不适合内存的数据集,
  4. 配备了数据缓存以加快数据加载过程,并且
  5. 允许对源或目标数据集进行自定义转换。

用于定义Torchmeta数据集(例如Omniglot)的元学习参数的TieredImagenetClassDataset包含来自34个类别的图像。元训练/验证/测试拆分超过20/6/8个类别。每个类别包含10到30个类别。按类别划分(而不是按类别划分)可确保所有训练课程与测试课程完全不同(不同于Mini-Imagenet)。它带有以下一组参数,这些参数定义了训练,验证和测试数据集的划分以及应用于它们的转换和增强技术

num_classes_per_task(int):每个任务的类数,对应于“ N向”分类中的“ N”。

meta_train:bool(`False`):使用数据集的元火车拆分。如果设置为True,则必须将参数meta_val和meta_test设置为False。这三个参数中的一个必须正确设置为“ True”。

meta_val:bool(`False`):使用数据集的元验证拆分。如果设置为True,则参数meta_train和metatest必须设置为False。这三个参数中只有一个必须设置为“ True”。

meta_test:bool(`False`):使用数据集的元测试拆分。如果设置为True,则参数meta_train和meta_val必须设置为False。这三个参数中只有一个必须设置为“ True”。

meta_split:{'train','val','test'}中的字符串,可选要使用的拆分名称,如果所有三个都设置为False,则覆盖参数meta_train,metaval和metatest。

transform:可调用的,可选的:获取“ PIL”图像并返回转换后版本的函数/转换。

target_transform:可调用,可选:接受目标并返回转换版本的函数/转换。

dataset_transform:可调用,可选:函数/转换,它接受数据集(即任务),并返回其转换后的版本。-> torchmeta.transforms.ClassSplitter()。

class_augmentations:可调用的,可选的列表:使用新类扩展数据集的函数列表。这些类是现有类的转换。

download:bool(默认值:False)如果为True,则下载pickle文件并处理根目录(位于tieredimagenet文件夹下)中的数据集。如果数据集已经可用,则不会再次下载/处理数据集。

结论

在此博客中,了解了Google AI最新发布的库Torchmeta,它提供了哪些功能以及可以解决什么样的元学习问题。可以浏览其他PyTorch元学习库,例如元Agonistic机器学习,以学习可以快速适应新任务的网络初始化。

https://github.com/dragen1860/MAML-Pytorch

如下图所示,在Torchmeta中很少有镜头学习可用于图像分类。

参考

https://github.com/markdtw/meta-learning-lstm-pytorch

https://arxiv.org/abs/1909.06576

https://docs.graphene-python.org/en/latest/execution/dataloader/