Source code for trajdl.tokenizers.simple
# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Dict, Iterable, List, Union
import pyarrow as pa
from ..common.enum import TokenEnum
from ..datasets.base import LocSeq
from .abstract import AbstractLocSeqTokenizer
[docs]
class SimpleTokenizer(AbstractLocSeqTokenizer):
[docs]
@classmethod
def construct_vocab(cls, init_vocab: Dict[str, int]) -> Dict[str, int]:
max_idx = max(init_vocab.values())
for idx, token in enumerate(
[
TokenEnum.BOS_TOKEN.value,
TokenEnum.EOS_TOKEN.value,
TokenEnum.UNK_TOKEN.value,
TokenEnum.MASK_TOKEN.value,
TokenEnum.PAD_TOKEN.value,
]
):
if token in init_vocab:
warnings.warn(
f"Token '{token}' exist in vocab, tokenizer will not give this token a index automatically.",
RuntimeWarning,
)
init_vocab[token] = max_idx + idx + 1
return init_vocab
[docs]
@classmethod
def build(cls, init_vocab: Dict[str, int]) -> "SimpleTokenizer":
vocab = cls.construct_vocab(init_vocab)
return cls(vocab=vocab)
[docs]
def loc2idx(self, loc: str) -> int:
return self.vocab.get(loc, self._unk)
def _tokenize_loc_seq_impl(
self, loc_seq: Union[Iterable[str], LocSeq, pa.ListScalar]
) -> List[int]:
if isinstance(loc_seq, pa.ListScalar):
loc_seq = loc_seq.as_py()
return [self.loc2idx(loc) for loc in loc_seq]