src.FRAME_FM.utils.LightningModuleWrapper
Classes
A thin wrapper around PyTorch Lightning's LightningModule to allow for future extensions |
Module Contents
- class src.FRAME_FM.utils.LightningModuleWrapper.BaseModule[source]
Bases:
pytorch_lightning.LightningModuleA thin wrapper around PyTorch Lightning’s LightningModule to allow for future extensions and customizations specific to FRAME-FM project needs.
Subclasses should implement training_step_body and validation_step_body methods instead of training_step and validation_step directly.
Enforces consistent logging patterns across training and validation steps. Enforces logging of loss by default.
- log_metrics(name: str, value: Any, on_step: bool = True, on_epoch: bool = True)[source]
Wrapper around self.log to enforce consistent logging defaults.
- training_step(batch: Any, batch_idx: int) Any[source]
Default behaviour: call a user-overridable hook and log loss.
- validation_step(batch: Any, batch_idx: int) Any[source]
Default behaviour: call a user-overridable hook and log loss.
- test_step(batch: Any, batch_idx: int) Any[source]
Default behaviour: call a user-overridable hook and log loss.