gat和transformer

  • 2021 年 6 月 18 日
  • AI

理解了transformer之后理解gat非常简单,因为基本上非常类似的。之前有篇很火的文章介绍了transformer是一种特殊的gnn:

//towardsdatascience.com/transformers-are-graph-neural-networks-bca9f75412aatowardsdatascience.com

,一起都是从这里开始的。

直接对比来看吧:(下文的transformer特指transformer encoder)

根据数据的流动来看看gat 和transformer到底啥关系:

1、模型输入:transformer:

token embedding+position embedding,假设某个句子拆分为n个token,则该句子输入transformer的时候转化为的 (1,n,features)的形式,1表示batch size,n表示句子的token数量,features的size是embedding 的维度。

GCN和GAT :

为了便于比较这里就用单层的GCN作为演示,因为GCN的聚合计算是直接求和,包括自身,所以上图可以转化一下使其和transformer变得更像一点:

这里不要去考虑公式做拉普拉斯矩阵变换之类的,纯粹看数据流动就行。

GAT的输入部分和GCN是一样的。

可以看到,就初始数据部分来说,以上图的节点C为例,GCN或GAT的输入就是(1,n,features),1是batch size,n是节点c的邻节点数量(包括c自己),features就是node features。

因此,从输入数据的形式来看,GCN,GAT和TFM对应起来了。

2、self-attention处理输入:

transformer:

现在我们有(1,n,features)的输入,输入到multi head attention层,

然后做这个事情,

对应这部分:

multi head attention吃了一个 (1,n,features)的样本,然后吐出了 (1,n,hidden size)的表征向量。

为了更好理解还是举个例子,假设输入的是5个tokenx1,x2,x3,x4,x5,是[[1,1,1,1,1],[2,2,2,2,2],[3,3,3,3,3],[4,4,4,4,4],[5,5,5,5,5]],q,v,k转化之后,两两做attention score计算得到attention score矩阵

x1,x2,x3,x4,x5

x1 [1, 1, 1, 1, 1],

x2 [1, 1, 1, 1, 1],

x3 [1, 1, 1, 1, 1],

x4 [1, 1, 1, 1, 1],

x5 [1, 1, 1, 1, 1],

一个5X5的方阵。然后加权求和value完事儿(上述是一个单头的计算过程),hidden size的大小取决于q v k 矩阵的size和head的数量,调用transformer的keras或者torch的实现的时候直接指定就可以了,内部会自动根据头的数量来准备q v k矩阵的size。

ok,看下gat的部分:

在这里,gat其实没用到self attention,而是另一种自己定义的attention计算:

1、

transformer里用了q v k三个线性变换矩阵,gat里就用了一个W;

2、

从这里开始和transformer完全不一样了,这里经过W变换之后,两两不是像transformer一样做点积,而是直接concat然后乘以一个随机初始化的linear线性变换层,经过leakyrelu之后得到一个标量eij,最后softmax得到 alpha ij 作为node i和node j之间的attention score,

最后加权求和,得到 节点i的最终的 表征向量,

公式图片来源:

机器之心:深入理解图注意力机制zhuanlan.zhihu.com图标

多头注意力机制基本差不多,看上面链接里的代码就好了~非常好理解。

3、最终输出:multi head attention吃了一个 (1,n,features)的样本,然后吐出了 (1,n,hidden size)的表征向量,gat则是吃了一个 (1,n,features)的样本,然后吐出了 (1,1,hidden size)的表征向量


没了,

机器之心:深入理解图注意力机制zhuanlan.zhihu.com图标

感觉看这篇再扫下原论文就够了,总的来说,gat就是把聚合函数换成了一种简化版的self-attention计算,算是构建了GNN和attention机制,这两个热门的领域之间的桥梁,但是论文和代码上,

Python/dgl/nn/pytorch/conv/gatconv.py” data-draft-node=”block” data-draft-type=”link-card” class=”LinkCard old LinkCard–noImage”>//github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/gatconv.pygithub.com

都有一些遗漏,就是有权图和有向图的情况,不过也很好解决,有权图简单,attention score分别和edge weights做乘法计算就行没啥问题,有向图也简单,转化为有权只不过u-v和v-u的edge weights不同而已,通过引入edge weights的信息就可以了,单向的话有一个edge的weight=0,这样就能比较简单的将gat推广到所有的有向无向有权无权的问题中去了。