CogLTX:應用BERT處理長文本

論文標題:CogLTX: Applying BERT to Long Texts
論文鏈接://arxiv.org/abs/2008.02496
論文來源:NeurIPS 2020

一、概述

BERT由於其隨文本長度二次增長的記憶體佔用和時間消耗,在處理長文本時顯得力不從心。通常BERT最大支援輸入序列長度為512,這對於標準的benchmark比如SQuAD和GLUE數據集是足夠的,但對於更加通常的情況下,比如更加複雜的數據集或者真實世界的文本數據,512的序列長度是不夠用的。

  1. 相關工作

目前BERT處理長文本的方法有截斷法、Pooling法和壓縮法。本文提出的CogLTX(Cognize Long TeXts)屬於壓縮法的一種。下面簡單介紹下這三種處理方法:

  • 截斷法

截斷法的處理方式其實就是暴力截斷,分為head截斷、tail截斷和head+tail截斷。head截斷就是從頭開始保留限制的token數,tail截斷就是從末尾往前截斷,head+tail截斷是開頭和結尾各保留一部分。

  • Pooling法

將長文本分成多個segment,拆分可以使用暴力截斷的方法,也可以使用斷句或者劃窗的方法。每一個segment都通過BERT進行encoding,然後對得到的[CLS]的embedding進行max-pooing或者mean-pooling,亦或將max-pooing、mean-pooling進行拼接。如果考慮性能,只能使用一個Pooling的話,就使用max-pooing,因為捕獲的特徵很稀疏,max-pooling會保留突出的特徵,mean-pooling會將特徵打平。

這種方法有明顯缺點,首先需要將多個segment進行encoding,文本越長,速度越慢。另外這種拆分文本的方式也犧牲了長距離token之間進行attention的可能性,舉例來說,如下圖,這是HotpotQA中的一個例子,解答問題的兩個關鍵句子之間長度相差超過512,因此他們不會出現在任何一個segment里,因此沒法做attention:

example

  • 壓縮法

壓縮法是將長文本按照句子分為多個segment,然後使用規則或者單獨訓練一個模型來剔除一些無意義的segment。

使用滑窗一類的方法通常會將每個segment的結果做aggregate,通常這一類方法會使用max-pooing或者mean-pooling,或者接一個額外的MLP或者LSTM,但是這樣犧牲了長程注意力的機會並且需要O(512^2\cdot L/512)=O(512L)的空間複雜度,這樣的複雜度在batch size為1,token總數為2500,BERT版本為large的情況下對於RTX 2080ti的GPU仍然太大。並且這種方法只能優化分類問題,對於其他任務比如span extraction,有L個BERT輸出,需要O(L^2)的空間來做aggregate。

  1. CogLTX

BERT之所以難以處理長文本的問題根源在於其O(L^2)的時間和空間複雜度,因此另外一個思路是簡化transformer的結構,但是目前為止這一部分的成果很少能夠應用於BERT。

CogLTX的靈感來自working memory,這是人類用來推理和決策的資訊儲存系統。實驗表明working memory每次只能保留5~9個item或者word,但是它卻有從長文本中進行推理的能力。 Working memory中的central executive,其功能就像有限容量的注意力系統,職責是協調綜合資訊。研究表明working memory中的內容會隨著時間衰減,除非通過rehearsal來保持,也就是注意和刷新頭腦中的資訊,然後通過retrieval competition從長時記憶中更新忽略的資訊用來推理和決策。

CogLTX的基本理念是通過拼接關鍵句子來進行推理。CogLTX的關鍵步驟是將文本拆分成多個blcok,然後識別關鍵的文本block,在CogLTX中叫做MemRecall,這是CogLTX的關鍵步驟。CogLTX中有另一個BERT模型,叫做judge,用來給block的相關性進行打分,並且它和原來的BERT(叫做reasoner)是jointly train的。CogLTX能夠通過一些干預(intervention)來將特定於任務的標籤(task-oriented label)轉換成相關性標註(relevance annotation),以此來訓練judge。

二、方法

  1. 方法論

CogLTX的基本假設是:對於大多數NLP任務來說,一些文本中關鍵的句子存儲了完成任務所需要的充分且必要的資訊,具體地,對於長文本x,其中存在由某些句子組成的短文本z,滿足reasoner(x^+)\approx reasoner(z^+),其中x^+z^+是下圖所示的BERT(reasoner)的輸入:

CogLTX for different tasks

我們將長文本x分解成多個blockx=\left [ x_{0}\; \cdots x_{T-1}\right ],當BERT的長度限制為L=512時,每個block的最大長度限制為B=63。關鍵短文本z由部分x中的block組成,也就是z=\left [ x_{z_{0}}\; \cdots x_{z_{n-1}}\right ],滿足len(z^+)\leq L並且z_{0}< \cdots < z_{n-1},接下來我們用z_{i}表示x_{z_{i}}z中的所有block都會自動排序,以保持x中的原始相對順序。

CogLTX中兩個重要的要素是MemRecall和兩個jointly train的BERT,其中MemRecall是利用judge模型來檢索關鍵塊然後輸入到reasoner中來完成任務的演算法。

  1. MemRecall

做QA任務時MemRecall的整個流程如下:

MemRecall

  • 輸入

前面圖中(a)Span Extraction Tasks、(b)Sequence-level Tasks和(c)Token-wise Tasks三種不同類型的任務有不同的特定設置,在任務(a)、(c)中,問題Q和子句x[i]作為query來檢索相關的block,但是在任務(b)中沒有query,並且相關性只被訓練數據隱式地定義。

z^+是被MemRecall維護的關鍵短文本,MemRecall除了接受x外還接受一個額外的初始z^+,在任務(a)、(c)中,query成為初始的z^+,judge在z^+的輔助下學習預測特定於任務的相關性。

  • 模型

MemRecall只使用一個模型,這是一個為每個token的相關性進行打分的BERT。假設z^{+}=\left [ [CLS]\; Q\; [SEP]\; z_{0}\; [SEP]\; \cdots\; z_{n-1}\right ],則有:

judge(z^{+})=sigmoid(MLP(BERT(z^{+})))\in (0,1)^{len(z^{+})}

一個blockz_{i}的得分是這個block內所有token得分的均值,記作judge(z^{+})[z_{i}]

  • 流程

MemRecall首先進行retrieval competition,每個blockx_i都會被打一個粗相關性得分judge(z^{+}\; [SEP]\; x_i)[x_i]。得分最高的幾個「winner」 block被插入到z中直到len(z^+)<L

接下來的rehearsal-decay過程會給每個z_i分配一個精相關性得分judge(z^{+})[z_{i}],得分最高的被保留在z^+中,得到一個new\; z^+。進行精相關性打分的原因是在沒有block之間交互和比較的情況下粗相關性得分不夠精確。

然後使用這個new\; z^+重複進行retrieval competition和rehearsal-decay,整個過程可以重複多次。通過這個迭代的過程可以實現multi-step reasoning。需要注意的是如果被z^+中的新block的更多資訊證明相關性不夠,在上一步中保留在new\; z^+中的block也可能decay,這是之前的multi-step reasoning方法所忽略的。

  1. 訓練
  • judge的監督學習

通常span extraction tasks會把答案block標記為relevant,即使是multi-hop的數據集比如HotpotQA,通常也會標註支援的句子。在這些情況下,judge通常使用監督學習的方式來訓練:

loss_{judge}(z)=CrossEntropy(judge(z^{+}),relv\_label(z^{+}))\\
relv\_label(z^{+})=[\underset{for\; query}{\underbrace{1,1,\cdots ,1}}\; \; \underset{z_0\; is\; irrelevant}{\underbrace{0,0,\cdots ,0}}\; \; \underset{z_1\; is\; relevant}{\underbrace{1,1,\cdots ,1}}\; \; \cdots ]\in [0,1]^{len(z^{+})}

這裡訓練的樣本z是從x中取樣出的多個連續blockz_{rand}(對應retrieval competition的數據分布)或者所有相關和隨機選擇的不相關blockz_{relv}(對應 rehearsal的數據分布)。

  • reasoner的監督學習

理想情況下,訓練時reasoner的輸入應該由MemRecall來生成,但是並不能保證所有的相關block都能被檢索到。以QA任務為例,如果答案的blcok沒有被檢索到,reasoner就無法通過檢索到的block進行訓練,因此解決方案為做一個近似,將所有相關blcok和retrieval competition中的「winner」 block輸入到reasoner中進行訓練。

  • judge的無監督學習

大多數的任務不會提供相關性的label。對於這種情況我們使用干預的手段來推斷相關性標籤:通過從z中剔除某個block來看它是否是不可或缺的。

假設z種包含所有的相關block,則根據我們的假設則有:

loss_{reasoner}(z_{-z_{i}})-loss_{reasoner}(z)>t,\forall z_{i}\in z,\; \; (necessity)\\
loss_{reasoner}([zx_{i}])-loss_{reasoner}(z)\approx 0,\forall x_{i}\notin z,\; \; (sufficiency)

z_{-z_{i}}是從z中移除z_{i}的結果並且t是一個閾值。每次訓練reasoner後,我們會剔除每個z中的z_i,然後根據loss的增加調整它們的相關性標籤。如果loss的增加是不顯著的,則表明這個block是不相關的,它可能在下一個epoch中不會再贏得retrieval competition,因為它將在下一個epoch中被標記為irrelevant來訓練judge。在實際中,閾值t會被劃分為t_{up}t_{down},保留一個buffer zone來避免標籤的頻繁切換。

下圖展示了20News數據集上無監督學習的一個例子:

example

  • 總結

訓練的流程圖總結如下:

訓練

三、實驗

  1. Reading comprehension

Reading comprehension

  1. Multi-hop question answering

Multi-hop question answering

  1. Text classification

Text classification

  1. Multi-label classification

Multi-label classification

  1. 記憶體和時間消耗

記憶體和時間消耗