Collate Function

Collate Function#

Attention

在1.0.0版本发布前,当前文档的内容可能会发生变化。

因为TrajDL是基于Pytorch Lightning构建的,本质依赖Pytorch的基础工具。因此在模型训练的时候,需要用户定义DatasetDataLoader,有些时候还需要定义Sampler。我们在前面章节里面讲述的批数据,也就是BaseArrowDataset,这个就是Dataset的子类。用户构建好后需要了解的是如何在DataLoader里面操作它。

TrajDL当前提供的BaseArrowDataset都是May-style Dataset。这种Dataset可以直接通过下标进行数据的读取。因此这种数据集一般也不会太大,太大会导致内存放不下。

Tip

目前TrajDL支持的公开数据集都可以使用Map-style Dataset实现,因此未来TrajDL会尝试支持Iterable-style datasets,这在数据集非常大的场景非常必要。

接下来我们说一下DataLoaderDataset的关系:

  1. 在不指定Sampler的情况下,DataLoader在生成一个batch的时候会先生成一组index,然后用这一组index去Dataset里面获取数据。举个例子,当我们将batch_size设定为3的时候,且shuffle设定为False的时候,第一个batch的index是[0, 1, 2],所以DataLoader内部会使用这些下标去Dataset里面取数。然后如果DataLoader在发现取出来的数据恰好可以组成一个矩阵的时候,DataLoader会把它们拼接在一起,形成一个batch的torch.Tensor

  2. 但是在TrajDL的场景里面,我们一般加载的是原始的位置序列或者轨迹序列,这些序列的长度并不相同,而且我们可能还需要做一些额外的操作,比如裁剪或扰动,这时候DataLoader里面有一个叫做collate_function的概念就派上了用场。collate_function会在使用下标[0, 1, 2]Dataset取数之后发生调用,然后在最终的batch生成之前结束调用。而collate_function的返回值就是用户通过DataLoader拿到的batch数据。

基于上述信息,我们可以发现,TrajDL已经实现好了Dataset,而DataLoader里面最核心的工作就是collate_function,因此很多时候使用TrajDL进行新的算法开发的时候,一个重要的工作就是如何从BaseArrowDataset里面生成batch数据,即如何编写collate_function

Note

对于Map-style DatasetPytorch是如何通过给定下标[0, 1, 2]然后进行数据加载的呢?

参照官方文档:torch.utils.data.Dataset,文档提到所有Dataset的子类需要实现__getitem__方法,这个方法是Python自带的魔术方法,实现后就可以使用[idx]进行数据索引。然而文档内还提到一个__getitems__方法,这个并不是Python自身的魔术方法,而是Pytorch定义的一个方法,而且并不是必须要实现,这个方法的作用是加速批量数据的加载

我们可以找到DataLoader的源码,查看DataLoader是如何针对给定的indices进行数据加载的:fetch.py,从源码里面可以看到,Pytorch会优先使用Dataset__getitems__方法进行数据加载,其次才是使用__getitem__方法,取完之后的数据调用了collate_fn方法。

这里有一个细节:

  • 如果使用的是__getitems__,那么传给collate_fn方法的参数是self.dataset.__getitems__(possibly_batched_index)

  • 如果使用的是__getitem__,那么传给collate_fn方法的参数是[self.dataset[idx] for idx in possibly_batched_index]

TrajDL前面的文档可以知道,BaseArrowDataset在通过[idx]取值的时候,返回的仍然是BaseArrowDataset,即__getitem__方法的返回值类型是BaseArrowDataset。同时,BaseArrowDataset也实现了__getitems__方法,其返回值类型仍然是BaseArrowDataset。这两者的区别在于size不同。因此对于TrajDL提供的BaseArrowDataset,用户在使用DataLoader加载的时候会默认使用__getitems__方法,所以传入collate_fn的参数一定是一个BaseArrowDataset,而不是List[BaseArrowDataset]

接下来我们以TULER算法为例,演示如何编写一个DataLoadercollate_function

TULER模型在训练的时候,需要三项内容:

  1. 一个位置序列组成的batch,类型是torch.LongTensor,shape是(batch_size, num_timesteps)

  2. 一个List[int],表示每条序列的长度,其size是batch_size

  3. 一个标签的batch,类型是torch.LongTensor,shape是(batch_size,),每一项对应每条序列的实际标签。

这实际上是一个三元组。

那么我们要定义的collate_function就很简单了,只要接受一个batch的输入,返回一个三元组即可。

由于BaseArrowDataset实现了__getitems__方法,且返回一个BaseArrowDataset类型,这个batch的输入就是一个BaseArrowDataset,其size等于batch_size

import polars as pl
from tqdm.notebook import tqdm
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

from trajdl.datasets.open_source import GowallaDataset
from trajdl.datasets import LocSeq, LocSeqDataset
from trajdl.tokenizers import LocSeqTokenizer

# 取id小于5的用户的数据
df = (
    GowallaDataset().load(return_as="pl")
    .filter(pl.col("user_id").cast(pl.Int64) < 5)
    .with_columns(pl.col("check_in_time").dt.strftime("%Y%m%d").alias("ds"))
    .group_by("user_id", "ds")
    .agg(pl.col("loc_id").sort_by(pl.col("check_in_time")).alias("loc_seq"))
    .filter(pl.col("loc_seq").list.len() >= 5)
    .select(pl.col("user_id").alias("id"), "loc_seq")
)
print(df.shape)
df.head()
load dataset: gowalla
(56, 2)
shape: (5, 2)
idloc_seq
strlist[str]
"4"["992976", "57155", … "57155"]
"2"["59838", "2390585", … "2371113"]
"2"["1275325", "1295851", … "1431584"]
"2"["617959", "43088", … "1077605"]
"2"["1807417", "187436", … "363837"]
# 将user id存储到entity_id字段里面
loc_seqs = [LocSeq(seq=loc_seq, entity_id=user_id) for user_id, loc_seq in df.iter_rows()]

# 构建tokenizer
tokenizer = LocSeqTokenizer.build(loc_seqs)

# 构建一个BaseArrowDataset
train_ds = LocSeqDataset.init_from_loc_seqs(loc_seqs)

# 可以看到有56条序列
train_ds
LocSeqDataset(size=56)
def collate_function(batch: LocSeqDataset, tokenizer: LocSeqTokenizer):
    """
    将LocSeqDataset转换为TULER需要的样本
    - 序列(含padding)
    - 序列长度
    - 标签(用户id)
    """
    seqs: List[torch.LongTensor] = []
    lengths: List[int] = []
    labels: List[int] = []

    # 取出我们需要的两列
    seq_col = batch.seq
    entity_id_col = batch.entity_id

    # 按行遍历batch
    for line_idx in range(len(batch)):
        # 将位置序列使用tokenizer编码,以torch.LongTenso的类型返回
        seqs.append(tokenizer.tokenize_loc_seq(seq_col[line_idx], return_as="pt"))
        # 记录序列的长度
        lengths.append(seqs[-1].shape[0])
        # 添加标签
        labels.append(int(entity_id_col[line_idx].as_py()))

    # 返回三元组
    return (
        # 对序列添加padding,padding的值就用tokenizer维护的.pad属性
        pad_sequence(seqs, batch_first=True, padding_value=tokenizer.pad),
        lengths,
        torch.LongTensor(labels),
    )
# 定义DataLoader
train_loader = DataLoader(train_ds, batch_size=2, collate_fn=lambda x: collate_function(x, tokenizer))

# 读取一个batch看看
iterator = iter(train_loader)
next(iterator)
(tensor([[   2,    1,  154,   53,    1, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
          1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
          1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
          1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
          1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
          1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
          1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
          1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
          1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
          1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
          1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
          1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
          1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
          1966],
         [   0,  155,  156,  157,  158,  159,  160,  161,  162,  163,    6,  164,
           165,  166,  167,  168,  169,    3,   54,   55,   56,   57,   58,  170,
           171,  172,   59,   60,   61,  173,   62,   13,  174,   14,  175,   15,
           176,  177,  178,  179,  180,  181,  182,  183,  184,   63,  185,  186,
           187,  188,  189,  190,  191,  192,  193,  194,  195,  196,  197,  198,
           199,  200,  201,  202,  203,  204,  205,  206,  207,  208,  209,  210,
           211,  212,  213,  214,  215,  216,  217,  218,   64,  219,   65,   66,
           220,  221,  222,  223,  224,  225,  226,  227,  228,  229,  230,  231,
           232,  233,   16,  234,  235,    0,  236,  237,  238,  239,   67,  240,
           241,  242,  243,  244,  245,  246,  247,  248,  249,  250,   68,  251,
           252,  253,   69,  254,  255,  256,  257,   70,  258,  259,  260,  261,
           262,  263,  264,  265,  266,  267,  268,  269,  270,  271,  272,  273,
           274,  275,  276,   71,  277,  278,  279,  280,  281,  282,  283,  284,
           285]]),
 [5, 157],
 tensor([4, 2]))

当然,我们还可以封装一下上面的三元组:

from typing import List, Optional
from dataclasses import dataclass

@dataclass
class TULERSampleExample:
    seqs: torch.LongTensor
    lengths: List[int]
    labels: Optional[torch.LongTensor] = None

    @property
    def batch_size(self) -> int:
        return len(self.lengths)


def collate_function_v2(batch: LocSeqDataset, tokenizer: LocSeqTokenizer):
    """
    将collate_function返回值的三元组封装成一个TULERSampleExample
    """
    seqs, lengths, labels = collate_function(batch, tokenizer)
    
    # 返回TULERSampleExample
    return TULERSampleExample(
        # 对序列添加padding,padding的值就用tokenizer维护的.pad属性
        seqs=seqs,
        lengths=lengths,
        labels=labels,
    )
# 定义DataLoader
train_loader = DataLoader(train_ds, batch_size=2, collate_fn=lambda x: collate_function_v2(x, tokenizer))

# 读取一个batch看看
iterator = iter(train_loader)
batch = next(iterator)
batch
TULERSampleExample(seqs=tensor([[   2,    1,  154,   53,    1, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
         1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
         1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
         1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
         1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
         1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
         1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
         1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
         1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
         1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
         1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
         1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
         1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966, 1966,
         1966],
        [   0,  155,  156,  157,  158,  159,  160,  161,  162,  163,    6,  164,
          165,  166,  167,  168,  169,    3,   54,   55,   56,   57,   58,  170,
          171,  172,   59,   60,   61,  173,   62,   13,  174,   14,  175,   15,
          176,  177,  178,  179,  180,  181,  182,  183,  184,   63,  185,  186,
          187,  188,  189,  190,  191,  192,  193,  194,  195,  196,  197,  198,
          199,  200,  201,  202,  203,  204,  205,  206,  207,  208,  209,  210,
          211,  212,  213,  214,  215,  216,  217,  218,   64,  219,   65,   66,
          220,  221,  222,  223,  224,  225,  226,  227,  228,  229,  230,  231,
          232,  233,   16,  234,  235,    0,  236,  237,  238,  239,   67,  240,
          241,  242,  243,  244,  245,  246,  247,  248,  249,  250,   68,  251,
          252,  253,   69,  254,  255,  256,  257,   70,  258,  259,  260,  261,
          262,  263,  264,  265,  266,  267,  268,  269,  270,  271,  272,  273,
          274,  275,  276,   71,  277,  278,  279,  280,  281,  282,  283,  284,
          285]]), lengths=[5, 157], labels=tensor([4, 2]))
batch.batch_size
2
type(batch.seqs)
torch.Tensor
type(batch.lengths)
list
type(batch.labels)
torch.Tensor

Tip

TrajDL内支持的算法目前都已经提供了像上面TULERSampleExample一样的样本类,放在trajdl.common.samples目录下。

有一些算法需要的输入很多,比如5、6个参数,使用上述这样的样本类管理会比较清晰。