trajdl.datasets.modules.stlstm module#

class trajdl.datasets.modules.stlstm.HSTLSTMDataModule(tokenizer: str | trajdl.tokenizers.abstract.AbstractTokenizer, train_parquet_path: str | None = None, val_parquet_path: str | None = None, test_parquet_path: str | None = None, train_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, val_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, test_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, train_batch_size: int = 2, val_batch_size: int = 2, train_sampler: torch.utils.data.sampler.Sampler | None = None, val_sampler: torch.utils.data.sampler.Sampler | None = None, num_cpus: int = 0, num_train_batches: int = 10, num_val_batches: int = 10, num_train_buckets: int = 10, num_val_buckets: int = 10)[source]#

Bases: BaseLocSeqDataModule

collate_function(ds: LocSeqDataset)[source]#

Collate function for the dataset. Different modules may require different implementations.

Parameters:

ds (BaseArrowDataset) – The dataset to collate.

Returns:

The collated data.

Return type:

Any

num_train_batches: int = 10#
num_train_buckets: int = 10#
num_val_batches: int = 10#
num_val_buckets: int = 10#
setup(stage: str)[source]#

Set up the data module, loading the tokenizer and initializing datasets.

Parameters:

stage (str) – Stage of operation (e.g. β€˜fit’, β€˜test’).

class trajdl.datasets.modules.stlstm.STLSTMDataModule(tokenizer: str | trajdl.tokenizers.abstract.AbstractTokenizer, train_parquet_path: str | None = None, val_parquet_path: str | None = None, test_parquet_path: str | None = None, train_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, val_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, test_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, train_batch_size: int = 2, val_batch_size: int = 2, train_sampler: torch.utils.data.sampler.Sampler | None = None, val_sampler: torch.utils.data.sampler.Sampler | None = None, num_cpus: int = 0, num_train_batches: int = 10, num_val_batches: int = 10, num_train_buckets: int = 10, num_val_buckets: int = 10, ts_bucketizer: Any | None = None, loc_bucketizer: Any | None = None)[source]#

Bases: BaseLocSeqDataModule

collate_function(batch_ds: LocSeqDataset)[source]#

Collate function for the dataset. Different modules may require different implementations.

Parameters:

ds (BaseArrowDataset) – The dataset to collate.

Returns:

The collated data.

Return type:

Any

loc_bucketizer: Any | None = None#
num_train_batches: int = 10#
num_train_buckets: int = 10#
num_val_batches: int = 10#
num_val_buckets: int = 10#
setup(stage: str)[source]#

Set up the data module, loading the tokenizer and initializing datasets.

Parameters:

stage (str) – Stage of operation (e.g. β€˜fit’, β€˜test’).

ts_bucketizer: Any | None = None#