Lightning DataModule#
Attention
在1.0.0版本发布前,当前文档的内容可能会发生变化。
LightningDataModule是Lightning里面的一个重要组成部分。使用Lightning框架训练的时候需要定义一个LightningDataModule用来管理数据。
See also
用户可以自行查阅LightningDataModule的官方文档来了解LightningDataModule的作用。
TrajDL提供了一个适配序列数据和轨迹数据的抽象子类BaseSeqDataModule。这个LightningDataModule已经针对TrajDL的BaseArrowDataset提供了一些基础功能的封装,针对序列数据提供了BaseLocSeqDataModule,针对轨迹数据提供了BaseTrajectoryDataModule。用户可以快速基于这两个子类进行LightningDataModule的构建。
简单来说具体使用流程是:
根据任务判断需要使用的是位置序列还是轨迹序列。
根据数据集快速构建训练、验证、测试集,这些数据集使用
BaseArrowDataset进行存储。根据任务编写
collate_function。选择
BaseSeqDataModule在其基础上编写自己的LightningDataModule,一般只要实现抽象方法的collate_function即可。
我们以TULER为例快速实现一个LightningDataModule,我们要完成如下步骤:
判断数据类型:使用
Gowalla数据集,用的是位置序列。基于
Polars快速构建训练、验证、测试集,构建Tokenizer编写
collate_function。基于
BaseLocSeqDataModule编写TULERDataModuleExample。
1. 构建训练、验证、测试集、Tokenizer#
from tqdm.notebook import tqdm
import polars as pl
import numpy as np
from trajdl.datasets.open_source import GowallaDataset
# 下载并加载Gowalla数据集
df = GowallaDataset().load(return_as="pl")
# 取id小于20的用户的数据,为了演示,我们把序列的长度也设置一下
df = (
df
.filter(pl.col("user_id").cast(pl.Int64) < 20)
.with_columns(pl.col("check_in_time").dt.strftime("%Y%m%d").alias("ds"))
.group_by("user_id", "ds")
.agg(pl.col("loc_id").sort_by(pl.col("check_in_time")).alias("loc_seq"))
.filter((pl.col("loc_seq").list.len() >= 5) & (pl.col("loc_seq").list.len() < 10))
.select("user_id", "ds", "loc_seq")
)
df.head()
load dataset: gowalla
| user_id | ds | loc_seq |
|---|---|---|
| str | str | list[str] |
| "4" | "20100613" | ["10677", "17313", … "25818"] |
| "0" | "20100610" | ["1251533", "210176", … "1221889"] |
| "17" | "20100221" | ["30595", "469525", … "433950"] |
| "0" | "20100819" | ["14515", "22765", … "19542"] |
| "18" | "20090924" | ["20172", "26341", … "26539"] |
# 构建一个user_map,key是user的id(字符串类型),value是user id转换后的下标(int类型)
user_map = {
user_id: idx for idx, user_id in enumerate(df.select(pl.col("user_id").unique())["user_id"])
}
user_map
{'4': 0,
'9': 1,
'15': 2,
'2': 3,
'10': 4,
'18': 5,
'0': 6,
'19': 7,
'17': 8,
'5': 9,
'14': 10,
'13': 11,
'7': 12}
# 添加一列叫sample_idx,表示当前序列的日期在这个用户所有序列里面的排名
add_sample_idx = df.with_columns(sample_idx=pl.int_range(pl.len()).over("user_id", order_by="ds"))
# 统计每个用户的位置序列数
num_locseqs_by_user = df.group_by("user_id").agg(pl.len().alias("num_locseqs"))
# 使用join,将每个用户的序列数join到第一个DataFrame上
tmp_df = add_sample_idx.join(num_locseqs_by_user, how="left", on=["user_id"])
tmp_df.head()
| user_id | ds | loc_seq | sample_idx | num_locseqs |
|---|---|---|---|---|
| str | str | list[str] | i64 | u32 |
| "4" | "20100613" | ["10677", "17313", … "25818"] | 10 | 11 |
| "0" | "20100610" | ["1251533", "210176", … "1221889"] | 0 | 7 |
| "17" | "20100221" | ["30595", "469525", … "433950"] | 1 | 6 |
| "0" | "20100819" | ["14515", "22765", … "19542"] | 1 | 7 |
| "18" | "20090924" | ["20172", "26341", … "26539"] | 1 | 2 |
# 针对每个用户,按ds划分训练、验证、测试集,比例是6: 2: 2
train_df = tmp_df.filter(pl.col("sample_idx") < pl.col("num_locseqs") * 0.6).select("user_id", "loc_seq")
val_df = tmp_df.filter((pl.col("sample_idx") >= pl.col("num_locseqs") * 0.6) & (pl.col("sample_idx") < pl.col("num_locseqs") * 0.8)).select("user_id", "loc_seq")
test_df = tmp_df.filter(pl.col("sample_idx") >= pl.col("num_locseqs") * 0.8).select("user_id", "loc_seq")
# 打印训练集,验证集和测试集的样本数,一行是一个样本
train_df.shape, val_df.shape, test_df.shape
((41, 2), (9, 2), (6, 2))
# 使用LocSeqDataset装载三个数据集,构建Tokenizer
from trajdl.datasets import LocSeq, LocSeqDataset
from trajdl.tokenizers import LocSeqTokenizer
def transform_dataframe_into_dataset(df: pl.DataFrame) -> LocSeqDataset:
"""
将一个Polars DataFrame转换为LocSeqDataset
"""
locseqs = [LocSeq(seq=loc_seq, entity_id=user_id) for user_id, loc_seq in df.iter_rows()]
return LocSeqDataset.init_from_loc_seqs(locseqs)
# 构建三个数据集
train_ds = transform_dataframe_into_dataset(train_df)
val_ds = transform_dataframe_into_dataset(val_df)
test_ds = transform_dataframe_into_dataset(test_df)
print("datasets:", train_ds, val_ds, test_ds)
# iter_as_seqs方法可以将BaseArrowDataset转换为单条序列的实例
tokenizer = LocSeqTokenizer.build(train_ds.iter_as_seqs())
datasets: LocSeqDataset(size=41) LocSeqDataset(size=9) LocSeqDataset(size=6)
2. 编写collate_function#
import torch
from typing import List, Optional, Dict
from dataclasses import dataclass
from torch.nn.utils.rnn import pad_sequence
from trajdl.common.samples import TULERSample
def collate_function(batch: LocSeqDataset, user_map: Dict[str, int], tokenizer: LocSeqTokenizer) -> TULERSample:
"""
将LocSeqDataset转换为TULER需要的样本
- 序列(含padding)
- 序列长度
- 标签(用户id)
"""
seqs: List[torch.LongTensor] = []
lengths: List[int] = []
labels: List[int] = []
# 取出我们需要的两列
seq_col = batch.seq
entity_id_col = batch.entity_id
# 按行遍历batch
for line_idx in range(len(batch)):
# 将位置序列使用tokenizer编码,以torch.LongTenso的类型返回
seqs.append(tokenizer.tokenize_loc_seq(seq_col[line_idx], return_as="pt"))
# 记录序列的长度
lengths.append(seqs[-1].shape[0])
# 添加标签,这里要用user_map将用户的id转换为idx
labels.append(user_map[entity_id_col[line_idx].as_py()])
# 返回样本
return TULERSample(
# 对序列添加padding,padding的值就用tokenizer维护的.pad属性
src=pad_sequence(seqs, batch_first=True, padding_value=tokenizer.pad),
seq_len=lengths,
labels=torch.LongTensor(labels),
)
3. 编写LightningDataModule#
因为我们使用的是位置序列,所以选择BaseLocSeqDataModule作为我们基座。
BaseLocSeqDataModule已经定义好了一些参数。
from trajdl.datasets.modules.abstract import BaseLocSeqDataModule
help(BaseLocSeqDataModule.__init__)
Help on function __init__ in module trajdl.datasets.modules.abstract:
__init__(self, tokenizer: Union[str, trajdl.tokenizers.abstract.AbstractTokenizer], train_parquet_path: Optional[str] = None, val_parquet_path: Optional[str] = None, test_parquet_path: Optional[str] = None, train_table: Union[pyarrow.lib.Table, polars.dataframe.frame.DataFrame, pandas.core.frame.DataFrame, trajdl.datasets.arrow.abstract.BaseArrowDataset, NoneType] = None, val_table: Union[pyarrow.lib.Table, polars.dataframe.frame.DataFrame, pandas.core.frame.DataFrame, trajdl.datasets.arrow.abstract.BaseArrowDataset, NoneType] = None, test_table: Union[pyarrow.lib.Table, polars.dataframe.frame.DataFrame, pandas.core.frame.DataFrame, trajdl.datasets.arrow.abstract.BaseArrowDataset, NoneType] = None, train_batch_size: int = 2, val_batch_size: int = 2, train_sampler: Optional[torch.utils.data.sampler.Sampler] = None, val_sampler: Optional[torch.utils.data.sampler.Sampler] = None, num_cpus: int = 0) -> None
Attributes:
prepare_data_per_node:
If True, each LOCAL_RANK=0 will call prepare data.
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
allow_zero_length_dataloader_with_multiple_devices:
If True, dataloader with zero length within local rank is allowed.
Default value is False.
tokenizer:
Tokenizer,这个直接传入即可train_parquet_path, val_parquet_path, test_parquet_path:可选参数,这些是数据集的路径,
BaseArrowDataset有一个save方法可以将数据集存储到文件,会以parquet的格式存储。train_table, val_table, test_table:可选参数:这些是数据集实例,也就是
BaseSeqDataModule是同时支持传入文件路径或直接传入BaseArrowDataset实例进行数据集配置的。train_batch_size, val_batch_size:训练集和验证集的batch_size,测试集的batch_size会使用
val_batch_size。train_sampler, val_sampler:可选参数:是否要传入
Sampler,测试集不会使用Sampler。num_cpus:可选参数,有多少个CPU就会将
DataLoader设置为多少个进程。
我们基于这个基类构建我们的TULERDataModuleExample,只要补充一个额外的user_map参数就好了,因为在collate_function里面要构建标签。
@dataclass
class TULERDataModuleExample(BaseLocSeqDataModule):
user_map: Optional[Dict[str, int]] = None
def __post_init__(self):
# 先调用父类的后处理,因为使用的是dataclass,所以需要做这一步
super().__post_init__()
# 检查一下user_map这个参数的类型
if not isinstance(self.user_map, dict):
raise ValueError(
"`user_map` should be a Dict[str, int] instance."
)
# 这个collate_function是一个抽象方法,子类必需实现
def collate_function(self, batch: LocSeqDataset) -> TULERSample:
# 因为父类已经存储了tokenizer,这里只要通过self.tokenizer即可获取
return collate_function(batch, self.user_map, self.tokenizer)
这个TULERDataModuleExample就编写完了,很简单,实际上就是增加一个user_map的参数,然后再编写一个collate_function,接下来我们测试一下,训练集的batch_size设置为2,验证集设置为3。
datamodule = TULERDataModuleExample(
tokenizer=tokenizer,
train_table=train_ds,
val_table=val_ds,
test_table=test_ds,
train_batch_size=2,
val_batch_size=3,
user_map=user_map)
datamodule.setup("fit")
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
test_loader = datamodule.test_dataloader()
next(iter(train_loader))
TULERSample(src=tensor([[171, 22, 172, 173, 174, 175, 176, 177],
[182, 183, 184, 185, 186, 218, 218, 218]]), seq_len=[8, 5], labels=tensor([ 3, 11]))
next(iter(val_loader))
TULERSample(src=tensor([[216, 216, 109, 216, 216, 216, 218],
[109, 216, 216, 216, 216, 216, 216],
[216, 216, 216, 216, 216, 216, 218]]), seq_len=[6, 7, 6], labels=tensor([2, 6, 3]))
next(iter(test_loader))
TULERSample(src=tensor([[216, 216, 216, 216, 216, 218],
[ 16, 3, 216, 3, 216, 216],
[216, 216, 201, 216, 216, 218]]), seq_len=[5, 6, 5], labels=tensor([0, 6, 3]))
BaseSeqDataModule抽象了TrajDL在训练、验证、测试过程中数据的pipeline:
其提供了训练集、验证集、测试集的统一加载方式,也就是用户通过对数据集加工得到
BaseArrowDataset后,可以直接放入BaseSeqDataModule里面,或者持久化之后又BaseSeqDataModule自动加载提供训练、验证、测试集的batch_size的配置,提供
Sampler的支持,提供Tokenizer的管理自动加载
BaseArrowDataset并构建DataLoader用户在继承其子类(
BaseLocSeqDataModule,BaseTrajectoryDataModule)的时候只要增加一些参数和自定义的collate_function即可。
Tip
LightningDataModule并不是必需使用的,因为LightningDataModule只是一个DataLoader的管理工具。用户可以根据自己的喜好自行定义训练流程,比如使用Pytorch原生的训练流程、或者使用Lightning Fabric,这些方式都可以自己定义DataLoader,不受TrajDL的约束。