trajdl.algorithms.embeddings.base module#
- class trajdl.algorithms.embeddings.base.BaseTokenEmbeddingLayer[source]#
Bases:
ModuleBase class for token embedding layers.
- freeze_parameters() None[source]#
Freezes the parameters of the embedding layer, preventing them from being trained.
- unfreeze_parameters() None[source]#
Unfreezes the parameters of the embedding layer, allowing them to be trained.
- is_frozen() bool#
Returns whether the parameters are currently frozen.
- abstract property embedding_dim: int#
- abstract forward(x: LongTensor) Tensor[source]#
Must override in subclass to compute embeddings.
- Parameters:
x (torch.LongTensor) β Input tensor containing token indices.
- Returns:
Embedding tensor for the input tokens, with increased dimensions.
- Return type:
torch.Tensor
- property is_frozen: bool#
Check if the parameters are frozen.
- Returns:
True if parameters are frozen, otherwise False.
- Return type:
bool
- class trajdl.algorithms.embeddings.base.SimpleEmbedding(tokenizer: AbstractTokenizer, embedding_dim: int)[source]#
Bases:
BaseTokenEmbeddingLayerToken embedding layer that uses PyTorchβs nn.Embedding.
- Parameters:
tokenizer (AbstractTokenizer) β Tokenizer
embedding_dim (int) β The dimensionality of the embeddings.
- property embedding_dim: int#
- class trajdl.algorithms.embeddings.base.Word2VecEmbedding(tokenizer: AbstractTokenizer, model_path: str)[source]#
Bases:
BaseTokenEmbeddingLayerToken embedding layer that uses a Gensim Word2Vec model.
- Parameters:
tokenizer (AbstractTokenizer) β Tokenizer
model_path (str) β Path to the Word2Vec model file.
- forward(x: torch.Tensor) torch.Tensor[source]#
Computes the Word2Vec embeddings for the input tokens.
- property embedding_dim: int#
- forward(x: LongTensor) Tensor[source]#
Computes the Word2Vec embeddings for the input tokens.
- Parameters:
x (torch.LongTensor) β Input tensor containing token indices.
- Returns:
Word2Vec embeddings for the input tokens, with increased dimensions.
- Return type:
torch.Tensor
- load_pretrained_word2vec_embeddings(tokenizer: AbstractTokenizer, word2vec_model_path: str) Embedding[source]#
load word2vec embeddings