Source code for trajdl.datasets.sampler.t2vec
# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Generator, List, Tuple
import numpy as np
import pyarrow as pa
from torch.utils.data import Sampler
from ... import trajdl_cpp
from ..arrow.ext.t2vec import T2VECDataset
[docs]
class T2VECSampler(Sampler):
def __init__(
self,
ds: T2VECDataset,
buckets_boundaries: List[Tuple[int, int]],
num_batches: int,
batch_size: int,
):
super().__init__()
self.num_batches = num_batches
self.batch_size = batch_size
src_lengths = pa.compute.list_value_length(
ds.src_table.column("src")
).to_numpy()
src_indices = ds.src_table.column("label_idx").to_numpy()
trg_lengths = pa.compute.list_value_length(
ds.trg_table.column("trg")
).to_numpy()
src_bound, trg_bound = zip(*buckets_boundaries)
buckets_map: Dict[int, List[int]] = trajdl_cpp.bucketize(
src_lengths, src_indices, trg_lengths, src_bound, trg_bound
)
num_out = 0
self.buckets = [None] * len(buckets_boundaries) * 2
for bucket_idx, bucket in buckets_map.items():
if 0 <= bucket_idx < len(self.buckets):
self.buckets[bucket_idx] = bucket
else:
num_out += len(bucket)
print(f"num out: {num_out}")
self.dist = np.array([len(b) for b in self.buckets]) / len(ds)
def __len__(self) -> int:
return self.num_batches
def __iter__(self) -> Generator[np.ndarray, None, None]:
for _ in range(self.num_batches):
sample = np.random.multinomial(1, self.dist)
bucket_idx = np.nonzero(sample)[0][0]
yield np.random.choice(len(self.buckets[bucket_idx]), self.batch_size)