# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from torch import nn
from ...common.enum import LossEnum
from ...common.samples import STLSTMSample
from ...loss.sampled_softmax import SampledSoftmaxLoss
from ...metrics.acc import AccMetrics
from ...tokenizers import AbstractTokenizer
from ...tokenizers.slot import Bucketizer
from ...utils import load_bucketizer, load_tokenizer
from ..abstract import BaseLightningModel
from ..embeddings.base import BaseTokenEmbeddingLayer, SimpleEmbedding
from .rnn import select_last_k_elements
[docs]
class STLSTM(nn.Module):
def __init__(
self,
tokenizer: AbstractTokenizer,
embedding_dim: int,
hidden_size: int,
ts_bucketizer: Bucketizer,
loc_bucketizer: Bucketizer,
loc_emb_layer: Optional[BaseTokenEmbeddingLayer] = None,
):
super().__init__()
self.loc_emb = (
loc_emb_layer
if loc_emb_layer is not None
else SimpleEmbedding(tokenizer=tokenizer, embedding_dim=embedding_dim)
)
self.temporal_upper_emb = nn.Embedding(
num_embeddings=ts_bucketizer.num_buckets, embedding_dim=embedding_dim
)
self.temporal_lower_emb = nn.Embedding(
num_embeddings=ts_bucketizer.num_buckets, embedding_dim=embedding_dim
)
self.spatial_upper_emb = nn.Embedding(
num_embeddings=loc_bucketizer.num_buckets, embedding_dim=embedding_dim
)
self.spatial_lower_emb = nn.Embedding(
num_embeddings=loc_bucketizer.num_buckets, embedding_dim=embedding_dim
)
self.temporal_ln = nn.Linear(
in_features=embedding_dim, out_features=3 * hidden_size, bias=False
)
self.spatial_ln = nn.Linear(
in_features=embedding_dim, out_features=3 * hidden_size, bias=False
)
self.input_weight = nn.Linear(
in_features=embedding_dim, out_features=4 * hidden_size, bias=True
)
self.hidden_weight = nn.Linear(
in_features=hidden_size, out_features=4 * hidden_size, bias=False
)
self._hidden_size = hidden_size
self._td_upper = ts_bucketizer.upper_bound
self._td_lower = ts_bucketizer.lower_bound
self._sd_upper = loc_bucketizer.upper_bound
self._sd_lower = loc_bucketizer.lower_bound
@property
def hidden_size(self) -> int:
return self._hidden_size
@property
def td_upper(self) -> float:
return self._td_upper
@property
def td_lower(self) -> float:
return self._td_lower
@property
def sd_upper(self) -> float:
return self._sd_upper
@property
def sd_lower(self) -> float:
return self._sd_lower
[docs]
def cell_step(
self,
loc: torch.LongTensor,
td_upper: torch.LongTensor,
td_lower: torch.LongTensor,
sd_upper: torch.LongTensor,
sd_lower: torch.LongTensor,
hidden: Tuple[torch.Tensor, torch.Tensor],
first_step: bool = False,
):
"""
loc, shape is (B,),位置batch
td_upper, td_lower: shape is (B,),这个是两个时间步之间的时间差被上界减去和减去下界的batch
sd_upper, sd_lower: shape is (B,),这个是两个时间步之间的空间位移被上界减去和减去下界的batch
hidden: [hidden, cell], shape: (B, H), (B, H)
first_step: bool, optional
default False
"""
# (B, C)
loc_emb = self.loc_emb(loc)
# (B, 4 * H)
tmp = self.input_weight(loc_emb) + self.hidden_weight(hidden[0])
# (B, H), (B, H), (B, H), (B, H)
i_t, f_t, o_t, g_t = tmp.chunk(4, 1)
# (B, C)
q = (self.temporal_upper_emb(td_upper) + self.temporal_lower_emb(td_lower)) / (
self.td_upper - self.td_lower
)
if first_step:
q = torch.zeros_like(q, device=q.device)
# (B, C)
s = (self.spatial_upper_emb(sd_upper) + self.spatial_lower_emb(sd_lower)) / (
self.sd_upper - self.sd_lower
)
if first_step:
s = torch.zeros_like(s, device=s.device)
# (B, 3 * H)
F = self.temporal_ln(q) + self.spatial_ln(s)
# (B, H), (B, H), (B, H)
i_F, f_F, o_F = F.chunk(3, 1)
i_t = torch.sigmoid(i_t + i_F)
f_t = torch.sigmoid(f_t + f_F)
o_t = torch.sigmoid(o_t + o_F)
g_t = torch.tanh(g_t)
# (B, H)
c_t = f_t * hidden[1] + i_t * g_t
# (B, H)
h_t = o_t * torch.tanh(c_t)
# (B, H), (B, H)
return h_t, c_t
[docs]
def forward(
self,
sample: STLSTMSample,
hidden: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
"""
Parameters
----------
sample: STLSTMSample
输入序列
hidden: Tuple[torch.Tensor, torch.Tensor], optional
(hidden, cell), shape: (B, H), (B, H)
"""
batch_size, num_timesteps = sample.loc_seq.shape
hidden = (
hidden
if hidden
else self.init_hidden(batch_size=batch_size, device=sample.loc_seq.device)
)
all_hiddens = []
for ts_idx in range(num_timesteps):
hidden = self.cell_step(
loc=sample.loc_seq[:, ts_idx],
td_upper=sample.td_upper_seq[:, ts_idx],
td_lower=sample.td_lower_seq[:, ts_idx],
sd_upper=sample.sd_upper_seq[:, ts_idx],
sd_lower=sample.sd_lower_seq[:, ts_idx],
hidden=hidden,
first_step=ts_idx == 0,
)
all_hiddens.append(hidden)
# (B, T + 1, H)
all_h = torch.stack([hidden[0]] + [h for h, _ in all_hiddens], dim=1)
# (B, T + 1, H)
all_c = torch.stack([hidden[1]] + [c for _, c in all_hiddens], dim=1)
valid_lengths_p1 = [length + 1 for length in sample.valid_lengths]
last_h = select_last_k_elements(x=all_h, lengths=valid_lengths_p1, k=1).squeeze(
dim=1
)
last_c = select_last_k_elements(x=all_c, lengths=valid_lengths_p1, k=1).squeeze(
dim=1
)
return all_h[:, 1:], (last_h, last_c)
[docs]
def init_hidden(self, batch_size: int, device: torch.device):
return (
torch.zeros(size=(batch_size, self.hidden_size), device=device),
torch.zeros(size=(batch_size, self.hidden_size), device=device),
)
[docs]
class STLSTMModule(BaseLightningModel):
def __init__(
self,
tokenizer: Union[str, AbstractTokenizer],
embedding_dim: int,
hidden_size: int,
ts_bucketizer: Union[str, Bucketizer],
loc_bucketizer: Union[str, Bucketizer],
reduction: Union[str, LossEnum] = "mean",
use_sampled_softmax: bool = True,
num_neg_samples: int = 64,
loc_emb_layer: Optional[BaseTokenEmbeddingLayer] = None,
optimizer_type: str = "adam",
learning_rate: float = 1e-3,
):
super().__init__(optimizer_type=optimizer_type, learning_rate=learning_rate)
self.save_hyperparameters()
self.metrics = AccMetrics()
tokenizer = load_tokenizer(tokenizer=tokenizer)
self._hidden_size = hidden_size
if loc_emb_layer:
embedding_dim = loc_emb_layer.embedding_dim
self.stlstm = STLSTM(
tokenizer=tokenizer,
embedding_dim=embedding_dim,
hidden_size=hidden_size,
ts_bucketizer=load_bucketizer(ts_bucketizer),
loc_bucketizer=load_bucketizer(loc_bucketizer),
loc_emb_layer=loc_emb_layer,
)
self._loss_reduction = LossEnum.parse(reduction)
self._num_locs = len(tokenizer)
self._use_sampled_softmax = use_sampled_softmax
if use_sampled_softmax:
self.w = torch.nn.Parameter(
torch.randn(self._num_locs, hidden_size) / np.sqrt(hidden_size)
)
self.b = torch.nn.Parameter(torch.zeros(size=(self._num_locs, 1)))
self.loss = SampledSoftmaxLoss(
weights=self.w,
bias=self.b,
num_words=self._num_locs,
num_samples=num_neg_samples,
reduction=self._loss_reduction.value,
use_sampled_softmax_in_eval=False,
)
else:
self.projector = nn.Linear(
in_features=hidden_size, out_features=len(tokenizer)
)
self.loss = nn.CrossEntropyLoss(reduction="none")
@property
def use_sampled_softmax(self) -> bool:
return self._use_sampled_softmax
@property
def hidden_size(self) -> int:
return self._hidden_size
@property
def num_locs(self) -> int:
return self._num_locs
@property
def loss_reduction(self) -> LossEnum:
return self._loss_reduction
[docs]
def encode(self, sample: STLSTMSample):
"""
以0为隐藏状态和细胞状态作为初始化,将一个session的前N-1个位置作为输入,输出后N-1个位置的隐藏状态
Parameters
----------
loc_seq: torch.LongTensor
位置序列
Returns
----------
output: torch.Tensor
shape is (B, T, H)
"""
# output shape is (B, T, H)
output, _ = self.stlstm(sample)
return output
[docs]
def forward(self, sample: STLSTMSample, k: int = 1):
"""
推理,选择最后k个时间步的输出
Parameters
----------
sample: STLSTMSample
推理时的样本
k: int, optional
取每条序列的最后k个时间步进行输出,默认值是1
Returns
----------
torch.Tensor, shape is (B, num_locs)
"""
# shape is (B, T, H)
output = self.encode(sample)
valid_lengths = sample.valid_lengths
# shape is (B, H)
last_pred = select_last_k_elements(
x=output, lengths=valid_lengths, k=k
).squeeze(dim=1)
if self.use_sampled_softmax:
# (B, num_locs)
last_pred = torch.matmul(
last_pred, self.w.transpose(0, 1)
) + self.b.transpose(0, 1)
else:
# (B, num_locs)
last_pred = self.projector(last_pred)
# (B, num_locs)
return torch.softmax(last_pred, dim=-1)
[docs]
def compute_loss(self, sample: STLSTMSample):
"""
计算一个batch的loss
Parameters
----------
sample: STLSTMSample
需要计算损失的样本
"""
# (B, T, H)
hidden = self.encode(sample)
# (B, T)
label = sample.labels
batch_size = sample.batch_size
if self.use_sampled_softmax:
loss = self.loss(
hidden.reshape(-1, self.hidden_size),
label.reshape(
-1,
),
sample.mask.reshape(
-1,
),
)
return loss, batch_size
else:
# (B, T)
loss = (
self.loss(
self.projector(hidden).reshape(-1, self.num_locs),
label.reshape(
-1,
),
).reshape(label.shape)
* sample.mask
)
if self.loss_reduction == LossEnum.SUM:
loss = loss.sum()
elif self.loss_reduction == LossEnum.MEAN:
loss = loss.sum() / sample.mask.sum()
return loss, batch_size
[docs]
def training_step(self, batch: STLSTMSample, batch_idx: int):
loss, batch_size = self.compute_loss(batch)
self.log("train_loss", loss, batch_size=batch_size)
return loss
[docs]
def validation_step(self, batch: STLSTMSample, batch_idx: int) -> torch.Tensor:
loss, batch_size = self.compute_loss(batch)
self.log("val_loss", loss, batch_size=batch_size)
# (B,)
pred = self.forward(sample=batch).argmax(dim=-1)
# (B,)
label = select_last_k_elements(batch.labels, batch.valid_lengths, k=1).squeeze(
dim=1
)
self.metrics.update(pred=pred, label=label)
return loss
[docs]
def test_step(self, batch: STLSTMSample, batch_idx: int) -> None:
# (B,)
pred = self.forward(sample=batch).argmax(dim=-1)
# (B,)
label = select_last_k_elements(batch.labels, batch.valid_lengths, k=1).squeeze(
dim=1
)
self.metrics.update(pred=pred, label=label)
[docs]
def on_validation_epoch_start(self):
self.metrics.reset()
[docs]
def on_validation_epoch_end(self):
metrics = self.metrics.value()
for key, value in metrics.items():
self.log(key, value, batch_size=1)
[docs]
def on_test_epoch_start(self):
self.metrics.reset()
[docs]
def on_test_epoch_end(self):
self.metrics.value()
[docs]
class HSTLSTM(BaseLightningModel):
def __init__(
self,
tokenizer: AbstractTokenizer,
embedding_dim: int,
hidden_size: int,
ts_buckets: Bucketizer,
loc_buckets: Bucketizer,
optimizer_type: str = "adam",
learning_rate: float = 1e-3,
):
super().__init__(optimizer_type=optimizer_type, learning_rate=learning_rate)
self.save_hyperparameters()
self.stlstm_encoder = STLSTM(
tokenizer=tokenizer,
embedding_dim=embedding_dim,
hidden_size=hidden_size,
ts_bucketizer=ts_buckets,
loc_bucketizer=loc_buckets,
)
self.lstm = nn.LSTM(
input_size=hidden_size, hidden_size=hidden_size, batch_first=True
)
self.stlstm_decoder = STLSTM(
tokenizer=tokenizer,
embedding_dim=embedding_dim,
hidden_size=hidden_size,
ts_bucketizer=ts_buckets,
loc_bucketizer=loc_buckets,
)
[docs]
def forward(
self,
loc_sessions: List[torch.LongTensor],
ts_upper_sessions: List[torch.LongTensor],
ts_lower_sessions: List[torch.LongTensor],
sd_upper_sessions: List[torch.LongTensor],
sd_lower_sessions: List[torch.LongTensor],
valid_lengths: List[List[int]],
):
# session的个数
num_sessions = len(loc_sessions)
for session_idx in range(num_sessions - 1):
pass
[docs]
def training_step(self, batch, batch_idx: int):
pass