Time-aware LSTM

其实之前有面试被问到过类似怎么融合时间的问题。当时见的还比较少,直接说了根据时间抽样,现在看来正确答案貌似应该是这个。

论文意图和背景

长话短说,其实就是要在 lstm 中加入时间的概念

传统的 LSTM

i_t = \sigma_{i}(x_tW_{xi} + h_{t-1}W_{hi} + w_{ci} \odot c_{t-1} + b_i )

f_t = \sigma_{f}(x_tW_{xf} + h_{t-1}W_{hf} + w_{cf} \odot c_{t-1} + b_f)

c_t = f_t \odot c_{t-1} + i_t \odot \sigma_{c}(x_tW_{xc} + h_{t-1}W_{hc} + b_c)

o_t = \sigma_{o}(x_tW_{xo} + h_{t-1}W_{ho} + w_{co} \odot c_t + b_o )

h_t = o_t \odot \sigma_{h}(c_t)

Phased LSTM

Phased LSTM 增加了 time gate。开启和关闭这个门由三个参数决定。
c_t h_t 仅仅在门开启的时候允许通过。 \tau 控制了震荡的时间。 \tau_{on} 控制了开相对于整个时间周期的比例。
s 控制了相位的偏移。这些参数都是可以学习的。
\phi_{t} = \frac{(t-s) mod \tau }{ \tau }

image.png
所以 phased_lstm 的公式还需要 time gate 的更新:

image.png
Phased LSTM 还有一个好处, 它能更好地保持初始信息。
传统的 lstm:

c_n = f_n \odot c_{n-1} = (1-\epsilon) \odot (f_{n-1} \odot c_{n-2} ) = … = (1-\epsilon)^n \odot c_0

相反地, 在 gate 关闭的时候, Phased LSTM 可以很好地保存信息。

Phased LSTM 实验结果

image.png

可以看到在一些数据集上其实效果远胜传统的 lstm,具体可查看原论文。

Phased LSTM 缺点

当然, Phased LSTM 也有一些缺点。

  1. 它仅仅考虑时间点的建模,而没有考虑时间间隔建模。在推荐系统中,Phased LSTM 仅仅对考虑了用户的活动状态而未考虑用户的静止状态(因为 Phased LSTM 设置了静止状态)。
  2. 并不能区分短时兴趣和长期兴趣的影响。

Time LSTM

Time LSTM 的提出就是为了克服 Phased LSTM 的问题。给定:

image.png

实际建模时会把单点的时间转换为时间间隔建模。

Time LSTM 变体

Time LSTM 变体 1

image.png

image.png

这是 Phased LSTM 最简单的改变,仅仅是把时间点换成了时间间隔,非常清晰易懂。

Time LSTM 变体 2

image.png

image.png
使用了 2 个 Time gate。Time gate1 主要用来利用现在的时间间隔来进行现在物品推荐,Time gate2 主要用来存储时间间隔为以后的推荐准备。

Time LSTM 变体 3

image.png

image.png
Time LSTM 变体 3 用了耦合的输入门和忘记门。

Time LSTM 实验结果

image.png
可以看到 Time LSTM 确实在效果上优于之前的 Time-related LSTM 方法

Phased LSTM torch 开源实现

github地址

import math

import torch
import torch.nn as nn



class PhasedLSTMCell(nn.Module):
    """Phased LSTM recurrent network cell.
    //arxiv.org/pdf/1610.09513v1.pdf
    """

    def __init__(
        self,
        hidden_size,
        leak=0.001,
        ratio_on=0.1,
        period_init_min=1.0,
        period_init_max=1000.0
    ):
        """
        Args:
            hidden_size: int, The number of units in the Phased LSTM cell.
            leak: float or scalar float Tensor with value in [0, 1]. Leak applied
                during training.
            ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
                period during which the gates are open.
            period_init_min: float or scalar float Tensor. With value > 0.
                Minimum value of the initialized period.
                The period values are initialized by drawing from the distribution:
                e^U(log(period_init_min), log(period_init_max))
                Where U(.,.) is the uniform distribution.
            period_init_max: float or scalar float Tensor.
                With value > period_init_min. Maximum value of the initialized period.
        """
        super().__init__()

        self.hidden_size = hidden_size
        self.ratio_on = ratio_on
        self.leak = leak

        # initialize time-gating parameters
        period = torch.exp(
            torch.Tensor(hidden_size).uniform_(
                math.log(period_init_min), math.log(period_init_max)
            )
        )
        self.tau = nn.Parameter(period)

        phase = torch.Tensor(hidden_size).uniform_() * period
        self.phase = nn.Parameter(phase)

    def _compute_phi(self, t):
        t_ = t.view(-1, 1).repeat(1, self.hidden_size)
        phase_ = self.phase.view(1, -1).repeat(t.shape[0], 1)
        tau_ = self.tau.view(1, -1).repeat(t.shape[0], 1)

        phi = torch.fmod((t_ - phase_), tau_).detach()
        phi = torch.abs(phi) / tau_
        return phi

    def _mod(self, x, y):
        """Modulo function that propagates x gradients."""
        return x + (torch.fmod(x, y) - x).detach()

    def set_state(self, c, h):
        self.h0 = h
        self.c0 = c

    def forward(self, c_s, h_s, t):
        # print(c_s.size(), h_s.size(), t.size())
        phi = self._compute_phi(t)

        # Phase-related augmentations
        k_up = 2 * phi / self.ratio_on
        k_down = 2 - k_up
        k_closed = self.leak * phi

        k = torch.where(phi < self.ratio_on, k_down, k_closed)
        k = torch.where(phi < 0.5 * self.ratio_on, k_up, k)
        k = k.view(c_s.shape[0], t.shape[0], -1)

        c_s_new = k * c_s + (1 - k) * self.c0
        h_s_new = k * h_s + (1 - k) * self.h0

        return h_s_new, c_s_new


class PhasedLSTM(nn.Module):
    """Wrapper for multi-layer sequence forwarding via
       PhasedLSTMCell"""

    def __init__(
        self,
        input_size,
        hidden_size,
        bidirectional=True
    ):
        super().__init__()
        self.hidden_size = hidden_size

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            bidirectional=bidirectional,
            batch_first=True
        )
        self.bi = 2 if bidirectional else 1

        self.phased_cell = PhasedLSTMCell(
            hidden_size=self.bi * hidden_size
        )

    def forward(self, u_sequence):
        """
        Args:
            sequence: The input sequence data of shape (batch, time, N)
            times: The timestamps corresponding to the data of shape (batch, time)
        """

        c0 = u_sequence.new_zeros((self.bi, u_sequence.size(0), self.hidden_size))
        h0 = u_sequence.new_zeros((self.bi, u_sequence.size(0), self.hidden_size))
        self.phased_cell.set_state(c0, h0)

        outputs = []
        for i in range(u_sequence.size(1)):
            u_t = u_sequence[:, i, :-1].unsqueeze(1)
            t_t = u_sequence[:, i, -1]

            out, (c_t, h_t) = self.lstm(u_t, (c0, h0))
            (c_s, h_s) = self.phased_cell(c_t, h_t, t_t)

            self.phased_cell.set_state(c_s, h_s)
            c0, h0 = c_s, h_s

            outputs.append(out)
        outputs = torch.cat(outputs, dim=1)

        return outputs