Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification
- 2022 年 10 月 29 日
- 筆記
- GraphMachineLearning, ML、DL
## 背景 **消息传递模型(Message Passing Model)**基于**拉普拉斯平滑假设**(领居是相似的),试图聚合图中的邻居的信息来获取足够的依据,以实现更鲁棒的半监督节点分类。 **图神经网络(Graph Neural Networks, GNN)**和**标签传播算法(Label Propagation, LPA)**均为消息传递算法,其中GNN主要基于传播特征来提升预测效果,而LPA基于迭代式的标签传播来作预测。 一些工作要么用LPA对GNN预测结果做后处理,要么用LPA对GNN进行正则化。但是,它们仍不能直接将GNN和LPA有效地整合到消息传递模型中。 为解决这个问题,本文提出了**统一消息传递模型(UNIMP)**[1],它可以在训练和推理时结合特征和标签传播。UniMP基于两个简单而有效的想法: – 将特征嵌入和标签嵌入同时作为输入信息进行传播 – 随机掩码部分标签信息,并在训练时对其进行预测 UniMP在概念上统一了特征传播和标签传播,具有强大的经验能力。 ![image](//img2022.cnblogs.com/blog/1596082/202210/1596082-20221029104805109-1144057785.png) ## 实现 ### 关键部分 – 将标签进行嵌入(原有的C类One-hot标签,通过线性变换成与原始节点特征相同的维度)。 – 然后,将标签嵌入和节点特征相加作为GNN输入。 为避免训练时使用标签导致标签泄露,这里使用了掩码标签训练的策略。每个Epoch随机将训练集中部分节点的标签置(掩码)0(视为训练监督信号),然后利用节点特征 $\mathbf{X}$ 和 $\mathbf{A}$以及剩余的标签去预测被掩码的标签)。 ### 模型部分 UniMP中使用了GraphTransformer(Transformer中的Q、K、V注意力形式,加上边特征),同时引入了H-GCN的门控残差机制来缓解过平滑。 ### 个人实验 将标签作为输入,在ArixV数据集节点分类任务上,能在小数点后第2位提升接近2个点。 在论文BOT[2]中也对标签作为输入做了阐述,其作者还发表了相应的论文来论证标签作为输入的有效性的原因。 ### 总结 标签有效的直觉就是,在图上的节点分类任务中,邻居标签也是预测目标节点标签的关键特征(这也和标签传播的思想一致) 标签嵌入和掩码标签预测是提升节点分类任务简单有效的方法。 ### 参考文献 > [1] Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification > [2] Bag of Tricks for Node Classification with Graph Neural Networks 2022-10-29 11:10:13 星期六