src.FRAME_FM.utils.embedders

Classes

BaseEmbedder

Base class for all neural network modules.

PatchEmbed

1-3D Image to Patch Embedding

STPatchEmbed

1-3D spatiotemporally located input to Patch Embedding

BoundedPatchEmbed

1-3D spatiotemporally located input to Patch Embedding

Functions

partition_embed_dim(→ numpy.ndarray)

Partitions integer embedding dimension into integer components.

calc_embed_omega(→ torch.Tensor)

Calculates appropriate angular frequencies for sincos embedding

sincos_embed_coords(→ torch.Tensor)

Create a periodic sin-cos embedding of an array of 1D coordinates.

Module Contents

src.FRAME_FM.utils.embedders.partition_embed_dim(embed_dim: int, dim_ratio: tuple[int | float, Ellipsis]) numpy.ndarray[source]

Partitions integer embedding dimension into integer components.

Parameters:
  • embed_dim (int) – Total number of dimensions.

  • dim_ratio (tuple[int | float, ...]) – Target ratio of component sizes.

Returns:

Array of integer components.

Return type:

numpy.ndarray

src.FRAME_FM.utils.embedders.calc_embed_omega(embed_dim: int, period: float = 40000.0, res_ratio: float = 10000.0) torch.Tensor[source]

Calculates appropriate angular frequencies for sincos embedding

Parameters:
  • embed_dim (int) – Number of dimensions in which to embed coordinates.

  • period (float, optional) – Coordinate distance over which the embedding is periodic. Defaults to 40,000.

  • res_ratio (float, optional) – Ratio between maximum and minimum resolutions of embedding. Defaults to 10,000.

Returns:

Array of angular frequencies

Return type:

torch.Tensor

src.FRAME_FM.utils.embedders.sincos_embed_coords(coordinates: torch.Tensor, omega: torch.Tensor) torch.Tensor[source]

Create a periodic sin-cos embedding of an array of 1D coordinates.

Parameters:
  • coordinates (torch.Tensor) – Coordinates to embed.

  • embed_dim (int) – Number of dimensions in which to embed coordinates.

  • period (float, optional) – Coordinate distance over which the embedding is periodic. Defaults to 40,000.

  • res_ratio (float, optional) – Ratio between maximum and minimum resolutions of embedding. Defaults to 10,000.

Returns:

Array with each row the sin-cos embedding of a coordinate.

Return type:

torch.Tensor

class src.FRAME_FM.utils.embedders.BaseEmbedder(*args: Any, **kwargs: Any)[source]

Bases: torch.nn.Module, abc.ABC

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

Variables:

training (bool) – Boolean represents whether this module is in training or evaluation mode.

n_patches: int[source]
embed_dim: int[source]
reconstruct_dim: int[source]
abstractmethod initialize_weights()[source]
abstractmethod tokenify(inpt: torch.Tensor) torch.Tensor[source]
abstractmethod forward(inpt: torch.Tensor) torch.Tensor[source]
abstractmethod reconstruct_tokens(embedding: torch.Tensor) torch.Tensor[source]
abstractmethod untokenify(inpt: torch.Tensor) torch.Tensor[source]
class src.FRAME_FM.utils.embedders.PatchEmbed(input_shape: tuple[int, Ellipsis], patch_shape: tuple[int, Ellipsis], n_channels: int, embed_dim: int, reconstruct_dim: int, bias: bool = True, norm_layer: torch.nn.Module | None = None, **conv_kwargs)[source]

Bases: BaseEmbedder

1-3D Image to Patch Embedding

n_channels[source]
embed_dim[source]
proj[source]
reconstruct_dim[source]
reconstruct_layer[source]
initialize_weights()[source]

Set up embedder weights and parameters.

tokenify(inputs: torch.Tensor) torch.Tensor[source]

Reshape batched input n-D images into sequences of patch tokens.

Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w).

Parameters:

inputs (torch.Tensor) – Batched sequence of images, shape e.g. [B, C, Hh, Ww]

Returns:

Batched sequences of patch tokens, shape e.g. [B, HW, hwC]

Return type:

torch.Tensor

untokenify(x: torch.Tensor, output_shape: tuple[int, Ellipsis] | None = None) torch.Tensor[source]

Reshape batched sequences of patch tokens into n-D images.

Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w).

Parameters:
  • x (torch.Tensor) – Batched sequences of tokens, shape e.g. [B, HW, hwC]

  • output_shape (tuple[int, ...] | None, optional) – Required image shape, if different from input_shape in class instantiation. Defaults to None: use input_shape.

Returns:

Batched sequence of images, shape e.g. [B, C, Hh, Ww]

Return type:

torch.Tensor

forward(x: torch.Tensor) tuple[torch.Tensor, torch.Tensor][source]

Convert batched input n-D images into sequences of patch embeddings.

Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w), with a D-D embedding.

Parameters:

x (torch.Tensor) – Batched sequence of images, shape e.g. [B, C, Hh, Ww].

Returns:

  • Batched sequences of embeddings, shape e.g. [B, HW, D]

  • Batched sequences of positions, shape e.g. [B, HW, D]

Return type:

tuple[torch.Tensor, torch.Tensor]

reconstruct_tokens(x: torch.Tensor) torch.Tensor[source]

Reconstruct patch tokens from an embedding.

Parameters:

x (torch.Tensor) – Batched sequence of images, shape e.g. [B, C, Hh, Ww].

Return type:

tuple[torch.Tensor, torch.Tensor]

class src.FRAME_FM.utils.embedders.STPatchEmbed(input_shape: tuple[int, Ellipsis], patch_shape: tuple[int, Ellipsis], n_channels: int, position_space: tuple[tuple[float, float], Ellipsis], embed_dim: int, reconstruct_dim: int, pos_embed_ratio: tuple[float, Ellipsis], bias: bool = True, norm_layer: torch.nn.Module | None = None, **conv_kwargs)[source]

Bases: PatchEmbed

1-3D spatiotemporally located input to Patch Embedding

position_space[source]
pos_embed_ratio[source]
encoder_omegas[source]
decoder_omegas[source]
initialize_weights()[source]

Set up embedder weights and parameters.

forward(st_input: tuple[torch.Tensor, torch.Tensor]) tuple[torch.Tensor, torch.Tensor][source]

Convert batched spatiotemporal inputs into sequences of patch embeddings.

Shapes are given for the example of a batch of B, C-channel 2D inputs in N-D space, with image size (Hh, Ww) and patch size (h, w), with a D-D embedding.

Parameters:

st_input (tuple[torch.Tensor, torch.Tensor]) – Spatiotemporal input, combining: * Batched sequence of values, shape e.g. [B, C, Hh, Ww]. * Batched sequence of positions, shape e.g. [B, N, Hh, Ww].

Returns:

  • Batched sequences of embeddings, shape e.g. [B, HW, D]

  • Batched sequences of position embeddings, shape e.g. [B, HW, D_d]

Return type:

tuple[torch.Tensor, torch.Tensor]

tokenify(inputs: tuple[torch.Tensor, torch.Tensor]) torch.Tensor[source]

Reshape batched n-D input values into sequences of patch tokens.

Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w).

Parameters:

st_input (tuple[torch.Tensor, torch.Tensor]) – Spatiotemporal input, combining: * Batched sequence of values, shape e.g. [B, C, Hh, Ww]. * Batched sequence of positions, shape e.g. [B, N, Hh, Ww].

Returns:

Batched sequences of patch tokens, shape e.g. [B, HW, hwC]

Return type:

torch.Tensor

class src.FRAME_FM.utils.embedders.BoundedPatchEmbed(input_shape: tuple[int, Ellipsis], patch_shape: tuple[int, Ellipsis], n_channels: int, position_space: tuple[tuple[float, float], Ellipsis], embed_dim: int, reconstruct_dim: int, pos_embed_ratio: tuple[float, Ellipsis], bias: bool = True, norm_layer: torch.nn.Module | None = None, **conv_kwargs)[source]

Bases: STPatchEmbed

1-3D spatiotemporally located input to Patch Embedding