trajdl.datasets.modules.tuler module#

class trajdl.datasets.modules.tuler.TULERDataModule(tokenizer: str | trajdl.tokenizers.abstract.AbstractTokenizer, train_parquet_path: str | None = None, val_parquet_path: str | None = None, test_parquet_path: str | None = None, train_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, val_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, test_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, train_batch_size: int = 2, val_batch_size: int = 2, train_sampler: torch.utils.data.sampler.Sampler | None = None, val_sampler: torch.utils.data.sampler.Sampler | None = None, num_cpus: int = 0, user_map: str | Dict[str, int] | NoneType = None)[source]#

Bases: BaseLocSeqDataModule

collate_function(ds: LocSeqDataset)[source]#

Collate function for the dataset. Different modules may require different implementations.

Parameters:

ds (BaseArrowDataset) – The dataset to collate.

Returns:

The collated data.

Return type:

Any

setup(stage: str)[source]#

Set up the data module, loading the tokenizer and initializing datasets.

Parameters:

stage (str) – Stage of operation (e.g. β€˜fit’, β€˜test’).

user_map: str | Dict[str, int] | None = None#