From 81b8a47cd8f1fd0e581262a594ab31e5e4df3283 Mon Sep 17 00:00:00 2001
From: "Noah D. Brenowitz" <nbrenowitz@nvidia.com>
Date: Mon, 9 Dec 2024 21:54:15 -0500
Subject: [PATCH] Add ying yang grid

This PR adds a new projection based grid, the ying yang grid. It
restructures some of the lambert conformal logic a bit, so Simon should
take a look.
---
 .gitignore                |   7 +++
 earth2grid/_regrid.py     |  10 +++-
 earth2grid/latlon.py      |  22 +++++++--
 earth2grid/lcc.py         |  74 +++-------------------------
 earth2grid/projections.py | 100 ++++++++++++++++++++++++++++++++++++++
 earth2grid/spatial.py     |   7 +++
 earth2grid/yinyang.py     |  94 +++++++++++++++++++++++++++++++++++
 examples/yinyang.py       |  65 +++++++++++++++++++++++++
 pyproject.toml            |   4 +-
 tests/test_spatial.py     |  41 ++++++++++++++++
 tests/test_yingyang.py    |  54 ++++++++++++++++++++
 11 files changed, 403 insertions(+), 75 deletions(-)
 create mode 100644 earth2grid/projections.py
 create mode 100644 earth2grid/yinyang.py
 create mode 100644 examples/yinyang.py
 create mode 100644 tests/test_spatial.py
 create mode 100644 tests/test_yingyang.py

diff --git a/.gitignore b/.gitignore
index e1f3e35..bb66842 100644
--- a/.gitignore
+++ b/.gitignore
@@ -116,7 +116,14 @@ test_grid_visualize.png
 *.png
 *.jpg
 *.jpeg
+*.gif
 public/
 
 a.out
 *.o
+
+# editor backup files
+# helix
+\#*\#
+# emacs
+*~
diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py
index daba8bb..48c6cd6 100644
--- a/earth2grid/_regrid.py
+++ b/earth2grid/_regrid.py
@@ -16,8 +16,13 @@
 from typing import Dict, Sequence
 
 import einops
-import netCDF4 as nc
 import torch
+
+try:
+    import netCDF4 as nc
+except ImportError:
+    nc = None
+
 from scipy import spatial
 
 from earth2grid.spatial import ang2vec, haversine_distance
@@ -59,6 +64,9 @@ def from_state_dict(d: Dict[str, torch.Tensor]) -> "Regridder":
 class TempestRegridder(torch.nn.Module):
     def __init__(self, file_path):
         super().__init__()
+        if nc is None:
+            raise ImportError("netCDF4 not imported. Please install for this feature.")
+
         dataset = nc.Dataset(file_path)
         self.lat = dataset["latc_b"][:]
         self.lon = dataset["lonc_b"][:]
diff --git a/earth2grid/latlon.py b/earth2grid/latlon.py
index 4b52669..c6abcbe 100644
--- a/earth2grid/latlon.py
+++ b/earth2grid/latlon.py
@@ -25,14 +25,18 @@
 
 
 class LatLonGrid(base.Grid):
-    def __init__(self, lat: list[float], lon: list[float]):
+    def __init__(self, lat: list[float], lon: list[float], cylinder: bool = True):
         """
         Args:
             lat: center of lat cells
             lon: center of lon cells
+            cylinder: if true, then lon is considered a periodic coordinate
+                on cylinder so that interpolation wraps around the edge.
+                Otherwise, it is assumed to be a finite plane.
         """
         self._lat = lat
         self._lon = lon
+        self.cylinder = cylinder
 
     @property
     def lat(self):
@@ -48,7 +52,7 @@ def shape(self):
 
     def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
         """Get regridder to the specified lat and lon points"""
-        return _RegridFromLatLon(self, lat, lon)
+        return _RegridFromLatLon(self, lat, lon, cylinder=self.cylinder)
 
     def _lonb(self):
         edges = (self.lon[1:] + self.lon[:-1]) / 2
@@ -78,15 +82,22 @@ def to_pyvista(self):
 class _RegridFromLatLon(torch.nn.Module):
     """Regrid from lat-lon to unstructured grid with bilinear interpolation"""
 
-    def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray):
+    def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray, cylinder: bool = True):
+        """
+        Args:
+            cylinder: if True than lon is assumed to be periodic
+        """
         super().__init__()
+        self.cylinder = cylinder
 
         lat, lon = np.broadcast_arrays(lat, lon)
         self.shape = lat.shape
 
         # TODO add device switching logic (maybe use torch registers for this
         # info)
-        long = np.concatenate([src.lon.ravel(), [360]], axis=-1)
+        long = src.lon.ravel()
+        if self.cylinder:
+            long = np.concatenate([long, [360]], axis=-1)
         long_t = torch.from_numpy(long)
 
         # flip the order latg since bilinear only works with increasing coordinate values
@@ -104,7 +115,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
         # pad z in lon direction
         # only works for a global grid
         # TODO generalize this to local grids and add options for padding
-        x = torch.cat([x, x[..., 0:1]], axis=-1)
+        if self.cylinder:
+            x = torch.cat([x, x[..., 0:1]], axis=-1)
         out = self._bilinear(x)
         return out.view(out.shape[:-1] + self.shape)
 
diff --git a/earth2grid/lcc.py b/earth2grid/lcc.py
index 55146d0..7db6368 100644
--- a/earth2grid/lcc.py
+++ b/earth2grid/lcc.py
@@ -13,10 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import numpy as np
-import torch
 
-from earth2grid import base
-from earth2grid._regrid import BilinearInterpolator
+from earth2grid import projections
 
 try:
     import pyvista as pv
@@ -31,7 +29,10 @@
 ]
 
 
-class LambertConformalConicProjection:
+LambertConformalConicGrid = projections.Grid
+
+
+class LambertConformalConicProjection(projections.Projection):
     def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: float):
         """
 
@@ -108,69 +109,6 @@ def inverse_project(self, x, y):
 HRRR_CONUS_PROJECTION = LambertConformalConicProjection(lon0=-97.5, lat0=38.5, lat1=38.5, lat2=38.5, radius=6371229.0)
 
 
-class LambertConformalConicGrid(base.Grid):
-    # nothing here is specific to the projection, so could be shared by any projected rectilinear grid
-    def __init__(self, projection: LambertConformalConicProjection, x, y):
-        """
-        Args:
-            projection: LambertConformalConicProjection object
-            x: range of x values
-            y: range of y values
-
-        """
-        self.projection = projection
-
-        self.x = np.array(x)
-        self.y = np.array(y)
-
-    @property
-    def lat_lon(self):
-        mesh_x, mesh_y = np.meshgrid(self.x, self.y)
-        return self.projection.inverse_project(mesh_x, mesh_y)
-
-    @property
-    def lat(self):
-        return self.lat_lon[0]
-
-    @property
-    def lon(self):
-        return self.lat_lon[1]
-
-    @property
-    def shape(self):
-        return (len(self.y), len(self.x))
-
-    def __getitem__(self, idxs):
-        yidxs, xidxs = idxs
-        return LambertConformalConicGrid(self.projection, x=self.x[xidxs], y=self.y[yidxs])
-
-    def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
-        """Get regridder to the specified lat and lon points"""
-
-        x, y = self.projection.project(lat, lon)
-
-        return BilinearInterpolator(
-            x_coords=torch.from_numpy(self.x),
-            y_coords=torch.from_numpy(self.y),
-            x_query=torch.from_numpy(x),
-            y_query=torch.from_numpy(y),
-        )
-
-    def visualize(self, data):
-        raise NotImplementedError()
-
-    def to_pyvista(self):
-        if pv is None:
-            raise ImportError("Need to install pyvista")
-
-        lat, lon = self.lat_lon
-        y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
-        x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
-        z = np.sin(np.deg2rad(lat))
-        grid = pv.StructuredGrid(x, y, z)
-        return grid
-
-
 def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059):
     # coordinates of point in top-left corner
     lat0 = 21.138123
@@ -183,7 +121,7 @@ def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059):
     x = [x0 + i * scale for i in range(ix0, ix0 + nx)]
     y = [y0 + i * scale for i in range(iy0, iy0 + ny)]
 
-    return LambertConformalConicGrid(HRRR_CONUS_PROJECTION, x, y)
+    return projections.Grid(HRRR_CONUS_PROJECTION, x, y)
 
 
 # Grid used by HRRR CONUS (Continental US) data
diff --git a/earth2grid/projections.py b/earth2grid/projections.py
new file mode 100644
index 0000000..4a71dfc
--- /dev/null
+++ b/earth2grid/projections.py
@@ -0,0 +1,100 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import abc
+
+import numpy as np
+import torch
+
+from earth2grid import base
+from earth2grid._regrid import BilinearInterpolator
+
+try:
+    import pyvista as pv
+except ImportError:
+    pv = None
+
+
+class Projection(abc.ABC):
+    @abc.abstractmethod
+    def project(self, lat: np.ndarray, lon: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+        """
+        Compute the projected x,y from lat,lon.
+        """
+        pass
+
+    @abc.abstractmethod
+    def inverse_project(self, x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+        """
+        Compute the lat,lon from the projected x,y.
+        """
+        pass
+
+
+class Grid(base.Grid):
+    # nothing here is specific to the projection, so could be shared by any projected rectilinear grid
+    def __init__(self, projection: Projection, x, y):
+        """
+        Args:
+            x: range of x values
+            y: range of y values
+
+        """
+        self.projection = projection
+
+        self.x = np.array(x)
+        self.y = np.array(y)
+
+    @property
+    def lat_lon(self):
+        mesh_x, mesh_y = np.meshgrid(self.x, self.y, indexing='ij')
+        return self.projection.inverse_project(mesh_x, mesh_y)
+
+    @property
+    def lat(self):
+        return self.lat_lon[0]
+
+    @property
+    def lon(self):
+        return self.lat_lon[1]
+
+    @property
+    def shape(self):
+        return (len(self.x), len(self.y))
+
+    def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
+        """Get regridder to the specified lat and lon points"""
+
+        x, y = self.projection.project(lat, lon)
+
+        return BilinearInterpolator(
+            x_coords=torch.from_numpy(self.x),
+            y_coords=torch.from_numpy(self.y),
+            x_query=torch.from_numpy(x),
+            y_query=torch.from_numpy(y),
+        )
+
+    def visualize(self, data):
+        raise NotImplementedError()
+
+    def to_pyvista(self):
+        if pv is None:
+            raise ImportError("Need to install pyvista")
+
+        lat, lon = self.lat_lon
+        y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
+        x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
+        z = np.sin(np.deg2rad(lat))
+        grid = pv.StructuredGrid(x, y, z)
+        return grid
diff --git a/earth2grid/spatial.py b/earth2grid/spatial.py
index 974108e..87a161b 100644
--- a/earth2grid/spatial.py
+++ b/earth2grid/spatial.py
@@ -44,3 +44,10 @@ def ang2vec(lon, lat):
     y = torch.cos(lat) * torch.sin(lon)
     z = torch.sin(lat)
     return (x, y, z)
+
+
+def vec2ang(x, y, z):
+    """convert lon,lat in radians to cartesian coordinates"""
+    lat = torch.asin(z)
+    lon = torch.atan2(y, x)
+    return lon, lat
diff --git a/earth2grid/yinyang.py b/earth2grid/yinyang.py
new file mode 100644
index 0000000..258d13a
--- /dev/null
+++ b/earth2grid/yinyang.py
@@ -0,0 +1,94 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Yin Yang
+
+the ying yang grid is an overset grid for the sphere containing two faces
+- Yin: a normal lat lon grid for 2/3 of lon, and 2/3 of lat
+- Yang: Yin but with pole along x
+
+
+Key facts
+
+ying
+lon: [-3 pi /4  - delta, 3 pi / 4 + delta ]
+lat: [-pi / 4 - delta, pi / 4 + delta]
+
+ying to yang transformation: alpha = 0, beta = 90, gamma = 180
+
+(x, y, z) - > (-x, z, y)
+
+"""
+import math
+
+import numpy as np
+import torch
+
+from earth2grid import latlon, projections, spatial
+
+
+def Ying(nlat: int, nlon: int, delta: int):
+    """The ying grid
+
+    nlat, and nlon are as in the latlon.equiangular_latlon_grid and
+    refer to full sphere.
+
+    ``nlat`` includes the poles [90, -90], and ``nlon`` is [0, 2 pi).
+
+    ``delta`` is the amount of overlap in terms of number of grid points.
+
+    """
+    # TODO test that min(lat) = -max(lat), and for lon too
+
+    dlat = 180 / (nlat - 1)
+    dlon = 360 / nlon
+
+    n = math.ceil(3 * nlon / 8)
+    lon = np.arange(-n - delta, n + delta + 1) * dlon
+    lat = np.arange(-(nlat - 1) // 4 - delta, (nlat + 1) // 4 + delta + 1) * dlat
+
+    return latlon.LatLonGrid(lat.tolist(), lon.tolist(), cylinder=False)
+
+
+class YangProjection(projections.Projection):
+    def project(self, lat: np.ndarray, lon: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+        """
+        Compute the projected x,y from lat,lon.
+        """
+        lat = torch.from_numpy(lat)
+        lon = torch.from_numpy(lon)
+
+        lat = torch.deg2rad(lat)
+        lon = torch.deg2rad(lon)
+
+        x, y, z = spatial.ang2vec(lat=lat, lon=lon)
+        x, y, z = -x, z, y
+        lon, lat = spatial.vec2ang(x, y, z)
+
+        lat = torch.rad2deg(lat)
+        lon = torch.rad2deg(lon)
+
+        return lat.numpy(), lon.numpy()
+
+    def inverse_project(self, x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+        """
+        Compute the lat,lon from the projected x,y.
+        """
+        # ying-yang is its own inverse
+        return self.project(x, y)
+
+
+def Yang(nlat, nlon, delta):
+    ying = Ying(nlat, nlon, delta)
+    return projections.Grid(YangProjection(), ying.lat, ying.lon)
diff --git a/examples/yinyang.py b/examples/yinyang.py
new file mode 100644
index 0000000..a03500d
--- /dev/null
+++ b/examples/yinyang.py
@@ -0,0 +1,65 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.from earth2grid.yinyang import Ying, Yang, YangProjection
+import matplotlib.pyplot as plt
+import numpy as np
+import pyvista as pv
+import torch
+
+from earth2grid.yinyang import Yang, Ying
+
+nlat = 721
+nlon = 1440
+delta = 64
+
+nlat = 37
+nlon = 72
+delta = 0
+
+ying = Ying(nlat, nlon, delta)
+yang = Yang(nlat, nlon, delta)
+
+
+def structured_grid(lon, lat):
+    y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
+    x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
+    z = np.sin(np.deg2rad(lat))
+    grid = pv.StructuredGrid(x, y, z)
+    return grid
+
+
+lon, lat = np.meshgrid(ying.lon, ying.lat)
+ying_g = structured_grid(lon, lat)
+yang_g = structured_grid(yang.lon, yang.lat)
+
+pl = pv.Plotter()
+pl.add_mesh(ying_g, show_edges=True)
+# scale slightly so yang is on top
+pl.add_mesh(yang_g.scale(1.002), show_edges=True, color="red", opacity=0.5)
+pl.show()
+
+
+y2y = ying.get_bilinear_regridder_to(yang.lat, yang.lon)
+y2y.float()
+
+x = torch.ones(ying.shape)
+y = y2y(x)
+y = y.reshape(yang.shape)
+print("mask", torch.isnan(y).sum() / y.numel())
+
+plt.figure()
+# TODO fix yang.shape, it is the opposite it should be
+plt.imshow(y.reshape(*ying.shape))
+plt.colorbar()
+plt.show()
diff --git a/pyproject.toml b/pyproject.toml
index 0b54520..0b4892a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,7 +24,6 @@ classifiers = [
 
 dependencies = [
     "einops>=0.7.0",
-    "netCDF4>=1.6.5",
     "numpy>=1.23.3",
     "torch>=2.0.1",
     "scipy"
@@ -35,6 +34,9 @@ dependencies = [
 
 
 [project.optional-dependencies]
+all = [
+    "netCDF4>=1.6.5",
+]
 viz = [
     "pyvista>=0.43.2",
     "matplotlib",
diff --git a/tests/test_spatial.py b/tests/test_spatial.py
new file mode 100644
index 0000000..ffb772f
--- /dev/null
+++ b/tests/test_spatial.py
@@ -0,0 +1,41 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import pytest
+import torch
+
+from earth2grid import spatial
+
+
+def test_vec2ang2vec():
+    vec = torch.randn(3)
+    vec /= torch.norm(vec)
+    x, y, z = vec
+
+    lon, lat = spatial.vec2ang(x, y, z)
+    x1, y1, z1 = spatial.ang2vec(lon, lat)
+    assert torch.allclose(torch.stack([x1, y1, z1]), torch.stack([x, y, z]))
+
+
+def test_vec2ang():
+    lon, lat = spatial.vec2ang(torch.tensor(0), torch.tensor(0), torch.tensor(1))
+    assert lat == pytest.approx(math.pi / 2)
+
+    lon, _ = spatial.vec2ang(torch.tensor(1), torch.tensor(0), torch.tensor(0))
+    assert lon == pytest.approx(0)
+
+    lon, _ = spatial.vec2ang(torch.tensor(0), torch.tensor(1), torch.tensor(0))
+    assert lon == pytest.approx(math.pi / 2)
diff --git a/tests/test_yingyang.py b/tests/test_yingyang.py
new file mode 100644
index 0000000..cdae764
--- /dev/null
+++ b/tests/test_yingyang.py
@@ -0,0 +1,54 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.from earth2grid.yinyang import Ying, Yang, YangProjection
+import numpy as np
+import pytest
+import torch
+
+from earth2grid.yinyang import Yang, Ying
+
+
+def test_yingyang():
+    nlat = 721
+    nlon = 1440
+    delta = 64
+
+    nlat = 37
+    nlon = 72
+    delta = 0
+
+    ying = Ying(nlat, nlon, delta)
+    yang = Yang(nlat, nlon, delta)
+
+    assert ying.lat.min() == pytest.approx(-45)
+    assert ying.lat.max() == pytest.approx(45)
+    assert ying.lat.min() == -ying.lat.max()
+    assert ying.lon.min() == -ying.lon.max()
+    y2y = ying.get_bilinear_regridder_to(yang.lat, yang.lon)
+    y2y.float()
+
+    x = torch.ones(ying.shape)
+    y = y2y(x)
+    mask = ~torch.isnan(y)
+    # this is a regression check. will need to verify and change for different res
+    fraction_missing = 1 - mask.sum().item() / mask.numel()
+    assert fraction_missing == pytest.approx(0.8038, abs=0.01)
+    assert torch.allclose(y[mask], torch.tensor(1).float())
+
+    # more complex check
+    lat, lon = np.meshgrid(ying.lat, ying.lon, indexing='ij')
+    x = torch.as_tensor(lat, dtype=torch.float).deg2rad().cos()
+    y = y2y(x)
+    expected = torch.as_tensor(yang.lat).deg2rad().cos().float()
+    assert torch.allclose(y[mask], expected[mask], atol=0.01)