trajdl.datasets.modules.gmvsae module#
- class trajdl.datasets.modules.gmvsae.GMVSAEDataModule(tokenizer: str | trajdl.tokenizers.abstract.AbstractTokenizer, train_parquet_path: str | None = None, val_parquet_path: str | None = None, test_parquet_path: str | None = None, train_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, val_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, test_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, train_batch_size: int = 2, val_batch_size: int = 2, train_sampler: torch.utils.data.sampler.Sampler | None = None, val_sampler: torch.utils.data.sampler.Sampler | None = None, num_cpus: int = 0, num_train_batches: int = 10, num_val_batches: int = 10, num_train_buckets: int = 10, num_val_buckets: int = 10)[source]#
Bases:
BaseLocSeqDataModule- collate_function(ds: LocSeqDataset) GMVSAESample[source]#
返回5项 1. 编码器的序列 2. 编码器的长度 3. 解码器的输入序列 4. 解码器输入序列的长度 5. 解码器解码序列的label 6. 解码器需要计算损失的mask
- num_train_batches: int = 10#
- num_train_buckets: int = 10#
- num_val_batches: int = 10#
- num_val_buckets: int = 10#