一分鐘帶你認識深度學習中的知識蒸餾

摘要:知識蒸餾(knowledge distillation)是模型壓縮的一種常用的方法

一、知識蒸餾入門

1.1 概念介紹

知識蒸餾(knowledge distillation)是模型壓縮的一種常用的方法,不同於模型壓縮中的剪枝和量化,知識蒸餾是通過構建一個輕量化的小模型,利用性能更好的大模型的監督信息,來訓練這個小模型,以期達到更好的性能和精度。最早是由Hinton在2015年首次提出並應用在分類任務上面,這個大模型我們稱之為teacher(教師模型),小模型我們稱之為Student(學生模型)。來自Teacher模型輸出的監督信息稱之為knowledge(知識),而student學習遷移來自teacher的監督信息的過程稱之為Distillation(蒸餾)。

1.2 知識蒸餾的種類

圖1 知識蒸餾的種類

1、 離線蒸餾

離線蒸餾方式即為傳統的知識蒸餾,如上圖(a)。用戶需要在已知數據集上面提前訓練好一個teacher模型,然後在對student模型進行訓練的時候,利用所獲取的teacher模型進行監督訓練來達到蒸餾的目的,而且這個teacher的訓練精度要比student模型精度要高,差值越大,蒸餾效果也就越明顯。一般來講,teacher的模型參數在蒸餾訓練的過程中保持不變,達到訓練student模型的目的。蒸餾的損失函數distillation loss計算teacher和student之前輸出預測值的差別,和student的loss加在一起作為整個訓練loss,來進行梯度更新,最終得到一個更高性能和精度的student模型。

2、 半監督蒸餾

半監督方式的蒸餾利用了teacher模型的預測信息作為標籤,來對student網絡進行監督學習,如上圖(b)。那麼不同於傳統離線蒸餾的方式,在對student模型訓練之前,先輸入部分的未標記的數據,利用teacher網絡輸出標籤作為監督信息再輸入到student網絡中,來完成蒸餾過程,這樣就可以使用更少標註量的數據集,達到提升模型精度的目的。

3、 自監督蒸餾

自監督蒸餾相比於傳統的離線蒸餾的方式是不需要提前訓練一個teacher網絡模型,而是student網絡本身的訓練完成一個蒸餾過程,如上圖(c)。具體實現方式 有多種,例如先開始訓練student模型,在整個訓練過程的最後幾個epoch的時候,利用前面訓練的student作為監督模型,在剩下的epoch中,對模型進行蒸餾。這樣做的好處是不需要提前訓練好teacher模型,就可以變訓練邊蒸餾,節省整個蒸餾過程的訓練時間。

1.3 知識蒸餾的功能

1、提升模型精度

用戶如果對目前的網絡模型A的精度不是很滿意,那麼可以先訓練一個更高精度的teacher模型B(通常參數量更多,時延更大),然後用這個訓練好的teacher模型B對student模型A進行知識蒸餾,得到一個更高精度的模型。

2、降低模型時延,壓縮網絡參數

用戶如果對目前的網絡模型A的時延不滿意,可以先找到一個時延更低,參數量更小的模型B,通常來講,這種模型精度也會比較低,然後通過訓練一個更高精度的teacher模型C來對這個參數量小的模型B進行知識蒸餾,使得該模型B的精度接近最原始的模型A,從而達到降低時延的目的。

3、圖片標籤之間的域遷移

用戶使用狗和貓的數據集訓練了一個teacher模型A,使用香蕉和蘋果訓練了一個teacher模型B,那麼就可以用這兩個模型同時蒸餾出一個可以識別狗,貓,香蕉以及蘋果的模型,將兩個不同與的數據集進行集成和遷移。

圖2 圖像域遷移訓練

4、降低標註量

該功能可以通過半監督的蒸餾方式來實現,用戶利用訓練好的teacher網絡模型來對未標註的數據集進行蒸餾,達到降低標註量的目的。

1.4 知識蒸餾的原理

圖3 知識蒸餾原理介紹

一般使用蒸餾的時候,往往會找一個參數量更小的student網絡,那麼相比於teacher來說,這個輕量級的網絡不能很好的學習到數據集之前隱藏的潛在關係,如上圖所示,相比於one hot的輸出,teacher網絡是將輸出的logits進行了softmax,更加平滑的處理了標籤,即將數字1輸出成了0.6(對1的預測)和0.4(對0的預測)然後輸入到student網絡中,相比於1來說,這種softmax含有更多的信息。好模型的目標不是擬合訓練數據,而是學習如何泛化到新的數據。所以蒸餾的目標是讓student學習到teacher的泛化能力,理論上得到的結果會比單純擬合訓練數據的student要好。另外,對於分類任務,如果soft targets的熵比hard targets高,那顯然student會學習到更多的信息。最終student模型學習的是teacher模型的泛化能力,而不是「過擬合訓練數據」

二、動手實踐知識蒸餾

ModelArts模型市場中的efficientDet目標檢測算法目前已經支持知識蒸餾,用戶可以通過下面的一個案例,來入門和熟悉知識蒸餾在檢測網絡中的使用流程。

2.1 準備數據集

數據集使用kaggle公開的Images of Canine Coccidiosis Parasite的識別任務,下載地址:。用戶下載數據集之後,發佈到ModelArts的數據集管理中,同時進行數據集切分,默認按照8:2的比例切分成train和eval兩種。

2.2 訂閱市場算法efficientDet

進到模型市場算法界面,找到efficientDet算法,點擊「訂閱」按鈕

圖4 市場訂閱efficientDet算法

然後到算法管理界面,找到已經訂閱的efficientDet,點擊同步,就可以進行算法訓練

圖5 算法管理同步訂閱算法

2.3 訓練student網絡模型

起一個efficientDet的訓練作業,model_name=efficientdet-d0,數據集選用2.1發佈的已經切分好的數據集,選擇好輸出路徑,點擊創建,具體創建參數如下:

圖6 創建student網絡的訓練作業

得到訓練的模型精度信息在評估結果界面,如下:

圖7 student模型訓練結果

可以看到student的模型精度在0.8473。

2.4 訓練teacher網絡模型

下一步就是訓練一個teacher模型,按照efficientDet文檔的描述,這裡選擇efficientdet-d3,同時需要添加一個參數,表明該訓練作業生成的模型是用來作為知識蒸餾的teacher模型,新起一個訓練作業,具體參數如下:

圖8 teacher模型訓練作業參數

得到的模型精度在評估結果一欄,具體如下:

圖9 teacher模型訓練結果

可以看到teacher的模型精度在0.875。

2.5 使用知識蒸餾提升student模型精度

有了teacher網絡,下一步就是進行知識蒸餾了,按照官方文檔,需要填寫teacher model url,具體填寫的內容就是2.4訓練輸出路徑下面的model目錄,注意需要選到model目錄的那一層級,同時需要添加參數use_offline_kd=True,具體模型參數如下所示:

圖10 採用知識蒸餾的student模型訓練作業參數

得到模型精度在評估結果一欄,具體如下:

圖11 使用知識蒸餾之後的student模型訓練結果

可以看到經過知識蒸餾之後的student的模型精度提升到了0.863,精度相比於之前的student網絡提升了1.6%百分點。

2.6 在線推理部署

訓練之後的模型就可以進行模型部署了,具體點擊「創建模型」

 

圖12 創建模型

界面會自動讀取模型訓練的保存路徑,點擊創建:

圖13 導入模型

模型部署成功之後,點擊創建在線服務:

圖14 部署在線服務

部署成功就可以進行在線預測了:

圖15 模型推理結果展示

三、知識蒸餾目前的應用領域

目前知識蒸餾的算法已經廣泛應用到圖像語義識別,目標檢測等場景中,並且針對不同的研究場景,蒸餾方法都做了部分的定製化修改,同時,在行人檢測,人臉識別,姿態檢測,圖像域遷移,視頻檢測等方面,知識蒸餾也是作為一種提升模型性能和精度的重要方法,隨着深度學習的發展,這種技術也會更加的成熟和穩定。

參考文獻:

[1]Data Distillation: Towards Omni-Supervised Learning

[2]On the Efficacy of Knowledge Distillation

[3]Knowledge Distillation and Student-Teacher Learning for Visual Intelligence: A Review and New Outlooks

[4]Towards Understanding Knowledge Distillation

[5]Model Compression via Distillation and Quantization

 

點擊關注,第一時間了解華為雲新鮮技術~