trajdl.datasets.modules.t2vec module#
- class trajdl.datasets.modules.t2vec.T2VECDataModule(tokenizer: str | AbstractTokenizer, train_src_path: str, train_trg_path: str, val_src_path: str, val_trg_path: str, train_batch_size: int, val_batch_size: int, num_train_batches: int, buckets_boundaries: List[Tuple[int, int]], num_cpus: int = -1)[source]#
Bases:
LightningDataModuleT2VEC的DataModule,训练和验证用
- collate_fn_train(batch: List[Tuple[ListScalar, ListScalar]]) T2VECSample[source]#
batch: List[Tuple[pyarrow.ListScalar, pyarrow.ListScalar]]
- setup(stage: str)[source]#
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage – either
'fit','validate','test', or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- train_dataloader()[source]#
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader()[source]#
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.
- class trajdl.datasets.modules.t2vec.T2VECDataModuleV2(tokenizer: str | AbstractTokenizer, train_parquet_path: str | None = None, val_parquet_path: str | None = None, test_parquet_path: str | None = None, train_table: Table | DataFrame | DataFrame | BaseArrowDataset | None = None, val_table: Table | DataFrame | DataFrame | BaseArrowDataset | None = None, test_table: Table | DataFrame | DataFrame | BaseArrowDataset | None = None, train_batch_size: int = 2, val_batch_size: int = 2, train_sampler: Sampler | None = None, val_sampler: Sampler | None = None, num_cpus: int = 0, num_train_batches: int = 10, num_val_batches: int = 10, num_train_buckets: int = 10, num_val_buckets: int = 10, k: int = 1)[source]#
Bases:
BaseTrajectoryDataModulek是用来控制负样本倍数的,当k等于1的时候,一个正样本对应一个负样本,k等于2的时候,一个正样本对应两个负样本
- collate_function(ds: TrajectoryDataset) T2VECSample[source]#
Collate function for the dataset. Different modules may require different implementations.
- Parameters:
ds (BaseArrowDataset) – The dataset to collate.
- Returns:
The collated data.
- Return type:
Any
- k: int = 1#
- num_train_batches: int = 10#
- num_train_buckets: int = 10#
- num_val_batches: int = 10#
- num_val_buckets: int = 10#
- trajdl.datasets.modules.t2vec.downsampling_distort(traj: Trajectory) Trajectory[source]#
给定一条轨迹数据,随机进行下采样和扰动,返回一条新的轨迹
- Parameters:
traj (Trajectory) – 轨迹
- Returns:
traj – 经过下采样和扰动后的轨迹
- Return type:
- trajdl.datasets.modules.t2vec.generate_samples(tokenizer: T2VECTokenizer, traj: Trajectory) Tuple[LongTensor, int, LongTensor][source]#