Source code for FRAME_FM.training.train

# SPDX-FileCopyrightText: 2026 FRAME-FM Contributors
#
# SPDX-License-Identifier: Apache-2.0

# src/FRAME_FM/training/train.py
from __future__ import annotations

from hydra import main as hydra_main
from hydra.utils import instantiate
from omegaconf import DictConfig
import pytorch_lightning as pl
import torch.multiprocessing as mp


[docs] @hydra_main(version_base=None, config_path="../../../configs", config_name="config") def main(cfg: DictConfig) -> None: # Ensure reproducibility pl.seed_everything(cfg.get("seed", 42), workers=True) # Instantiate Data + Model from config data = instantiate(cfg.data, _convert_="partial") model = instantiate(cfg.model, _convert_="partial") # Configure MLflow logger (if provided) logger = None if "logging" in cfg: logger = instantiate(cfg.logging) # Instantiate PL Trainer trainer = instantiate(cfg.trainer, logger=logger) # Train if isinstance(data, pl.LightningDataModule): trainer.fit(model, datamodule=data) else: trainer.fit(model, train_dataloaders=data['training'], val_dataloaders=data['validation']) # Optional: test after training if hasattr(cfg.trainer, "run_test") and cfg.trainer.run_test: if isinstance(data, pl.LightningDataModule): trainer.test(model, datamodule=data) else: trainer.test(model, dataloaders=data['test'])
if __name__ == "__main__": mp.set_start_method("spawn", force=True) main()