深入理解seq2seq—attention的前置重要知識

  • 2021 年 3 月 23 日
  • AI

為了更好的了解seq2seq以及attention的思想,特意買了本:

看,不得不說,這本書對機器翻譯的講解非常的全面系統,不過我不做機器翻譯的,主要對seq2seq部分感興趣,看完之後做個總結。

1、機器翻譯的任務是一種序列到序列,即sequence to sequence形式的問題,輸入是一個序列,輸出也是一個序列,本質上來說是一種序列轉換的任務,因此使用seq2seq的框架結構非常的適合處理這類問題,當然除了機器翻譯之外,還有許多任務也符合sequence to sequence的形式,例如nlp中的:

以及表格數據中的多元多步時間序列問題。(這意味著transformer可以稍加經過改造來適配時間序列問題,當然已經有人做了,還發了論文放了源程式碼:

憶臻:AAAI21最佳論文Informer:效果遠超Transformer的長序列預測神器!zhuanlan.zhihu.com圖標

個人暫時還沒有把transformer的結構適配到時間序列預測問題中,不過在mxnet的gluonts里,transformer已經納入進來了:

gluonts.model packagets.gluon.ai

以及fastai出的torch版時間序列library:

timeseriesAI/tsaigithub.com圖標

不過整體的完整性上來看,gluonts應該算是目前實現模型最全的,雖然底層使用mxnet,不過別擔心,封裝的非常完整,不太需要接觸mxnet的語法除非你希望能夠自定義model,當然mxnet本身的語法也很簡單,和keras的設計非常類似,上手也是很快的

2、在transformer之前,seq2seq後續的改進基本上是三個方向:

(1) 魔改encoder和decoder的結構,例如encoder部分使用LSTM,雙向LSTM,CNN等;

(2)在encoder或者decoder中加入各種網路組件幫助訓練,例如resnet+deep lstm作為encoder來保證深層的lstm擁有更強大的抽象能力的同時,訓練的效率也得到提升;

(3)encoder和decoder部分的不同設計策略

下面稍微展開一下:

(1)encoder部分的採用過反向LSTM、雙向LSTM和CNN看下面三個圖:

其中圖6-14的結構實際上是LSTM+CNN的複合結構,先使用LSTM(單向or雙向)得到每一個時刻的隱含狀態(或者說語義編碼),而後在語義編碼向量的基礎上再引入CNN的卷積+池化最終得到更高階的句子向量;

而decoder部分基本上差不了太多,以RNN為主。這一階段,seq2seq整體的訓練策略比較簡單,

基本上是encoder生成源句子的高階向量化表示,然後僅使用這個高階向量來decode decoder部分的輸出。

(2). encoder和decoder部分的各類複雜的網路組件,例如resnet ,maxout,zoneout,dropout、glu 等等本質上是一種換湯不換藥的設計,他們其實和網路本身的基礎架構關係不大,大都是為了增強模型的訓練效率和泛化能力等而引入的組件,換句話說,即使不使用seq2seq的結構,在常規的網路設計上,這類組件也可以應用起來;

(3).encoder和decoder的不同涉及策略

馬東什麼:seq2seq by keras 總結zhuanlan.zhihu.com圖標

發現這塊之前總結過,github上有一個用keras寫的seq2seq的項目還有Google的seq2seq都是這個階段的各類seq2seq的encoder和decoder的設計模式的實現:

//google.github.io/seq2seq/google.github.io

//github.com/farizrahman4u/seq2seqgithub.com

可以看到,基本上是基於decoder的解碼方式進行的各種魔改。

而對於decoder的設計模式可以劃分為:

1、greedy search:

圖片來源,之前寫的:

馬東什麼:seq2seq 重溫以及時間序列預測應用zhuanlan.zhihu.com圖標

greed search的思路很貪心,解碼的過程中,選擇一個softmax預測的概率輸出最高的詞作為當前的翻譯結果,然後再將這個詞作為下一個目標詞的輸入和上一個詞的LSTM對應的step階段的hidden state一起預測下一個目標詞,貪心搜索的缺點比較明顯,就是誤差累積的問題,這個和基於seq2seq的時間序列預測任務存在的問題也是一樣,一旦某一步decode的預測出錯,則後面所有的decode的預測都是在上一步誤差的基礎上進行累計,從而使得整體的誤差很高。

2、為了彌補greedy search的問題,beam search被提出

beam search的思路也很簡單,就是取topk個預測概率最高的目標詞同時進入下一個的decode,然後在最終的輸出中才進行最終的top1的選擇。

而與時間序列相關的seq2seq結構無法使用beam search,因為回歸型的seq2seq的輸出沒有概率分布的概念而是輸出一個單值,針對於這個問題實際上時間序列預測 of seq2seq 捨棄了beam search的思路,使用了teacher forcing的擴展方法——curriculum learning,對應的取樣方法叫做scheduled sampling。

[論文解讀]Scheduled Sampling for Sequence Prediction with Recurrent Neural Networksblog.csdn.net

具體可見上。

3、ensemble decoding,使用了集成學習的思想:


前面我們提到了greedy search和beam search,他們的共同點就是解碼的時候,終究還是使用t時刻的預測結果來幫助預測t+1時刻的標籤,這樣始終無法脫離累計誤差的問題,

陳猛:簡說Seq2Seq原理及實現zhuanlan.zhihu.com圖標

圖片來源於上:

這種模式也稱之為 free running mode

為了避免累計誤差的問題,於是乎teacher forcing mode的解碼策略引入:

它是一種網路訓練方法,對於開發用於機器翻譯,文本摘要,影像字幕的深度學習語言模型以及許多其他應用程式至關重要。它每次不使用上一個state的輸出作為下一個state的輸入,而是直接使用訓練數據的標準答案(ground truth)的對應上一項作為下一個state的輸入。
看一下大佬們對它的評價:
存在把輸出返回到模型輸入中的這種循環連接單元的模型可以通過Teacher Forcing機制進行訓練。
這種技術最初被作為反向傳播的替代技術進行宣傳與開發
在動態監督學習任務中經常使用的一種有趣的技術是,在計算過程中用教師訊號 d(t)替換上一個單元的實際輸出 y ( t )。我們稱這種技術為Teacher Forcing。
Teacher Forcing工作原理: 在訓練過程的 t時刻,使用訓練數據集的期望輸出或實際輸出: y(t), 作為下一時間步驟的輸入: x(t+1),而不是使用模型生成的輸出 h(t)。
在訓練過程中接收ground truth的輸出 y(t) 作為 t+1時刻的輸入

思路很簡單,我們則編碼階段使用ground truth(就是真實的decode的目標詞的embedding來輔助decode而不是使用上一步的預測結果來輔助decode)

單純使用teacher forcing策略的缺點在於:

Teacher Forcing同樣存在缺點: 一直靠老師帶的孩子是走不遠的。
因為依賴標籤數據,在訓練過程中,模型會有較好的效果,但是在測試的時候因為不能得到ground truth的支援,(預測的時候沒有ground truth可用,所以實際的過程還是用上一個目標詞的hidden state參與生成下一個目標詞)所以如果目前生成的序列在訓練過程中有很大不同,模型就會變得脆弱。
也就是說,這種模型的cross-domain能力會更差,也就是如果測試數據集與訓練數據集來自不同的領域,模型的performance就會變差。

炫云:Teacher Forcing訓練機制zhuanlan.zhihu.com圖標

beam search在一定程度上緩解了這個問題。


接下來就是重磅,attention機制的引入,這一時期,各類attention-based model層出不窮,專治各種花里胡哨,attention引入的緣由主要在於RNN的問題:

1、對於長文本,encoder部分,越早進入的word,其資訊越難以被保留下來,因為encoder部分的rnn,每當t時刻接受一個新詞的embedding,則包含t時刻之前的word的hidden state會繼續經過變換和當前詞的embedding經過RNN轉換後的資訊進行融合產生新的隱含狀態,這種變換和融合的過程很容易造成前面已經添加的詞的資訊的丟失或改變,越早添加進來的詞,則最終的隱含狀態越難以保留,

例如上圖,如果需要源句子的「economic」影響目標句子中「經濟」這個詞的產生,那麼「economic」這個詞的資訊首先需要保存在編碼器的第一個隱藏狀態中,並且要一直保證這個資訊能夠不發生變化直到編碼器的最後一個隱藏狀態,這個是比較費勁的。

其實這個過程可以理解為傳話遊戲,我們要保證第一個人說的話到達最後一個人的時候不發生變化還是比較困難的。

同時,反向傳播的時候,「經濟」這個詞的損失也同樣要跋山涉水經過同樣長的反向路徑才能傳播回來影響「economic」對應的參數。

2、encoder的最後一個編碼器的輸出,一般稱之為context vector,即對encoder部分的輸入的源句子的高階抽象的文本向量,一般大小是固定的,固定長度的向量的記憶能力有限,特別是輸入的源句子很長的時候,固定長度的context vector的性能會變得很長,如下圖:

因此,引入了 attention機制,接下來的部分就是attention的描述了,之前也已經總結過了:

馬東什麼:從attention到bertzhuanlan.zhihu.com圖標

就不再贅述了。


最後補充一下關於cnn+seq2seq的網路結構。

cnn版的seq2seq的例子相對於RNN少得多,git上找得到的demo也基本上是RNN+CNN這樣的複合結構,直到facebook發表了《Convolutional Sequence to Sequence Learning》,提出了完全使用CNN來構成Seq2Seq模型,用於機器翻譯,超越了Google創造的基於LSTM機器翻譯的效果,我才得以完全窺探CNN+seq2seq的結構,之前wavenet的seq2seq的結構也一直沒有完全理解,因為接受的初始的觀念基本上是基於RNN+seq2seq結構的,而純CNN應用於seq2seq中的機制有一些不同。

關於這部分,打算寫一篇新的介紹cnn+seq2seq的文章來總結,Convolutional Sequence to Sequence Learning和wavenet的cnn seq2seq架構方式。