SMASH:經典One-Shot神經網絡搜索,僅需單卡 | ICLR 2018
SMASH 方法使用輔助網絡生成次優權重來支持網絡的快速測試,從結果來看,生成的權重與正常訓練的權重在準確率上存在關聯性,整體搜索速度很快,僅需要單卡進行搜索,提供了一個很好的新思路。
來源:曉飛的算法工程筆記 公眾號
論文: SMASH: One-Shot Model Architecture Search through HyperNetworks
Introduction
常規的網絡設計需要耗費大量的時間進行驗證,為了節約驗證時間,論文提出訓練一個輔助網絡 HyperNet,用於動態生成不同結構模型的權重。儘管使用這些生成權重的性能不如常規學習到的權重,但在訓練初期,使用生成權重的不同網絡的相對性能可以在一定程度上映射出其最優狀態時的相對性能。因此,論文提出 one-shot 模型結構搜索 SMASH(one-Shot Model Architecture Search through Hypernetworks),結合輔助網絡生成的權重,可以僅使用一輪訓練來對大量的結構進行排序。
One-Shot Model Architecture Search through HyperNetworks
SMASH 的邏輯如算法 1,核心是通過輔助網絡 HyperNet 根據不同的網絡結構生成對應的權重,然後根據驗證集表現進行排序:
- 首先訓練輔助網絡 HyperNet,在每個訓練階段,隨機採樣一個網絡,然後使用 HyperNet 生成權重,end-to-end 地對其進行完整的反向訓練。
- 在訓練好 HyperNet 後,隨機採樣大量的網絡結構,使用 HyperNet 生成的權重,然後在測試集上驗證性能。
- 選擇性能最好的結構進行最終的訓練測試。
SMASH 包含兩個核心部分:
- 如何生成網絡結構。論文採用基於存儲體(memory bank)的前向網絡,能夠生成複雜且多分支的拓撲結構,並且能夠使用二進制向量進行編碼。
- 如何根據網絡結構生成權重。訓練一個輔助網絡 HyperNet,直接學習二進制結構編碼到權重空間的映射。
論文認為,只要 HyperNet 學習到如何生成有效的權重,那麼在驗證集上,使用生成權重的網絡的準確率會和正常訓練的網絡的準確率產生關聯,此時,網絡的結構將會變成影響驗證集準確率的主要因子。
Defining Variable Network Configurations
為了能夠生成多種的網絡結構並且方便編碼輸入 HyperNet,論文採用存儲體(memory-bank)的方式進行網絡表示,將網絡視為一系列初始為 0 的存儲體,每層的操作視為對存儲體的讀寫。對於單分支網絡,網絡包含一個大的存儲體,每次操作都覆蓋存儲體的內容(對 ResNet 是相加),對於 DenseNet 的多分枝網絡,則讀取所有前面的存儲體,然後將結果寫入空的存儲體,而對於 FractalNet,則構造更為複雜。
SMASH 的基礎模型包含多個 block,如圖 2(b),每個 block 包含多個特定分辨率的存儲體,前後 block 間的存儲體分辨率為 1/2 倍,通過1\times 1卷積加平均池化進行下採樣,1\times 1卷積和全連接輸出層的權重是學習來的,不是生成的。
在採樣網絡時,每個 block 中的存儲體個數以及每個存儲體的 channel 數都是隨機的,而 block 中的層則隨機選擇讀寫模式以及處理數據的 op 操作。當讀入多個存儲體時,在 channel 維度對存儲體的 tensor 進行 concat,而寫入時則將結果與每個存儲體中的 tensor 相加。在實驗中,層僅允許讀取所屬的 block 的存儲體。
op 操作包含用於降維1\times 1卷積、多個常規卷積和非線性激活,如圖 2(a),每次隨機選擇 4 個卷積中一個激活,包括其卷積核大小,輸出 channel 等超參也是隨機的,1\times 1卷積的輸出 channel 數與 op 的輸出 channel 數成一定比例,比例也是隨機選取的,特別說明:
- 1\times 1卷積的權重由 HyperNet 生成,其它卷積則通過正常訓練獲得(算法 1 的 first loop)。
- 為了保證可變的深度,每個 block 僅學習 4 個卷積,並且在 block 的 op 操作中共享其權值。限制最大卷積核大小以及最大輸出 channel 數,假設選擇的 op 操作的參數小於最大值,則將權重裁剪至目標大小。
- 下採樣卷積和輸出層同樣基於輸入的 channel 數對權重進行裁剪。
在設計時,為了讓網絡更多地採用 HyperNet 產生的權重,僅在下採樣層中以及輸出層之前使用 BatchNorm,主要由於很難通過生成的方式產生這種運行時統計的結果。為了彌補這一舉措,使用 WeightNorm 的改進版,將生成的1\times 1卷積核除以其歐幾里得範數進行正則化(不是單獨正則化各 channel),這對 SMASH 十分有效,僅帶來些許的性能下降。
Learning to map architectures to weights
Hypernet 採用全卷積網絡,這樣輸出的W可以根據輸入c的大小改變而改變,輸入c為 4 維 tensor(BCHW),batch size 為 1,這樣輸出就不會存在完全獨立性。輸出W的每個 channel 都對應c的一個子集,而權重W對應 op 操作的信息都 embedding 在c的 channel 中。
假設 op 讀取 1,2,4 存儲體然後寫入 2,4 存儲體,則輸入c的 1、2 和 4 channel 會填入 1,代表輸入的存儲體,而 6、8 channel 也會填入 1,代表輸出的存儲體,剩餘的 channel 用於描述 op 的其它超參數,比如膨脹值(dilation),輸入c的 width 方向是對 op 操作的輸出 channel 數的編碼。
基於以上的 Hypernet 結構,na ï ve 的實現要求輸入c的大小和W的大小一致或者使用上採樣來產生更多的輸出,但這樣效果不好。論文使用 channel-based 的權重壓縮方法,不僅能夠減小c的大小,還能保持 HyperNet 的表達能力。簡單講就是將輸入c的分辨率設定為W的大小進行1/k,HyperNet 的輸出 channel 設定為k,最後將結果 reshape 成W的大小,具體可以看看論文的附錄 B。
Experiments
Testing the SMASH correlation
對比 SMASH 生成權重的網絡與正常訓練的網絡的準確率,證明 SMASH 生成的權重可以快速地比較相對準確率。
Benchmarking
CONCLUSION
SMASH 方法使用輔助網絡生成次優權重來支持網絡的快速測試,從結果來看,生成的權重與正常訓練的權重在準確率上存在關聯性,整體搜索速度很快,僅需要單卡進行搜索,提供了一個很好的新思路。
如果本文對你有幫助,麻煩點個贊或在看唄 ~
更多內容請關注 微信公眾號【曉飛的算法工程筆記】