圖神經網絡庫DGL-message pasing
Message Passing
這是一篇關於DGL消息傳遞的讀後感。
Message Passing Paradigm
這裡,x_v\in\mathbb{R}^{d_1}是節點v的特徵,w_e\in\mathbb{R}^{d_2}是邊(u,v)的特徵。這個t+1時刻的消息傳遞框架定義如下的逐邊(node-wise)和逐邊的計算:
\text{Node-wise}:x_v^{t+1}=\psi(x_v^{t},\rho{\left(\left\{m_e^{t+1}:(u,v,e)\in \varepsilon\right\}\right)})
在上面的等式中,\phi是定義在每條邊上的消息函數,它通過結合邊特徵和兩個節點來產生message;\psi是定義在每個節點上的更新函數,通過使用reduce functaion
\rho去聚合(aggregating)輸入信息,從而更新此節點。
Note:上面對Message Passing Paradigm的解釋包含了很多信息。接下來我嘗試通過下面這個圖來解釋這個計算框架:
在更新圖節點的特徵值,我們通常都是根據此節點的鄰居節點更新的。假設我們想要更新節點6。那麼,我們首先要知道的就是節點6和其鄰居節點(2, 3, 7, 8, 9)通過相連的邊傳遞的是什麼信息,是相加嗎?還是其它計算。此時這裡的信息就是利用消息函數去完成的,消息函數通過(u,v,e)在每條邊上都生成了這個message,待會就利用這些信息去實施更新。
要更新節點6了,我們要知道節點6的信息更新只能通過邊a_{36},a_{26},a_{76},a_{68},a_{69}(對於無向圖來說,節點的順序不重要;有向圖的就集中一個方向就行)上的信息去更新。這時候reduce function
\rho非常重要,它就是根據將節點6的鄰居節點上的邊信息收集起來,然後進行操作,可以是直接相加、相減、求和、平均、求權重和等等。那麼最後,就利用\psi去更新節點6的值。如果沒聽懂,我很抱歉,或許接下來的內容能讓你重新了解。如果聽懂了,我感到很欣慰,接下來的內容能讓你更深刻的理解。
2.1 Built-in Functions and Message Passing APIs
接下來分別介紹message function,reduce function和update function。
-
message function:由於消息函數會在邊上產生信息,那麼,它需要一個
edges
參數。這個edges
有三個成員,分別是src, dst, data。能偶用來訪問邊上源節點的特徵,目標節點的特徵和邊本身的特徵。例子如下,假設我們將邊上src的hu特徵和dst上的hv特徵相加,然後保存到he上。def message_func(edges): # 消息函數的參數為edge return {'he':edges.src['hu']+edges.dst['hv']}
當然,dgl庫本身也有處理這方面的內置函數:
dgl.function.u_add_v('hu', 'hv', 'he')
這裡的u_add_v就表明了把源節點的特徵和目標節點上的特徵相加。 -
reduce function:需要有一個
nodes
參數。它有一個mailbox
成員,能夠用來訪問這個節點收到的message。就像最開始講的,一個節點只能夠收到來自於鄰居節點上的信息。所以它這個mailbox
就存儲了這些信息。所以,如果我們想把mailbox
收到的message相加,然後存儲到h里的話,也很簡單:import torch def reduce_func(nodes): return {'h':torch.sum(nodes.mailbox['m'], dim=1)}# 這裡之所以有'm',就是消息存儲到'm'這個鍵值上了,就像這裡的'h'
當然,dgl也準備了內建函數
dgl.function.sum('m', h)
。 -
update function:也需要一個參數
nodes
參數。它通常在最後一步去結合reduce function
聚合的結果和目標節點的特徵去更新目標節點的特徵。 -
update_all()
:它是一個高階函數,融合了消息的產生、聚合和節點的更新。所以,它需要三個參數:message function,reduce function和update function。也可以在update_all()
函數外調用更新函數而不用在update特別指定。DGL推薦這種在update_all()
外定義更新函數,因為為了代碼的簡潔,update function一般寫成純張量操作。具體實現例子如下,表達式為:將節點特徵ft_j
和i-j相連接的邊特徵a_ij
求和再乘以2。\text{final_ft}_{i}=2*\sum_{j\in{N(i)}}(ft_j*a_{ij})def update_all_example(graph): """ 注意,這裡的特徵名稱是一開始都設置好的。這個圖本身包含了: graph.ndata[ft] graph.edata['a'] 而m是臨時用來存儲message的 """ graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft')) final_ft = graph.ndata['ft'] * 2 return final_ft
dgl.function
里實現了很多了message function和reduce function。接下來,介紹一個重點,
update_all
函數。使用這個函數,將極大簡化代碼。此處不再贅述。DGL庫在message passing教程中接下來的內容都是關於如何優化使用和在不同場景下的使用,核心並沒有改變。接下來,我嘗試講解DGL例子中使用Message Passing來構造GCN的例子,來更加清晰的使用Message Passing。
使用Message Passing構造GCN
在這篇文章中,節點的更新過程為:
這裡\widetilde{A}=A+I,也就是鄰接矩陣增加了自環。主要,在實踐中,已經有自環的節點無需再增加自環。\widetilde{D}_{ii}=\sum_{j}\widetilde{A}_{ij}度矩陣是增加自環之後求得的。
g = dgl.remove_self_loop(g) # 增加自環時,首先去除原本的自環
g = dgl.add_self_loop(g)
degs = g.in_degrees().float() #無向圖中,入度和出度是相同的。
norm = torch.pow(degs, -0.5)
接下來,我們需要將節點更新過程(2)用Message Passing來表達。首先,我們要知道Message Passing的思想就是在目標節點上求得edges上的信息,然後聚合起來更新目標節點。先給出最終表達式,有個目的性,然後再一步步推導:
以上的表達式說明,在更新第$k$層的第$i$個節點特徵時,將$k-1$層第$i$個節點特徵與其鄰居節點$j$特徵進行$\Theta$轉換、度的標準化,最後求和更新。這很符合圖卷積的思想:<font color="yellow">將鄰居節點的信息結合起來,更新目標節點。</font>接下來將公式(2)進行分解到公式(3)。
我們一步步看公式(2):
\begin{bmatrix}
\frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\
0 & \frac{1}{\deg{(2)}} &\cdots &0\\
0 & 0 & \cdots& 0\\
\vdots & \vdots &\ddots &0 \\
0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}}
\end{bmatrix}*\widetilde{A}
Note:\frac{1}{\sqrt{\deg(i)}}是節點i的度的標準化。
這裡,\widetilde{D}^{-\frac{1}{2}}\widetilde{A}就相當於將\widetilde{A}的第i行的值乘以節點i的度。要進行下一步操作時,我們首先要搞清楚鄰接矩陣(此文章增加了自環,但我們仍然用鄰接矩陣稱呼它)\widetilde{A}的意義。如下,
a_{ij}=1 & 當節點j是節點i的鄰居節點,那麼第i行的第j列為1 \\
a_{ij}=0 & 其它\\
\end{cases}
我們令\widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}=M
\frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\
0 & \frac{1}{\deg{(2)}} &\cdots &0\\
0 & 0 & \cdots& 0\\
\vdots & \vdots &\ddots &0 \\
0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}}
\end{bmatrix}*\widetilde{A}*\begin{bmatrix}
\frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\
0 & \frac{1}{\deg{(2)}} &\cdots &0\\
0 & 0 & \cdots& 0\\
\vdots & \vdots &\ddots &0 \\
0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}}
\end{bmatrix} \\
\begin{bmatrix}a_{11}\frac{1}{\sqrt{\deg{(1)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{12}\frac{1}{\sqrt{\deg{(1)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{1n}\frac{1}{\sqrt{\deg{(1)}}}*\frac{1}{\sqrt{\deg{(n)}}} \\a_{21}\frac{1}{\sqrt{\deg{(2)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{22}\frac{1}{\sqrt{\deg{(2)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{2n}\frac{1}{\sqrt{\deg{(2)}}}*\frac{1}{\sqrt{\deg{(n)}}}\\
\vdots & \vdots &\ddots &\vdots \\a_{n1}\frac{1}{\sqrt{\deg{(n)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{n2}\frac{1}{\sqrt{\deg{(n)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{nn}\frac{1}{\sqrt{\deg{(n)}}}*\frac{1}{\sqrt{\deg{(n)}}}
\end{bmatrix}
大家看到這裡應該比較清楚了,這裡的係數a_{ij}只能為0或者1,並且取決於j是否為i的節點。X\Theta僅僅只是將特徵X進行特徵映射了,僅僅只是改變X的列維度。那麼,接下來MX的值就意味着矩陣M使用係數去選取X的值,這樣就實現了選取鄰居節點特徵的意義。(害,其實我剛看到公式(2)的時候,內心是崩潰的,它怎麼就實現抽取鄰居節點特徵的卷積效果,推導到公式(3)時才恍然大悟,原來是這麼回事)。
那麼MX的每一行值,就是公式三的結果了。我們再次看下公式(3):
分析出message function、reduce function為:
- message function:每個源節點特徵乘以其度的正則
- reduce function:將message function產生的信息求和,並且乘以目標節點的度正則。
def gcn_msg(edge):
# 在邊上的源節點上,乘以其度的正則
msg = edge.src['h'] * edge.src['norm']
return {'m': msg}
def gcn_reduce(node):
# 將目標節點的邊上的信息聚合(這裡是sum),再乘以目標節點上的度的正則
# 這裡的torch.sum(, dim=1)在維度1上相加,是因為node.mailbox['m']的shape = [batch, mails, feat]
# 需要將mails整個加起來
accum = torch.sum(node.mailbox['m'], 1) * node.data['norm']
return {'h': accum} # 這時候,節點就存在一個數據node.data['h'] = accum
class NodeApplyModule(nn.Module):
def __init__(self, out_feats, activation=None, bias=True):
super(NodeApplyModule, self).__init__()
if bias:
self.bias = nn.Parameter(torch.Tensor(out_feats))
else:
self.bias = None
self.activation = activation
self.reset_parameters()
def reset_parameters(self):
if self.bias is not None:
stdv = 1. / math.sqrt(self.bias.size(0))
self.bias.data.uniform_(-stdv, stdv)
def forward(self, nodes): # 要更新時,添加上偏置和激活函數
h = nodes.data['h']
if self.bias is not None:
h = h + self.bias
if self.activation:
h = self.activation(h)
return {'h': h} # 此時nodes['h'] = h 這是就被更新了
class GCNLayer(nn.Module):
def __init__(self,
g,
in_feats,
out_feats,
activation,
dropout,
bias=True):
super(GCNLayer, self).__init__()
self.g = g
self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
if dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.
self.node_update = NodeApplyModule(out_feats, activation, bias)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
def forward(self, h):
if self.dropout:
h = self.dropout(h)
self.g.ndata['h'] = torch.mm(h, self.weight) # 這裡是首先將節點特徵進行映射,也就是X*O
self.g.update_all(gcn_msg, gcn_reduce, self.node_update) # 然後求message,聚合,更新。
h = self.g.ndata.pop('h')
return h
# add self loop
g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g)
n_edges = g.number_of_edges()
# normalization
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0
g.ndata['norm'] = norm.unsqueeze(1) # 傳遞到GCN的圖已經存在了'norm'數據