深度学习(五)之原型网络

在本文中,将介绍一些关于小样本学习的相关知识,以及介绍如何使用pytorch构建一个原型网络(Prototypical Networks[1]),并应用于miniImageNet 数据集。

实验环境:

pytorch:1.11.0
代码地址://github.com/xiaohuiduan/deeplearning-study/tree/main/小样本学习

小样本学习引入

在这一节将简要的对小样本学习(FSL)相关的知识进行介绍。由于我并不是专门研究小样本的(我学习FSL也只是为了完成我的课程作业),因此,如果本文存在任何问题,欢迎进行批评指正😎。

邮箱📫:[email protected]

首先将模型看成一个黑盒子,不去关注它的内部结构,而是关注其inputoutput

在分类模型[2]中,input是一张猫or狗的图片,output则为0/1(代表其为猫或者狗;实际上输出的是两者的预测概率)。

但是关于上面的模型,存在一个问题,训练这样的模型需要大量的数据,根据@petewarden[3]的说法,训练一个分类图片的网络,每个类别需要大约1000张。但是,很多场景我们并没有足够的数据进行训练,也就是说数据集的样本比较少,这时候针对小样本数据集可以有两种处理方式[4]

  • 数据增强:比如说对图像进行旋转,裁剪等等。

  • 数据建模:如使用小样本学习的方法进行对数据进行建模。

小样本学习是什么

以VGG模型预测猫狗分类举例,模型对于一张新的图片的预测可以形象的解释为:

图片中的动物因为有着尖尖的耳朵,有着较长的胡须,鼻子那一个地方不是很突出,因此,我(VGG)判断它是一只猫o(=•ェ•=)m。

但是在小样本学习中,不是这样的,小样本学习对于一张新的图片的预测可以形象的解释为:

我手中有2张图片,图片A和图片B。对于新的图片,我(模型)也不知道他是啥,但是我发现它跟图片B长得很相似,因此我(模型)判断这张新的图片和图片B是同一个类别。

在上述解释中,图片A和B称之为support sets,而新的图片称之为query sets

小样本的分类模型与传统的深度学习分类模型(如VGG)有着不同,这里引用论文[1:1]的一句话:

Few-shot classification is a task in which a classifier must be adapted to accommodate
new classes not seen in training, given only a few examples of each of these classes. A naive approach,
such as re-training the model on the new data, would severely overfit.

也就是说,小样本分类并不是像VGG一样,test的数据是以前训练过的类别,对于小样本分类来说,进行test的数据是一些新的类,并且这些类别的样本很少,因此,没法对其进行re-training,否则会造成过拟合。

小样本学习方法

小样本学习可以认为是一个N-way K-shot的分类问题(不确定是不是所有的小样本分类任务都被认为是N-way K-shot分类问题)。

无论对于测试集还是训练集,都需要进行如下的划分,将数据集划分为两个部分:左边DataSet代表数据集,右边分别代表Support set,右边代表Query Set。在train或者test数据集中:

  1. 首先对于所有类别,随机选择其中N(图中N=3)个类别(图中,选择了类别2,类别3和类别5)。
  2. 在step 1中选择的类别样本中,随机选择K(图中K=3)个样本(绿色的部分),构成Support Set。也就是说,Support Set中拥有K*N个样本。
  3. 然后在所选择类别的剩余样本中,选择X(这里X=1)个样本(红色的部分),构成Query Set。也就是说,Query Set中拥有X*N个样本。

在训练集中,以上步骤构成的Support Set和Query Set会被input到Model中进行训练,称之为一个eposides(相当于mini-batch)。

对于模型来说,其目的则为判断Query Set中的样本与哪一个支持集最相似

原型网络(Prototype Network)

原理简述

Prototype Network的原理很简单,可以简单的概括为:将support set中的图片\(data_1,data_2,\cdots,data_n\)映射到某一个向量空间\(c_1,c_2,\cdots,c_n\);对于Query set中的某一张图片\(query_i\)使用同一个映射函数,也映射到到向量空间\(x_i\),然后判断\(x_i\)\(c_1,c_2,\cdots,c_n\)的距离(余弦距离or欧氏距离),选择距离最近向量所对应的类别作为\(query_i\)所属的类别。

示意图[5]如下所示:

如果了解NLP中word2vec的话,会发现,其与word2vec的Embedding思想是很相似的。

算法流程

算法流程图[1:2]如下所示,红色框和绿色框中的过程已经在前文进行介绍,这里主要是来介绍一下loss的计算方式。

image-20220507144721234

实际上,loss的计算方式就是一个交叉熵损失函数,pytorch中CrossEntropyLoss的计算方法如下所示,class代表\(x\)实际所属类别\(x[j]\)代表模型对于\(x\)所属类别\(j\)的概率预测。

\[\operatorname{loss}(x, \text { class })=-\log \left(\frac{\exp (x[\text { class }])}{\sum_{j} \exp (x[j])}\right)=-x[\text { class }]+\log \left(\sum_{j} \exp (x[j])\right)
\]

但是,在算法流程图中,大家会发现,其loss计算的正负号刚好与上面公式中的相反,解释如下:

以欧式距离为例,距离越远(\(d\)则越大),则代表两者的相似度越低。如果不加负号的话,进行softmax计算,距离越远的则predict概率越大,这明显是错误的。因此,加了一个负号之后,距离越远,进行softmax之后,输出则越小,predict的概率也变小,这才是合理的。

以上,便是原型网络的算法流程。

算法实现

数据集处理

mini-Imagenet是一个专门用于训练小样本学习的训练集,数据集中一共有100个类别,每个类别600张图片,一共有60000张图片。数据集可以从mini-ImageNet | Kaggles上面下载。在下载文件中,一共有4个文件,一个是数据集图片的压缩包,另外3个csv文件分别代表了训练、验证和测试集相关的信息。其中训练集有64个类,验证集16个类,测试集20个类。

csv的部分数据如下所示,filename代表了图片的名字,label代表了图片对应的标签。

因此,可以构建一个label所对应filename的字典

def read_csv(csv_path):
    dict = collections.defaultdict(list)
    df = pd.read_csv(csv_path)
    for index,row in df.iterrows():
        dict[row["label"]].append(row["filename"])
    return dict
train_dict = read_csv(train_csv_path)
val_dict = read_csv(val_csv_path)
test_dict = read_csv(test_csv_path)

同时,构建data与labels的对应关系:

from PIL import Image
import numpy as np
from torchvision import transforms
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

resize_transform = transforms.Resize(84) # 提前对图片进行缩放,以节省内存空间,将最短的边变成84
def build_data(data_dict):
    datas = []
    labels = []
    label_index = 0
    for label in data_dict.keys(): # 对图片的标签进行迭代
        for path in data_dict[label]: # 对标签对应的文件名进行迭代
            img_path = os.path.join(img_root_dir,path) 
            img = Image.open(img_path) # 读取文件
            img = resize_transform(img) # 进行缩放
            datas.append(img) 
            labels.append(label_index)

        label_index += 1
    return {"datas":datas,"labels":labels}

随机产生Support set 和 Query set

在下面代码中,CategoriesSampler的作用是为了产生index,然后供给dataloader使用。

class CategoriesSampler():
    """
        目的是为了随机产生K_way*(N_support+N_query)个图片对应的index
    """
    def __init__(self, data, n_batch, K_way, N_per):
        self.n_batch = n_batch
        self.K_way = K_way
        self.N_per = N_per
        labels = np.array(data["labels"]) # [0,0,0,0,1,1,1,1,2,2,2,2……]
        self.index = [] # 记录label对应的索引位置
        for i in range(max(labels)+1):
            ind = np.argwhere(labels == i).reshape(-1)
            self.index.append(torch.from_numpy(ind))   

    def __len__(self):
        return self.n_batch
    
    def __iter__(self):
        for i_batch in range(self.n_batch):  
            batch = []
            classes = torch.randperm(len(self.index))[:self.K_way] # 随机选择K个类别构成support set和query set
            for c in classes:
                l = self.index[c] # 类别c对应的图片数组的索引,如 l = [5,6,7,8,9]
                pos = torch.randperm(len(l))[:self.N_per] # 如 pos = [4,1,0]
                batch.append(l[pos]) # 如 l[pos] = [9,6,5]

            batch = torch.stack(batch).reshape(-1)
            yield batch

同时,定义Dataset类,如下:

class MiniImageNet(Dataset):

    def __init__(self, data):

        self.datas = data["datas"]
        self.labels = data["labels"]
        self.transform = transforms.Compose([
            transforms.CenterCrop(84),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.datas)

    def __getitem__(self, i):
        img, label = self.datas[i], self.labels[i]
        return self.transform(img), label

使用示例如下,在MiniImageNet的__getitem__函数中,其参数iCategoriesSampler__iter__函数所产生。

batch_sampler = CategoriesSampler(datas,eval_step,K_way,N_shot+N_query)
data_loader = DataLoader(dataset=data_set, batch_sampler=batch_sampler,
                                            num_workers=16, pin_memory=True)

实际上,上面的代码就是为了实现如下图所示的功能:

Embedding网络构建

前文说到,需要将图片进行向量化表示,在下面的代码中,便可以将一张图片(shape=\(3\times84\times84\))变成一个1600维的向量。(网络结构来自于论文)

class CNN_Net(nn.Module):
    """
        用于特征提取
    """

    def __init__(self, input_dim):
        super(CNN_Net, self).__init__()
        
        self.input_dim = input_dim
        def conv_block(in_channel,out_channel):
            return nn.Sequential(
                nn.Conv2d(in_channel, out_channel, 3,padding=1),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )
        self.encoder = nn.Sequential(
            conv_block(input_dim,64),
            conv_block(64,64),
            conv_block(64,64),
            conv_block(64,64),
        )
    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        return x

损失函数

关于计算损失的关键函数如下所示(参考了论文作者的源代码[6]):

def cal_euc_distance(self, query_z, center,K_way, N_query):
    """
        计算query_z与center的距离
        query_z : (K_way*N_query,z_dim)
        center : (K_way,z_dim)
    """
    center = center.unsqueeze(0).expand(
        K_way*N_query, K_way, self.z_dim)  # (K_way*N_query,K_way,z_dim)
    query_z = query_z.unsqueeze(1).expand(
        K_way*N_query, K_way, self.z_dim)  # (K_way*N_query,K_way,z_dim)

    return torch.pow(query_z-center, 2).sum(2)  # (K_way*N_query,K_way)

def loss_acc(self, query_z, center, K_way, N_query):
    """
        计算loss和acc
        query_z : (K_way*N_query,z_dim)
        center : (K_way,z_dim)
    """
    target_inds = torch.arange(0, K_way).view(K_way, 1).expand(
        K_way, N_query).long().to(self.device) # shape=(K_way, N_query)
    
    distance = self.cal_euc_distance(query_z, center,K_way, N_query)    # (K_way*N_query,K_way) 
    predict_label = torch.argmin(distance, dim=1)  # (K_way*N_query) 预测出来的label

    acc = torch.eq(target_inds.contiguous().view(-1),
                    predict_label).float().mean() # 准确率

    loss = F.log_softmax(-distance, dim=1).view(K_way,
                                                N_query, K_way)  # (K_way,N_query,K_way)
    loss = - \
        loss.gather(dim=2, index=target_inds.unsqueeze(2)).view(-1).mean()
    return loss, acc

def set_forward_loss(self, K_way, N_shot, N_query,sample_datas):
    """
        sample_datas: shape(K_way*(N_shot+N_query),3,84,84)
    """

    z = self.cnn_net(sample_datas) # shape=(K_way*(N_shot+N_query),z_dim) ,将support set和query set都进行向量化表示
    z = z.view(K_way,N_shot+N_query,-1) # shape = (K_way,N_shot+N_query,1600)
    
    support_z = z[:,:N_shot] # support set的向量化表示 shape=(K_way,N_shot,1600)
    query_z = z[:,N_shot:].contiguous().view(K_way*N_query,-1) # Query set的向量化表示 shape=(K_way*N_query,1600)
    
    center = torch.mean(support_z, dim=1) # 计算support set的向量均值,shape=(K_way,1600)
    return self.loss_acc(query_z, center,K_way,N_query)

关于实验中具体的参数设计,可以参考论文[1:3]或者Github上面的源代码。在原论文中,对于实验的设计讲得非常清楚。

实验结果

下面的表格为测试集的acc(当验证集acc为最大值时测试集所对应的acc):

N-shot=1 N-shot=5
K_way=5 0.4313 0.6684

总结

总的来说,原型网络是一个容易理解的网络模型,思想简单,易于实现。

References


  1. [1703.05175] Prototypical Networks for Few-shot Learning (arxiv.org) ↩︎ ↩︎ ↩︎ ↩︎

  2. 深度学习(二)之猫狗分类 – 段小辉 – 博客园 (cnblogs.com) ↩︎

  3. How many images do you need to train a neural network? « Pete Warden’s blog ↩︎

  4. 刘颖, 雷研博, 范九伦, 王富平, 公衍超, 田奇. 基于小样本学习的图像分类技术综述. 自动化学报, 2021, 47(2): 297−315 ↩︎

  5. 【Pytorch】prototypical network原型网络小样本图像分类简述及其实现_Jnchin的博客-CSDN博客_原型网络小样本 ↩︎

  6. jakesnell/prototypical-networks: Code for the NeurIPS 2017 Paper “Prototypical Networks for Few-shot Learning” (github.com) ↩︎