trajdl.datasets.modules.stlstm module#
- class trajdl.datasets.modules.stlstm.HSTLSTMDataModule(tokenizer: str | trajdl.tokenizers.abstract.AbstractTokenizer, train_parquet_path: str | None = None, val_parquet_path: str | None = None, test_parquet_path: str | None = None, train_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, val_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, test_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, train_batch_size: int = 2, val_batch_size: int = 2, train_sampler: torch.utils.data.sampler.Sampler | None = None, val_sampler: torch.utils.data.sampler.Sampler | None = None, num_cpus: int = 0, num_train_batches: int = 10, num_val_batches: int = 10, num_train_buckets: int = 10, num_val_buckets: int = 10)[source]#
Bases:
BaseLocSeqDataModule- collate_function(ds: LocSeqDataset)[source]#
Collate function for the dataset. Different modules may require different implementations.
- Parameters:
ds (BaseArrowDataset) β The dataset to collate.
- Returns:
The collated data.
- Return type:
Any
- num_train_batches: int = 10#
- num_train_buckets: int = 10#
- num_val_batches: int = 10#
- num_val_buckets: int = 10#
- class trajdl.datasets.modules.stlstm.STLSTMDataModule(tokenizer: str | trajdl.tokenizers.abstract.AbstractTokenizer, train_parquet_path: str | None = None, val_parquet_path: str | None = None, test_parquet_path: str | None = None, train_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, val_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, test_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, train_batch_size: int = 2, val_batch_size: int = 2, train_sampler: torch.utils.data.sampler.Sampler | None = None, val_sampler: torch.utils.data.sampler.Sampler | None = None, num_cpus: int = 0, num_train_batches: int = 10, num_val_batches: int = 10, num_train_buckets: int = 10, num_val_buckets: int = 10, ts_bucketizer: Any | None = None, loc_bucketizer: Any | None = None)[source]#
Bases:
BaseLocSeqDataModule- collate_function(batch_ds: LocSeqDataset)[source]#
Collate function for the dataset. Different modules may require different implementations.
- Parameters:
ds (BaseArrowDataset) β The dataset to collate.
- Returns:
The collated data.
- Return type:
Any
- loc_bucketizer: Any | None = None#
- num_train_batches: int = 10#
- num_train_buckets: int = 10#
- num_val_batches: int = 10#
- num_val_buckets: int = 10#
- setup(stage: str)[source]#
Set up the data module, loading the tokenizer and initializing datasets.
- Parameters:
stage (str) β Stage of operation (e.g. βfitβ, βtestβ).
- ts_bucketizer: Any | None = None#