圖神經網路庫DGL-message pasing

Message Passing

這是一篇關於DGL消息傳遞的讀後感。

Message Passing Paradigm

​ 這裡,x_v\in\mathbb{R}^{d_1}是節點v的特徵,w_e\in\mathbb{R}^{d_2}是邊(u,v)的特徵。這個t+1時刻的消息傳遞框架定義如下的逐邊(node-wise)和逐邊的計算:

\text{Edge-wise}:m_e^{t+1}=\phi(x_v^{t},x_u^{t},w_e^{t}),\ (u,v,e)\in\varepsilon \\
\text{Node-wise}:x_v^{t+1}=\psi(x_v^{t},\rho{\left(\left\{m_e^{t+1}:(u,v,e)\in \varepsilon\right\}\right)})

​ 在上面的等式中,\phi是定義在每條邊上的消息函數,它通過結合邊特徵和兩個節點來產生message;\psi是定義在每個節點上的更新函數,通過使用reduce functaion\rho聚合(aggregating)輸入資訊,從而更新此節點。

​ Note:上面對Message Passing Paradigm的解釋包含了很多資訊。接下來我嘗試通過下面這個圖來解釋這個計算框架:
image-20201113112324278.png

​ 在更新圖節點的特徵值,我們通常都是根據此節點的鄰居節點更新的。假設我們想要更新節點6。那麼,我們首先要知道的就是節點6和其鄰居節點(2, 3, 7, 8, 9)通過相連的邊傳遞的是什麼資訊,是相加嗎?還是其它計算。此時這裡的資訊就是利用消息函數去完成的,消息函數通過(u,v,e)在每條邊上都生成了這個message,待會就利用這些資訊去實施更新。

​ 要更新節點6了,我們要知道節點6的資訊更新只能通過邊a_{36},a_{26},a_{76},a_{68},a_{69}(對於無向圖來說,節點的順序不重要;有向圖的就集中一個方向就行)上的資訊去更新。這時候reduce function\rho非常重要,它就是根據將節點6的鄰居節點上的邊資訊收集起來,然後進行操作,可以是直接相加、相減、求和、平均、求權重和等等。那麼最後,就利用\psi去更新節點6的值。如果沒聽懂,我很抱歉,或許接下來的內容能讓你重新了解。如果聽懂了,我感到很欣慰,接下來的內容能讓你更深刻的理解。

2.1 Built-in Functions and Message Passing APIs

​ 接下來分別介紹message functionreduce functionupdate function

  • message function:由於消息函數會在邊上產生資訊,那麼,它需要一個edges參數。這個edges有三個成員,分別是src, dst, data。能偶用來訪問邊上源節點的特徵,目標節點的特徵和邊本身的特徵。例子如下,假設我們將邊上src的hu特徵和dst上的hv特徵相加,然後保存到he上。

    def message_func(edges):  # 消息函數的參數為edge
        return {'he':edges.src['hu']+edges.dst['hv']}
    

    當然,dgl庫本身也有處理這方面的內置函數:dgl.function.u_add_v('hu', 'hv', 'he')這裡的u_add_v就表明了把源節點的特徵和目標節點上的特徵相加。

  • reduce function:需要有一個nodes參數。它有一個mailbox成員,能夠用來訪問這個節點收到的message。就像最開始講的,一個節點只能夠收到來自於鄰居節點上的資訊。所以它這個mailbox就存儲了這些資訊。所以,如果我們想把mailbox收到的message相加,然後存儲到h里的話,也很簡單:

    import torch
    def reduce_func(nodes):
        return {'h':torch.sum(nodes.mailbox['m'], dim=1)}# 這裡之所以有'm',就是消息存儲到'm'這個鍵值上了,就像這裡的'h'
    

    當然,dgl也準備了內建函數dgl.function.sum('m', h)

  • update function:也需要一個參數nodes參數。它通常在最後一步去結合reduce function聚合的結果和目標節點的特徵去更新目標節點的特徵。

  • update_all():它是一個高階函數,融合了消息的產生、聚合和節點的更新。所以,它需要三個參數:message function,reduce function和update function。也可以在update_all()函數外調用更新函數而不用在update特別指定。DGL推薦這種在update_all()外定義更新函數,因為為了程式碼的簡潔,update function一般寫成純張量操作。具體實現例子如下,表達式為:將節點特徵ft_j和i-j相連接的邊特徵a_ij求和再乘以2。

    \text{final_ft}_{i}=2*\sum_{j\in{N(i)}}(ft_j*a_{ij})
    def update_all_example(graph):
        """
        	注意,這裡的特徵名稱是一開始都設置好的。這個圖本身包含了:
        	graph.ndata[ft]
        	graph.edata['a']
        	而m是臨時用來存儲message的
        """
        graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
        final_ft = graph.ndata['ft'] * 2
        return final_ft
    

    dgl.function里實現了很多了message functionreduce function

    接下來,介紹一個重點,update_all函數。使用這個函數,將極大簡化程式碼。

    此處不再贅述。DGL庫在message passing教程中接下來的內容都是關於如何優化使用和在不同場景下的使用,核心並沒有改變。接下來,我嘗試講解DGL例子中使用Message Passing來構造GCN的例子,來更加清晰的使用Message Passing。

使用Message Passing構造GCN

​ 在這篇文章中,節點的更新過程為:

Z=\widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}X\Theta \ \ \ \ (2)

​ 這裡\widetilde{A}=A+I,也就是鄰接矩陣增加了自環。主要,在實踐中,已經有自環的節點無需再增加自環。\widetilde{D}_{ii}=\sum_{j}\widetilde{A}_{ij}度矩陣是增加自環之後求得的。

g = dgl.remove_self_loop(g) # 增加自環時,首先去除原本的自環
g = dgl.add_self_loop(g)
degs = g.in_degrees().float() #無向圖中,入度和出度是相同的。
norm = torch.pow(degs, -0.5)

​ 接下來,我們需要將節點更新過程(2)用Message Passing來表達。首先,我們要知道Message Passing的思想就是在目標節點上求得edges上的資訊,然後聚合起來更新目標節點。先給出最終表達式,有個目的性,然後再一步步推導:

x_i^{k}=\sum_{j\in N(i)\cup{(i)}}\frac{1}{\sqrt{\deg(i)}\sqrt{\text{deg}(j)}}x_j^{k-1}\Theta \ \ \ (3)
以上的表達式說明,在更新第$k$層的第$i$個節點特徵時,將$k-1$層第$i$個節點特徵與其鄰居節點$j$特徵進行$\Theta$轉換、度的標準化,最後求和更新。這很符合圖卷積的思想:<font color="yellow">將鄰居節點的資訊結合起來,更新目標節點。</font>接下來將公式(2)進行分解到公式(3)。

​ 我們一步步看公式(2):

\widetilde{D}^{-\frac{1}{2}}\widetilde{A} =
\begin{bmatrix}
\frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\
0 & \frac{1}{\deg{(2)}} &\cdots &0\\
0 & 0 & \cdots& 0\\
\vdots & \vdots &\ddots &0 \\
0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}}
\end{bmatrix}*\widetilde{A}

​ Note:\frac{1}{\sqrt{\deg(i)}}是節點i的度的標準化。

​ 這裡,\widetilde{D}^{-\frac{1}{2}}\widetilde{A}就相當於將\widetilde{A}的第i行的值乘以節點i的度。要進行下一步操作時,我們首先要搞清楚鄰接矩陣(此文章增加了自環,但我們仍然用鄰接矩陣稱呼它)\widetilde{A}的意義。如下,

\begin{cases}
a_{ij}=1 & 當節點j是節點i的鄰居節點,那麼第i行的第j列為1 \\
a_{ij}=0 & 其它\\
\end{cases}

​ 我們令\widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}=M

\widetilde{D}^{-\frac{1}{2}}\widetilde{A} \widetilde{D}^{-\frac{1}{2}}=\begin{bmatrix}
\frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\
0 & \frac{1}{\deg{(2)}} &\cdots &0\\
0 & 0 & \cdots& 0\\
\vdots & \vdots &\ddots &0 \\
0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}}
\end{bmatrix}*\widetilde{A}*\begin{bmatrix}
\frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\
0 & \frac{1}{\deg{(2)}} &\cdots &0\\
0 & 0 & \cdots& 0\\
\vdots & \vdots &\ddots &0 \\
0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}}
\end{bmatrix} \\
\begin{bmatrix}a_{11}\frac{1}{\sqrt{\deg{(1)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{12}\frac{1}{\sqrt{\deg{(1)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{1n}\frac{1}{\sqrt{\deg{(1)}}}*\frac{1}{\sqrt{\deg{(n)}}} \\a_{21}\frac{1}{\sqrt{\deg{(2)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{22}\frac{1}{\sqrt{\deg{(2)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{2n}\frac{1}{\sqrt{\deg{(2)}}}*\frac{1}{\sqrt{\deg{(n)}}}\\
\vdots & \vdots &\ddots &\vdots \\a_{n1}\frac{1}{\sqrt{\deg{(n)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{n2}\frac{1}{\sqrt{\deg{(n)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{nn}\frac{1}{\sqrt{\deg{(n)}}}*\frac{1}{\sqrt{\deg{(n)}}}
\end{bmatrix}

​ 大家看到這裡應該比較清楚了,這裡的係數a_{ij}只能為0或者1,並且取決於j是否為i的節點。X\Theta僅僅只是將特徵X進行特徵映射了,僅僅只是改變X的列維度。那麼,接下來MX的值就意味著矩陣M使用係數去選取X的值,這樣就實現了選取鄰居節點特徵的意義。(害,其實我剛看到公式(2)的時候,內心是崩潰的,它怎麼就實現抽取鄰居節點特徵的卷積效果,推導到公式(3)時才恍然大悟,原來是這麼回事)。

​ 那麼MX的每一行值,就是公式三的結果了。我們再次看下公式(3):

x_i^{k}=\sum_{j\in N(i)\cup{(i)}}\frac{1}{\sqrt{\deg(i)}\sqrt{\text{deg}(j)}}x_j^{k-1}\Theta \ \ \ (3)

​ 分析出message function、reduce function為:

  • message function:每個源節點特徵乘以其度的正則
  • reduce function:將message function產生的資訊求和,並且乘以目標節點的度正則。
def gcn_msg(edge):
    # 在邊上的源節點上,乘以其度的正則
    msg = edge.src['h'] * edge.src['norm']
    return {'m': msg} 
def gcn_reduce(node):
    # 將目標節點的邊上的資訊聚合(這裡是sum),再乘以目標節點上的度的正則
    # 這裡的torch.sum(, dim=1)在維度1上相加,是因為node.mailbox['m']的shape = [batch, mails, feat]
    # 需要將mails整個加起來
    accum = torch.sum(node.mailbox['m'], 1) * node.data['norm'] 
    return {'h': accum} # 這時候,節點就存在一個數據node.data['h'] = accum
class NodeApplyModule(nn.Module):
    def __init__(self, out_feats, activation=None, bias=True):
        super(NodeApplyModule, self).__init__()
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_feats))
        else:
            self.bias = None
        self.activation = activation
        self.reset_parameters()

    def reset_parameters(self):
        if self.bias is not None:
            stdv = 1. / math.sqrt(self.bias.size(0))
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, nodes): # 要更新時,添加上偏置和激活函數
        h = nodes.data['h']
        if self.bias is not None:
            h = h + self.bias
        if self.activation:
            h = self.activation(h)
        return {'h': h} # 此時nodes['h'] = h 這是就被更新了
class GCNLayer(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 out_feats,
                 activation,
                 dropout,
                 bias=True):
        super(GCNLayer, self).__init__()
        self.g = g
        self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = 0.
        self.node_update = NodeApplyModule(out_feats, activation, bias)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, h):
        if self.dropout:
            h = self.dropout(h)
        self.g.ndata['h'] = torch.mm(h, self.weight) # 這裡是首先將節點特徵進行映射,也就是X*O
        self.g.update_all(gcn_msg, gcn_reduce, self.node_update) # 然後求message,聚合,更新。
        h = self.g.ndata.pop('h')
        return h
# add self loop
g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g)
n_edges = g.number_of_edges()

# normalization
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0
g.ndata['norm'] = norm.unsqueeze(1) # 傳遞到GCN的圖已經存在了'norm'數據