trajdl.common.samples module#
- class trajdl.common.samples.GMVSAESample(encoder_seq: LongTensor, encoder_lengths: List[int], decoder_seq: LongTensor | None = None, decoder_lengths: List[int] | None = None, decoder_labels: LongTensor | None = None, mask: Tensor | None = None)[source]#
Bases:
objectGMVSAE的输入样本
- Parameters:
encoder_seq (torch.LongTensor) – shape is (B, T)
encoder_lengths (List[int]) – length of each sequence
decoder_seq (Optional[torch.LongTensor], optional) – shape is (B, T + 1), each of seq added BOS, default is None
decoder_lengths (Optional[List[int]], optional) – encoder_lengths + 1, default is None
decoder_labels (Optional[torch.LongTensor], optional) – shape is (B, T + 1), each of seq added EOS, default is None
mask (Optional[torch.Tensor], optional) – shape is (B, T + 1), default is None
- property batch_size: int#
- decoder_labels: LongTensor | None = None#
- decoder_lengths: List[int] | None = None#
- decoder_seq: LongTensor | None = None#
- encoder_lengths: List[int]#
- encoder_seq: LongTensor#
- mask: Tensor | None = None#
- class trajdl.common.samples.STLSTMSample(loc_seq: LongTensor, td_upper_seq: LongTensor, td_lower_seq: LongTensor, sd_upper_seq: LongTensor, sd_lower_seq: LongTensor, valid_lengths: List[int], labels: LongTensor | None = None, mask: LongTensor | None = None)[source]#
Bases:
objectST-LSTM的输入样本
- Parameters:
loc_seq (torch.LongTensor) – shape is (B, T), 位置序列
td_upper_seq (torch.LongTensor) – shape is (B, T), 这个是两个时间步之间的时间差被上界减去的batch
td_lower_seq (torch.LongTensor) – shape is (B, T), 这个是两个时间步之间的时间差减去下界的batch
sd_upper_seq (torch.LongTensor) – shape is (B, T), 这个是两个时间步之间的空间位移被上界减去的batch
sd_lower_seq (torch.LongTensor) – shape is (B, T), 这个是两个时间步之间的空间位移减去下界的batch
valid_lengths (List[int]) – 每条序列的实际长度
labels (torch.LongTensor, optional) – shape is (B, T), 训练时传入, LSTM的输出对应的标签
mask (torch.LongTensor, optional) – shape is (B, T), 训练时传入, LSTM的输出对应的mask, 为1表示该位置应该计算损失, 为0不计算
- property batch_size: int#
- labels: LongTensor | None = None#
- loc_seq: LongTensor#
- mask: LongTensor | None = None#
- sd_lower_seq: LongTensor#
- sd_upper_seq: LongTensor#
- td_lower_seq: LongTensor#
- td_upper_seq: LongTensor#
- valid_lengths: List[int]#
- class trajdl.common.samples.T2VECSample(src: LongTensor, lengths: List[int], target: LongTensor | None = None)[source]#
Bases:
objectT2VEC的输入样本
- Parameters:
src (torch.LongTensor) – shape is (batch_size, seq_length)
lengths (List[int]) – length of each src sequence
target (torch.LongTensor, optional) – shape is (batch_size, seq_length’), default is None
- property batch_size: int#
- lengths: List[int]#
- src: LongTensor#
- target: LongTensor | None = None#
- class trajdl.common.samples.TULERSample(src: LongTensor, seq_len: List[int], labels: LongTensor | None = None)[source]#
Bases:
objectTULER的输入样本
- Parameters:
src (torch.LongTensor) – shape is (B, T)
seq_len (List[int]) – length of each sequence
labels (torch.LongTensor, optional) – shape is (B,), label of each sequence, default is None
- property batch_size: int#
- labels: LongTensor | None = None#
- seq_len: List[int]#
- src: LongTensor#