【论文笔记】基于强化学习的人机对话
- 2019 年 10 月 7 日
- 筆記
【导读】人机对话已经广泛应用于日常生活,比如苹果的siri,微软的小冰等等。本篇论文作者基于强化学习,针对给定环境和对话长度,设计了一种新的算法(RL-DAGM)。
论文地址:
https://arxiv.org/pdf/1807.07255.pdf
作者在本篇论文中提高了交互的多样性,并提高了人的参与度,因此作者自己准备了数据集,这个在后半部分会进行详细介绍。
01
对话行为分类
在阐述模型之前,先说一下作者在模型中使用的分类器。作者通过一个分类器将对话进行了分类,类别在实验部分的数据集内有简单介绍,分类器的实际例子如下所示:

分类器作者是基于双向RNN的GRU进行展开的。网络主要由embedding、biGRUs、MLP组成。
- 句子进行embedding嵌入

- 双向GRUs

- MLP

02
对话生成模型

监督学习
作者在论文中提到,该模型通过学习两个人之间的对话来区分对话内容的种类,然后在接下来的对话中通过分类生成对话。
作者的对话模型由策略网络和生成网络两部分组成,首先历史聊天记录通过策略网络来挑选对话类别,然后基于聊天记录和对话类别通过生成网络进行对话。对话模型公式:



第二个公式表示第i组对话的类别;r_i表示生成的对话,p_a,p_r分别表示策略网络和生成网络;u_k为对话,a_k为对应的对话类别。

在策略网络中,我们可以从图中(b)看出,上半部分的输入a通过词嵌入以及GRU的计算得到最后的隐藏状态:

下半部分通过双向GRU得到隐藏状态h以及输出t,并通过一系列计算得到最后的输出:

最后通过MLP:

在生成网络中,作者基于sequence to-sequence框架,与标准的encoder-decoder不同的是在编码时使用了attention机制。
将embedding之后的数据输入网络中,编码的第k个隐藏状态用v_{i,k}表示:


然后通过attention机制:

这里W和v都是参数,指的是编码GRU的第j-1个隐藏状态,作者在论文中给出了第j个隐藏状态的计算公式:

最后计算生成概率:

黄框中是一组只有一个1的向量,表示对应w的索引。
生成网络的输出pr:

以上介绍是通过监督学习实现的,然而这种模型对于长对话来说会存在弊端(作者训练集中45%的对话不超过5回合),并且策略网络是在不知晓未来对话的情况下进行优化的,因此作者提出使用强化学习来优化模型。

强化学习
策略网络与监督学习中的表示相同,用pa(ai|si)表示,并定义了奖励:

E[len(.)]与E[rel(.)]的计算公式:


这里m( , )是双向LSTM计算所得,是用来计算相关度的。目标函数是为了最大化奖励:

03
实验

数据集
为了提高交互的多样性,作者在数据标注上做了分类,分别为CM.S,CM.Q,CM.A,CS.S,CS.Q,CS.A,O:

作者从百度贴吧中选取500个对话作为训练集,每个对话分别对应3个labels,作者使用了一个分类器将对话进行了分类。

实验设置
作者实验中参考的代码:
S2SA: sequence to sequence with attention (Bahdanau et al., 2015)(https://github.com/mila-udem/blocks)
HRED: the hierarchical encoder-decoder model (https://github.com/julianser/hed-dlg-truncated)
VHRED: the hierarchical latent variable encoder-decoder model (https://github.com/julianser/hed-dlg-truncated)
RL-S2S: dialogue generation with reinforcement learning (https://github.com/liuyuemaicha/Deep-Reinforcement-Learning-for-Dialogue-Generation-in-tensorflow)

实验结果
为了测试 模型的效果,作者将每组对话的最后一句作为背景,将历史对话输入到不同的模型中,作者并对输出的回答分为3个等级,以此来直观的评判模型的效果。等级2:回答的句子不仅相关性强,而且内容有趣;等级1:回答勉强可以,但没有足够的信息;等级0:回答毫无意义,或者无法错误。评估结果如下:

实例展示:

从中可以看出 SL-DAGM 和 RL-DAGM都有在尝试转新的话题。