深入理解seq2seq—attention的前置重要知识
- 2021 年 3 月 23 日
- AI
为了更好的了解seq2seq以及attention的思想,特意买了本:

看,不得不说,这本书对机器翻译的讲解非常的全面系统,不过我不做机器翻译的,主要对seq2seq部分感兴趣,看完之后做个总结。
1、机器翻译的任务是一种序列到序列,即sequence to sequence形式的问题,输入是一个序列,输出也是一个序列,本质上来说是一种序列转换的任务,因此使用seq2seq的框架结构非常的适合处理这类问题,当然除了机器翻译之外,还有许多任务也符合sequence to sequence的形式,例如nlp中的:

以及表格数据中的多元多步时间序列问题。(这意味着transformer可以稍加经过改造来适配时间序列问题,当然已经有人做了,还发了论文放了源代码:
忆臻:AAAI21最佳论文Informer:效果远超Transformer的长序列预测神器!
个人暂时还没有把transformer的结构适配到时间序列预测问题中,不过在mxnet的gluonts里,transformer已经纳入进来了:
以及fastai出的torch版时间序列library:
不过整体的完整性上来看,gluonts应该算是目前实现模型最全的,虽然底层使用mxnet,不过别担心,封装的非常完整,不太需要接触mxnet的语法除非你希望能够自定义model,当然mxnet本身的语法也很简单,和keras的设计非常类似,上手也是很快的
)
2、在transformer之前,seq2seq后续的改进基本上是三个方向:
(1) 魔改encoder和decoder的结构,例如encoder部分使用LSTM,双向LSTM,CNN等;
(2)在encoder或者decoder中加入各种网络组件帮助训练,例如resnet+deep lstm作为encoder来保证深层的lstm拥有更强大的抽象能力的同时,训练的效率也得到提升;
(3)encoder和decoder部分的不同设计策略
下面稍微展开一下:
(1)encoder部分的采用过反向LSTM、双向LSTM和CNN看下面三个图:



其中图6-14的结构实际上是LSTM+CNN的复合结构,先使用LSTM(单向or双向)得到每一个时刻的隐含状态(或者说语义编码),而后在语义编码向量的基础上再引入CNN的卷积+池化最终得到更高阶的句子向量;
而decoder部分基本上差不了太多,以RNN为主。这一阶段,seq2seq整体的训练策略比较简单,

基本上是encoder生成源句子的高阶向量化表示,然后仅使用这个高阶向量来decode decoder部分的输出。
(2). encoder和decoder部分的各类复杂的网络组件,例如resnet ,maxout,zoneout,dropout、glu 等等本质上是一种换汤不换药的设计,他们其实和网络本身的基础架构关系不大,大都是为了增强模型的训练效率和泛化能力等而引入的组件,换句话说,即使不使用seq2seq的结构,在常规的网络设计上,这类组件也可以应用起来;
(3).encoder和decoder的不同涉及策略
发现这块之前总结过,github上有一个用keras写的seq2seq的项目还有谷歌的seq2seq都是这个阶段的各类seq2seq的encoder和decoder的设计模式的实现:
//github.com/farizrahman4u/seq2seq
可以看到,基本上是基于decoder的解码方式进行的各种魔改。
而对于decoder的设计模式可以划分为:
1、greedy search:
图片来源,之前写的:

greed search的思路很贪心,解码的过程中,选择一个softmax预测的概率输出最高的词作为当前的翻译结果,然后再将这个词作为下一个目标词的输入和上一个词的LSTM对应的step阶段的hidden state一起预测下一个目标词,贪心搜索的缺点比较明显,就是误差累积的问题,这个和基于seq2seq的时间序列预测任务存在的问题也是一样,一旦某一步decode的预测出错,则后面所有的decode的预测都是在上一步误差的基础上进行累计,从而使得整体的误差很高。
2、为了弥补greedy search的问题,beam search被提出

beam search的思路也很简单,就是取topk个预测概率最高的目标词同时进入下一个的decode,然后在最终的输出中才进行最终的top1的选择。
而与时间序列相关的seq2seq结构无法使用beam search,因为回归型的seq2seq的输出没有概率分布的概念而是输出一个单值,针对于这个问题实际上时间序列预测 of seq2seq 舍弃了beam search的思路,使用了teacher forcing的扩展方法——curriculum learning,对应的采样方法叫做scheduled sampling。
[论文解读]Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks
具体可见上。
3、ensemble decoding,使用了集成学习的思想:


前面我们提到了greedy search和beam search,他们的共同点就是解码的时候,终究还是使用t时刻的预测结果来帮助预测t+1时刻的标签,这样始终无法脱离累计误差的问题,
图片来源于上:

这种模式也称之为 free running mode;
为了避免累计误差的问题,于是乎teacher forcing mode的解码策略引入:
它是一种网络训练方法,对于开发用于机器翻译,文本摘要,图像字幕的深度学习语言模型以及许多其他应用程序至关重要。它每次不使用上一个state的输出作为下一个state的输入,而是直接使用训练数据的标准答案(ground truth)的对应上一项作为下一个state的输入。
看一下大佬们对它的评价:
存在把输出返回到模型输入中的这种循环连接单元的模型可以通过Teacher Forcing机制进行训练。
这种技术最初被作为反向传播的替代技术进行宣传与开发
在动态监督学习任务中经常使用的一种有趣的技术是,在计算过程中用教师信号 d(t)替换上一个单元的实际输出 y ( t )。我们称这种技术为Teacher Forcing。
Teacher Forcing工作原理: 在训练过程的 t时刻,使用训练数据集的期望输出或实际输出: y(t), 作为下一时间步骤的输入: x(t+1),而不是使用模型生成的输出 h(t)。
在训练过程中接收ground truth的输出 y(t) 作为 t+1时刻的输入

思路很简单,我们则编码阶段使用ground truth(就是真实的decode的目标词的embedding来辅助decode而不是使用上一步的预测结果来辅助decode)
单纯使用teacher forcing策略的缺点在于:
Teacher Forcing同样存在缺点: 一直靠老师带的孩子是走不远的。
因为依赖标签数据,在训练过程中,模型会有较好的效果,但是在测试的时候因为不能得到ground truth的支持,(预测的时候没有ground truth可用,所以实际的过程还是用上一个目标词的hidden state参与生成下一个目标词)所以如果目前生成的序列在训练过程中有很大不同,模型就会变得脆弱。
也就是说,这种模型的cross-domain能力会更差,也就是如果测试数据集与训练数据集来自不同的领域,模型的performance就会变差。
beam search在一定程度上缓解了这个问题。
接下来就是重磅,attention机制的引入,这一时期,各类attention-based model层出不穷,专治各种花里胡哨,attention引入的缘由主要在于RNN的问题:
1、对于长文本,encoder部分,越早进入的word,其信息越难以被保留下来,因为encoder部分的rnn,每当t时刻接受一个新词的embedding,则包含t时刻之前的word的hidden state会继续经过变换和当前词的embedding经过RNN转换后的信息进行融合产生新的隐含状态,这种变换和融合的过程很容易造成前面已经添加的词的信息的丢失或改变,越早添加进来的词,则最终的隐含状态越难以保留,

例如上图,如果需要源句子的“economic”影响目标句子中“经济”这个词的产生,那么“economic”这个词的信息首先需要保存在编码器的第一个隐藏状态中,并且要一直保证这个信息能够不发生变化直到编码器的最后一个隐藏状态,这个是比较费劲的。
其实这个过程可以理解为传话游戏,我们要保证第一个人说的话到达最后一个人的时候不发生变化还是比较困难的。
同时,反向传播的时候,“经济”这个词的损失也同样要跋山涉水经过同样长的反向路径才能传播回来影响“economic”对应的参数。
2、encoder的最后一个编码器的输出,一般称之为context vector,即对encoder部分的输入的源句子的高阶抽象的文本向量,一般大小是固定的,固定长度的向量的记忆能力有限,特别是输入的源句子很长的时候,固定长度的context vector的性能会变得很长,如下图:


因此,引入了 attention机制,接下来的部分就是attention的描述了,之前也已经总结过了:
就不再赘述了。
最后补充一下关于cnn+seq2seq的网络结构。
cnn版的seq2seq的例子相对于RNN少得多,git上找得到的demo也基本上是RNN+CNN这样的复合结构,直到facebook发表了《Convolutional Sequence to Sequence Learning》,提出了完全使用CNN来构成Seq2Seq模型,用于机器翻译,超越了谷歌创造的基于LSTM机器翻译的效果,我才得以完全窥探CNN+seq2seq的结构,之前wavenet的seq2seq的结构也一直没有完全理解,因为接受的初始的观念基本上是基于RNN+seq2seq结构的,而纯CNN应用于seq2seq中的机制有一些不同。
关于这部分,打算写一篇新的介绍cnn+seq2seq的文章来总结,Convolutional Sequence to Sequence Learning和wavenet的cnn seq2seq架构方式。