BaseArrowDataSet#
Attention
在1.0.0版本发布前,当前文档的内容可能会发生变化。
BaseArrowDataset是TrajDL在模型训练时主要使用的数据集,它的底层数据全都是用pyarrow.Table进行存储的,这就意味着BaseArrowDataset与任何支持Arrow的工具都是可以打通的。在Python科学计算领域里与Arrow关系紧密的框架有Polars,Pandas等,当然还有Spark。这些框架可以利用支持Arrow的特性直接与BaseArrowDataset进行交互。
1. init_from_table#
BaseArrowTable有一个init_from_table的方法,支持直接从polars.DataFrame,pandas.DataFrame,pyarrow.Table加载数据到BaseArrowTable内。这个方法要求用户清楚的知道BaseArrowTable的schema。
我们以LocSeqDataset为例进行讲解,我们首先可以通过.schema()方法获取其pyarrow.Table的schema。
from trajdl.datasets import LocSeqDataset
LocSeqDataset.schema()
seq: large_list<item: large_string>
child 0, item: large_string
entity_id: large_string
ts_seq: large_list<item: int64>
child 0, item: int64
ts_delta: large_list<item: float>
child 0, item: float
dis_delta: large_list<item: float>
child 0, item: float
start_ts: int64
可以看到,一共有六列,并且每一列都有已经设置好的类型,那么如果我们拿到一个符合这个schema的pyarrow.Table,就可以直接构建一个LocSeqDataset,我们用Polars做一个演示。
import polars as pl
# 构建一个只有2列的DataFrame
df = pl.DataFrame({
"seq": [
["a", "b", "c"],
["b"],
["c", "d"],
],
"entity_id": ["1", "2", "3"],
})
df.head()
| seq | entity_id |
|---|---|
| list[str] | str |
| ["a", "b", "c"] | "1" |
| ["b"] | "2" |
| ["c", "d"] | "3" |
# 转换成pyarrow.Table看一下
df.to_arrow()
pyarrow.Table
seq: large_list<item: large_string>
child 0, item: large_string
entity_id: large_string
----
seq: [[["a","b","c"],["b"],["c","d"]]]
entity_id: [["1","2","3"]]
# 通过init_from_arrow可以直接加载
ds = LocSeqDataset.init_from_arrow(df.to_arrow())
ds
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<ts_seq: large_list<item: int64>> does not exist in the input table.
warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<ts_delta: large_list<item: float>> does not exist in the input table.
warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<dis_delta: large_list<item: float>> does not exist in the input table.
warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<start_ts: int64> does not exist in the input table.
warnings.warn(
LocSeqDataset(size=3)
ds.schema()
seq: large_list<item: large_string>
child 0, item: large_string
entity_id: large_string
ts_seq: large_list<item: int64>
child 0, item: int64
ts_delta: large_list<item: float>
child 0, item: float
dis_delta: large_list<item: float>
child 0, item: float
start_ts: int64
ds.seq
<pyarrow.lib.ChunkedArray object at 0x7f207ce19c00>
[
[
[
"a",
"b",
"c"
],
[
"b"
],
[
"c",
"d"
]
]
]
ds.entity_id
<pyarrow.lib.ChunkedArray object at 0x7f20d5244be0>
[
[
"1",
"2",
"3"
]
]
可以看到,我们通过Polars构建了一个简单的polars.DataFrame,里面只有3条序列,并且只有2个属性,这两个属性与LocSeqDataset定义的一致,Polars底层的Arrow Table就可以快速导入到LocSeqDataset里面,因为使用的是Arrow,底层没有发生数据拷贝,因此这个转换的性能是极高的。
同理,这种转换也支持Pandas,PyArrow等其他工具。
df.to_pandas()
| seq | entity_id | |
|---|---|---|
| 0 | [a, b, c] | 1 |
| 1 | [b] | 2 |
| 2 | [c, d] | 3 |
import pyarrow as pa
# 从pandas加载
arrow_table = pa.Table.from_pandas(df.to_pandas())
LocSeqDataset.init_from_arrow(arrow_table)
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<ts_seq: large_list<item: int64>> does not exist in the input table.
warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<ts_delta: large_list<item: float>> does not exist in the input table.
warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<dis_delta: large_list<item: float>> does not exist in the input table.
warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<start_ts: int64> does not exist in the input table.
warnings.warn(
LocSeqDataset(size=3)
当然,为了方便使用,TrajDL将上述的DataFrame进行了统一封装,得到了init_from_table方法
help(LocSeqDataset.init_from_table)
Help on method init_from_table in module trajdl.datasets.arrow.abstract:
init_from_table(table: Union[pyarrow.lib.Table, polars.dataframe.frame.DataFrame, pandas.core.frame.DataFrame]) -> 'BaseArrowDataset' class method of trajdl.datasets.arrow.locseq.LocSeqDataset
Initialize the dataset from a Polars, Pandas, or Arrow table.
Parameters
----------
table : Union[pa.Table, pl.DataFrame, pd.DataFrame]
The input table to initialize the dataset.
Returns
-------
BaseArrowDataset
An instance of the dataset initialized from the provided table.
Raises
------
ValueError
If the input is not one of the accepted table types.
LocSeqDataset.init_from_table(df)
LocSeqDataset(size=3)
LocSeqDataset.init_from_table(df.to_pandas())
LocSeqDataset(size=3)
LocSeqDataset.init_from_table(df.to_arrow())
LocSeqDataset(size=3)
Tip
TrajDL建议用户使用Polars框架,因为Polars框架在单机上的性能比较好,其数据类型多样,而且与Arrow的交互非常清晰。
接下来我们已TrajectoryDataset为例来实验一下。
from trajdl.datasets.open_source.conf import PortoDataset
# 只取前两条数据作为演示
df = PortoDataset().load().head(2)
df.head()
load dataset: porto
| TRIP_ID | CALL_TYPE | ORIGIN_CALL | ORIGIN_STAND | TAXI_ID | TIMESTAMP | DAY_TYPE | MISSING_DATA | POLYLINE |
|---|---|---|---|---|---|---|---|---|
| str | str | i64 | i64 | i64 | i64 | str | bool | list[array[f64, 2]] |
| "1372636858620000589" | "C" | null | null | 20000589 | 1372636858 | "A" | false | [[-8.618643, 41.141412], [-8.618499, 41.141376], … [-8.630838, 41.154489]] |
| "1372637303620000596" | "B" | null | 7 | 20000596 | 1372637303 | "A" | false | [[-8.639847, 41.159826], [-8.640351, 41.159871], … [-8.66574, 41.170671]] |
from trajdl.datasets import TrajectoryDataset
TrajectoryDataset.schema()
seq: large_list<item: fixed_size_list<item: double>[2]>
child 0, item: fixed_size_list<item: double>[2]
child 0, item: double
entity_id: large_string
ts_seq: large_list<item: int64>
child 0, item: int64
ts_delta: large_list<item: float>
child 0, item: float
dis_delta: large_list<item: float>
child 0, item: float
start_ts: int64
import polars as pl
# 取出轨迹,TAXI_ID作为entity_id
new_df = df.select(pl.col("POLYLINE").alias("seq"), pl.col("TAXI_ID").cast(pl.String).alias("entity_id"))
new_df.head()
| seq | entity_id |
|---|---|
| list[array[f64, 2]] | str |
| [[-8.618643, 41.141412], [-8.618499, 41.141376], … [-8.630838, 41.154489]] | "20000589" |
| [[-8.639847, 41.159826], [-8.640351, 41.159871], … [-8.66574, 41.170671]] | "20000596" |
ds = TrajectoryDataset.init_from_table(new_df)
ds
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<ts_seq: large_list<item: int64>> does not exist in the input table.
warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<ts_delta: large_list<item: float>> does not exist in the input table.
warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<dis_delta: large_list<item: float>> does not exist in the input table.
warnings.warn(
/home/chaosong/Documents/coder_projects/TrajDL/venv/lib/python3.12/site-packages/trajdl/datasets/arrow/abstract.py:190: RuntimeWarning: Field pyarrow.Field<start_ts: int64> does not exist in the input table.
warnings.warn(
TrajectoryDataset(size=2)
ds.seq
<pyarrow.lib.ChunkedArray object at 0x7f207ce961a0>
[
[
[
[
-8.618643,
41.141412
],
[
-8.618499,
41.141376
],
...
[
-8.630829,
41.154498
],
[
-8.630838,
41.154489
]
],
[
[
-8.639847,
41.159826
],
[
-8.640351,
41.159871
],
...
[
-8.665767,
41.170635
],
[
-8.66574,
41.170671
]
]
]
]
ds.entity_id
<pyarrow.lib.ChunkedArray object at 0x7f206fb91000>
[
[
"20000589",
"20000596"
]
]
Tip
本文讲解了BaseArrowDataset是如何与支持Arrow的框架进行交互的,这种交互方式跨过了单条序列BaseSeq,直接从科学计算框架的底层数据进行加载,避免了数据拷贝,极大了优化了数据集的构建速度。并且基于Arrow的数据集在Pytorch框架使用多进程DataLoader的时候可以显著减少内存使用。