Source code for trajdl.tokenizers.slot

# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pickle
import warnings
from pathlib import Path
from typing import Union

import torch


[docs] class Bucketizer: """ A class to create buckets for numerical data, allowing for indexing of values into buckets. Parameters ---------- lower_bound : float The lower bound of the range. upper_bound : float The upper bound of the range. num_buckets : int The number of buckets to create within the specified range. Attributes ---------- bucket_size : float The size of each bucket. Methods ------- get_bucket_index(value): Returns the index of the bucket that the given value falls into. """ def __init__(self, lower_bound: float, upper_bound: float, num_buckets: int): self._lower_bound = lower_bound self._upper_bound = upper_bound self._num_buckets = num_buckets self.bucket_size = (upper_bound - lower_bound) / num_buckets @property def lower_bound(self) -> float: return self._lower_bound @property def upper_bound(self) -> float: return self._upper_bound @property def num_buckets(self) -> int: return self._num_buckets
[docs] def get_bucket_index(self, value: float): """ Get the index of the bucket that the given value belongs to. Parameters ---------- value : float The value to be placed in a bucket. Returns ------- int The index of the bucket that contains the value. """ if value < self.lower_bound or value > self.upper_bound: warnings.warn( f"Value {value} is out of bounds ({self.lower_bound}, {self.upper_bound})", RuntimeWarning, ) if value <= self.lower_bound: return 0 if value >= self.upper_bound: return self.num_buckets - 1 # 计算桶的索引 idx = int((value - self.lower_bound) / self.bucket_size) # 确保idx不超过num_buckets - 1,这里会保证upper_bound属于最后一个桶 return min(idx, self.num_buckets - 1)
[docs] def get_bucket_indices(self, tensor: torch.Tensor) -> torch.LongTensor: """ Get the indices of the buckets for each value in the input tensor. """ indices = ( ((tensor - self.lower_bound) / self.bucket_size) .clamp(0, self.num_buckets - 1) .floor() ) return indices.long()
[docs] def save(self, path: Union[str, Path]): p = Path(path) folder = p.parent os.makedirs(folder, exist_ok=True) with open(p, "wb") as f: pickle.dump(self, f)
[docs] @staticmethod def load(path: Union[str, Path]) -> "Bucketizer": with open(path, "rb") as f: return pickle.load(f)