PyTorch 中 weight decay 的設置

  • 2020 年 10 月 21 日
  • 筆記

先介紹一下 Caffe 和 TensorFlow 中 weight decay 的設置:

  • Caffe 中, SolverParameter.weight_decay 可以作用於所有的可訓練參數, 不妨稱為 global weight decay, 另外還可以為各層中的每個可訓練參數設置獨立的 decay_mult, global weight decay 和當前可訓練參數的 decay_mult 共同決定了當前可訓練參數的 weight decay.
  • TensorFlow 中, 某些接口可以為其下創建的可訓練參數設置獨立的 weight decay (如 slim.conv2d 可通過 weights_regularizer, bias_regularizer 分別為其下定義的 weight 和 bias 設置不同的 regularizer).

在 PyTorch 中, 模塊 (nn.Module) 和參數 (nn.Parameter) 的定義沒有暴露與 weight decay 設置相關的 argument, 它把 weight decay 的設置放到了 torch.optim.Optimizer (嚴格地說, 是 torch.optim.Optimizer 的子類, 下同) 中.

torch.optim.Optimizer 中直接設置 weight_decay, 其將作用於該 optimizer 負責優化的所有可訓練參數 (和 Caffe 中 SolverParameter.weight_decay 的作用類似), 這往往不是所期望的: BatchNorm 層的 \(\gamma\)\(\beta\) 就不應該添加正則化項, 卷積層和全連接層的 bias 也往往不用加正則化項. 幸運地是, torch.optim.Optimizer 支持為不同的可訓練參數設置不同的 weight_decay (params 支持 dict 類型), 於是問題轉化為如何將不期望添加正則化項的可訓練參數 (如 BN 層的可訓練參數及卷積層和全連接層的 bias) 從可訓練參數列表中分離出來. 筆者借鑒網上的一些方法, 寫了一個滿足該功能的函數, 沒有經過嚴格測試, 僅供參考.

"""
作者: 採石工
博客: //www.cnblogs.com/quarryman/
發佈時間: 2020年10月21日
版權聲明: 自由分享, 保持署名-非商業用途-非衍生, 知識共享3.0協議.
"""
import torch
from torchvision import models


def split_parameters(module):
    params_decay = []
    params_no_decay = []
    for m in module.modules():
        if isinstance(m, torch.nn.Linear):
            params_decay.append(m.weight)
            if m.bias is not None:
                params_no_decay.append(m.bias)
        elif isinstance(m, torch.nn.modules.conv._ConvNd):
            params_decay.append(m.weight)
            if m.bias is not None:
                params_no_decay.append(m.bias)
        elif isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
            params_no_decay.extend([*m.parameters()])
        elif len(list(m.children())) == 0:
            params_decay.extend([*m.parameters()])
    assert len(list(module.parameters())) == len(params_decay) + len(params_no_decay)
    return params_decay, params_no_decay


def print_parameters_info(parameters):
    for k, param in enumerate(parameters):
        print('[{}/{}] {}'.format(k+1, len(parameters), param.shape))
        
        
if __name__ == '__main__':
    model = models.resnet18(pretrained=False)
    params_decay, params_no_decay = split_parameters(model)
    print_parameters_info(params_decay)
    print_parameters_info(params_no_decay)

參考

版權聲明

版權聲明:自由分享,保持署名-非商業用途-非衍生,知識共享3.0協議。
如果你對本文有疑問或建議,歡迎留言!轉載請保留版權聲明!
如果你覺得本文不錯, 也可以用微信讚賞一下哈.