用最簡陋的LSTM,超越各種Transformer

  • 2019 年 12 月 2 日
  • 筆記

栗子 發自 凹非寺 量子位 報道 | 公眾號 QbitAI

如今,語言模型的世界,幾乎被Transformer/BERT佔領了。但如果回到2017年,把轟動世界的論文Attention Is All You Need從時間線上抹掉呢?

多頭注意力不存在了,只剩下原始的LSTM,自然語言處理又會朝怎樣的方向進化?

機器學習大佬Stephen Merity (Smerity) 突發奇想,強行回到過去,依靠簡單質樸的LSTM,做出了單頭注意力RNN,簡稱SHA-RNN。這個古法炮製的新模型,只用單個GPU訓練不到24小時,就在語言建模數據集enwik8上獲得了接近SOTA的成績。

除了算力要求不高,它還支援最多5000個token的長距離依賴。

論文引發了大量圍觀和討論,推特已有1700贊,Reddit熱度達到了170。

要那麼多頭做什麼?

就像蝴蝶效應,大佬Smerity說他要證明的是:只要方法稍有改變,整個領域會朝完全不同的方向發展。他開發的新模型,是由幾個部分組成的:一個可訓練的嵌入層,一層或者多層堆疊的單頭注意力RNN (SHA-RNN) ,再加一個softmax分類器。其中,SHA-RNN的結構就是下圖這樣:

LN=Layer Normalization

大致說來,SHA-RNN用的是單頭的、基於指針的注意力 (Pointer Based Attention) ,借鑒了2017年作者本人領銜的研究;還包含一個改造過的前饋層,名叫「Boom」,帶有層歸一化。

那麼,分別來觀察一下,注意力和前饋層。

首先是注意力。Smerity老師說,許多受Transformer啟發的模型架構,都假設在構造上沒有順序 (Sequentiality) ,且每層都有幾十個頭,計算起來太複雜了,大家也並不知道有多少頭是有效的。

相比之下,SHA-RNN模型的注意力是簡化的,只留一個頭,唯一的矩陣乘法出現在query (下圖Q) 那裡,A是縮放點乘注意力 (Scaled Dot-Product Attention) ,是向量之間的運算。

MM=Matrix Manipulation,LN=Layer Normalization

這樣一來,計算起來效率很高,普通台式機也可以訓練。

接下來講前饋層 (「Boom」 Layer) 。雖然這是從Transformer借鑒來的,不過Smerity老師重新排布了一下:

用了一個v∈ℝH向量,又用矩陣乘法 (GeLU激活) 得到另一個向量u∈ℝN×H。

然後,把u向量分解成N個向量,再求和,得到w∈ℝH向量。

這樣一來,與傳統的下映射層 (Down-Projection Layers) 相比,減少了運算量,除掉了一整個矩陣的參數。

那麼,SHA-RNN成績怎麼樣呢?

拉出來遛遛

Smerity老師說,雖然能用家裡的台式機訓練,但跑著跑著沒了耐心,於是改用GPU (12GB Titan V) 訓練了不到一天。

然後,就在兩個數據集enwik8WikiText-103試一試吧。

其中,enwik8數據集包含了上億位元組維基百科XML轉儲。這是比賽結果:

當然,直接和純LSTM比是沒意義的,直接和無頭SHA-RNN比也沒意義。

測試機上的表現,超越了各種Transformer。

另一場比賽,在WikiText-103數據集上進行,測試的是Tokenization (分詞) 。結果認為,SHA-RNN可以有效抵禦Tokenization攻擊。

成功了。

開源了