Skip to content

Commit

Permalink
Add Batch.regrid
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Sep 12, 2024
1 parent c91c3ab commit 186e718
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 83 deletions.
101 changes: 101 additions & 0 deletions aurora/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import dataclasses
from datetime import datetime
from functools import partial
from typing import Callable

import numpy as np
import torch
from scipy.interpolate import RegularGridInterpolator as RGI

from aurora.normalisation import (
normalise_atmos_var,
Expand Down Expand Up @@ -145,3 +148,101 @@ def to(self, device: str | torch.device) -> "Batch":
def type(self, t: type) -> "Batch":
"""Convert everything to type `t`."""
return self._fmap(lambda x: x.type(t))

def regrid(self, res: float) -> "Batch":
"""Regrid the batch to a `res` degrees resolution.
This results in `float32` data on the CPU.
This function is not optimised for either speed or accuracy. Use at your own risk.
"""

shape = (round(180 / res) + 1, round(360 / res))
lat_new = torch.from_numpy(np.linspace(90, -90, shape[0]))
lon_new = torch.from_numpy(np.linspace(0, 360, shape[1], endpoint=False))
interpolate_res = partial(
interpolate,
lat=self.metadata.lat,
lon=self.metadata.lon,
lat_new=lat_new,
lon_new=lon_new,
)

return Batch(
surf_vars={k: interpolate_res(v) for k, v in self.surf_vars.items()},
static_vars={k: interpolate_res(v) for k, v in self.static_vars.items()},
atmos_vars={k: interpolate_res(v) for k, v in self.atmos_vars.items()},
metadata=Metadata(
lat=lat_new,
lon=lon_new,
atmos_levels=self.metadata.atmos_levels,
time=self.metadata.time,
rollout_step=self.metadata.rollout_step,
),
)


def interpolate(
v: torch.Tensor,
lat: torch.Tensor,
lon: torch.Tensor,
lat_new: torch.Tensor,
lon_new: torch.Tensor,
) -> torch.Tensor:
"""Interpolate a variable `v` with latitudes `lat` and longitudes `lon` to new latitudes
`lat_new` and new longitudes `lon_new`."""
# Perform the interpolation in double precision.
return torch.from_numpy(
interpolate_numpy(
v.double().numpy(),
lat.double().numpy(),
lon.double().numpy(),
lat_new.double().numpy(),
lon_new.double().numpy(),
)
).float()


def interpolate_numpy(
v: np.ndarray,
lat: np.ndarray,
lon: np.ndarray,
lat_new: np.ndarray,
lon_new: np.ndarray,
) -> np.ndarray:
"""Like :func:`.interpolate`, but for NumPy tensors."""

# Implement periodic longitudes in `lon`.
assert (np.diff(lon) > 0).all()
lon = np.concatenate((lon[-1:] - 360, lon, lon[:1] + 360))

# Merge all batch dimensions into one.
batch_shape = v.shape[:-2]
v = v.reshape(-1, *v.shape[-2:])

# Loop over all batch elements.
vs_regridded = []
for vi in v:
# Implement periodic longitudes in `vi`.
vi = np.concatenate((vi[:, -1:], vi, vi[:, :1]), axis=1)

rgi = RGI(
(lat, lon),
vi,
method="linear",
bounds_error=False, # Allow out of bounds, for the latitudes.
fill_value=None, # Extrapolate latitudes if they are out of bounds.
)
lat_new_grid, lon_new_grid = np.meshgrid(
lat_new,
lon_new,
indexing="ij",
sparse=True,
)
vs_regridded.append(rgi((lat_new_grid, lon_new_grid)))

# Recreate the batch dimensions.
v_regridded = np.stack(vs_regridded, axis=0)
v_regridded = v_regridded.reshape(*batch_shape, lat_new.shape[0], lon_new.shape[0])

return v_regridded
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dynamic = ["version"]
requires-python = ">=3.10"
dependencies = [
"numpy",
"scipy",
"torch",
"einops",
"timm==0.6.13",
Expand Down
84 changes: 84 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import pickle
from datetime import datetime
from typing import Generator, TypedDict

import numpy as np
import pytest
import torch
from huggingface_hub import hf_hub_download

from aurora import Batch, Metadata
from aurora.batch import interpolate_numpy


class SavedMetadata(TypedDict):
"""Type of metadata of a saved test batch."""

lat: np.ndarray
lon: np.ndarray
time: list[datetime]
atmos_levels: list[int | float]


class SavedBatch(TypedDict):
"""Type of a saved test batch."""

surf_vars: dict[str, np.ndarray]
static_vars: dict[str, np.ndarray]
atmos_vars: dict[str, np.ndarray]
metadata: SavedMetadata


@pytest.fixture()
def test_input_output() -> Generator[tuple[Batch, SavedBatch], None, None]:
# Load test input.
path = hf_hub_download(
repo_id="microsoft/aurora",
filename="aurora-0.25-small-pretrained-test-input.pickle",
)
with open(path, "rb") as f:
test_input: SavedBatch = pickle.load(f)

# Load static variables.
path = hf_hub_download(
repo_id="microsoft/aurora",
filename="aurora-0.25-static.pickle",
)
with open(path, "rb") as f:
static_vars: dict[str, np.ndarray] = pickle.load(f)

static_vars = {
k: interpolate_numpy(
v,
np.linspace(90, -90, v.shape[0]),
np.linspace(0, 360, v.shape[1], endpoint=False),
test_input["metadata"]["lat"],
test_input["metadata"]["lon"],
)
for k, v in static_vars.items()
}

# Construct a proper batch from the test input.
batch = Batch(
surf_vars={k: torch.from_numpy(v) for k, v in test_input["surf_vars"].items()},
static_vars={k: torch.from_numpy(v) for k, v in static_vars.items()},
atmos_vars={k: torch.from_numpy(v) for k, v in test_input["atmos_vars"].items()},
metadata=Metadata(
lat=torch.from_numpy(test_input["metadata"]["lat"]),
lon=torch.from_numpy(test_input["metadata"]["lon"]),
atmos_levels=tuple(test_input["metadata"]["atmos_levels"]),
time=tuple(test_input["metadata"]["time"]),
),
)

# Load test output.
path = hf_hub_download(
repo_id="microsoft/aurora",
filename="aurora-0.25-small-pretrained-test-output.pickle",
)
with open(path, "rb") as f:
test_output: SavedBatch = pickle.load(f)

yield batch, test_output
37 changes: 37 additions & 0 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import numpy as np

from tests.conftest import SavedBatch

from aurora import Batch


def test_interpolation(test_input_output: tuple[Batch, SavedBatch]) -> None:
batch, _ = test_input_output

# Regridding to the same resolution shouldn't change the data.
batch_regridded = batch.regrid(0.45)
batch_regridded = batch_regridded.crop(4) # Regridding added the south pole. Remove it again.

for k in batch.surf_vars:
np.testing.assert_allclose(
batch.surf_vars[k],
batch_regridded.surf_vars[k],
rtol=5e-6,
)
for k in batch.static_vars:
np.testing.assert_allclose(
batch.static_vars[k],
batch_regridded.static_vars[k],
atol=1e-7,
)
for k in batch.atmos_vars:
np.testing.assert_allclose(
batch.atmos_vars[k],
batch_regridded.atmos_vars[k],
rtol=5e-6,
)

np.testing.assert_allclose(batch.metadata.lat, batch_regridded.metadata.lat, atol=1e-10)
np.testing.assert_allclose(batch.metadata.lon, batch_regridded.metadata.lon, atol=1e-10)
88 changes: 5 additions & 83 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,18 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import pickle
from datetime import datetime
from typing import TypedDict

import numpy as np
import torch
from huggingface_hub import hf_hub_download
from scipy.interpolate import RegularGridInterpolator as RGI

from aurora import AuroraSmall, Batch, Metadata


class SavedMetadata(TypedDict):
"""Type of metadata of a saved test batch."""

lat: np.ndarray
lon: np.ndarray
time: list[datetime]
atmos_levels: list[int | float]

from tests.conftest import SavedBatch

class SavedBatch(TypedDict):
"""Type of a saved test batch."""
from aurora import AuroraSmall, Batch

surf_vars: dict[str, np.ndarray]
static_vars: dict[str, np.ndarray]
atmos_vars: dict[str, np.ndarray]
metadata: SavedMetadata

def test_aurora_small(test_input_output: tuple[Batch, SavedBatch]) -> None:
batch, test_output = test_input_output

def test_aurora_small() -> None:
model = AuroraSmall(use_lora=True)

# Load test input.
path = hf_hub_download(
repo_id="microsoft/aurora",
filename="aurora-0.25-small-pretrained-test-input.pickle",
)
with open(path, "rb") as f:
test_input: SavedBatch = pickle.load(f)

# Load test output.
path = hf_hub_download(
repo_id="microsoft/aurora",
filename="aurora-0.25-small-pretrained-test-output.pickle",
)
with open(path, "rb") as f:
test_output: SavedBatch = pickle.load(f)

# Load static variables.
path = hf_hub_download(
repo_id="microsoft/aurora",
filename="aurora-0.25-static.pickle",
)
with open(path, "rb") as f:
static_vars: dict[str, np.ndarray] = pickle.load(f)

def interpolate(v: np.ndarray) -> np.ndarray:
"""Interpolate a static variable `v` to the grid of the test data."""
rgi = RGI(
(
np.linspace(90, -90, v.shape[0]),
np.linspace(0, 360, v.shape[1], endpoint=False),
),
v,
method="linear",
bounds_error=False,
)
lat_new, lon_new = np.meshgrid(
test_input["metadata"]["lat"],
test_input["metadata"]["lon"],
indexing="ij",
sparse=True,
)
return rgi((lat_new, lon_new))

static_vars = {k: interpolate(v) for k, v in static_vars.items()}

# Construct a proper batch from the test input.
batch = Batch(
surf_vars={k: torch.from_numpy(v) for k, v in test_input["surf_vars"].items()},
static_vars={k: torch.from_numpy(v) for k, v in static_vars.items()},
atmos_vars={k: torch.from_numpy(v) for k, v in test_input["atmos_vars"].items()},
metadata=Metadata(
lat=torch.from_numpy(test_input["metadata"]["lat"]),
lon=torch.from_numpy(test_input["metadata"]["lon"]),
atmos_levels=tuple(test_input["metadata"]["atmos_levels"]),
time=tuple(test_input["metadata"]["time"]),
),
)

# Load the checkpoint and run the model.
model.load_checkpoint(
"microsoft/aurora",
Expand Down Expand Up @@ -130,7 +52,7 @@ def assert_approx_equality(v_out: np.ndarray, v_ref: np.ndarray, tol: float) ->
for k in pred.static_vars:
assert_approx_equality(
pred.static_vars[k].numpy(),
static_vars[k],
batch.static_vars[k].numpy(),
1e-10, # These should be exactly equal.
)
for k in pred.atmos_vars:
Expand Down

0 comments on commit 186e718

Please sign in to comment.