【論文筆記】基於強化學習的人機對話
- 2019 年 10 月 7 日
- 筆記
【導讀】人機對話已經廣泛應用於日常生活,比如蘋果的siri,微軟的小冰等等。本篇論文作者基於強化學習,針對給定環境和對話長度,設計了一種新的演算法(RL-DAGM)。
論文地址:
https://arxiv.org/pdf/1807.07255.pdf
作者在本篇論文中提高了交互的多樣性,並提高了人的參與度,因此作者自己準備了數據集,這個在後半部分會進行詳細介紹。
01
對話行為分類
在闡述模型之前,先說一下作者在模型中使用的分類器。作者通過一個分類器將對話進行了分類,類別在實驗部分的數據集內有簡單介紹,分類器的實際例子如下所示:

分類器作者是基於雙向RNN的GRU進行展開的。網路主要由embedding、biGRUs、MLP組成。
- 句子進行embedding嵌入

- 雙向GRUs

- MLP

02
對話生成模型

監督學習
作者在論文中提到,該模型通過學習兩個人之間的對話來區分對話內容的種類,然後在接下來的對話中通過分類生成對話。
作者的對話模型由策略網路和生成網路兩部分組成,首先歷史聊天記錄通過策略網路來挑選對話類別,然後基於聊天記錄和對話類別通過生成網路進行對話。對話模型公式:



第二個公式表示第i組對話的類別;r_i表示生成的對話,p_a,p_r分別表示策略網路和生成網路;u_k為對話,a_k為對應的對話類別。

在策略網路中,我們可以從圖中(b)看出,上半部分的輸入a通過詞嵌入以及GRU的計算得到最後的隱藏狀態:

下半部分通過雙向GRU得到隱藏狀態h以及輸出t,並通過一系列計算得到最後的輸出:

最後通過MLP:

在生成網路中,作者基於sequence to-sequence框架,與標準的encoder-decoder不同的是在編碼時使用了attention機制。
將embedding之後的數據輸入網路中,編碼的第k個隱藏狀態用v_{i,k}表示:


然後通過attention機制:

這裡W和v都是參數,指的是編碼GRU的第j-1個隱藏狀態,作者在論文中給出了第j個隱藏狀態的計算公式:

最後計算生成概率:

黃框中是一組只有一個1的向量,表示對應w的索引。
生成網路的輸出pr:

以上介紹是通過監督學習實現的,然而這種模型對於長對話來說會存在弊端(作者訓練集中45%的對話不超過5回合),並且策略網路是在不知曉未來對話的情況下進行優化的,因此作者提出使用強化學習來優化模型。

強化學習
策略網路與監督學習中的表示相同,用pa(ai|si)表示,並定義了獎勵:

E[len(.)]與E[rel(.)]的計算公式:


這裡m( , )是雙向LSTM計算所得,是用來計算相關度的。目標函數是為了最大化獎勵:

03
實驗

數據集
為了提高交互的多樣性,作者在數據標註上做了分類,分別為CM.S,CM.Q,CM.A,CS.S,CS.Q,CS.A,O:

作者從百度貼吧中選取500個對話作為訓練集,每個對話分別對應3個labels,作者使用了一個分類器將對話進行了分類。

實驗設置
作者實驗中參考的程式碼:
S2SA: sequence to sequence with attention (Bahdanau et al., 2015)(https://github.com/mila-udem/blocks)
HRED: the hierarchical encoder-decoder model (https://github.com/julianser/hed-dlg-truncated)
VHRED: the hierarchical latent variable encoder-decoder model (https://github.com/julianser/hed-dlg-truncated)
RL-S2S: dialogue generation with reinforcement learning (https://github.com/liuyuemaicha/Deep-Reinforcement-Learning-for-Dialogue-Generation-in-tensorflow)

實驗結果
為了測試 模型的效果,作者將每組對話的最後一句作為背景,將歷史對話輸入到不同的模型中,作者並對輸出的回答分為3個等級,以此來直觀的評判模型的效果。等級2:回答的句子不僅相關性強,而且內容有趣;等級1:回答勉強可以,但沒有足夠的資訊;等級0:回答毫無意義,或者無法錯誤。評估結果如下:

實例展示:

從中可以看出 SL-DAGM 和 RL-DAGM都有在嘗試轉新的話題。