# Copyright (c) Matt Arran.
# This source code is adapted from code (c) Meta Platforms, Inc. and affiliates,
# licensed under the licence in the licences/LICENSE_MAE.txt file.
# --------------------------------------------------------
# References:
# MAE: https://github.com/facebookresearch/mae/tree/main
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# --------------------------------------------------------
from timm.models.vision_transformer import Block
import torch
from torch import nn
from ..utils.embedders import BaseEmbedder, PatchEmbed, STPatchEmbed, BoundedPatchEmbed
from ..utils.LightningModuleWrapper import BaseModule
def _select_embedder(input_shape: tuple[int, ...],
n_channel: int,
patch_shape: tuple[int, ...],
positioned: str = "",
pos_space: tuple[tuple[float, float], ...] | None = None,
embed_ratio: tuple[float, ...] | None = None,
embed_dim: int = 16,
reconstruct_dim: int = 16) -> BaseEmbedder:
if not positioned:
return PatchEmbed(
input_shape, patch_shape, n_channel, embed_dim, reconstruct_dim
)
assert pos_space is not None, \
f"If inputs of shape {input_shape} have positions, position_space must not be None."
assert embed_ratio is not None, \
f"If inputs of shape {input_shape} have positions, pos_embed_ratio must not be None."
if positioned == "pixels":
return STPatchEmbed(
input_shape, patch_shape, n_channel, pos_space, embed_dim, reconstruct_dim, embed_ratio
)
if positioned == "bounds":
return BoundedPatchEmbed(
input_shape, patch_shape, n_channel, pos_space, embed_dim, reconstruct_dim, embed_ratio
)
raise ValueError(f"Position specification ({positioned}) must be '', 'pixels', or 'bounds'.")
[docs]
class MultimodalMaskedAutoencoder(BaseModule):
"""Masked Autoencoder with flexible multi-input embeddings and transformer backbone
"""
def __init__(self,
input_shapes: list[tuple[int, ...]],
n_channels: list[int],
patch_shapes: list[tuple[int, ...]],
inputs_positioned: list[str] | str = "",
position_space: tuple[tuple[float, float], ...] | None = None,
pos_embed_ratio: tuple[float, ...] | None = None,
encoder_embed_dim: int = 16,
encoder_depth: int = 24,
encoder_num_heads: int = 16,
decoder_embed_dim: int = 16,
decoder_depth: int = 8,
decoder_num_heads: int = 16,
mlp_ratio: float = 4.,
norm_layer: type[nn.LayerNorm] = nn.LayerNorm,
norm_token_loss: bool = False,
learning_rate: float = 1.e-3,
default_mask_ratio: float = 0.75):
"""Instantiate Multimodal Masked Autoencoder
Args:
input_shapes (list[tuple[int, ...]]): Shapes of each model input.
n_channels (list[int]): Numbers of channels in each model input.
patch_shapes (list[tuple[int, ...]]): Sizes of patches into which to divide each input.
inputs_positioned (list[str] | str): How positions of model inputs are provided:
"": no position data, "pixels": pixel coordinates, "bounds": coordinate bounds.
Any single string is taken to apply to all inputs. Defaults to "".
position_space (tuple[tuple[float, float], ...] | None): Space in which positions lie,
or None if no input has positions. Defaults to None.
pos_embed_ratio (tuple[float, ...] | None): Relative sizes of position embedding dim.s,
or None if no input has positions. Defaults to None.
encoder_embed_dim (int): Dimensions into which to embed each patch. Defaults to 16.
encoder_depth (int, optional): Number of attention layers for encoding. Defaults to 24.
encoder_num_heads (int, optional): Number of attention heads per layer. Defaults to 16.
decoder_embed_dim (int). Dimensions from which to reconstruct each patch. Defaults to 16.
decoder_depth (int, optional): Number of attention layers for decoding. Defaults to 8.
decoder_num_heads (int, optional): Number of attention heads per layer. Defaults to 16.
mlp_ratio (float, optional): Ratio of MLP and embedding dimensions in attention blocks.
Defaults to 4..
norm_layer (type[nn.LayerNorm], optional): Layer class for [en/de]coder normalisation.
Defaults to nn.LayerNorm.
norm_token_loss (bool, optional): Whether to variance-normalise per-token loss.
Defaults to False.
learning_rate (float): Initial learning rate for Adam optimizer. Defaults to 1.e-3.
default_mask_ratio (float): Default proportion of token embeddings to mask per batch.
"""
super().__init__()
# --------------------------------------------------------------------------
if isinstance(inputs_positioned, str):
inputs_positioned = [inputs_positioned for _ in input_shapes]
input_properties = zip(input_shapes, patch_shapes, n_channels, inputs_positioned)
self.input_embedders = nn.ModuleList([
_select_embedder(
input_shape,
n_channel,
patch_shape,
positioned,
position_space,
pos_embed_ratio,
encoder_embed_dim,
decoder_embed_dim
)
for input_shape, patch_shape, n_channel, positioned in input_properties
]) # type: ignore
[docs]
self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_embed_dim))
[docs]
self.ec_mode_embeddings = nn.ParameterList(
[nn.Parameter(torch.zeros(1, 1, encoder_embed_dim))] * len(input_shapes)
)
[docs]
self.blocks = nn.ModuleList([
Block(encoder_embed_dim, encoder_num_heads, mlp_ratio,
qkv_bias=True, norm_layer=norm_layer)
for _ in range(encoder_depth)
])
[docs]
self.norm = norm_layer(encoder_embed_dim)
[docs]
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
[docs]
self.dc_mode_embeddings = nn.ParameterList(
[nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))] * len(input_shapes)
)
[docs]
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
[docs]
self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio,
qkv_bias=True, norm_layer=norm_layer)
for _ in range(decoder_depth)
])
[docs]
self.decoder_norm = norm_layer(decoder_embed_dim)
# --------------------------------------------------------------------------
[docs]
self.norm_token_loss = norm_token_loss
[docs]
self.learning_rate = learning_rate
[docs]
self.default_mask_ratio = default_mask_ratio
self.initialize_weights()
[docs]
def initialize_weights(self):
"""Initialise layer weights and parameters, including in input embedders.
"""
# initialization
input_iter = zip(self.input_embedders, self.ec_mode_embeddings, self.dc_mode_embeddings)
for input_embedder, ec_mode_embedding, dc_mode_embedding in input_iter:
input_embedder.initialize_weights()
nn.init.normal_(ec_mode_embedding, std=.02)
nn.init.normal_(dc_mode_embedding, std=.02)
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
nn.init.normal_(self.cls_token, std=.02)
nn.init.normal_(self.mask_token, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
[docs]
def random_masking(self, x: torch.Tensor, mask_ratio: float
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Shuffle batched token embeddings and mask random selection.
Args:
x (torch.Tensor): Batched token embeddings, shape [B, L, D].
mask_ratio (float): Proportion p of token embeddings to mask per batch.
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
* Randomly selected token embeddings, shape [B, pL, D].
* Mask with 0 where token extracted, 1 otherwise, shape [B, L].
* IDs with which to restore original, unshuffled token embeddings.
"""
B, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(B, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([B, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
[docs]
def forward_encoder(self, inputs: list, mask_ratio: float
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Tokenise and embed inputs, randomly mask tokens, and encode using a transformer.
Args:
inputs (list): Batched model inputs, for conversion by input_embedders into
token and position embeddings of shapes ([B, L_i, D])_i and ([B, L_i, D_d])_i
mask_ratio (float): Proportion p of token embeddings to mask per batch.
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
* Encodings of randomly selected input embeddings, shape [B, 1 + (1-p)sum(L_i), D].
* Batched mode and position embeddings for decoder, shape [B, sum(L_i), D_d]
* Mask with 0 where token extracted, 1 otherwise, shape [B, sum(L_i)].
* IDs with which to restore original, unshuffled token embeddings, shape [B, sum(L_i)].
"""
# embed inputs and positions
x, metadata_embed = [], []
iiter = zip(inputs, self.input_embedders, self.ec_mode_embeddings, self.dc_mode_embeddings)
for inpt, embedder, ec_mode_embedding, dc_mode_embedding in iiter:
input_embedding, decoder_pos_embedding = embedder(inpt)
B, L_i = input_embedding.shape[:2]
x.append(input_embedding + ec_mode_embedding.expand(B, L_i, -1))
metadata_embed.append(decoder_pos_embedding + dc_mode_embedding.expand(B, L_i, -1))
x = torch.cat(x, dim=1)
metadata_embed = torch.cat(metadata_embed, dim=1)
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append cls token
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, metadata_embed, mask, ids_restore
[docs]
def forward_decoder(self, x: torch.Tensor, ids_restore: torch.Tensor,
metadata_embed: torch.Tensor) -> list[torch.Tensor]:
"""Transform encoding of masked inputs, decode using a transformer, and reconstruct tokens.
Args:
x (torch.Tensor): Encodings of shuffled, masked tokens, shape [B, 1 + (1-p)L, D].
ids_restore (torch.Tensor): IDs with which to restore original, unshuffled encodings,
shape [B, L].
metadata_embed (torch.Tensor): Encodings of input mode and positions, shape [B, L, D_d]
Returns:
list[torch.Tensor]: Decoded tokens for each input, as reconstructed by input_embedders,
shapes ([B, L_i, D_i])_i with sum(L_i) = L.
"""
# embed latent representation in decoder space
x = self.decoder_embed(x)
# append mask tokens to sequence, excluding cls token
mask_tokens = self.mask_token.expand(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], -1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
# unshuffle
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
# add position embedding
x_ = x_ + metadata_embed
# append cls token
x = torch.cat([x[:, :1, :], x_], dim=1)
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
preds, start_patch = [], 1
for embedder in self.input_embedders:
end_patch = start_patch + embedder.n_patches
preds.append(embedder.reconstruct_tokens(x[:, start_patch:end_patch]))
start_patch = end_patch
return preds
[docs]
def forward_loss(self, inputs: list, predictions: list[torch.Tensor],
mask: torch.Tensor) -> torch.Tensor:
"""Calculate masked-token MSE between batched inputs and model predictions.
Args:
inputs (list): Batched model inputs, for conversion to tokens by input_embedders.
predictions (list[torch.Tensor]): Model predictions, shapes ([B, L_i, D_i])_i.
mask (torch.Tensor): Mask with 1 where token masked, shape [B, sum(L_i)].
Returns:
torch.Tensor: Average mean squared error over the batch, shape [1].
"""
losses = []
for ie, inpt, prediction in zip(self.input_embedders, inputs, predictions):
target = ie.tokenify(inpt)
norm = target.var(dim=[2, 3], keepdim=True) if self.norm_token_loss else 1
loss = (prediction - target) ** 2 / norm
losses.append(loss.mean(dim=-1))
loss = torch.concat(losses, dim=1)
loss = (loss * mask).sum() / mask.sum()
return loss
[docs]
def forward(self, inputs: list, mask_ratio: float = 0.75
) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
"""Apply MMMAE to inputs and return the loss, predictions, and mask.
Args:
inputs (list): Batched model inputs.
mask_ratio (float, optional): Proportion of token embeddings to mask per batch.
Defaults to 0.75.
Returns:
tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
* Mean squared error of model predictions, over masked tokens, shape [1].
* Model predictions of input tokens, shapes ([B, L_i, D_i])_i.
* Mask with 0 where token extracted, 1 otherwise, shape [B, sum(L_i)].
"""
latent, metadata_embed, mask, ids_restore = self.forward_encoder(inputs, mask_ratio)
preds = self.forward_decoder(latent, ids_restore, metadata_embed)
loss = self.forward_loss(inputs, preds, mask)
return loss, preds, mask
def _sharedStep(self, inputs):
loss, _, _ = self(inputs, mask_ratio=self.default_mask_ratio)
return loss, {}
[docs]
def training_step_body(self, batch, batch_idx):
return self._sharedStep(batch)
[docs]
def validation_step_body(self, batch, batch_idx):
return self._sharedStep(batch)
[docs]
def test_step_body(self, batch, batch_idx):
return self._sharedStep(batch)