src.FRAME_FM.utils.LightningModuleWrapper

Classes

BaseModule

A thin wrapper around PyTorch Lightning's LightningModule to allow for future extensions

Module Contents

class src.FRAME_FM.utils.LightningModuleWrapper.BaseModule[source]

Bases: pytorch_lightning.LightningModule

A thin wrapper around PyTorch Lightning’s LightningModule to allow for future extensions and customizations specific to FRAME-FM project needs.

Subclasses should implement training_step_body and validation_step_body methods instead of training_step and validation_step directly.

Enforces consistent logging patterns across training and validation steps. Enforces logging of loss by default.

log_metrics(name: str, value: Any, on_step: bool = True, on_epoch: bool = True)[source]

Wrapper around self.log to enforce consistent logging defaults.

training_step(batch: Any, batch_idx: int) Any[source]

Default behaviour: call a user-overridable hook and log loss.

validation_step(batch: Any, batch_idx: int) Any[source]

Default behaviour: call a user-overridable hook and log loss.

test_step(batch: Any, batch_idx: int) Any[source]

Default behaviour: call a user-overridable hook and log loss.

abstractmethod training_step_body(batch: Any, batch_idx: int) tuple[Any, Dict[str, Any]][source]

Subclasses implement this instead of training_step. Should return (loss, logs_dict).

abstractmethod validation_step_body(batch: Any, batch_idx: int) tuple[Any, Dict[str, Any]][source]
abstractmethod test_step_body(batch: Any, batch_idx: int) tuple[Any, Dict[str, Any]][source]