# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from ..common.enum import Mode
from ..common.samples import GMVSAESample
from ..loss.sampled_softmax import SampledSoftmaxLoss
from ..utils import find_best_checkpoint, load_tokenizer, tiny_value_of_dtype
from .framework import PretrainTrainFramework
[docs]
class Encoder(nn.Module):
def __init__(
self,
embedding_layer: nn.Embedding,
hidden_size: int,
num_layers: int = 1,
dropout: float = 0.0,
):
super(Encoder, self).__init__()
self.emb = embedding_layer
self.rnn = nn.GRU(
input_size=self.emb.embedding_dim,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
batch_first=True,
)
[docs]
def forward(self, seq: torch.LongTensor, lengths: List[int]):
# (B, T, C)
emb = self.emb(seq)
x = pack_padded_sequence(
input=emb, lengths=lengths, batch_first=True, enforce_sorted=False
)
# shape of last_encoder_hidden_state is (num_layers, batch_size, num_hiddens)
_, last_encoder_hidden_state = self.rnn(x)
# (num_layers, B, H)
return last_encoder_hidden_state
[docs]
class LatentSpace(nn.Module):
def __init__(
self,
hidden_size: int,
num_layers: int,
c: int,
reduction: str = "none",
):
super(LatentSpace, self).__init__()
self.mean_linear = nn.Linear(
in_features=hidden_size * num_layers, out_features=hidden_size
)
self.log_var_linear = nn.Linear(
in_features=hidden_size * num_layers, out_features=hidden_size
)
self.mean_c = nn.Parameter(torch.rand(size=(c, hidden_size)))
self.log_var_c = nn.Parameter(
torch.zeros(size=(c, hidden_size)), requires_grad=False
)
self.c = c
self.reduction = reduction
[docs]
def get_mean_c(self, idx: int) -> torch.Tensor:
assert 0 <= idx < self.c
return self.mean_c[idx : idx + 1]
[docs]
def forward(self, encoder_state: torch.Tensor):
"""
encoder_state shape is (num_layers, B, H)
"""
batch_size = encoder_state.shape[1]
# (B, num_layers * H)
state = encoder_state.swapaxes(0, 1).reshape(batch_size, -1)
# (B, H)
mean_z = self.mean_linear(state)
# (B, H)
log_var_z = self.log_var_linear(state)
# (B, H)
eps_z = torch.normal(mean=0.0, std=1.0, size=mean_z.shape, device=mean_z.device)
# (B, H)
z = eps_z * torch.exp(log_var_z) + mean_z
# (B, c, H)
stack_mu_c = self.mean_c.unsqueeze(dim=0).repeat(batch_size, 1, 1)
# (B, c, H)
stack_log_sigma_sq_c = self.log_var_c.unsqueeze(dim=0).repeat(batch_size, 1, 1)
# (B, c, H)
stack_z = z.unsqueeze(dim=1).repeat(1, self.c, 1)
# (B, c)
pi_post_logits = torch.sum(
(stack_z - stack_mu_c) ** 2 / torch.exp(stack_log_sigma_sq_c), dim=-1
)
# (B, c)
pi_post = torch.softmax(pi_post_logits, dim=-1) + tiny_value_of_dtype(
pi_post_logits.dtype
)
# (B, c, H)
stack_mu_z = z.unsqueeze(dim=1).repeat(1, self.c, 1)
# (B, c, H)
stack_log_sigma_sq_z = log_var_z.unsqueeze(dim=1).repeat(1, self.c, 1)
# (B,)
batch_gaussian_loss = 0.5 * torch.sum(
pi_post
* torch.mean(
stack_log_sigma_sq_c
+ torch.exp(stack_log_sigma_sq_z) / torch.exp(stack_log_sigma_sq_c)
+ (stack_mu_z - stack_mu_c) ** 2 / torch.exp(stack_log_sigma_sq_c),
dim=-1,
),
dim=-1,
) - 0.5 * torch.mean(1 + log_var_z, dim=-1)
if self.reduction == "mean":
batch_gaussian_loss = batch_gaussian_loss.mean()
# shape is (1,)
# batch_uniform_loss = torch.mean(
# torch.mean(pi_post, dim=0) * torch.log(torch.mean(pi_post, dim=0))
# )
batch_uniform_loss = torch.mean(pi_post * torch.log(pi_post), dim=1)
if self.reduction == "mean":
batch_uniform_loss = batch_uniform_loss.mean()
return z, batch_gaussian_loss, batch_uniform_loss
[docs]
class Decoder(nn.Module):
def __init__(
self,
emb: nn.Embedding,
hidden_size: int,
padding_value: float,
num_layers: int = 1,
dropout: float = 0.0,
):
super(Decoder, self).__init__()
self.emb = emb
self.rnn = nn.GRU(
input_size=self.emb.embedding_dim,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
batch_first=True,
)
self._padding_value = padding_value
self._num_layers = num_layers
[docs]
def forward(
self, seq: torch.LongTensor, lengths: List[int], init_hidden: torch.Tensor
):
"""
init_hidden shape is (B, H)
"""
# (B, T, C)
emb = self.emb(seq)
x = pack_padded_sequence(
emb, lengths=lengths, batch_first=True, enforce_sorted=False
)
# (B, T, C)
output, _ = self.rnn(
x, init_hidden.unsqueeze(dim=0).repeat(self._num_layers, 1, 1)
)
output, _ = pad_packed_sequence(
output, batch_first=True, padding_value=self._padding_value
)
return output
[docs]
class GMVSAE(PretrainTrainFramework):
"""
GMVSAE
"""
def __init__(
self,
tokenizer: Union[str, Any],
embedding_dim: int,
hidden_size: int,
mem_num: int,
mode: str,
num_layers: int = 1,
num_neg_samples: int = 64,
init_mu_c_pretrained_path: Optional[str] = None,
pretrain_ckpt_folder: Optional[str] = None,
):
"""
Parameters
----------
tokenizer: path of tokenzier or a tokenizer instance
embedding_dim: size of emb dim
hidden_size: rnn hidden size
mem_num: num types of route
num_layers: default 1, num layers of rnn
mode: {"pretrain", "train", "eval"}
num_neg_samples: default 64, num negative samples in sampled softmax loss
init_mu_c_pretrained_path: default None, path of pretrained init_mu_c
pretrain_ckpt_folder: default None, path of pretrained ckpt folder
"""
super().__init__(mode=mode)
self.save_hyperparameters()
tokenizer = load_tokenizer(tokenizer)
num_vocab = len(tokenizer)
self.emb = nn.Embedding(num_vocab, embedding_dim)
self.encoder = Encoder(
embedding_layer=self.emb, hidden_size=hidden_size, num_layers=num_layers
)
self.latent = LatentSpace(
hidden_size=hidden_size,
num_layers=num_layers,
c=mem_num,
reduction="mean",
)
self.decoder = Decoder(
emb=self.emb,
hidden_size=hidden_size,
num_layers=num_layers,
padding_value=tokenizer.pad,
)
# weights and bias are used for transform rnn output into label space
# just create tensors to use as the embeddings Glorit init (std=(1.0 / sqrt(fan_in))
self.w = torch.nn.Parameter(
torch.randn(num_vocab, hidden_size) / np.sqrt(hidden_size)
)
self.b = torch.nn.Parameter(torch.zeros(size=(num_vocab, 1)))
self.reconstruct_loss = SampledSoftmaxLoss(
weights=self.w,
bias=self.b,
num_words=num_vocab,
num_samples=num_neg_samples,
reduction="mean",
use_sampled_softmax_in_eval=True,
)
self.hidden_size = hidden_size
self._mem_num = mem_num
self._pretrain_ckpt_folder = pretrain_ckpt_folder
self._init_mu_c_pretrained_path = init_mu_c_pretrained_path
@property
def mem_num(self) -> int:
return self._mem_num
[docs]
def init_from_pretrained_ckpt(self) -> None:
"""
当模型是训练模式的时候,从预训练的checkpoint进行参数的加载
"""
if self.mode == Mode.TRAIN:
if self._pretrain_ckpt_folder:
ckpt_filename = find_best_checkpoint(
self._pretrain_ckpt_folder, is_maximizing=False
)
print(f"load weights from {ckpt_filename}")
checkpoint = torch.load(
os.path.join(self._pretrain_ckpt_folder, ckpt_filename)
)
self.load_state_dict(checkpoint["state_dict"])
if self._init_mu_c_pretrained_path:
print(f"load init_mu_c from {self._init_mu_c_pretrained_path}")
init_mu_c = np.load(self._init_mu_c_pretrained_path)
with torch.no_grad():
self.latent.mean_c.copy_(
torch.from_numpy(init_mu_c).to(self.latent.mean_c.dtype)
)
[docs]
def init_decoder_state_for_inference(
self, batch_size: int, c_idx: int
) -> torch.Tensor:
"""
init a decoder state for inference
"""
assert 0 <= c_idx < self.mem_num
# shape is (1, hidden_size)
mean_c = self.latent.get_mean_c(c_idx)
# shape is (B, hidden_size)
return mean_c.repeat(batch_size, 1)
[docs]
def decode(
self,
init_state: torch.Tensor,
decoder_seq: torch.LongTensor,
decoder_lengths: List[int],
decoder_labels: torch.LongTensor,
mask: torch.BoolTensor,
) -> torch.Tensor:
"""
init_state: shape is (B, H)
decoder_seq: shape is (B, T + 1), each of seq added BOS
decoder_lengths: List[int], encoder_lengths + 1
decoder_labels: shape is (B, T + 1), each of seq added EOS
mask: shape is (B, T + 1)
"""
# (B, T, H)
decoder_output = self.decoder(decoder_seq, decoder_lengths, init_state)
# (B, T)
logits = torch.sigmoid(
(decoder_output * self.w[decoder_labels]).sum(dim=-1)
+ self.b[decoder_labels].squeeze(dim=-1)
)
# (B,)
batch_likelihood = (logits * mask).sum(dim=1) / mask.sum(dim=1)
return batch_likelihood
[docs]
def forward(self, batch: GMVSAESample) -> torch.Tensor:
"""
推理的逻辑,在预训练阶段和评估阶段的生成结果不同。
1. 预训练阶段是生成隐变量z
2. 评估阶段是生成序列的异常分数。
Parameters
----------
batch: GMVSAESample
输入样本
batch_idx: int
lightning框架使用的batch_idx
Returns
----------
z: torch.Tensor
当模式为预训练的时候返回这一项,shape is (num_layers, B, H),隐变量z
inference_result: torch.Tensor
当模式为评估的时候返回这一项,shape is (B,),是各条序列的异常分数
"""
if self.mode == Mode.PRETRAIN:
# (num_layers, B, H)
return self.generate_z(batch.encoder_seq, batch.encoder_lengths)
elif self.mode == Mode.EVAL:
return self.abnormal_detect(
decoder_seq=batch.decoder_seq,
decoder_lengths=batch.decoder_lengths,
decoder_labels=batch.decoder_labels,
mask=batch.mask,
)
[docs]
def abnormal_detect(
self,
decoder_seq: torch.LongTensor,
decoder_lengths: List[int],
decoder_labels: torch.LongTensor,
mask: torch.BoolTensor,
) -> torch.Tensor:
"""
使用解码器进行序列的异常检测
Parameters
----------
decoder_seq: torch.LongTensor
解码器输入的序列
decoder_lengths: List[int]
解码器输入的各序列的长度
decoder_labels: torch.LongTensor
解码器需要的输出
mask: torch.BoolTensor
解码器输入序列长度对应的mask矩阵
Returns
----------
torch.Tensor
shape is (B,),表示每条序列的异常分数
"""
batch_size = len(decoder_lengths)
scores = []
for c_idx in range(self.mem_num):
init_state = self.init_decoder_state_for_inference(
batch_size=batch_size, c_idx=c_idx
)
# (B,)
batch_likelihood = self.decode(
init_state, decoder_seq, decoder_lengths, decoder_labels, mask
)
scores.append(batch_likelihood.unsqueeze(dim=1))
# (B, c)
scores = torch.cat(scores, dim=1)
# (B,)
return scores.max(dim=1).values
[docs]
def generate_z(
self,
encoder_seq: torch.LongTensor,
encoder_lengths: List[int],
return_loss: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
"""
这个是编码阶段进行z的生成
Parameters
----------
encoder_seq: torch.LongTensor
编码器的输入序列,shape是(B, T)
encoder_lengths: List[int]
编码器输入序列各个序列的长度
return_loss: bool, optional
是否返回损失,默认是False
Returns
----------
(z, batch_gaussian_loss, batch_uniform_loss): Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
当return_loss为True的时候,再返回两个loss,三个tensor的shape都是(num_layers, B, H)
z: torch.Tensor
当return_loss为False的时候,只返回z,shape是(num_layers, B, H)
"""
# (num_layers, B, H)
encoder_last_hidden = self.encoder(encoder_seq, encoder_lengths)
# (num_layers, B, H), (num_layers, B, H), (num_layers, B, H)
z, batch_gaussian_loss, batch_uniform_loss = self.latent(encoder_last_hidden)
if return_loss:
return z, batch_gaussian_loss, batch_uniform_loss
else:
return z
[docs]
def compute_loss(self, batch: GMVSAESample):
"""
计算损失
Parameters
----------
batch: GMVSAESample
输入样本
Returns
----------
loss: torch.Tensor
shape is (1,)
batch_size: int
batch size
"""
batch_size = batch.encoder_seq.shape[0]
# (num_layers, B, H), (num_layers, B, H), (num_layers, B, H)
z, batch_gaussian_loss, batch_uniform_loss = self.generate_z(
encoder_seq=batch.encoder_seq,
encoder_lengths=batch.encoder_lengths,
return_loss=True,
)
# (B, T, H)
decoder_output = self.decoder(batch.decoder_seq, batch.decoder_lengths, z)
# (B * T, H)
decoder_output_reshape = decoder_output.reshape(-1, self.hidden_size)
# (B * T,)
decoder_labels_reshape = batch.decoder_labels.reshape(-1)
# (B * T,)
mask_reshape = batch.mask.reshape(-1)
# (1,)
reconstruct_loss = self.reconstruct_loss(
decoder_output_reshape, decoder_labels_reshape, mask_reshape
)
if self.mode == Mode.PRETRAIN:
return reconstruct_loss, batch_size
elif self.mode == Mode.TRAIN:
return (
reconstruct_loss + batch_gaussian_loss + batch_uniform_loss,
batch_size,
)
else:
raise RuntimeError(f"Invalid model value: {self.mode}")
[docs]
def training_step(self, batch: GMVSAESample, batch_idx: int):
loss, batch_size = self.compute_loss(batch)
self.log("train_loss", loss, batch_size=batch_size)
return loss
[docs]
def validation_step(self, batch: GMVSAESample, batch_idx: int):
loss, batch_size = self.compute_loss(batch)
self.log("val_loss", loss, batch_size=batch_size)
return loss