Source code for trajdl.datasets.base.locseq
# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterator, List, Union
from .abstract import BaseSeq
[docs]
class LocSeq(BaseSeq):
"""Location sequence representation.
Attributes
----------
loc_seq : List[str]
The list of locations in the sequence.
entity_id : Union[str, None]
An optional identifier, which can store the sequence ID or user ID.
ts_seq : Union[List[int], None]
A sequence of timestamps (in seconds or milliseconds), defined by the user without constraints.
"""
[docs]
@staticmethod
def check_seq(seq: Any) -> List[str]:
"""Check if the sequence is valid.
Parameters
----------
seq : Any
The sequence to be checked.
Returns
-------
List[str]
The validated location sequence.
Raises
------
ValueError
If seq is not a List[str].
"""
if not isinstance(seq, list) or any(not isinstance(loc, str) for loc in seq):
raise ValueError("`seq` must be a List[str]")
return seq
def __getitem__(self, idx: int) -> str:
"""Get the location at the specified index.
Parameters
----------
idx : int
The index of the location in the sequence.
Returns
-------
str
The location at the given index.
"""
return self._seq[idx]
def __iter__(self) -> Iterator[str]:
"""Iterate over the location sequence.
Yields
------
str
Each location in the sequence.
"""
for loc in self._seq:
yield loc
@property
def o(self) -> str:
"""Return the starting location."""
return self.__getitem__(0)
@property
def d(self) -> str:
"""Return the destination location."""
return self.__getitem__(-1)
def _loc_expr(self, seq: Union[List[str], List[int]], length: int) -> str:
"""Create a string representation of the sequence for display.
Parameters
----------
seq : Union[List[str], List[int]]
The sequence to create a representation for.
length : int
The length of the sequence.
Returns
-------
str
A string representation of the first three elements of the sequence,
followed by an ellipsis if the length exceeds three.
"""
seq_example = ", ".join(
f"'{i}'" if isinstance(i, str) else str(i) for i in seq[:3]
)
if length > 3:
seq_example += ", ..."
return seq_example
def __repr__(self) -> str:
"""Return a string representation of the LocSeq object.
Returns
-------
str
A string that represents the LocSeq object.
"""
length = self.__len__()
size_repr = f"size={length}"
entity_id_expr = f"entity_id='{self.entity_id}'" if self.entity_id else ""
loc_seq_expr = (
f"loc_seq={self._loc_expr(self._seq, length)}" if self._seq else ""
)
ts_seq_repr = f"ts_seq={self._loc_expr(self._ts, length)}" if self._ts else ""
expression = ", ".join(
i for i in [size_repr, entity_id_expr, loc_seq_expr, ts_seq_repr] if i
)
return f"LocSeq({expression})"