Bert不完全手冊9. 長文本建模 BigBird & Longformer & Reformer & Performer
這一章我們來嘮嘮如何優化BERT對文本長度的限制。BERT使用的Transformer結構核心在於注意力機制強大的交互和記憶能力。不過Attention本身O(n^2)的計算和內存複雜度,也限制了Transformer在長文本中的應用。
之前對長文檔的一些處理方案多是暴力截斷,或者分段得到文本表徵後再進行融合。這一章我們看下如何通過優化attention的計算方式,降低內存/計算複雜度,實現長文本建模。Google出品的Efficient Transformers: A Survey裏面對更高效的Transformer魔改進行了分類,這一章我們主要介紹以下5個方向:
- 以Transformer-XL為首的片段遞歸
- Longformer等通過稀疏注意力,降低內存使用方案
- Performer等通過矩陣分解,降低attention內積計算複雜度的低秩方案
- Reformer等可學習pattern的注意力方案
- Bigbird等固定pattern注意力機制
Transformer-xl
- paper: Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context
- github://github.com/kimiyoung/transformer-xl
- Take Away: 相對位置編碼 + 片段遞歸機制
為了突破Transformer對固定長度建模的限制,Transformer-xl提出了相對位置編碼和片段遞歸的方案,後續也被XLNET沿用~
- 片段遞歸
片段遞歸的思路其實很早就有,不過之前的方案多是保留上一個片段的last hidden state,作為當前片段的補充信息。而Transformer-xl則是直接保留並cache了上個片段的所有hidden state,和當前片段進行拼接,梯度更新時只更新當前片段的隱藏層。
具體的Attentenion計算中如下,\(\tau\)是片段,\(n\)是hidden layer,\(\circ\)是向量拼接,\(SG()\)是不進行梯度更新的意思。於是當前片段Q,K,V是由上個片段的隱藏層和當前片段的隱藏層拼接得到。每個片段完成計算後會把隱藏層計算結果進行存儲,用於下個片段的計算,用空間換時間,既避免了重複計算,又使得新的片段能保留大部分的歷史片段信息。這裡的歷史片段信息並不一定只使用T-1,理論上在內存允許的情況下可以拼接更多歷史片段~
- 相對位置編碼
片段遞歸如果和絕對位置編碼一起使用會存在問題,因為不同片段相同位置的
絕對位置編碼相同,模型無法區分它們來自不同的片段。因此作者提出了相對位置編碼。之前在討論絕對位置編碼不適用於NER任務時有分析過相對位置編碼>>中文NER的那些事兒5. Transformer相對位置編碼&TENER代碼實現,這裡我們再回顧下~
絕對位置編碼是直接加到詞向量上,在Attention計算中進行交互。把內積展開可以得到如上形式,包括4個部分:Query和Key的純語義交互,各自的位置信息和語義的交互,以及反映相對距離的位置交互。
Transformer-XL的相對位置編碼和以上的展開形式基本一一對應,也使用了三角函數的編碼方式,只需要兩點調整
- key對應的絕對位置編碼\(p_j\)替換為兩個token相對位置i-j的相對位置編碼\(R_{i,j}\)
- query的位置編碼\(P_iW_q\)替換成兩個learnable的參數u和v
和以上絕對位置編碼的Attention計算對比:
- 語義交互不變
- 位置交互:絕對位置編碼內積替換為相對位置編碼對應的全局位置偏置, 在表徵距離的同時加入了方向信息
- query位置*key語義:因為交互是計算一個query token對全部key token的attention,所以這裡的位置編碼部分是個常量,作者替換為了trainable的參數u,於是這部分有了更明確的含義就是key對應的全局語義偏置
- query語義*key位置: 替換為query語義 * query和key的相對位置編碼,也就是語義和位置交互
結合片段遞歸和相對位置編碼,Transformer-xl突破了Transformer對固定文本長度的限制。同時片段遞歸和以下4種Transformer優化方案是正交的關係,可以在以下的四種方案中疊加使用片段遞歸去優化長文本建模
Longformer
- paper: Longformer: The Long-Document Transformer
- github://github.com/allenai/longformer
- 中文預訓練模型://github.com/SCHENLIU/longformer-chinese
- Take Away: 滑動窗口稀疏注意力機制
Longformer的3點主要創新是
- 滑動窗口attention(圖b)
解決attention計算複雜度最簡單直觀的方案,就是把原本all-2-all的attention計算限制到適當的window size(w)內,這樣對於長度為n的序列,原本O(n^2)的複雜度就縮減到了O(n*w)。因為attention本質是引入當前token的上下文信息,但token其實很難和八丈遠外的內容進行交互,所以合理的窗口選擇並不會損失太多的信息,並且和stack-cnn相同更高的layer會擁有更大的感知野。Longformer這裡選擇了512作為窗口大小,attention的複雜度和BERT相同。
- 空洞滑窗attention(圖C)
和Dilated-CNN相同,這裡作者也採用了dilation來擴大相同計算量下的感知野。不過感覺這裡和CNN還是有些區別,圖像使用Dilation因為單一像素本身信息有限,需要通過kernel來提取圖片局部特徵,而對文本序列來說,每個token就是最小粒度的信息元包含信息更多,因此dilation會帶來更多的信息損失。不過作者在使用過程中也加了一些tricks,包括對多頭的不同頭使用不同的dilation策略,以及底層layer不使用dilation保留更多信息,更高層使用更大的dilation擴大感知野。不過在後面的消融實驗中空洞滑窗的效果提升並不十分顯著。
- 任務導向全局attention(圖d)
以上局部attenion在一些任務中存在不足,例如QA任務中可能問題無法和上下文進行完整交互,以及分類任務中CLS無法獲得全部上下文信息。因此作者在下游任務微調中加入了針對部分token的全局attention。因此在下游微調時,需要進行全局交互的token,會用預訓練的Q,K,V進行初始化,不過會用兩套線性映射的參數分別對全局和滑動窗口的Q,K,V進行映射。
Longformer的預訓練是在Roberta的基礎上用長文本進行continue train。原始Roberta的position embedding只有512維,這裡longformer把PE直接複製了8遍,得到4096維度的PE用於初始化,這樣在有效保留原始PE局部信息的同時,也和以上512的window-size有了對應。至於longformer的效果,可以直接看和下面BigBird的對比。
Bigbird
- paper: Big Bird:Transformers for Longer Sequences
- github: //github.com/google-research/bigbird
- Take Away: 使用補充固定token計算全局注意力
又是一個非常清新脫俗的模型起名~ 大鳥模型和longformer相比增加了隨機注意力機制,不過感覺主要的創新是對全局注意力機制進行了改良,提出了固定注意力patten的ETC全局注意力機制。
- 隨機注意力機制
在滑動窗口注意力之外,模型會每行隨機採樣r個token來進行交互,不過這裡的隨機注意力並不和以下的ETC全局注意力一同使用~
- 全局注意力
只使用滑動窗口注意力+隨機注意力,作者發現效果和BERT相比還是有所損失,因此加入了全局注意力。和longformer的區別在於,Bigbird除了支持對部分已有token(一般是序列的第一個和最後一個字符)進行全局attention之外,簡稱Bigbird-ITC。還
支持加入額外token(類似CLS)來做全局注意力,簡稱Bigbird-ETC,ETC不和隨機注意力一同使用。從下游任務效果上來看ETC的效果略好於ITC+隨機注意力,效果對比基本是用的BigBird-ETC,不過這也限制了BigBird只能用在NLU場景~
整體效果,在QA和長文本摘要任務上上Bigbird基本是新SOTA
Reformer
- paper: REFORMER: THE EFFICIENT TRANSFORMER
- github: //github.com/google/trax/tree/master/trax/models/reformer
- Take Away: LSH搜索序列中的高權重token,做固定長度局部注意力計算
先來看下原始Transformer的空間複雜度: \(max(b*l* d_{ffn}, b *n_{h} * l^2)*n_{l}\)。其中b是batch,l是文本長度,\(d_{ffn}\)是Feed Forward層大小,\(n_{h}\)是多頭的head size,\(n_l\)是層數。Reformer引入了三個方案來降低Transformer的計算和內存複雜度
- LSH Attention:近似計算,針對l,只計算注意力中高權重的部分
- 可逆網絡:時間換空間,針對\(n_l\),只存儲最後一層的參數
- 分塊計算:時間換空間,針對\(d_{ffn}\),對FFN層做分塊計算
後兩個方案和長文本無關這裡我們簡單過,重點是LSH Attention部分的創新~
- LSH Attention
Local Sensitentive Hashing Attention是Reformer的主要貢獻,也就是最初分類中的可學習pattern注意力機制。考慮Attention的結果是被高權重的key主導的,因此每個token的注意力權重可以被部分高權重的token近似,只計算局部注意力從而避免計算\(L^2\)的注意力矩陣。難點轉換成了如何更高效的找到高權重的key,也就是和query token向量空間更相似的key token來進行局部交互,這裡作者使用了LSH,一種在高維數據中快速近似查找的算法。
LSH使用哈希函數對高位空間的向量x計算哈希函數h(x),\(h(x)\)滿足在高維空間中更近的向量有更高的概率落在相同的哈希桶中,反之在高維空間中距離更遠的向量有更低的概率會落在相同的哈希桶中。LSH有很多種算法,這裡作者使用的是基於角距離的局部敏感哈希算法。隨機初始化向量R維度是\(d_{model} * bucket/2\),哈希結果為旋轉(xR)之後最近的一個正或者負的單位向量\(h(x) = argmax([xR;-xR])\)
使用LSH計算Attention會存在幾個問題
- query和key的hashing不同:為了解決這個問題作者把計算注意力之前query和key各自的線性映射統一成了一個,\(k_j=\frac{q_j}{||q_j||}\),這樣二者的哈希也會相同,只需要對key進行計算就得到token的哈希分桶。例如上圖(b)長度為6的序列被分成3個桶[1,2,4],[3,6],[5]
- 哈希的誤差:哈希只是使得相似的向量落入相同桶的概率更高,為了進一步提高這個概率,可以進行多次不同的哈希函數對輸出結果取交際,進一步降低近似帶來的信息損失。也就是用更多的時間和空間來換取更好的近似效果
- 每個序列哈希分桶的大小可能不盡相同,無法進行batch計算:這裡作者又做了一步近似。根據以上的哈希結果對token進行重排序,相同哈希的token放一起,桶內按原始位置排序,按固定長度m進行切分,每個chunk的query對當前chunk和前一個chunk的key計算注意,也就是位於[m,2m]的query對[0,2m]的key計算注意力,這裡m和哈希桶數反向相關\(m=\frac{l}{n_{bucket}}\),也就是和平均哈希桶的大小正相關。實際上LSH只是用來排序,提高固定長度內注意力權重占整個序列的比例,從而通過有限長度的注意力矩陣近似全序列的注意力結果。同樣是固定窗口,LSH使得該窗口內的token權重會高於以上Longformer,BigBird這類完全基於位置的固定窗口的注意力機制,不過LSH的搜索和排序也會進一步提高時間複雜度
- 可逆殘差網絡
可逆殘差的概念是來自The reversible residual network: Backpropagation without storing activations(Gomez2017)。通過引入可逆變換,RevNet使得模型不需要存儲中間層的參數計算梯度
,每一層的參數可以由下一層通過可逆運算得到。屬於時間換空間的方案,因為反向傳播計算梯度時需要先還原本層的參數,因此時間上會增加50%左右~ 細節我們就不多展開想看math的往蘇神這看可逆ResNet:極致的暴力美學, 簡單易懂的往大師兄這看可逆殘差網絡RevNet
- 分塊計算
分塊主要針對FFN層。因為Feed Forward一般會設置幾倍於Attention層的hidden size,通過先升維再降維的操作提高中間層的信息表達能力,優化信息的空間分佈,以及抵消Relu帶來的信息損失。但是過大的hidden size會帶來極高的空間佔用。因為是在embedding維度進行變換每個位置之間的計算獨立,因此可以分塊進行計算再拼接,用時間來換空間
效果評測部分我們在下面的performer里一起討論
Performer
- paper: Rethinking Attention with Performers
- github: //github.com/google-research/google-research/tree/master/performer
- Take Away: 提出核函數使得QK變換後的內積可以近似注意力矩陣,配合乘法結合律把複雜度從平方降低到線性
多頭注意力機制的計算是query和key先計算Attention矩陣A,再對V進行加權,也就是上圖等號左邊的計算順序,複雜度是序列長度的平方。為了避免計算\(L^2\)的注意力矩陣,作者採用矩陣分解\(q^{\prime} \in R^{L,r},k^{\prime} \in R^{L,r}\),這裡r<d<<L,配合矩陣乘法的結合律,K先和V計算再和Q內積,把空間複雜度從平方級降低到線性。但是注意力矩陣過softmax之後無法直接做可逆轉換得到\(q^{\prime},k^{\prime}\), 因此作者提出了使用positive Random Feature對QK進行映射,使得映射後的矩陣\(q^{\prime},k^{\prime}\)內積可以近似Attention矩陣。簡單解釋就是以下的變換
\]
所以Performer的核心在\(\phi\)核函數的設計使得映射後的QK內積可以高度近似注意力矩陣,具體設計如下
這裡\(SM(x,y) = exp(x^Ty)\)也就是原。始的注意力矩陣,按照\(f(x)=exp(w^Tx-\frac{||x||^2}{2})\)對Q和K進行變換後,QK內積的期望就等於原始的注意力矩陣。不過在實際計算中只能對隨機變量w進行有限次採樣, 因此是近似原始注意力矩陣。論文有大量篇幅在進行推導和證明,這裡就不做展開了。
效果對比我們直接參考Google給出的效果對比,橫軸是速度,縱軸是效果(多任務平均值),點的大小是內存。整體上BigBird還是拔得頭籌,它並不是所有任務的SOTA但是整體效果穩定優秀,想看詳細對比結果的參考REF2~
Reference
- Efficient Transformers: A Survey
- Long Range Arena: A Benchmark for Efficient Transformers