Skip to content

Commit

Permalink
Interpolation for arbitrary grids (#159)
Browse files Browse the repository at this point in the history
* Add interpolation for arbitrary lat-lon grids
* Add interpolation file and test
* Update license, docstrings, type hints
  • Loading branch information
jleinonen authored Dec 11, 2024
1 parent a03e2a4 commit add37b6
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 78 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Interpolation between arbitrary lat-lon grids

### Changed

- Set zarr chunks for lead time to size 1 in examples.
Expand Down
6 changes: 6 additions & 0 deletions docs/modules/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ A collection of utilities to manipulate and check coordinate systems dictionarie
utils.coords.map_coords
utils.coords.split_coords

.. autosummary::
:toctree: generated/utils/
:template: class.rst

utils.interp.LatLonInterpolation

:mod:`earth2studio.utils`: Time
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
80 changes: 2 additions & 78 deletions earth2studio/models/dx/corrdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
handshake_coords,
handshake_dim,
)
from earth2studio.utils.interp import latlon_interpolation_regular
from earth2studio.utils.type import CoordSystem

VARIABLES = [
Expand Down Expand Up @@ -289,7 +290,7 @@ def _interpolate(self, x: torch.Tensor) -> torch.Tensor:
"""Interpolate from input lat/lon (self.lat, self.lon) onto output lat/lon
(self.lat_grid, self.lon_grid) using bilinear interpolation."""
input_coords = self.input_coords()
return self.interpolate(
return latlon_interpolation_regular(
x,
torch.as_tensor(input_coords["lat"], device=x.device, dtype=torch.float32),
torch.as_tensor(input_coords["lon"], device=x.device, dtype=torch.float32),
Expand Down Expand Up @@ -467,80 +468,3 @@ def unet_regression(
x_next = net(x_hat, x_lr, t_hat, class_labels).to(torch.float64)

return x_next

@staticmethod
def interpolate(
values: torch.Tensor,
lat0: torch.Tensor,
lon0: torch.Tensor,
lat1: torch.Tensor,
lon1: torch.Tensor,
) -> torch.Tensor:
"""Specialized form of bilinear interpolation intended for optimal use on GPU.
In particular, the mapped values must be defined on a regular rectangular grid,
(lat0, lon0). Both lat0 and lon0 are vectors with equal spacing.
lat1, lon1 are assumed to be 2-dimensional meshgrids with possibly unequal spacing.
Parameters
----------
values : torch.Tensor [..., W_in, H_in]
Input values defined over (lat0, lon0) that will be interpolated onto
(lat1, lon1) grid.
lat0 : torch.Tensor [W_in, ]
Vector of input latitude coordinates, assumed to be increasing with
equal spacing.
lon0 : torch.Tensor [H_in, ]
Vector of input longitude coordinates, assumed to be increasing with
equal spacing.
lat1 : torch.Tensor [W_out, H_out]
Tensor of output latitude coordinates
lon1 : torch.Tensor [W_out, H_out]
Tensor of output longitude coordinates
Returns
-------
result : torch.Tensor [..., W_out, H_out]
Tensor of interpolated values onto lat1, lon1 grid.
"""

# Get input grid shape and flatten
latshape, lonshape = lat1.shape
lat1 = lat1.flatten()
lon1 = lon1.flatten()

# Get indices of nearest points
latinds = torch.searchsorted(lat0, lat1) - 1
loninds = torch.searchsorted(lon0, lon1) - 1

# Get original grid spacing
dlat = lat0[1] - lat0[0]
dlon = lon0[1] - lon0[0]

# Get unit distances
normed_lat_distance = (lat1 - lat0[latinds]) / dlat
normed_lon_distance = (lon1 - lon0[loninds]) / dlon

# Apply bilinear mapping
result = (
values[..., latinds, loninds]
* (1 - normed_lat_distance)
* (1 - normed_lon_distance)
)
result += (
values[..., latinds, loninds + 1]
* (1 - normed_lat_distance)
* (normed_lon_distance)
)
result += (
values[..., latinds + 1, loninds]
* (normed_lat_distance)
* (1 - normed_lon_distance)
)
result += (
values[..., latinds + 1, loninds + 1]
* (normed_lat_distance)
* (normed_lon_distance)
)
return result.reshape(*values.shape[:-2], latshape, lonshape)
195 changes: 195 additions & 0 deletions earth2studio/utils/interp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch
from numpy.typing import ArrayLike
from scipy.interpolate import LinearNDInterpolator
from torch import Tensor, nn


def latlon_interpolation_regular(
values: torch.Tensor,
lat0: torch.Tensor,
lon0: torch.Tensor,
lat1: torch.Tensor,
lon1: torch.Tensor,
) -> torch.Tensor:
"""Specialized form of bilinear interpolation intended for optimal use on GPU.
In particular, the mapped values must be defined on a regular rectangular grid,
(lat0, lon0). Both lat0 and lon0 are vectors with equal spacing.
lat1, lon1 are assumed to be 2-dimensional meshgrids with possibly unequal spacing.
Parameters
----------
values : torch.Tensor [..., H_in, W_in]
Input values defined over (lat0, lon0) that will be interpolated onto
(lat1, lon1) grid.
lat0 : torch.Tensor [H_in, ]
Vector of input latitude coordinates, assumed to be increasing with
equal spacing.
lon0 : torch.Tensor [W_in, ]
Vector of input longitude coordinates, assumed to be increasing with
equal spacing.
lat1 : torch.Tensor [H_out, W_out]
Tensor of output latitude coordinates
lon1 : torch.Tensor [H_out, W_out]
Tensor of output longitude coordinates
Returns
-------
result : torch.Tensor [..., H_out, W_out]
Tensor of interpolated values onto lat1, lon1 grid.
"""

# Get input grid shape and flatten
latshape, lonshape = lat1.shape
lat1 = lat1.flatten()
lon1 = lon1.flatten()

# Get indices of nearest points
latinds = torch.searchsorted(lat0, lat1) - 1
loninds = torch.searchsorted(lon0, lon1) - 1

# Get original grid spacing
dlat = lat0[1] - lat0[0]
dlon = lon0[1] - lon0[0]

# Get unit distances
normed_lat_distance = (lat1 - lat0[latinds]) / dlat
normed_lon_distance = (lon1 - lon0[loninds]) / dlon

# Apply bilinear mapping
result = (
values[..., latinds, loninds]
* (1 - normed_lat_distance)
* (1 - normed_lon_distance)
)
result += (
values[..., latinds, loninds + 1]
* (1 - normed_lat_distance)
* (normed_lon_distance)
)
result += (
values[..., latinds + 1, loninds]
* (normed_lat_distance)
* (1 - normed_lon_distance)
)
result += (
values[..., latinds + 1, loninds + 1]
* (normed_lat_distance)
* (normed_lon_distance)
)
return result.reshape(*values.shape[:-2], latshape, lonshape)


class LatLonInterpolation(nn.Module):
"""Bilinear interpolation between arbitrary grids.
The mapped values can be on arbitrary grid, but the output grid should be
contained within the input grid.
Initializing the interpolation object can be somewhat slow, but interpolation is
fast and can run on the GPU once initialized. Therefore, prefer to reuse the
interpolation object when possible.
To run the interpolation on the GPU, use the .to() method of the interpolator
to move it to the GPU before running the interpolation.
Parameters
----------
lat_in : torch.Tensor | ArrayLike
Tensor [H_in, W_in] of input latitude coordinates
lon_in : torch.Tensor | ArrayLike
Tensor [H_in, W_in] of input longitude coordinates
lat_out : torch.Tensor | ArrayLike
Tensor [H_out, W_out] of output latitude coordinates
lon_out : torch.Tensor | ArrayLike
Tensor [H_out, W_out] of output longitude coordinates
"""

def __init__(
self,
lat_in: torch.Tensor | ArrayLike,
lon_in: torch.Tensor | ArrayLike,
lat_out: torch.Tensor | ArrayLike,
lon_out: torch.Tensor | ArrayLike,
):
super().__init__()

lat_in = (
lat_in.cpu().numpy() if isinstance(lat_in, Tensor) else np.array(lat_in)
)
lon_in = (
lon_in.cpu().numpy() if isinstance(lon_in, Tensor) else np.array(lon_in)
)
lat_out = (
lat_out.cpu().numpy() if isinstance(lat_out, Tensor) else np.array(lat_out)
)
lon_out = (
lon_out.cpu().numpy() if isinstance(lon_out, Tensor) else np.array(lon_out)
)

(i_in, j_in) = np.mgrid[: lat_in.shape[0], : lat_in.shape[1]]

in_points = np.stack((lat_in.ravel(), lon_in.ravel()), axis=-1)
i_interp = LinearNDInterpolator(in_points, i_in.ravel())
j_interp = LinearNDInterpolator(in_points, j_in.ravel())

out_points = np.stack((lat_out.ravel(), lon_out.ravel()), axis=-1)
i_map = i_interp(out_points).reshape(lat_out.shape)
j_map = j_interp(out_points).reshape(lat_out.shape)

i_map = torch.Tensor(i_map)
j_map = torch.Tensor(j_map)

self.register_buffer("i_map", i_map)
self.register_buffer("j_map", j_map)

@torch.inference_mode()
@torch.compile
def forward(self, values: Tensor) -> Tensor:
"""Perform bilinear interpolation for values.
Parameters
----------
values : torch.Tensor
Input values of shape [..., H_in, W_in] defined over (lat_in, lon_in)
that will be interpolated onto (lat_out, lon_out) grid.
Returns
-------
result : torch.Tensor
Tensor of shape [..., H_out, W_out] of interpolated values on lat1, lon1 grid.
"""
i = self.i_map
i0 = i.floor().to(dtype=torch.int64).clamp(min=0, max=values.shape[-2] - 2)
i1 = i0 + 1
j = self.j_map
j0 = j.floor().to(dtype=torch.int64).clamp(min=0, max=values.shape[-1] - 2)
j1 = j0 + 1

f00 = values[..., i0, j0]
f01 = values[..., i0, j1]
f10 = values[..., i1, j0]
f11 = values[..., i1, j1]

dj = j - j0
f0 = torch.lerp(f00, f01, dj)
f1 = torch.lerp(f10, f11, dj)
return torch.lerp(f0, f1, i - i0)
82 changes: 82 additions & 0 deletions test/utils/test_interp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest
import torch

from earth2studio.utils.interp import LatLonInterpolation


@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("input_type", ["zeros", "random", "gradient"])
def test_interpolation(device, input_type):
(lat_in, lon_in) = np.meshgrid(
np.arange(35.0, 38.0, 0.25), np.arange(5.0, 8.0, 0.25), indexing="ij"
)
(lat_out, lon_out) = np.meshgrid(
np.arange(36.0, 37.0, 0.1), np.arange(6.0, 7.0, 0.1), indexing="ij"
)

interp = LatLonInterpolation(lat_in, lon_in, lat_out, lon_out)
interp.to(device=device)
if input_type == "zeros":
x = torch.zeros(lat_in.shape, device=device)
elif input_type == "random":
x = torch.rand(*lat_in.shape, device=device)
elif input_type == "gradient":
x = (
torch.linspace(0, 1, lat_in.shape[1], device=device)
.unsqueeze(0)
.repeat(lat_in.shape[0], 1)
)

y = interp(x)

if input_type == "zeros":
assert (y == 0).all()
elif input_type == "random":
assert ((y >= 0) & (y <= 1)).all()
elif input_type == "gradient":
assert (y[:, 1:] > y[:, :-1]).all()


@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
def test_interpolation_analytical(device):
lat_in = np.array([[0.0, 0.0], [1.0, 1.0]])
lon_in = np.array([[0.0, 1.0], [0.0, 1.0]])

(lat_out, lon_out) = np.mgrid[:1.01:0.25, :1.01:0.25]

interp = LatLonInterpolation(lat_in, lon_in, lat_out, lon_out)
interp.to(device=device)

x = torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device)
y = interp(x)

y_correct = torch.tensor(
[
[0.00, 0.25, 0.50, 0.75, 1.00],
[0.25, 0.50, 0.75, 1.00, 1.25],
[0.50, 0.75, 1.00, 1.25, 1.50],
[0.75, 1.00, 1.25, 1.50, 1.75],
[1.00, 1.25, 1.50, 1.75, 2.00],
],
device=device,
)

epsilon = 1e-6 # allow for some FP roundoff
assert (abs(y - y_correct) < epsilon).all()

0 comments on commit add37b6

Please sign in to comment.