import matplotlib.pyplot as plt
from pathlib import Path
import pyproj
import rioxarray as rxr
import torch
from torch.utils.data import TensorDataset
from typing import Callable
import xarray
from ..utils.LightningDataModuleWrapper import BaseDataModule
from ..datasets.InputOnly_Dataset import TransformedInputCoordsDataset
[docs]
def convert_to_long_lat(x, y, src_crs, dst_crs="EPSG:4326"):
# convert to lat/lon if the dataset has a CRS (Coordinate Reference System) defined
if src_crs is not None:
transformer = pyproj.Transformer.from_crs(src_crs, dst_crs, always_xy=True)
longitude, latitude = transformer.transform(x, y)
return list(zip(longitude, latitude))
else:
raise ValueError("CRS not defined in dataset attributes, cannot convert to lat/lon")
[docs]
class XarrayStaticDataModule(BaseDataModule):
'''
A simple DataModule for loading tiled static data from a geotiff file using xarray.
'''
[docs]
train_dataset: TransformedInputCoordsDataset
[docs]
val_dataset: TransformedInputCoordsDataset
[docs]
test_dataset: TransformedInputCoordsDataset | None
def __init__(self,
data_root: str,
batch_size: int = 32,
num_workers: int = 4,
pin_memory: bool = True,
persistent_workers: bool = False,
train_split: float = 0.85,
val_split: float = 0.15,
test_split: float = 0.0,
split_strategy: str = "fraction",
train_transforms: Callable | None = None,
val_transforms: Callable | None = None,
test_transforms: Callable | None = None,
tile_size: int = 256,
) -> None:
super().__init__(
data_root=data_root,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
train_split=train_split,
val_split=val_split,
test_split=test_split,
split_strategy=split_strategy,
train_transforms=train_transforms,
val_transforms=val_transforms,
test_transforms=test_transforms,
)
[docs]
self.tile_size = tile_size
def _load_raw_data(self) -> xarray.DataArray:
# currently reading a single file
ds = rxr.open_rasterio(self.data_root)
if isinstance(ds, xarray.DataArray):
return ds
else:
raise ValueError("data_root must be a single filename")
[docs]
def tile_array(self, array: xarray.DataArray) -> xarray.DataArray:
# sanity check
_, nY, nX = array.shape
if nY < self.tile_size or nX < self.tile_size:
raise ValueError(
"DataArray is smaller than minimal tile size required for encoding: "
f"{nY}x{nX} < {self.tile_size}x{self.tile_size}"
)
# tile
tiles = array.coarsen(
x=self.tile_size, y=self.tile_size, boundary='pad'
).construct(
x=("tile_xid", "x"), y=("tile_yid", "y")
).stack(
batch_dim=("tile_xid", "tile_yid")
).transpose("batch_dim", "band", "y", "x")
# THIS WILL CHANGE
# replace nans with zeros (required for PCA for latent space visualization)
tiles = tiles.fillna(0)
return tiles
def _create_datasets(self, stage: str | None = None) -> None:
"""
Reads the DataArray from the attributes, tiles it into patches along x
and y axis, and outputs stacked tiles. This dataset contains only
inputs to b
AK: the tiling in this way does not preserve relative positions of
tiles, so for the exrension to multiple layers the bands need to be
stacked first, then tiled.
"""
tiles = self.tile_array(self._raw_data)
# to tensor dataset
dataset = TensorDataset(torch.tensor(tiles.values, dtype=torch.float32))
# split into subsets
train_base, val_base, test_base = self._split_dataset(dataset)
# transform datasets
self.train_dataset = TransformedInputCoordsDataset(train_base, self.train_transforms)
self.val_dataset = TransformedInputCoordsDataset(val_base, self.val_transforms)
self.test_dataset = None if test_base is None else TransformedInputCoordsDataset(
test_base, self.test_transforms
)
[docs]
class GeotiffSpatialDataModule(XarrayStaticDataModule):
'''
Module for loading tiled values and their positions from a geotiff file, using xarray.
'''
def _create_datasets(self, stage: str | None = None) -> None:
tiles = self.tile_array(self._raw_data)
# to tensor dataset
spatial_dataset = TensorDataset(
torch.tensor(tiles.values, dtype=torch.float32),
self.extract_position_tensor(tiles),
)
# split into subsets
train_base, val_base, test_base = self._split_dataset(spatial_dataset)
# transform datasets
self.train_dataset = TransformedInputCoordsDataset(train_base, self.train_transforms)
self.val_dataset = TransformedInputCoordsDataset(val_base, self.val_transforms)
self.test_dataset = None if test_base is None else TransformedInputCoordsDataset(
test_base, self.test_transforms
)
[docs]
class GeotiffBoundedDataModule(GeotiffSpatialDataModule):
'''
Module for loading tiled values and tile coordinate bounds from a geotiff file, using xarray.
'''
[docs]
def main():
"""
Try creating the data module and plotting examples. Currently not using hydra.
"""
PLOTTING = True
DEBUG = True
# example usage
geotiff_path = Path(
"/gws/ssde/j25b/eds_ai/frame-fm/data/inputs/land_cover_map_2015/data/"
"LCM2015_GB_1km_percent_cover_aggregate_class.tif"
)
# try initializing the dataloader
tile_size = 128
data_module = GeotiffSpatialDataModule(
data_root=geotiff_path.as_posix(), tile_size=tile_size
)
data_module.setup()
train_dataloader = data_module.train_dataloader()
for batch_id, (batch_values, batch_positions) in enumerate(train_dataloader):
non_zero_tile = (batch_values != 0).any(dim=(1, 2, 3))
if non_zero_tile.any():
tile_id = non_zero_tile.int().argmax()
break
tile_values, tile_positions = batch_values[tile_id], batch_positions[tile_id]
bounded_data_module = GeotiffBoundedDataModule(
data_root=geotiff_path.as_posix(), tile_size=tile_size
)
bounded_data_module.setup()
train_dataloader = bounded_data_module.train_dataloader()
for bbatch_values, batch_bounds in train_dataloader:
tile_matches = (bbatch_values == tile_values.unsqueeze(0)).all(dim=(1, 2, 3))
if tile_matches.any():
tile_id = tile_matches.int().argmax()
break
tile_bounds = batch_bounds[tile_id]
if DEBUG:
print("GeotiffSpatialDataModule:")
print("_________________")
print(f"Train dataloader batches: {len(data_module.train_dataset)}")
print(f"Validation dataloader batches: {len(data_module.val_dataset)}")
if data_module.test_dataset is not None:
print(f"Test dataloader batches: {len(data_module.test_dataset)}")
print(f"First non-zero training batch #{batch_id}")
print(f"First non-zero tile #{tile_id} of {batch_values.shape[0]} in batch")
print(f"Values shape (should be nBands, {tile_size}, {tile_size}): {tile_values.shape}")
print(f"Positions shape (should be 2, {tile_size}, {tile_size}): {tile_positions.shape}")
print()
print("GeotiffBoundedDataModule:")
print("_________________")
print(f"Bounds shape (should be 2, 2): {tile_bounds.shape}")
if PLOTTING:
fig_path = Path("./experiments/figures/")
fig_path.mkdir(parents=True, exist_ok=True)
# see what the values are...
nBands, _, _ = tile_values.shape
vmin, vmax = tile_values.min(), tile_values.max()
fig, axs = plt.subplots(2, nBands // 2, figsize=(15, 6))
for band, ax in enumerate(axs.flatten()):
sm = ax.pcolormesh(
tile_positions[1], tile_positions[0], tile_values[band],
cmap="viridis", vmin=vmin, vmax=vmax, shading='nearest'
)
ax.set_title(f"Band {band}")
fig.colorbar(sm, ax=axs, label="% coverage")
x_bounds, y_bounds = tile_bounds[0].tolist(), tile_bounds[1].tolist()
fig.suptitle(f"Land cover in OSGB bounds {x_bounds}, {y_bounds}")
fig.savefig(fig_path / "geotiff_bands.png")
# ^ this is meant to be in percentages of 10 aggregate classes
return None
if __name__ == "__main__":
main()