Time-aware LSTM
其实之前有面试被问到过类似怎么融合时间的问题。当时见的还比较少,直接说了根据时间抽样,现在看来正确答案貌似应该是这个。
论文意图和背景
长话短说,其实就是要在 lstm 中加入时间的概念
传统的 LSTM
i_t = \sigma_{i}(x_tW_{xi} + h_{t-1}W_{hi} + w_{ci} \odot c_{t-1} + b_i )
f_t = \sigma_{f}(x_tW_{xf} + h_{t-1}W_{hf} + w_{cf} \odot c_{t-1} + b_f)
c_t = f_t \odot c_{t-1} + i_t \odot \sigma_{c}(x_tW_{xc} + h_{t-1}W_{hc} + b_c)
o_t = \sigma_{o}(x_tW_{xo} + h_{t-1}W_{ho} + w_{co} \odot c_t + b_o )
h_t = o_t \odot \sigma_{h}(c_t)
Phased LSTM
Phased LSTM 增加了 time gate。开启和关闭这个门由三个参数决定。
c_t 和 h_t 仅仅在门开启的时候允许通过。 \tau 控制了震荡的时间。 \tau_{on} 控制了开相对于整个时间周期的比例。
s 控制了相位的偏移。这些参数都是可以学习的。
\phi_{t} = \frac{(t-s) mod \tau }{ \tau }
所以 phased_lstm 的公式还需要 time gate 的更新:
Phased LSTM 还有一个好处, 它能更好地保持初始信息。
传统的 lstm:
c_n = f_n \odot c_{n-1} = (1-\epsilon) \odot (f_{n-1} \odot c_{n-2} ) = … = (1-\epsilon)^n \odot c_0
相反地, 在 gate 关闭的时候, Phased LSTM 可以很好地保存信息。
Phased LSTM 实验结果
可以看到在一些数据集上其实效果远胜传统的 lstm,具体可查看原论文。
Phased LSTM 缺点
当然, Phased LSTM 也有一些缺点。
- 它仅仅考虑时间点的建模,而没有考虑时间间隔建模。在推荐系统中,Phased LSTM 仅仅对考虑了用户的活动状态而未考虑用户的静止状态(因为 Phased LSTM 设置了静止状态)。
- 并不能区分短时兴趣和长期兴趣的影响。
Time LSTM
Time LSTM 的提出就是为了克服 Phased LSTM 的问题。给定:
实际建模时会把单点的时间转换为时间间隔建模。
Time LSTM 变体
Time LSTM 变体 1
这是 Phased LSTM 最简单的改变,仅仅是把时间点换成了时间间隔,非常清晰易懂。
Time LSTM 变体 2
使用了 2 个 Time gate。Time gate1 主要用来利用现在的时间间隔来进行现在物品推荐,Time gate2 主要用来存储时间间隔为以后的推荐准备。
Time LSTM 变体 3
Time LSTM 变体 3 用了耦合的输入门和忘记门。
Time LSTM 实验结果
可以看到 Time LSTM 确实在效果上优于之前的 Time-related LSTM 方法
Phased LSTM torch 开源实现
import math
import torch
import torch.nn as nn
class PhasedLSTMCell(nn.Module):
"""Phased LSTM recurrent network cell.
//arxiv.org/pdf/1610.09513v1.pdf
"""
def __init__(
self,
hidden_size,
leak=0.001,
ratio_on=0.1,
period_init_min=1.0,
period_init_max=1000.0
):
"""
Args:
hidden_size: int, The number of units in the Phased LSTM cell.
leak: float or scalar float Tensor with value in [0, 1]. Leak applied
during training.
ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
period during which the gates are open.
period_init_min: float or scalar float Tensor. With value > 0.
Minimum value of the initialized period.
The period values are initialized by drawing from the distribution:
e^U(log(period_init_min), log(period_init_max))
Where U(.,.) is the uniform distribution.
period_init_max: float or scalar float Tensor.
With value > period_init_min. Maximum value of the initialized period.
"""
super().__init__()
self.hidden_size = hidden_size
self.ratio_on = ratio_on
self.leak = leak
# initialize time-gating parameters
period = torch.exp(
torch.Tensor(hidden_size).uniform_(
math.log(period_init_min), math.log(period_init_max)
)
)
self.tau = nn.Parameter(period)
phase = torch.Tensor(hidden_size).uniform_() * period
self.phase = nn.Parameter(phase)
def _compute_phi(self, t):
t_ = t.view(-1, 1).repeat(1, self.hidden_size)
phase_ = self.phase.view(1, -1).repeat(t.shape[0], 1)
tau_ = self.tau.view(1, -1).repeat(t.shape[0], 1)
phi = torch.fmod((t_ - phase_), tau_).detach()
phi = torch.abs(phi) / tau_
return phi
def _mod(self, x, y):
"""Modulo function that propagates x gradients."""
return x + (torch.fmod(x, y) - x).detach()
def set_state(self, c, h):
self.h0 = h
self.c0 = c
def forward(self, c_s, h_s, t):
# print(c_s.size(), h_s.size(), t.size())
phi = self._compute_phi(t)
# Phase-related augmentations
k_up = 2 * phi / self.ratio_on
k_down = 2 - k_up
k_closed = self.leak * phi
k = torch.where(phi < self.ratio_on, k_down, k_closed)
k = torch.where(phi < 0.5 * self.ratio_on, k_up, k)
k = k.view(c_s.shape[0], t.shape[0], -1)
c_s_new = k * c_s + (1 - k) * self.c0
h_s_new = k * h_s + (1 - k) * self.h0
return h_s_new, c_s_new
class PhasedLSTM(nn.Module):
"""Wrapper for multi-layer sequence forwarding via
PhasedLSTMCell"""
def __init__(
self,
input_size,
hidden_size,
bidirectional=True
):
super().__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
bidirectional=bidirectional,
batch_first=True
)
self.bi = 2 if bidirectional else 1
self.phased_cell = PhasedLSTMCell(
hidden_size=self.bi * hidden_size
)
def forward(self, u_sequence):
"""
Args:
sequence: The input sequence data of shape (batch, time, N)
times: The timestamps corresponding to the data of shape (batch, time)
"""
c0 = u_sequence.new_zeros((self.bi, u_sequence.size(0), self.hidden_size))
h0 = u_sequence.new_zeros((self.bi, u_sequence.size(0), self.hidden_size))
self.phased_cell.set_state(c0, h0)
outputs = []
for i in range(u_sequence.size(1)):
u_t = u_sequence[:, i, :-1].unsqueeze(1)
t_t = u_sequence[:, i, -1]
out, (c_t, h_t) = self.lstm(u_t, (c0, h0))
(c_s, h_s) = self.phased_cell(c_t, h_t, t_t)
self.phased_cell.set_state(c_s, h_s)
c0, h0 = c_s, h_s
outputs.append(out)
outputs = torch.cat(outputs, dim=1)
return outputs