论文阅读《A Transductive Multi-Head Model for Cross-Domain Few-Shot Learning》

  • 2021 年 4 月 21 日
  • AI

好久不见哈 一下子就快月底啦 (已经满心欢喜期待五一啦嘻嘻)
最近更新都是围绕域适应 20/21 较新的论文(arxiv上的)
大都数网上还没有出现解读材料,故记录仅自我理解,若有偏差可简信交流。

论文名称:
《A Transductive Multi-Head Model for Cross-Domain Few-Shot Learning》
论文地址://arxiv.org/abs/2006.11384v1
论文代码://github.com/leezhp1994/TMHFS
本篇文章只记录个人阅读论文的笔记,具体翻译、代码等不展开,详细可见上述的链接.

Background

之前的论文阅读中提过了几次小样本域适应问题的背景(提出),这边就不再详细叙述,简单摘录几句。
The main challenge of cross-domain few-shot learning lies in the cross domain divergences in both the input
data space and the output label space;
主要挑战在于输入数据空间和输出标签空间中的跨域差异;

Work

In this paper, we present a new method, Transductive Multi-Head Few-Shot learning (TMHFS), to address the cross-domain few-shot learning challenge
针对跨域问题,提出了TMHFS模型(Transductive Multi-Head Few-Shot learning),转导(直推)多头小样本学习模型。
TMHFS is based on the Meta-Confidence Transduction (MCT) and Dense Feature-Matching Networks (DFMN) method
It extends the transductive model by adding an instance-wise global classification network based on the
semantic information, after the common feature embedding network as a new prediction “head”.
多头:也就是说整个模型的基础是MCT和DFMN,在这个基础上加入了一个基于语义信息的实例全局分类网络,将其公共特征嵌入网络作为一种新的预测。(“两头变成三头”)

Model

Problem Statement
In cross-domain few-shot learning setting, we have a source domain S = {Xs, Ys} from a total Cg classes and a target domain T = {Xt, Yt} from a set of totally different classes.(跨域问题定义)
The two domains have different marginal distributions in the input feature space, and disjoint output class sets.

模型如上图所示,整个模型包含三个过程以及三个头。
三个过程:
train:根据图中箭头可以看出,训练过程三个头都使用了,即使用MCT(基于距离的实例元训练分类器)、DFMN(像素分类器)和基于全局信息的语义分类器来训练嵌入网络。
fine-tining:我们只使用语义全局分类器和目标域中的支持实例来微调模型。
test:我们使用MCT部分,即元训练的实例分类器,用微调的嵌入网络来预测查询集的标签。

三个头
MCT
可参考此文《 Transductive few-shot learning with meta-learned confidence》
//arxiv.org/pdf/2002.12017.pdf
The MCT uses distance based prototype classifier to make pre�diction for the query instances
(使用基于距离的原型分类器对查询实例进行预测,感觉有点是基于原型网络的基础)

DFMN
used solely in the training stage(看公式和度量学习有点类似,但这应该是直推/转导学习的通用,后续看还需要看一下这方面的知识。)

值得注意的是,以上两个基础结构共享相同的特征提取网络{f_θ}
a new global instance-wise prediction head
对于这个预测头,我们考虑了所有Cg类上的全局分类问题。如图1所示,支持集和查询集都用于作为该分支的训练输入,例如:

Loss
Training stage.
The purpose of training is to pre-train an embedding model fθ (i.e., the feature extractor) in the source domain.

Fine-tuning stage:
Given a few-shot learning task in the target domain, we fine-tune the embedding model fθ on the
support set by using only the instance-wise prediction head fδ, aiming to adapt fθ to the target domain data

Experiments

总的来说整个模型是基于转导/直推学习上的集成(多头)
从实验上来看效果也还不错
Ending~


希望四月依旧好运 加油呀!五月