trajdl.datasets.sampler.session module#
这个文件主要做一些session维度的sampler
- class trajdl.datasets.sampler.session.SessionSampler(ds: BaseArrowDataset, num_batches: int, batch_size: int, num_buckets: int = 10)[source]#
Bases:
Sampler这个sampler是为了ST-LSTM等算法设计的,需要在基础的数据集上根据用户的id进行聚合 sampler在构建的时候,要拿到数据集,比如LocSeqDataset或者TrajectoryDataset
- property batch_size#
- construct_user_samples_dict(ds: BaseArrowDataset) Dict[str, List[int]][source]#
这个方法是用来统计每个用户的序列的
- property num_buckets#