Source code for src.FRAME_FM.utils.embedders

# Copyright (c) Matt Arran.

# This source code is adapted from Masked Autoencoder (MAE) code, copyright
# Meta Platforms, Inc. and affiliates, licensed under the licence found in the
# licences/LICENSE_MAE.txt file.
# PatchEmbed is inspired by the equivalent class in PyTorch Image Models (timm).
# --------------------------------------------------------
# Embedding utils
# --------------------------------------------------------
# References:
# MAE: https://github.com/facebookresearch/mae/tree/main
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# --------------------------------------------------------

from abc import ABC, abstractmethod
from collections.abc import Callable
from math import prod
import numpy as np
import torch

_Conv_dim_dict: dict[int, type[torch.nn.modules.conv._ConvNd]] = {
    1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d
    }
_conv_dim_dict: dict[int, Callable[..., torch.Tensor]] = {
    1: torch.nn.functional.conv1d, 2: torch.nn.functional.conv2d, 3: torch.nn.functional.conv3d
    }


[docs] def partition_embed_dim(embed_dim: int, dim_ratio: tuple[int | float, ...]) -> np.ndarray: """Partitions integer embedding dimension into integer components. Args: embed_dim (int): Total number of dimensions. dim_ratio (tuple[int | float, ...]): Target ratio of component sizes. Returns: numpy.ndarray: Array of integer components. """ embed_dims = np.round( embed_dim // 2 * np.array(dim_ratio) / sum(dim_ratio) ).astype(int) embed_dims[embed_dims.argmin()] += embed_dim // 2 - embed_dims.sum() return embed_dims
[docs] def calc_embed_omega(embed_dim: int, period: float = 4e4, res_ratio: float = 1e4) -> torch.Tensor: """Calculates appropriate angular frequencies for sincos embedding Args: embed_dim (int): Number of dimensions in which to embed coordinates. period (float, optional): Coordinate distance over which the embedding is periodic. Defaults to 40,000. res_ratio (float, optional): Ratio between maximum and minimum resolutions of embedding. Defaults to 10,000. Returns: torch.Tensor: Array of angular frequencies """ n_periods = torch.exp(torch.linspace(0, np.log(res_ratio), embed_dim)).round() # (D,) return 2 * np.pi * n_periods / period # (D,)
[docs] def sincos_embed_coords(coordinates: torch.Tensor, omega: torch.Tensor) -> torch.Tensor: """Create a periodic sin-cos embedding of an array of 1D coordinates. Args: coordinates (torch.Tensor): Coordinates to embed. embed_dim (int): Number of dimensions in which to embed coordinates. period (float, optional): Coordinate distance over which the embedding is periodic. Defaults to 40,000. res_ratio (float, optional): Ratio between maximum and minimum resolutions of embedding. Defaults to 10,000. Returns: torch.Tensor: Array with each row the sin-cos embedding of a coordinate. """ phases = torch.einsum('m,d->md', coordinates, omega) # (M, D), outer product return torch.cat([torch.sin(phases), torch.cos(phases)], dim=1) # (M, 2D)
[docs] class BaseEmbedder(torch.nn.Module, ABC):
[docs] n_patches: int
[docs] embed_dim: int
[docs] reconstruct_dim: int
@abstractmethod
[docs] def initialize_weights(self): pass
@abstractmethod
[docs] def tokenify(self, inpt: torch.Tensor) -> torch.Tensor: pass
@abstractmethod
[docs] def forward(self, inpt: torch.Tensor) -> torch.Tensor: pass
@abstractmethod
[docs] def reconstruct_tokens(self, embedding: torch.Tensor) -> torch.Tensor: pass
@abstractmethod
[docs] def untokenify(self, inpt: torch.Tensor) -> torch.Tensor: pass
[docs] class PatchEmbed(BaseEmbedder): """ 1-3D Image to Patch Embedding """ def __init__(self, input_shape: tuple[int, ...], patch_shape: tuple[int, ...], n_channels: int, embed_dim: int, reconstruct_dim: int, bias: bool = True, norm_layer: torch.nn.Module | None = None, **conv_kwargs): """Instantiate embedder for patches in 1-3D, n-channel images. Args: input_shape (tuple[int, ...]): Shape of input image. patch_shape (tuple[int, ...]): Shape of patches into which image will be divided. n_channels (int): Number of channels recorded in image. embed_dim (int): Number of dimensions into which to embed each patch. reconstruct_dim (int): Number of embedding dimensions from which to reconstruct patch. bias (bool, optional): Whether to include bias in patch embedding. Defaults to True. norm_layer (torch.nn.Module | None, optional): Layer with which to normalise embedding. Defaults to None: no normalisation. **conv_kwargs: Keyword arguments to pass for convolution layer instantiation. """ super().__init__() assert len(input_shape) in _Conv_dim_dict.keys(), \ f"{len(input_shape)}D input not supported" self.input_shape, self.patch_shape = tuple(input_shape), tuple(patch_shape) self.grid_shape, self.n_patches = self._count_patches(input_shape)
[docs] self.n_channels = n_channels
[docs] self.embed_dim = embed_dim
conv_class = _Conv_dim_dict[len(input_shape)]
[docs] self.proj = conv_class( n_channels, embed_dim, kernel_size=patch_shape, stride=patch_shape, bias=bias, **conv_kwargs )
if norm_layer is None: self.norm = torch.nn.Identity() else: self.norm = norm_layer(embed_dim)
[docs] self.reconstruct_dim = reconstruct_dim
[docs] self.reconstruct_layer = torch.nn.Linear( reconstruct_dim, prod(patch_shape) * n_channels, bias=True )
def _count_patches(self, input_shape): grid_shape = tuple(s_i // s_p for s_i, s_p in zip(input_shape, self.patch_shape)) n_patches = prod(grid_shape) return grid_shape, n_patches def _define_position_embedding(self, embed_dim: int) -> torch.nn.Parameter: """Create a sin-cos embedding of positions on an n-dimensional grid. Args: embed_dim (int): Number of dimensions into which to embed grid positions. Returns: torch.Parameter: Array with each row the sin-cos embedding of a grid position. """ # Divide sin and cos embeddings according to size of grid in each dimension embed_dims = partition_embed_dim(embed_dim, dim_ratio=self.grid_shape) # Define grid: len(grid_shape)-tuple of np.arrays of shape grid_shape grid = torch.meshgrid([ torch.arange(grid_s, dtype=torch.float32) for grid_s in self.grid_shape ], indexing='ij') # Create sincos embedding: np.array of shape (prod(grid_shape), 2 * sum(embed_dims)) omegas = [calc_embed_omega(dim) for dim in embed_dims] embedding = torch.cat([ sincos_embed_coords(coords.flatten(), omega) for coords, omega in zip(grid, omegas) ], dim=1) return torch.nn.Parameter(embedding.float().unsqueeze(0), requires_grad=False)
[docs] def initialize_weights(self): """Set up embedder weights and parameters. """ # define fixed sin-cos embeddings of patch position within image self.pos_embed = self._define_position_embedding(self.embed_dim) self.decoder_pos_embed = self._define_position_embedding(self.reconstruct_dim) # initialize projection like nn.Linear (instead of nn.Conv2d) w = self.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
[docs] def tokenify(self, inputs: torch.Tensor) -> torch.Tensor: """Reshape batched input n-D images into sequences of patch tokens. Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w). Args: inputs (torch.Tensor): Batched sequence of images, shape e.g. [B, C, Hh, Ww] Returns: torch.Tensor: Batched sequences of patch tokens, shape e.g. [B, HW, hwC] """ patch_shape, n_dim = self.patch_shape, len(self.patch_shape) assert len(inputs.shape) - 2 == n_dim, \ f"{len(inputs.shape) - 2}-D input not divisible into {n_dim}-D patches" for dim, (s_i, s_p) in enumerate(zip(inputs.shape[2:], patch_shape)): assert s_i % s_p == 0, \ f"Input dimension {dim} not divisible into patches ({s_i} % {s_p} != 0)" grid_shape, n_patches = self._count_patches(inputs.shape[2:]) x = inputs.reshape(shape=( (inputs.shape[0], self.n_channels) + sum([(s_g, s_p) for s_g, s_p in zip(grid_shape, patch_shape)], ()) )) # (B, C, H, h, W, w) for 2D patches x = x.permute( 0, *range(2, 2 + 2 * n_dim, 2), *range(3, 3 + 2 * n_dim, 2), 1 ) # (B, H, W, h, w, C) for 2D patches x = x.reshape(shape=(inputs.shape[0], n_patches, prod(patch_shape) * self.n_channels)) # (B, H W, h w C) for 2D patches return x
[docs] def untokenify(self, x: torch.Tensor, output_shape: tuple[int, ...] | None = None) -> torch.Tensor: """Reshape batched sequences of patch tokens into n-D images. Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w). Args: x (torch.Tensor): Batched sequences of tokens, shape e.g. [B, HW, hwC] output_shape (tuple[int, ...] | None, optional): Required image shape, if different from input_shape in class instantiation. Defaults to None: use input_shape. Returns: torch.Tensor: Batched sequence of images, shape e.g. [B, C, Hh, Ww] """ patch_shape, n_dim = self.patch_shape, len(self.patch_shape) n_channels = self.n_channels if output_shape is None: grid_shape, n_patches = self.grid_shape, self.n_patches else: assert len(output_shape) == n_dim, \ f"{len(output_shape)}-D output not formable from {n_dim}-D patches" grid_shape, n_patches = self._count_patches(output_shape) assert x.shape[1] == n_patches, \ f"Grid shape {grid_shape} not formable from {x.shape[1]} patches" assert x.shape[2] == prod(patch_shape) * self.n_channels, \ f"{n_channels}-channel {patch_shape} patch not formable from {x.shape[2]} values" x = x.reshape(shape=(x.shape[0],) + grid_shape + patch_shape + (n_channels,)) # (N, H, W, h, w, C) for 2D patches x = x.permute( 0, -1, *sum([(id, n_dim + id) for id in range(1, 1 + n_dim)], ()) ) # (N, C, H, h, W, w) for 2D patches imgs = x.reshape(shape=( (x.shape[0], n_channels) + sum([(s_g * s_p,) for s_g, s_p in zip(grid_shape, patch_shape)], ()) )) # (N, C, H h, W w) for 2D patches return imgs
[docs] def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Convert batched input n-D images into sequences of patch embeddings. Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w), with a D-D embedding. Args: x (torch.Tensor): Batched sequence of images, shape e.g. [B, C, Hh, Ww]. Returns: tuple[torch.Tensor, torch.Tensor]: * Batched sequences of embeddings, shape e.g. [B, HW, D] * Batched sequences of positions, shape e.g. [B, HW, D] """ for dim, (s_actual, s_expected) in enumerate(zip(x.shape[2:], self.input_shape)): assert s_actual == s_expected, \ f"Input dimension {dim} ({s_actual}) doesn't match specification ({s_expected})" x = self.proj(x) # Project each patch into embedding, by convolution x = x.flatten(start_dim=2).transpose(1, 2) # (B, D, H, W) -> (B, HW, D) for 2D patches x = self.norm(x) x = x + self.pos_embed return x, self.decoder_pos_embed.expand(x.shape[0], -1, -1)
[docs] def reconstruct_tokens(self, x: torch.Tensor) -> torch.Tensor: """Reconstruct patch tokens from an embedding. Args: x (torch.Tensor): Batched sequence of images, shape e.g. [B, C, Hh, Ww]. Returns: tuple[torch.Tensor, torch.Tensor]: """ return self.reconstruct_layer(x)
[docs] class STPatchEmbed(PatchEmbed): """ 1-3D spatiotemporally located input to Patch Embedding """ def __init__(self, input_shape: tuple[int, ...], patch_shape: tuple[int, ...], n_channels: int, position_space: tuple[tuple[float, float], ...], embed_dim: int, reconstruct_dim: int, pos_embed_ratio: tuple[float, ...], bias: bool = True, norm_layer: torch.nn.Module | None = None, **conv_kwargs): """Instantiate embedder for patches in 1-3D, n-channel images. Args: input_shape (tuple[int, ...]): Shape of input image. patch_shape (tuple[int, ...]): Shape of patches into which image will be divided. n_channels (int): Number of channels recorded in image. position_space (tuple[tuple[float, float], ...]): Space in which pixels are positioned. (range(x_0), ...) for coordinates x_i. embed_dim (int): Number of dimensions into which to embed each patch. reconstruct_dim (int): Number of embedding dimensions from which to reconstruct patch. pos_embed_ratio (tuple[float, ...]): Relative sizes of position embedding dimensions. bias (bool, optional): Whether to include bias in patch embedding. Defaults to True. norm_layer (torch.nn.Module | None, optional): Layer with which to normalise embedding. Defaults to None: no normalisation. **conv_kwargs: Keyword arguments to pass for convolution layer instantiation. """ super().__init__( input_shape=input_shape, patch_shape=patch_shape, n_channels=n_channels, embed_dim=embed_dim, reconstruct_dim=reconstruct_dim, bias=bias, norm_layer=norm_layer, **conv_kwargs, )
[docs] self.position_space = position_space
[docs] self.pos_embed_ratio = pos_embed_ratio
[docs] self.encoder_omegas = torch.nn.ParameterList()
[docs] self.decoder_omegas = torch.nn.ParameterList()
def _define_position_embedding(self, embed_dim: int, omegas: torch.nn.ParameterList ) -> Callable[[torch.Tensor], torch.Tensor]: # Divide sin and cos embeddings according to size of grid in each dimension embed_dims = partition_embed_dim(embed_dim, dim_ratio=self.pos_embed_ratio) st_dim = len(self.position_space) conv_fn = _conv_dim_dict[len(self.input_shape)] self.pos_conv_kernel = torch.nn.Parameter( torch.ones((st_dim, 1) + self.patch_shape) / prod(self.patch_shape), requires_grad=False, ) omegas.extend([ torch.nn.Parameter(calc_embed_omega(dim, period=x_max - x_min), requires_grad=False) for dim, (x_min, x_max) in zip(embed_dims, self.position_space) ]) def embedding(pos: torch.Tensor) -> torch.Tensor: # pos shape B, 2, Hh, Ww for 2D patches batch_size = pos.shape[0] assert pos.shape[1] == len(self.position_space), \ f"{pos.shape[1]}-D position space doesn't match spec. ({len(self.position_space)})" for dim, (s_pos, s_spec) in enumerate(zip(pos.shape[2:], self.input_shape)): assert s_pos == s_spec, \ f"Input positions dimension {dim} ({s_pos}) doesn't match spec. ({s_spec})" pos = conv_fn( pos, self.pos_conv_kernel, stride=self.patch_shape, groups=st_dim ) # B, 2, H, W for 2D patches pos = pos.transpose(0, 1).flatten(start_dim=1) # 2, BHW for 2D patches embeddings = torch.cat([ sincos_embed_coords(coords, omega) for coords, omega in zip(pos, omegas) ], dim=1) # BHW, D for 2D patches return embeddings.reshape([batch_size, -1, embed_dim]) # B, HW, D for 2D patches return embedding
[docs] def initialize_weights(self): """Set up embedder weights and parameters. """ # define fixed sin-cos embedding functions for patch positions self.pos_embed = self._define_position_embedding(self.embed_dim, self.encoder_omegas) self.decoder_pos_embed = self._define_position_embedding(self.reconstruct_dim, self.decoder_omegas) # initialize projection like nn.Linear (instead of nn.Conv2d) w = self.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
[docs] def forward(self, st_input: tuple[torch.Tensor, torch.Tensor] ) -> tuple[torch.Tensor, torch.Tensor]: """Convert batched spatiotemporal inputs into sequences of patch embeddings. Shapes are given for the example of a batch of B, C-channel 2D inputs in N-D space, with image size (Hh, Ww) and patch size (h, w), with a D-D embedding. Args: st_input (tuple[torch.Tensor, torch.Tensor]): Spatiotemporal input, combining: * Batched sequence of values, shape e.g. [B, C, Hh, Ww]. * Batched sequence of positions, shape e.g. [B, N, Hh, Ww]. Returns: tuple[torch.Tensor, torch.Tensor]: * Batched sequences of embeddings, shape e.g. [B, HW, D] * Batched sequences of position embeddings, shape e.g. [B, HW, D_d] """ x, pos = st_input assert x.shape[1] == self.n_channels, \ f"# of input channels ({x.shape[1]}) doesn't match spec. ({self.n_channels})" for dim, (s_x, s_spec) in enumerate(zip(x.shape[2:], self.input_shape)): assert s_x == s_spec, \ f"Input values dimension {dim} ({s_x}) doesn't match spec. ({s_spec})" x = self.proj(x) # Project each patch into embedding, by convolution x = x.flatten(start_dim=2).transpose(1, 2) # (B, D, H, W) -> (B, HW, D) for 2D patches x = self.norm(x) x = x + self.pos_embed(pos) return x, self.decoder_pos_embed(pos)
[docs] def tokenify(self, inputs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: """Reshape batched n-D input values into sequences of patch tokens. Shapes are given for the example of a batch of B, C-channel 2D images, with image size (Hh, Ww) and patch size (h, w). Args: st_input (tuple[torch.Tensor, torch.Tensor]): Spatiotemporal input, combining: * Batched sequence of values, shape e.g. [B, C, Hh, Ww]. * Batched sequence of positions, shape e.g. [B, N, Hh, Ww]. Returns: torch.Tensor: Batched sequences of patch tokens, shape e.g. [B, HW, hwC] """ return PatchEmbed.tokenify(self, inputs[0])
[docs] class BoundedPatchEmbed(STPatchEmbed): def _define_position_embedding(self, embed_dim: int, omegas: torch.nn.ParameterList ) -> Callable[[torch.Tensor], torch.Tensor]: # Divide sin and cos embeddings according to size of grid in each dimension embed_dims = partition_embed_dim(embed_dim, dim_ratio=self.pos_embed_ratio) self.patch_grid_coords = torch.nn.ParameterList([ torch.nn.Parameter(torch.arange(0.5, grid_s) / grid_s, requires_grad=False) for grid_s in self.grid_shape ]) omegas.extend([ torch.nn.Parameter(calc_embed_omega(dim, period=x_max - x_min), requires_grad=False) for dim, (x_min, x_max) in zip(embed_dims, self.position_space) ]) def embedding(bounds_batch: torch.Tensor) -> torch.Tensor: # [[[bottom, top], [left, right]], ...] for 2D patches batch_size = bounds_batch.shape[0] assert bounds_batch.shape[1] == len(self.position_space), \ f"{bounds_batch.shape[1]}-D position space doesn't match spec. ({len(self.position_space)})" assert bounds_batch.shape[2] == 2, \ f"Input position bounds (e.g. {bounds_batch[0, 0]}) must be 2-element ranges." pos = torch.cat([ torch.cartesian_prod(*[ b_min + (b_max - b_min) * patch_coords for (b_min, b_max), patch_coords in zip(bounds, self.patch_grid_coords) ]) for bounds in bounds_batch ], dim=0).transpose(0, 1) # 2, BHW for 2D patches embeddings = torch.cat([ sincos_embed_coords(coords, omega) for coords, omega in zip(pos, omegas) ], dim=1) # BHW, D for 2D patches return embeddings.reshape([batch_size, -1, embed_dim]) # B, HW, D for 2D patches return embedding