【DL】深度解析LSTM神經網路的設計原理

  • 2020 年 2 月 20 日
  • 筆記

以下文章來源於夕小瑤的賣萌屋 ,作者夕小瑤

引人入勝的開篇:

想要搞清楚LSTM中的每個公式的每個細節為什麼是這樣子設計嗎?想知道simple RNN是如何一步步的走向了LSTM嗎?覺得LSTM的工作機制看不透?恭喜你打開了正確的文章!

前方核彈級高能預警!本文資訊量非常大,文章長且思維連貫性強,建議預留20分鐘以上的時間進行閱讀。

前置知識1:

在上一篇文章《前饋到回饋:解析RNN》中,小夕從最簡單的無隱藏層的前饋神經網路引出了簡單的循環神經網路:

它就是無隱藏層的循環神經網路,起名叫「simple RNN」

這種方式即在每個時刻做決策的時候都考慮一下上一個時刻的決策結果。畫出圖來就是醬的:

(其中圓球里的下半球代表兩向量的內積,上半球代表將內積結果激活)

雖然通過這種簡單回饋確實可以看出每個時間點的決策會受前一時間點決策的影響,但是似乎很難讓人信服這竟然能跟記憶扯上邊!

想一下,人的日常行為流程是這樣的。比如你在搭積木,那麼每個時間點你的行為都會經歷下面的子過程:

1、眼睛看到現在手裡的積木。

2、回憶一下目前最高層的積木的場景。

3、結合1和2的資訊來做出當前時刻積木插到哪裡的決策。

相信聰明的小夥伴已經知道我要表達的意思啦。第1步手裡的積木就是當前時刻的外部輸入X;第2步就是調用歷史時刻的資訊/記憶;第3步就是融合X和歷史記憶的資訊來推理出決策結果,即RNN的一步前向過程的輸出y(t)。

有沒有更加聰明的小夥伴驚奇的注意到第2步!!!我們在回憶歷史的時候,一般不是簡單的回憶上一個積木的形狀,而是去回憶一個更加模糊而宏觀的場景。在這個例子中,這個場景就是最近幾次行為所產生出的抽象記憶——即「積木最高層的地形圖」!

也就是說,人們在做很多時序任務的時候,尤其是稍微複雜的時序任務時,潛意識的做法並不是直接將上個時刻的輸出y(t-1)直接連接進來,而是連接一個模糊而抽象的東西進來!這個東西是什麼呢?

當然就是神經網路中的隱結點h啊!也就是說,人們潛意識裡直接利用的是一段歷史記憶融合後的東西h,而不單單是上一時間點的輸出。而網路的輸出則取自這個隱結點。所以更合理的刻畫人的潛意識的模型應該是這樣的:

(記憶在隱單元中存儲和流動,輸出取自隱單元)

這種加入了隱藏層的循環神經網路就是經典的RNN神經網路!即「standard RNN」

RNN從simple到standard的變動及其意義對於本文後續內容非常重要哦。

前置知識2:

在上一篇文章《從前饋到回饋:循環神經網路(RNN)》中簡單講解和證明過,由於在誤差反向傳播時,算出來的梯度會隨著往前傳播而發生指數級的衰減或放大!而且這是在數學上板上釘釘的事情。因此,RNN的記憶單元是短時的。

好啦,那我們就借鑒前輩設計RNN的經驗,從simple版本開始,即無隱藏層的、簡單完成輸出到輸入回饋的網路結構開始,去設計一個全新的、可以解決梯度爆炸消失問題從而記住長距離依賴關係的神經網路吧!

那麼如何讓梯度隨著時間的流動不發生指數級消失或者爆炸呢?

好像想起來挺難的,但是這個問題可能中學生會解答!那就是讓算出來的梯度恆為1!因為1的任何次方都是1嘛( ̄∇ ̄)

所以按照這個搞笑的想法,我們把要設計的長時記憶單元記為c(以下全部用c指代長時記憶單元),那麼我們設計出來的長時記憶單元的數學模型就是這樣子嘍:

c(t) = c(t-1)

這樣的話,誤差反向傳播時的導數就恆定為1啦~誤差就可以一路無損耗的向前傳播到網路的前端,從而學習到遙遠的前端與網路末端的遠距離依賴關係。

路人:Excuse me?

不要急不要急,反正假設我們的c中存儲了資訊,那麼c就能把這個資訊一路帶到輸出層沒問題吧?在T時刻算出來的梯度資訊存儲在c里後,它也能把梯度一路帶到時刻0而無任何損耗也沒問題吧?對吧( ̄∇ ̄)

所以資訊的運輸問題解決了,那麼就要解決對資訊進行裝箱卸車的問題。

先來看裝箱問題,即如何把新資訊寫入c裡面去呢?

當然要先定義一下新資訊是什麼。不妨直接拿來simple RNN中對新資訊的定義,即當前時刻的外部輸入x(t)與前一時刻的網路輸出(即回饋單元)y(t-1)聯合得到網路在當前這一時刻get到的新資訊,記為

。即:

好,新資訊

定義完成。下面考慮把

加到c裡面去。如果把這個問題拿去問小學生的話,那麼可能會兵分兩路:

1、乘進去!

2、加進去!

那麼這兩種哪種可行呢?

其實稍微一想就很容易判斷:乘法操作更多的是作為一種對資訊進行某種控制的操作(比如任意數與0相乘後直接消失,相當於關閉操作;任意數與大於1的數相乘後會被放大規模等),而加法操作則是新資訊疊加舊資訊的操作。

下面我們深入的討論一下乘性操作加性操作,這在理解LSTM里至關重要。當然,首先,你要掌握偏導的概念和方法、複合函數的求導法則、鏈式求導法則。有了這三點微積分基礎後才能看懂哦。

(害怕數學和基礎不夠的童鞋可以跳過這裡的論乘法和論加法小節。)

論乘法:

乘法時即令長時記憶添加資訊時的數學模型為:

因此網路完整數學模型如下:

公式【0.1】

公式【0.2】

公式【0.3】

為了計算方便,還是像之前一樣假設激活函數為線性激活(即沒有激活函數。實際上tanh在小值時可以近似為線性,relu在正數時也為線性,這個假設還是很無可厚非的),這時網路模型簡化為:

【1】

假如網路經過了T個時間步到了loss端,這時若要更新t=0時刻下網路參數V的權重,則即對t=0時刻的參數V求偏導,即計算

其中

(其中的f_loss(·)為損失函數)

好,稍微一算,發現

中的f_loss'的值就是我們要往前傳的梯度(參數更新資訊),則我們的目標就是討論y'(t=T),寫全了就是

【2】

對V求偏導時其他變數(就是說的W和x)自然也就成了常量,這裡我們再做一個過分簡化,直接刪掉

項!(在y二階乘方存在的情況下忽略一階乘方),這時就可以直接展開公式【1】:

對v(0)求導的話,會得到

如果說RNN的

是音速級的梯度爆炸和消失,那這

簡直是光速級爆炸和消失了吶!~

所以說直接將歷史記憶乘進長時記憶單元只會讓情況更糟糕,導致當初c(t)=c(t-1)讓導數恆為1的構想完全失效,這也說明了乘性更新並不是簡單的資訊疊加,而是控制和scaling。

論加法:

如果改成加性規則呢?此時添加資訊的數學模型為

與前面的做法一樣,假設線性激活並代入網路模型後得到

噫?也有指數項~不過由於v加了一個偏置1,導致爆炸的可能性遠遠大於消失。不過通過做梯度截斷,也能很大程度的緩解梯度爆炸的影響。

嗯~梯度消失的概率小了很多,梯度爆炸也能勉強緩解,看起來比RNN靠譜多了,畢竟控制好爆炸的前提下,梯度消失的越慢,記憶的距離就越長嘛。

因此,在往長時記憶單元添加資訊方面,加性規則要顯著優於乘性規則。也證明了加法更適合做資訊疊加,而乘法更適合做控制和scaling。

由此,我們就確定應用加性規則啦,至此我們設計的網路應該是這樣子的:

【3.1】

【3.2】

【3.3】

那麼有沒有辦法讓資訊裝箱和運輸同時存在的情況下,讓梯度消失的可能性變的更低,讓梯度爆炸的可能性和程度也更低呢?

你想呀,我們往長時記憶單元添加新資訊的頻率肯定是很低的,現實生活中只有很少的時刻我們可以記很久,大部分時刻的資訊沒過幾天就忘了。因此現在這種模型一股腦的試圖永遠記住每個時刻的資訊的做法肯定是不合理的,我們應該只記憶該記的資訊。

顯然,對新資訊選擇記或者不記是一個控制操作,應該使用乘性規則。因此在新資訊前加一個控制閥門,只需要讓公式【3.1】變為

這個

我們就叫做「輸入門」啦,取值0.0~1.0。

為了實現這個取值範圍,我們很容易想到使用sigmoid函數作為輸入門的激活函數,畢竟sigmoid的輸出範圍一定是在0.0到1.0之間嘛。因此以輸入門為代表的控制門的激活函數均為sigmoid,因此控制門:

當然,這是對一個長時記憶單元的控制。我們到時候肯定要設置很多記憶單元的,要不然腦容量也太低啦。因此每個長時記憶單元都有它專屬的輸入門,在數學上我們不妨使用

來表示這個按位相乘的操作,用大寫字母C來表示長時記憶單元集合。即:

【4】

嗯~由於輸入門只會在必要的時候開啟,因此大部分情況下公式【4】可以看成

C(t)=C(t-1),也就是我們最理想的狀態。由此加性操作帶來的梯度爆炸也大大減輕啦,梯度消失更更更輕了。

等等,愛思考的同學可能會注意到一個問題。萬一神經網路讀到一段資訊量很大的文本,以致於這時輸入門欣喜若狂,一直保持大開狀態,狼吞虎咽的試圖記住所有這些資訊,會發生什麼呢?

顯然就會導致c的值變的非常大!

要知道,我們的網路要輸出的時候是要把c激活的(參考公式【0.3】),當c變的很大時,sigmoid、tanh這些常見的激活函數的輸出就完全飽和了!比如如圖tanh:

當c很大時,tanh趨近於1,這時c變得再大也沒有什麼意義了,因為飽和了!腦子記不住這麼多東西!

這種情況怎麼辦呢?顯然relu函數這種正向無飽和的激活函數是一種選擇,但是我們總不能將這個網路輸出的激活函數限定為relu吧?那也設計的太失敗啦!

那怎麼辦呢?

其實想想我們自己的工作原理就知道啦。我們之所以既可以記住小時候的事情,也可以記住一年前的事情,也沒有覺得腦子不夠用,不就是因為我們。。。愛忘事嘛。所以還需要加一個門用來忘事!這個門就叫做「遺忘門」吧。這樣每個時刻到來的時候,記憶要先通過遺忘門忘掉一些事情再考慮要不要接受這個時刻的新資訊。

顯然,遺忘門是用來控制記憶消失程度的,因此也要用乘性運算,即我們設計的網路已進化成:

或者向量形式的:

好啦~解決了如何為我們的長時記憶單元可控的添加新資訊的問題,又貼心的考慮到並優雅解決了資訊輸入太過豐富導致輸入控制門「合不攏嘴」的尷尬情況,那麼是時候考慮我們的長時記憶單元如何輸出啦~

有人說,輸出有什麼好考慮的,當前的輸出難道不就僅僅是激活當前的記憶嗎?難道不就是最前面說的y(t)=f(c(t))?(其中f(·)為激活函數)

試想,假如人有1萬個長時記憶的腦細胞,每個腦細胞記一件事情,那麼我們在處理眼前的事情的時候是每個時刻都把這1萬個腦細胞里的事情都回憶一遍嗎?顯然不是呀,我們只會讓其中一部分跟當前任務當前時刻相關的腦細胞輸出,即應該給我們的長時記憶單元添加一個輸出閥門!也就是說應該輸出:

嗯~終於看起來好像沒有什麼問題了。

那麼我們最後再定義一下控制門們(輸入門、遺忘門、輸出門)受誰的控制就可以啦。

這個問題也很顯然,當然就是讓各個門受當前時刻的外部輸入x(t)和上一時刻的輸出y(t-1)啦,即

。。。。。。?

好像這樣的思維在RNN中並不會有什麼問題,但!是!不要忘了在我們這個新設計的網路中,多了一堆閥門!尤其注意到輸出門,一旦輸出門關閉,就會導致其控制的記憶f(c(t))被截斷,下一時刻各個門就僅僅受當前時刻的外部輸入x(t)控制了!這顯然不符合我們的設計初衷(儘可能的讓決策考慮到儘可能久的歷史資訊)。怎麼辦呢?

最簡單的做法就是再把長時記憶單元接入各個門,即把上一時刻的長時記憶c(t-1)接入遺忘門和輸入門,把當前時刻的長時記憶c(t)接入輸出門(當資訊流動到輸出門的時候,當前時刻的長時記憶已經被計算完成了)。即

當然,這個讓各個門考慮長時記憶的做法是後人打的修補程式,這些從長時記憶單元到門單元的連接被稱為"peephole(貓眼)"

至此還有什麼問題嗎?看起來真沒有問題啦~我們設計的simple版的網路就完成啦,總結一下,即:

就起名叫「門限simple RNN」吧!(非學術界認可)

然而,作為偉大的設計者,怎麼能止步於simple呢!我們要像simple RNN推廣出standardRNN的做法那樣,推廣出我們的standard版本!即加入隱藏層!

為什麼要加隱藏層已經在本文開頭提到了,這也是simpleRNN到standardRNN的核心區別,這也是RNN及其變種可以作為深度學習的主角之一的原因。模仿RNN的做法,我們直接用隱藏層單元h來代替最終輸出y:

顯然,由於h隨時都可以被輸出門截斷,所以我們可以很感性的把h理解為短時記憶單元。

而從數學上看的話,更是短時記憶了,因為梯度流經h的時候,經歷的是h(t)->c(t)->h(t-1)的連環相乘的路徑(在輸入輸出門關閉前),顯然如前邊的數學證明中所述,這樣會發生梯度爆炸和消失,而梯度消失的時候就意味著記憶消失了,即h為短時記憶單元。

同樣的思路可以再證明一下,由於梯度只從c走的時候,存在一條無連環相乘的路徑,可以避免梯度消失。又有遺忘門避免激活函數和梯度飽和,因此c為長時記憶單元。

好啦,我們standard版本的新型網路也完成了!有沒有覺得資訊量超級大,又亂掉了呢?不要急,貼心的小夕就再帶你總結一下我們這個網路的前饋過程:

新時刻t剛剛到來的時候,

1、首先長時記憶單元c(t-1)通過遺忘門g_forget去遺忘一些資訊。

2、其中g_forget受當前時刻的外部輸入x(t)、上一時刻的輸出(短時記憶)h(t-1)、上一時刻的長時記憶c(t-1)的控制。

3、然後由當前時刻外部輸入x(t)和上一時刻的短時記憶h(t-1)計算出當前時刻的新資訊

4、然後由輸入門g_in控制,將當前時刻的部分新資訊

寫入長時記憶單元,產生新的長時記憶c(t)。

5、其中g_in受x(t)、h(t-1)、c(t-1)的控制。

6、激活長時記憶單元c(t),準備上天(輸出)。

7、然後由輸出門g_out把控,將至目前積累下來的記憶c(t)選出部分相關的記憶生成這一時刻我們關注的記憶h(t),再把這部分記憶進行輸出y(t)。

8、其中輸出門g_out受x(t)、h(t-1)和當前時刻的長時記憶c(t)的控制。

前饋的過程寫完了,梯度反傳的過程就讓深度學習平台去自動求導來完成吧~有M傾向的同學可以嘗試對上述過程進行手動求導。

好啦,最後對全文的設計過程總結一下:

1、我們為了解決RNN中的梯度消失的問題,為了讓梯度無損傳播,想到了c(t)=c(t-1)這個樸素卻沒毛病的梯度傳播模型,我們於是稱c為「長時記憶單元」。

2、然後為了把新資訊平穩安全可靠的裝入長時記憶單元,我們引入了「輸入門」。

3、然後為了解決新資訊裝載次數過多帶來的激活函數飽和的問題,引入了「遺忘門」。

4、然後為了讓網路能夠選擇合適的記憶進行輸出,我們引入了「輸出門」。

5、然後為了解決記憶被輸出門截斷後使得各個門單元受控性降低的問題,我們引入了「peephole」連接。

6、然後為了將神經網路的簡單回饋結構升級成模糊歷史記憶的結構,引入了隱單元h,並且發現h中存儲的模糊歷史記憶是短時的,於是記h為短時記憶單元。

7、於是該網路既具備長時記憶,又具備短時記憶,就乾脆起名叫「長短時記憶神經網路(Long Short Term Memory Neural Networks,簡稱LSTM)「啦。

(呼~歷時三天終於完稿了。將12000字的手稿強行壓縮到了5600字,將初稿里5、6個啰哩啰嗦的故事全都刪了。天,我只想說,我再也不抱著講透徹的想法給別人講解LSTM了!!!

參考文獻:

1. Hochreiter S, Schmidhuber J. Long Short-Term Memory[J]. Neural Computation, 1997, 9(8): 1735-1780.

2. Gers F A, Schmidhuber J, Cummins F, et al.Learning to Forget: Continual Prediction with LSTM[J]. Neural Computation,2000, 12(10): 2451-2471.

3. Gers F A, Schraudolph N N, Schmidhuber J, etal. Learning precise timing with lstm recurrent networks[J]. Journal of MachineLearning Research, 2003, 3(1): 115-143.

4. A guide to recurrent neural networks and backpropagation. Mikael Bod ́en.

5. http://colah.github.io/posts/2015-08-Understanding-LSTMs/

6. 《Supervised Sequence Labelling with Recurrent Neural Networks》Alex Graves

7. 《Hands on machine learning with sklearn and tf》Aurelien Geron

8. 《Deep learning》Goodfellow et.

The End