Source code for src.FRAME_FM.datasets.base_gridded_dataset

import torch
from pathlib import Path

from FRAME_FM.datasets.base_dataset import BaseDataset
from FRAME_FM.utils.data_utils import load_data_from_uri, unify_transforms, get_main_vars
from FRAME_FM.transforms import resolve_transform, apply_preprocessors


[docs] class BaseGriddedDataset(BaseDataset): _transforms = [ {"type": "vars_to_dimension", "variables": "__all__", "new_dim": "variable", "only_vars_with_time": False}, {"type": "to_tensor"} ] def _setup_dataset(self): self.data = load_data_from_uri(self.data_uri, chunks=self.chunks) self.main_var = get_main_vars(self.data)[0] self.first_coord = list(self.data[self.main_var].coords.keys())[0]
[docs] def __len__(self) -> int: return len(self.data[self.main_var])
[docs] def __getitem__(self, idx: int) -> torch.Tensor: # Return the data sample at the specified index sample = self.data.isel(**{self.first_coord: idx}) # Apply runtime transforms if any for transform in self.transforms: sample = resolve_transform(transform)(sample) return sample # type: ignore
[docs] class BaseGeoTIFFDataset(BaseDataset): _transforms = [ {"type": "vars_to_dimension", "variables": ["band_data"], "new_dim": "variable"}, {"type": "to_tensor"} ]
[docs] def __len__(self) -> int: return len(self.data["band_data"])
[docs] def __getitem__(self, idx: int) -> torch.Tensor: # Return the data sample at the specified index sample = self.data.isel(band=idx) # Apply runtime transforms if any for transform in self.transforms: sample = resolve_transform(transform)(sample) return sample # type: ignore
[docs] class BaseASCIIGridDataset(BaseGeoTIFFDataset): pass
[docs] class BaseGriddedTimeSeriesDataset(BaseDataset): # Define transforms that are always appended to the end of the transforms list in any child class. # This ensures that the data is always converted to a tensor and has a "variable" dimension for # the model to work with, even if the user doesn't specify these transforms themselves. _transforms = [ {"type": "vars_to_dimension", "variables": "__all__", "new_dim": "variable"}, {"type": "to_tensor"} ]
[docs] DEFAULT_CHUNKS = {"time": 64} # Default chunking strategy to ensure Dask is used for time series data
def __init__(self, data_uri: str | Path | list | tuple, preprocessors: list | None = None, transforms: list | None = None, # time_range: tuple | None = None, time_stride: int = 16, chunks: dict | None = None, override_transforms: bool = False, cache_dir: None | Path | str = None, generate_stats: bool = True, force_recache: bool = False ): # Set instance variables specific to time series datasets # self.time_range = time_range
[docs] self.time_stride = time_stride
[docs] self.chunks = chunks or self.DEFAULT_CHUNKS
# Call super init to set up transforms and preprocessors super().__init__( data_uri=data_uri, preprocessors=preprocessors, transforms=transforms, chunks=chunks, override_transforms=override_transforms, cache_dir=cache_dir, generate_stats=generate_stats, force_recache=force_recache ) # def _setup_dataset(self): # # Apply the time selection at the start, to allow any subsequent processing to focus within # # the selected time range (if specified). # # Load the dataset ready for training # subset_selection = {"time": self.time_range} if self.time_range else {} # self.data = load_data_from_uri(self.data_uri, chunks=self.chunks, subset_selection=subset_selection)
[docs] def __len__(self) -> int: return len(self.data["time"]) // self.time_stride
[docs] def __getitem__(self, idx: int) -> torch.Tensor: # Return the data sample at the specified index time_slice = slice(idx * self.time_stride, (idx + 1) * self.time_stride) sample = self.data.isel(time=time_slice) # Apply runtime transforms if any for transform in self.transforms: sample = resolve_transform(transform)(sample) return sample # type: ignore