詳解十大經典機器學習算法——EM算法

本文始發於個人公眾號:TechFlow,原創不易,求個關注

今天是機器學習專題的第14篇文章,我們來聊聊大名鼎鼎的EM算法。

EM算法的英文全稱是Expectation-maximization algorithm,即最大期望算法,或者是期望最大化算法。EM算法號稱是十大機器學習算法之一,聽這個名頭就知道它非同凡響。我看過許多博客和資料,但是少有資料能夠將這個算法的來龍去脈以及推導的細節全部都講清楚,所以我今天博覽各家所長,試着儘可能地將它講得清楚明白。

從本質上來說EM算法是最大似然估計方法的進階版,還記得最大似然估計嗎,我們之前介紹貝葉斯模型的文章當中有提到過,來簡單複習一下。

最大似然估計

假設當下我們有一枚硬幣,我們想知道這枚硬幣拋出去之後正面朝上的概率是多少,於是我們拋了10次硬幣做了一個實驗。發現其中正面朝上的次數是5次,反面朝上的次數也是5次。所以我們認為硬幣每次正面朝上的概率是50%。

從表面上來看,這個結論非常正常,理所應當。但我們仔細分析會發現這是有問題的,問題在於我們做出來的實驗結果和實驗參數之間不是強耦合的。也就是說如果硬幣被人做過手腳,它正面朝上的概率是60%,我們拋擲10次,也有可能得到5次正面5次反面的概率。同理,如果正面朝上的概率是70%,我們也有一定的概率可以得到5次正面5次反面的結果。現在我們得到了這樣的結果,怎麼能說明就一定是50%朝上的概率導致的呢?

那我們應該怎麼辦呢,繼續做實驗嗎?

顯然不管我們做多少次實驗都不能從根本上解決這個問題,既然參數影響的是出現結果的概率,我們還是應該回到這個角度,從概率上下手。我們知道,拋硬幣是一個二項分佈的事件,我們假設拋擲硬幣正面朝上的概率是p,那麼反面朝上的概率就是1-p。於是我們可以帶入二項分佈的公式,算出10次拋擲之後,5次是正面結果在當前p參數下出現的概率是多少。

於是,我們可以得到這樣一條曲線:

也就是正面朝上的概率是0.5的時候,10次拋擲出現5次正面的概率最大。我們把正面朝上的概率看成是實驗當中的參數,我們把似然看成是概率。那麼最大似然估計,其實就是指的是使得當前實驗結果出現概率最大的參數。

也就是說我們通過實驗結果和概率,找出最有可能導致這個結果的原因或者說參數,這個就叫做最大似然估計。

原理理解了,解法也就順水推舟了。

首先,我們需要用函數將實驗結果出現的概率表示出來。這個函數的學名叫做似然函數(likelihood function)。

有了函數之後,我們需要對函數進行化簡,比如一些多次進行的實驗,需要對似然函數求對數,將累乘計算轉化成累加運算等。

最後,我們對化簡完的似然函數進行求導,令導數為0,找出極值點處參數的值,就是我們通過最大似然估計方法找到的最佳參數。

引入隱變量

以上只是最大似然估計的基礎用法,如果我們把問題稍微變化一下,引入多一個變量,會發生什麼情況呢?

我們來看一個經典的例子,同樣是拋硬幣,但是我們將題目的條件稍作修改,那麼整個問題就會完全不同。

這個例子來源於闡述EM算法的經典論文:《Do, C. B., & Batzoglou, S. (2008). What is the expectation maximization algorithm?. Nature biotechnology, 26(8), 897.》在這個例子當中,我們有A和B兩枚硬幣,其中A硬幣正面朝上的概率是0.5,B硬幣正面朝上的概率是0.4,我們隨機從兩枚硬幣當中選取一枚進行實驗。

每次實驗我們一共進行5次,記錄下正反面的個數。經過5輪實驗之後,我們得到的結果如下:

由於我們知道每一輪當中選擇了什麼硬幣進行實驗,所以整個過程依然非常順利。如果我們去掉硬幣的信息,假設我們並不知道每一輪當中選擇了什麼硬幣進行實驗,我們又該怎麼求A和B向上的概率呢?

在新的實驗當中,我們不知道硬幣選擇的情況,也就是說實驗當中隱藏了一個我們無法得知的變量。這種變量稱為隱變量,隱變量的存在干擾了參數和實驗結果的直接聯繫。比如在這個問題當中,我們想要知道每種硬幣正面向上的概率,我們要計算這個概率首先要知道每一輪用了哪一種硬幣。如果我們想要推算每一次實驗用了哪一種硬幣又需要先知道硬幣正面朝上的概率。也就是說這兩個變量互相糾纏、互相依賴,我們已知的信息太少,無法直接解開。就好像先有雞還是先有蛋的問題,陷入死循環。

EM算法正是為了解決這個問題誕生的。

EM算法

前面我們說了,隱變量和我們想要求的參數互相糾纏,形成了一個死循環,但是我們已有的信息不足以讓我們解開這個糾纏。既然無法解開,那麼我們就不解了,我們直接暴力破解

是的,你沒有看錯,EM算法的本質非常簡單粗暴:既然我們無法求解隱變量,我們就不求了,我們直接假設一個初始值代入計算,有了結果之後再進行迭代。

比如我們假設p1是硬幣A正面向上的概率,p2是硬幣B正面向上的概率。原本我們是希望通過最大似然估計來求解使得結果出現的p1和p2,現在我們直接假設,進行迭代:

我們假設p1=0.7,p2=0.3,這個值是我們隨便假設的,你可以任意假設其他的值。我們把p1,p2代入上面的結果當中進行計算。

比如第一輪當中,出現的結果是3正2反,如果是A硬幣,出現這樣結果的概率根據二項分佈很容易計算:(0.7^3 * 0.3^2 = 0.03087),同理,我們可以算出硬幣B的概率是0.01323。我們用同樣的方法算出所有的概率:

既然我們概率有了,顯然我們可以做預測了,根據這個概率表猜測每一輪究竟用了哪一個硬幣。

根據最大似然的法則,我們可以得出每一輪用的硬幣是:

第一輪是硬幣A

第二輪是硬幣B

第三輪是硬幣B

第四輪是硬幣A

第五輪是硬幣B

猜測出硬幣的分佈之後有什麼用呢?很簡單,我們可以用猜測的結果重新估計p1和p2的值

比如說硬幣A出現在第一輪和第四輪當中,這兩輪一共做了10次實驗,其中6正4反,那麼我們可以修正p1的值為0.6。硬幣B出現在第2,3,5輪當中,這三輪當中做了15次實驗,一共5正10反,所以正面向上的概率是1/3。可以發現,經過了一次迭代之後,我們的結果向真實值逼近了一些

雖然結果還可以,但這種方法依然比較粗糙,我們還有更好的辦法。

例子改進

我們來改進一下上面這個例子的計算過程,主要的問題在於我們在根據假設出來的概率計算分佈之後,我們直接通過似然估計去猜測當前輪次拋了哪一枚硬幣。這樣做當然是可以的,但感覺不夠嚴謹,因為我們直接猜測有些武斷,並不一定準確。

那有沒有更好的辦法?

其實是有的,相比於直接猜測某個輪次當中選擇了哪一枚硬幣,我們可以用選擇硬幣的概率來代入來計算期望,這樣的效果會更好,比如根據剛才的計算結果,我們可以算出每個輪次當中選擇硬幣的概率:

我們在用這個概率帶入實驗結果當中計算期望,可以得到p1的期望表格:

[p_1 = frac{2.1+0.6+0.0729+2.1+0.6}{2.7+2.7+0.292+2.7+2.7}=0.490 ]

同樣的方法,我們可以算計出新的p2的期望表格:

代入,我們可以得到新的p2是0.377

把估計結果改成使用概率代入迭代之後,我們的估計的結果精準了許多,也就是說我們收斂的速度更快了。我們重複以上的過程,直到收斂,當收斂的時候,我們就能獲得極大似然估計最大時候p1和p2的取值。這也是整個EM算法的精髓。

我們整理一下EM算法的運作過程,首先我們先隨機出來一個參數的值代入實驗結果,計算出隱變量的概率分佈或者是取值,我們再通過隱變量迭代我們的參數值,如此重複迭代,直到收斂。我們進一步抽象,可以把它主要總結成兩個步驟,分別是E步驟和M步驟

在E步驟當中,我們根據假設出來的參數值計算出未知變量的期望估計,應用在隱變量上
在M步驟當中,我們根據隱變量的估計值,再計算當前參數的極大似然估計

根據這個理論,我們還可以對上面的過程進行改進。

這個方法到這裡就介紹完了,我想大家也應該都能理解,但是我們還沒有從數學上去證明,為什麼這樣操作行得通呢?為什麼這個方法一定會收斂,我們收斂的值就是最優解呢?所以我們還需要通過數學來證明一下。

數學證明

假設我們有一個樣本集X它是由m個樣本構成的,可以寫成(X={x_1, x_2, cdots x_m}),對於這m個樣本當中,它們都有一個隱變量z是未知的。並且還有一個參數(theta),也就是我們希望通過極大似然估計求解的參數。由於當中包含隱變量z,所以我們沒辦法直接對概率函數求導求極值進行計算。

我們先寫出含有隱變量的概率函數:

[P_i = P(x_i, z_i; theta) ]

我們希望找到對於全局最優的參數(theta),所以我們希望找到使得(prod_{i=1}^mP_i)最大,我們對這個式子求log,可以得到:

[sum_{i=1}^mlog P_i= sum_{i=1}^m log sum_{z_i}P(x_i, z_i; theta) ]

我們假設隱變量z的概率分佈是(Q_i),所以上式可以變形為:

[sum_{i=1}^mlog P_i= sum_{i=1}^m log sum_{z_i}Q_i(z_i)frac{P(x_i, z_i; theta)}{Q_i(z_i)} ]

到這裡似乎卡住了,其實沒有,我們在之前的文章當中寫過,對於凸函數有Jensen不等式:E[f(x)] >= f(E[x]),即函數的期望值大於等於期望值的函數值。而對數函數是廣義上的凸函數,嚴格意義上的凹函數,它可以使用Jensen不等式,但是不等號的方向需要變號。

而上式當中(Q_i(z_i))是隱變量的概率分佈,所以(sum_{z_i}Q_i(z_i)[frac{P(x_i, z_i; theta)}{Q_i(z_i)}])(frac{P(x_i, z_i; theta)}{Q_i(z_i)})的期望,於是我們可以代入Jensen不等式得到:

[sum_{i=1}^mlog P_i geq sum_{i=1}^m sum_{z_i}Q_i(z_i)log frac{P(x_i, z_i; theta)}{Q_i(z_i)} ]

上面這個不等號右邊的式子就容易求解多了,當我們固定z變量的時候,我們可以很方便地求解似然最大時的參數(theta)。同理當我們有了(theta)的取值之後,又可以來優化z。這種兩個變量固定一個,輪流優化另一個的方法叫做坐標上升法,也是機器學習當中非常常用的求解方式。

如上圖所示,這個一圈一圈的是損失函數的等高線。當我們使用坐標上升法的時候,我們每次固定一個軸的變量,優化另一個變量,然後交替進行,我們同樣可以得到全局最優解。

除此之外,我們也可以從數學上進行解釋。

由於上面的式子是一個不等式,我們沒有辦法直接求解左邊的最值,所以我們通過不斷優化右邊式子的方法來逼近左邊的最值。我們令左邊的一串式子是(L(theta)),不等號右邊的式子是(J(z, theta)),然後我們來看一張圖,這張圖是我從大神的博客里找來的神圖:

上圖當中最上方的紅色是(L(theta)),下面的圖像是J。我們每次固定z,都可以找到一個更好的(theta),使得(J(z, theta))朝着高點不斷逼近,最終達到它的最大值。

直覺上這是OK的,但是我們還需要從數學上來證明。

根據Jensen不等式,只有當自變量x是常數的時候才可以取等,我們的自變量是(frac{P(x_i, z_i; theta)}{Q_i(z_i)}),我們令它等於常數c:

[frac{P(x_i, z_i, theta)}{Q_i(z_i)} = c ]

由於(sum_{z_i}Q_i(z_i)=1),所以我們可以知道(sum_{z_i}P(x_i,z_i, theta)=c),我們代入上式,可以得到:

[begin{aligned} Q_i(z_i)cdot c &= P(x_i, z_i, theta) \ Q_i(z_i) &=frac{P(x_i. z_i; theta)}{c}\ Q_i(z_i) &= frac{P(x_i. z_i; theta)}{sum_{z_i}P(x_i, z_i; theta)}\ Q_i(z_i) &= frac{P(x_i. z_i; theta)}{P(x_i; theta)}\ Q_i(z_i) &= P(z_i|x_i; theta)\ end{aligned} ]

經過這一串變形之後,我們得到了(Q_i(z_i))的計算公式其實是一個後驗概率。這一步也就是我們剛才介紹的E步,之後,在確定了(Q_i(z_i))之後,我們來求導求極值的方法求使得函數最大時的(theta),也就是剛才的M步。

所以,整個EM算法的過程就是重複這個過程,直到收斂。

那麼我們又該怎麼保證算法能夠一定收斂呢?其實也不難,由於我們在進行E步驟的時候遵循了Jensen不等式的取等條件求出的z,所以可以保證能夠取到等號,也就是:

[L(theta_t) = sum_{i=1}^m sum_{z_i}Q_i(z_i)log frac{P(x_i, z_i; theta_t)}{Q_i(z_i)} ]

當我們固定(Q_i(z_i))求導得到極大化的參數(theta_{t+1})之後,我們得到右式,一定是優於(L(theta))的,但是我們不能確定對於新的(theta_{t+1}),我們之前的(Q_i(t_i))的分佈也能滿足Jensen不等式的取等條件,所以:

[begin{aligned} L(theta_{t+1}) &geq sum_{i=1}^m sum_{z_i}Q_i(z_i)log frac{P(x_i, z_i; theta_{t+1})}{Q_i(z_i)}\ &geq sum_{i=1}^m sum_{z_i}Q_i(z_i)log frac{P(x_i, z_i; theta_{t})}{Q_i(z_i)} \ &=L(theta) end{aligned} ]

這樣我們就證明了似然函數的取值是在遞增的,當最後收斂的時候,就是最大似然估計時的值,此時的參數(theta)就是我們需要的最大似然估計方法得出的參數。

總結

到這裡,EM算法就算是介紹完了。整個算法給我最大的感受是這又是一個建立在數學推導上的算法,它的推導過程非常嚴謹,效果也非常好,通過它可以解決很多直觀上無法解決的問題。並且更難得的是,即使我們拋棄掉數學上嚴謹的證明和推導,也不妨礙我們直觀地理解算法的思路。難怪該算法可以列入十大機器學習算法之一,的確非常經典。

最後,不知道大家在看的時候有沒有一種感覺,就是EM算法的思路好像之前在什麼地方見到過?有種似曾相識的感覺?

有這種感覺是對的,如果你回想一下之前講的Kmeans,你會發現我們好像也是一開始的時候由於不知道聚類的中心進行了猜測。然後通過迭代一點一點地逼近。如果再多想一點,可以發現Kmeans的計算過程是可以和EM算法的過程相印證的。通過建模我們是可以把Kmeans的問題轉化成EM算法的模型,感興趣的同學可以研究一下這個問題,當然也可以期待一下我們後續的文章。

最後,關於EM算法的內容就到這裡,如果覺得有所收穫,請順手點個關注或者轉發吧,你們的舉手之勞對我來說很重要。