图神经网络库DGL-message pasing

Message Passing

这是一篇关于DGL消息传递的读后感。

Message Passing Paradigm

​ 这里,x_v\in\mathbb{R}^{d_1}是节点v的特征,w_e\in\mathbb{R}^{d_2}是边(u,v)的特征。这个t+1时刻的消息传递框架定义如下的逐边(node-wise)和逐边的计算:

\text{Edge-wise}:m_e^{t+1}=\phi(x_v^{t},x_u^{t},w_e^{t}),\ (u,v,e)\in\varepsilon \\
\text{Node-wise}:x_v^{t+1}=\psi(x_v^{t},\rho{\left(\left\{m_e^{t+1}:(u,v,e)\in \varepsilon\right\}\right)})

​ 在上面的等式中,\phi是定义在每条边上的消息函数,它通过结合边特征和两个节点来产生message;\psi是定义在每个节点上的更新函数,通过使用reduce functaion\rho聚合(aggregating)输入信息,从而更新此节点。

​ Note:上面对Message Passing Paradigm的解释包含了很多信息。接下来我尝试通过下面这个图来解释这个计算框架:
image-20201113112324278.png

​ 在更新图节点的特征值,我们通常都是根据此节点的邻居节点更新的。假设我们想要更新节点6。那么,我们首先要知道的就是节点6和其邻居节点(2, 3, 7, 8, 9)通过相连的边传递的是什么信息,是相加吗?还是其它计算。此时这里的信息就是利用消息函数去完成的,消息函数通过(u,v,e)在每条边上都生成了这个message,待会就利用这些信息去实施更新。

​ 要更新节点6了,我们要知道节点6的信息更新只能通过边a_{36},a_{26},a_{76},a_{68},a_{69}(对于无向图来说,节点的顺序不重要;有向图的就集中一个方向就行)上的信息去更新。这时候reduce function\rho非常重要,它就是根据将节点6的邻居节点上的边信息收集起来,然后进行操作,可以是直接相加、相减、求和、平均、求权重和等等。那么最后,就利用\psi去更新节点6的值。如果没听懂,我很抱歉,或许接下来的内容能让你重新了解。如果听懂了,我感到很欣慰,接下来的内容能让你更深刻的理解。

2.1 Built-in Functions and Message Passing APIs

​ 接下来分别介绍message functionreduce functionupdate function

  • message function:由于消息函数会在边上产生信息,那么,它需要一个edges参数。这个edges有三个成员,分别是src, dst, data。能偶用来访问边上源节点的特征,目标节点的特征和边本身的特征。例子如下,假设我们将边上src的hu特征和dst上的hv特征相加,然后保存到he上。

    def message_func(edges):  # 消息函数的参数为edge
        return {'he':edges.src['hu']+edges.dst['hv']}
    

    当然,dgl库本身也有处理这方面的内置函数:dgl.function.u_add_v('hu', 'hv', 'he')这里的u_add_v就表明了把源节点的特征和目标节点上的特征相加。

  • reduce function:需要有一个nodes参数。它有一个mailbox成员,能够用来访问这个节点收到的message。就像最开始讲的,一个节点只能够收到来自于邻居节点上的信息。所以它这个mailbox就存储了这些信息。所以,如果我们想把mailbox收到的message相加,然后存储到h里的话,也很简单:

    import torch
    def reduce_func(nodes):
        return {'h':torch.sum(nodes.mailbox['m'], dim=1)}# 这里之所以有'm',就是消息存储到'm'这个键值上了,就像这里的'h'
    

    当然,dgl也准备了内建函数dgl.function.sum('m', h)

  • update function:也需要一个参数nodes参数。它通常在最后一步去结合reduce function聚合的结果和目标节点的特征去更新目标节点的特征。

  • update_all():它是一个高阶函数,融合了消息的产生、聚合和节点的更新。所以,它需要三个参数:message function,reduce function和update function。也可以在update_all()函数外调用更新函数而不用在update特别指定。DGL推荐这种在update_all()外定义更新函数,因为为了代码的简洁,update function一般写成纯张量操作。具体实现例子如下,表达式为:将节点特征ft_j和i-j相连接的边特征a_ij求和再乘以2。

    \text{final_ft}_{i}=2*\sum_{j\in{N(i)}}(ft_j*a_{ij})
    def update_all_example(graph):
        """
        	注意,这里的特征名称是一开始都设置好的。这个图本身包含了:
        	graph.ndata[ft]
        	graph.edata['a']
        	而m是临时用来存储message的
        """
        graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
        final_ft = graph.ndata['ft'] * 2
        return final_ft
    

    dgl.function里实现了很多了message functionreduce function

    接下来,介绍一个重点,update_all函数。使用这个函数,将极大简化代码。

    此处不再赘述。DGL库在message passing教程中接下来的内容都是关于如何优化使用和在不同场景下的使用,核心并没有改变。接下来,我尝试讲解DGL例子中使用Message Passing来构造GCN的例子,来更加清晰的使用Message Passing。

使用Message Passing构造GCN

​ 在这篇文章中,节点的更新过程为:

Z=\widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}X\Theta \ \ \ \ (2)

​ 这里\widetilde{A}=A+I,也就是邻接矩阵增加了自环。主要,在实践中,已经有自环的节点无需再增加自环。\widetilde{D}_{ii}=\sum_{j}\widetilde{A}_{ij}度矩阵是增加自环之后求得的。

g = dgl.remove_self_loop(g) # 增加自环时,首先去除原本的自环
g = dgl.add_self_loop(g)
degs = g.in_degrees().float() #无向图中,入度和出度是相同的。
norm = torch.pow(degs, -0.5)

​ 接下来,我们需要将节点更新过程(2)用Message Passing来表达。首先,我们要知道Message Passing的思想就是在目标节点上求得edges上的信息,然后聚合起来更新目标节点。先给出最终表达式,有个目的性,然后再一步步推导:

x_i^{k}=\sum_{j\in N(i)\cup{(i)}}\frac{1}{\sqrt{\deg(i)}\sqrt{\text{deg}(j)}}x_j^{k-1}\Theta \ \ \ (3)
以上的表达式说明,在更新第$k$层的第$i$个节点特征时,将$k-1$层第$i$个节点特征与其邻居节点$j$特征进行$\Theta$转换、度的标准化,最后求和更新。这很符合图卷积的思想:<font color="yellow">将邻居节点的信息结合起来,更新目标节点。</font>接下来将公式(2)进行分解到公式(3)。

​ 我们一步步看公式(2):

\widetilde{D}^{-\frac{1}{2}}\widetilde{A} =
\begin{bmatrix}
\frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\
0 & \frac{1}{\deg{(2)}} &\cdots &0\\
0 & 0 & \cdots& 0\\
\vdots & \vdots &\ddots &0 \\
0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}}
\end{bmatrix}*\widetilde{A}

​ Note:\frac{1}{\sqrt{\deg(i)}}是节点i的度的标准化。

​ 这里,\widetilde{D}^{-\frac{1}{2}}\widetilde{A}就相当于将\widetilde{A}的第i行的值乘以节点i的度。要进行下一步操作时,我们首先要搞清楚邻接矩阵(此文章增加了自环,但我们仍然用邻接矩阵称呼它)\widetilde{A}的意义。如下,

\begin{cases}
a_{ij}=1 & 当节点j是节点i的邻居节点,那么第i行的第j列为1 \\
a_{ij}=0 & 其它\\
\end{cases}

​ 我们令\widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}=M

\widetilde{D}^{-\frac{1}{2}}\widetilde{A} \widetilde{D}^{-\frac{1}{2}}=\begin{bmatrix}
\frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\
0 & \frac{1}{\deg{(2)}} &\cdots &0\\
0 & 0 & \cdots& 0\\
\vdots & \vdots &\ddots &0 \\
0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}}
\end{bmatrix}*\widetilde{A}*\begin{bmatrix}
\frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\
0 & \frac{1}{\deg{(2)}} &\cdots &0\\
0 & 0 & \cdots& 0\\
\vdots & \vdots &\ddots &0 \\
0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}}
\end{bmatrix} \\
\begin{bmatrix}a_{11}\frac{1}{\sqrt{\deg{(1)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{12}\frac{1}{\sqrt{\deg{(1)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{1n}\frac{1}{\sqrt{\deg{(1)}}}*\frac{1}{\sqrt{\deg{(n)}}} \\a_{21}\frac{1}{\sqrt{\deg{(2)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{22}\frac{1}{\sqrt{\deg{(2)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{2n}\frac{1}{\sqrt{\deg{(2)}}}*\frac{1}{\sqrt{\deg{(n)}}}\\
\vdots & \vdots &\ddots &\vdots \\a_{n1}\frac{1}{\sqrt{\deg{(n)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{n2}\frac{1}{\sqrt{\deg{(n)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{nn}\frac{1}{\sqrt{\deg{(n)}}}*\frac{1}{\sqrt{\deg{(n)}}}
\end{bmatrix}

​ 大家看到这里应该比较清楚了,这里的系数a_{ij}只能为0或者1,并且取决于j是否为i的节点。X\Theta仅仅只是将特征X进行特征映射了,仅仅只是改变X的列维度。那么,接下来MX的值就意味着矩阵M使用系数去选取X的值,这样就实现了选取邻居节点特征的意义。(害,其实我刚看到公式(2)的时候,内心是崩溃的,它怎么就实现抽取邻居节点特征的卷积效果,推导到公式(3)时才恍然大悟,原来是这么回事)。

​ 那么MX的每一行值,就是公式三的结果了。我们再次看下公式(3):

x_i^{k}=\sum_{j\in N(i)\cup{(i)}}\frac{1}{\sqrt{\deg(i)}\sqrt{\text{deg}(j)}}x_j^{k-1}\Theta \ \ \ (3)

​ 分析出message function、reduce function为:

  • message function:每个源节点特征乘以其度的正则
  • reduce function:将message function产生的信息求和,并且乘以目标节点的度正则。
def gcn_msg(edge):
    # 在边上的源节点上,乘以其度的正则
    msg = edge.src['h'] * edge.src['norm']
    return {'m': msg} 
def gcn_reduce(node):
    # 将目标节点的边上的信息聚合(这里是sum),再乘以目标节点上的度的正则
    # 这里的torch.sum(, dim=1)在维度1上相加,是因为node.mailbox['m']的shape = [batch, mails, feat]
    # 需要将mails整个加起来
    accum = torch.sum(node.mailbox['m'], 1) * node.data['norm'] 
    return {'h': accum} # 这时候,节点就存在一个数据node.data['h'] = accum
class NodeApplyModule(nn.Module):
    def __init__(self, out_feats, activation=None, bias=True):
        super(NodeApplyModule, self).__init__()
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_feats))
        else:
            self.bias = None
        self.activation = activation
        self.reset_parameters()

    def reset_parameters(self):
        if self.bias is not None:
            stdv = 1. / math.sqrt(self.bias.size(0))
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, nodes): # 要更新时,添加上偏置和激活函数
        h = nodes.data['h']
        if self.bias is not None:
            h = h + self.bias
        if self.activation:
            h = self.activation(h)
        return {'h': h} # 此时nodes['h'] = h 这是就被更新了
class GCNLayer(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 out_feats,
                 activation,
                 dropout,
                 bias=True):
        super(GCNLayer, self).__init__()
        self.g = g
        self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = 0.
        self.node_update = NodeApplyModule(out_feats, activation, bias)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, h):
        if self.dropout:
            h = self.dropout(h)
        self.g.ndata['h'] = torch.mm(h, self.weight) # 这里是首先将节点特征进行映射,也就是X*O
        self.g.update_all(gcn_msg, gcn_reduce, self.node_update) # 然后求message,聚合,更新。
        h = self.g.ndata.pop('h')
        return h
# add self loop
g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g)
n_edges = g.number_of_edges()

# normalization
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0
g.ndata['norm'] = norm.unsqueeze(1) # 传递到GCN的图已经存在了'norm'数据