­

kaldi中CD-DNN-HMM網路參數更新公式手寫推導

  • 2019 年 11 月 5 日
  • 筆記

在基於DNN-HMM的語音識別中,DNN的作用跟GMM是一樣的,即它是取代GMM的,具體作用是算特徵值對每個三音素狀態的概率,算出來哪個最大這個特徵值就對應哪個狀態。只不過以前是用GMM算的,現在用DNN算了。這是典型的多分類問題,所以輸出層用的激活函數是softmax,損失函數用的是cross entropy(交叉熵)。不用均方差做損失函數的原因是在分類問題上它是非凸函數,不能保證全局最優解(只有凸函數才能保證全局最優解)。Kaldi中也支援DNN-HMM,它還依賴於上下文(context dependent, CD),所以叫CD-DNN-HMM。在kaldi的nnet1中,特徵提取用filterbank,每幀40維數據,默認取當前幀前後5幀加上當前幀共11幀作為輸入,所以輸入層維數是440(440 = 40*11)。同時默認有4個隱藏層,每層1024個網元,激活函數是sigmoid。今天我們看看網路的各種參數是怎麼得到的(手寫推導)。由於真正的網路比較複雜,為了推導方便這裡對其進行了簡化,只有一個隱藏層,每層的網元均為3,同時只有weight沒有bias。這樣網路如下圖:

上圖中輸入層3個網元為i1/i2/i3(i表示input),隱藏層3個網元為h1/h2/h3(h表示hidden),輸出層3個網元為o1/o2/o3(o表示output)。隱藏層h1的輸入為 (q11等表示輸入層和隱藏層之間的權值),輸出為。輸出層o1的輸入為(w11等表示隱藏層和輸出層之間的權值),輸出為。其他可類似推出。損失函數用交叉熵。今天我們看看網路參數(以隱藏層和輸出層之間的w11以及輸入層和隱藏層之間的q11為例)在每次迭代訓練後是怎麼更新的。先看隱藏層和輸出層之間的w11。

 

1,隱藏層和輸出層之間的w11的更新

 

 先分別求三個導數的值:

 

 所以最終的w11更新公式如下圖:

 

2,輸入層和隱藏層之間的q11的更新

 

先分別求三個導數的值:

 

所以最終的q11更新公式如下圖:

 

以上的公式推導中如有錯誤,煩請指出,非常感謝!