From 186e718ab378b5beaef8880dd5d5e370f8fc24b9 Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Thu, 12 Sep 2024 17:31:04 +0200 Subject: [PATCH] Add `Batch.regrid` --- aurora/batch.py | 101 ++++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 1 + tests/conftest.py | 84 ++++++++++++++++++++++++++++++++++++ tests/test_batch.py | 37 ++++++++++++++++ tests/test_model.py | 88 +++----------------------------------- 5 files changed, 228 insertions(+), 83 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_batch.py diff --git a/aurora/batch.py b/aurora/batch.py index 69a2d65..3be2371 100644 --- a/aurora/batch.py +++ b/aurora/batch.py @@ -2,9 +2,12 @@ import dataclasses from datetime import datetime +from functools import partial from typing import Callable +import numpy as np import torch +from scipy.interpolate import RegularGridInterpolator as RGI from aurora.normalisation import ( normalise_atmos_var, @@ -145,3 +148,101 @@ def to(self, device: str | torch.device) -> "Batch": def type(self, t: type) -> "Batch": """Convert everything to type `t`.""" return self._fmap(lambda x: x.type(t)) + + def regrid(self, res: float) -> "Batch": + """Regrid the batch to a `res` degrees resolution. + + This results in `float32` data on the CPU. + + This function is not optimised for either speed or accuracy. Use at your own risk. + """ + + shape = (round(180 / res) + 1, round(360 / res)) + lat_new = torch.from_numpy(np.linspace(90, -90, shape[0])) + lon_new = torch.from_numpy(np.linspace(0, 360, shape[1], endpoint=False)) + interpolate_res = partial( + interpolate, + lat=self.metadata.lat, + lon=self.metadata.lon, + lat_new=lat_new, + lon_new=lon_new, + ) + + return Batch( + surf_vars={k: interpolate_res(v) for k, v in self.surf_vars.items()}, + static_vars={k: interpolate_res(v) for k, v in self.static_vars.items()}, + atmos_vars={k: interpolate_res(v) for k, v in self.atmos_vars.items()}, + metadata=Metadata( + lat=lat_new, + lon=lon_new, + atmos_levels=self.metadata.atmos_levels, + time=self.metadata.time, + rollout_step=self.metadata.rollout_step, + ), + ) + + +def interpolate( + v: torch.Tensor, + lat: torch.Tensor, + lon: torch.Tensor, + lat_new: torch.Tensor, + lon_new: torch.Tensor, +) -> torch.Tensor: + """Interpolate a variable `v` with latitudes `lat` and longitudes `lon` to new latitudes + `lat_new` and new longitudes `lon_new`.""" + # Perform the interpolation in double precision. + return torch.from_numpy( + interpolate_numpy( + v.double().numpy(), + lat.double().numpy(), + lon.double().numpy(), + lat_new.double().numpy(), + lon_new.double().numpy(), + ) + ).float() + + +def interpolate_numpy( + v: np.ndarray, + lat: np.ndarray, + lon: np.ndarray, + lat_new: np.ndarray, + lon_new: np.ndarray, +) -> np.ndarray: + """Like :func:`.interpolate`, but for NumPy tensors.""" + + # Implement periodic longitudes in `lon`. + assert (np.diff(lon) > 0).all() + lon = np.concatenate((lon[-1:] - 360, lon, lon[:1] + 360)) + + # Merge all batch dimensions into one. + batch_shape = v.shape[:-2] + v = v.reshape(-1, *v.shape[-2:]) + + # Loop over all batch elements. + vs_regridded = [] + for vi in v: + # Implement periodic longitudes in `vi`. + vi = np.concatenate((vi[:, -1:], vi, vi[:, :1]), axis=1) + + rgi = RGI( + (lat, lon), + vi, + method="linear", + bounds_error=False, # Allow out of bounds, for the latitudes. + fill_value=None, # Extrapolate latitudes if they are out of bounds. + ) + lat_new_grid, lon_new_grid = np.meshgrid( + lat_new, + lon_new, + indexing="ij", + sparse=True, + ) + vs_regridded.append(rgi((lat_new_grid, lon_new_grid))) + + # Recreate the batch dimensions. + v_regridded = np.stack(vs_regridded, axis=0) + v_regridded = v_regridded.reshape(*batch_shape, lat_new.shape[0], lon_new.shape[0]) + + return v_regridded diff --git a/pyproject.toml b/pyproject.toml index e889efd..92c8861 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dynamic = ["version"] requires-python = ">=3.10" dependencies = [ "numpy", + "scipy", "torch", "einops", "timm==0.6.13", diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..91056c1 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,84 @@ +"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" + +import pickle +from datetime import datetime +from typing import Generator, TypedDict + +import numpy as np +import pytest +import torch +from huggingface_hub import hf_hub_download + +from aurora import Batch, Metadata +from aurora.batch import interpolate_numpy + + +class SavedMetadata(TypedDict): + """Type of metadata of a saved test batch.""" + + lat: np.ndarray + lon: np.ndarray + time: list[datetime] + atmos_levels: list[int | float] + + +class SavedBatch(TypedDict): + """Type of a saved test batch.""" + + surf_vars: dict[str, np.ndarray] + static_vars: dict[str, np.ndarray] + atmos_vars: dict[str, np.ndarray] + metadata: SavedMetadata + + +@pytest.fixture() +def test_input_output() -> Generator[tuple[Batch, SavedBatch], None, None]: + # Load test input. + path = hf_hub_download( + repo_id="microsoft/aurora", + filename="aurora-0.25-small-pretrained-test-input.pickle", + ) + with open(path, "rb") as f: + test_input: SavedBatch = pickle.load(f) + + # Load static variables. + path = hf_hub_download( + repo_id="microsoft/aurora", + filename="aurora-0.25-static.pickle", + ) + with open(path, "rb") as f: + static_vars: dict[str, np.ndarray] = pickle.load(f) + + static_vars = { + k: interpolate_numpy( + v, + np.linspace(90, -90, v.shape[0]), + np.linspace(0, 360, v.shape[1], endpoint=False), + test_input["metadata"]["lat"], + test_input["metadata"]["lon"], + ) + for k, v in static_vars.items() + } + + # Construct a proper batch from the test input. + batch = Batch( + surf_vars={k: torch.from_numpy(v) for k, v in test_input["surf_vars"].items()}, + static_vars={k: torch.from_numpy(v) for k, v in static_vars.items()}, + atmos_vars={k: torch.from_numpy(v) for k, v in test_input["atmos_vars"].items()}, + metadata=Metadata( + lat=torch.from_numpy(test_input["metadata"]["lat"]), + lon=torch.from_numpy(test_input["metadata"]["lon"]), + atmos_levels=tuple(test_input["metadata"]["atmos_levels"]), + time=tuple(test_input["metadata"]["time"]), + ), + ) + + # Load test output. + path = hf_hub_download( + repo_id="microsoft/aurora", + filename="aurora-0.25-small-pretrained-test-output.pickle", + ) + with open(path, "rb") as f: + test_output: SavedBatch = pickle.load(f) + + yield batch, test_output diff --git a/tests/test_batch.py b/tests/test_batch.py new file mode 100644 index 0000000..7c68716 --- /dev/null +++ b/tests/test_batch.py @@ -0,0 +1,37 @@ +"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" + +import numpy as np + +from tests.conftest import SavedBatch + +from aurora import Batch + + +def test_interpolation(test_input_output: tuple[Batch, SavedBatch]) -> None: + batch, _ = test_input_output + + # Regridding to the same resolution shouldn't change the data. + batch_regridded = batch.regrid(0.45) + batch_regridded = batch_regridded.crop(4) # Regridding added the south pole. Remove it again. + + for k in batch.surf_vars: + np.testing.assert_allclose( + batch.surf_vars[k], + batch_regridded.surf_vars[k], + rtol=5e-6, + ) + for k in batch.static_vars: + np.testing.assert_allclose( + batch.static_vars[k], + batch_regridded.static_vars[k], + atol=1e-7, + ) + for k in batch.atmos_vars: + np.testing.assert_allclose( + batch.atmos_vars[k], + batch_regridded.atmos_vars[k], + rtol=5e-6, + ) + + np.testing.assert_allclose(batch.metadata.lat, batch_regridded.metadata.lat, atol=1e-10) + np.testing.assert_allclose(batch.metadata.lon, batch_regridded.metadata.lon, atol=1e-10) diff --git a/tests/test_model.py b/tests/test_model.py index c3df873..7720d2d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,96 +1,18 @@ """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" -import pickle -from datetime import datetime -from typing import TypedDict - import numpy as np import torch -from huggingface_hub import hf_hub_download -from scipy.interpolate import RegularGridInterpolator as RGI - -from aurora import AuroraSmall, Batch, Metadata - - -class SavedMetadata(TypedDict): - """Type of metadata of a saved test batch.""" - - lat: np.ndarray - lon: np.ndarray - time: list[datetime] - atmos_levels: list[int | float] +from tests.conftest import SavedBatch -class SavedBatch(TypedDict): - """Type of a saved test batch.""" +from aurora import AuroraSmall, Batch - surf_vars: dict[str, np.ndarray] - static_vars: dict[str, np.ndarray] - atmos_vars: dict[str, np.ndarray] - metadata: SavedMetadata +def test_aurora_small(test_input_output: tuple[Batch, SavedBatch]) -> None: + batch, test_output = test_input_output -def test_aurora_small() -> None: model = AuroraSmall(use_lora=True) - # Load test input. - path = hf_hub_download( - repo_id="microsoft/aurora", - filename="aurora-0.25-small-pretrained-test-input.pickle", - ) - with open(path, "rb") as f: - test_input: SavedBatch = pickle.load(f) - - # Load test output. - path = hf_hub_download( - repo_id="microsoft/aurora", - filename="aurora-0.25-small-pretrained-test-output.pickle", - ) - with open(path, "rb") as f: - test_output: SavedBatch = pickle.load(f) - - # Load static variables. - path = hf_hub_download( - repo_id="microsoft/aurora", - filename="aurora-0.25-static.pickle", - ) - with open(path, "rb") as f: - static_vars: dict[str, np.ndarray] = pickle.load(f) - - def interpolate(v: np.ndarray) -> np.ndarray: - """Interpolate a static variable `v` to the grid of the test data.""" - rgi = RGI( - ( - np.linspace(90, -90, v.shape[0]), - np.linspace(0, 360, v.shape[1], endpoint=False), - ), - v, - method="linear", - bounds_error=False, - ) - lat_new, lon_new = np.meshgrid( - test_input["metadata"]["lat"], - test_input["metadata"]["lon"], - indexing="ij", - sparse=True, - ) - return rgi((lat_new, lon_new)) - - static_vars = {k: interpolate(v) for k, v in static_vars.items()} - - # Construct a proper batch from the test input. - batch = Batch( - surf_vars={k: torch.from_numpy(v) for k, v in test_input["surf_vars"].items()}, - static_vars={k: torch.from_numpy(v) for k, v in static_vars.items()}, - atmos_vars={k: torch.from_numpy(v) for k, v in test_input["atmos_vars"].items()}, - metadata=Metadata( - lat=torch.from_numpy(test_input["metadata"]["lat"]), - lon=torch.from_numpy(test_input["metadata"]["lon"]), - atmos_levels=tuple(test_input["metadata"]["atmos_levels"]), - time=tuple(test_input["metadata"]["time"]), - ), - ) - # Load the checkpoint and run the model. model.load_checkpoint( "microsoft/aurora", @@ -130,7 +52,7 @@ def assert_approx_equality(v_out: np.ndarray, v_ref: np.ndarray, tol: float) -> for k in pred.static_vars: assert_approx_equality( pred.static_vars[k].numpy(), - static_vars[k], + batch.static_vars[k].numpy(), 1e-10, # These should be exactly equal. ) for k in pred.atmos_vars: