ELECTRA:用Bert欺騙Bert

  • 2020 年 3 月 31 日
  • 筆記

18年Bert橫空出世,刷新了各大榜單之後,對齊的改進層出不窮,Ernie, Albert大多數改進都沒有新意,無非就是大力出奇蹟的堆疊參數。ICLR2020 斯坦福和Google為我們提供了一種新思路,用Bert來"欺騙"Bert。今天我們來為大家介紹ELECTRA的思路。

背景

目前以Bert為代表的state of art的預訓練模型都是基於MLM(Masked Language modeling)來進行預訓練的,這些模型將輸入的句子中15%的Mask掉,然後使用模型去預測被mask掉的原始內容。當然這些模型也面臨一個問題就是模型因為參數過多,每次訓練只能學習到訓練數據的15%的內容,從而導致計算量過大的問題。

因此文章中提出了一種新的訓練方法:隨機替換句子中的token使用模型去判斷這個token是否被替換過。ELECTRA的效果有多兇殘呢,我們看下圖,左圖是右圖的方法版,橫軸是預訓練的FLOPs(floating point operations),TF中的浮點數計算統計量,縱軸是GLUE的分數,在同等計算量的情況下,ELECTRA一直碾壓Bert,在訓練到一定程度之後可以到達RoBERTa的效果。

訓練方法

ELECTRA的主要貢獻是在預訓練中將MLM(Masked Language Model)替換為RTD任務(Replace Token Detection)預測Token是否被替換。這個任務由兩個模型來實現,Generator和Discriminator。Generator負責生成被替換的Token,用Discriminator去判斷每個Token是否是被替換過的。

那麼如何替換Token以及判斷呢?下面我們來分別講一下Generator和Discriminator.

1. Generator

有部落客嘗試過隨機替換Token的方法,但是效果並不好,因為隨機替換太簡單了。那麼文中是怎麼做的呢?在MLM任務中我們會隨機Mask掉一部分位置的Token並訓練模型去預測這一部分,文中也借鑒了這個思想,使用Generator訓練了一小的MLM任務,然後Discriminator去判斷這些Token是否被Generator替換過。

對於特定位置t,我們假設該位置被mask掉了,那麼該位置被預測為x_{t}的概率為:

其中x=[x_{1},x_{2},…,x_{n}] 是輸入的Token序列,h(x)=[h_{1},h_{2},…,h_{n}] 是經過MLM之後的輸出,其中e(x_{t})^{T} 是token的embedding。

對於Generator的loss如何計算呢?可以看到文中是這麼計算的,對我們mask的Token的概率進行log然後求期望。

現在這麼說或許還是有點模糊,具體來看一下ELECTRA是如何實現的吧。首先對輸入的sequence按照一定概率進行mask,輸入模型的config,預訓練的input,不能被mask的位置和已經被mask的位置,然後我們對預訓練的輸入按照一定的概率產生返回結果我們對輸入的數據進行mask,最後對input進行處理轉換成一個dict,裡面存儲了input_id, masked_lmposition等內容。

def mask(config: configure_pretraining.PretrainingConfig,           inputs: pretrain_data.Inputs, mask_prob, proposal_distribution=1.0,           disallow_from_mask=None, already_masked=None):      # Get the batch size, sequence length, and max masked-out tokens    N = config.max_predictions_per_seq    B, L = modeling.get_shape_list(inputs.input_ids)      # Find indices where masking out a token is allowed    vocab = tokenization.FullTokenizer(        config.vocab_file, do_lower_case=config.do_lower_case).vocab    candidates_mask = _get_candidates_mask(inputs, vocab, disallow_from_mask)      # Set the number of tokens to mask out per example    num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32)    num_to_predict = tf.maximum(1, tf.minimum(        N, tf.cast(tf.round(num_tokens * mask_prob), tf.int32)))    masked_lm_weights = tf.cast(tf.sequence_mask(num_to_predict, N), tf.float32)    if already_masked is not None:      masked_lm_weights *= (1 - already_masked)      # Get a probability of masking each position in the sequence    candidate_mask_float = tf.cast(candidates_mask, tf.float32)    sample_prob = (proposal_distribution * candidate_mask_float)    sample_prob /= tf.reduce_sum(sample_prob, axis=-1, keepdims=True)      # Sample the positions to mask out    sample_prob = tf.stop_gradient(sample_prob)    sample_logits = tf.log(sample_prob)    masked_lm_positions = tf.random.categorical(        sample_logits, N, dtype=tf.int32)    masked_lm_positions *= tf.cast(masked_lm_weights, tf.int32)      # Get the ids of the masked-out tokens    shift = tf.expand_dims(L * tf.range(B), -1)    flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1])    masked_lm_ids = tf.gather_nd(tf.reshape(inputs.input_ids, [-1]),                                 flat_positions)    masked_lm_ids = tf.reshape(masked_lm_ids, [B, -1])    masked_lm_ids *= tf.cast(masked_lm_weights, tf.int32)      # Update the input ids    replace_with_mask_positions = masked_lm_positions * tf.cast(        tf.less(tf.random.uniform([B, N]), 0.85), tf.int32)    inputs_ids, _ = scatter_update(        inputs.input_ids, tf.fill([B, N], vocab["[MASK]"]),        replace_with_mask_positions)      return pretrain_data.get_updated_inputs(        inputs,        input_ids=tf.stop_gradient(inputs_ids),        masked_lm_positions=masked_lm_positions,        masked_lm_ids=masked_lm_ids,        masked_lm_weights=masked_lm_weights    )

文中的Generator採用了Bert模型,也正如文中所說,文中使用Bert對Mask的Token進行預測,對每一個位置的Mask計算loss最後求和,loss的計算過程如下。輸入是我們剛才處理過的maskedinputs和Generator。為了將Bert計算的logits轉換為預測的Label,程式碼在Generator之後加了一層全連接層和sofmax,然後將預測的label轉為one_hot編碼,然後採用上述公式計算Mask部分的loss。

def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model):      """Masked language modeling softmax layer."""      masked_lm_weights = inputs.masked_lm_weights      with tf.variable_scope("generator_predictions"):        if self._config.uniform_generator:          logits = tf.zeros(self._bert_config.vocab_size)          logits_tiled = tf.zeros(              modeling.get_shape_list(inputs.masked_lm_ids) +              [self._bert_config.vocab_size])          logits_tiled += tf.reshape(logits, [1, 1, self._bert_config.vocab_size])          logits = logits_tiled        else:          relevant_hidden = pretrain_helpers.gather_positions(              model.get_sequence_output(), inputs.masked_lm_positions)          hidden = tf.layers.dense(              relevant_hidden,              units=modeling.get_shape_list(model.get_embedding_table())[-1],              activation=modeling.get_activation(self._bert_config.hidden_act),              kernel_initializer=modeling.create_initializer(                  self._bert_config.initializer_range))          hidden = modeling.layer_norm(hidden)          output_bias = tf.get_variable(              "output_bias",              shape=[self._bert_config.vocab_size],              initializer=tf.zeros_initializer())          logits = tf.matmul(hidden, model.get_embedding_table(),                             transpose_b=True)          logits = tf.nn.bias_add(logits, output_bias)          oh_labels = tf.one_hot(            inputs.masked_lm_ids, depth=self._bert_config.vocab_size,            dtype=tf.float32)          probs = tf.nn.softmax(logits)        log_probs = tf.nn.log_softmax(logits)        label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1)          numerator = tf.reduce_sum(inputs.masked_lm_weights * label_log_probs)        denominator = tf.reduce_sum(masked_lm_weights) + 1e-6        loss = numerator / denominator        preds = tf.argmax(log_probs, axis=-1, output_type=tf.int32)          MLMOutput = collections.namedtuple(            "MLMOutput", ["logits", "probs", "loss", "per_example_loss", "preds"])        return MLMOutput(            logits=logits, probs=probs, per_example_loss=label_log_probs,            loss=loss, preds=preds)

當然文中的思路是很清晰,但是我有一點疑惑,就是在預訓練過程中的loss是Generator和Discriminator的loss求和,當然為了保證效果loss肯定是希望變小的。不過為了保證隨機生成的效果,這裡應該是預測的和原文本出入比較大比較好,那麼如何去平衡loss變小的問題呢?或許可以借鑒NSP(Next Sentence Prediction)任務中隊token進行跨領域的替換?或者在最終的loss中按照權重,增大Discriminator的權重?這裡還是比較讓人迷惑的,而且文中說的天花亂墜的思路看來只不過是Bert的二次利用,感覺有一丟丟受騙的感覺。

2. Discriminator

Disctiminator根據generator生成的輸入去判斷是否是生成的Token,這樣就將問題轉化為一個二分類問題,對每一個Token我們將其分成是生成Token或者不是,那麼如此使用交叉熵來表示就是個非常好的選擇了。事實上,論文中對Discriminator的loss採用的也是交叉熵。

下面我們來看一下程式碼中如何實現Discriminator,如論文中所述,Discriminator和Generator都採用Bert,不同於Generator,Discriminator的輸入是經過Generator生成之後的fake_input,label表示Token是否是fake。

def _get_discriminator_output(self, inputs, discriminator, labels):      """Discriminator binary classifier."""      with tf.variable_scope("discriminator_predictions"):        hidden = tf.layers.dense(            discriminator.get_sequence_output(),            units=self._bert_config.hidden_size,            activation=modeling.get_activation(self._bert_config.hidden_act),            kernel_initializer=modeling.create_initializer(                self._bert_config.initializer_range))        logits = tf.squeeze(tf.layers.dense(hidden, units=1), -1)        weights = tf.cast(inputs.input_mask, tf.float32)        labelsf = tf.cast(labels, tf.float32)        losses = tf.nn.sigmoid_cross_entropy_with_logits(            logits=logits, labels=labelsf) * weights        per_example_loss = (tf.reduce_sum(losses, axis=-1) /                            (1e-6 + tf.reduce_sum(weights, axis=-1)))        loss = tf.reduce_sum(losses) / (1e-6 + tf.reduce_sum(weights))        probs = tf.nn.sigmoid(logits)        preds = tf.cast(tf.round((tf.sign(logits) + 1) / 2), tf.int32)        DiscOutput = collections.namedtuple(            "DiscOutput", ["loss", "per_example_loss", "probs", "preds",                           "labels"])        return DiscOutput(            loss=loss, per_example_loss=per_example_loss, probs=probs,            preds=preds, labels=labels,        )

在上一部分討論Generator的時候,我談到了loss的問題,果然在這一部分就被打臉,因為我們目標就是使整體的loss最小化,而最小的loss果然給DIscriminator賦予了權值。而實際上在實現過程中Generator也有相應的權值:self.total_loss 的計算首先是Generator的權重乘以loss然後加上Discriminator的權重乘以loss。如此嚴謹有理有據也讓我表示自己沒有想多哈哈,內心還是有點小竊喜呢,嘻嘻嘻。

def __init__(self, config: configure_pretraining.PretrainingConfig,                 features, is_training):      # Set up model config      self._config = config      self._bert_config = training_utils.get_bert_config(config)      if config.debug:        self._bert_config.num_hidden_layers = 3        self._bert_config.hidden_size = 144        self._bert_config.intermediate_size = 144 * 4        self._bert_config.num_attention_heads = 4        # Mask the input      masked_inputs = pretrain_helpers.mask(          config, pretrain_data.features_to_inputs(features), config.mask_prob)        # Generator      embedding_size = (          self._bert_config.hidden_size if config.embedding_size is None else          config.embedding_size)      if config.uniform_generator:        mlm_output = self._get_masked_lm_output(masked_inputs, None)      elif config.electra_objective and config.untied_generator:        generator = self._build_transformer(            masked_inputs, is_training,            bert_config=get_generator_config(config, self._bert_config),            embedding_size=(None if config.untied_generator_embeddings                            else embedding_size),            untied_embeddings=config.untied_generator_embeddings,            name="generator")        mlm_output = self._get_masked_lm_output(masked_inputs, generator)      else:        generator = self._build_transformer(            masked_inputs, is_training, embedding_size=embedding_size)        mlm_output = self._get_masked_lm_output(masked_inputs, generator)      fake_data = self._get_fake_data(masked_inputs, mlm_output.logits)      self.mlm_output = mlm_output      self.total_loss = config.gen_weight * mlm_output.loss

3. GAN

相信看到這裡的小夥伴們和我心理都有個疑問,這個和GAN的區別是什麼呢?文中對此也做了解釋:

這裡我們已經大概了解了ELECTRA的設計思想了,文章認為MLM任務比較簡單,而且Mask的Token位置比較少,導致模型學習到的內容有限。而基於此觀點,文章設計了比較精巧的生成器-判別器模式,而為了保證模型可以學習到語料的全部內容,生成器也避免簡單地隨機替換。這種類似於GAN又和GAN有所區別的思路,初看讓人激動不已。然而看了源碼,我的激動也逐漸理性,相比於文章的花哨,源碼看上去更像是對Bert的一次組裝,而生成器和判別器一起訓練,Loss一起計算也讓我有一丟疑惑。但是整體來說也是一次讓人激動地嘗試。那麼模型的效果如何呢?我們將在下一篇文章中進行解釋,請大家保持期待。

實驗結果

1. 共享參數(Weight Sharing)

Generator和Discriminator應該共享參數嗎?文中嘗試了3中方法,不共享參數,共享embedding和共享所有參數,從效果來看共享所有參數的效果是最優的,但是這也意味著生成器和分辨器要一樣大,這真的有必要嗎?

生成器的工作是預測Mask掉的Token,至於對不對並不重要,而且從某種角度上來說,預測的越離譜,可能越適合分辨器學習。而Discriminator面對的是一整句話,要逐Token的判斷該Token是原生的還是非原生的,他要學習的東西相比於Generator不僅龐大而且複雜,讓Discriminator和Generator一樣大對於任重而道遠的Discriminator過於殘忍,而如果讓Generator和Discriminator一樣大,又過於浪費。

begin{array}[b] {|c|} hline 實驗方法 & GLUE score\ hline 不共享參數 & 83.6 \ hline 共享embedding & 84.3 \ hline 共享所有參數 & 84.4 \ hline end{array}\

因此文中所採用的是共享embedding的方法。

2. Small generators

作者在保持hidden size的情況下降低層數,從而降低Generator的大小,Discriminator嘗試了hidden size為256,512和768,這裡我們發現秉承著大力出奇蹟的原則,同等情況下,Discriminator的hidden size為768的效果最好。與此同時我們還發現在Discriminator保持不變的情況下,Generator的大小並不是越大越好,當Generator的大小是Discriminator的1/4~1/2的時候實驗效果最好。

這是為什麼呢?當Generator變得複雜,可能會有兩種情況發生,Generator對Token的預測都非常有效,沒有起到欺騙的作用,或者Generator過擬合導致任務對Discriminator過於複雜而降低了學習效率。(Discriminator:救救孩子吧)

3. Training Algorithms

  1. Two-stage訓練:訓練完Generator之後使用Generator的權重初始化Discriminator,然後訓練Discriminator。
  2. Adversarial Contrastive Estimation: 上一篇文章中我們介紹了ELECTRA和GAN的區別,Discriminator的梯度無法傳遞到Generator,文中嘗試了用強化學習的方法來訓練模型。作者將Generator最小化MLM Loss替換為最大化Discriminator被替換Token的RTD LOSS。這時作者又面臨了問題,新的loss無法使用梯度上升去最大化loss,於是作者採用了policy gradient reinforcement learning來尋找最優的分辨器。

對於Two-stage的方法中我們需要注意用Generator的參數初始化Discriminator,那麼兩個大小應該是一樣。既然如此,那麼應該可以預測這樣做的效果並不會很好,為什麼呀?因為上一部分我們不是提到了Generator是Discriminator的1/4~1/2時效果最好嗎。除此以外,作者還嘗試了另一種訓練方法。

對於Adversarial Contrastive Estimation的看法大家可以在回復中留言討論。

最後的實驗結果也顯示原始的訓練方法效果最好。

4. Small model? Big model?

文中訓練了Small Electra和Big Electra,模型規模大家可以通過下標判斷。相比於Bert Base, Small Electra的參數都進行了縮小,Big Electra和Bert large的超參數保持一致,同時訓練的時間要更長一點。

begin{array}[b] {|c|} hline & Small Electra & Big Electra & Bert Base\ hline Sequence Length & 128 & 512 & 512\ hline Batch size & 128 & 2048 & 256\ hline Hidden size & 256 & 1024 & 768\ hline Embedding size & 128 & 1024 & 768\ hline end{array}\

那麼Electra的效果如何呢?我們先看Small Electra,效果可以說是一騎絕塵。同等規模的情況下效果遠超其他模型,Small Electra的Glue score要比Bert small高4.8個點,但是當規模變大,Electra Base的效果比Bert Base高2.9個點。沒有Electra small那麼亮眼,但也足夠激動人心。畢竟Bert縮小版Albert也只是在xxLarge的情況下打敗了Bert,但是Electra卻贏得毫不費力。

下面我們來看Electra Large的效果,是否如Small一樣激動人心呢?結果依然不負眾望,在SST, MRPC任務中Electra Large略遜於RoBERTa,其他任務都取得了很不錯的成績。需要注意的是在Electra之用了1/4的計算量就打敗了RoBERTa。

4. Efficiency Analysis

前文中提到Bert只計算了被替換的Token的loss,Electra使用了全部的Token,作者進一步做了一些實驗來探討那種方法更好。作者進行了一下3個實驗

  • ELECTRA 15%:使用Electra計算15%的loss
  • Replace MLM: 使用Bert訓練在預訓練的時候輸入不用MASK而是用其他生成器的輸出替換
  • All-TokensMLM:結合了Bert和Electra,Bert的預測變成了預測所有Token

實驗結果如下:

可以看到Electra 15%的效果和Bert相似,因此Bert之前只學習15%的Token的做法對於輸入是有很大的資訊損失的,而Electra的做法也彌補了這一損失。這也證明了之前作者的看法:Bert只學習15%的Token是不夠的。Replace MLM的效果和Electra 15%的效果相差不大,這說明MASK的內容其實並不重要,重要的是要學習全部的輸入序列。All-Tokens MLM的效果也解釋了這一點。

通過這些實驗結果,我想我們也可以理解為什麼作者認為分辨器的任務比生成器複雜,分辨器的規模也要比生成器大。因為Input is all.(這句話是我加的)

一些想法

前文中提到我個人對Electra的loss不是很認同的看法,這部分我們看到作者對Loss也做了一些新的嘗試,例如最大化分辨器被替換的Token的RTD Loss,雖然如此但我個人還是不是特別滿意。因為我認為為了保證分辨器的效果,生成器要給他儘可能複雜的替換效果。但是這裡的RTD Loss關注的是分辨器,生成器在整個過程中所受到的關注並不多。當然文中實驗也提到過,生成器規模變大,分辨器的效果會相對變差。那麼文中此時也沒有提到生成器規模變大,是否會導致生成器的loss降低,預測效果提高?而如果證明這一點,那麼我也可以確定我之前的猜測並沒有錯:分辨器還是需要足夠的噪音以提高效果

當然整體而言,這篇文章給了我們在Bert基礎上進行嘗試的工匠一個亮眼的提示,其他的論文大都想著對結構,超參等進行修改。而這篇論文回歸數據本身,關注我們的輸入,用Bert"欺騙'Bert,想想都讓人興奮呢。那麼下一次讓人驚喜的創意又會是什麼呢?希望大家和我一起期待。