trajdl.algorithms.abstract module

trajdl.algorithms.abstract module#

class trajdl.algorithms.abstract.BaseLightningModel(optimizer_type='adam', learning_rate=0.001)[source]#

Bases: LightningModule

A base Lightning model that encapsulates optimizer configuration.

Parameters:
  • optimizer_type (str, optional) – The type of optimizer to use (β€œadam”, β€œsgd”, or β€œrmsprop”). Default is β€œadam”.

  • learning_rate (float, optional) – The learning rate for the optimizer. Default is 1e-3.

configure_optimizers()[source]#

Configures the optimizer for the model based on the specified optimizer type.

Returns:

optimizer – The configured optimizer instance for training.

Return type:

torch.optim.Optimizer

Raises:

ValueError – If an unsupported optimizer type is provided.