trajdl.datasets.sampler.bucket module#

class trajdl.datasets.sampler.bucket.BucketSampler(ds: BaseArrowDataset, num_buckets: int, num_batches: int, batch_size: int, seed=None)[source]#

Bases: Sampler

Sampler that produces batches from randomly selected buckets.

ds#

Dataset.

Type:

BaseArrowDataset

num_buckets#

The number of buckets.

Type:

int

num_batches#

The number of batches to yield.

Type:

int

batch_size#

The size of each batch.

Type:

int

class trajdl.datasets.sampler.bucket.SeqInfo(sample_idx: int, size: int)[source]#

Bases: object

sample_idx: int#
size: int#
trajdl.datasets.sampler.bucket.generate_buckets(seqs: BaseArrowDataset, num_buckets: int) List[ndarray][source]#

Generate buckets from sequences.

Parameters:
  • seqs (BaseArrowDataset) – A LocSeqDataset or a TrajectoryDataset.

  • num_buckets (int) – The desired number of buckets.

Returns:

A list of NumPy arrays, where each array contains sample indices for a bucket.

Return type:

List[np.ndarray]

trajdl.datasets.sampler.bucket.generate_buckets_by_stats(stats: List[SeqInfo], num_buckets: int) List[ndarray][source]#
trajdl.datasets.sampler.bucket.generate_sub_buckets(stats: List[SeqInfo], num_buckets: int = 10) List[List[SeqInfo]][source]#

Generate sub-buckets from statistics.

Parameters:
  • stats (List[SeqInfo]) – A list of SeqInfo where each instance contains a sample index and length.

  • num_buckets (int, optional) – The desired number of buckets (default is 10).

Returns:

A list of sub-buckets containing the statistics.

Return type:

List[List[Tuple[int, int]]]

Raises:

ValueError – If num_buckets is less than 1.

trajdl.datasets.sampler.bucket.mean_size(bucket: List[SeqInfo]) float[source]#