Source code for trajdl.datasets.sampler.bucket

# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from dataclasses import dataclass
from typing import Generator, List

import numpy as np
from torch.utils.data import Sampler

from ..arrow.abstract import BaseArrowDataset


[docs] @dataclass class SeqInfo: sample_idx: int size: int
[docs] def mean_size(bucket: List[SeqInfo]) -> float: return np.mean([seq_info.size for seq_info in bucket]).item()
[docs] def generate_sub_buckets( stats: List[SeqInfo], num_buckets: int = 10 ) -> List[List[SeqInfo]]: """Generate sub-buckets from statistics. Parameters ---------- stats : List[SeqInfo] A list of SeqInfo where each instance contains a sample index and length. num_buckets : int, optional The desired number of buckets (default is 10). Returns ------- List[List[Tuple[int, int]]] A list of sub-buckets containing the statistics. Raises ------ ValueError If num_buckets is less than 1. """ if num_buckets < 1: raise ValueError("`num_buckets` must be at least 1.") num_samples = len(stats) if num_samples < num_buckets: num_buckets = num_samples sorted_array = sorted(stats, key=lambda x: (x.size, x.sample_idx)) num_samples_per_bucket = math.floor(num_samples / num_buckets) buckets = [ sorted_array[i * num_samples_per_bucket : (i + 1) * num_samples_per_bucket] for i in range(num_buckets) ] return buckets
[docs] def generate_buckets_by_stats( stats: List[SeqInfo], num_buckets: int ) -> List[np.ndarray]: all_buckets = [ np.array([seq_info.sample_idx for seq_info in b], dtype=int) for b in generate_sub_buckets(stats, num_buckets=num_buckets) ] return all_buckets
[docs] def generate_buckets( seqs: BaseArrowDataset, num_buckets: int, ) -> List[np.ndarray]: """Generate buckets from sequences. Parameters ---------- seqs : BaseArrowDataset A LocSeqDataset or a TrajectoryDataset. num_buckets : int The desired number of buckets. Returns ------- List[np.ndarray] A list of NumPy arrays, where each array contains sample indices for a bucket. """ seqs = seqs.seq stats = [ SeqInfo(sample_idx=sample_idx, size=len(seq)) for sample_idx, seq in enumerate(seqs) ] return generate_buckets_by_stats(stats=stats, num_buckets=num_buckets)
[docs] class BucketSampler(Sampler): """Sampler that produces batches from randomly selected buckets. Attributes ---------- ds: BaseArrowDataset Dataset. num_buckets: int The number of buckets. num_batches : int The number of batches to yield. batch_size : int The size of each batch. """ def __init__( self, ds: BaseArrowDataset, num_buckets: int, num_batches: int, batch_size: int, seed=None, ): """Initialize the BucketSampler. Parameters ---------- ds: BaseArrowDataset Dataset. num_buckets: int The number of buckets. num_batches : int The number of batches to yield from the sampler. batch_size : int The size of each batch. seed : int, optional Random seed for reproducibility (default is None). """ self.buckets = generate_buckets(seqs=ds, num_buckets=num_buckets) self.num_buckets = len(self.buckets) self.batch_size = batch_size self.num_batches = num_batches if seed is not None: np.random.seed(seed) def __len__(self) -> int: """Return the number of batches.""" return self.num_batches def __iter__(self) -> Generator[np.ndarray, None, None]: """Yield batches of sample indices.""" for _ in range(self.num_batches): random_bucket = self.buckets[np.random.choice(self.num_buckets)] yield np.random.choice(random_bucket, self.batch_size)