trajdl.algorithms.framework package#

class trajdl.algorithms.framework.PretrainTrainFramework(mode: str, optimizer_type='adam', learning_rate=0.001)[source]#

Bases: BaseLightningModel, ABC

预训练+训练框架

abstract compute_loss(*args, **kwargs)[source]#

这个方法需要根据mode进行loss的计算

abstract init_from_pretrained_ckpt()[source]#

这个方法是给定一个预训练checkpoint的目录,根据一些逻辑对训练阶段的模型进行初始化的工作

property mode: Mode#
set_mode(mode: Mode | str) None[source]#