Source code for src.FRAME_FM.models.demo_simple_conv

from typing import Sequence
from FRAME_FM.utils.embedders import BaseEmbedder, PatchEmbed
from FRAME_FM.utils.LightningModuleWrapper import BaseModule


[docs] class MyModel(BaseModule): pass
import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # import torch.multiprocessing as mp # mp.set_start_method("spawn", force=True)
[docs] class SpatialCollapse(nn.Module): """ Input: [B, V, H, W] Output: [B] """ def __init__(self, in_variables, hidden_channels=32, dropout=0.2): super().__init__() # --- Conv stack ---
[docs] self.conv1 = nn.Conv2d(in_variables, hidden_channels, kernel_size=3, padding=1)
[docs] self.bn1 = nn.BatchNorm2d(hidden_channels)
[docs] self.conv2 = nn.Conv2d(hidden_channels, hidden_channels * 2, kernel_size=3, padding=1)
[docs] self.bn2 = nn.BatchNorm2d(hidden_channels * 2)
[docs] self.conv3 = nn.Conv2d(hidden_channels * 2, hidden_channels * 4, kernel_size=3, padding=1)
[docs] self.bn3 = nn.BatchNorm2d(hidden_channels * 4)
[docs] self.dropout = nn.Dropout2d(dropout)
[docs] self.pool = nn.AdaptiveAvgPool2d((1, 1))
[docs] self.head = nn.Linear(hidden_channels * 4, 1)
self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") nn.init.zeros_(m.bias)
[docs] def forward(self, x): # x: [B, V, H, W] x = F.relu(self.bn1(self.conv1(x))) x = self.dropout(x) x = F.relu(self.bn2(self.conv2(x))) x = self.dropout(x) x = F.relu(self.bn3(self.conv3(x))) x = self.pool(x) # [B, C, 1, 1] x = x.view(x.size(0), -1) # [B, C] x = self.head(x) # [B, 1] return x.squeeze(-1) # [B]
from pathlib import Path import xarray as xr import pandas as pd
[docs] KERCHUNK_ZIP = "tests/transforms/fixtures/ecmwf-era5X_oper_an_sfc_2000_2020_2d_repack.kr1.0.json.zip"
from FRAME_FM.utils.data_utils import load_data_from_uri from FRAME_FM.transforms import transform_mapping, ToTensorTransform if __name__ == "__main__":
[docs] n_variables = 5
model = SpatialCollapse(in_variables=n_variables) x = torch.randn(16, n_variables, 128, 128) y = model(x) print(y.shape) # Read in this zipped Kerchunk file and modify it and then do a basic Pytorch # training loop with it. ds = load_data_from_uri(KERCHUNK_ZIP) transforms = [ {"type": "reverse_axis", "dim": "latitude"}, {"type": "subset", "variables": "d2m", "time": ("2001-01-01", "2001-01-01T00:05:00"), "latitude": (60, -30), "longitude": (40, 160)}, ] for transform in reversed(transforms): if transform["type"] not in transform_mapping: raise ValueError(f"Unsupported transform type: {transform['type']}") transform_class = transform_mapping[transform["type"]] transform = transform_class(**{k: v for k, v in transform.items() if k != "type"}) ds = transform(ds) # Calculate the mean and std for each time step mean = ds["d2m"].mean(dim=("latitude", "longitude")) std = ds["d2m"].std(dim=("latitude", "longitude")) # Write them to a csv file via pandas csv_file = "d2m_mean_std.csv" pd.DataFrame({"time": ds.time.values, "mean": mean.values, "std": std.values}).to_csv(csv_file, index=False) print(f"Saved mean and std to {csv_file}") # Duplicate the variable dimension 5 times to create a fake "variable" dimension for the model n_variables = 5 array = [] for i in range(n_variables): array.append(ds["d2m"].values[None, ...]) # [1, T, lat, lon] data = np.concatenate(array, axis=0) # [V=2, T, lat, lon] # Reshape to [T, V, lat, lon] data = data.transpose(1, 0, 2, 3) data = ToTensorTransform()(data) # Convert to PyTorch tensor # Final dataset is now a PyTorch tensor, ready for training print(data.shape) # Should be [V, T, lat, lon] after subsetting model = SpatialCollapse(in_variables=n_variables) x = data[:16] # Take a batch of 16 time steps y = model(x) print(y.shape) # Now create a 5 epoch training loop with this data and model optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = nn.MSELoss() for epoch in range(5): model.train() optimizer.zero_grad() output = model(x) loss = criterion(output, y) loss.backward() optimizer.step() print(f"Epoch {epoch + 1}, Loss: {loss.item()}") # Save the trained model torch.save(model.state_dict(), "spatial_collapse_model.pth")