Source code for src.FRAME_FM.transforms.transforms
# Define transforms
import xarray as xr
import cf_xarray # noqa: F401 - We just need to register the accessor for CF-compliant operations on xarray objects
import numpy as np
import torch
from FRAME_FM.utils.transform_utils import check_object_type
from FRAME_FM.utils.data_utils import convert_subset_selectors_to_slices
[docs]
class BaseTransform:
def __init__(self, *args, **kwargs):
pass
[docs]
def __call__(self, sample):
raise NotImplementedError("Transform must implement the __call__ method.")
[docs]
class FillMissingValueTransform(BaseTransform):
def __init__(self, strategy: str = "constant", fill_value: None | float = None,
method: None | str = "linear"):
[docs]
def __call__(self, sample: DS | DA) -> DS | DA:
# Implement missing value filling logic here
check_object_type(sample, allowed_types=(DS, DA), caller=self.__class__.__name__)
# Depending on the method, implement infilling strategy
if self.strategy == "constant":
if self.fill_value is None:
raise ValueError("fill_value must be provided for 'constant' method.")
filled = sample.fillna(self.fill_value)
elif self.strategy == "interpolate":
filled = sample.interpolate_na(dim=None, method=self.method) # type: ignore
else:
raise ValueError(f"Unsupported fill strategy: {self.strategy}")
return filled
[docs]
class NormalizeTransform(BaseTransform):
def __init__(self, mean: float, std: float):
[docs]
def __call__(self, sample: DA) -> DA:
# Implement normalization logic here
check_object_type(sample, allowed_types=DA, caller=self.__class__.__name__)
return (sample - self.mean) / self.std
[docs]
class RenameTransform(BaseTransform):
def __init__(self, var_id: str, new_name: str):
[docs]
def __call__(self, sample: DS) -> DS:
# Implement renaming logic here
check_object_type(sample, allowed_types=DS, caller=self.__class__.__name__)
sample = sample.rename_vars({self.var_id: self.new_name})
return sample
[docs]
class ResampleTransform(BaseTransform):
def __init__(self, dim: str, freq: str | int , method: str = "mean"):
[docs]
def __call__(self, sample):
# Implement resampling logic here
check_object_type(sample, allowed_types=(DS, DA), caller=self.__class__.__name__)
if self.method not in ["mean", "sum", "max", "min", "median"]:
raise ValueError(f"Unsupported resampling method: {self.method}")
# Choose resample if we have a time dimension, otherwise use coarsen for spatial dimensions
if self.dim == "time":
resampled = sample.resample({self.dim: self.freq})
else:
resampled = sample.coarsen({self.dim: self.freq}, boundary="trim")
if not hasattr(resampled, self.method):
raise ValueError(f"Invalid resample method: {self.method}")
result = getattr(resampled, self.method)()
return result
[docs]
class ReshapeTransform(BaseTransform):
def __init__(self, shape: tuple):
[docs]
def __call__(self, sample):
# Implement reshaping logic here
check_object_type(sample, allowed_types=DA, caller=self.__class__.__name__)
return sample.to_numpy().reshape(self.shape)
[docs]
class RollTransform(BaseTransform):
def __init__(self, dim: str, shift: None|int):
[docs]
def __call__(self, sample):
# Implement rolling logic here
check_object_type(sample, allowed_types=DS, caller=self.__class__.__name__)
shift = self.shift
if shift is None:
# Check if we need to roll
if float(sample[self.dim].max()) > 350 and float(sample[self.dim].min()) < 10:
shift = sample.sizes[self.dim] // 2
else:
shift = 0
print(f"Rolling {self.dim} by {shift} positions.")
rolled = sample.roll({self.dim: shift}, roll_coords=True)
# Adjust the coordinate values after rolling
coord_vals = rolled.coords[self.dim].values
rolled.coords[self.dim] = np.where(coord_vals >= 180., coord_vals - 360., coord_vals)
return rolled
[docs]
class ReverseAxisTransform(BaseTransform):
def __init__(self, dim: str):
[docs]
def __call__(self, sample):
# Implement axis reversal logic here
check_object_type(sample, allowed_types=DS, caller=self.__class__.__name__)
ds_rev = sample.isel(**{self.dim: slice(None, None, -1)})
return ds_rev
[docs]
class SortAxisTransform(BaseTransform):
def __init__(self, dim: str, ascending: bool = True):
[docs]
def __call__(self, sample):
# Implement axis sorting logic here
check_object_type(sample, allowed_types=DS, caller=self.__class__.__name__)
sorted_sample = sample.sortby(self.dim, ascending=self.ascending)
return sorted_sample
[docs]
class SubsetTransform(BaseTransform):
def __init__(self, **subset_selectors):
if "variables" in subset_selectors:
variables = subset_selectors.pop("variables")
self.variables = variables if isinstance(variables, (list, tuple)) else [variables]
else:
self.variables = None
[docs]
def __call__(self, sample):
# Implement subsetting logic here
check_object_type(sample, allowed_types=(DS, DA), caller=self.__class__.__name__)
if self.variables is None:
# If no specific variables are provided, apply the subset to all variables in
# the Dataset or the single DataArray
return sample.sel(**self.subset_selectors)
# If we have variables then we need to create a new Dataset with only those
# variables and apply the subset selectors to each variable
ds = xr.Dataset()
ds.attrs.update(sample.attrs)
for var_id in self.variables:
# Use common subset selectors unless overridden by variable-specific selectors
if self.subset_selectors:
ds[var_id] = sample[var_id].sel(**self.subset_selectors)
else:
ds[var_id] = sample[var_id]
return ds
[docs]
class SqueezeTransform(BaseTransform):
[docs]
def __call__(self, sample):
# Implement squeezing logic here
check_object_type(sample, allowed_types=(DS, DA, TT), caller=self.__class__.__name__)
return sample.squeeze()
[docs]
class TilerTransform(BaseTransform):
"""
A transform that takes a Dataset or DataArray and breaks it into smaller tiles along specified dimensions.
This uses the xarray `coarsen` + `construct` pattern to create non-overlapping tiles of the data, which can
be useful for training models on large spatial datasets by reducing memory usage and allowing for batch
processing of smaller chunks of data.
"""
def __init__(self, boundary: str = "pad", **dim_tile_sizes):
[docs]
def __call__(self, sample: DA) -> DA:
check_object_type(sample, allowed_types=DA, caller=self.__class__.__name__)
# Create the dictionary to send to the ".construct()" method, using a naming convention of
# ("{dim}_coarse", "{dim}_fine") for the new dimensions created by the tiling process.
tile_dims = {dim: (f"{dim}_coarse", f"{dim}_fine") for dim in self.tile_sizes}
coarsened = sample.coarsen(**self.tile_sizes, boundary=self.boundary).construct(**tile_dims) # type: ignore
# Prepare a stacking regrouping of the original dimensions and the new dimensions
batch_dims = []
target_dims = []
for dim in sample.dims:
if dim in self.tile_sizes:
batch_dims.append(f"{dim}_coarse")
target_dims.append(f"{dim}_fine")
else:
target_dims.append(dim)
stacked = coarsened.stack(batch_dim=batch_dims)
# Reorder to have batch_dim first, followed by the original dimensions and then the fine tile dimensions
tiled = stacked.transpose("batch_dim", *target_dims)
# Store reverse-lookup metadata in attrs
tiled.attrs.update({
"tiler_tile_sizes": self.tile_sizes,
"tiler_boundary": self.boundary,
"tiler_original_sizes": {dim: sample.sizes[dim] for dim in self.tile_sizes},
"tiler_original_coords": {dim: sample.coords[dim].values.tolist()
for dim in self.tile_sizes if dim in sample.coords},
})
return tiled
[docs]
class ToDataArray(BaseTransform):
def __init__(self, var_id: str):
[docs]
def __call__(self, sample: DS | DA) -> DA:
# Implement conversion to xarray DataArray here
check_object_type(sample, allowed_types=(DS, DA), caller=self.__class__.__name__)
if isinstance(sample, DS):
if len(sample.data_vars) != 1:
raise ValueError("ToDataArrayTransform can only be applied to Datasets with a single variable.")
return sample[self.var_id]
return sample
[docs]
class ToTensorTransform(BaseTransform):
[docs]
def __call__(self, sample: DA | np.ndarray) -> torch.Tensor:
# Implement conversion to PyTorch tensor here
check_object_type(sample, allowed_types=(DA, np.ndarray), caller=self.__class__.__name__)
if isinstance(sample, DA):
sample = sample.values
return torch.from_numpy(sample)
[docs]
class TransposeTransform(BaseTransform):
[docs]
def __call__(self, sample):
# Implement transposing logic here
check_object_type(sample, allowed_types=(DA, TT), caller=self.__class__.__name__)
return sample.transpose()
[docs]
class VarsToDimensionTransform(BaseTransform):
"""
A transform that takes a list of variables from a Dataset and stacks them into a
new dimension, effectively converting the variable dimension into a coordinate
dimension. This is useful for models that expect a single multi-channel input
rather than separate variables.
Since the purpose is to prepare the data for conversion to a Tensor, we assume
that ancillary variables that are not genuine coordinates can be dropped.
"""
[docs]
exclusion_vars = ["time_bounds", "lat_bounds", "lon_bounds",
"time_bnds", "lat_bnds", "lon_bnds",
"crs", "spatial_ref", "bounds", "bnds"]
def __init__(self, variables: list, new_dim: str, only_vars_with_time: bool = True):
[docs]
def __call__(self, sample):
# Implement logic to convert variables to a new dimension here
check_object_type(sample, allowed_types=DS, caller=self.__class__.__name__)
# Check special case of variables = "__all__", take all variables and filter out those not needed/suitable
if self.variables == "__all__":
# Exclude variables relate to bounds and coordinates
bounds_vars = set([b_list[0] for b_list in sample.cf.bounds.values()])
if self.only_vars_with_time:
vars_without_time = set([var_id for var_id in sample.data_vars
if not hasattr(sample[var_id], "time")])
else:
vars_without_time = set()
exclusion_vars = set([var_id for var_id in self.exclusion_vars if var_id in sample.data_vars])
# Combine all exclusion criteria into a single set of variables to drop
all_exclusion_vars = bounds_vars | vars_without_time | exclusion_vars
# Drop the variables from the sample.
sample.drop_vars(all_exclusion_vars)
# Remove those variables from the wish list
variables = set(sample.data_vars) - all_exclusion_vars
else:
variables = self.variables
# Create a set of arrays to concatenate together
arrays = [sample[var_id] for var_id in variables]
stacked = xr.concat(arrays, dim=self.new_dim)
return stacked
[docs]
transform_mapping = {
"fill_missing": FillMissingValueTransform,
"fill_nan": FillNaNTransform,
"normalize": NormalizeTransform,
"rename": RenameTransform,
"resample": ResampleTransform,
"reshape": ReshapeTransform,
"reverse_axis": ReverseAxisTransform,
"roll": RollTransform,
"scale": ScaleTransform,
"sort_axis": SortAxisTransform,
"squeeze": SqueezeTransform,
"subset": SubsetTransform,
"tiler": TilerTransform,
"to_dataarray": ToDataArray,
"to_tensor": ToTensorTransform,
"transpose": TransposeTransform,
"vars_to_dimension": VarsToDimensionTransform
}
[docs]
def resolve_transform(transform_config: dict) -> BaseTransform:
"""
If a transform is a dictionary with a "type" key, resolve it to the corresponding transform class instance.
If it is already an instance of a transform class, return it as is.
Args:
- transform_config (dict or BaseTransform): The transform configuration to resolve.
Returns:
- BaseTransform: An instance of a transform class.
"""
if isinstance(transform_config, BaseTransform):
return transform_config
transform_type = transform_config.get("type")
if transform_type not in transform_mapping:
raise ValueError(f"Unsupported transform type: {transform_type}")
transform_class = transform_mapping[transform_type]
return transform_class(**{k: v for k, v in transform_config.items() if k != "type"})
[docs]
def apply_transforms(data: xr.Dataset | xr.DataArray, preprocessors: list) -> xr.Dataset | xr.DataArray:
"""
Apply a list of preprocessing transforms to a data sample.
Args:
sample (xr.Dataset | xr.DataArray): The input data sample to be transformed.
preprocessors (list): A list of transform configurations to apply to the sample.
Returns:
xr.Dataset | xr.DataArray: The transformed data sample after applying all preprocessors.
"""
for preprocessor in preprocessors:
if not isinstance(preprocessor, dict) or "type" not in preprocessor:
raise ValueError(f"Each preprocessor must be a dictionary with a 'type' key. Invalid preprocessor: {preprocessor}")
data = resolve_transform(preprocessor)(data)
return data
# Create `apply_preprocessors` as an alias for `apply_transforms` to allow for more intuitive naming when used
# in the context of preprocessing steps.