Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification
- 2022 年 10 月 29 日
- 筆記
- GraphMachineLearning, ML、DL
## 背景 **消息傳遞模型(Message Passing Model)**基於**拉普拉斯平滑假設**(領居是相似的),試圖聚合圖中的鄰居的信息來獲取足夠的依據,以實現更魯棒的半監督節點分類。 **圖神經網絡(Graph Neural Networks, GNN)**和**標籤傳播算法(Label Propagation, LPA)**均為消息傳遞算法,其中GNN主要基於傳播特徵來提升預測效果,而LPA基於迭代式的標籤傳播來作預測。 一些工作要麼用LPA對GNN預測結果做後處理,要麼用LPA對GNN進行正則化。但是,它們仍不能直接將GNN和LPA有效地整合到消息傳遞模型中。 為解決這個問題,本文提出了**統一消息傳遞模型(UNIMP)**[1],它可以在訓練和推理時結合特徵和標籤傳播。UniMP基於兩個簡單而有效的想法: – 將特徵嵌入和標籤嵌入同時作為輸入信息進行傳播 – 隨機掩碼部分標籤信息,並在訓練時對其進行預測 UniMP在概念上統一了特徵傳播和標籤傳播,具有強大的經驗能力。  ## 實現 ### 關鍵部分 – 將標籤進行嵌入(原有的C類One-hot標籤,通過線性變換成與原始節點特徵相同的維度)。 – 然後,將標籤嵌入和節點特徵相加作為GNN輸入。 為避免訓練時使用標籤導致標籤泄露,這裡使用了掩碼標籤訓練的策略。每個Epoch隨機將訓練集中部分節點的標籤置(掩碼)0(視為訓練監督信號),然後利用節點特徵 $\mathbf{X}$ 和 $\mathbf{A}$以及剩餘的標籤去預測被掩碼的標籤)。 ### 模型部分 UniMP中使用了GraphTransformer(Transformer中的Q、K、V注意力形式,加上邊特徵),同時引入了H-GCN的門控殘差機制來緩解過平滑。 ### 個人實驗 將標籤作為輸入,在ArixV數據集節點分類任務上,能在小數點後第2位提升接近2個點。 在論文BOT[2]中也對標籤作為輸入做了闡述,其作者還發表了相應的論文來論證標籤作為輸入的有效性的原因。 ### 總結 標籤有效的直覺就是,在圖上的節點分類任務中,鄰居標籤也是預測目標節點標籤的關鍵特徵(這也和標籤傳播的思想一致) 標籤嵌入和掩碼標籤預測是提升節點分類任務簡單有效的方法。 ### 參考文獻 > [1] Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification > [2] Bag of Tricks for Node Classification with Graph Neural Networks 2022-10-29 11:10:13 星期六