trajdl.datasets.sampler.session module#

这个文件主要做一些session维度的sampler

class trajdl.datasets.sampler.session.SessionSampler(ds: BaseArrowDataset, num_batches: int, batch_size: int, num_buckets: int = 10)[source]#

Bases: Sampler

这个sampler是为了ST-LSTM等算法设计的,需要在基础的数据集上根据用户的id进行聚合 sampler在构建的时候,要拿到数据集,比如LocSeqDataset或者TrajectoryDataset

property batch_size#
construct_user_samples_dict(ds: BaseArrowDataset) Dict[str, List[int]][source]#

这个方法是用来统计每个用户的序列的

property num_buckets#