Transforms

The “tranforms” directory is a location for all transformation classes and relevant utilities used to integrate them into the FRAME-FM framework.

Overview of Transforms

PyTorch Dataset classes typically employ a sequence of transform objects that modify input data into a form suitable for model training or inference.

This usually happens in the __getitem__() method, but it may also be appropriate to use transforms elsewhere, such as in the __init__() method when initial modifications are required.

Within the FRAME_FM package, transforms are all children of the FRAME_FM.transforms.BaseTransform class. They are typically run as follows:

from FRAME_FM.transforms import StandardizeTransform

ds = xr.load_dataset(<some_dataset>)
da = ds["d2m"].isel(time=0)

standardize_transform = StandardizeTransform()
standardized_da = standardize_transform(
    da,
    mean=float(da.mean()),
    std=float(da.std())
)

Using pre-existing Transforms

Other types of PyTorch transforms, such as those in the torchvision library, can be included in a sequence of transforms:

Defining your own Transforms

You can create your own transforms, by ensuring they inherit the BaseTransform class and follow the following structure:

import xarray as xr
DS = xr.Dataset
from FRAME_FM.transforms import BaseTransform

class Add1Transform(BaseTransform):
    """
    Adds 1 to the data array.
    """
    def __init__(self, var_id: str):
        # You can define your own input parameters
        self.var_id = var_id

    def __call__(self, sample: DS) -> DS:
        # This must receive a data object, such as an xarray Dataset/DataArray
        check_object_type(sample, allowed_types=DS, caller=self.__class__.__name__)

        sample = sample[self.var_id] += 1
        return sample