Time-aware LSTM
其實之前有面試被問到過類似怎麼融合時間的問題。當時見的還比較少,直接說了根據時間抽樣,現在看來正確答案貌似應該是這個。
論文意圖和背景
長話短說,其實就是要在 lstm 中加入時間的概念
傳統的 LSTM
i_t = \sigma_{i}(x_tW_{xi} + h_{t-1}W_{hi} + w_{ci} \odot c_{t-1} + b_i )
f_t = \sigma_{f}(x_tW_{xf} + h_{t-1}W_{hf} + w_{cf} \odot c_{t-1} + b_f)
c_t = f_t \odot c_{t-1} + i_t \odot \sigma_{c}(x_tW_{xc} + h_{t-1}W_{hc} + b_c)
o_t = \sigma_{o}(x_tW_{xo} + h_{t-1}W_{ho} + w_{co} \odot c_t + b_o )
h_t = o_t \odot \sigma_{h}(c_t)
Phased LSTM
Phased LSTM 增加了 time gate。開啟和關閉這個門由三個參數決定。
c_t 和 h_t 僅僅在門開啟的時候允許通過。 \tau 控制了震蕩的時間。 \tau_{on} 控制了開相對於整個時間周期的比例。
s 控制了相位的偏移。這些參數都是可以學習的。
\phi_{t} = \frac{(t-s) mod \tau }{ \tau }
所以 phased_lstm 的公式還需要 time gate 的更新:
Phased LSTM 還有一個好處, 它能更好地保持初始資訊。
傳統的 lstm:
c_n = f_n \odot c_{n-1} = (1-\epsilon) \odot (f_{n-1} \odot c_{n-2} ) = … = (1-\epsilon)^n \odot c_0
相反地, 在 gate 關閉的時候, Phased LSTM 可以很好地保存資訊。
Phased LSTM 實驗結果
可以看到在一些數據集上其實效果遠勝傳統的 lstm,具體可查看原論文。
Phased LSTM 缺點
當然, Phased LSTM 也有一些缺點。
- 它僅僅考慮時間點的建模,而沒有考慮時間間隔建模。在推薦系統中,Phased LSTM 僅僅對考慮了用戶的活動狀態而未考慮用戶的靜止狀態(因為 Phased LSTM 設置了靜止狀態)。
- 並不能區分短時興趣和長期興趣的影響。
Time LSTM
Time LSTM 的提出就是為了克服 Phased LSTM 的問題。給定:
實際建模時會把單點的時間轉換為時間間隔建模。
Time LSTM 變體
Time LSTM 變體 1
這是 Phased LSTM 最簡單的改變,僅僅是把時間點換成了時間間隔,非常清晰易懂。
Time LSTM 變體 2
使用了 2 個 Time gate。Time gate1 主要用來利用現在的時間間隔來進行現在物品推薦,Time gate2 主要用來存儲時間間隔為以後的推薦準備。
Time LSTM 變體 3
Time LSTM 變體 3 用了耦合的輸入門和忘記門。
Time LSTM 實驗結果
可以看到 Time LSTM 確實在效果上優於之前的 Time-related LSTM 方法
Phased LSTM torch 開源實現
import math
import torch
import torch.nn as nn
class PhasedLSTMCell(nn.Module):
"""Phased LSTM recurrent network cell.
//arxiv.org/pdf/1610.09513v1.pdf
"""
def __init__(
self,
hidden_size,
leak=0.001,
ratio_on=0.1,
period_init_min=1.0,
period_init_max=1000.0
):
"""
Args:
hidden_size: int, The number of units in the Phased LSTM cell.
leak: float or scalar float Tensor with value in [0, 1]. Leak applied
during training.
ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
period during which the gates are open.
period_init_min: float or scalar float Tensor. With value > 0.
Minimum value of the initialized period.
The period values are initialized by drawing from the distribution:
e^U(log(period_init_min), log(period_init_max))
Where U(.,.) is the uniform distribution.
period_init_max: float or scalar float Tensor.
With value > period_init_min. Maximum value of the initialized period.
"""
super().__init__()
self.hidden_size = hidden_size
self.ratio_on = ratio_on
self.leak = leak
# initialize time-gating parameters
period = torch.exp(
torch.Tensor(hidden_size).uniform_(
math.log(period_init_min), math.log(period_init_max)
)
)
self.tau = nn.Parameter(period)
phase = torch.Tensor(hidden_size).uniform_() * period
self.phase = nn.Parameter(phase)
def _compute_phi(self, t):
t_ = t.view(-1, 1).repeat(1, self.hidden_size)
phase_ = self.phase.view(1, -1).repeat(t.shape[0], 1)
tau_ = self.tau.view(1, -1).repeat(t.shape[0], 1)
phi = torch.fmod((t_ - phase_), tau_).detach()
phi = torch.abs(phi) / tau_
return phi
def _mod(self, x, y):
"""Modulo function that propagates x gradients."""
return x + (torch.fmod(x, y) - x).detach()
def set_state(self, c, h):
self.h0 = h
self.c0 = c
def forward(self, c_s, h_s, t):
# print(c_s.size(), h_s.size(), t.size())
phi = self._compute_phi(t)
# Phase-related augmentations
k_up = 2 * phi / self.ratio_on
k_down = 2 - k_up
k_closed = self.leak * phi
k = torch.where(phi < self.ratio_on, k_down, k_closed)
k = torch.where(phi < 0.5 * self.ratio_on, k_up, k)
k = k.view(c_s.shape[0], t.shape[0], -1)
c_s_new = k * c_s + (1 - k) * self.c0
h_s_new = k * h_s + (1 - k) * self.h0
return h_s_new, c_s_new
class PhasedLSTM(nn.Module):
"""Wrapper for multi-layer sequence forwarding via
PhasedLSTMCell"""
def __init__(
self,
input_size,
hidden_size,
bidirectional=True
):
super().__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
bidirectional=bidirectional,
batch_first=True
)
self.bi = 2 if bidirectional else 1
self.phased_cell = PhasedLSTMCell(
hidden_size=self.bi * hidden_size
)
def forward(self, u_sequence):
"""
Args:
sequence: The input sequence data of shape (batch, time, N)
times: The timestamps corresponding to the data of shape (batch, time)
"""
c0 = u_sequence.new_zeros((self.bi, u_sequence.size(0), self.hidden_size))
h0 = u_sequence.new_zeros((self.bi, u_sequence.size(0), self.hidden_size))
self.phased_cell.set_state(c0, h0)
outputs = []
for i in range(u_sequence.size(1)):
u_t = u_sequence[:, i, :-1].unsqueeze(1)
t_t = u_sequence[:, i, -1]
out, (c_t, h_t) = self.lstm(u_t, (c0, h0))
(c_s, h_s) = self.phased_cell(c_t, h_t, t_t)
self.phased_cell.set_state(c_s, h_s)
c0, h0 = c_s, h_s
outputs.append(out)
outputs = torch.cat(outputs, dim=1)
return outputs