Source code for src.FRAME_FM.datasets.InputOnly_Dataset

# src/FRAME_FM/datasets/InputOnly_Dataset.py
"""
Lightweight Dataset wrapper that takes inputs only"""
from typing import Optional, Any
from torch.utils.data import Dataset


[docs] class TransformedInputDataset(Dataset): """ This class applies to input only style datasets that are useful for visual autoencoders. The method currently uses a scaling coefficient to scale the input, this will change in the future when the decision about transform settings are finalized. """ def __init__(self, base: Dataset, transform: Optional[Any] = None) -> None:
[docs] self.base = base
[docs] self.transform = transform
[docs] def __len__(self) -> int: return len(self.base)
[docs] def __getitem__(self, idx: int): tile = self.base[idx] # expected dimensions are (C x H x W) or (T x C x H x W) if self.transform is not None: tile = self.transform(tile) return tile
[docs] class TransformedInputCoordsDataset(Dataset): """ This class applies to input only in the dataset that also passes grid tile coordinates. """ def __init__(self, base: Dataset, transform: Optional[Any] = None) -> None:
[docs] self.base = base
[docs] self.transform = transform
[docs] def __len__(self) -> int: return len(self.base)
[docs] def __getitem__(self, idx: int): tile, coordinates = self.base[idx] # expected dimensions are (C x H x W) or (T x C x H x W) if self.transform is not None: tile = self.transform(tile) return tile, coordinates