gat和transformer
- 2021 年 6 月 18 日
- AI
理解了transformer之后理解gat非常简单,因为基本上非常类似的。之前有篇很火的文章介绍了transformer是一种特殊的gnn:
//towardsdatascience.com/transformers-are-graph-neural-networks-bca9f75412aa
,一起都是从这里开始的。
直接对比来看吧:(下文的transformer特指transformer encoder)
根据数据的流动来看看gat 和transformer到底啥关系:
1、模型输入:transformer:
token embedding+position embedding,假设某个句子拆分为n个token,则该句子输入transformer的时候转化为的 (1,n,features)的形式,1表示batch size,n表示句子的token数量,features的size是embedding 的维度。
GCN和GAT :
为了便于比较这里就用单层的GCN作为演示,因为GCN的聚合计算是直接求和,包括自身,所以上图可以转化一下使其和transformer变得更像一点:
这里不要去考虑公式做拉普拉斯矩阵变换之类的,纯粹看数据流动就行。
GAT的输入部分和GCN是一样的。
可以看到,就初始数据部分来说,以上图的节点C为例,GCN或GAT的输入就是(1,n,features),1是batch size,n是节点c的邻节点数量(包括c自己),features就是node features。
因此,从输入数据的形式来看,GCN,GAT和TFM对应起来了。
2、self-attention处理输入:
transformer:
现在我们有(1,n,features)的输入,输入到multi head attention层,
然后做这个事情,
对应这部分:
multi head attention吃了一个 (1,n,features)的样本,然后吐出了 (1,n,hidden size)的表征向量。
为了更好理解还是举个例子,假设输入的是5个tokenx1,x2,x3,x4,x5,是[[1,1,1,1,1],[2,2,2,2,2],[3,3,3,3,3],[4,4,4,4,4],[5,5,5,5,5]],q,v,k转化之后,两两做attention score计算得到attention score矩阵
x1,x2,x3,x4,x5
x1 [1, 1, 1, 1, 1],
x2 [1, 1, 1, 1, 1],
x3 [1, 1, 1, 1, 1],
x4 [1, 1, 1, 1, 1],
x5 [1, 1, 1, 1, 1],
一个5X5的方阵。然后加权求和value完事儿(上述是一个单头的计算过程),hidden size的大小取决于q v k 矩阵的size和head的数量,调用transformer的keras或者torch的实现的时候直接指定就可以了,内部会自动根据头的数量来准备q v k矩阵的size。
ok,看下gat的部分:
在这里,gat其实没用到self attention,而是另一种自己定义的attention计算:
1、
transformer里用了q v k三个线性变换矩阵,gat里就用了一个W;
2、
从这里开始和transformer完全不一样了,这里经过W变换之后,两两不是像transformer一样做点积,而是直接concat然后乘以一个随机初始化的linear线性变换层,经过leakyrelu之后得到一个标量eij,最后softmax得到 alpha ij 作为node i和node j之间的attention score,
最后加权求和,得到 节点i的最终的 表征向量,
公式图片来源:
多头注意力机制基本差不多,看上面链接里的代码就好了~非常好理解。
3、最终输出:multi head attention吃了一个 (1,n,features)的样本,然后吐出了 (1,n,hidden size)的表征向量,gat则是吃了一个 (1,n,features)的样本,然后吐出了 (1,1,hidden size)的表征向量
没了,
感觉看这篇再扫下原论文就够了,总的来说,gat就是把聚合函数换成了一种简化版的self-attention计算,算是构建了GNN和attention机制,这两个热门的领域之间的桥梁,但是论文和代码上,
都有一些遗漏,就是有权图和有向图的情况,不过也很好解决,有权图简单,attention score分别和edge weights做乘法计算就行没啥问题,有向图也简单,转化为有权只不过u-v和v-u的edge weights不同而已,通过引入edge weights的信息就可以了,单向的话有一个edge的weight=0,这样就能比较简单的将gat推广到所有的有向无向有权无权的问题中去了。