Source code for trajdl.datasets.arrow.abstract

# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Iterable, List, Union

import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.parquet as pq

from ...common.enum import ArrowColName
from ..base import BaseSeq


[docs] class BaseArrowDataset(ABC): def __init__(self, table: pa.Table, unsafe: bool = False): """ Initialize the BaseArrowDataset. Parameters ---------- table : pa.Table The input Arrow table to initialize the dataset. unsafe : bool, optional If set to True, the correctness of the input table will not be checked. Raises ------ ValueError If unsafe is False and the table fails the validation check. """ super().__init__() if not unsafe: self.check_table(table) self.table = table @property def seq(self) -> pa.lib.ChunkedArray: """ Get the SEQ column from the Arrow table. Returns ------- pa.lib.ChunkedArray The SEQ column as a ChunkedArray. """ return self.table.column(ArrowColName.SEQ.value) @property def entity_id(self) -> pa.lib.ChunkedArray: """ Get the ENTITY_ID column from the Arrow table. Returns ------- pa.lib.ChunkedArray The ENTITY_ID column as a ChunkedArray. """ return self.table.column(ArrowColName.ENTITY_ID.value) @property def ts_seq(self) -> pa.lib.ChunkedArray: """ Get the TS_SEQ column from the Arrow table. Returns ------- pa.lib.ChunkedArray The TS_SEQ column as a ChunkedArray. """ return self.table.column(ArrowColName.TS_SEQ.value) @property def ts_delta(self) -> pa.lib.ChunkedArray: """ Get the TS_DELTA column from the Arrow table. Returns ------- pa.lib.ChunkedArray The TS_DELTA column as a ChunkedArray. """ return self.table.column(ArrowColName.TS_DELTA.value) @property def dis_delta(self) -> pa.lib.ChunkedArray: """ Get the DIS_DELTA column from the Arrow table. Returns ------- pa.lib.ChunkedArray The DIS_DELTA column as a ChunkedArray. """ return self.table.column(ArrowColName.DIS_DELTA.value) @property def start_ts(self) -> pa.lib.ChunkedArray: """ Get the START_TS column from the Arrow table. Returns ------- pa.lib.ChunkedArray The START_TS column as a ChunkedArray. """ return self.table.column(ArrowColName.START_TS.value)
[docs] @classmethod @abstractmethod def schema(cls) -> pa.lib.Schema: """ Abstract method to define the Arrow schema for the dataset. Returns ------- pa.lib.Schema The schema definition of the dataset. Raises ------ NotImplementedError If the method is not implemented in a subclass. """ raise NotImplementedError( "Subclasses should implement this method." ) # pragma: no cover
[docs] @abstractmethod def check_table(self, table: pa.Table) -> None: """ Check the validity of the input table. Parameters ---------- table : pa.Table The table to validate. Raises ------ ValueError If the table type is not correct. NotImplementedError If the method is not implemented in a subclass. """ raise NotImplementedError( "Subclasses should implement this method." ) # pragma: no cover
[docs] @classmethod def init_from_arrow(cls, table: pa.Table): """ Initialize the dataset from an Arrow table. Parameters ---------- table : pa.Table The Arrow table to initialize the dataset. Returns ------- BaseArrowDataset An instance of the dataset initialized from the Arrow table. Warns ----- RuntimeWarning If any field in the schema does not exist in the input table. """ col_names = {field.name for field in table.schema} arrays = [] for field in cls.schema(): if field.name not in col_names: warnings.warn( f"Field {field} does not exist in the input table.", RuntimeWarning ) arrays.append(pa.array([None] * len(table), type=field.type)) else: arrays.append(table[field.name]) new_table = pa.Table.from_arrays(arrays, schema=cls.schema()) return cls(table=new_table) # Updated to return the new_table
[docs] @classmethod def init_from_table( cls, table: Union[pa.Table, pl.DataFrame, pd.DataFrame] ) -> "BaseArrowDataset": """ Initialize the dataset from a Polars, Pandas, or Arrow table. Parameters ---------- table : Union[pa.Table, pl.DataFrame, pd.DataFrame] The input table to initialize the dataset. Returns ------- BaseArrowDataset An instance of the dataset initialized from the provided table. Raises ------ ValueError If the input is not one of the accepted table types. """ if isinstance(table, pl.DataFrame): arrow_table = table.to_arrow() elif isinstance(table, pd.DataFrame): arrow_table = pa.Table.from_pandas(table) elif isinstance(table, pa.Table): arrow_table = table else: raise ValueError( "`table` should be an instance of {'pyarrow.Table', 'polars.DataFrame', 'pandas.DataFrame'}" ) return cls.init_from_arrow(table=arrow_table)
[docs] @classmethod def init_from_parquet(cls, path: str) -> "BaseArrowDataset": """ Initialize the dataset from a Parquet file. Parameters ---------- path : str The file path to the Parquet file. Returns ------- BaseArrowDataset An instance of the dataset initialized from the Parquet file. Notes ----- Due to differences in handling List[List[Float32]] across different frameworks, it is recommended to read the file using PyArrow and try to convert types where necessary. """ arrow_table = pq.read_table(path) return cls.init_from_arrow(arrow_table)
def __len__(self) -> int: """ Get the number of rows in the dataset. Returns ------- int The number of rows in the dataset. """ return len(self.table) def __getitem__(self, idx: Union[int, np.int64]) -> "BaseArrowDataset": """ Retrieve a single row from the dataset. Parameters ---------- idx : Union[int, np.int64] The index of the row to retrieve. Returns ------- BaseArrowDataset A new dataset containing the requested rows. """ return self.__class__(table=self.table.slice(idx, length=1), unsafe=True) def __getitems__( self, indices: List[Union[int, Iterable[int], Iterable[np.int64], np.ndarray]] ) -> "BaseArrowDataset": """ Retrieve multiple rows from the dataset. Parameters ---------- indices : List[Union[int, Iterable[int], Iterable[np.int64], np.ndarray]] The indices of the rows to retrieve. Returns ------- BaseArrowDataset A new dataset containing the requested rows. """ if all(isinstance(tmp, (int, np.integer)) for tmp in indices): new_indices = indices else: new_indices = (idx for index_list in indices for idx in index_list) rows = (self.table.slice(idx, length=1) for idx in new_indices) table = pa.concat_tables(rows) return self.__class__(table=table, unsafe=True)
[docs] def save(self, path: Union[str, Path]) -> None: """ Save the dataset to a Parquet file. Parameters ---------- path : Union[str, Path] The file path to save the dataset. Notes ----- If the provided path does not end with '.parquet', it will be appended automatically. """ p = Path(path) if p.suffix != ".parquet": p = p.with_suffix(".parquet") folder = p.parent os.makedirs(folder, exist_ok=True) pq.write_table(self.table, p)
[docs] def to_polars(self) -> pl.DataFrame: return pl.from_arrow(self.table)
[docs] @abstractmethod def sub_classes_construction(self, *args, **kwargs) -> BaseSeq: """ 这个方法是将table的一行数据转换为单条序列的实例 """ raise NotImplementedError( "Subclasses should implement this method." ) # pragma: no cover
[docs] def iter_as_seqs(self): cols = { col_name: self.table.column(col_name) for col_name in self.schema().names } for idx in range(len(self.table)): yield self.sub_classes_construction( **{col_name: col[idx].as_py() for col_name, col in cols.items()} )