FRAME_FM.utils package

Shared utilities used across the project, such as I/O helpers, metrics, and transforms. Generic code that supports models, datasets, dataloaders, and training logic.

Submodules

FRAME_FM.utils.LightningDataModuleWrapper module

class FRAME_FM.utils.LightningDataModuleWrapper.BaseDataModule(data_root: str, batch_size: int = 32, num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = False, split_strategy: str = 'fraction', train_split: float = 0.8, val_split: float = 0.2, test_split: float = 0.0, train_indices: Sequence[int] | None = None, val_indices: Sequence[int] | None = None, test_indices: Sequence[int] | None = None, train_sampler: Sampler[Any] | None = None, val_sampler: Sampler[Any] | None = None, test_sampler: Sampler[Any] | None = None, train_transforms: callable | None = None, val_transforms: callable | None = None, test_transforms: callable | None = None)[source]

Bases: LightningDataModule, ABC

Base class for all DataModules in FRAME-FM.

  • Standardises common arguments (data_root, batch_size, num_workers, etc.).

  • Provides consistent DataLoader construction.

  • Leaves actual dataset creation to subclasses so they can handle arbitrary data formats (shapefiles, tabular, NetCDF, etc.).

setup(stage: str | None = None) None[source]

Called by Lightning at the beginning of training/validation/testing.

Use to: - Optionally load raw data (once) - Delegate to _create_datasets to build train/val/test datasets

test_dataloader() DataLoader[Any][source]
train_dataloader() DataLoader[Any][source]
val_dataloader() DataLoader[Any][source]

FRAME_FM.utils.LightningModuleWrapper module

class FRAME_FM.utils.LightningModuleWrapper.BaseModule[source]

Bases: LightningModule

A thin wrapper around PyTorch Lightning’s LightningModule to allow for future extensions and customizations specific to FRAME-FM project needs.

Subclasses should implement training_step_body and validation_step_body methods instead of training_step and validation_step directly.

Enforces consistent logging patterns across training and validation steps. Enforces logging of loss by default.

log_metrics(name: str, value: Any, on_step: bool = True, on_epoch: bool = True)[source]

Wrapper around self.log to enforce consistent logging defaults.

test_step(batch: Any, batch_idx: int) Any[source]

Default behaviour: call a user-overridable hook and log loss.

test_step_body(batch: Any, batch_idx: int) tuple[Any, Dict[str, Any]][source]
training_step(batch: Any, batch_idx: int) Any[source]

Default behaviour: call a user-overridable hook and log loss.

training_step_body(batch: Any, batch_idx: int) tuple[Any, Dict[str, Any]][source]

Subclasses implement this instead of training_step. Should return (loss, logs_dict).

validation_step(batch: Any, batch_idx: int) Any[source]

Default behaviour: call a user-overridable hook and log loss.

validation_step_body(batch: Any, batch_idx: int) tuple[Any, Dict[str, Any]][source]

FRAME_FM.utils.embedders module

class FRAME_FM.utils.embedders.BaseEmbedder(*args: Any, **kwargs: Any)[source]

Bases: Module, ABC

embed_dim: int
abstractmethod forward(inpt: Tensor) Tensor[source]
abstractmethod initialize_weights()[source]
n_patches: int
reconstruct_dim: int
abstractmethod reconstruct_tokens(embedding: Tensor) Tensor[source]
abstractmethod tokenify(inpt: Tensor) Tensor[source]
abstractmethod untokenify(inpt: Tensor) Tensor[source]
class FRAME_FM.utils.embedders.BoundedPatchEmbed(input_shape: dict[Any, int], patch_shape: dict[Any, int], n_channels: int, position_space: dict[Any, tuple[float, float]], embed_dim: int, reconstruct_dim: int, pos_embed_ratio: dict[Any, float], bias: bool = True, norm_layer: Module | None = None, **conv_kwargs)[source]

Bases: STPatchEmbed

class FRAME_FM.utils.embedders.PatchEmbed(input_shape: dict[Any, int], patch_shape: dict[Any, int], n_channels: int, embed_dim: int, reconstruct_dim: int, bias: bool = True, norm_layer: Module | None = None, **conv_kwargs)[source]

Bases: BaseEmbedder

1-3D Image to Patch Embedding

forward(x: Tensor) tuple[Tensor, Tensor][source]

Convert batched input n-D images into sequences of patch embeddings.

Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w), with a D-D embedding.

Parameters:

x (torch.Tensor) – Batched sequence of images, shape e.g. [B, C, Hh, Ww].

Returns:

  • Batched sequences of embeddings, shape e.g. [B, HW, D]

  • Batched sequences of positions, shape e.g. [B, HW, D]

Return type:

tuple[torch.Tensor, torch.Tensor]

initialize_weights()[source]

Set up embedder weights and parameters.

reconstruct_tokens(x: Tensor) Tensor[source]

Reconstruct patch tokens from an embedding.

Parameters:

x (torch.Tensor) – Batched sequence of images, shape e.g. [B, C, Hh, Ww].

Return type:

tuple[torch.Tensor, torch.Tensor]

tokenify(inputs: Tensor) Tensor[source]

Reshape batched input n-D images into sequences of patch tokens.

Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w).

Parameters:

inputs (torch.Tensor) – Batched sequence of images, shape e.g. [B, C, Hh, Ww]

Returns:

Batched sequences of patch tokens, shape e.g. [B, HW, hwC]

Return type:

torch.Tensor

untokenify(x: Tensor, output_shape: dict[Any, int] | None = None) Tensor[source]

Reshape batched sequences of patch tokens into n-D images.

Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w).

Parameters:
  • x (torch.Tensor) – Batched sequences of tokens, shape e.g. [B, HW, hwC]

  • output_shape (dict[Any, int], optional) – Sizes of output dimensions required, if different from input_shape in class instantiation.

Returns:

Batched sequence of images, shape e.g. [B, C, Hh, Ww]

Return type:

torch.Tensor

class FRAME_FM.utils.embedders.STPatchEmbed(input_shape: dict[Any, int], patch_shape: dict[Any, int], n_channels: int, position_space: dict[Any, tuple[float, float]], embed_dim: int, reconstruct_dim: int, pos_embed_ratio: dict[Any, float], bias: bool = True, norm_layer: Module | None = None, **conv_kwargs)[source]

Bases: PatchEmbed

1-3D spatiotemporally located input to Patch Embedding

forward(st_input: tuple[Tensor, Tensor]) tuple[Tensor, Tensor][source]

Convert batched spatiotemporal inputs into sequences of patch embeddings.

Shapes are given for the example of a batch of B, C-channel 2D inputs in N-D space, with image size (Hh, Ww) and patch size (h, w), with a D-D embedding.

Parameters:

st_input (tuple[torch.Tensor, torch.Tensor]) – Spatiotemporal input, combining: * Batched sequence of values, shape e.g. [B, C, Hh, Ww]. * Batched sequence of positions, shape e.g. [B, N, Hh, Ww].

Returns:

  • Batched sequences of embeddings, shape e.g. [B, HW, D]

  • Batched sequences of position embeddings, shape e.g. [B, HW, D_d]

Return type:

tuple[torch.Tensor, torch.Tensor]

initialize_weights()[source]

Set up embedder weights and parameters.

tokenify(inputs: tuple[Tensor, Tensor]) Tensor[source]

Reshape batched n-D input values into sequences of patch tokens.

Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w).

Parameters:

st_input (tuple[torch.Tensor, torch.Tensor]) – Spatiotemporal input, combining: * Batched sequence of values, shape e.g. [B, C, Hh, Ww]. * Batched sequence of positions, shape e.g. [B, N, Hh, Ww].

Returns:

Batched sequences of patch tokens, shape e.g. [B, HW, hwC]

Return type:

torch.Tensor

FRAME_FM.utils.embedders.calc_embed_omega(embed_dim: int, period: float = 40000.0, res_ratio: float = 10000.0) Tensor[source]

Calculates appropriate angular frequencies for sincos embedding

Parameters:
  • embed_dim (int) – Number of dimensions in which to embed coordinates.

  • period (float, optional) – Coordinate distance over which the embedding is periodic. Defaults to 40,000.

  • res_ratio (float, optional) – Ratio between maximum and minimum resolutions of embedding. Defaults to 10,000.

Returns:

Array of angular frequencies

Return type:

torch.Tensor

FRAME_FM.utils.embedders.partition_embed_dim(embed_dim: int, dim_ratio: list[int | float]) ndarray[source]

Partitions integer embedding dimension into integer components.

Parameters:
  • embed_dim (int) – Total number of dimensions.

  • dim_ratio (list[int | float]) – Target ratio of component sizes.

Returns:

Array of integer components.

Return type:

numpy.ndarray

FRAME_FM.utils.embedders.sincos_embed_coords(coordinates: Tensor, omega: Tensor) Tensor[source]

Create a periodic sin-cos embedding of an array of 1D coordinates.

Parameters:
  • coordinates (torch.Tensor) – Coordinates to embed.

  • embed_dim (int) – Number of dimensions in which to embed coordinates.

  • period (float, optional) – Coordinate distance over which the embedding is periodic. Defaults to 40,000.

  • res_ratio (float, optional) – Ratio between maximum and minimum resolutions of embedding. Defaults to 10,000.

Returns:

Array with each row the sin-cos embedding of a coordinate.

Return type:

torch.Tensor

Module contents