Time-aware LSTM

其實之前有面試被問到過類似怎麼融合時間的問題。當時見的還比較少,直接說了根據時間抽樣,現在看來正確答案貌似應該是這個。

論文意圖和背景

長話短說,其實就是要在 lstm 中加入時間的概念

傳統的 LSTM

i_t = \sigma_{i}(x_tW_{xi} + h_{t-1}W_{hi} + w_{ci} \odot c_{t-1} + b_i )

f_t = \sigma_{f}(x_tW_{xf} + h_{t-1}W_{hf} + w_{cf} \odot c_{t-1} + b_f)

c_t = f_t \odot c_{t-1} + i_t \odot \sigma_{c}(x_tW_{xc} + h_{t-1}W_{hc} + b_c)

o_t = \sigma_{o}(x_tW_{xo} + h_{t-1}W_{ho} + w_{co} \odot c_t + b_o )

h_t = o_t \odot \sigma_{h}(c_t)

Phased LSTM

Phased LSTM 增加了 time gate。開啟和關閉這個門由三個參數決定。
c_t h_t 僅僅在門開啟的時候允許通過。 \tau 控制了震蕩的時間。 \tau_{on} 控制了開相對於整個時間周期的比例。
s 控制了相位的偏移。這些參數都是可以學習的。
\phi_{t} = \frac{(t-s) mod \tau }{ \tau }

image.png
所以 phased_lstm 的公式還需要 time gate 的更新:

image.png
Phased LSTM 還有一個好處, 它能更好地保持初始資訊。
傳統的 lstm:

c_n = f_n \odot c_{n-1} = (1-\epsilon) \odot (f_{n-1} \odot c_{n-2} ) = … = (1-\epsilon)^n \odot c_0

相反地, 在 gate 關閉的時候, Phased LSTM 可以很好地保存資訊。

Phased LSTM 實驗結果

image.png

可以看到在一些數據集上其實效果遠勝傳統的 lstm,具體可查看原論文。

Phased LSTM 缺點

當然, Phased LSTM 也有一些缺點。

  1. 它僅僅考慮時間點的建模,而沒有考慮時間間隔建模。在推薦系統中,Phased LSTM 僅僅對考慮了用戶的活動狀態而未考慮用戶的靜止狀態(因為 Phased LSTM 設置了靜止狀態)。
  2. 並不能區分短時興趣和長期興趣的影響。

Time LSTM

Time LSTM 的提出就是為了克服 Phased LSTM 的問題。給定:

image.png

實際建模時會把單點的時間轉換為時間間隔建模。

Time LSTM 變體

Time LSTM 變體 1

image.png

image.png

這是 Phased LSTM 最簡單的改變,僅僅是把時間點換成了時間間隔,非常清晰易懂。

Time LSTM 變體 2

image.png

image.png
使用了 2 個 Time gate。Time gate1 主要用來利用現在的時間間隔來進行現在物品推薦,Time gate2 主要用來存儲時間間隔為以後的推薦準備。

Time LSTM 變體 3

image.png

image.png
Time LSTM 變體 3 用了耦合的輸入門和忘記門。

Time LSTM 實驗結果

image.png
可以看到 Time LSTM 確實在效果上優於之前的 Time-related LSTM 方法

Phased LSTM torch 開源實現

github地址

import math

import torch
import torch.nn as nn



class PhasedLSTMCell(nn.Module):
    """Phased LSTM recurrent network cell.
    //arxiv.org/pdf/1610.09513v1.pdf
    """

    def __init__(
        self,
        hidden_size,
        leak=0.001,
        ratio_on=0.1,
        period_init_min=1.0,
        period_init_max=1000.0
    ):
        """
        Args:
            hidden_size: int, The number of units in the Phased LSTM cell.
            leak: float or scalar float Tensor with value in [0, 1]. Leak applied
                during training.
            ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
                period during which the gates are open.
            period_init_min: float or scalar float Tensor. With value > 0.
                Minimum value of the initialized period.
                The period values are initialized by drawing from the distribution:
                e^U(log(period_init_min), log(period_init_max))
                Where U(.,.) is the uniform distribution.
            period_init_max: float or scalar float Tensor.
                With value > period_init_min. Maximum value of the initialized period.
        """
        super().__init__()

        self.hidden_size = hidden_size
        self.ratio_on = ratio_on
        self.leak = leak

        # initialize time-gating parameters
        period = torch.exp(
            torch.Tensor(hidden_size).uniform_(
                math.log(period_init_min), math.log(period_init_max)
            )
        )
        self.tau = nn.Parameter(period)

        phase = torch.Tensor(hidden_size).uniform_() * period
        self.phase = nn.Parameter(phase)

    def _compute_phi(self, t):
        t_ = t.view(-1, 1).repeat(1, self.hidden_size)
        phase_ = self.phase.view(1, -1).repeat(t.shape[0], 1)
        tau_ = self.tau.view(1, -1).repeat(t.shape[0], 1)

        phi = torch.fmod((t_ - phase_), tau_).detach()
        phi = torch.abs(phi) / tau_
        return phi

    def _mod(self, x, y):
        """Modulo function that propagates x gradients."""
        return x + (torch.fmod(x, y) - x).detach()

    def set_state(self, c, h):
        self.h0 = h
        self.c0 = c

    def forward(self, c_s, h_s, t):
        # print(c_s.size(), h_s.size(), t.size())
        phi = self._compute_phi(t)

        # Phase-related augmentations
        k_up = 2 * phi / self.ratio_on
        k_down = 2 - k_up
        k_closed = self.leak * phi

        k = torch.where(phi < self.ratio_on, k_down, k_closed)
        k = torch.where(phi < 0.5 * self.ratio_on, k_up, k)
        k = k.view(c_s.shape[0], t.shape[0], -1)

        c_s_new = k * c_s + (1 - k) * self.c0
        h_s_new = k * h_s + (1 - k) * self.h0

        return h_s_new, c_s_new


class PhasedLSTM(nn.Module):
    """Wrapper for multi-layer sequence forwarding via
       PhasedLSTMCell"""

    def __init__(
        self,
        input_size,
        hidden_size,
        bidirectional=True
    ):
        super().__init__()
        self.hidden_size = hidden_size

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            bidirectional=bidirectional,
            batch_first=True
        )
        self.bi = 2 if bidirectional else 1

        self.phased_cell = PhasedLSTMCell(
            hidden_size=self.bi * hidden_size
        )

    def forward(self, u_sequence):
        """
        Args:
            sequence: The input sequence data of shape (batch, time, N)
            times: The timestamps corresponding to the data of shape (batch, time)
        """

        c0 = u_sequence.new_zeros((self.bi, u_sequence.size(0), self.hidden_size))
        h0 = u_sequence.new_zeros((self.bi, u_sequence.size(0), self.hidden_size))
        self.phased_cell.set_state(c0, h0)

        outputs = []
        for i in range(u_sequence.size(1)):
            u_t = u_sequence[:, i, :-1].unsqueeze(1)
            t_t = u_sequence[:, i, -1]

            out, (c_t, h_t) = self.lstm(u_t, (c0, h0))
            (c_s, h_s) = self.phased_cell(c_t, h_t, t_t)

            self.phased_cell.set_state(c_s, h_s)
            c0, h0 = c_s, h_s

            outputs.append(out)
        outputs = torch.cat(outputs, dim=1)

        return outputs