#
#
#
#
#
#
# This source code is adapted from Masked Autoencoder (MAE) code, copyright
# Meta Platforms, Inc. and affiliates, licensed under the licence found in the
# LICENSES/CC-BY-NC-4.0.txt file.
# PatchEmbed is inspired by the equivalent class in PyTorch Image Models (timm).
# --------------------------------------------------------
# Embedding utils
# --------------------------------------------------------
# References:
# MAE: https://github.com/facebookresearch/mae/tree/main
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Callable
from math import prod
import numpy as np
import torch
from typing import Any
_Conv_dim_dict: dict[int, type[torch.nn.modules.conv._ConvNd]] = {
1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d
}
_conv_dim_dict: dict[int, Callable[..., torch.Tensor]] = {
1: torch.nn.functional.conv1d, 2: torch.nn.functional.conv2d, 3: torch.nn.functional.conv3d
}
[docs]
def partition_embed_dim(embed_dim: int, dim_ratio: list[int | float]) -> np.ndarray:
"""Partitions integer embedding dimension into integer components.
Args:
embed_dim (int): Total number of dimensions.
dim_ratio (list[int | float]): Target ratio of component sizes.
Returns:
numpy.ndarray: Array of integer components.
"""
embed_dims = np.round(
embed_dim // 2 * np.array(dim_ratio) / sum(dim_ratio)
).astype(int)
embed_dims[embed_dims.argmin()] += embed_dim // 2 - embed_dims.sum()
return embed_dims
[docs]
def calc_embed_omega(embed_dim: int, period: float = 4e4, res_ratio: float = 1e4) -> torch.Tensor:
"""Calculates appropriate angular frequencies for sincos embedding
Args:
embed_dim (int): Number of dimensions in which to embed coordinates.
period (float, optional): Coordinate distance over which the embedding is periodic.
Defaults to 40,000.
res_ratio (float, optional): Ratio between maximum and minimum resolutions of embedding.
Defaults to 10,000.
Returns:
torch.Tensor: Array of angular frequencies
"""
n_periods = torch.exp(torch.linspace(0, np.log(res_ratio), embed_dim)).round() # (D,)
return 2 * np.pi * n_periods / period # (D,)
[docs]
def sincos_embed_coords(coordinates: torch.Tensor, omega: torch.Tensor) -> torch.Tensor:
"""Create a periodic sin-cos embedding of an array of 1D coordinates.
Args:
coordinates (torch.Tensor): Coordinates to embed.
embed_dim (int): Number of dimensions in which to embed coordinates.
period (float, optional): Coordinate distance over which the embedding is periodic.
Defaults to 40,000.
res_ratio (float, optional): Ratio between maximum and minimum resolutions of embedding.
Defaults to 10,000.
Returns:
torch.Tensor: Array with each row the sin-cos embedding of a coordinate.
"""
phases = torch.einsum('m,d->md', coordinates, omega) # (M, D), outer product
return torch.cat([torch.sin(phases), torch.cos(phases)], dim=1) # (M, 2D)
[docs]
class BaseEmbedder(torch.nn.Module, ABC):
n_patches: int
embed_dim: int
reconstruct_dim: int
[docs]
@abstractmethod
def initialize_weights(self):
pass
[docs]
@abstractmethod
def tokenify(self, inpt: torch.Tensor) -> torch.Tensor:
pass
[docs]
@abstractmethod
def forward(self, inpt: torch.Tensor) -> torch.Tensor:
pass
[docs]
@abstractmethod
def reconstruct_tokens(self, embedding: torch.Tensor) -> torch.Tensor:
pass
[docs]
@abstractmethod
def untokenify(self, inpt: torch.Tensor) -> torch.Tensor:
pass
[docs]
class PatchEmbed(BaseEmbedder):
""" 1-3D Image to Patch Embedding
"""
def __init__(self,
input_shape: dict[Any, int],
patch_shape: dict[Any, int],
n_channels: int,
embed_dim: int,
reconstruct_dim: int,
bias: bool = True,
norm_layer: torch.nn.Module | None = None,
**conv_kwargs):
"""Instantiate embedder for patches in 1-3D, n-channel images.
Args:
input_shape (dict[Any, int]): Sizes of input dimensions.
patch_shape (dict[Any, int]): Shape of patches into which input will be divided.
n_channels (int): Number of channels recorded in input.
embed_dim (int): Number of dimensions into which to embed each patch.
reconstruct_dim (int): Number of embedding dimensions from which to reconstruct patch.
bias (bool, optional): Whether to include bias in patch embedding. Defaults to True.
norm_layer (torch.nn.Module | None, optional): Layer with which to normalise embedding.
Defaults to None: no normalisation.
**conv_kwargs: Keyword arguments to pass for convolution layer instantiation.
"""
super().__init__()
if len(input_shape) not in _Conv_dim_dict.keys():
raise ValueError(f"{len(input_shape)}D input not supported")
if input_shape.keys() != patch_shape.keys():
raise ValueError(
f"input_shape dims {input_shape.keys()}"
f" must equal patch_shape dims {patch_shape.keys()}."
)
self.input_shape = input_shape
self.patch_shape = patch_shape
self.grid_shape, self.n_patches = self._count_patches(self.input_shape)
self.n_channels = n_channels
self.embed_dim = embed_dim
conv_class = _Conv_dim_dict[len(input_shape)]
kernel_size = tuple(patch_shape[dim] for dim in input_shape.keys())
self.proj = conv_class(
n_channels,
embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=bias,
**conv_kwargs
)
if norm_layer is None:
self.norm = torch.nn.Identity()
else:
self.norm = norm_layer(embed_dim)
self.reconstruct_dim = reconstruct_dim
self.reconstruct_layer = torch.nn.Linear(
reconstruct_dim, prod(kernel_size) * n_channels, bias=True
)
def _count_patches(self, input_shape: dict[Any, int]) -> tuple[dict[Any, int], int]:
grid_shape = {dim: s_i // self.patch_shape[dim] for dim, s_i in input_shape.items()}
n_patches = prod(grid_shape.values())
return grid_shape, n_patches
def _define_position_embedding(self, embed_dim: int) -> torch.nn.Parameter:
"""Create a sin-cos embedding of positions on an n-dimensional grid.
Args:
embed_dim (int): Number of dimensions into which to embed grid positions.
Returns:
torch.Parameter: Array with each row the sin-cos embedding of a grid position.
"""
# Divide sin and cos embeddings according to size of grid in each dimension
embed_dims = partition_embed_dim(embed_dim, dim_ratio=list(self.grid_shape.values()))
# Define grid: len(grid_shape)-tuple of np.arrays of shape grid_shape
grid = torch.meshgrid([
torch.arange(grid_s, dtype=torch.float32) for grid_s in self.grid_shape.values()
], indexing='ij')
# Create sincos embedding: np.array of shape (prod(grid_shape), 2 * sum(embed_dims))
omegas = [calc_embed_omega(dim) for dim in embed_dims]
embedding = torch.cat([
sincos_embed_coords(coords.flatten(), omega) for coords, omega in zip(grid, omegas)
], dim=1)
return torch.nn.Parameter(embedding.float().unsqueeze(0), requires_grad=False)
[docs]
def initialize_weights(self):
"""Set up embedder weights and parameters.
"""
# define fixed sin-cos embeddings of patch position within image
self.pos_embed = self._define_position_embedding(self.embed_dim)
self.decoder_pos_embed = self._define_position_embedding(self.reconstruct_dim)
# initialize projection like nn.Linear (instead of nn.Conv2d)
w = self.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
[docs]
def tokenify(self, inputs: torch.Tensor) -> torch.Tensor:
"""Reshape batched input n-D images into sequences of patch tokens.
Shapes are given for the example of a batch of B, C-channel 2D images,
with image size (Hh, Ww) and patch size (h, w).
Args:
inputs (torch.Tensor): Batched sequence of images, shape e.g. [B, C, Hh, Ww]
Returns:
torch.Tensor: Batched sequences of patch tokens, shape e.g. [B, HW, hwC]
"""
input_dims, n_dim = self.input_shape.keys(), len(self.input_shape)
if len(inputs.shape) - 2 != n_dim:
raise ValueError(
f"{len(inputs.shape) - 2}-D input not divisible into {n_dim}-D patches"
)
for dim, s_i in zip(input_dims, inputs.shape[2:]):
s_p = self.patch_shape[dim]
if s_i % s_p != 0:
raise ValueError(
f"Input dimension {dim} not divisible into patches ({s_i} % {s_p} != 0)"
)
input_shape = dict(zip(input_dims, inputs.shape[2:]))
grid_shape, n_patches = self._count_patches(input_shape)
x = inputs.reshape(shape=(
(inputs.shape[0], self.n_channels)
+ sum([(grid_shape[dim], self.patch_shape[dim]) for dim in input_dims], ())
))
# (B, C, H, h, W, w) for 2D patches
x = x.permute(
0, *range(2, 2 + 2 * n_dim, 2), *range(3, 3 + 2 * n_dim, 2), 1
)
# (B, H, W, h, w, C) for 2D patches
patch_size = prod(self.patch_shape.values())
x = x.reshape(shape=(inputs.shape[0], n_patches, patch_size * self.n_channels))
# (B, H W, h w C) for 2D patches
return x
[docs]
def untokenify(self, x: torch.Tensor,
output_shape: dict[Any, int] | None = None) -> torch.Tensor:
"""Reshape batched sequences of patch tokens into n-D images.
Shapes are given for the example of a batch of B, C-channel 2D images,
with image size (Hh, Ww) and patch size (h, w).
Args:
x (torch.Tensor): Batched sequences of tokens, shape e.g. [B, HW, hwC]
output_shape (dict[Any, int], optional): Sizes of output dimensions required,
if different from input_shape in class instantiation.
Returns:
torch.Tensor: Batched sequence of images, shape e.g. [B, C, Hh, Ww]
"""
input_dims, n_dim = self.input_shape.keys(), len(self.input_shape)
patch_shape = tuple(self.patch_shape[dim] for dim in input_dims)
if output_shape is None:
grid_shape, n_patches = self.grid_shape, self.n_patches
else:
if len(output_shape) != n_dim:
raise ValueError(
f"{len(output_shape)}-D output not formable from {n_dim}-D patches"
)
if self.input_shape.keys() != output_shape.keys():
raise ValueError(
f"output_shape dims {output_shape.keys()}"
f" must equal input_shape dims {self.input_shape.keys()}."
)
output_shape = {dim: output_shape[dim] for dim in input_dims}
grid_shape, n_patches = self._count_patches(output_shape)
grid_shape = tuple(grid_shape[dim] for dim in input_dims)
if x.shape[1] != n_patches:
raise ValueError(
f"Grid shape {grid_shape} not formable from {x.shape[1]} patches"
)
if x.shape[2] != prod(patch_shape) * self.n_channels:
raise ValueError(
f"{self.n_channels}-channel {patch_shape} patch not formable"
f" from {x.shape[2]} values"
)
x = x.reshape(shape=(x.shape[0],) + grid_shape + patch_shape + (self.n_channels,))
# (N, H, W, h, w, C) for 2D patches
x = x.permute(
0, -1, *sum([(id, n_dim + id) for id in range(1, 1 + n_dim)], ())
)
# (N, C, H, h, W, w) for 2D patches
imgs = x.reshape(shape=(
(x.shape[0], self.n_channels)
+ sum([(s_g * s_p,) for s_g, s_p in zip(grid_shape, patch_shape)], ())
))
# (N, C, H h, W w) for 2D patches
return imgs
[docs]
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert batched input n-D images into sequences of patch embeddings.
Shapes are given for the example of a batch of B, C-channel 2D images,
with image size (Hh, Ww) and patch size (h, w), with a D-D embedding.
Args:
x (torch.Tensor): Batched sequence of images, shape e.g. [B, C, Hh, Ww].
Returns:
tuple[torch.Tensor, torch.Tensor]:
* Batched sequences of embeddings, shape e.g. [B, HW, D]
* Batched sequences of positions, shape e.g. [B, HW, D]
"""
for (dim, s_expected), s_actual in zip(self.input_shape.items(), x.shape[2:]):
if s_actual != s_expected:
raise ValueError(
f"Input dimension {dim} ({s_actual}) doesn't match specification"
f" ({s_expected})"
)
x = self.proj(x) # Project each patch into embedding, by convolution
x = x.flatten(start_dim=2).transpose(1, 2) # (B, D, H, W) -> (B, HW, D) for 2D patches
x = self.norm(x)
x = x + self.pos_embed
return x, self.decoder_pos_embed.expand(x.shape[0], -1, -1)
[docs]
def reconstruct_tokens(self, x: torch.Tensor) -> torch.Tensor:
"""Reconstruct patch tokens from an embedding.
Args:
x (torch.Tensor): Batched sequence of images, shape e.g. [B, C, Hh, Ww].
Returns:
tuple[torch.Tensor, torch.Tensor]:
"""
return self.reconstruct_layer(x)
[docs]
class STPatchEmbed(PatchEmbed):
""" 1-3D spatiotemporally located input to Patch Embedding
"""
def __init__(self,
input_shape: dict[Any, int],
patch_shape: dict[Any, int],
n_channels: int,
position_space: dict[Any, tuple[float, float]],
embed_dim: int,
reconstruct_dim: int,
pos_embed_ratio: dict[Any, float],
bias: bool = True,
norm_layer: torch.nn.Module | None = None,
**conv_kwargs):
"""Instantiate embedder for patches in 1-3D, n-channel images.
Args:
input_shape (dict[Any, int]): Sizes of input dimensions.
patch_shape (dict[Any, int]): Shape of patches into which input will be divided.
n_channels (int): Number of channels recorded in image.
position_space (dict[Any, tuple[float, float]]): Space in which pixels are positioned.
For example, for coords x (X-periodic) and y (aperiodic with minimum Y, range ΔY):
{'x': (0, X), 'y': (Y, Y + 2ΔY)}
embed_dim (int): Number of dimensions into which to embed each patch.
reconstruct_dim (int): Number of embedding dimensions from which to reconstruct patch.
pos_embed_ratio (dict[Any, float]): Relative sizes of position embedding dimensions.
bias (bool, optional): Whether to include bias in patch embedding. Defaults to True.
norm_layer (torch.nn.Module | None, optional): Layer with which to normalise embedding.
Defaults to None: no normalisation.
**conv_kwargs: Keyword arguments to pass for convolution layer instantiation.
"""
super().__init__(
input_shape=input_shape,
patch_shape=patch_shape,
n_channels=n_channels,
embed_dim=embed_dim,
reconstruct_dim=reconstruct_dim,
bias=bias,
norm_layer=norm_layer,
**conv_kwargs,
)
self.position_space = position_space
self.pos_embed_ratio = [pos_embed_ratio[dim] for dim in self.position_space.keys()]
self.encoder_omegas = torch.nn.ParameterList()
self.decoder_omegas = torch.nn.ParameterList()
def _define_position_embedding(self, embed_dim: int, omegas: torch.nn.ParameterList
) -> Callable[[torch.Tensor], torch.Tensor]:
# Divide sin and cos embeddings according to size of grid in each dimension
embed_dims = partition_embed_dim(embed_dim, dim_ratio=self.pos_embed_ratio)
st_dim = len(self.position_space)
conv_fn = _conv_dim_dict[len(self.input_shape)]
kernel_shape = tuple(self.patch_shape[dim] for dim in self.input_shape.keys())
self.pos_conv_kernel = torch.nn.Parameter(
torch.ones((st_dim, 1) + kernel_shape) / prod(kernel_shape), requires_grad=False,
)
omegas.extend([
torch.nn.Parameter(calc_embed_omega(dim, period=x_max - x_min), requires_grad=False)
for dim, (x_min, x_max) in zip(embed_dims, self.position_space)
])
def embedding(pos: torch.Tensor) -> torch.Tensor:
# pos shape B, 2, Hh, Ww for 2D patches
batch_size = pos.shape[0]
if pos.shape[1] != len(self.position_space):
raise ValueError(
f"{pos.shape[1]}-D position space doesn't match spec."
f" ({len(self.position_space)})"
)
for (dim, s_spec), s_pos in zip(self.input_shape.items(), pos.shape[2:]):
if s_pos != s_spec:
raise ValueError(
f"Input positions dimension {dim} ({s_pos}) doesn't match spec. ({s_spec})"
)
pos = conv_fn(
pos, self.pos_conv_kernel, stride=kernel_shape, groups=st_dim
) # B, 2, H, W for 2D patches
pos = pos.transpose(0, 1).flatten(start_dim=1) # 2, BHW for 2D patches
embeddings = torch.cat([
sincos_embed_coords(coords, omega) for coords, omega in zip(pos, omegas)
], dim=1) # BHW, D for 2D patches
return embeddings.reshape([batch_size, -1, embed_dim]) # B, HW, D for 2D patches
return embedding
[docs]
def initialize_weights(self):
"""Set up embedder weights and parameters.
"""
# define fixed sin-cos embedding functions for patch positions
self.pos_embed = self._define_position_embedding(self.embed_dim, self.encoder_omegas)
self.decoder_pos_embed = self._define_position_embedding(self.reconstruct_dim, self.decoder_omegas)
# initialize projection like nn.Linear (instead of nn.Conv2d)
w = self.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
[docs]
def forward(self, st_input: tuple[torch.Tensor, torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert batched spatiotemporal inputs into sequences of patch embeddings.
Shapes are given for the example of a batch of B, C-channel 2D inputs in N-D space,
with image size (Hh, Ww) and patch size (h, w), with a D-D embedding.
Args:
st_input (tuple[torch.Tensor, torch.Tensor]): Spatiotemporal input, combining:
* Batched sequence of values, shape e.g. [B, C, Hh, Ww].
* Batched sequence of positions, shape e.g. [B, N, Hh, Ww].
Returns:
tuple[torch.Tensor, torch.Tensor]:
* Batched sequences of embeddings, shape e.g. [B, HW, D]
* Batched sequences of position embeddings, shape e.g. [B, HW, D_d]
"""
x, pos = st_input
if x.shape[1] != self.n_channels:
raise ValueError(
f"# of input channels ({x.shape[1]}) doesn't match spec. ({self.n_channels})"
)
for (dim, s_spec), s_x in zip(self.input_shape.items(), x.shape[2:]):
if s_x != s_spec:
raise ValueError(
f"Input values dimension {dim} ({s_x}) doesn't match spec. ({s_spec})"
)
x = self.proj(x) # Project each patch into embedding, by convolution
x = x.flatten(start_dim=2).transpose(1, 2) # (B, D, H, W) -> (B, HW, D) for 2D patches
x = self.norm(x)
x = x + self.pos_embed(pos)
return x, self.decoder_pos_embed(pos)
[docs]
def tokenify(self, inputs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
"""Reshape batched n-D input values into sequences of patch tokens.
Shapes are given for the example of a batch of B, C-channel 2D images,
with image size (Hh, Ww) and patch size (h, w).
Args:
st_input (tuple[torch.Tensor, torch.Tensor]): Spatiotemporal input, combining:
* Batched sequence of values, shape e.g. [B, C, Hh, Ww].
* Batched sequence of positions, shape e.g. [B, N, Hh, Ww].
Returns:
torch.Tensor: Batched sequences of patch tokens, shape e.g. [B, HW, hwC]
"""
return PatchEmbed.tokenify(self, inputs[0])
[docs]
class BoundedPatchEmbed(STPatchEmbed):
def _define_position_embedding(self, embed_dim: int, omegas: torch.nn.ParameterList
) -> Callable[[torch.Tensor], torch.Tensor]:
# Divide sin and cos embeddings according to size of grid in each dimension
embed_dims = partition_embed_dim(embed_dim, dim_ratio=self.pos_embed_ratio)
self.patch_grid_coords = torch.nn.ParameterList([
torch.nn.Parameter(
torch.arange(0.5, self.grid_shape[dim]) / self.grid_shape[dim]
if dim in self.grid_shape else torch.tensor([0.5]),
requires_grad=False
)
for dim in self.position_space.keys()
])
omegas.extend([
torch.nn.Parameter(calc_embed_omega(ndim, period=x_max - x_min), requires_grad=False)
for ndim, (x_min, x_max) in zip(embed_dims, self.position_space.values())
])
def embedding(bounds_batch: torch.Tensor) -> torch.Tensor:
# [[[bottom, top], [left, right]], ...] for 2D patches
batch_size, ndim = bounds_batch.shape[0], len(self.position_space)
if bounds_batch.shape[1] != ndim:
raise ValueError(
f"{bounds_batch.shape[1]}-D position space doesn't match {ndim}-D spec."
)
if bounds_batch.shape[2] != 2:
raise ValueError(
f"Input position bounds (e.g. {bounds_batch[0, 0]}) must be 2-element ranges."
)
pos = torch.cat([
torch.cartesian_prod(*[
b_min + (b_max - b_min) * patch_coords
for (b_min, b_max), patch_coords in zip(bounds, self.patch_grid_coords)
])
for bounds in bounds_batch
], dim=0).transpose(0, 1) # 2, BHW for 2D patches
embeddings = torch.cat([
sincos_embed_coords(coords, omega) for coords, omega in zip(pos, omegas)
], dim=1) # BHW, D for 2D patches
return embeddings.reshape([batch_size, -1, embed_dim]) # B, HW, D for 2D patches
return embedding