diff --git a/earth2grid/healpix.py b/earth2grid/healpix.py index ef0b9aa..8008416 100644 --- a/earth2grid/healpix.py +++ b/earth2grid/healpix.py @@ -56,7 +56,12 @@ except ImportError: healpixpad = None -__all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d"] +try: + import cuhpx +except ImportError: + cuhpx = None + +__all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d", "reorder"] def pad(x: torch.Tensor, padding: int) -> torch.Tensor: @@ -91,10 +96,57 @@ def pad(x: torch.Tensor, padding: int) -> torch.Tensor: return healpixpad.HEALPixPadFunction.apply(x.unsqueeze(2), padding).squeeze(2) +def _apply_cuhpx_remap(func, x, **kwargs): + shape = x.shape + x = x.view(-1, 1, shape[-1]) + nside = npix2nside(x.shape[-1]) + x = func(x.contiguous(), **kwargs, nside=nside) + x = x.contiguous() + x = x.view(shape[:-1] + (-1,)) + return x + + +def npix2nside(npix: int): + nside = math.sqrt(npix // 12) + return int(nside) + + +def npix2level(npix: int): + return nside2level(npix2nside(npix)) + + +def nside2level(nside: int): + return int(math.log2(nside)) + + class PixelOrder(Enum): RING = 0 NEST = 1 + def reorder_from_cuda(self, x, src: "PixelOrderT"): + if self == PixelOrder.RING: + return src.to_ring_cuda(x) + elif self == PixelOrder.NEST: + return src.to_nest_cuda(x) + + def to_ring_cuda(self, x: torch.Tensor): + if self == PixelOrder.RING: + return x + elif self == PixelOrder.NEST: + return _apply_cuhpx_remap(cuhpx.nest2ring, x) + + def to_nest_cuda(self, x: torch.Tensor): + if self == PixelOrder.RING: + return _apply_cuhpx_remap(cuhpx.ring2nest, x) + elif self == PixelOrder.NEST: + return x + + def to_xy_cuda(self, x: torch.Tensor, dest: "XY"): + if self == PixelOrder.RING: + return _apply_cuhpx_remap(cuhpx.ring2flat, x, clockwise=dest.clockwise, origin=dest.origin.name) + elif self == PixelOrder.NEST: + return _apply_cuhpx_remap(cuhpx.nest2flat, x, clockwise=dest.clockwise, origin=dest.origin.name) + class Compass(Enum): """Cardinal directions in counter clockwise order""" @@ -126,12 +178,42 @@ class XY: origin: Compass = Compass.S clockwise: bool = False + def reorder_from_cuda(self, x, src: "PixelOrderT"): + return src.to_xy_cuda(x, self) + + def to_xy_cuda(self, x: torch.Tensor, dest: "XY"): + return _apply_cuhpx_remap( + cuhpx.flat2flat, + x, + src_origin=self.origin.name, + src_clockwise=self.clockwise, + dest_origin=dest.origin.name, + dest_clockwise=dest.clockwise, + ) + + def to_ring_cuda(self, x: torch.Tensor): + return _apply_cuhpx_remap( + cuhpx.flat2ring, + x, + origin=self.origin.name, + clockwise=self.clockwise, + ) + + def to_nest_cuda(self, x: torch.Tensor): + return _apply_cuhpx_remap(cuhpx.flat2nest, x, origin=self.origin.name, clockwise=self.clockwise) + PixelOrderT = Union[PixelOrder, XY] HEALPIX_PAD_XY = XY(origin=Compass.N, clockwise=True) +def reorder(x: torch.Tensor, src_pixel_order: PixelOrderT, dest_pixel_order: PixelOrderT): + """Reorder x from one pixel order to another""" + grid = Grid(level=npix2level(x.size(-1)), pixel_order=src_pixel_order) + return grid.reorder(dest_pixel_order, x) + + def _convert_xyindex(nside: int, src: XY, dest: XY, i): if src.clockwise != dest.clockwise: i = _flip_xy(nside, i) @@ -261,13 +343,19 @@ def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray): def approximate_grid_length_meters(self): return approx_grid_length_meters(self._nside()) - def reorder(self, order: PixelOrderT, x: torch.Tensor) -> torch.Tensor: - """Rorder the pixels of ``x`` to have ``order``""" + def _reorder_cpu(self, x: torch.Tensor, order: PixelOrderT): output_grid = Grid(level=self.level, pixel_order=order) i_nest = output_grid._nest_ipix() i_me = self._nest2me(i_nest) return x[..., i_me] + def reorder(self, order: PixelOrderT, x: torch.Tensor) -> torch.Tensor: + """Rorder the pixels of ``x`` to have ``order``""" + if x.device.type == "cuda": + return order.reorder_from_cuda(x, self.pixel_order) + else: + return self._reorder_cpu(x, order) + def get_healpix_regridder(self, dest: "Grid"): if self.level != dest.level: return self.get_bilinear_regridder_to(dest.lat, dest.lon) diff --git a/tests/test_healpix.py b/tests/test_healpix.py index d35a4e4..9ca04e8 100644 --- a/tests/test_healpix.py +++ b/tests/test_healpix.py @@ -63,7 +63,7 @@ def test_rotate_index(rot): @pytest.mark.parametrize("origin", list(healpix.Compass)) @pytest.mark.parametrize("clockwise", [True, False]) -def test_reorder(tmp_path, origin, clockwise): +def test_Grid_reorder(tmp_path, origin, clockwise): src_grid = healpix.Grid(level=4, pixel_order=healpix.XY(origin=origin, clockwise=clockwise)) dest_grid = healpix.Grid(level=4, pixel_order=healpix.PixelOrder.NEST) @@ -145,6 +145,21 @@ def test_conv2d(): assert out.shape == (n, cout, 1, npix) +@pytest.mark.parametrize("nside", [16]) +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("src_pixel_order", [healpix.HEALPIX_PAD_XY, healpix.PixelOrder.RING, healpix.PixelOrder.NEST]) +@pytest.mark.parametrize("dest_pixel_order", [healpix.HEALPIX_PAD_XY, healpix.PixelOrder.RING, healpix.PixelOrder.NEST]) +def test_reorder(nside, src_pixel_order, dest_pixel_order, device): + # Generate some test data + if device == "cuda" and torch.cuda.device_count() == 0: + pytest.skip("no cuda devices available") + + data = torch.randn(1, 2, 12 * nside * nside, device=device) + out = healpix.reorder(data, src_pixel_order, dest_pixel_order) + out = healpix.reorder(out, dest_pixel_order, src_pixel_order) + assert torch.all(data == out), data - out + + def test_latlon_cuda_set_device_regression(): """See https://github.com/NVlabs/earth2grid/issues/6""" diff --git a/tests/test_latlon.py b/tests/test_latlon.py index 78ad304..95e9cd3 100644 --- a/tests/test_latlon.py +++ b/tests/test_latlon.py @@ -24,7 +24,7 @@ def test_lat_lon_bilinear_regrid_to(): regrid.float() lat = torch.broadcast_to(torch.tensor(src.lat), src.shape) - z = torch.tensor(lat).float() + z = lat.float() out = regrid(z) assert out.shape == dest.shape