Skip to content

Commit

Permalink
fix import, add 1d and 2d tests
Browse files Browse the repository at this point in the history
  • Loading branch information
simonbyrne committed Sep 27, 2024
1 parent 009f8a0 commit 1487746
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
4 changes: 2 additions & 2 deletions earth2grid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module:
"""Get a regridder from `src` to `dest`"""
if src == dest:
return Identity()
elif isinstance(src, latlon.LatLonGrid)
elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, latlon.LatLonGrid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, healpix.Grid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, lcc.LambertConformalConicGrid)
elif isinstance(src, lcc.LambertConformalConicGrid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, healpix.Grid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
Expand Down
22 changes: 16 additions & 6 deletions tests/test_lcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,22 @@ def test_grid_slice():
assert slice_grid.lat == approx(lats)
assert slice_grid.lon == approx(lons)

def test_grid_slice():
def test_regrid_1d():
src = HRRR_CONUS_GRID
dest_lat = np.linspace(25.0, 33.0, 10)
dest_lon = np.linspace(-123, -98, 10)
regrid = src.get_bilinear_regridder_to(dest_lat, dest_lon)
src_lat = torch.broadcast_to(torch.tensor(src.lat), src.shape)
out_lat = regrid(src_lat)

assert torch.allclose(out_lat, torch.tensor(dest_lat))

def test_regrid_2d():
src = HRRR_CONUS_GRID
dest = HRRR_CONUS_GRID[1:-1, 1:-1]
regrid = src.get_bilinear_regridder_to(dest.lat, dest.lon)
lat = torch.broadcast_to(torch.tensor(src.lat), src.shape)
out = regrid(lat)
dest_lat, dest_lon = np.meshgrid(np.linspace(25.0, 33.0, 10), np.linspace(-123, -98, 10))
regrid = src.get_bilinear_regridder_to(dest_lat, dest_lon)
src_lat = torch.broadcast_to(torch.tensor(src.lat), src.shape)
out_lat = regrid(src_lat)

assert torch.allclose(out_lat, torch.tensor(dest_lat))

assert torch.allclose(out, lat[1:-1, 1:-1])

0 comments on commit 1487746

Please sign in to comment.