Source code for trajdl.common.enum

# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import Any, List


[docs] class BaseEnum(Enum):
[docs] @classmethod def from_string(cls, value: str): for item in cls: if item.value == value: return item raise ValueError(f"{value} is not a valid {cls.__name__}")
[docs] @classmethod def parse(cls, value: Any) -> "BaseEnum": if isinstance(value, str): return cls.from_string(value) elif isinstance(value, cls): return value else: raise ValueError( f"`value` should be a str or an instance of {cls.__name__}" )
[docs] @classmethod def values(cls) -> List[Any]: return {item.value for item in cls}
[docs] class Mode(BaseEnum): PRETRAIN = "pretrain" TRAIN = "train" EVAL = "eval"
[docs] class LossEnum(BaseEnum): SUM = "sum" MEAN = "mean" NONE = "none"
[docs] class ArrowColName(BaseEnum): SEQ = "seq" ENTITY_ID = "entity_id" TS_SEQ = "ts_seq" TS_DELTA = "ts_delta" DIS_DELTA = "dis_delta" START_TS = "start_ts"
[docs] class TokenEnum(BaseEnum): PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" MASK_TOKEN = "[MASK]"
[docs] class OpenSourceDatasetEnum(BaseEnum): GOWALLA_URL = "GOWALLA_URL" PORTO_URL = "PORTO_URL"
[docs] class ReturnASEnum(BaseEnum): NP = "np" PT = "pt" PY = "py" TRAJ = "traj" LOCSEQ = "locseq"