gat和transformer
- 2021 年 6 月 18 日
- AI
理解了transformer之後理解gat非常簡單,因為基本上非常類似的。之前有篇很火的文章介紹了transformer是一種特殊的gnn:
//towardsdatascience.com/transformers-are-graph-neural-networks-bca9f75412aa
,一起都是從這裡開始的。
直接對比來看吧:(下文的transformer特指transformer encoder)
根據數據的流動來看看gat 和transformer到底啥關係:
1、模型輸入:transformer:
token embedding+position embedding,假設某個句子拆分為n個token,則該句子輸入transformer的時候轉化為的 (1,n,features)的形式,1表示batch size,n表示句子的token數量,features的size是embedding 的維度。
GCN和GAT :
為了便於比較這裡就用單層的GCN作為演示,因為GCN的聚合計算是直接求和,包括自身,所以上圖可以轉化一下使其和transformer變得更像一點:
這裡不要去考慮公式做拉普拉斯矩陣變換之類的,純粹看數據流動就行。
GAT的輸入部分和GCN是一樣的。
可以看到,就初始數據部分來說,以上圖的節點C為例,GCN或GAT的輸入就是(1,n,features),1是batch size,n是節點c的鄰節點數量(包括c自己),features就是node features。
因此,從輸入數據的形式來看,GCN,GAT和TFM對應起來了。
2、self-attention處理輸入:
transformer:
現在我們有(1,n,features)的輸入,輸入到multi head attention層,
然後做這個事情,
對應這部分:
multi head attention吃了一個 (1,n,features)的樣本,然後吐出了 (1,n,hidden size)的表徵向量。
為了更好理解還是舉個例子,假設輸入的是5個tokenx1,x2,x3,x4,x5,是[[1,1,1,1,1],[2,2,2,2,2],[3,3,3,3,3],[4,4,4,4,4],[5,5,5,5,5]],q,v,k轉化之後,兩兩做attention score計算得到attention score矩陣
x1,x2,x3,x4,x5
x1 [1, 1, 1, 1, 1],
x2 [1, 1, 1, 1, 1],
x3 [1, 1, 1, 1, 1],
x4 [1, 1, 1, 1, 1],
x5 [1, 1, 1, 1, 1],
一個5X5的方陣。然後加權求和value完事兒(上述是一個單頭的計算過程),hidden size的大小取決於q v k 矩陣的size和head的數量,調用transformer的keras或者torch的實現的時候直接指定就可以了,內部會自動根據頭的數量來準備q v k矩陣的size。
ok,看下gat的部分:
在這裡,gat其實沒用到self attention,而是另一種自己定義的attention計算:
1、
transformer里用了q v k三個線性變換矩陣,gat里就用了一個W;
2、
從這裡開始和transformer完全不一樣了,這裡經過W變換之後,兩兩不是像transformer一樣做點積,而是直接concat然後乘以一個隨機初始化的linear線性變換層,經過leakyrelu之後得到一個標量eij,最後softmax得到 alpha ij 作為node i和node j之間的attention score,
最後加權求和,得到 節點i的最終的 表徵向量,
公式圖片來源:
多頭注意力機制基本差不多,看上面鏈接里的程式碼就好了~非常好理解。
3、最終輸出:multi head attention吃了一個 (1,n,features)的樣本,然後吐出了 (1,n,hidden size)的表徵向量,gat則是吃了一個 (1,n,features)的樣本,然後吐出了 (1,1,hidden size)的表徵向量
沒了,
感覺看這篇再掃下原論文就夠了,總的來說,gat就是把聚合函數換成了一種簡化版的self-attention計算,算是構建了GNN和attention機制,這兩個熱門的領域之間的橋樑,但是論文和程式碼上,
都有一些遺漏,就是有權圖和有向圖的情況,不過也很好解決,有權圖簡單,attention score分別和edge weights做乘法計算就行沒啥問題,有向圖也簡單,轉化為有權只不過u-v和v-u的edge weights不同而已,通過引入edge weights的資訊就可以了,單向的話有一個edge的weight=0,這樣就能比較簡單的將gat推廣到所有的有向無向有權無權的問題中去了。