From 271b5a1c93a491f2fa9c80cb22969b8fb2015a3b Mon Sep 17 00:00:00 2001
From: David Pruitt <dpruitt@nvidia.com>
Date: Fri, 23 Aug 2024 21:25:38 +0000
Subject: [PATCH] Move cuda padding routines to healpix.py, check for cuda
 before attempting cuda install

---
 earth2grid/healpix.py                    | 87 ++++++++++++++++++++--
 earth2grid/healpixpad.py                 | 91 ------------------------
 earth2grid/third_party/zephyr/healpix.py | 11 ---
 setup.py                                 | 35 ++++++---
 4 files changed, 105 insertions(+), 119 deletions(-)
 delete mode 100644 earth2grid/healpixpad.py

diff --git a/earth2grid/healpix.py b/earth2grid/healpix.py
index 8d80737..44a86c6 100644
--- a/earth2grid/healpix.py
+++ b/earth2grid/healpix.py
@@ -33,6 +33,7 @@
 """
 
 import math
+import warnings
 from dataclasses import dataclass
 from enum import Enum
 from typing import Union
@@ -48,8 +49,17 @@
 except ImportError:
     pv = None
 
-from earth2grid import base, healpixpad
-from earth2grid.third_party.zephyr.healpix import healpix_pad
+try:
+    import healpixpad_cuda
+
+    healpixpad_cuda_avail = True
+except ImportError:
+    healpixpad_cuda_avail = False
+    warnings.warn("healpixpad_cuda module not available, reverting to CPU for all padding routines")
+
+
+from earth2grid import base
+from earth2grid.third_party.zephyr.healpix import healpix_pad as heapixpad_cpu
 
 __all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d"]
 
@@ -59,7 +69,7 @@ def pad(x: torch.Tensor, padding: int) -> torch.Tensor:
     Pad each face consistently with its according neighbors in the HEALPix
 
     Args:
-        x: The input tensor of shape [N, F, H, W]
+        x: The input tensor of shape [N, F, H, W] or [N, F, C, H, W]
         padding: the amount of padding
 
     Returns:
@@ -80,10 +90,12 @@ def pad(x: torch.Tensor, padding: int) -> torch.Tensor:
         torch.Size([1, 12, 18, 18])
 
     """
-    if x.device.type != 'cuda':
-        return healpix_pad(x, padding)
+    if x.device.type != 'cuda' or not healpixpad_cuda_avail:
+        return heapixpad_cpu(x, padding)
+    elif x.ndim == 5:
+        return HEALPixPadFunction.apply(x, padding)
     else:
-        return healpixpad.HEALPixPadFunction.apply(x.unsqueeze(2), padding).squeeze(2)
+        return HEALPixPadFunction.apply(x.unsqueeze(2), padding).squeeze(2)
 
 
 class PixelOrder(Enum):
@@ -293,6 +305,69 @@ def to_image(self, x: torch.Tensor, fill_value=torch.nan) -> torch.Tensor:
         return output
 
 
+class HEALPixPadFunction(torch.autograd.Function):
+    """
+    A torch autograd class that pads a healpixpad xy tensor
+    """
+
+    @staticmethod
+    def forward(ctx, input, pad):
+        """
+        The forward pass of the padding class
+
+        Parameters
+        ----------
+        input: torch.tensor
+            The tensor to pad, must have 5 dimensions and be in (B, F, C, H, W) format
+            where F == 12 and H == W
+        pad: int
+            The amount to pad each face of the tensor
+
+        Returns
+        -------
+        torch.tensor: The padded tensor
+        """
+        ctx.pad = pad
+        if input.ndim != 5:
+            raise ValueError(
+                f"Input tensor must be have 5 dimensions (B, F, C, H, W), got {len(input.shape)} dimensions instead"
+            )
+        if input.shape[1] != 12:
+            raise ValueError(
+                f"Input tensor must be have 5 dimensions (B, F, C, H, W), with F == 12, got {input.shape[1]}"
+            )
+        if input.shape[3] != input.shape[4]:
+            raise ValueError(
+                f"Input tensor must be have 5 dimensions (B, F, C, H, W), with H == @, got {input.shape[3]},  {input.shape[4]}"
+            )
+        # make contiguous
+        input = input.contiguous()
+        out = healpixpad_cuda.forward(input, pad)[0]
+        return out
+
+    @staticmethod
+    def backward(ctx, grad):
+        """
+        The forward pass of the padding class
+
+        Parameters
+        ----------
+        input: torch.tensor
+            The tensor to pad, must have 5 dimensions and be in (B, F, C, H, W) format
+            where F == 12 and H == W
+        pad: int
+            The amount to pad each face of the tensor
+
+        Returns
+        -------
+        torch.tensor: The padded tensor
+        """
+        pad = ctx.pad
+        grad = grad.contiguous()
+        out = healpixpad_cuda.backward(grad, pad)[0]
+        return out, None
+
+
 # nside = 2^ZOOM_LEVELS
 ZOOM_LEVELS = 20
 
diff --git a/earth2grid/healpixpad.py b/earth2grid/healpixpad.py
deleted file mode 100644
index 60df9b7..0000000
--- a/earth2grid/healpixpad.py
+++ /dev/null
@@ -1,91 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License
-#
-# Written by Mauro Bisson <maurob@nvidia.com> and Thorsten Kurth <tkurth@nvidia.com>.
-
-
-import healpixpad_cuda
-import torch
-
-
-class HEALPixPadFunction(torch.autograd.Function):
-    """
-    A torch autograd class that pads a healpixpad xy tensor
-    """
-
-    @staticmethod
-    def forward(ctx, input, pad):
-        """
-        The forward pass of the padding class
-
-        Parameters
-        ----------
-        input: torch.tensor
-            The tensor to pad, must have 5 dimensions and be in (B, F, C, H, W) format
-            where F == 12 and H == W
-        pad: int
-            The amount to pad each face of the tensor
-
-        Returns
-        -------
-        torch.tensor: The padded tensor
-        """
-        ctx.pad = pad
-        if len(input.shape) != 5:
-            raise ValueError(
-                f"Input tensor must be have 4 dimensions (B, F, C, H, W), got {len(input.shape)} dimensions instead"
-            )
-        # make contiguous
-        input = input.contiguous()
-        out = healpixpad_cuda.forward(input, pad)[0]
-        return out
-
-    @staticmethod
-    def backward(ctx, grad):
-        pad = ctx.pad
-        grad = grad.contiguous()
-        out = healpixpad_cuda.backward(grad, pad)[0]
-        return out, None
-
-
-class HEALPixPad(torch.nn.Module):
-    """
-    A torch module that handles padding of healpixpad xy tensors
-
-    Paramaeters
-    -----------
-    padding: int
-        The amount to pad the tensors
-    """
-
-    def __init__(self, padding: int):
-        super(HEALPixPad, self).__init__()
-        self.padding = padding
-
-    def forward(self, input: torch.Tensor) -> torch.Tensor:
-        """
-        The forward pass of the padding class
-
-        Parameters
-        ----------
-        input: torch.tensor
-            The tensor to pad, must have 5 dimensions and be in (B, F, C, H, W) format
-            where F == 12 and H == W
-
-        Returns
-        -------
-        torch.tensor: The padded tensor
-        """
-        return HEALPixPadFunction.apply(input, self.padding)
diff --git a/earth2grid/third_party/zephyr/healpix.py b/earth2grid/third_party/zephyr/healpix.py
index fe70810..fdcd739 100644
--- a/earth2grid/third_party/zephyr/healpix.py
+++ b/earth2grid/third_party/zephyr/healpix.py
@@ -23,21 +23,10 @@
 
 """
 
-import sys
 
 import torch
 import torch as th
 
-sys.path.append('/home/disk/quicksilver/nacc/dlesm/HealPixPad')
-have_healpixpad = False
-try:
-    from healpixpad import HEALPixPad  # noqa
-
-    have_healpixpad = True
-except ImportError:
-    print("Warning, cannot find healpixpad module")
-    have_healpixpad = False
-
 
 def healpix_pad(x: torch.Tensor, padding: int, enable_nhwc: bool = False) -> torch.Tensor:
     """
diff --git a/setup.py b/setup.py
index 9bcfce0..dc5d496 100644
--- a/setup.py
+++ b/setup.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 import os
 import subprocess
+import warnings
 from typing import List
 
 from setuptools import setup
@@ -59,20 +60,32 @@ def get_compiler():
     "earth2grid/csrc/healpixpad/healpixpad_cuda_fwd.cu",
     "earth2grid/csrc/healpixpad/healpixpad_cuda_bwd.cu",
 ]
-setup(
-    name='earth2grid',
-    ext_modules=[
-        cpp_extension.CppExtension(
-            'earth2grid._healpix_bare',
-            src_files,
-            extra_compile_args=extra_compile_args,
-            include_dirs=[os.path.abspath("earth2grid/csrc"), os.path.abspath("earth2grid/third_party/healpix_bare")],
-        ),
-        cpp_extension.CUDAExtension(
+
+ext_modules = [
+    cpp_extension.CppExtension(
+        'earth2grid._healpix_bare',
+        src_files,
+        extra_compile_args=extra_compile_args,
+        include_dirs=[os.path.abspath("earth2grid/csrc"), os.path.abspath("earth2grid/third_party/healpix_bare")],
+    ),
+]
+
+try:
+    from torch.utils.cpp_extension import CUDAExtension
+
+    ext_modules.append(
+        CUDAExtension(
             name='healpixpad_cuda',
             sources=cuda_src_files,
             extra_compile_args={'nvcc': ['-O2']},
         ),
-    ],
+    )
+except ImportError:
+    warnings.warn("Cuda extensions for torch not found, skipping cuda healpix padding module")
+
+
+setup(
+    name='earth2grid',
+    ext_modules=ext_modules,
     cmdclass={'build_ext': cpp_extension.BuildExtension},
 )