通過autograd或jax 快速實現自定義損失函數下的lightgbm

  • 2021 年 1 月 28 日
  • AI

//github.com/HIPS/autogradgithub.com

//github.com/google/jaxgithub.com

autograd目前和xla一起集成到jax里了,google/jaxautograd目前和xla一起積極維護jax去了,不過jax的應用和autograd差不多,就是api有一些不同,不過jax目前在windows上沒法使用,autograd倒是不限制系統,對自動求導感興趣的各位大佬可以直接上手jax,性能高,lightgbm驗證起來非常方便;

之前寫自定義損失函數的時候總是需要自己去推導一下損失函數的一階和二階梯度的表達式,這一塊兒後來找到了sympy,但是總覺得不太方便,後來找到了autograd,順藤摸瓜找到了jax。

下面主要是autograd應用於lightgbm的demo,jax有時間我再好好研究一下,真是個好東西~~~~~

下面我們定義一個smape自定義損失函數:

from autograd import hessian
from autograd import grad
import autograd.numpy as np

def smape_eval(y_true, y_pred):
    y_ture=np.array(y_true)
    result= np.abs(y_pred - y_true) / (np.abs(y_pred) + np.abs(y_true))
    return result

smape求導的一個問題是如果分母為0,也就是y_true=y_pred=0的時候,式子是沒有意義的,在numpy下分母為0計算的結果為nan,使用autograd求導的結果也是一樣的。

然後我們直接使用autograd的egrad功能,

grad_smape=egrad(smape_eval,1)
hessian_smape=egrad(egrad(smape_eval,1))

這裡的函數smape_eval是之前寫自定義評價指標用的,一開始實現smape一直有問題,所以用的mape,但是如果模型能夠直接優化評價指標,可坑效果要更好,比如我們auc這種評價指標難以計算其一階梯度和二階梯度,如果lgb可以直接針對其進行優化,可能效果會更好。

這裡的1用於指定我們要求導的變數,因為smape_eval里有兩個變數,通過1指定我們使用第二個變數,0就是指定第一個變數了:

grad_smape=egrad(smape_eval,1)
hessian_smape=egrad(egrad(smape_eval,1))
print(grad_smape(np.array([0,0,0]),np.array([0,0,0,])))
print(hessian_smape(np.array([1,1,1]),np.array([2,2,2,])))

可以看一下這裡的計算結果;

然後我們就可以開始愉快的定義自定義損失函數了:

def smape_loss(labels,preds):
#    masked_arr = ~((preds==0)&(labels==0))
#    preds, labels = preds[masked_arr], labels[masked_arr]
    grad = grad_smape(labels,preds)
    hess  = hessian_smape(labels,preds)
    grad[np.isnan(grad)]=0
    hess[np.isnan(hess)]=0
    return grad, hess

將smape_loss放到lgb的參數里:

lgb_params = {
          'min_child_weight': 0.03454472573214212, #祖傳參數
          'feature_fraction': 0.3797454081646243,
          'bagging_fraction': 0.4181193142567742,
          'min_data_in_leaf': 106,
          'objective': smape_loss,
          'max_depth': -1,
          'learning_rate': 0.006883242363721497,
          "boosting_type": "gbdt",
          "bagging_seed": 11,
          "verbosity": -1,
          'reg_alpha': 0.3899927210061127,
          'reg_lambda': 0.6485237330340494,
          'random_state': 47,
          'n_estimators':10000,
          'n_jobs':-1,
          #'device': 'gpu',
          #'gpu_platform_id' : 0,
          #'gpu_device_id' : 0,
           'metric':None
          
    
         }


from sklearn.model_selection import KFold
kf=KFold(5)

cols=list(y.columns)
best_score=100
y_pred0 = np.zeros((y.shape[0], y.shape[1]))
y_all_pred0 = np.zeros((n_bag, y_all.shape[0], y_all.shape[1]))
for i in range(y.shape[1]):
    for fold, (train_idx, test_idx) in enumerate(kf.split(X, y)):
        model =lgb.LGBMRegressor(**lgb_params)


        #model.fit(X,y[cols[i]])
        X_train,y_train = X.iloc[train_idx], y.iloc[train_idx][cols[i]] # 每一次單獨對一個時間步建立一個模型
        X_test, y_test = X.iloc[test_idx], y.iloc[test_idx][cols[i]]


        model.fit(X_train,y_train,eval_set=[(X_train,y_train),(X_test,y_test)],early_stopping_rounds=100,  \
                  eval_metric=smape_eval,verbose=50)
        print('done!')
        y_pred = model.predict(X_test)
        y_all_pred = model.predict(X_all)
        y_pred0[test_idx,i:i+1] = y_pred.reshape(-1,1)
        y_all_pred0[fold,:,i:i+1]  = y_all_pred.reshape(-1,1)

        y_pred += test2.AllVisits.values[test_idx]
        y_pred = np.expm1(y_pred)
        y_pred[y_pred < 0.5 * offset] = 0
        res = smape(test2[y_cols[i]].values[test_idx], y_pred)
        y_pred = offset*((y_pred / offset).round())
        res_round = smape(test2[y_cols[i]].values[test_idx], y_pred)

        y_all_pred += test_all2.AllVisits.values
        y_all_pred = np.expm1(y_all_pred)
        y_all_pred[y_all_pred < 0.5 * offset] = 0
        res_all = smape(test_all2[y_cols[i]], y_all_pred)
        y_all_pred = offset*((y_all_pred / offset).round())
        res_all_round = smape(test_all2[y_cols[i]], y_all_pred)
        print('smape train: %0.5f' % res, 'round: %0.5f' % res_round,
              '     smape LB: %0.5f' % res_all, 'round: %0.5f' % res_all_round)

        del X_train,y_train,X_test,y_test;gc.collect()

沒得問題了,有時間好好研究一下jax的用法~感動,不過簡單應用的話用autograd就可以了。