Keras遷移學習
- 2019 年 12 月 18 日
- 筆記
遷移學習
簡單來說遷移學習是把在ImageNet等大型數據集上訓練好的CNN模型拿過來,經過簡單的調整應用到自己的項目上去。

遷移學習的分類
遷移學習分為三種:
- 第一種叫transfer learning。用於影像分類的卷積神經網路由兩部分組成:從一系列卷積層和池化層開始,並以全連接的分類器結束。第一部分稱為模型的卷積基(convolutional base),即全連接層之前的卷積池化部分,特徵提取就是利用預訓練好的的網路模型的卷積基,運行新的數據,並在輸出之上訓練一個新的分類器(見圖1.1)。因此,我們只需要訓練分類器部分,卷積基直接用現成的不動。

為什麼只重用卷積基?能使用相同的分類器嗎?一般來說前面的卷積基提取了低級特徵,這在很多其他類似問題是可以通用的。而最後的全連接層是與具體問題相關的高級特徵,因此不太可復用。
- 第二種是fine tune,即微調,就是讓一部分底層也參與訓練。一般來說,只有在頂層的分類器已經被訓練好之後,才去微調卷積基的頂層。
- 預訓練模型。例如,Caffe庫有一個model zoo,其他人可以在這裡找到各種訓練好的模型的checkpoint。 一個典型的遷移學習過程是這樣的。首先通過transfer learning對新的數據集進行訓練,訓練過一定epoch之後,改用fine tune方法繼續訓練,同時降低學習率。這樣做是因為如果一開始就採用fine tune方法的話,網路還沒有適應新的數據,那麼在進行參數更新的時候,比較大的梯度可能會導致原本訓練的比較好的參數被污染,反而導致效果下降。 微調經驗總結 主要的因素是數據集的大小和原始數據集的相似性。有一點一定記住:網路前幾層學到的是通用特徵,後面幾層學到的是與類別相關的特徵。下面分情況討論:
- 新數據集很小,與原始數據集類似。 因為新數據集比較小,如果fine-tune可能會過擬合;又因為新舊數據集類似,所以可能高層特徵類似,可以使用預訓練網路當做特徵提取器,用提取的特徵訓練線性分類器。
- 新數據集很大,與原始數據集類似。 可以微調,不用擔心過擬合。
- 新數據集很小但與原始數據集非常不同。 新數據集小,最好不要fine-tune,和原數據集不類似,最好也不使用高層特徵。這時可是使用前面層的特徵來訓練SVM分類器。
- 新數據集很大,與原始數據集非常不同。 因為新數據集足夠大,可以重新訓練。但是實踐中fine-tune預訓練模型還是有益的。新數據集足夠大,可以fine-tine整個網路。

程式碼步驟 載入數據 這一步很正常,主要是處理圖片數據和劃分數據集載入MobileNetV2模型(不含全連接層) Keras的應用模組Application提供了帶有預訓練權重的Keras模型,這些模型可以用來進行預測、特徵提取和finetune。你可以從keras.applications模組中導入它。base_model = MobileNetV2(weights='imagenet', include_top=False)
添加新的頂層
def add_new_last_layer(base_model, nb_classes): x = base_model.output x = GlobalAveragePooling2D()(x) # GlobalAveragePooling2D 將 MxNxC 的張量轉換成 1xC 張量,C是通道數 x = Dense(FC_SIZE, activation='relu')(x) predictions = Dense(nb_classes, activation='softmax')(x) model = Model(input=base_model.input, output=predictions) return model
訓練頂層分類器
凍結base_model所有層,然後進行訓練。
def setup_to_transfer_learn(model, base_model): for layer in base_model.layers: layer.trainable = False model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy']) setup_to_transfer_learn(model, base_model)
其實這一步也可以和上一步結合起來寫,更加簡潔:
from keras import models from keras import layers # 在conv_base的基礎上添加全連接分類網路 conv_base = MobileNetV2(weights='imagenet', include_top=False) conv_base.trainable = False model = models.Sequential() model.add(conv_base) model.add(layers.Flatten()) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dense(1, activation='sigmoid'))
對頂層分類器進行fine_tuning
凍結部分層,對頂層分類器進行Fine-tune
Fine-tune以一個預訓練好的網路為基礎,在新的數據集上重新訓練一小部分權重。fine-tune應該在很低的學習率下進行。
def setup_to_finetune(model): for layer in model.layers[:NB_MobileNetV2_LAYERS_TO_FREEZE]: layer.trainable = False for layer in model.layers[NB_MobileNetV2_LAYERS_TO_FREEZE:]: layer.trainable = True model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])
這裡可能比較疑惑的是NB_MobileNetV2_LAYERS_TO_FREEZE是多少呢,怎麼找呢。方法是利用Pycharm的Debug功能,查看base_model.layers中的值。 當然也可以選擇使用layer name來進行選擇:
conv_base.trainable = True set_trainable = False for layer in conv_base.layers: if layer.name == 'block5_conv1': set_trainable = True if set_trainable: layer.trainable = True else: layer.trainable = False
總體程式碼
from keras.applications import MobileNetV2 from keras import layers from keras.models import Model from keras.optimizers import SGD from keras.utils import plot_model FC_SIZE = 256 IM_WIDTH, IM_HEIGHT = 28, 28 nb_classes = 100 NB_MobileNetV2_LAYERS_TO_FREEZE = 149 def add_new_last_layer(base_model, nb_classes): x = base_model.output x = layers.GlobalAveragePooling2D()(x) x = layers.Dense(FC_SIZE, activation='relu')(x) predictions = layers.Dense(nb_classes, activation='softmax')(x) model = Model(inputs=base_model.input, outputs=predictions) return model def setup_to_transfer_learn(model, base_model): for layer in base_model.layers: layer.trainable = False model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) def setup_to_finetune(model): for layer in model.layers[:NB_MobileNetV2_LAYERS_TO_FREEZE]: layer.trainable = False for layer in model.layers[NB_MobileNetV2_LAYERS_TO_FREEZE:]: layer.trainable = True model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy']) if __name__ == '__main__': base_model = MobileNetV2(weights='imagenet', include_top=False) model = add_new_last_layer(base_model, nb_classes) setup_to_transfer_learn(model, base_model) model.fit() setup_to_finetune(model) model.fit() print(model.summary()) plot_model(model, to_file='mobilev2.png', show_shapes=True)
總結
- 在小數據集上,過擬合將是主要問題。數據擴充是處理影像數據時過擬合的強大方法。
- 通過卷積基特徵提取可以利用先前學習的特徵。
- 作為特徵提取的補充,我們可以使用微調來適應新的問題。Reference