用最簡陋的LSTM,超越各種Transformer
- 2019 年 12 月 2 日
- 筆記
栗子 發自 凹非寺 量子位 報道 | 公眾號 QbitAI
如今,語言模型的世界,幾乎被Transformer/BERT佔領了。但如果回到2017年,把轟動世界的論文Attention Is All You Need從時間線上抹掉呢?
多頭注意力不存在了,只剩下原始的LSTM,自然語言處理又會朝怎樣的方向進化?
機器學習大佬Stephen Merity (Smerity) 突發奇想,強行回到過去,依靠簡單質樸的LSTM,做出了單頭注意力RNN,簡稱SHA-RNN。這個古法炮製的新模型,只用單個GPU訓練不到24小時,就在語言建模數據集enwik8上獲得了接近SOTA的成績。
除了算力要求不高,它還支援最多5000個token的長距離依賴。
論文引發了大量圍觀和討論,推特已有1700贊,Reddit熱度達到了170。
要那麼多頭做什麼?
就像蝴蝶效應,大佬Smerity說他要證明的是:只要方法稍有改變,整個領域會朝完全不同的方向發展。他開發的新模型,是由幾個部分組成的:一個可訓練的嵌入層,一層或者多層堆疊的單頭注意力RNN (SHA-RNN) ,再加一個softmax分類器。其中,SHA-RNN的結構就是下圖這樣:

△ LN=Layer Normalization
大致說來,SHA-RNN用的是單頭的、基於指針的注意力 (Pointer Based Attention) ,借鑒了2017年作者本人領銜的研究;還包含一個改造過的前饋層,名叫「Boom」,帶有層歸一化。
那麼,分別來觀察一下,注意力和前饋層。
首先是注意力。Smerity老師說,許多受Transformer啟發的模型架構,都假設在構造上沒有順序 (Sequentiality) ,且每層都有幾十個頭,計算起來太複雜了,大家也並不知道有多少頭是有效的。
相比之下,SHA-RNN模型的注意力是簡化的,只留一個頭,唯一的矩陣乘法出現在query (下圖Q) 那裡,A是縮放點乘注意力 (Scaled Dot-Product Attention) ,是向量之間的運算。

△ MM=Matrix Manipulation,LN=Layer Normalization
這樣一來,計算起來效率很高,普通台式機也可以訓練。
接下來講前饋層 (「Boom」 Layer) 。雖然這是從Transformer借鑒來的,不過Smerity老師重新排布了一下:
用了一個v∈ℝH向量,又用矩陣乘法 (GeLU激活) 得到另一個向量u∈ℝN×H。
然後,把u向量分解成N個向量,再求和,得到w∈ℝH向量。
這樣一來,與傳統的下映射層 (Down-Projection Layers) 相比,減少了運算量,除掉了一整個矩陣的參數。
那麼,SHA-RNN成績怎麼樣呢?
拉出來遛遛
Smerity老師說,雖然能用家裡的台式機訓練,但跑著跑著沒了耐心,於是改用GPU (12GB Titan V) 訓練了不到一天。
然後,就在兩個數據集enwik8和WikiText-103試一試吧。
其中,enwik8數據集包含了上億位元組維基百科XML轉儲。這是比賽結果:
當然,直接和純LSTM比是沒意義的,直接和無頭SHA-RNN比也沒意義。

測試機上的表現,超越了各種Transformer。
另一場比賽,在WikiText-103數據集上進行,測試的是Tokenization (分詞) 。結果認為,SHA-RNN可以有效抵禦Tokenization攻擊。
成功了。