diff --git a/earth2grid/healpix.py b/earth2grid/healpix.py index 16857ab..ef0b9aa 100644 --- a/earth2grid/healpix.py +++ b/earth2grid/healpix.py @@ -185,7 +185,7 @@ def _npix(self): def _nest_ipix(self): """convert to nested index number""" - i = torch.arange(self._npix()) + i = torch.arange(self._npix(), device="cpu") if isinstance(self.pixel_order, XY): i_xy = _convert_xyindex(nside=self._nside(), src=self.pixel_order, dest=XY(), i=i) i = xy2nest(self._nside(), i_xy) @@ -195,29 +195,28 @@ def _nest_ipix(self): pass else: raise ValueError(self.pixel_order) - return i.numpy() + return i - def _nest2me(self, ipix: np.ndarray) -> np.ndarray: + def _nest2me(self, ipix: torch.Tensor) -> torch.Tensor: """return the index in my PIXELORDER corresponding to ipix in NEST ordering""" if isinstance(self.pixel_order, XY): i_xy = nest2xy(self._nside(), ipix) i_me = _convert_xyindex(nside=self._nside(), src=XY(), dest=self.pixel_order, i=i_xy) elif self.pixel_order == PixelOrder.RING: - ipix_t = torch.from_numpy(ipix) - i_me = healpix_bare.nest2ring(self._nside(), ipix_t).numpy() + i_me = healpix_bare.nest2ring(self._nside(), ipix) elif self.pixel_order == PixelOrder.NEST: i_me = ipix return i_me @property def lat(self): - ipix = torch.from_numpy(self._nest_ipix()) + ipix = self._nest_ipix() _, lat = healpix_bare.pix2ang(self._nside(), ipix, lonlat=True, nest=True) return lat.numpy() @property def lon(self): - ipix = torch.from_numpy(self._nest_ipix()) + ipix = self._nest_ipix() lon, _ = healpix_bare.pix2ang(self._nside(), ipix, lonlat=True, nest=True) return lon.numpy() @@ -256,7 +255,7 @@ def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray): lat, lon = np.broadcast_arrays(lat, lon) i_ring, weights = healpix_bare.get_interp_weights(self._nside(), torch.tensor(lon), torch.tensor(lat)) i_nest = healpix_bare.ring2nest(self._nside(), i_ring.ravel()) - i_me = torch.from_numpy(self._nest2me(i_nest.numpy())).view(i_ring.shape) + i_me = self._nest2me(i_nest).reshape(i_ring.shape) return ApplyWeights(i_me, weights) def approximate_grid_length_meters(self): diff --git a/tests/test_healpix.py b/tests/test_healpix.py index d7d0e63..d35a4e4 100644 --- a/tests/test_healpix.py +++ b/tests/test_healpix.py @@ -143,3 +143,18 @@ def test_conv2d(): weight = torch.zeros(cout, cin, 3, 3) out = healpix.conv2d(x, weight, padding=(1, 1)) assert out.shape == (n, cout, 1, npix) + + +def test_latlon_cuda_set_device_regression(): + """See https://github.com/NVlabs/earth2grid/issues/6""" + + if torch.cuda.device_count() == 0: + pytest.skip() + + default = torch.get_default_device() + try: + torch.set_default_device("cuda") + grid = healpix.Grid(4) + grid.lat + finally: + torch.set_default_device(default)