Collate Function#
Attention
在1.0.0版本发布前,当前文档的内容可能会发生变化。
因为TrajDL是基于Pytorch Lightning构建的,本质依赖Pytorch的基础工具。因此在模型训练的时候,需要用户定义Dataset和DataLoader,有些时候还需要定义Sampler。我们在前面章节里面讲述的批数据,也就是BaseArrowDataset,这个就是Dataset的子类。用户构建好后需要了解的是如何在DataLoader里面操作它。
TrajDL当前提供的BaseArrowDataset都是May-style Dataset。这种Dataset可以直接通过下标进行数据的读取。因此这种数据集一般也不会太大,太大会导致内存放不下。
Tip
目前TrajDL支持的公开数据集都可以使用Map-style Dataset实现,因此未来TrajDL会尝试支持Iterable-style datasets,这在数据集非常大的场景非常必要。
接下来我们说一下DataLoader与Dataset的关系:
在不指定
Sampler的情况下,DataLoader在生成一个batch的时候会先生成一组index,然后用这一组index去Dataset里面获取数据。举个例子,当我们将batch_size设定为3的时候,且shuffle设定为False的时候,第一个batch的index是[0, 1, 2],所以DataLoader内部会使用这些下标去Dataset里面取数。然后如果DataLoader在发现取出来的数据恰好可以组成一个矩阵的时候,DataLoader会把它们拼接在一起,形成一个batch的torch.Tensor。但是在
TrajDL的场景里面,我们一般加载的是原始的位置序列或者轨迹序列,这些序列的长度并不相同,而且我们可能还需要做一些额外的操作,比如裁剪或扰动,这时候DataLoader里面有一个叫做collate_function的概念就派上了用场。collate_function会在使用下标[0, 1, 2]从Dataset取数之后发生调用,然后在最终的batch生成之前结束调用。而collate_function的返回值就是用户通过DataLoader拿到的batch数据。
基于上述信息,我们可以发现,TrajDL已经实现好了Dataset,而DataLoader里面最核心的工作就是collate_function,因此很多时候使用TrajDL进行新的算法开发的时候,一个重要的工作就是如何从BaseArrowDataset里面生成batch数据,即如何编写collate_function。
Note
对于Map-style Dataset,Pytorch是如何通过给定下标[0, 1, 2]然后进行数据加载的呢?
参照官方文档:torch.utils.data.Dataset,文档提到所有Dataset的子类需要实现__getitem__方法,这个方法是Python自带的魔术方法,实现后就可以使用[idx]进行数据索引。然而文档内还提到一个__getitems__方法,这个并不是Python自身的魔术方法,而是Pytorch定义的一个方法,而且并不是必须要实现,这个方法的作用是加速批量数据的加载。
我们可以找到DataLoader的源码,查看DataLoader是如何针对给定的indices进行数据加载的:fetch.py,从源码里面可以看到,Pytorch会优先使用Dataset的__getitems__方法进行数据加载,其次才是使用__getitem__方法,取完之后的数据调用了collate_fn方法。
这里有一个细节:
如果使用的是
__getitems__,那么传给collate_fn方法的参数是self.dataset.__getitems__(possibly_batched_index)如果使用的是
__getitem__,那么传给collate_fn方法的参数是[self.dataset[idx] for idx in possibly_batched_index]
从TrajDL前面的文档可以知道,BaseArrowDataset在通过[idx]取值的时候,返回的仍然是BaseArrowDataset,即__getitem__方法的返回值类型是BaseArrowDataset。同时,BaseArrowDataset也实现了__getitems__方法,其返回值类型仍然是BaseArrowDataset。这两者的区别在于size不同。因此对于TrajDL提供的BaseArrowDataset,用户在使用DataLoader加载的时候会默认使用__getitems__方法,所以传入collate_fn的参数一定是一个BaseArrowDataset,而不是List[BaseArrowDataset]。
接下来我们以TULER算法为例,演示如何编写一个DataLoader的collate_function。
TULER模型在训练的时候,需要三项内容:
一个位置序列组成的batch,类型是
torch.LongTensor,shape是(batch_size, num_timesteps)一个
List[int],表示每条序列的长度,其size是batch_size一个标签的batch,类型是
torch.LongTensor,shape是(batch_size,),每一项对应每条序列的实际标签。
这实际上是一个三元组。
那么我们要定义的collate_function就很简单了,只要接受一个batch的输入,返回一个三元组即可。
由于BaseArrowDataset实现了__getitems__方法,且返回一个BaseArrowDataset类型,这个batch的输入就是一个BaseArrowDataset,其size等于batch_size。
import polars as pl
from tqdm.notebook import tqdm
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from trajdl.datasets.open_source import GowallaDataset
from trajdl.datasets import LocSeq, LocSeqDataset
from trajdl.tokenizers import LocSeqTokenizer
# 取id小于5的用户的数据
df = (
GowallaDataset().load(return_as="pl")
.filter(pl.col("user_id").cast(pl.Int64) < 5)
.with_columns(pl.col("check_in_time").dt.strftime("%Y%m%d").alias("ds"))
.group_by("user_id", "ds")
.agg(pl.col("loc_id").sort_by(pl.col("check_in_time")).alias("loc_seq"))
.filter(pl.col("loc_seq").list.len() >= 5)
.select(pl.col("user_id").alias("id"), "loc_seq")
)
print(df.shape)
df.head()
load dataset: gowalla
(56, 2)
| id | loc_seq |
|---|---|
| str | list[str] |
| "4" | ["992976", "57155", … "57155"] |
| "2" | ["59838", "2390585", … "2371113"] |
| "2" | ["1275325", "1295851", … "1431584"] |
| "2" | ["617959", "43088", … "1077605"] |
| "2" | ["1807417", "187436", … "363837"] |
# 将user id存储到entity_id字段里面
loc_seqs = [LocSeq(seq=loc_seq, entity_id=user_id) for user_id, loc_seq in df.iter_rows()]
# 构建tokenizer
tokenizer = LocSeqTokenizer.build(loc_seqs)
# 构建一个BaseArrowDataset
train_ds = LocSeqDataset.init_from_loc_seqs(loc_seqs)
# 可以看到有56条序列
train_ds
LocSeqDataset(size=56)
def collate_function(batch: LocSeqDataset, tokenizer: LocSeqTokenizer):
"""
将LocSeqDataset转换为TULER需要的样本
- 序列(含padding)
- 序列长度
- 标签(用户id)
"""
seqs: List[torch.LongTensor] = []
lengths: List[int] = []
labels: List[int] = []
# 取出我们需要的两列
seq_col = batch.seq
entity_id_col = batch.entity_id
# 按行遍历batch
for line_idx in range(len(batch)):
# 将位置序列使用tokenizer编码,以torch.LongTenso的类型返回
seqs.append(tokenizer.tokenize_loc_seq(seq_col[line_idx], return_as="pt"))
# 记录序列的长度
lengths.append(seqs[-1].shape[0])
# 添加标签
labels.append(int(entity_id_col[line_idx].as_py()))
# 返回三元组
return (
# 对序列添加padding,padding的值就用tokenizer维护的.pad属性
pad_sequence(seqs, batch_first=True, padding_value=tokenizer.pad),
lengths,
torch.LongTensor(labels),
)
# 定义DataLoader
train_loader = DataLoader(train_ds, batch_size=2, collate_fn=lambda x: collate_function(x, tokenizer))
# 读取一个batch看看
iterator = iter(train_loader)
next(iterator)
(tensor([[ 2, 1, 154, 53, 1, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966],
[ 0, 155, 156, 157, 158, 159, 160, 161, 162, 163, 6, 164,
165, 166, 167, 168, 169, 3, 54, 55, 56, 57, 58, 170,
171, 172, 59, 60, 61, 173, 62, 13, 174, 14, 175, 15,
176, 177, 178, 179, 180, 181, 182, 183, 184, 63, 185, 186,
187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198,
199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
211, 212, 213, 214, 215, 216, 217, 218, 64, 219, 65, 66,
220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
232, 233, 16, 234, 235, 0, 236, 237, 238, 239, 67, 240,
241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 68, 251,
252, 253, 69, 254, 255, 256, 257, 70, 258, 259, 260, 261,
262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273,
274, 275, 276, 71, 277, 278, 279, 280, 281, 282, 283, 284,
285]]),
[5, 157],
tensor([4, 2]))
当然,我们还可以封装一下上面的三元组:
from typing import List, Optional
from dataclasses import dataclass
@dataclass
class TULERSampleExample:
seqs: torch.LongTensor
lengths: List[int]
labels: Optional[torch.LongTensor] = None
@property
def batch_size(self) -> int:
return len(self.lengths)
def collate_function_v2(batch: LocSeqDataset, tokenizer: LocSeqTokenizer):
"""
将collate_function返回值的三元组封装成一个TULERSampleExample
"""
seqs, lengths, labels = collate_function(batch, tokenizer)
# 返回TULERSampleExample
return TULERSampleExample(
# 对序列添加padding,padding的值就用tokenizer维护的.pad属性
seqs=seqs,
lengths=lengths,
labels=labels,
)
# 定义DataLoader
train_loader = DataLoader(train_ds, batch_size=2, collate_fn=lambda x: collate_function_v2(x, tokenizer))
# 读取一个batch看看
iterator = iter(train_loader)
batch = next(iterator)
batch
TULERSampleExample(seqs=tensor([[ 2, 1, 154, 53, 1, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
1966],
[ 0, 155, 156, 157, 158, 159, 160, 161, 162, 163, 6, 164,
165, 166, 167, 168, 169, 3, 54, 55, 56, 57, 58, 170,
171, 172, 59, 60, 61, 173, 62, 13, 174, 14, 175, 15,
176, 177, 178, 179, 180, 181, 182, 183, 184, 63, 185, 186,
187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198,
199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
211, 212, 213, 214, 215, 216, 217, 218, 64, 219, 65, 66,
220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
232, 233, 16, 234, 235, 0, 236, 237, 238, 239, 67, 240,
241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 68, 251,
252, 253, 69, 254, 255, 256, 257, 70, 258, 259, 260, 261,
262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273,
274, 275, 276, 71, 277, 278, 279, 280, 281, 282, 283, 284,
285]]), lengths=[5, 157], labels=tensor([4, 2]))
batch.batch_size
2
type(batch.seqs)
torch.Tensor
type(batch.lengths)
list
type(batch.labels)
torch.Tensor
Tip
TrajDL内支持的算法目前都已经提供了像上面TULERSampleExample一样的样本类,放在trajdl.common.samples目录下。
有一些算法需要的输入很多,比如5、6个参数,使用上述这样的样本类管理会比较清晰。