src.FRAME_FM.utils.LightningDataModuleWrapper
Classes
Base class for all DataModules in FRAME-FM. |
Module Contents
- class src.FRAME_FM.utils.LightningDataModuleWrapper.BaseDataModule(data_root: str, batch_size: int = 32, num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = False, split_strategy: str = 'fraction', train_split: float = 0.8, val_split: float = 0.2, test_split: float = 0.0, train_indices: Sequence[int] | None = None, val_indices: Sequence[int] | None = None, test_indices: Sequence[int] | None = None, train_sampler: torch.utils.data.Sampler[Any] | None = None, val_sampler: torch.utils.data.Sampler[Any] | None = None, test_sampler: torch.utils.data.Sampler[Any] | None = None, train_transforms: callable | None = None, val_transforms: callable | None = None, test_transforms: callable | None = None)[source]
Bases:
pytorch_lightning.LightningDataModule,abc.ABCBase class for all DataModules in FRAME-FM.
Standardises common arguments (data_root, batch_size, num_workers, etc.).
Provides consistent DataLoader construction.
Leaves actual dataset creation to subclasses so they can handle arbitrary data formats (shapefiles, tabular, NetCDF, etc.).
- setup(stage: str | None = None) None[source]
Called by Lightning at the beginning of training/validation/testing.
Use to: - Optionally load raw data (once) - Delegate to _create_datasets to build train/val/test datasets
- train_dataloader() torch.utils.data.DataLoader[Any][source]
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader() torch.utils.data.DataLoader[Any][source]
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.
- test_dataloader() torch.utils.data.DataLoader[Any][source]
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.