BaseArrowDataSet

Contents

BaseArrowDataSet#

Attention

在1.0.0版本发布前,当前文档的内容可能会发生变化。

BaseArrowDatasetTrajDL在模型训练时主要使用的数据集,它的底层数据全都是用pyarrow.Table进行存储的,这就意味着BaseArrowDataset与任何支持Arrow的工具都是可以打通的。在Python科学计算领域里与Arrow关系紧密的框架有PolarsPandas等,当然还有Spark。这些框架可以利用支持Arrow的特性直接与BaseArrowDataset进行交互。

1.  init_from_table#

BaseArrowTable有一个init_from_table的方法,支持直接从polars.DataFramepandas.DataFramepyarrow.Table加载数据到BaseArrowTable内。这个方法要求用户清楚的知道BaseArrowTable的schema。

我们以LocSeqDataset为例进行讲解,我们首先可以通过.schema()方法获取其pyarrow.Table的schema。

from trajdl.datasets import LocSeqDataset
LocSeqDataset.schema()
seq: large_list<item: large_string>
  child 0, item: large_string
entity_id: large_string
ts_seq: large_list<item: int64>
  child 0, item: int64
ts_delta: large_list<item: float>
  child 0, item: float
dis_delta: large_list<item: float>
  child 0, item: float
start_ts: int64

可以看到,一共有六列,并且每一列都有已经设置好的类型,那么如果我们拿到一个符合这个schema的pyarrow.Table,就可以直接构建一个LocSeqDataset,我们用Polars做一个演示。

import polars as pl

# 构建一个只有2列的DataFrame
df = pl.DataFrame({
    "seq": [
        ["a", "b", "c"],
        ["b"],
        ["c", "d"],
    ],
    "entity_id": ["1", "2", "3"],
})
df.head()
shape: (3, 2)
seqentity_id
list[str]str
["a", "b", "c"]"1"
["b"]"2"
["c", "d"]"3"
# 转换成pyarrow.Table看一下
df.to_arrow()
pyarrow.Table
seq: large_list<item: large_string>
  child 0, item: large_string
entity_id: large_string
----
seq: [[["a","b","c"],["b"],["c","d"]]]
entity_id: [["1","2","3"]]
# 通过init_from_arrow可以直接加载
ds = LocSeqDataset.init_from_arrow(df.to_arrow())
ds
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<ts_seq: large_list<item: int64>> does not exist in the input table.
  warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<ts_delta: large_list<item: float>> does not exist in the input table.
  warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<dis_delta: large_list<item: float>> does not exist in the input table.
  warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<start_ts: int64> does not exist in the input table.
  warnings.warn(
LocSeqDataset(size=3)
ds.schema()
seq: large_list<item: large_string>
  child 0, item: large_string
entity_id: large_string
ts_seq: large_list<item: int64>
  child 0, item: int64
ts_delta: large_list<item: float>
  child 0, item: float
dis_delta: large_list<item: float>
  child 0, item: float
start_ts: int64
ds.seq
<pyarrow.lib.ChunkedArray object at 0x7f207ce19c00>
[
  [
    [
      "a",
      "b",
      "c"
    ],
    [
      "b"
    ],
    [
      "c",
      "d"
    ]
  ]
]
ds.entity_id
<pyarrow.lib.ChunkedArray object at 0x7f20d5244be0>
[
  [
    "1",
    "2",
    "3"
  ]
]

可以看到,我们通过Polars构建了一个简单的polars.DataFrame,里面只有3条序列,并且只有2个属性,这两个属性与LocSeqDataset定义的一致,Polars底层的Arrow Table就可以快速导入到LocSeqDataset里面,因为使用的是Arrow,底层没有发生数据拷贝,因此这个转换的性能是极高的。

同理,这种转换也支持PandasPyArrow等其他工具。

df.to_pandas()
seq entity_id
0 [a, b, c] 1
1 [b] 2
2 [c, d] 3
import pyarrow as pa

# 从pandas加载
arrow_table = pa.Table.from_pandas(df.to_pandas())
LocSeqDataset.init_from_arrow(arrow_table)
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<ts_seq: large_list<item: int64>> does not exist in the input table.
  warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<ts_delta: large_list<item: float>> does not exist in the input table.
  warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<dis_delta: large_list<item: float>> does not exist in the input table.
  warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<start_ts: int64> does not exist in the input table.
  warnings.warn(
LocSeqDataset(size=3)

当然,为了方便使用,TrajDL将上述的DataFrame进行了统一封装,得到了init_from_table方法

help(LocSeqDataset.init_from_table)
Help on method init_from_table in module trajdl.datasets.arrow.abstract:

init_from_table(table: Union[pyarrow.lib.Table, polars.dataframe.frame.DataFrame, pandas.core.frame.DataFrame]) -> 'BaseArrowDataset' class method of trajdl.datasets.arrow.locseq.LocSeqDataset
    Initialize the dataset from a Polars, Pandas, or Arrow table.

    Parameters
    ----------
    table : Union[pa.Table, pl.DataFrame, pd.DataFrame]
        The input table to initialize the dataset.

    Returns
    -------
    BaseArrowDataset
        An instance of the dataset initialized from the provided table.

    Raises
    ------
    ValueError
        If the input is not one of the accepted table types.
LocSeqDataset.init_from_table(df)
LocSeqDataset(size=3)
LocSeqDataset.init_from_table(df.to_pandas())
LocSeqDataset(size=3)
LocSeqDataset.init_from_table(df.to_arrow())
LocSeqDataset(size=3)

Tip

TrajDL建议用户使用Polars框架,因为Polars框架在单机上的性能比较好,其数据类型多样,而且与Arrow的交互非常清晰。

接下来我们已TrajectoryDataset为例来实验一下。

from trajdl.datasets.open_source.conf import PortoDataset

# 只取前两条数据作为演示
df = PortoDataset().load().head(2)
df.head()
load dataset: porto
shape: (2, 9)
TRIP_IDCALL_TYPEORIGIN_CALLORIGIN_STANDTAXI_IDTIMESTAMPDAY_TYPEMISSING_DATAPOLYLINE
strstri64i64i64i64strboollist[array[f64, 2]]
"1372636858620000589""C"nullnull200005891372636858"A"false[[-8.618643, 41.141412], [-8.618499, 41.141376], … [-8.630838, 41.154489]]
"1372637303620000596""B"null7200005961372637303"A"false[[-8.639847, 41.159826], [-8.640351, 41.159871], … [-8.66574, 41.170671]]
from trajdl.datasets import TrajectoryDataset

TrajectoryDataset.schema()
seq: large_list<item: fixed_size_list<item: double>[2]>
  child 0, item: fixed_size_list<item: double>[2]
      child 0, item: double
entity_id: large_string
ts_seq: large_list<item: int64>
  child 0, item: int64
ts_delta: large_list<item: float>
  child 0, item: float
dis_delta: large_list<item: float>
  child 0, item: float
start_ts: int64
import polars as pl

# 取出轨迹,TAXI_ID作为entity_id
new_df = df.select(pl.col("POLYLINE").alias("seq"), pl.col("TAXI_ID").cast(pl.String).alias("entity_id"))
new_df.head()
shape: (2, 2)
seqentity_id
list[array[f64, 2]]str
[[-8.618643, 41.141412], [-8.618499, 41.141376], … [-8.630838, 41.154489]]"20000589"
[[-8.639847, 41.159826], [-8.640351, 41.159871], … [-8.66574, 41.170671]]"20000596"
ds = TrajectoryDataset.init_from_table(new_df)
ds
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<ts_seq: large_list<item: int64>> does not exist in the input table.
  warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<ts_delta: large_list<item: float>> does not exist in the input table.
  warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<dis_delta: large_list<item: float>> does not exist in the input table.
  warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<start_ts: int64> does not exist in the input table.
  warnings.warn(
TrajectoryDataset(size=2)
ds.seq
<pyarrow.lib.ChunkedArray object at 0x7f207ce961a0>
[
  [
    [
      [
        -8.618643,
        41.141412
      ],
      [
        -8.618499,
        41.141376
      ],
      ...
      [
        -8.630829,
        41.154498
      ],
      [
        -8.630838,
        41.154489
      ]
    ],
    [
      [
        -8.639847,
        41.159826
      ],
      [
        -8.640351,
        41.159871
      ],
      ...
      [
        -8.665767,
        41.170635
      ],
      [
        -8.66574,
        41.170671
      ]
    ]
  ]
]
ds.entity_id
<pyarrow.lib.ChunkedArray object at 0x7f206fb91000>
[
  [
    "20000589",
    "20000596"
  ]
]

Tip

本文讲解了BaseArrowDataset是如何与支持Arrow的框架进行交互的,这种交互方式跨过了单条序列BaseSeq,直接从科学计算框架的底层数据进行加载,避免了数据拷贝,极大了优化了数据集的构建速度。并且基于Arrow的数据集在Pytorch框架使用多进程DataLoader的时候可以显著减少内存使用。