TDM 三部曲 (與 Deep Retrieval)

推薦系統的主要目的是從海量物品庫中高效檢索用戶最感興趣的物品,既然是「海量」,意味著用戶基本不可能瀏覽完所有的物品,所以才需要推薦系統來輔助用戶高效獲取感興趣的資訊。同樣也正是因為「海量」,由於算力的限制,複雜模型也是很難直接遍歷每個物品算出分數來排序。如今的推薦系統通常大致分為召回 (retrieval) 和排序 (ranking) 兩個階段,召回是從全量物品庫中快速得到一個候選集合,通常是幾百到幾千,後面的排序模組則使用更複雜的模型對候選集排序得到 top-k 物品推薦給用戶。

召回需要在速度和準確性上作平衡,其結果很大程度上決定了推薦的上限。如果其返回的候選集中沒有包含用戶感興趣的物品,那麼後面的排序模型能力再強也沒用。但是受速度的限制,長期以來的主流做法是使用簡單模型如物品協同過濾,或者獲取 embedding 後轉換成向量最近鄰搜索問題。這種方案在模型表達能力上有一定的局限,而且近鄰搜索與實際的目標 (如提升點擊率) 在優化方向上不一定一致。想要在召回中直接使用複雜模型特別是近幾年湧現出來的各種深度學習模型作推理,在這個領域很長一段時間來都沒什麼大的進展。

不過局面終有一天會被打破,本篇介紹的這些近幾年公開的演算法皆是致力於探索在大規模召回問題中直接使用複雜模型。標題中的 TDM 三部曲指的是以 TDM 為首的三篇圍繞樹結構的論文:

Deep Retrieval 指的是論文:

這裡需要指出的是雖然從論文發表時間來看是 TDM -> JTM -> OTM,但 OTM 嚴格意義上不能算是 JTM 的改進版本。因為 TDM 的訓練大致分為兩步:樹的學習和模型的學習,JTM 改進的是前者,而 OTM 改進的是後者,因而 JTM 和 OTM 看上去更像是同父異母的姐妹。

現在回到最開始的問題,是什麼制約了召回中複雜模型的使用?複雜模型不可避免地使得線上單個樣本的計算時間增大,那麼遍歷全量物品庫顯然不可承受。如果把召回看作是一個檢索的過程,即從全量庫中檢索符合條件的物品,那麼可以產生一些新的思路。我們知道在傳統資料庫中可以通過添加索引來極大增加查詢效率,那麼在召回中是否也可以遷移這種思想?TDM 和 Deep Retrieval 論文的核心就是立足於如何構建這樣一套高效的索引結構來增加檢索效率,從而使得在召回中直接使用複雜模型成為可能。

下面逐一說明這些演算法的內部原理,另外從應用的角度也會講一些實現細節,完整程式碼見//github.com/massquantity/dismember 。TDM 有官方的開源實現,而我的 TDM 實現在原版基礎上未做過多修改,相當於將原版的 Python2 和 C++ 程式碼用 Scala 重寫了一遍。是的,我詫異地發現原版用的貌似是 Python2 。而 JTM、OTM 和 Deep Retrieval 甚至沒找到什麼開源的實現 (不排除以後有),因此也就自由發揮了。

TDM


如上文所述,TDM 通過樹這種數據結構來構建索引。照原論文里的說法是可以使用多叉樹,但無論是論文還是官方實現使用的都是二叉樹,所以這裡僅討論二叉樹的情況。考慮下圖的這棵樹,每一個物品對應著樹上的一個葉節點 (圖中 id 為 7 到14),我們的目標是得到用戶可能偏好最大的 K 個物品,即最底層的 K 個葉節點。用戶的偏好可用 \(p(n|u)\) 來表示,意為用戶 \(u\) 對於節點 \(n\) 感興趣的概率。這實際上就是個二分類問題,將用戶特徵和節點 id 輸入某個模型如深度神經網路就可以得到這個概率,按概率排序後進而得到 top-K 個物品。現在的問題是如何高效地得到這些葉節點的概率?

雖然圖中最底層只有 8 個葉節點,但實際場景中可能會有百萬到上億個物品,所以直接用複雜模型遍歷葉節點檢索是不大可行的。論文中採用的是 beam search 的方法從根節點 (root node) 開始逐層挑選 top-K 節點,而挑選的依據正是用戶對每個節點的偏好 \(p(n|u)\),然後將這些 top-K 節點的子節點作為下一層的候選節點,一直到最後一層。二叉樹有一個很好的性質,如果一個節點 id 是 \(n\) ,那麼其葉子結點是 \(2 * n + 1\)\(2 * n + 2\) ,一次遍歷就能取得當前層節點的所有葉子結點。這相當於每個選中的節點有兩個候選子節點,那麼整體的計算次數是 \(2 * k * \text{log}|C|\) ,其中 \(|C|\) 是所有物品集的數量。若物品總量為 1 億,\(k\) 為 10,那麼推理一次需要計算 \(2 * 10 * \text{log}_2(100000000) \approx 532\) 次,對比原來的 1 億次下降了不止一個數量級,時間複雜度從 \(\mathcal{O}(n)\) 下降到 \(\mathcal{O}(\text{log}(n))\),這樣複雜模型就可以使用了。

不過上述流程很容易會給人帶來一個疑惑,這樣層層檢索下來,如何保證最終得到的葉節點一定是 \(p(n|u)\) 最大的 K 個?為了解答這個問題論文里引入了一個興趣最大堆樹的概念,直接衍生自傳統數據結構中的堆 (heap):

\[p^{(j)}(n|u) = \frac{\max\limits_{n_c \,\in\, \{n\text{‘s children nodes in level } j\text{+1}\}} \;p^{(j+1)}(n_c|u)}{\alpha^{(j)}} \tag{1.1}
\]

這個公式意為每個節點的 \(p(n|u)\) 都等於其所有子節點 \(p(n_c|u)\) 的最大值,\(j\) 代表樹的第 \(j\) 層,\(\alpha^{(j)}\) 是一個歸一化參數可以被忽略。由於 beam search 在每一層都會搜索到 \(p(n|u)\) 最大的 K 個節點,那麼滿足了這個性質之後,這 K 個節點的子節點也一定包含了下一層的 top-K,這樣一直檢索到最後一層就能得到 top-K 的葉節點了。可以看到樹的結構充當著索引的作用,讓檢索過程中能直接跳過眾多不相干的節點。

以上更多地是 TDM 訓練完後的推理 (檢索) 過程,那麼接下來的問題是 TDM 的訓練。TDM 整個體系大致分為兩個部分 —— 模型和樹,那麼訓練也是分別要學習這兩樣東西。這裡的模型作用是計算用戶 \(u\) 對於某個節點 \(n\) 的偏好 \(p(n|u)\) ,如前文所述(幾乎)可以是任意複雜模型,所以論文里果然上了一個帶有時序特徵和 attention 的複雜深度學習模型,具體模型結構這裡就不細述了,因為不是 TDM 的重點,僅談如何在樹結構確定的情況下訓練這個模型。

通常模型的訓練離不開樣本,所以這裡的關鍵是先要構造樣本,而後用 TensorFlow 之類的框架搭個神經網路訓練就比較常規了。將用戶交互過的物品 (葉節點) 設為正樣本,由於樹結構已預先確定,那麼每個葉節點的祖先節點也就確定了,而根據上面的最大堆性質,正樣本葉節點的祖先節點也同樣為正樣本,同時在各層隨機取樣一些除正樣本以外的節點作為負樣本。對於一棵二叉樹,根節點 id 為 0,於是參考上圖第 \(j\) 層的節點 id 範圍為左閉右開的 \([2^j – 1, \,2^{j+1}-1)\) ,那麼每一層在這個範圍內取樣就可以了。

模型訓練完後,接下來看樹結構的學習。所謂的樹結構,說地直白一點就是物品 id 和樹葉節點 id 的一一對應關係。對於二叉樹節點 \(n\) ,其父節點是 \((n – 1) / 2\) ,那麼依次上溯就可得到節點 \(n\) 所有的祖先節點。所以這裡我們只需要關心物品和葉節點的對應關係,這層關係確定後上面的祖先節點也都確定了。關於樹結構學習的具體方法,論文里說的是將所有物品的 embedding 向量遞歸地使用 k-means 聚類來確定最終的葉節點分布,而物品的 embedding 則來自於前面模型的 embedding 層。不過說實話光看論文里的描述很難搞得明白 TDM 的這棵樹究竟是怎麼聚類得到的,所以下面我們來看具體實現。

不得不說論文里只是給了個模糊的框架,而真正寫程式碼的時候又牽扯出了諸多細節,主要體現在樹的構建和操作上。首先來看上文中的樹聚類學習,我覺得這個過程可以這麼解釋:目標是將所有物品分配到各個葉節點,但是直接分配不可行,所以先將所有物品分配到根節點 0,然後通過聚類將所有物品分成兩類,分別分配給根節點的子節點,也就是一半的物品分配到節點 1,另一半分配到節點 2 。然後再對屬於節點 1 和 2 的物品分別聚類,繼續將物品往下分配到各自的子節點,這是一個遞歸的過程,最後在葉節點只分配到一個物品的時候終止。為了保證樹的平衡性,每次聚類的時候都要進行再平衡,即保證聚類出來兩個子類的物品數量一樣,具體方法是計算每個物品到聚類質心的距離,距離最遠的幾個會被調整到另外一個子類。如果你熟悉後面的 JTM,會發現這整個層層分配 + 再平衡的操作和 JTM 的步驟如出一轍,不同之處在於 TDM 中分成兩類的依據是聚類,而 JTM 中是目標函數,這一點後文再述。

基於以上流程,在各個節點上聚類是可以並行的,原版實現用的是 Python 多進程,通過隊列(Queue)和管道(Pipe)進行進程間通訊。不過對於這樣一種將一個大任務遞歸地層層拆分成多個子任務的並行計算,Java 7 中增加的 ForkJoinPool 看來是更適合的選擇,而在 Scala 實現中則可以直接使用 Java 的類庫。另外值得一提的是論文中提到聚類的原始方案是譜聚類 (spectral clustering),但因為計算複雜度太高所以改成了 K-means ,在我的數據集上試驗下來譜聚類的效果確實比 K-means 好一點,當然耗時也長得多。

接下來考慮一下在樹上檢索的流程。因為論文里討論的是一般情況,所以給的演算法流程里是從根節點開始檢索,但實際上並不需要。假設要獲取 top-3 的物品,每一層 beam search 的候選節點數是 6,那麼完全可以跳過前幾層而直接從 level 3 開始檢索,因為上方的比如 level 2 只包含 4 個節點,beam search 的時候肯定會全包括進去而不需要按偏好排序。召回中一般需要取幾百到幾千的物品,這樣可以跳過開始的很多層,從而節省計算資源和加速推理。同樣程式碼里有一個參數 start_sample_level ,表示開始進行負取樣的層,如果推理的時候前幾層的節點不需要包括,那麼這幾層也同樣不需要取樣和訓練了。

最後再來看一個細節,論文在描述的時候給出了一棵樹的圖例:

這是一顆滿二叉樹,即每一層的節點數都達到了最大值。然而我相信大部分第一次看論文的人都不會注意到的一點是,物品的數量不會正好是 2 的 n 次冪,那麼就不會覆蓋完樹的最後一層,也就是極有可能會出現下面這種情況:

這是一顆完全二叉樹,也滿足論文里的描述,然而這種葉節點不是在同一層的樹在實現上並不是很友好,比如在向上取樣時沒法規定一個統一的起始取樣層,以及檢索的時候如果指定了最大高度則容易跳過倒數第二層的葉節點。那麼原版實現是怎麼處理這個問題的呢?就是強行把所有葉節點都拉平到最下面同一層,見程式碼 ,對應到我的程式碼

JTM


前文提到在 TDM 中樹的學習採用的是一種層次化聚類的方式,並沒什麼理論依據,論文里也說這只是一個直覺性 (intuitively) 的方法。這種層次化聚類方法的問題”直覺上”與常用的向量最近鄰方法類似,即模型和最近鄰搜索的優化方向不一致。前者優化的是用戶 \(u\) 對節點的偏好 \(p(n|u)\),而後者優化的是向量相似度。所以從這個角度上來說 TDM 中的層層聚類也是在向量相似度上作文章。

顯然作者認為這樣拍腦袋出來的方法是不大合理的,所以才有了 JTM 的出現。其核心思路還是比較直接的,就是讓模型和樹優化同一個目標。模型優化部分和 TDM 中的差不多,變化的僅是樹的學習這一部分。

對比 JTM 論文中的這張圖與前面 TDM 中的樹的圖,最顯著的不同是右圖最下方出現了一個物品和葉節點的映射函數 \(\pi(\cdot)\) 。前面講 TDM 的時候提到過樹結構取決於物品 id 和葉節點 id 的一一對應關係,這一點實際上是在 JTM 論文中被明確提出來的。有了這個之後統一優化目標為:

\[\mathcal{L}(\theta, \pi) = -\sum\limits_{i=1}^n\sum\limits_{j=0}^{l_{max}}\,\text{log}\,\hat{p}(b_j(\pi(c_i))|u_i;\theta,\pi) \tag{2.1}
\]

對於第 \(i\) 個正樣本 \((u_i, c_i)\)\(u_i\) 為用戶,\(c_i\) 為其感興趣的物品,那麼 \(c_i\) 通過 \(\pi(\cdot)\) 映射到某一個葉節點即 \(\pi(c_i) = n_i\)\(b_j(\cdot)\) 為某一節點到樹的第 \(j\) 層祖先節點的映射。於是上式的意思是最大化正樣本中用戶與節點偏好的概率,而這裡的節點包括物品對應的葉節點和相應的祖先節點,由於損失函數一般為最小化,所以上面採用的是 \(-\text{log}(\cdot)\)

\((2.1)\) 式代表的是所有物品的目標函數,而對於單個物品 \(c_i\) 來說並不需要囊括所有的樣本,而只需要計算其為目標物品 (target item) 的樣本。於是設 \(\mathcal{A_i}\) 為目標物品是 \(c_i\) 的所有樣本,則 \(c_i\) 的目標函數為:

\[\mathcal{L}_{c_i, \pi(c_i)} = -\sum\limits_{(u,c) \in \mathcal{A}_i}\sum\limits_{j=0}^{l_{max}}\,\text{log}\,\hat{p}(b_j(\pi(c_i))|u;\theta,\pi) \tag{2.2}
\]

TDM 中訓練模型來優化 \((2.1)\) 式的方法是葉節點上溯得到所有祖先節點作為正樣本,同時在每一層隨機取樣另外的節點作為負樣本。而在 JTM中這部分可變可不變,放到後面講實現的時候再說明。這裡先假設解決了模型優化的問題後,樹如何學習來同樣滿足 \((2.1)\) 式就成為了 JTM 的核心。

JTM 中樹的學習簡單來說就是貪心 + 試錯法,上面的映射函數 \(\pi(\cdot)\) 在具體的實現中就是一個 map (或者是 Python 里的 dict),將物品 \(c\) 映射到葉節點 \(n\) 。那麼所謂的試錯法就是把一個物品映射到每一個可能的節點,分別計算 \((2.2)\) 式,最後將物品映射到值最大的那個節點。然而直接使用這種方法過於簡單粗暴,計算量非常大。注意 \((2.1)\) 式中的兩個加和涉及到所有的樣本和所有的層,假設有 1 百萬樣本,1 萬種物品,那麼可能的葉節點位置也為 1 萬,而樹的層數為 \(\text{log}_2(10000) \approx 14\) 。對於所有物品 ,就需要計算 \(1000000 \times 14 \times 10000 = 1.4 \times 10^{11}\) 次才能得到最佳的 \(\pi(\cdot)\)。一般推薦系統里肯定遠遠不止 1 百萬樣本和 1 萬物品,所以總體計算量會快速增長到不可承受。

可以看到上面對於 JTM 的計算可分為三個部分 —— 總樣本數、樹的層數以及候選節點數。論文中提出的貪心法主要是通過減少後兩者來降低整體計算量。先將所有物品都映射到樹的根節點即 \(\pi(c_i) = 0\) ,再每隔 \(d\) 層將物品分配到對應層的子節點,一直到最後一層每一個物品都分配到一個葉節點。下式代表從 \(s\) 層到 \(d\) 層的目標函數:

\[\mathcal{L}_{c_i, \pi(c_i)}^{s,d} = -\sum\limits_{(u,c) \in \mathcal{A}_i}\sum\limits_{j=s}^d\,\text{log}\,\hat{p}(b_j(\pi(c_i))|u;\theta,\pi) \tag{2.3}
\]

我們來看下這個方法是怎麼降低計算量的。依然以上面的例子,原來的方案需要計算所有層 14,加上所有的葉節點位置 1 萬。貪心法需要計算 \(d\) 層,\(d\) 是一個超參數,理論上 \(d\) 越大越精確,但相應的計算量也越大,當 \(d = 14\) 時就和原來的一樣了。而利用二叉樹本身的特點,一個節點往下 \(d\) 層的子節點數是 \(2^d\) 個,論文中給的例子是 \(d = 7,\; 2^d = 128\) ,相比於原來的 1 萬就小了很多。

前文提到過樹的平衡對於檢索效率很重要,因此每分配完 \(d\) 層後,論文中還加了一個再平衡 (rebalance) 操作。如果僅僅是通過計算 \((2.2)\) 式來分配物品到節點,很可能出現的情況是某個節點分配了超多的物品,那麼學習出來的樹會變得非常不平衡,而再平衡的目的就是使得一個節點可分配的物品數不超過 \(2^{l_\max – d}\) 個。這個再平衡操作使得實現的複雜度上了一個台階,想要分配一個物品,並不是每個節點計算一遍 \((2.2)\) 式取值最大的分配就好了,而是需要把所有計算值保存下來並排序,再通過 rebalance 將超過數量的物品分配到別的空閑節點上,具體實現見程式碼

在 JTM 的具體實現中還有幾個點值得討論。首先,JTM 中的模型優化和 TDM 中的是否是一樣的?如果嚴格按照論文里的那應該是不一樣的,因為 TDM 用的是二分類而 JTM 用的是多分類,然而使用多分類至少可能產生兩個問題,都和 softmax 的計算有關。一是 softmax 的分母計算應該包含哪些類別?這在論文中沒有明確說明,如果使用當前層的所有節點作為類別,那麼每一層使用的模型就變得不一樣了,因為每一層的節點數是不一樣的,這樣無論訓練還是推理都會帶來更大的複雜性。另一個是為了緩解 softmax 訓練計算量大的問題,論文中明確提到了使用常見的 NCE 來取樣訓練,然而這類取樣方法通常只適用於訓練,實際的推理過程中仍然需要計算全量 softmax,這樣利用樹結構來加速檢索的效果會大打折扣。基於以上考慮,我的 JTM 實現中仍然沿用了 TDM 的二分類模型訓練,那麼相應的樹學習中計算 \((2.2)\) 式也就是二分類模型輸出的概率。而實際上並不需要計算概率,我們需要的是相對大小並排序,那麼只需要計算模型的標量輸出 logit 就可以了。

這裡我不負責任地猜測一下作者為什麼要在論文里強行上這麼個不好實現的多分類目標函數,最可能的原因是這樣寫能讓提出的理論更加「優雅」。JTM 的核心是模型和樹優化同一個目標函數,如果使用二分類那麼這個公式可能就沒法寫得非常統一了,至少不那麼一目了然,讀者倒回去看一下 TDM 論文中的公式 \((4)\) 就明白了。

其次 JTM 論文的 \(3.2\) 節末尾簡略提了一句,”Furthermore, each sub-task can run in parallel to further improve the efficiency” 。雖然只有一句話,但實現中這一點其實挺重要的,因為 JTM 的貪心法雖然降低了很多計算量,但如果想算得精確一些 \(d\) 就不能取得太小,而 \(d\) 越大計算量也越大,所以利用並行計算來加速樹的學習是有必要的。然而論文里也沒說具體的 sub-task 究竟是什麼,只能我自己猜了。

在樹學習 (Tree Learning) 這個演算法 (論文中的 Algorithm 2) 中大致有兩個可以並行的地方,即節點內並行和節點間並行。前者指的是同一個節點內的所有 item 在往下 \(d\) 層分配子節點時並行;後者指的是同一層的節點之間並行。假設設置的最大並行度為 16,那麼對於靠近根節點的幾層可以使用節點內並行,因為 0 – 3 層的節點數都小於 16 ,如果使用節點間並行則無法達到最大並行度,而 4 層以下則可以使用節點間並行。

另外我發現節點間並行還有另外一種實現思路,那就是非同步學習。上面的方法其實是一種同步學習,也就是每一層都要等待該層所有的節點都分配好了,再繼續往下 \(d\) 層分配,如下圖 level 2 的 4 個節點就需要相互等待:

但實際上每層節點往下 \(d\) 層分配一直到最後一層,這個過程的每個節點之間是相互獨立互不影響的,那麼每個節點一路分配到最後一層的過程可視為一個 sub-task ,同一層的節點之間就不需要相互等待了,如下圖中每一個框內就是一個 sub-task ,4 個可以並行計算,對應程式碼為 JTMAsync

OTM


OTM 這篇論文,乍看上去比較理論化不大好懂 (與其他幾篇比起來),但核心 idea 卻很簡潔明了,即解決訓練和測試數據分布不一致的問題。回憶一下 TDM 中的模型訓練數據來自於正樣本葉節點及其祖先節點,以及每一層取樣的負樣本節點。然而實際推理過程中用的是自頂而下的 beam search,每一層保留 top-K 節點,這樣推理中經過的節點和訓練過程中使用的樣本節點可能分布截然不同,致使最終召回效果下降。

因而 OTM 在模型訓練時捨棄了 TDM 的這套構造樣本的方式,而是直接使用當前模型在樹上作 beam search,得到的每一層 top-K 節點作為訓練樣本。那麼接下來的問題是得到的這些樣本,哪些是正樣本哪些是負樣本呢?

如果想要偷懶點,可以直接採用類似 TDM 的模式,將 beam search 得到的節點中屬於正樣本祖先節點的設為正樣本,其餘的則設為負樣本。然而作者認為這樣並不能保證最後得到的葉節點一定是用戶偏好 \(p(n|u)\) 最大的 K 個。為了證明這一點 (以及其他相關的) 論文里洋洋洒洒上了一大坨,甚至很多證明還都放到了另外的補充材料中 (supplemental material) 中。最後得出來的結論是節點的標籤 \(z_n^*\) 滿足下式才是最優的:

\[z_n^* = y_{\pi(n’)}, \;n’ \in \mathop{\text{argmax}}_{n’ \in \mathcal{L}(n)} \, \eta_{\pi(n’)}(\bold{x}) \tag{3.1}
\]

其中 \(\pi(\cdot)\) 為上文 JTM 中提到的物品到節點的映射, \(\mathcal{L}(n)\) 為節點 \(n\) 對應的所有葉節點,\(\eta_{n}(\bold{x}) = p(y_{n} = 1|\bold{x})\) 為模型節點 \(n\) 的預測概率。那麼 \((3.1)\) 式的意思是節點 \(n\) 的標籤取決於模型對於其所有葉節點中預測概率最大的那個。論文中稱 \(z_n^*\) 為 pseudo target ,並配合下圖對提出的核心 idea 作了說明。

最底下的一層標號 1 – 8 的為物品,跨過映射函數 \(\pi(\cdot)\) 映射到了樹的根節點 7 – 14 。圖 \((\rm{a})\) 中的紅色節點為 TDM 中採用的正樣本上溯得到的訓練節點,對照圖 \((\rm{b})\) 中的藍色節點為實際 beam search 中的每層 top-K 節點,不同的流程導致二者的節點分布可能差別很大。而圖 \((\rm{c})\) 則顯示了 pseudo target 的生成過程,與 TDM 不同,OTM 中並不是每個正樣本的祖先節點也都是設為正樣本,比如節點 6 在圖 \((\rm{a})\) 中是正樣本,而在圖 \((\rm{c})\) 中則是負樣本,因為其葉節點為 13 和 14,而 \(\eta_{13}(\bold{x})= 0.5 > \eta_{14}(\bold{x}) = 0.4\) ,所以根據 \((3.1)\) 式節點 6 的 pseudo target 應和節點 13 相同,即為 0 。

然而直接根據 \((3.1)\) 式算出所有節點的 pseudo target 是不現實的, 因為計算一個節點需要遍歷該節點的所有葉節點得出最大值,而像上層的一些節點幾乎牽涉到了樹的所有葉節點。因此論文中提出的方案是每一層節點的 pseudo target 取決於其子節點的預測概率較大的那個:

\[\hat{z}_n({\bold{x}};\boldsymbol{\theta}) = \hat{z}_{n’}(\bold{x};\boldsymbol{\theta}), \; n’ \in \mathop{\text{argmax}}_{n’ \in \,\mathcal{C}(n)}\, p_{g_{\boldsymbol{\theta}}}(z_{n’} = 1|\bold{x}) \tag{3.2}
\]

其中 \(\mathcal{C}(n)\) 表示節點 \(n\) 的子節點。對於二叉樹來說,一個節點的子節點只有兩個,計算量就小了很多。葉節點因為沒有子節點,所以其 pseudo target 取決於數據本身 \(\hat{z}_n({\bold{x}};\boldsymbol{\theta}) = y_{\pi(n)}\) ,即正樣本對應的葉節點為 1,負樣本為 0 ,那麼從葉節點自底而上計算 \((3.2)\) 式就能得到樹上任意節點的 pseudo target 。

本篇開頭提到過,OTM 改進的是 TDM 中模型學習這一部分,那麼樹的學習這一部分論文中是直接沿用 JTM 的方法。OTM 的核心 idea 雖然簡潔明了,但其真正的實現還是比較複雜的,其複雜性主要來源於樣本的構造,因為已經不是 TDM 那樣簡單的節點上溯和負取樣了。首先看一下論文中給出的 Algorithm 1:

說實話論文中的這個演算法流程我看著是有點奇怪的。注意第 4 和第 5 步使用的都是 \(\boldsymbol{\theta} _t\) ,即模型上一輪的固定參數,而 \(\tilde{\mathcal{B}}_h(\bold{x};\boldsymbol{\theta}_t)\) 下標是 \(h\) 也就是樹的第 \(h\) 層,那麼這個流程的意思是 beam search 過程中每一層都計算 \(\tilde{\mathcal{B}}_h(\bold{x};\boldsymbol{\theta}_t)\)\(\hat{z}_n({\bold{x}};\boldsymbol{\theta}_t)\) 然後更新模型參數(第 6 步)?這樣豈不是 beam search 進行下一層計算的時候模型參數就不是上一輪的固定參數了?抑或是論文里說的固定參數範圍僅限定於 beam search 中的一層而不是整個 beam search 過程?

而且如果嚴格按照論文中的演算法流程,beam search 得到的每一層節點都單獨計算 \(\hat{z}_n({\bold{x}};\boldsymbol{\theta}_t)\) 勢必會產生很多重複計算,因為每次計算 pseudo target 都要從葉節點開始上溯。所以我在實現中每次真正更新模型參數前先將一批數據中所有層的 pseudo target 和 beam search 節點都計算好。這樣既能使用上一輪的固定參數模型,又能一次性不重複地計算完所有的 pseudo target。

根據論文的補充材料 (supplementary material) 顯示,第 5 步中只需要為滿足 \(n \in \tilde{\mathcal{B}}_h(\bold{x};\boldsymbol{\theta}) \bigcap \mathcal{S}^+_h(\bold{y})\) 的節點計算 pseudo target ,而對於 \(n \in \tilde{\mathcal{B}}_h(\bold{x};\boldsymbol{\theta}) \,\backslash\, \mathcal{S}^+_h(\bold{y})\) 節點的 \(\hat{z}_n({\bold{x}};\boldsymbol{\theta})\) 可直接設為 0 。\(\mathcal{S}^+_h(\bold{y})\) 代表正樣本節點在 \(h\) 層的祖先節點,那麼這裡的意思是每一層 beam search 得到的節點,只有與正樣本的祖先節點有重合的才需要計算 pseudo target 。

綜上所述,我認為效率最高的訓練流程是先從樹的葉節點自底而上計算每一層正樣本祖先節點的 pseudo target ,再從根節點自頂而下進行 beam search 獲取訓練節點,最後在訓練節點中搜索是否存在正樣本祖先節點,如果存在則把節點 label 設為相應的 pseudo target,若不存在則 label 為 0 。訓練節點的 label 都確定後就可以使用這些節點正式更新模型參數。

論文中還有一點值得注意,TDM 中一個樣本只需要單個 label,在 OTM 中擴增到了一個樣本多 label 的情況,若用論文中的符號表示則分別對應 \(|\mathcal{I}_\bold{x}| = 1\)\(|\mathcal{I}_\bold{x}| \geqslant 1\) 。之前在看 TDM 論文的時候就有這個疑惑: 一個用戶可能對多個物品感興趣,如果把這多個物品分散到不同的樣本中,再像 TDM 中那樣直接每一層負取樣,極有可能會把一個正樣本當成了另外樣本的負樣本來訓練。而如果是一個樣本有多個 label 的話則可以避免這種情況,比如 OTM 中每一層 beam search 得到的訓練節點,可以有多個正樣本,只要這些正樣本分別對應於多個 label 的祖先節點。

然而多 label 帶來的問題是一個樣本不同的 target 節點可能有同一個父節點,那麼這個父節點的 pseudo target 應該取決於哪一個 target 節點呢?這一點在論文中沒有明確說明,不過參照論文中的 \((1)\) 式對於 target 的正式定義 (這裡記為 \((3.3)\) 式),可以將有相同父節點的 target 節點進行聚合,即先將一組 target 節點按父節點分組,屬於同一組的再進行加和。

\[z_n = \mathbb{I}(\sum\limits_{n’ \in \mathcal{L}(n)} y_{\pi(n’)} \geq 1) \tag{3.3}
\]

Scala 2.13 在集合庫中新增了 groupMapReduce 方法,非常適合這個需求,假設已經得到了一組節點組成的列表 nodes,每個節點用元組 (id, score) 表示,那麼想要將其中相同父節點的 target 分組聚合得到一個新的列表,只需要一行程式碼 nodes.groupMapReduce(n => (n._1 - 1) / 2)(_._2)(_ + _)

Deep Retrieval


Deep Retrieval 的核心賣點和 TDM 系列差不多,即在大規模召回中直接使用複雜模型,因而兩者總免不了被拿來作比較。TDM 系列為了能快速檢索引入了樹作為索引結構,而 DR 中的索引結構是一個 \(K \times D\) 的矩陣,總共有 \(D\) 層,每層 \(K\) 個節點,見論文中的圖 \((\rm{a})\)

在檢索的時候同樣使用了 beam search,從最左側的一層開始使用 user embed 作為輸入,每一層選擇 top-B 的節點,最後得到 top-B 的 path,再通過映射函數找到 path 對應的物品。path 指的是每一層選出的節點組成的序列,論文中用 \(c = (c_1, c_2,…,c_D)\) 表示,每條 path 可以看作是一個 cluster 。這個步驟得到的 path 以及物品之間的順序並不重要,因為論文中還同時訓練了一個重排序 (rerank) 模型,對得到的物品作進一步排序最後輸出召回結果。從論文里看這個 rerank 模型是屬於 Deep Retrieval 的一部分,而不是一般意義上跟在召回模組後的粗排或精排。

與 TDM 一樣,Deep Retrieval 的整個體系也需要訓練兩個部分 —— 模型和索引結構,不過這裡的索引結構被具象化為了一個映射函數 \(\pi(\cdot)\) 。這一點和 JTM 類似,不同之處在於 JTM 中僅僅是物品到葉節點的映射,而 DR 中是物品到多條 path 的映射。上圖 \((\rm{b})\) 為 DR 的模型結構,第一層的輸入為 user embed,而後的每一層輸入為 user embed 和之前層的節點 embed 的拼接,每一層的輸出為 \(K\) 個節點的 softmax。由於 DR 中每個物品可以映射到 \(J\) 條 path,那麼總的目標函數為:

\[\mathcal{Q}_{\text{str}}(\theta, \pi) = \sum\limits_{i=1}^N\text{log}\left(\sum\limits_{j=1}^J p(c_{i,j} = \pi_j(y_i)|x_i,\theta)\right) \tag{4.1}
\]

上文講 JTM 的時候提到過模型使用二分類還是多分類的選擇,使用多分類的問題是會使樹每層的模型不同,且推理的時候計算量大。從上面的圖 (\(\rm{b}\)) 看 DR 使用的正是多分類 softmax 輸出概率,而每一層的輸入輸出都不相同,所以 DR 中每一層 MLP 本質上是不同的模型,僅在 user embed 層面是共享的,這一點和 TDM 所有節點共享同一個模型不一樣。另一方面,由於 DR 模型中每一層的類別比較少 (論文中 K = 100),也就不需要 NCE 這樣的近似計算了,可直接通過原始 softmax 更新模型。所以綜合來看雖然每層模型不同致使參數量變大,但類別設定的少的話訓練和推理在這方面應該不構成什麼問題。

論文里將需要訓練的兩部分,即模型和索引結構,分為了類似於 EM 演算法的 E-step 和 M-step ,E-step 為固定 \(\pi(\cdot)\) 優化模型參數 \(\theta\),M-step 為固定模型參數 \(\theta\) 優化 \(\pi(\cdot)\) ,二者優化的是同一個目標函數:

\[\mathcal{Q}_{\text{pen}}(\theta, \pi) = \mathcal{Q}_{\text{str}}(\theta, \pi) – \alpha \cdot \sum\limits_{c \in [K]^D} f(|c|) \tag{4.2}
\]

\((4.2)\) 式和 \((4.1)\) 式的不同點在於引入了一個懲罰函數 \(f(|c|)\),用於防止一條 path 被分配到了太多的物品。不過仔細看的話可以發現加的這個懲罰函數只會影響 M-step,而 E-step 只優化模型參數,所以 E-step 訓練的時候可以忽略這個 \(f(|c|)\)

E-step 的訓練完成後,接下來是 M-step 的優化。如果之前沒有寫 JTM,我大概對這部分也不會有什麼特別的感覺,然而現在我越看越覺得 DR 的這個 M-step 與 JTM 很像。當然不是說具體的演算法步驟,而是背後的核心思想相似。M-step 中比較重要的是理解論文中定義的打分函數 score function :

\[s[v,c] \triangleq \sum\limits_{i:y_i = v} p(c|x_i,\theta) \tag{4.3}
\]

\(s[v,c]\) 表示物品 \(v\) 分配到 path \(c\) 的累計重要度,使用的是所有目標物品為 \(v\) 的樣本加和,表示為 \(i:y_i=v\) 。拋開符號的差異,\((4.3)\) 式所表示的意思其實和 JTM 中的 \((2.2)\) 式如出一轍,\(i:y_i=v\) 就約等於 \((2.2)\) 式的 \(\mathcal{A_i}\) 。二者流程的內在含義都是想要獲得物品的最佳映射,那麼就把所有可能的映射對應物都計算一遍目標函數。不同點在於 JTM 中一個物品只映射到一個葉節點,所以取目標函數最大的那個節點;而 DR 中一個物品可以對應多條 path,因而取分數最大的 \(S\) 條候選 path,\(S\) 是一個超參數。

在得到了所有的 \(s[v,c]\) 後就意味著得到了每個物品 \(v\)\(S\) 條候選 path,接下來的目標是從 \(S\) 條中選出最終的 \(J\) 條。之所以在之前的計算中不直接選擇 \(J\) 條出來,是因為之前 \(s[v,c]\) 的計算沒有考慮 \((4.2)\) 式里的懲罰函數。DR 中加入懲罰函數 \(f(|c|)\) 是為了防止一條 path 被分配太多的物品導致不均衡,而這與 JTM 中的 rebalance 操作異曲同工,因為 rebalance 也是為了防止一個節點被分配太多的物品,所以到這裡我確信 DR 的 M-step 絕對借鑒了 JTM 里的思想。

經過一系列推導,論文中提到了依據 incremental gain 的大小來選擇最終的 \(J\) 條 path ,如下演算法流程 :

\[\text{incremental gain} = N_v \left(\text{log}(\sum\limits_{j=1}^{i-1}s[v,\pi_j(v)] + s[v,c]) – \text{log}(\sum\limits_{j=1}^{i-1}s[v,\pi_j(v)])\right) – \alpha(f(|c|+1) – f(c)) \tag{4.4}
\]

注意這個演算法流程的輸入是 \(s[v,c]\) ,也就是默認 \((4.3)\) 式的 \(s[v,c]\) 已經提前計算好了。不過這一步實際上是挺耗時的,因為需要所有的樣本都推理一遍。\(s[v,c]\) 可以通過流式訓練 (streaming training),細節就不細述了,論文里這一塊寫地比較詳細。在實現中由於我用的是固定數據集,所以無論是直接計算 \((4.3)\) 式訓練還是使用流式訓練都可以,在程式碼中前者用」batch「表示,後者用」streaming「表示。這裡的直接計算 \((4.3)\) 式指的是先將所有數據都扔進模型計算出所有樣本的 \(p(c|x_i,\theta)\) ,再對各個物品與 path 分組 (groupby) 加和,最後排序得到每個物品分數最大的 \(S\) 條 path 。

最後關於 beam search 後的重排序 (rerank) 模型,在論文 2.3 節說這個 rerank 模型用的是 softmax ,然而後面的實驗部分又說這只是在公開數據集上使用的,實際生產環境用的是 logistic regression ,原因是 softmax 的效果不大好。這個操作就有點迷了,合著這個 softmax 就是用來在公開數據集上刷榜的? 反正我的實現就是按照論文里的原始提法,用 sampled_softmax 近似 softmax 以解決物品數過多的問題。

Deep Retrieval 論文中還有一個槽點,如果我之前沒看過 OTM 論文大概率也不會察覺,那就是 DR 的實驗為什麼沒和 OTM 作比較?一開始我以為是因為兩者都首發表於 2020 年,所以互相不知道對方的工作。然而重看論文的時候發現 DR 論文的 Related Works 里赫然寫著 TDM, JTM, OTM 。所以又回頭看了一下 OTM 論文就明白了,因為在實驗的數據集上 OTM 的指標遠高於 Deep Retrieval ,不可能在論文里拿一個效果更好的模型作對比。當然僅憑這點並不能蓋棺定論 OTM 一定優於 Deep Retrieval 。

/