Skip to content

Commit

Permalink
Merge pull request #8 from NVlabs/cuhpx-reordering
Browse files Browse the repository at this point in the history
use cuhpx for reordering on gpu
  • Loading branch information
nbren12 authored Aug 24, 2024
2 parents 44297de + ae73fe0 commit 5b9b294
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 5 deletions.
94 changes: 91 additions & 3 deletions earth2grid/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@
except ImportError:
healpixpad = None

__all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d"]
try:
import cuhpx
except ImportError:
cuhpx = None

__all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d", "reorder"]


def pad(x: torch.Tensor, padding: int) -> torch.Tensor:
Expand Down Expand Up @@ -91,10 +96,57 @@ def pad(x: torch.Tensor, padding: int) -> torch.Tensor:
return healpixpad.HEALPixPadFunction.apply(x.unsqueeze(2), padding).squeeze(2)


def _apply_cuhpx_remap(func, x, **kwargs):
shape = x.shape
x = x.view(-1, 1, shape[-1])
nside = npix2nside(x.shape[-1])
x = func(x.contiguous(), **kwargs, nside=nside)
x = x.contiguous()
x = x.view(shape[:-1] + (-1,))
return x


def npix2nside(npix: int):
nside = math.sqrt(npix // 12)
return int(nside)


def npix2level(npix: int):
return nside2level(npix2nside(npix))


def nside2level(nside: int):
return int(math.log2(nside))


class PixelOrder(Enum):
RING = 0
NEST = 1

def reorder_from_cuda(self, x, src: "PixelOrderT"):
if self == PixelOrder.RING:
return src.to_ring_cuda(x)
elif self == PixelOrder.NEST:
return src.to_nest_cuda(x)

def to_ring_cuda(self, x: torch.Tensor):
if self == PixelOrder.RING:
return x
elif self == PixelOrder.NEST:
return _apply_cuhpx_remap(cuhpx.nest2ring, x)

def to_nest_cuda(self, x: torch.Tensor):
if self == PixelOrder.RING:
return _apply_cuhpx_remap(cuhpx.ring2nest, x)
elif self == PixelOrder.NEST:
return x

def to_xy_cuda(self, x: torch.Tensor, dest: "XY"):
if self == PixelOrder.RING:
return _apply_cuhpx_remap(cuhpx.ring2flat, x, clockwise=dest.clockwise, origin=dest.origin.name)
elif self == PixelOrder.NEST:
return _apply_cuhpx_remap(cuhpx.nest2flat, x, clockwise=dest.clockwise, origin=dest.origin.name)


class Compass(Enum):
"""Cardinal directions in counter clockwise order"""
Expand Down Expand Up @@ -126,12 +178,42 @@ class XY:
origin: Compass = Compass.S
clockwise: bool = False

def reorder_from_cuda(self, x, src: "PixelOrderT"):
return src.to_xy_cuda(x, self)

def to_xy_cuda(self, x: torch.Tensor, dest: "XY"):
return _apply_cuhpx_remap(
cuhpx.flat2flat,
x,
src_origin=self.origin.name,
src_clockwise=self.clockwise,
dest_origin=dest.origin.name,
dest_clockwise=dest.clockwise,
)

def to_ring_cuda(self, x: torch.Tensor):
return _apply_cuhpx_remap(
cuhpx.flat2ring,
x,
origin=self.origin.name,
clockwise=self.clockwise,
)

def to_nest_cuda(self, x: torch.Tensor):
return _apply_cuhpx_remap(cuhpx.flat2nest, x, origin=self.origin.name, clockwise=self.clockwise)


PixelOrderT = Union[PixelOrder, XY]

HEALPIX_PAD_XY = XY(origin=Compass.N, clockwise=True)


def reorder(x: torch.Tensor, src_pixel_order: PixelOrderT, dest_pixel_order: PixelOrderT):
"""Reorder x from one pixel order to another"""
grid = Grid(level=npix2level(x.size(-1)), pixel_order=src_pixel_order)
return grid.reorder(dest_pixel_order, x)


def _convert_xyindex(nside: int, src: XY, dest: XY, i):
if src.clockwise != dest.clockwise:
i = _flip_xy(nside, i)
Expand Down Expand Up @@ -261,13 +343,19 @@ def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
def approximate_grid_length_meters(self):
return approx_grid_length_meters(self._nside())

def reorder(self, order: PixelOrderT, x: torch.Tensor) -> torch.Tensor:
"""Rorder the pixels of ``x`` to have ``order``"""
def _reorder_cpu(self, x: torch.Tensor, order: PixelOrderT):
output_grid = Grid(level=self.level, pixel_order=order)
i_nest = output_grid._nest_ipix()
i_me = self._nest2me(i_nest)
return x[..., i_me]

def reorder(self, order: PixelOrderT, x: torch.Tensor) -> torch.Tensor:
"""Rorder the pixels of ``x`` to have ``order``"""
if x.device.type == "cuda":
return order.reorder_from_cuda(x, self.pixel_order)
else:
return self._reorder_cpu(x, order)

def get_healpix_regridder(self, dest: "Grid"):
if self.level != dest.level:
return self.get_bilinear_regridder_to(dest.lat, dest.lon)
Expand Down
17 changes: 16 additions & 1 deletion tests/test_healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_rotate_index(rot):

@pytest.mark.parametrize("origin", list(healpix.Compass))
@pytest.mark.parametrize("clockwise", [True, False])
def test_reorder(tmp_path, origin, clockwise):
def test_Grid_reorder(tmp_path, origin, clockwise):
src_grid = healpix.Grid(level=4, pixel_order=healpix.XY(origin=origin, clockwise=clockwise))
dest_grid = healpix.Grid(level=4, pixel_order=healpix.PixelOrder.NEST)

Expand Down Expand Up @@ -145,6 +145,21 @@ def test_conv2d():
assert out.shape == (n, cout, 1, npix)


@pytest.mark.parametrize("nside", [16])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("src_pixel_order", [healpix.HEALPIX_PAD_XY, healpix.PixelOrder.RING, healpix.PixelOrder.NEST])
@pytest.mark.parametrize("dest_pixel_order", [healpix.HEALPIX_PAD_XY, healpix.PixelOrder.RING, healpix.PixelOrder.NEST])
def test_reorder(nside, src_pixel_order, dest_pixel_order, device):
# Generate some test data
if device == "cuda" and torch.cuda.device_count() == 0:
pytest.skip("no cuda devices available")

data = torch.randn(1, 2, 12 * nside * nside, device=device)
out = healpix.reorder(data, src_pixel_order, dest_pixel_order)
out = healpix.reorder(out, dest_pixel_order, src_pixel_order)
assert torch.all(data == out), data - out


def test_latlon_cuda_set_device_regression():
"""See https://github.com/NVlabs/earth2grid/issues/6"""

Expand Down
2 changes: 1 addition & 1 deletion tests/test_latlon.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_lat_lon_bilinear_regrid_to():

regrid.float()
lat = torch.broadcast_to(torch.tensor(src.lat), src.shape)
z = torch.tensor(lat).float()
z = lat.float()

out = regrid(z)
assert out.shape == dest.shape

0 comments on commit 5b9b294

Please sign in to comment.