trajdl.datasets.modules.t2vec module#

class trajdl.datasets.modules.t2vec.T2VECDataModule(tokenizer: str | AbstractTokenizer, train_src_path: str, train_trg_path: str, val_src_path: str, val_trg_path: str, train_batch_size: int, val_batch_size: int, num_train_batches: int, buckets_boundaries: List[Tuple[int, int]], num_cpus: int = -1)[source]#

Bases: LightningDataModule

T2VEC的DataModule,训练和验证用

collate_fn_train(batch: List[Tuple[ListScalar, ListScalar]]) T2VECSample[source]#

batch: List[Tuple[pyarrow.ListScalar, pyarrow.ListScalar]]

setup(stage: str)[source]#

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
train_dataloader()[source]#

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader()[source]#

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

class trajdl.datasets.modules.t2vec.T2VECDataModuleV2(tokenizer: str | AbstractTokenizer, train_parquet_path: str | None = None, val_parquet_path: str | None = None, test_parquet_path: str | None = None, train_table: Table | DataFrame | DataFrame | BaseArrowDataset | None = None, val_table: Table | DataFrame | DataFrame | BaseArrowDataset | None = None, test_table: Table | DataFrame | DataFrame | BaseArrowDataset | None = None, train_batch_size: int = 2, val_batch_size: int = 2, train_sampler: Sampler | None = None, val_sampler: Sampler | None = None, num_cpus: int = 0, num_train_batches: int = 10, num_val_batches: int = 10, num_train_buckets: int = 10, num_val_buckets: int = 10, k: int = 1)[source]#

Bases: BaseTrajectoryDataModule

k是用来控制负样本倍数的,当k等于1的时候,一个正样本对应一个负样本,k等于2的时候,一个正样本对应两个负样本

collate_function(ds: TrajectoryDataset) T2VECSample[source]#

Collate function for the dataset. Different modules may require different implementations.

Parameters:

ds (BaseArrowDataset) – The dataset to collate.

Returns:

The collated data.

Return type:

Any

k: int = 1#
num_train_batches: int = 10#
num_train_buckets: int = 10#
num_val_batches: int = 10#
num_val_buckets: int = 10#
setup(stage: str)[source]#

Set up the data module, loading the tokenizer and initializing datasets.

Parameters:

stage (str) – Stage of operation (e.g. ‘fit’, ‘test’).

trajdl.datasets.modules.t2vec.downsampling_distort(traj: Trajectory) Trajectory[source]#

给定一条轨迹数据,随机进行下采样和扰动,返回一条新的轨迹

Parameters:

traj (Trajectory) – 轨迹

Returns:

traj – 经过下采样和扰动后的轨迹

Return type:

Trajectory

trajdl.datasets.modules.t2vec.generate_samples(tokenizer: T2VECTokenizer, traj: Trajectory) Tuple[LongTensor, int, LongTensor][source]#