【Embedding】GraphSAGE:不得不學的圖網絡

  • 2020 年 10 月 29 日
  • AI

GraphSAGE 是 17 年的文章了,但是一直在工業界受到重視,最主要的就是它論文名字中的兩個關鍵詞:inductive 和 large graph。今天我們就梳理一下這篇文章的核心思路,和一些容易被忽視的細節。

為什麼要用 GraphSAGE

大家先想想圖為什麼這麼火,主要有這麼幾點原因,圖的數據來源豐富,圖包含的信息多。所以現在都在考慮如何更好的使用圖的信息。

那麼我們用圖需要做到什麼呢?最核心的就是利用圖的結構信息,為每個 node 學到一個合適的 embedding vector。只要有了合適的 embedding 的結果,接下來無論做什麼工作,我們就可以直接拿去套模型了。

在 GraphSAGE 之前,主要的方法有 DeepWalk,GCN 這些,但是不足在於需要對全圖進行學習。而且是以 transductive learning 為主,也就是說需要在訓練的時候,圖就已經包含了要預測的節點。

考慮到實際應用中,圖的結構會頻繁變化,在最終的預測階段,可能會往圖中新添加一些節點。那麼該怎麼辦呢?GraphSAGE 就是為此而提出的,它的核心思路其實就是它的名字 GraphSAGE = Graph Sample Aggregate。也就是說對圖進行 sample 和 aggregate。

GraphSAGE 的思路

我們提到了 sample 和 aggregate,具體指的是什麼呢?這個步驟如何進行?為什麼它可以應用到大規模的圖上?接下來就為大家用通俗易懂的語言描述清楚。

顧名思義,sample 就是選一些點出來,aggregate 就是再把它們的信息聚合起來。那麼整個流程怎麼走?看下面這張圖:

我們在第一幅圖上先學習 sample 的過程。假如我有一張這樣的圖,需要對最中心的節點進行 mebedding 的更新,先從它的鄰居中選擇 S1 個(這裡的例子中是選擇 3 個)節點,假如 K=2,那麼我們對第 2 層再進行採樣,也就是對剛才選擇的 S1 個鄰居再選擇它們的鄰居。

在第二幅圖上,我們就可以看到對於聚合的操作,也就是說先拿鄰居的鄰居來更新鄰居的信息,再用更新後的鄰居的信息來更新目標節點(也就是中間的紅色點)的信息。聽起來可能稍微有點啰嗦,但是思路上並不繞,大家仔細梳理一下就明白了。

第三幅圖中,如果我們要預測一個未知節點的信息,只需要用它的鄰居們來進行預測就可以了。

我們再梳理一下這個思路:如果我想知道小明是一個什麼性格的人,我去找幾個他關係好的小夥伴觀察一下,然後我為了進一步確認,我再去選擇他的小夥伴們的其他小夥伴,再觀察一下。也就是說,通過小明的小夥伴們的小夥伴,來判斷小明的小夥伴們是哪一類人,然後再根據他的小夥伴們,我就可以粗略的得知,小明是哪一類性格的人了。

GraphSAGE 思路補充

現在我們知道了 GraphSAGE 的基本思路,可能小夥伴們還有一些困惑:單個節點的思路是這樣子,那麼整體的訓練過程該怎麼進行呢?至今也沒有告訴我們 GraphSAGE 為什麼可以應用在大規模的圖上,為什麼是 inductive 的呢?

接下來我們就補充一下 GraphSAGE 的訓練過程,以及在這個過程它有哪些優勢。

首先是考慮到我們要從初始特徵開始,一層一層的做 embedding 的更新,我們該如何知道自己需要對哪些點進行聚合呢?應用前面提到的 sample 的思路,具體的方法來看一看算法

首先看算法的第 2-7 行,其實就是一個 sample 的過程,並且將 sample 的結果保存到 B 中。接下來的 9-15 行,就是一個 aggregate 的過程,按照前面 sample 的結果,將對應的鄰居信息 aggregate 到目標節點上來。

細心的小夥伴肯定發現了 sample 的過程是從 K 到 1 的(看第 2 行),而 aggregate 的過程是從 1 到 K 的(第 9 行)。這個道理很明顯,採樣的時候,我們先從整張圖選擇自己要給哪些節點 embedding,然後對這些節點的鄰居進行採樣,並且逐漸採樣到遠一點的鄰居上。

但是在聚合時,肯定先從最遠處的鄰居上開始進行聚合,最後第 K 層的時候,才能聚合到目標節點上來。這就是 GraphSAGE 的完整思路。

那麼需要思考一下的是,這麼簡單的思路其中有哪些奧妙呢?

GraphSAGE 的精妙之處

首先是為什麼要提出 GraphSAGE 呢?其實最主要的是 inductive learning 這一點。這兩天在幾個討論群同時看到有同學對 transductive learning 和 inductive learning 有一些討論,總體來說,inductive learning 無疑是可以在測試時,對新加入的內容進行推理的。

因此,GraphSAGE 的一大優點就是,訓練好了以後,可以對新加入圖網絡中的節點也進行推理,這在實際場景的應用中是非常重要的。

另一方面,在圖網絡的運用中,往往是數據集都非常大,因此 mini batch 的能力就非常重要了。但是正因為 GraphSAGE 的思路,我們只需要對自己採樣的數據進行聚合,無需考慮其它節點。每個 batch 可以是一批 sample 結果的組合。

再考慮一下聚合函數的部分,這裡訓練的結果中,聚合函數占很大的重要性。關於聚合函數的選擇有兩個條件:

  • 首先要可導,因為要反向傳遞來訓練目標的聚合函數參數;
  • 其次是對稱,這裡的對稱指的是對輸入不敏感,因為我們在聚合的時候,圖中的節點關係並沒有順序上的特徵。

所以在作者原文中選擇的都是諸如 Mean,max pooling 之類的聚合器,雖然作者也使用了 LSTM,但是在輸入前會將節點進行 shuffle 操作,也就是說 LSTM 從序列順序中並不能學到什麼知識。

此外在論文中還有一個小細節,我初次看的時候沒有細讀論文,被一位朋友指出後才發現果然如此,先貼一下原文:

這裡的 lines 4 and 5 in Algorithm 1,也就是我們前面給出的算法中的第 11 和 12 行。

也就是說,作者在文中提到的 GraphSAGE-GCN 其實就是用上面這個聚合函數,替代掉其它方法中先聚合,再 concat 的操作,並且作者指出這種方法是局部譜卷積的線性近似,因此將其稱為 GCN 聚合器。

來點善後工作

最後我們就簡單的補充一些喜聞樂見,且比較簡單的東西吧。用 GraphSAGE 一般用來做什麼?

首先作者提出,它既可以用來做無監督學習,也可以用來做有監督學習,有監督學習我們就可以直接使用最終預測的損失函數為目標,反向傳播來訓練。那麼無監督學習呢?

其實無論是哪種用途,需要注意的是圖本身,我們還是主要用它來完成 embedding 的操作。也就是得到一個節點的 embedding 後比較有效的 feature vector。那麼做無監督時,如何知道它的 embedding 結果是對是錯呢?

作者選擇了一個很容易理解的思路,就是鄰居的關係。默認當兩個節點距離相近時,就會讓它們的 embedding 結果比較相似,如果距離遠,那 embedding 的結果自然應該區別較大。這樣一來下面的損失函數就很容易理解了:

z_v 表示是目標節點 u 的鄰居,而 v_n 則表示不是,P_n(v) 是負樣本的分佈,Q 是負樣本的數量。

那麼現在剩下唯一的問題就是鄰居怎麼定義?

作者選擇了一個很簡單的思路:直接使用 DeepWalk 進行隨機遊走,步長為 5,測試 50 次,走得到的都是鄰居。

總結

實驗結果我們就不展示了,其實可以看到作者在很多地方都用了一些比較 baseline 的思路,大家可以在對應的地方進行更換和調整,以適應自己的業務需求。

後面我們也會繼續分享 GNN 和 embedding 方面比較經典和啟發性的一些 paper,歡迎大家持續關注~~~

PS:看到這裡的都是真愛了,悄悄說一句謝謝大家的信任,我這個懶人都這麼久沒更了,每天還有新小夥伴們持續關注,實在是太感動了,哭花臉o(╥﹏╥)o

點個再看嘛,不點就點個贊嘛,實在不點那我也只能算了嘛,再哭花臉(╥╯^╰╥)

一文搞懂 PyTorch 內部機制

一篇長文學懂 pytorch

一個例子告訴你,在 pytorch 中應該如何並行生成數據