trajdl.datasets.modules.ctle module

trajdl.datasets.modules.ctle module#

class trajdl.datasets.modules.ctle.CTLEDataModule(tokenizer: str | trajdl.tokenizers.abstract.AbstractTokenizer, train_parquet_path: str | None = None, val_parquet_path: str | None = None, test_parquet_path: str | None = None, train_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, val_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, test_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, train_batch_size: int = 2, val_batch_size: int = 2, train_sampler: torch.utils.data.sampler.Sampler | None = None, val_sampler: torch.utils.data.sampler.Sampler | None = None, num_cpus: int = 0, mask_prob: float = 0.2)[source]#

Bases: BaseLocSeqDataModule

collate_function(ds: LocSeqDataset)[source]#

Collate function for the dataset. Different modules may require different implementations.

Parameters:

ds (BaseArrowDataset) – The dataset to collate.

Returns:

The collated data.

Return type:

Any

mask_prob: float = 0.2#