論文閱讀《A Transductive Multi-Head Model for Cross-Domain Few-Shot Learning》

  • 2021 年 4 月 21 日
  • AI

好久不見哈 一下子就快月底啦 (已經滿心歡喜期待五一啦嘻嘻)
最近更新都是圍繞域適應 20/21 較新的論文(arxiv上的)
大都數網上還沒有出現解讀材料,故記錄僅自我理解,若有偏差可簡信交流。

論文名稱:
《A Transductive Multi-Head Model for Cross-Domain Few-Shot Learning》
論文地址://arxiv.org/abs/2006.11384v1
論文程式碼://github.com/leezhp1994/TMHFS
本篇文章只記錄個人閱讀論文的筆記,具體翻譯、程式碼等不展開,詳細可見上述的鏈接.

Background

之前的論文閱讀中提過了幾次小樣本域適應問題的背景(提出),這邊就不再詳細敘述,簡單摘錄幾句。
The main challenge of cross-domain few-shot learning lies in the cross domain divergences in both the input
data space and the output label space;
主要挑戰在於輸入數據空間和輸出標籤空間中的跨域差異;

Work

In this paper, we present a new method, Transductive Multi-Head Few-Shot learning (TMHFS), to address the cross-domain few-shot learning challenge
針對跨域問題,提出了TMHFS模型(Transductive Multi-Head Few-Shot learning),轉導(直推)多頭小樣本學習模型。
TMHFS is based on the Meta-Confidence Transduction (MCT) and Dense Feature-Matching Networks (DFMN) method
It extends the transductive model by adding an instance-wise global classification network based on the
semantic information, after the common feature embedding network as a new prediction 「head」.
多頭:也就是說整個模型的基礎是MCT和DFMN,在這個基礎上加入了一個基於語義資訊的實例全局分類網路,將其公共特徵嵌入網路作為一種新的預測。(「兩頭變成三頭」)

Model

Problem Statement
In cross-domain few-shot learning setting, we have a source domain S = {Xs, Ys} from a total Cg classes and a target domain T = {Xt, Yt} from a set of totally different classes.(跨域問題定義)
The two domains have different marginal distributions in the input feature space, and disjoint output class sets.

模型如上圖所示,整個模型包含三個過程以及三個頭。
三個過程:
train:根據圖中箭頭可以看出,訓練過程三個頭都使用了,即使用MCT(基於距離的實例元訓練分類器)、DFMN(像素分類器)和基於全局資訊的語義分類器來訓練嵌入網路。
fine-tining:我們只使用語義全局分類器和目標域中的支援實例來微調模型。
test:我們使用MCT部分,即元訓練的實例分類器,用微調的嵌入網路來預測查詢集的標籤。

三個頭
MCT
可參考此文《 Transductive few-shot learning with meta-learned confidence》
//arxiv.org/pdf/2002.12017.pdf
The MCT uses distance based prototype classifier to make pre�diction for the query instances
(使用基於距離的原型分類器對查詢實例進行預測,感覺有點是基於原型網路的基礎)

DFMN
used solely in the training stage(看公式和度量學習有點類似,但這應該是直推/轉導學習的通用,後續看還需要看一下這方面的知識。)

值得注意的是,以上兩個基礎結構共享相同的特徵提取網路{f_θ}
a new global instance-wise prediction head
對於這個預測頭,我們考慮了所有Cg類上的全局分類問題。如圖1所示,支援集和查詢集都用於作為該分支的訓練輸入,例如:

Loss
Training stage.
The purpose of training is to pre-train an embedding model fθ (i.e., the feature extractor) in the source domain.

Fine-tuning stage:
Given a few-shot learning task in the target domain, we fine-tune the embedding model fθ on the
support set by using only the instance-wise prediction head fδ, aiming to adapt fθ to the target domain data

Experiments

總的來說整個模型是基於轉導/直推學習上的集成(多頭)
從實驗上來看效果也還不錯
Ending~


希望四月依舊好運 加油呀!五月