Source code for trajdl.datasets.modules.stlstm

# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Any, List, Optional

import torch
from torch.nn.utils.rnn import pad_sequence

from ...common.samples import STLSTMSample
from ...utils import load_bucketizer, valid_lengths_to_mask
from ..arrow import LocSeqDataset
from ..sampler.bucket import BucketSampler
from ..sampler.session import SessionSampler
from .abstract import BaseLocSeqDataModule


[docs] @dataclass class STLSTMDataModule(BaseLocSeqDataModule): num_train_batches: int = 10 num_val_batches: int = 10 num_train_buckets: int = 10 num_val_buckets: int = 10 ts_bucketizer: Optional[Any] = None loc_bucketizer: Optional[Any] = None def __post_init__(self): super().__post_init__()
[docs] def setup(self, stage: str): super().setup(stage=stage) self.train_sampler = BucketSampler( ds=self.train_ds, num_buckets=self.num_train_buckets, num_batches=self.num_train_batches, batch_size=self.train_batch_size, ) self.val_sampler = BucketSampler( ds=self.val_ds, num_buckets=self.num_val_buckets, num_batches=self.num_val_batches, batch_size=self.val_batch_size, ) self.ts_bucketizer = load_bucketizer(self.ts_bucketizer) self.loc_bucketizer = load_bucketizer(self.loc_bucketizer)
[docs] def collate_function(self, batch_ds: LocSeqDataset): loc_seqs = [] ts_upper = [] ts_lower = [] sd_upper = [] sd_lower = [] valid_lengths = [] labels = [] for seq, tsd, disd in zip(batch_ds.seq, batch_ds.ts_delta, batch_ds.dis_delta): seq = self.tokenizer.tokenize_loc_seq(seq, return_as="pt") loc_seqs.append(seq[:-1]) ts_upper.append( self.ts_bucketizer.get_bucket_indices( self.ts_bucketizer.upper_bound - torch.Tensor(tsd.as_py()) )[:-1] ) ts_lower.append( self.ts_bucketizer.get_bucket_indices( torch.Tensor(tsd.as_py()) - self.ts_bucketizer.lower_bound )[:-1] ) sd_upper.append( self.loc_bucketizer.get_bucket_indices( self.loc_bucketizer.upper_bound - torch.Tensor(disd.as_py()) )[:-1] ) sd_lower.append( self.loc_bucketizer.get_bucket_indices( torch.Tensor(disd.as_py()) - self.loc_bucketizer.lower_bound )[:-1] ) valid_lengths.append(len(loc_seqs[-1])) labels.append(seq[1:]) if valid_lengths[-1] < 1: raise ValueError("The length of sequence must be greater than 1.") mask = valid_lengths_to_mask(valid_lengths) return STLSTMSample( loc_seq=pad_sequence( loc_seqs, batch_first=True, padding_value=self.tokenizer.pad ), td_upper_seq=pad_sequence(ts_upper, batch_first=True, padding_value=0), td_lower_seq=pad_sequence(ts_lower, batch_first=True, padding_value=0), sd_upper_seq=pad_sequence(sd_upper, batch_first=True, padding_value=0), sd_lower_seq=pad_sequence(sd_lower, batch_first=True, padding_value=0), valid_lengths=valid_lengths, labels=pad_sequence( labels, batch_first=True, padding_value=self.tokenizer.pad ), mask=mask, )
[docs] @dataclass class HSTLSTMDataModule(BaseLocSeqDataModule): num_train_batches: int = 10 num_val_batches: int = 10 num_train_buckets: int = 10 num_val_buckets: int = 10 def __post_init__(self): super().__post_init__()
[docs] def setup(self, stage: str): super().setup(stage=stage) self.train_sampler = SessionSampler( ds=self.train_ds, num_batches=self.num_train_batches, batch_size=self.train_batch_size, )
[docs] def collate_function(self, ds: LocSeqDataset): # 每个元素是一个用户对应的样本所在的行下标 session_indices: List[List[int]] = [] user_sessions = [] previous_user = ds.entity_id[0] for idx in range(len(ds)): user_id = ds.entity_id[idx] if user_id != previous_user: session_indices.append(user_sessions) user_sessions = [] user_sessions.append(idx) previous_user = user_id session_indices.append(user_sessions) # 当前这个批次里面session数最多的样本的session数是多少 max_session_length = max( len(user_sessions) for user_sessions in session_indices ) # 每一项都是一个batch samples = [[] for _ in range(max_session_length)] # 遍历每个用户的session序列,然后将session组成batch放到samples里面 for user_sessions in session_indices: for session_idx, sample_idx in enumerate(user_sessions): samples[session_idx].append( ( self.tokenizer.tokenize_loc_seq( ds.seq[sample_idx], return_as="pt" ), ds.ts_delta[sample_idx].as_py(), ds.dis_delta[sample_idx].as_py(), ) ) # 这里是把samples里面的batch挑出来之后,对位置序列、tsd、disd进行pad loc_samples, ts_delta_samples, dis_delta_samples = [], [], [] valid_lengths = [] for session in samples: loc_seq, ts_delta, dis_delta = zip(*session) valid_lengths.append([len(t) for t in loc_seq]) loc_samples.append( pad_sequence( loc_seq, batch_first=True, padding_value=self.tokenizer.pad ) ) ts_delta_samples.append( pad_sequence( [torch.Tensor(t) for t in ts_delta], batch_first=True, padding_value=0, ) ) dis_delta_samples.append( pad_sequence( [torch.Tensor(t) for t in dis_delta], batch_first=True, padding_value=0, ) ) # 返回三个List[Tensor],每个tensor是一个batch,并且做了padding # 再返回一个List[List[int]],每个元素是session的valid_length return loc_samples, ts_delta_samples, dis_delta_samples, valid_lengths