論文解讀(GGD)《Rethinking and Scaling Up Graph Contrastive Learning: An Extremely Efficient Approach with Group Discrimination》

論文資訊

論文標題:Rethinking and Scaling Up Graph Contrastive Learning: An Extremely Efficient Approach with Group Discrimination
論文作者:Yizhen Zheng, Shirui Pan, Vincent Cs Lee, Yu Zheng, Philip S. Yu
論文來源:2022,NeurIPS
論文地址:download 
論文程式碼:download 

1 Introduction

  GCL 需要大量的 Epoch 在數據集上訓練,本文的啟發來自 GCL 的代表性工作 DGI 和 MVGRL,因為 Sigmoid 函數存在的缺陷,因此,本文提出  Group Discrimination (GD) ,並基於此提出本文的模型 Graph Group Discrimination (GGD)。

  Graph ContrastiveLearning 和 Group Discrimination 的區別:

  

  • GD directly discriminates a group of positive nodes from a group of negative nodes.
  • GCL maximise the mutual information (MI) between an anchor node and its positive counterparts, sharing similar semantic information while doing the opposite for negative counterparts.

  貢獻:

  • 1) We re-examine existing GCL approaches (e.g., DGI and MVGRL), and we introduce a novel and efficient self-supervised GRL paradigm, namely, Group Discrimination (GD).
  • 2) Based on GD, we propose a new self-supervised GRL model, GGD, which is fast in training and convergence, and possess high scalability.
  • 3) We conduct extensive experiments on eight datasets, including an extremely large dataset, ogbn-papers100M with billion edges.

2 Rethinking Representative GCL Methods

  本節以經典的 DGI 、MVGRL 為例子,說明了互資訊最大化並不是對比學習的貢獻因素,而是一個新的範式,群體歧視(group discrimination)。

2.1 Rethinking GCL Methods

  回顧一下 DGI :

  

  程式碼:

class DGI(nn.Module):
    def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
        super(DGI, self).__init__()
        self.encoder = Encoder(g, in_feats, n_hidden, n_layers, activation, dropout)
        self.discriminator = Discriminator(n_hidden)
        self.loss = nn.BCEWithLogitsLoss()

    def forward(self, features):
        positive = self.encoder(features, corrupt=False)
        negative = self.encoder(features, corrupt=True)
        summary = torch.sigmoid(positive.mean(dim=0))
        positive = self.discriminator(positive, summary)
        negative = self.discriminator(negative, summary)
        l1 = self.loss(positive, torch.ones_like(positive))
        l2 = self.loss(negative, torch.zeros_like(negative))
        return l1 + l2

View Code

  本文研究 DGI 結論:一個 Sigmoid 函數不適用於權重被 Xavier 初始化的 GNN 生成的 summary vector,且 summary vector  中的元素非常接近於相同的值。

  接著嘗試將 Summary vector 中的數值變換成不同的常量 (from 0 to 1):

  

  結論:

    • 將 summary vector 中的數值變成 0,求解相似度時導致所有的 score 變成 0,也就是 postive 項的損失函數變成 負無窮,無法優化;
    • summary vector 設置其他值,導致 數值不穩定;

  DGI 的簡化:

  ① 將 summary vector 設置為 單位向量(縮放對損失不影響);

  ② 去掉 Discriminator (Bilinear​ :先做線性變換,再求內積相似度)的權重向量;【雙線性層的 $W$ 其實就是一個線性變換層】

    $\begin{aligned}\mathcal{L}_{D G I} &=\frac{1}{2 N}\left(\sum\limits _{i=1}^{N} \log \mathcal{D}\left(\mathbf{h}_{i}, \mathbf{s}\right)+\log \left(1-\mathcal{D}\left(\tilde{\mathbf{h}}_{i}, \mathbf{s}\right)\right)\right) \\&\left.=\frac{1}{2 N}\left(\sum\limits_{i=1}^{N} \log \left(\mathbf{h}_{i} \cdot \mathbf{s}\right)+\log \left(1-\tilde{\mathbf{h}}_{i} \cdot \mathbf{s}\right)\right)\right) \\&=\frac{1}{2 N}\left(\sum\limits_{i=1}^{N} \log \left(\operatorname{sum}\left(\mathbf{h}_{i}\right)\right)+\log \left(1-\operatorname{sum}\left(\tilde{\mathbf{h}}_{i}\right)\right)\right)\end{aligned} \quad\quad\quad(1)$

  Bilinear :

    $\mathcal{D}\left(\mathbf{h}_{i}, \mathbf{s}\right)=\sigma_{s i g}\left(\mathbf{h}_{i} \cdot \mathbf{W} \cdot \mathbf{s}\right)\quad\quad\quad(2)$

  實驗:替換 $\text{Eq.1}$ 中的 aggregation function ,即 sum 函數

  

  替換形式為:

    $\mathcal{L}_{B C E}=-\frac{1}{2 N}\left(\sum\limits _{i=1}^{2 N} y_{i} \log \hat{y}_{i}+\left(1-y_{i}\right) \log \left(1-\hat{y}_{i}\right)\right)\quad\quad\quad(3)$

  其中,$\hat{y}_{i}=\operatorname{agg}\left(\mathbf{h}_{i}\right)$ ,$y_{i} \in \mathbb{R}^{1 \times 1}$ ,$\hat{y}_{i} \in \mathbb{R}^{1 \times 1}$。論文中闡述 $y_{i}$ 和 $\hat{y}_{i}$ 分別代表 node $i$ 是否是 postive sample ,及其預測輸出。Q :當 aggregation function 採用 $\text{mean}$ 的時候,對於 postive  sample $i$ ,$\hat{y}_{i}$ 值會趨於 $1$ 么?

  DGI 真正所做的是區分正確拓撲生成的一組節點和損壞拓撲生成的節點,如 Figure 1 所示。可以這麼理解,DGI 是使用一個固定的向量 $s$ 去區分兩組節點嵌入矩陣(postive and negative)。

  為解決上述 GD 的問題,本文將考慮使用 $\text{Eq.3}$ 去替換 DGI 中的損失函數。替換的好處:節省顯示記憶體和加快計算速度,對於精度沒啥改變,說的天花亂墜。

  

  Note:方差大的稍微大一點的 method ,就是容易被詆毀。

  Group Discrimination 定義:GRL method,將不同組別的節點劃分給不同的組,對於 postive pair 和 negative pair 分別劃分到 “1” 組 和 “0” 組。

3 Methodology 

  整體框架:

  

  組成部分

    • Siamese Network :模仿 MVGRL 的架構;
    • Data Augmentation :提供相似意義資訊,帶來的是時間成本;【dropout edge、feature mask】
    • Loss function : $\text{Eq.3}$;
  模型推斷:

  首先:固定 GNN encoder、MLP predict 的參數,獲得初步的節點表示 $\mathbf{H}_{\theta}$;

  其次:MVGRL 多視圖對比工作給本文深刻的啟發,所以考慮引入全局資訊 :$ \mathbf{H}_{\theta}^{\text {global }}=\mathbf{A}^{n} \mathbf{H}_{\theta}$;

  最後:得到局部表示和全局表示的聚合 $\mathbf{H}=\mathbf{H}_{\theta}^{\text {global }}+\mathbf{H}_{\theta}$ ;

4 Experiments

4.1 Datasets

  

4.2 Result

節點分類

  

訓練時間 和 記憶體消耗

  

  

4.3 Evaluating on Large-scale datasets

  

  

  

  

 

5 Future Work

   For example, can we extend the current binary Group Discrimination scheme (i.e., classifying nodes generated with different topology) to discrimination among multiple groups?