FRAME_FM.utils package¶
Shared utilities used across the project, such as I/O helpers, metrics, and transforms. Generic code that supports models, datasets, dataloaders, and training logic.
Submodules¶
FRAME_FM.utils.LightningDataModuleWrapper module¶
- class FRAME_FM.utils.LightningDataModuleWrapper.BaseDataModule(data_root: str, batch_size: int = 32, num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = False, split_strategy: str = 'fraction', train_split: float = 0.8, val_split: float = 0.2, test_split: float = 0.0, train_indices: Sequence[int] | None = None, val_indices: Sequence[int] | None = None, test_indices: Sequence[int] | None = None, train_sampler: Sampler[Any] | None = None, val_sampler: Sampler[Any] | None = None, test_sampler: Sampler[Any] | None = None, train_transforms: callable | None = None, val_transforms: callable | None = None, test_transforms: callable | None = None)[source]¶
Bases:
LightningDataModule,ABCBase class for all DataModules in FRAME-FM.
Standardises common arguments (data_root, batch_size, num_workers, etc.).
Provides consistent DataLoader construction.
Leaves actual dataset creation to subclasses so they can handle arbitrary data formats (shapefiles, tabular, NetCDF, etc.).
FRAME_FM.utils.LightningModuleWrapper module¶
- class FRAME_FM.utils.LightningModuleWrapper.BaseModule[source]¶
Bases:
LightningModuleA thin wrapper around PyTorch Lightning’s LightningModule to allow for future extensions and customizations specific to FRAME-FM project needs.
Subclasses should implement training_step_body and validation_step_body methods instead of training_step and validation_step directly.
Enforces consistent logging patterns across training and validation steps. Enforces logging of loss by default.
- log_metrics(name: str, value: Any, on_step: bool = True, on_epoch: bool = True)[source]¶
Wrapper around self.log to enforce consistent logging defaults.
- test_step(batch: Any, batch_idx: int) Any[source]¶
Default behaviour: call a user-overridable hook and log loss.
- training_step(batch: Any, batch_idx: int) Any[source]¶
Default behaviour: call a user-overridable hook and log loss.
- training_step_body(batch: Any, batch_idx: int) tuple[Any, Dict[str, Any]][source]¶
Subclasses implement this instead of training_step. Should return (loss, logs_dict).
FRAME_FM.utils.embedders module¶
- class FRAME_FM.utils.embedders.BaseEmbedder(*args: Any, **kwargs: Any)[source]¶
Bases:
Module,ABC- embed_dim: int¶
- n_patches: int¶
- reconstruct_dim: int¶
- class FRAME_FM.utils.embedders.BoundedPatchEmbed(input_shape: dict[Any, int], patch_shape: dict[Any, int], n_channels: int, position_space: dict[Any, tuple[float, float]], embed_dim: int, reconstruct_dim: int, pos_embed_ratio: dict[Any, float], bias: bool = True, norm_layer: Module | None = None, **conv_kwargs)[source]¶
Bases:
STPatchEmbed
- class FRAME_FM.utils.embedders.PatchEmbed(input_shape: dict[Any, int], patch_shape: dict[Any, int], n_channels: int, embed_dim: int, reconstruct_dim: int, bias: bool = True, norm_layer: Module | None = None, **conv_kwargs)[source]¶
Bases:
BaseEmbedder1-3D Image to Patch Embedding
- forward(x: Tensor) tuple[Tensor, Tensor][source]¶
Convert batched input n-D images into sequences of patch embeddings.
Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w), with a D-D embedding.
- Parameters:
x (torch.Tensor) – Batched sequence of images, shape e.g. [B, C, Hh, Ww].
- Returns:
Batched sequences of embeddings, shape e.g. [B, HW, D]
Batched sequences of positions, shape e.g. [B, HW, D]
- Return type:
tuple[torch.Tensor, torch.Tensor]
- reconstruct_tokens(x: Tensor) Tensor[source]¶
Reconstruct patch tokens from an embedding.
- Parameters:
x (torch.Tensor) – Batched sequence of images, shape e.g. [B, C, Hh, Ww].
- Return type:
tuple[torch.Tensor, torch.Tensor]
- tokenify(inputs: Tensor) Tensor[source]¶
Reshape batched input n-D images into sequences of patch tokens.
Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w).
- Parameters:
inputs (torch.Tensor) – Batched sequence of images, shape e.g. [B, C, Hh, Ww]
- Returns:
Batched sequences of patch tokens, shape e.g. [B, HW, hwC]
- Return type:
torch.Tensor
- untokenify(x: Tensor, output_shape: dict[Any, int] | None = None) Tensor[source]¶
Reshape batched sequences of patch tokens into n-D images.
Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w).
- Parameters:
x (torch.Tensor) – Batched sequences of tokens, shape e.g. [B, HW, hwC]
output_shape (dict[Any, int], optional) – Sizes of output dimensions required, if different from input_shape in class instantiation.
- Returns:
Batched sequence of images, shape e.g. [B, C, Hh, Ww]
- Return type:
torch.Tensor
- class FRAME_FM.utils.embedders.STPatchEmbed(input_shape: dict[Any, int], patch_shape: dict[Any, int], n_channels: int, position_space: dict[Any, tuple[float, float]], embed_dim: int, reconstruct_dim: int, pos_embed_ratio: dict[Any, float], bias: bool = True, norm_layer: Module | None = None, **conv_kwargs)[source]¶
Bases:
PatchEmbed1-3D spatiotemporally located input to Patch Embedding
- forward(st_input: tuple[Tensor, Tensor]) tuple[Tensor, Tensor][source]¶
Convert batched spatiotemporal inputs into sequences of patch embeddings.
Shapes are given for the example of a batch of B, C-channel 2D inputs in N-D space, with image size (Hh, Ww) and patch size (h, w), with a D-D embedding.
- Parameters:
st_input (tuple[torch.Tensor, torch.Tensor]) – Spatiotemporal input, combining: * Batched sequence of values, shape e.g. [B, C, Hh, Ww]. * Batched sequence of positions, shape e.g. [B, N, Hh, Ww].
- Returns:
Batched sequences of embeddings, shape e.g. [B, HW, D]
Batched sequences of position embeddings, shape e.g. [B, HW, D_d]
- Return type:
tuple[torch.Tensor, torch.Tensor]
- tokenify(inputs: tuple[Tensor, Tensor]) Tensor[source]¶
Reshape batched n-D input values into sequences of patch tokens.
Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w).
- Parameters:
st_input (tuple[torch.Tensor, torch.Tensor]) – Spatiotemporal input, combining: * Batched sequence of values, shape e.g. [B, C, Hh, Ww]. * Batched sequence of positions, shape e.g. [B, N, Hh, Ww].
- Returns:
Batched sequences of patch tokens, shape e.g. [B, HW, hwC]
- Return type:
torch.Tensor
- FRAME_FM.utils.embedders.calc_embed_omega(embed_dim: int, period: float = 40000.0, res_ratio: float = 10000.0) Tensor[source]¶
Calculates appropriate angular frequencies for sincos embedding
- Parameters:
embed_dim (int) – Number of dimensions in which to embed coordinates.
period (float, optional) – Coordinate distance over which the embedding is periodic. Defaults to 40,000.
res_ratio (float, optional) – Ratio between maximum and minimum resolutions of embedding. Defaults to 10,000.
- Returns:
Array of angular frequencies
- Return type:
torch.Tensor
- FRAME_FM.utils.embedders.partition_embed_dim(embed_dim: int, dim_ratio: list[int | float]) ndarray[source]¶
Partitions integer embedding dimension into integer components.
- Parameters:
embed_dim (int) – Total number of dimensions.
dim_ratio (list[int | float]) – Target ratio of component sizes.
- Returns:
Array of integer components.
- Return type:
numpy.ndarray
- FRAME_FM.utils.embedders.sincos_embed_coords(coordinates: Tensor, omega: Tensor) Tensor[source]¶
Create a periodic sin-cos embedding of an array of 1D coordinates.
- Parameters:
coordinates (torch.Tensor) – Coordinates to embed.
embed_dim (int) – Number of dimensions in which to embed coordinates.
period (float, optional) – Coordinate distance over which the embedding is periodic. Defaults to 40,000.
res_ratio (float, optional) – Ratio between maximum and minimum resolutions of embedding. Defaults to 10,000.
- Returns:
Array with each row the sin-cos embedding of a coordinate.
- Return type:
torch.Tensor