Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gpu -> tensor bug #9

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions earth2grid/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _npix(self):

def _nest_ipix(self):
"""convert to nested index number"""
i = torch.arange(self._npix())
i = torch.arange(self._npix(), device="cpu")
if isinstance(self.pixel_order, XY):
i_xy = _convert_xyindex(nside=self._nside(), src=self.pixel_order, dest=XY(), i=i)
i = xy2nest(self._nside(), i_xy)
Expand All @@ -195,29 +195,28 @@ def _nest_ipix(self):
pass
else:
raise ValueError(self.pixel_order)
return i.numpy()
return i

def _nest2me(self, ipix: np.ndarray) -> np.ndarray:
def _nest2me(self, ipix: torch.Tensor) -> torch.Tensor:
"""return the index in my PIXELORDER corresponding to ipix in NEST ordering"""
if isinstance(self.pixel_order, XY):
i_xy = nest2xy(self._nside(), ipix)
i_me = _convert_xyindex(nside=self._nside(), src=XY(), dest=self.pixel_order, i=i_xy)
elif self.pixel_order == PixelOrder.RING:
ipix_t = torch.from_numpy(ipix)
i_me = healpix_bare.nest2ring(self._nside(), ipix_t).numpy()
i_me = healpix_bare.nest2ring(self._nside(), ipix)
elif self.pixel_order == PixelOrder.NEST:
i_me = ipix
return i_me

@property
def lat(self):
ipix = torch.from_numpy(self._nest_ipix())
ipix = self._nest_ipix()
_, lat = healpix_bare.pix2ang(self._nside(), ipix, lonlat=True, nest=True)
return lat.numpy()

@property
def lon(self):
ipix = torch.from_numpy(self._nest_ipix())
ipix = self._nest_ipix()
lon, _ = healpix_bare.pix2ang(self._nside(), ipix, lonlat=True, nest=True)
return lon.numpy()

Expand Down Expand Up @@ -256,7 +255,7 @@ def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
lat, lon = np.broadcast_arrays(lat, lon)
i_ring, weights = healpix_bare.get_interp_weights(self._nside(), torch.tensor(lon), torch.tensor(lat))
i_nest = healpix_bare.ring2nest(self._nside(), i_ring.ravel())
i_me = torch.from_numpy(self._nest2me(i_nest.numpy())).view(i_ring.shape)
i_me = self._nest2me(i_nest).reshape(i_ring.shape)
return ApplyWeights(i_me, weights)

def approximate_grid_length_meters(self):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,18 @@ def test_conv2d():
weight = torch.zeros(cout, cin, 3, 3)
out = healpix.conv2d(x, weight, padding=(1, 1))
assert out.shape == (n, cout, 1, npix)


def test_latlon_cuda_set_device_regression():
"""See https://github.com/NVlabs/earth2grid/issues/6"""

if torch.cuda.device_count() == 0:
pytest.skip()

default = torch.get_default_device()
try:
torch.set_default_device("cuda")
grid = healpix.Grid(4)
grid.lat
finally:
torch.set_default_device(default)
Loading