trajdl.algorithms.loc_pred.rnn module

trajdl.algorithms.loc_pred.rnn module#

class trajdl.algorithms.loc_pred.rnn.RNNNextLocPredictor(embedding_layer: Embedding, rnn_hidden_size: int, fc_hidden_size: int, output_size: int, num_layers: int, padding_value: int, dropout: float = 0.0)[source]#

Bases: Module

forward(src: LongTensor, lengths: List[int])[source]#
Parameters:
  • src (shape is (B, T))

  • lengths (List[int])

trajdl.algorithms.loc_pred.rnn.select_last_k_elements(x: Tensor, lengths: List[int], k: int) Tensor[source]#

้€‰ๅ‡บๅบๅˆ—้‡Œ้ขๆœ€ๅŽkไธชๅ…ƒ็ด ๏ผŒ่ฟ™้‡Œ่ฆ่€ƒ่™‘ๆฏไธชๅบๅˆ—็š„ๅฎž้™…้•ฟๅบฆ๏ผŒxๆ˜ฏpaddingๅŽ็š„ๅบๅˆ— shape of x is (B, T, *)