Source code for src.FRAME_FM.utils.LightningModuleWrapper

# src/FRAME_FM/utils/LightningModuleWrapper.py
from typing import Any, Dict, Optional
import pytorch_lightning as pl


[docs] class BaseModule(pl.LightningModule): """ A thin wrapper around PyTorch Lightning's LightningModule to allow for future extensions and customizations specific to FRAME-FM project needs. Subclasses should implement `training_step_body` and `validation_step_body` methods instead of `training_step` and `validation_step` directly. Enforces consistent logging patterns across training and validation steps. Enforces logging of loss by default. """ def __init__(self): super().__init__() self.save_hyperparameters() # ------ OVERWRITES -----
[docs] def log_metrics( self, name: str, value: Any, on_step: bool = True, on_epoch: bool = True ): """Wrapper around self.log to enforce consistent logging defaults.""" self.log( name, value, on_step=on_step, on_epoch=on_epoch, prog_bar=False, logger=True )
[docs] def training_step(self, batch: Any, batch_idx: int) -> Any: """Default behaviour: call a user-overridable hook and log loss.""" loss, logs = self.training_step_body(batch, batch_idx) # standard logging pattern self.log( "train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True ) for k, v in logs.items(): self.log(f"train/{k}", v, on_step=True, on_epoch=True, logger=True) return loss
[docs] def validation_step(self, batch: Any, batch_idx: int) -> Any: """Default behaviour: call a user-overridable hook and log loss.""" loss, logs = self.validation_step_body(batch, batch_idx) self.log( "val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True ) for k, v in logs.items(): self.log(f"val/{k}", v, on_step=False, on_epoch=True, logger=True) return loss
[docs] def test_step(self, batch: Any, batch_idx: int) -> Any: """Default behaviour: call a user-overridable hook and log loss.""" loss, logs = self.testing_step_body(batch, batch_idx) self.log( "test/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True ) for k, v in logs.items(): self.log(f"test/{k}", v, on_step=False, on_epoch=True, logger=True) return loss
# ---- Hooks model developers are expected to override ----
[docs] def training_step_body( self, batch: Any, batch_idx: int ) -> tuple[Any, Dict[str, Any]]: """ Subclasses implement this instead of training_step. Should return (loss, logs_dict). """ raise NotImplementedError
[docs] def validation_step_body( self, batch: Any, batch_idx: int ) -> tuple[Any, Dict[str, Any]]: raise NotImplementedError
[docs] def test_step_body(self, batch: Any, batch_idx: int) -> tuple[Any, Dict[str, Any]]: raise NotImplementedError