Source code for trajdl.algorithms.embeddings.base

# Copyright 2024 All authors of TrajDL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod

import torch
import torch.nn as nn
from gensim.models import Word2Vec
from tqdm import tqdm

from ...tokenizers.abstract import AbstractTokenizer


[docs] class BaseTokenEmbeddingLayer(nn.Module): """ Base class for token embedding layers. Methods ------- forward(x: torch.Tensor) -> torch.Tensor Computes the embedding for the input tokens. freeze_parameters() -> None Freezes the parameters of the embedding layer, preventing them from being trained. unfreeze_parameters() -> None Unfreezes the parameters of the embedding layer, allowing them to be trained. is_frozen() -> bool Returns whether the parameters are currently frozen. """ def __init__(self): super(BaseTokenEmbeddingLayer, self).__init__() self._frozen = False @property @abstractmethod def embedding_dim(self) -> int: raise NotImplementedError( "Subclasses should implement this method." ) # pragma: no cover
[docs] @abstractmethod def forward(self, x: torch.LongTensor) -> torch.Tensor: """ Must override in subclass to compute embeddings. Parameters ---------- x : torch.LongTensor Input tensor containing token indices. Returns ------- torch.Tensor Embedding tensor for the input tokens, with increased dimensions. """ raise NotImplementedError("Must override forward method") # pragma: no cover
[docs] def freeze_parameters(self) -> None: """Freeze the parameters to prevent training.""" for param in self.parameters(): param.requires_grad = False self._frozen = True
[docs] def unfreeze_parameters(self) -> None: """Unfreeze the parameters to allow training.""" for param in self.parameters(): param.requires_grad = True self._frozen = False
@property def is_frozen(self) -> bool: """ Check if the parameters are frozen. Returns ------- bool True if parameters are frozen, otherwise False. """ return self._frozen
[docs] class SimpleEmbedding(BaseTokenEmbeddingLayer): """ Token embedding layer that uses PyTorch's nn.Embedding. Parameters ---------- tokenizer: AbstractTokenizer Tokenizer embedding_dim : int The dimensionality of the embeddings. Methods ------- forward(x: torch.Tensor) -> torch.Tensor Computes the embeddings for the input tokens. """ def __init__(self, tokenizer: AbstractTokenizer, embedding_dim: int): super(SimpleEmbedding, self).__init__() self.embedding = nn.Embedding( len(tokenizer), embedding_dim, padding_idx=tokenizer.pad ) @property def embedding_dim(self) -> int: return self.embedding.embedding_dim
[docs] def forward(self, x: torch.LongTensor) -> torch.Tensor: """ Computes the embeddings for the input tokens. Parameters ---------- x : torch.LongTensor Input tensor containing token indices. Returns ------- torch.Tensor Embeddings for the input tokens, with increased dimensions. """ return self.embedding(x)
[docs] class Word2VecEmbedding(BaseTokenEmbeddingLayer): """ Token embedding layer that uses a Gensim Word2Vec model. Parameters ---------- tokenizer: AbstractTokenizer Tokenizer model_path : str Path to the Word2Vec model file. Methods ------- forward(x: torch.Tensor) -> torch.Tensor Computes the Word2Vec embeddings for the input tokens. """ def __init__(self, tokenizer: AbstractTokenizer, model_path: str): super(Word2VecEmbedding, self).__init__() self.embedding = self.load_pretrained_word2vec_embeddings( tokenizer=tokenizer, word2vec_model_path=model_path ) @property def embedding_dim(self) -> int: return self.embedding.embedding_dim
[docs] def load_pretrained_word2vec_embeddings( self, tokenizer: AbstractTokenizer, word2vec_model_path: str ) -> nn.Embedding: """ load word2vec embeddings """ model = Word2Vec.load(word2vec_model_path) embedding_dim = model.vector_size embedding_matrix = torch.zeros( size=(len(tokenizer), embedding_dim), dtype=torch.float32 ) words = model.wv.index_to_key for word in tqdm(words, desc="loading word2vec embeddings"): embedding_matrix[tokenizer.loc2idx(word)] = torch.from_numpy( model.wv[word].copy() ) return nn.Embedding.from_pretrained(embedding_matrix, freeze=False)
[docs] def forward(self, x: torch.LongTensor) -> torch.Tensor: """ Computes the Word2Vec embeddings for the input tokens. Parameters ---------- x : torch.LongTensor Input tensor containing token indices. Returns ------- torch.Tensor Word2Vec embeddings for the input tokens, with increased dimensions. """ return self.embedding(x)