Source code for src.FRAME_FM.datasets.InputOnly_Dataset
# src/FRAME_FM/datasets/InputOnly_Dataset.py
"""
Lightweight Dataset wrapper that takes inputs only"""
from typing import Optional, Any
from torch.utils.data import Dataset
[docs]
class TransformedInputDataset(Dataset):
"""
This class applies to input only style datasets that are useful for visual
autoencoders. The method currently uses a scaling coefficient to scale the
input, this will change in the future when the decision about transform
settings are finalized.
"""
def __init__(self, base: Dataset, transform: Optional[Any] = None) -> None:
[docs]
def __getitem__(self, idx: int):
tile = self.base[idx] # expected dimensions are (C x H x W) or (T x C x H x W)
if self.transform is not None:
tile = self.transform(tile)
return tile
[docs]
class TransformedInputCoordsDataset(Dataset):
"""
This class applies to input only in the dataset that also passes grid tile coordinates.
"""
def __init__(self, base: Dataset, transform: Optional[Any] = None) -> None:
[docs]
def __getitem__(self, idx: int):
tile, coordinates = self.base[idx] # expected dimensions are (C x H x W) or (T x C x H x W)
if self.transform is not None:
tile = self.transform(tile)
return tile, coordinates