trajdl.datasets.sampler.bucket module#
- class trajdl.datasets.sampler.bucket.BucketSampler(ds: BaseArrowDataset, num_buckets: int, num_batches: int, batch_size: int, seed=None)[source]#
Bases:
SamplerSampler that produces batches from randomly selected buckets.
- ds#
Dataset.
- Type:
- num_buckets#
The number of buckets.
- Type:
int
- num_batches#
The number of batches to yield.
- Type:
int
- batch_size#
The size of each batch.
- Type:
int
- class trajdl.datasets.sampler.bucket.SeqInfo(sample_idx: int, size: int)[source]#
Bases:
object- sample_idx: int#
- size: int#
- trajdl.datasets.sampler.bucket.generate_buckets(seqs: BaseArrowDataset, num_buckets: int) List[ndarray][source]#
Generate buckets from sequences.
- Parameters:
seqs (BaseArrowDataset) – A LocSeqDataset or a TrajectoryDataset.
num_buckets (int) – The desired number of buckets.
- Returns:
A list of NumPy arrays, where each array contains sample indices for a bucket.
- Return type:
List[np.ndarray]
- trajdl.datasets.sampler.bucket.generate_buckets_by_stats(stats: List[SeqInfo], num_buckets: int) List[ndarray][source]#
- trajdl.datasets.sampler.bucket.generate_sub_buckets(stats: List[SeqInfo], num_buckets: int = 10) List[List[SeqInfo]][source]#
Generate sub-buckets from statistics.
- Parameters:
stats (List[SeqInfo]) – A list of SeqInfo where each instance contains a sample index and length.
num_buckets (int, optional) – The desired number of buckets (default is 10).
- Returns:
A list of sub-buckets containing the statistics.
- Return type:
List[List[Tuple[int, int]]]
- Raises:
ValueError – If num_buckets is less than 1.