trajdl.datasets.modules.gmvsae module#

class trajdl.datasets.modules.gmvsae.GMVSAEDataModule(tokenizer: str | trajdl.tokenizers.abstract.AbstractTokenizer, train_parquet_path: str | None = None, val_parquet_path: str | None = None, test_parquet_path: str | None = None, train_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, val_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, test_table: pyarrow.lib.Table | polars.dataframe.frame.DataFrame | pandas.core.frame.DataFrame | trajdl.datasets.arrow.abstract.BaseArrowDataset | NoneType = None, train_batch_size: int = 2, val_batch_size: int = 2, train_sampler: torch.utils.data.sampler.Sampler | None = None, val_sampler: torch.utils.data.sampler.Sampler | None = None, num_cpus: int = 0, num_train_batches: int = 10, num_val_batches: int = 10, num_train_buckets: int = 10, num_val_buckets: int = 10)[source]#

Bases: BaseLocSeqDataModule

collate_function(ds: LocSeqDataset) GMVSAESample[source]#

返回5项 1. 编码器的序列 2. 编码器的长度 3. 解码器的输入序列 4. 解码器输入序列的长度 5. 解码器解码序列的label 6. 解码器需要计算损失的mask

num_train_batches: int = 10#
num_train_buckets: int = 10#
num_val_batches: int = 10#
num_val_buckets: int = 10#
setup(stage: str)[source]#

Set up the data module, loading the tokenizer and initializing datasets.

Parameters:

stage (str) – Stage of operation (e.g. ‘fit’, ‘test’).