Source code for src.FRAME_FM.transforms.transforms

# Define transforms
import xarray as xr
import cf_xarray  # noqa: F401 - We just need to register the accessor for CF-compliant operations on xarray objects
import numpy as np
import torch

[docs] DA = xr.DataArray
[docs] DS = xr.Dataset
[docs] TT = torch.Tensor
from FRAME_FM.utils.transform_utils import check_object_type from FRAME_FM.utils.data_utils import convert_subset_selectors_to_slices
[docs] class BaseTransform: def __init__(self, *args, **kwargs): pass
[docs] def __call__(self, sample): raise NotImplementedError("Transform must implement the __call__ method.")
[docs] class FillMissingValueTransform(BaseTransform): def __init__(self, strategy: str = "constant", fill_value: None | float = None, method: None | str = "linear"):
[docs] self.strategy = strategy
[docs] self.fill_value = fill_value
[docs] self.method = method
[docs] def __call__(self, sample: DS | DA) -> DS | DA: # Implement missing value filling logic here check_object_type(sample, allowed_types=(DS, DA), caller=self.__class__.__name__) # Depending on the method, implement infilling strategy if self.strategy == "constant": if self.fill_value is None: raise ValueError("fill_value must be provided for 'constant' method.") filled = sample.fillna(self.fill_value) elif self.strategy == "interpolate": filled = sample.interpolate_na(dim=None, method=self.method) # type: ignore else: raise ValueError(f"Unsupported fill strategy: {self.strategy}") return filled
[docs] class FillNaNTransform(FillMissingValueTransform): pass
[docs] class NormalizeTransform(BaseTransform): def __init__(self, mean: float, std: float):
[docs] self.mean = mean
[docs] self.std = std
[docs] def __call__(self, sample: DA) -> DA: # Implement normalization logic here check_object_type(sample, allowed_types=DA, caller=self.__class__.__name__) return (sample - self.mean) / self.std
[docs] class ScaleTransform(NormalizeTransform): pass
[docs] class RenameTransform(BaseTransform): def __init__(self, var_id: str, new_name: str):
[docs] self.var_id = var_id
[docs] self.new_name = new_name
[docs] def __call__(self, sample: DS) -> DS: # Implement renaming logic here check_object_type(sample, allowed_types=DS, caller=self.__class__.__name__) sample = sample.rename_vars({self.var_id: self.new_name}) return sample
[docs] class ResampleTransform(BaseTransform): def __init__(self, dim: str, freq: str | int , method: str = "mean"):
[docs] self.dim = dim
[docs] self.freq = freq
[docs] self.method = method
[docs] def __call__(self, sample): # Implement resampling logic here check_object_type(sample, allowed_types=(DS, DA), caller=self.__class__.__name__) if self.method not in ["mean", "sum", "max", "min", "median"]: raise ValueError(f"Unsupported resampling method: {self.method}") # Choose resample if we have a time dimension, otherwise use coarsen for spatial dimensions if self.dim == "time": resampled = sample.resample({self.dim: self.freq}) else: resampled = sample.coarsen({self.dim: self.freq}, boundary="trim") if not hasattr(resampled, self.method): raise ValueError(f"Invalid resample method: {self.method}") result = getattr(resampled, self.method)() return result
[docs] class ReshapeTransform(BaseTransform): def __init__(self, shape: tuple):
[docs] self.shape = shape
[docs] def __call__(self, sample): # Implement reshaping logic here check_object_type(sample, allowed_types=DA, caller=self.__class__.__name__) return sample.to_numpy().reshape(self.shape)
[docs] class RollTransform(BaseTransform): def __init__(self, dim: str, shift: None|int):
[docs] self.dim = dim
[docs] self.shift = shift
[docs] def __call__(self, sample): # Implement rolling logic here check_object_type(sample, allowed_types=DS, caller=self.__class__.__name__) shift = self.shift if shift is None: # Check if we need to roll if float(sample[self.dim].max()) > 350 and float(sample[self.dim].min()) < 10: shift = sample.sizes[self.dim] // 2 else: shift = 0 print(f"Rolling {self.dim} by {shift} positions.") rolled = sample.roll({self.dim: shift}, roll_coords=True) # Adjust the coordinate values after rolling coord_vals = rolled.coords[self.dim].values rolled.coords[self.dim] = np.where(coord_vals >= 180., coord_vals - 360., coord_vals) return rolled
[docs] class ReverseAxisTransform(BaseTransform): def __init__(self, dim: str):
[docs] self.dim = dim
[docs] def __call__(self, sample): # Implement axis reversal logic here check_object_type(sample, allowed_types=DS, caller=self.__class__.__name__) ds_rev = sample.isel(**{self.dim: slice(None, None, -1)}) return ds_rev
[docs] class SortAxisTransform(BaseTransform): def __init__(self, dim: str, ascending: bool = True):
[docs] self.dim = dim
[docs] self.ascending = ascending
[docs] def __call__(self, sample): # Implement axis sorting logic here check_object_type(sample, allowed_types=DS, caller=self.__class__.__name__) sorted_sample = sample.sortby(self.dim, ascending=self.ascending) return sorted_sample
[docs] class SubsetTransform(BaseTransform): def __init__(self, **subset_selectors): if "variables" in subset_selectors: variables = subset_selectors.pop("variables") self.variables = variables if isinstance(variables, (list, tuple)) else [variables] else: self.variables = None
[docs] self.subset_selectors = convert_subset_selectors_to_slices(subset_selectors)
[docs] def __call__(self, sample): # Implement subsetting logic here check_object_type(sample, allowed_types=(DS, DA), caller=self.__class__.__name__) if self.variables is None: # If no specific variables are provided, apply the subset to all variables in # the Dataset or the single DataArray return sample.sel(**self.subset_selectors) # If we have variables then we need to create a new Dataset with only those # variables and apply the subset selectors to each variable ds = xr.Dataset() ds.attrs.update(sample.attrs) for var_id in self.variables: # Use common subset selectors unless overridden by variable-specific selectors if self.subset_selectors: ds[var_id] = sample[var_id].sel(**self.subset_selectors) else: ds[var_id] = sample[var_id] return ds
[docs] class SqueezeTransform(BaseTransform):
[docs] def __call__(self, sample): # Implement squeezing logic here check_object_type(sample, allowed_types=(DS, DA, TT), caller=self.__class__.__name__) return sample.squeeze()
[docs] class TilerTransform(BaseTransform): """ A transform that takes a Dataset or DataArray and breaks it into smaller tiles along specified dimensions. This uses the xarray `coarsen` + `construct` pattern to create non-overlapping tiles of the data, which can be useful for training models on large spatial datasets by reducing memory usage and allowing for batch processing of smaller chunks of data. """ def __init__(self, boundary: str = "pad", **dim_tile_sizes):
[docs] self.boundary = boundary
[docs] self.tile_sizes = dim_tile_sizes
[docs] def __call__(self, sample: DA) -> DA: check_object_type(sample, allowed_types=DA, caller=self.__class__.__name__) # Create the dictionary to send to the ".construct()" method, using a naming convention of # ("{dim}_coarse", "{dim}_fine") for the new dimensions created by the tiling process. tile_dims = {dim: (f"{dim}_coarse", f"{dim}_fine") for dim in self.tile_sizes} coarsened = sample.coarsen(**self.tile_sizes, boundary=self.boundary).construct(**tile_dims) # type: ignore # Prepare a stacking regrouping of the original dimensions and the new dimensions batch_dims = [] target_dims = [] for dim in sample.dims: if dim in self.tile_sizes: batch_dims.append(f"{dim}_coarse") target_dims.append(f"{dim}_fine") else: target_dims.append(dim) stacked = coarsened.stack(batch_dim=batch_dims) # Reorder to have batch_dim first, followed by the original dimensions and then the fine tile dimensions tiled = stacked.transpose("batch_dim", *target_dims) # Store reverse-lookup metadata in attrs tiled.attrs.update({ "tiler_tile_sizes": self.tile_sizes, "tiler_boundary": self.boundary, "tiler_original_sizes": {dim: sample.sizes[dim] for dim in self.tile_sizes}, "tiler_original_coords": {dim: sample.coords[dim].values.tolist() for dim in self.tile_sizes if dim in sample.coords}, }) return tiled
[docs] class ToDataArray(BaseTransform): def __init__(self, var_id: str):
[docs] self.var_id = var_id
[docs] def __call__(self, sample: DS | DA) -> DA: # Implement conversion to xarray DataArray here check_object_type(sample, allowed_types=(DS, DA), caller=self.__class__.__name__) if isinstance(sample, DS): if len(sample.data_vars) != 1: raise ValueError("ToDataArrayTransform can only be applied to Datasets with a single variable.") return sample[self.var_id] return sample
[docs] class ToTensorTransform(BaseTransform):
[docs] def __call__(self, sample: DA | np.ndarray) -> torch.Tensor: # Implement conversion to PyTorch tensor here check_object_type(sample, allowed_types=(DA, np.ndarray), caller=self.__class__.__name__) if isinstance(sample, DA): sample = sample.values return torch.from_numpy(sample)
[docs] class TransposeTransform(BaseTransform):
[docs] def __call__(self, sample): # Implement transposing logic here check_object_type(sample, allowed_types=(DA, TT), caller=self.__class__.__name__) return sample.transpose()
[docs] class VarsToDimensionTransform(BaseTransform): """ A transform that takes a list of variables from a Dataset and stacks them into a new dimension, effectively converting the variable dimension into a coordinate dimension. This is useful for models that expect a single multi-channel input rather than separate variables. Since the purpose is to prepare the data for conversion to a Tensor, we assume that ancillary variables that are not genuine coordinates can be dropped. """
[docs] exclusion_vars = ["time_bounds", "lat_bounds", "lon_bounds", "time_bnds", "lat_bnds", "lon_bnds", "crs", "spatial_ref", "bounds", "bnds"]
def __init__(self, variables: list, new_dim: str, only_vars_with_time: bool = True):
[docs] self.variables = variables
[docs] self.new_dim = new_dim
[docs] self.only_vars_with_time = only_vars_with_time
[docs] def __call__(self, sample): # Implement logic to convert variables to a new dimension here check_object_type(sample, allowed_types=DS, caller=self.__class__.__name__) # Check special case of variables = "__all__", take all variables and filter out those not needed/suitable if self.variables == "__all__": # Exclude variables relate to bounds and coordinates bounds_vars = set([b_list[0] for b_list in sample.cf.bounds.values()]) if self.only_vars_with_time: vars_without_time = set([var_id for var_id in sample.data_vars if not hasattr(sample[var_id], "time")]) else: vars_without_time = set() exclusion_vars = set([var_id for var_id in self.exclusion_vars if var_id in sample.data_vars]) # Combine all exclusion criteria into a single set of variables to drop all_exclusion_vars = bounds_vars | vars_without_time | exclusion_vars # Drop the variables from the sample. sample.drop_vars(all_exclusion_vars) # Remove those variables from the wish list variables = set(sample.data_vars) - all_exclusion_vars else: variables = self.variables # Create a set of arrays to concatenate together arrays = [sample[var_id] for var_id in variables] stacked = xr.concat(arrays, dim=self.new_dim) return stacked
[docs] transform_mapping = { "fill_missing": FillMissingValueTransform, "fill_nan": FillNaNTransform, "normalize": NormalizeTransform, "rename": RenameTransform, "resample": ResampleTransform, "reshape": ReshapeTransform, "reverse_axis": ReverseAxisTransform, "roll": RollTransform, "scale": ScaleTransform, "sort_axis": SortAxisTransform, "squeeze": SqueezeTransform, "subset": SubsetTransform, "tiler": TilerTransform, "to_dataarray": ToDataArray, "to_tensor": ToTensorTransform, "transpose": TransposeTransform, "vars_to_dimension": VarsToDimensionTransform }
[docs] def resolve_transform(transform_config: dict) -> BaseTransform: """ If a transform is a dictionary with a "type" key, resolve it to the corresponding transform class instance. If it is already an instance of a transform class, return it as is. Args: - transform_config (dict or BaseTransform): The transform configuration to resolve. Returns: - BaseTransform: An instance of a transform class. """ if isinstance(transform_config, BaseTransform): return transform_config transform_type = transform_config.get("type") if transform_type not in transform_mapping: raise ValueError(f"Unsupported transform type: {transform_type}") transform_class = transform_mapping[transform_type] return transform_class(**{k: v for k, v in transform_config.items() if k != "type"})
[docs] def apply_transforms(data: xr.Dataset | xr.DataArray, preprocessors: list) -> xr.Dataset | xr.DataArray: """ Apply a list of preprocessing transforms to a data sample. Args: sample (xr.Dataset | xr.DataArray): The input data sample to be transformed. preprocessors (list): A list of transform configurations to apply to the sample. Returns: xr.Dataset | xr.DataArray: The transformed data sample after applying all preprocessors. """ for preprocessor in preprocessors: if not isinstance(preprocessor, dict) or "type" not in preprocessor: raise ValueError(f"Each preprocessor must be a dictionary with a 'type' key. Invalid preprocessor: {preprocessor}") data = resolve_transform(preprocessor)(data) return data
# Create `apply_preprocessors` as an alias for `apply_transforms` to allow for more intuitive naming when used # in the context of preprocessing steps.
[docs] apply_preprocessors = apply_transforms