From add37b6027def2b92e322089c1915ee03685d9b3 Mon Sep 17 00:00:00 2001 From: Jussi Leinonen Date: Wed, 11 Dec 2024 18:22:41 +0100 Subject: [PATCH] Interpolation for arbitrary grids (#159) * Add interpolation for arbitrary lat-lon grids * Add interpolation file and test * Update license, docstrings, type hints --- CHANGELOG.md | 2 + docs/modules/utils.rst | 6 + earth2studio/models/dx/corrdiff.py | 80 +----------- earth2studio/utils/interp.py | 195 +++++++++++++++++++++++++++++ test/utils/test_interp.py | 82 ++++++++++++ 5 files changed, 287 insertions(+), 78 deletions(-) create mode 100644 earth2studio/utils/interp.py create mode 100644 test/utils/test_interp.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ae553112..c00eff17 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Interpolation between arbitrary lat-lon grids + ### Changed - Set zarr chunks for lead time to size 1 in examples. diff --git a/docs/modules/utils.rst b/docs/modules/utils.rst index 2c1a21e8..7c0a0593 100644 --- a/docs/modules/utils.rst +++ b/docs/modules/utils.rst @@ -25,6 +25,12 @@ A collection of utilities to manipulate and check coordinate systems dictionarie utils.coords.map_coords utils.coords.split_coords +.. autosummary:: + :toctree: generated/utils/ + :template: class.rst + + utils.interp.LatLonInterpolation + :mod:`earth2studio.utils`: Time ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/earth2studio/models/dx/corrdiff.py b/earth2studio/models/dx/corrdiff.py index a12e40b2..4d91ba4a 100644 --- a/earth2studio/models/dx/corrdiff.py +++ b/earth2studio/models/dx/corrdiff.py @@ -38,6 +38,7 @@ handshake_coords, handshake_dim, ) +from earth2studio.utils.interp import latlon_interpolation_regular from earth2studio.utils.type import CoordSystem VARIABLES = [ @@ -289,7 +290,7 @@ def _interpolate(self, x: torch.Tensor) -> torch.Tensor: """Interpolate from input lat/lon (self.lat, self.lon) onto output lat/lon (self.lat_grid, self.lon_grid) using bilinear interpolation.""" input_coords = self.input_coords() - return self.interpolate( + return latlon_interpolation_regular( x, torch.as_tensor(input_coords["lat"], device=x.device, dtype=torch.float32), torch.as_tensor(input_coords["lon"], device=x.device, dtype=torch.float32), @@ -467,80 +468,3 @@ def unet_regression( x_next = net(x_hat, x_lr, t_hat, class_labels).to(torch.float64) return x_next - - @staticmethod - def interpolate( - values: torch.Tensor, - lat0: torch.Tensor, - lon0: torch.Tensor, - lat1: torch.Tensor, - lon1: torch.Tensor, - ) -> torch.Tensor: - """Specialized form of bilinear interpolation intended for optimal use on GPU. - - In particular, the mapped values must be defined on a regular rectangular grid, - (lat0, lon0). Both lat0 and lon0 are vectors with equal spacing. - - lat1, lon1 are assumed to be 2-dimensional meshgrids with possibly unequal spacing. - - Parameters - ---------- - values : torch.Tensor [..., W_in, H_in] - Input values defined over (lat0, lon0) that will be interpolated onto - (lat1, lon1) grid. - lat0 : torch.Tensor [W_in, ] - Vector of input latitude coordinates, assumed to be increasing with - equal spacing. - lon0 : torch.Tensor [H_in, ] - Vector of input longitude coordinates, assumed to be increasing with - equal spacing. - lat1 : torch.Tensor [W_out, H_out] - Tensor of output latitude coordinates - lon1 : torch.Tensor [W_out, H_out] - Tensor of output longitude coordinates - - Returns - ------- - result : torch.Tensor [..., W_out, H_out] - Tensor of interpolated values onto lat1, lon1 grid. - """ - - # Get input grid shape and flatten - latshape, lonshape = lat1.shape - lat1 = lat1.flatten() - lon1 = lon1.flatten() - - # Get indices of nearest points - latinds = torch.searchsorted(lat0, lat1) - 1 - loninds = torch.searchsorted(lon0, lon1) - 1 - - # Get original grid spacing - dlat = lat0[1] - lat0[0] - dlon = lon0[1] - lon0[0] - - # Get unit distances - normed_lat_distance = (lat1 - lat0[latinds]) / dlat - normed_lon_distance = (lon1 - lon0[loninds]) / dlon - - # Apply bilinear mapping - result = ( - values[..., latinds, loninds] - * (1 - normed_lat_distance) - * (1 - normed_lon_distance) - ) - result += ( - values[..., latinds, loninds + 1] - * (1 - normed_lat_distance) - * (normed_lon_distance) - ) - result += ( - values[..., latinds + 1, loninds] - * (normed_lat_distance) - * (1 - normed_lon_distance) - ) - result += ( - values[..., latinds + 1, loninds + 1] - * (normed_lat_distance) - * (normed_lon_distance) - ) - return result.reshape(*values.shape[:-2], latshape, lonshape) diff --git a/earth2studio/utils/interp.py b/earth2studio/utils/interp.py new file mode 100644 index 00000000..68a56e2e --- /dev/null +++ b/earth2studio/utils/interp.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from numpy.typing import ArrayLike +from scipy.interpolate import LinearNDInterpolator +from torch import Tensor, nn + + +def latlon_interpolation_regular( + values: torch.Tensor, + lat0: torch.Tensor, + lon0: torch.Tensor, + lat1: torch.Tensor, + lon1: torch.Tensor, +) -> torch.Tensor: + """Specialized form of bilinear interpolation intended for optimal use on GPU. + + In particular, the mapped values must be defined on a regular rectangular grid, + (lat0, lon0). Both lat0 and lon0 are vectors with equal spacing. + + lat1, lon1 are assumed to be 2-dimensional meshgrids with possibly unequal spacing. + + Parameters + ---------- + values : torch.Tensor [..., H_in, W_in] + Input values defined over (lat0, lon0) that will be interpolated onto + (lat1, lon1) grid. + lat0 : torch.Tensor [H_in, ] + Vector of input latitude coordinates, assumed to be increasing with + equal spacing. + lon0 : torch.Tensor [W_in, ] + Vector of input longitude coordinates, assumed to be increasing with + equal spacing. + lat1 : torch.Tensor [H_out, W_out] + Tensor of output latitude coordinates + lon1 : torch.Tensor [H_out, W_out] + Tensor of output longitude coordinates + + Returns + ------- + result : torch.Tensor [..., H_out, W_out] + Tensor of interpolated values onto lat1, lon1 grid. + """ + + # Get input grid shape and flatten + latshape, lonshape = lat1.shape + lat1 = lat1.flatten() + lon1 = lon1.flatten() + + # Get indices of nearest points + latinds = torch.searchsorted(lat0, lat1) - 1 + loninds = torch.searchsorted(lon0, lon1) - 1 + + # Get original grid spacing + dlat = lat0[1] - lat0[0] + dlon = lon0[1] - lon0[0] + + # Get unit distances + normed_lat_distance = (lat1 - lat0[latinds]) / dlat + normed_lon_distance = (lon1 - lon0[loninds]) / dlon + + # Apply bilinear mapping + result = ( + values[..., latinds, loninds] + * (1 - normed_lat_distance) + * (1 - normed_lon_distance) + ) + result += ( + values[..., latinds, loninds + 1] + * (1 - normed_lat_distance) + * (normed_lon_distance) + ) + result += ( + values[..., latinds + 1, loninds] + * (normed_lat_distance) + * (1 - normed_lon_distance) + ) + result += ( + values[..., latinds + 1, loninds + 1] + * (normed_lat_distance) + * (normed_lon_distance) + ) + return result.reshape(*values.shape[:-2], latshape, lonshape) + + +class LatLonInterpolation(nn.Module): + """Bilinear interpolation between arbitrary grids. + + The mapped values can be on arbitrary grid, but the output grid should be + contained within the input grid. + + Initializing the interpolation object can be somewhat slow, but interpolation is + fast and can run on the GPU once initialized. Therefore, prefer to reuse the + interpolation object when possible. + + To run the interpolation on the GPU, use the .to() method of the interpolator + to move it to the GPU before running the interpolation. + + Parameters + ---------- + lat_in : torch.Tensor | ArrayLike + Tensor [H_in, W_in] of input latitude coordinates + lon_in : torch.Tensor | ArrayLike + Tensor [H_in, W_in] of input longitude coordinates + lat_out : torch.Tensor | ArrayLike + Tensor [H_out, W_out] of output latitude coordinates + lon_out : torch.Tensor | ArrayLike + Tensor [H_out, W_out] of output longitude coordinates + """ + + def __init__( + self, + lat_in: torch.Tensor | ArrayLike, + lon_in: torch.Tensor | ArrayLike, + lat_out: torch.Tensor | ArrayLike, + lon_out: torch.Tensor | ArrayLike, + ): + super().__init__() + + lat_in = ( + lat_in.cpu().numpy() if isinstance(lat_in, Tensor) else np.array(lat_in) + ) + lon_in = ( + lon_in.cpu().numpy() if isinstance(lon_in, Tensor) else np.array(lon_in) + ) + lat_out = ( + lat_out.cpu().numpy() if isinstance(lat_out, Tensor) else np.array(lat_out) + ) + lon_out = ( + lon_out.cpu().numpy() if isinstance(lon_out, Tensor) else np.array(lon_out) + ) + + (i_in, j_in) = np.mgrid[: lat_in.shape[0], : lat_in.shape[1]] + + in_points = np.stack((lat_in.ravel(), lon_in.ravel()), axis=-1) + i_interp = LinearNDInterpolator(in_points, i_in.ravel()) + j_interp = LinearNDInterpolator(in_points, j_in.ravel()) + + out_points = np.stack((lat_out.ravel(), lon_out.ravel()), axis=-1) + i_map = i_interp(out_points).reshape(lat_out.shape) + j_map = j_interp(out_points).reshape(lat_out.shape) + + i_map = torch.Tensor(i_map) + j_map = torch.Tensor(j_map) + + self.register_buffer("i_map", i_map) + self.register_buffer("j_map", j_map) + + @torch.inference_mode() + @torch.compile + def forward(self, values: Tensor) -> Tensor: + """Perform bilinear interpolation for values. + + Parameters + ---------- + values : torch.Tensor + Input values of shape [..., H_in, W_in] defined over (lat_in, lon_in) + that will be interpolated onto (lat_out, lon_out) grid. + + Returns + ------- + result : torch.Tensor + Tensor of shape [..., H_out, W_out] of interpolated values on lat1, lon1 grid. + """ + i = self.i_map + i0 = i.floor().to(dtype=torch.int64).clamp(min=0, max=values.shape[-2] - 2) + i1 = i0 + 1 + j = self.j_map + j0 = j.floor().to(dtype=torch.int64).clamp(min=0, max=values.shape[-1] - 2) + j1 = j0 + 1 + + f00 = values[..., i0, j0] + f01 = values[..., i0, j1] + f10 = values[..., i1, j0] + f11 = values[..., i1, j1] + + dj = j - j0 + f0 = torch.lerp(f00, f01, dj) + f1 = torch.lerp(f10, f11, dj) + return torch.lerp(f0, f1, i - i0) diff --git a/test/utils/test_interp.py b/test/utils/test_interp.py new file mode 100644 index 00000000..f0be252d --- /dev/null +++ b/test/utils/test_interp.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch + +from earth2studio.utils.interp import LatLonInterpolation + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.parametrize("input_type", ["zeros", "random", "gradient"]) +def test_interpolation(device, input_type): + (lat_in, lon_in) = np.meshgrid( + np.arange(35.0, 38.0, 0.25), np.arange(5.0, 8.0, 0.25), indexing="ij" + ) + (lat_out, lon_out) = np.meshgrid( + np.arange(36.0, 37.0, 0.1), np.arange(6.0, 7.0, 0.1), indexing="ij" + ) + + interp = LatLonInterpolation(lat_in, lon_in, lat_out, lon_out) + interp.to(device=device) + if input_type == "zeros": + x = torch.zeros(lat_in.shape, device=device) + elif input_type == "random": + x = torch.rand(*lat_in.shape, device=device) + elif input_type == "gradient": + x = ( + torch.linspace(0, 1, lat_in.shape[1], device=device) + .unsqueeze(0) + .repeat(lat_in.shape[0], 1) + ) + + y = interp(x) + + if input_type == "zeros": + assert (y == 0).all() + elif input_type == "random": + assert ((y >= 0) & (y <= 1)).all() + elif input_type == "gradient": + assert (y[:, 1:] > y[:, :-1]).all() + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_interpolation_analytical(device): + lat_in = np.array([[0.0, 0.0], [1.0, 1.0]]) + lon_in = np.array([[0.0, 1.0], [0.0, 1.0]]) + + (lat_out, lon_out) = np.mgrid[:1.01:0.25, :1.01:0.25] + + interp = LatLonInterpolation(lat_in, lon_in, lat_out, lon_out) + interp.to(device=device) + + x = torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device) + y = interp(x) + + y_correct = torch.tensor( + [ + [0.00, 0.25, 0.50, 0.75, 1.00], + [0.25, 0.50, 0.75, 1.00, 1.25], + [0.50, 0.75, 1.00, 1.25, 1.50], + [0.75, 1.00, 1.25, 1.50, 1.75], + [1.00, 1.25, 1.50, 1.75, 2.00], + ], + device=device, + ) + + epsilon = 1e-6 # allow for some FP roundoff + assert (abs(y - y_correct) < epsilon).all()