STLSTM#
本节将会介绍使用TrajDL来复现ST-LSTM算法及在Gowalla数据集下的实验,主要包含如下内容:
Gowalla数据集的预处理基于
TrajDL的ST-LSTM模型训练和推理
Note
ST-LSTM作者并未提供源码,TrajDL基于论文描述在Pytorch框架下进行了模型和实验的复现。在TrajDL中实现了STLSTM的完整实验流程,HST-LSTM会在后续版本中完成。
1. Gowalla数据集预处理#
在Open Source Dataset中我们已经介绍了Gowalla数据集是一种基于社交网络的签到(Check-in)数据集,原始数据集的每一条记录都是一个位置ID,同时还有该位置ID的经纬度坐标。在Quick Start中我们已经介绍过如何使用TrajDL中的接口来加载Gowalla数据集并将其转换为Trajectory和LocSeq。这些数据预处理的操作在本章中不再赘述。
在ST-LSTM中,还要额外处理轨迹序列两个相邻位置之间的时间间隔(time intervals)和距离间隔(distance intervals)。下面将通过代码来介绍如何进行数据的预处理。
先加载Gowalla数据集,并且提取出其中的Check-in轨迹序列:
from trajdl.datasets.open_source import GowallaDataset
import polars as pl
original_df = (
GowallaDataset()
.load(return_as="pl")
.sort(["user_id", "check_in_time"])
.with_columns(ds=pl.col("check_in_time").dt.strftime("%Y%m%d"))
.with_columns(tmp_id=pl.int_range(pl.len()).over("user_id", "ds"))
.limit(50000)
)
original_df.head(7)
load dataset: gowalla
| user_id | check_in_time | lat | lng | loc_id | ds | tmp_id |
|---|---|---|---|---|---|---|
| str | datetime[μs] | f64 | f64 | str | str | i64 |
| "0" | 2010-05-22 02:49:04 | 30.248924 | -97.749626 | "608105" | "20100522" | 0 |
| "0" | 2010-05-22 17:50:55 | 39.297443 | -94.716053 | "8977" | "20100522" | 1 |
| "0" | 2010-05-22 19:13:12 | 38.985246 | -94.605919 | "18574" | "20100522" | 2 |
| "0" | 2010-05-23 16:50:50 | 39.093533 | -94.593174 | "17269" | "20100523" | 0 |
| "0" | 2010-05-23 17:51:57 | 39.093258 | -94.593871 | "1161876" | "20100523" | 1 |
| "0" | 2010-05-23 22:40:45 | 39.000447 | -94.577709 | "1163401" | "20100523" | 2 |
| "0" | 2010-05-23 23:13:17 | 39.001931 | -94.579466 | "668208" | "20100523" | 3 |
计算每两个相邻位置之间的移动时间差:
df_with_ts = original_df.with_columns(
ts_delta=(
pl.col("check_in_time") - pl.col("check_in_time").shift(1)
).dt.total_seconds()
).with_columns(
ts_delta=pl.when(pl.col("tmp_id") == 0)
.then(pl.lit(0))
.otherwise(pl.col("ts_delta"))
)
df_with_ts.head(7)
| user_id | check_in_time | lat | lng | loc_id | ds | tmp_id | ts_delta |
|---|---|---|---|---|---|---|---|
| str | datetime[μs] | f64 | f64 | str | str | i64 | i64 |
| "0" | 2010-05-22 02:49:04 | 30.248924 | -97.749626 | "608105" | "20100522" | 0 | 0 |
| "0" | 2010-05-22 17:50:55 | 39.297443 | -94.716053 | "8977" | "20100522" | 1 | 54111 |
| "0" | 2010-05-22 19:13:12 | 38.985246 | -94.605919 | "18574" | "20100522" | 2 | 4937 |
| "0" | 2010-05-23 16:50:50 | 39.093533 | -94.593174 | "17269" | "20100523" | 0 | 0 |
| "0" | 2010-05-23 17:51:57 | 39.093258 | -94.593871 | "1161876" | "20100523" | 1 | 3667 |
| "0" | 2010-05-23 22:40:45 | 39.000447 | -94.577709 | "1163401" | "20100523" | 2 | 17328 |
| "0" | 2010-05-23 23:13:17 | 39.001931 | -94.579466 | "668208" | "20100523" | 3 | 1952 |
计算每两个相邻位置之间的移动距离差:
import math
df_with_dis = (
df_with_ts.with_columns(
lat1=pl.col("lat") * math.pi / 180, lng1=pl.col("lng") * math.pi / 180
)
.with_columns(lat2=pl.col("lat1").shift(1), lng2=pl.col("lng1").shift(1))
.with_columns(
dlat=(pl.col("lat2") - pl.col("lat1")),
dlng=(pl.col("lng2") - pl.col("lng1")),
)
.with_columns(
a=((pl.col("dlat") / 2).sin() ** 2)
+ pl.col("lat1").cos()
* pl.col("lat2").cos()
* ((pl.col("dlng") / 2).sin() ** 2)
)
.with_columns(dis_delta=2 * (pl.col("a") ** 0.5).arcsin() * 6371)
.with_columns(
dis_delta=pl.when(pl.col("tmp_id") == 0)
.then(pl.lit(0))
.otherwise(pl.col("dis_delta"))
)
)
# 删掉无用的中间列
df_processed = df_with_dis.drop("lat1", "lat2", "lng1", "lng2", "dlat", "dlng", "a", "tmp_id")
df_processed.head(7)
| user_id | check_in_time | lat | lng | loc_id | ds | ts_delta | dis_delta |
|---|---|---|---|---|---|---|---|
| str | datetime[μs] | f64 | f64 | str | str | i64 | f64 |
| "0" | 2010-05-22 02:49:04 | 30.248924 | -97.749626 | "608105" | "20100522" | 0 | 0.0 |
| "0" | 2010-05-22 17:50:55 | 39.297443 | -94.716053 | "8977" | "20100522" | 54111 | 1043.413827 |
| "0" | 2010-05-22 19:13:12 | 38.985246 | -94.605919 | "18574" | "20100522" | 4937 | 35.990638 |
| "0" | 2010-05-23 16:50:50 | 39.093533 | -94.593174 | "17269" | "20100523" | 0 | 0.0 |
| "0" | 2010-05-23 17:51:57 | 39.093258 | -94.593871 | "1161876" | "20100523" | 3667 | 0.067434 |
| "0" | 2010-05-23 22:40:45 | 39.000447 | -94.577709 | "1163401" | "20100523" | 17328 | 10.41405 |
| "0" | 2010-05-23 23:13:17 | 39.001931 | -94.579466 | "668208" | "20100523" | 1952 | 0.224238 |
以“天”为单位来切分访问会话(Visit Session),仅保留序列长度大于3且小于等于10的序列:
sessions = (
df_processed.group_by("user_id", "ds")
.agg(
pl.len().alias("seq_len"),
pl.col("loc_id").sort_by("check_in_time").alias("loc_seq"),
pl.col("ts_delta").sort_by("check_in_time").alias("ts_delta"),
pl.col("dis_delta").sort_by("check_in_time").alias("dis_delta"),
(pl.col("check_in_time").min().dt.timestamp() / 1e6).alias("start_ts"),
)
.filter(pl.col("seq_len") > 3)
.filter(pl.col("seq_len") <= 10)
.select("user_id", "loc_seq", "ts_delta", "dis_delta", "start_ts")
)
sessions.head(5)
| user_id | loc_seq | ts_delta | dis_delta | start_ts |
|---|---|---|---|---|
| str | list[str] | list[i64] | list[f64] | f64 |
| "100800" | ["254901", "166144", … "347439"] | [0, 6808, … 6259] | [0.0, 4.608195, … 7.622445] | 1.2867e9 |
| "100365" | ["748327", "748454", … "748696"] | [0, 697, … 25991] | [0.0, 1.326659, … 2.68612] | 1.2690e9 |
| "100516" | ["112875", "1295130", … "218492"] | [0, 3531, … 149] | [0.0, 0.36733, … 0.32299] | 1.2822e9 |
| "100116" | ["225602", "11102", … "821938"] | [0, 5204, … 18600] | [0.0, 0.2617, … 0.110696] | 1.2710e9 |
| "10001" | ["328495", "115312", … "263225"] | [0, 1278, … 3235] | [0.0, 3.070297, … 1.536069] | 1.2627e9 |
因为在ST-LSTM中需要以Visit Session为单位来计算每个Visit Session的特征,所以此处先计算每个用户下的Visit Session的数量:
tmp_df = (
sessions.sort(["user_id", "start_ts"])
.with_columns(
tmp_id=pl.int_range(pl.len()).over("user_id", order_by="start_ts")
)
.join(
sessions.group_by("user_id").agg(pl.len().alias("num_sessions")),
how="left",
on=["user_id"],
)
)
tmp_df.head(5)
| user_id | loc_seq | ts_delta | dis_delta | start_ts | tmp_id | num_sessions |
|---|---|---|---|---|---|---|
| str | list[str] | list[i64] | list[f64] | f64 | i64 | u32 |
| "0" | ["17269", "1161876", … "668208"] | [0, 3667, … 1952] | [0.0, 0.067434, … 0.224238] | 1.2746e9 | 0 | 12 |
| "0" | ["10259", "19542", … "1221889"] | [0, 16334, … 1765] | [0.0, 2352.661212, … 0.333267] | 1.2758e9 | 1 | 12 |
| "0" | ["1251533", "210176", … "1221889"] | [0, 8682, … 10856] | [0.0, 0.35967, … 0.373425] | 1.2761e9 | 2 | 12 |
| "0" | ["232554", "69230", … "23256"] | [0, 5861, … 10867] | [0.0, 0.001444, … 1247.163358] | 1.2797e9 | 3 | 12 |
| "0" | ["23256", "1303715", … "9410"] | [0, 2679, … 10834] | [0.0, 21.831889, … 1247.163358] | 1.2801e9 | 4 | 12 |
切分训练集:验证集:测试集=6:2:2,以Visit Session为粒度进行划分
# 划分训练集:验证集:测试集=6:2:2
train_df = tmp_df.filter(pl.col("tmp_id") <= pl.col("num_sessions") * 0.6).select(
"user_id", "loc_seq", "ts_delta", "dis_delta", "start_ts"
)
val_df = tmp_df.filter(
(pl.col("tmp_id") > pl.col("num_sessions") * 0.6)
& (pl.col("tmp_id") <= pl.col("num_sessions") * 0.8)
).select("user_id", "loc_seq", "ts_delta", "dis_delta", "start_ts")
test_df = tmp_df.filter(pl.col("tmp_id") > pl.col("num_sessions") * 0.8).select(
"user_id", "loc_seq", "ts_delta", "dis_delta", "start_ts"
)
分别构建训练、验证和测试的LocSeqDataset:
from tqdm.notebook import tqdm
from trajdl.datasets import LocSeq, LocSeqDataset
def generate_loc_seqs(df):
return [
LocSeq(
seq=loc_seq,
entity_id=user_id,
ts_delta=ts_delta,
dis_delta=dis_delta,
start_ts=start_ts,
)
for user_id, loc_seq, ts_delta, dis_delta, start_ts in tqdm(
df.iter_rows(), total=df.height
)
]
train_ds, val_ds, test_ds = (
LocSeqDataset.init_from_loc_seqs(generate_loc_seqs(train_df)),
LocSeqDataset.init_from_loc_seqs(generate_loc_seqs(val_df)),
LocSeqDataset.init_from_loc_seqs(generate_loc_seqs(test_df)),
)
构建LocSeqTokenizer并保存,此处在Gowalla数据集中每个Check-in的位置都有一个唯一的ID,所以无需使用TrajDL中的GridSystem来做Trajectory的离散化。
from trajdl.tokenizers import LocSeqTokenizer
train_locseqs = generate_loc_seqs(train_df)
tokenizer = LocSeqTokenizer.build(loc_seqs=train_locseqs)
计算时间间隔和距离间隔的上下界,并构建Bucketizer。给定一个数或者一组数,Bucketizer可以调用get_bucket_index或者get_bucket_indices函数来计算其对应的bucket,此处的bucket也就是论文中的slot切片。
# 计算数据集的ts_delta的上下界
tsd_stats = train_df.select(ts_delta=pl.col("ts_delta").explode()).select(
ts_delta_max=pl.col("ts_delta").max(), ts_delta_min=pl.col("ts_delta").min()
)
# 计算数据集dist_delta的上下界
disd_stats = train_df.select(dis_delta=pl.col("dis_delta").explode()).select(
dis_delta_max=pl.col("dis_delta").max(), dis_delta_min=pl.col("dis_delta").min()
)
ts_delta_lower, ts_delta_upper = math.floor(tsd_stats["ts_delta_min"][0]), math.ceil(
tsd_stats["ts_delta_max"][0]
)
dis_delta_lower, dis_delta_upper = math.floor(
disd_stats["dis_delta_min"][0]
), math.ceil(disd_stats["dis_delta_max"][0])
print(ts_delta_lower, ts_delta_upper, dis_delta_lower, dis_delta_upper)
0 82128 0 8656
构建Bucketizer:
from trajdl.tokenizers.slot import Bucketizer
time_bucketizer = Bucketizer(
lower_bound=ts_delta_lower, upper_bound=ts_delta_upper, num_buckets=10
)
loc_bucketizer = Bucketizer(
lower_bound=dis_delta_lower, upper_bound=dis_delta_upper, num_buckets=10
)
2. DataModule#
from trajdl.datasets.modules.stlstm import STLSTMDataModule
data_module = STLSTMDataModule(
tokenizer=tokenizer,
train_table=train_ds,
val_table=val_ds,
test_table=test_ds,
ts_bucketizer=time_bucketizer,
loc_bucketizer=loc_bucketizer,
train_batch_size=4,
val_batch_size=4,
num_train_batches=50,
num_val_batches=20,
num_cpus=-1,
)
data_module.setup("fit")
train_dataloader = data_module.train_dataloader()
next(iter(train_dataloader))
/home/chaosong/miniconda3/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: Using fork() can cause Polars to deadlock in the child process.
In addition, using fork() with Python in general is a recipe for mysterious
deadlocks and crashes.
The most likely reason you are seeing this error is because you are using the
multiprocessing module on Linux, which uses fork() by default. This will be
fixed in Python 3.14. Until then, you want to use the "spawn" context instead.
See https://docs.pola.rs/user-guide/misc/multiprocessing/ for details.
If you really know what your doing, you can silence this warning with the warning module
or by setting POLARS_ALLOW_FORKING_THREAD=1.
self.pid = os.fork()
/home/chaosong/miniconda3/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: Using fork() can cause Polars to deadlock in the child process.
In addition, using fork() with Python in general is a recipe for mysterious
deadlocks and crashes.
The most likely reason you are seeing this error is because you are using the
multiprocessing module on Linux, which uses fork() by default. This will be
fixed in Python 3.14. Until then, you want to use the "spawn" context instead.
See https://docs.pola.rs/user-guide/misc/multiprocessing/ for details.
If you really know what your doing, you can silence this warning with the warning module
or by setting POLARS_ALLOW_FORKING_THREAD=1.
self.pid = os.fork()
STLSTMSample(loc_seq=tensor([[2740, 2741, 2742, 2743, 2744, 2745],
[ 373, 635, 62, 77, 634, 104],
[5532, 5533, 5534, 5535, 508, 505],
[ 40, 482, 4861, 485, 4862, 4863]]), td_upper_seq=tensor([[9, 9, 9, 9, 9, 9],
[9, 9, 6, 9, 9, 9],
[9, 9, 9, 9, 9, 9],
[9, 9, 9, 9, 9, 9]]), td_lower_seq=tensor([[0, 0, 0, 0, 0, 0],
[0, 0, 3, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]), sd_upper_seq=tensor([[9, 9, 9, 9, 9, 9],
[9, 9, 9, 9, 9, 9],
[9, 9, 9, 9, 9, 9],
[9, 9, 9, 9, 9, 9]]), sd_lower_seq=tensor([[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]), valid_lengths=[6, 6, 6, 6], labels=tensor([[2741, 2742, 2743, 2744, 2745, 2746],
[ 635, 62, 77, 634, 104, 1764],
[5533, 5534, 5535, 508, 505, 323],
[ 482, 4861, 485, 4862, 4863, 11]]), mask=tensor([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.]]))
3. ST-LSTM模型#
from trajdl.algorithms.loc_pred.stlstm import STLSTMModule
# 构建GM-VASE模型
model = STLSTMModule(
tokenizer=tokenizer,
embedding_dim=128,
hidden_size=256,
ts_bucketizer=time_bucketizer,
loc_bucketizer=loc_bucketizer,
)
model
STLSTMModule(
(stlstm): STLSTM(
(loc_emb): SimpleEmbedding(
(embedding): Embedding(6586, 128, padding_idx=6585)
)
(temporal_upper_emb): Embedding(10, 128)
(temporal_lower_emb): Embedding(10, 128)
(spatial_upper_emb): Embedding(10, 128)
(spatial_lower_emb): Embedding(10, 128)
(temporal_ln): Linear(in_features=128, out_features=768, bias=False)
(spatial_ln): Linear(in_features=128, out_features=768, bias=False)
(input_weight): Linear(in_features=128, out_features=1024, bias=True)
(hidden_weight): Linear(in_features=256, out_features=1024, bias=False)
)
(loss): SampledSoftmaxLoss(
(eval_loss): CrossEntropyLoss()
)
)
4. 训练#
执行如下的代码即可开始模型的训练:
import lightning as L
trainer = L.Trainer(max_epochs=1, logger=False, enable_checkpointing=False)
trainer.fit(model, data_module)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params | Mode
------------------------------------------------------
0 | stlstm | STLSTM | 1.4 M | train
1 | loss | SampledSoftmaxLoss | 1.7 M | train
------------------------------------------------------
3.1 M Trainable params
0 Non-trainable params
3.1 M Total params
12.526 Total estimated model params size (MB)
13 Modules in train mode
0 Modules in eval mode
acc: 0.00%
`Trainer.fit` stopped: `max_epochs=1` reached.
acc: 0.00%
5. 推理#
在STLSTMModule中的forward函数即为推理函数,其输出即为预测的下一位置的概率,经过argmax操作之后即为概率最大的下一位置预测:
import torch
data_module.setup("test")
test_loader = data_module.test_dataloader()
with torch.inference_mode():
predictions = trainer.predict(model, test_loader)
print(predictions[0].argmax(dim=-1))
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
tensor([4954, 998, 4954, 4954])
Tip
本文介绍了ST-LSTM算法的数据集预处理、时空特征处理以及模型的训练&推理。
本文代码中大量使用
Polars来处理数据,推荐读者在使用TrajDL时同样使用Polars工具进行数据处理。