diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index 564a7cd..daba8bb 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -44,7 +44,7 @@ def forward(self, z): weight = self.weight.view(-1, p) # using embedding bag is 2x faster on cpu and 4x on gpu. - output = torch.nn.functional.embedding_bag(index, zrs, per_sample_weights=weight, mode='sum') + output = torch.nn.functional.embedding_bag(index, zrs, per_sample_weights=weight, mode="sum") output = output.T.view(*shape, -1) return output.reshape(list(shape) + output_shape) @@ -173,11 +173,11 @@ def forward(self, z: torch.Tensor): *shape, y, x = z.shape zrs = z.view(-1, y * x).T # using embedding bag is 2x faster on cpu and 4x on gpu. - output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum') + output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode="sum") interpolated = torch.full( [self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device ) - interpolated.masked_scatter_(self.mask.view(-1,1), output) + interpolated.masked_scatter_(self.mask.view(-1, 1), output) interpolated = interpolated.T.view(*shape, *self.mask.shape) return interpolated diff --git a/earth2grid/lcc.py b/earth2grid/lcc.py index abac978..55146d0 100644 --- a/earth2grid/lcc.py +++ b/earth2grid/lcc.py @@ -34,7 +34,7 @@ class LambertConformalConicProjection: def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: float): """ - + Args: lat0: latitude of origin (degrees) lon0: longitude of origin (degrees) @@ -50,7 +50,6 @@ def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: f self.lat2 = lat2 self.radius = radius - c1 = np.cos(np.deg2rad(lat1)) c2 = np.cos(np.deg2rad(lat2)) t1 = np.tan(np.pi / 4 + np.deg2rad(lat1) / 2) @@ -58,8 +57,8 @@ def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: f if np.abs(lat1 - lat2) < 1e-8: self.n = np.sin(np.deg2rad(lat1)) - else: - self.n = np.log(c1/c2) / np.log(t2/t1) + else: + self.n = np.log(c1 / c2) / np.log(t2 / t1) self.RF = radius * c1 * np.power(t1, self.n) / self.n self.rho0 = self._rho(lat0) @@ -78,17 +77,16 @@ def _theta(self, lon): """ # center about reference longitude delta_lon = lon - self.lon0 - delta_lon = delta_lon - np.round(delta_lon/360) * 360 # convert to [-180, 180] + delta_lon = delta_lon - np.round(delta_lon / 360) * 360 # convert to [-180, 180] return self.n * np.deg2rad(delta_lon) - def project(self, lat, lon): """ Compute the projected x,y from lat,lon. """ rho = self._rho(lat) theta = self._theta(lon) - + x = rho * np.sin(theta) y = self.rho0 - rho * np.cos(theta) return x, y @@ -99,26 +97,21 @@ def inverse_project(self, x, y): """ rho = np.hypot(x, self.rho0 - y) theta = np.arctan2(x, self.rho0 - y) - - lat = np.rad2deg(2 * np.arctan(np.power(self.RF/rho, 1/self.n))) - 90 + + lat = np.rad2deg(2 * np.arctan(np.power(self.RF / rho, 1 / self.n))) - 90 lon = self.lon0 + np.rad2deg(theta / self.n) return lat, lon + # Projection used by HRRR CONUS (Continental US) data # https://rapidrefresh.noaa.gov/hrrr/HRRR_conus.domain.txt -HRRR_CONUS_PROJECTION = LambertConformalConicProjection( - lon0 = -97.5, - lat0 = 38.5, - lat1 = 38.5, - lat2 = 38.5, - radius = 6371229.0 -) +HRRR_CONUS_PROJECTION = LambertConformalConicProjection(lon0=-97.5, lat0=38.5, lat1=38.5, lat2=38.5, radius=6371229.0) class LambertConformalConicGrid(base.Grid): # nothing here is specific to the projection, so could be shared by any projected rectilinear grid def __init__(self, projection: LambertConformalConicProjection, x, y): - """ + """ Args: projection: LambertConformalConicProjection object x: range of x values @@ -155,12 +148,13 @@ def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray): """Get regridder to the specified lat and lon points""" x, y = self.projection.project(lat, lon) - + return BilinearInterpolator( - x_coords = torch.from_numpy(self.x), - y_coords = torch.from_numpy(self.y), - x_query = torch.from_numpy(x), - y_query = torch.from_numpy(y)) + x_coords=torch.from_numpy(self.x), + y_coords=torch.from_numpy(self.y), + x_query=torch.from_numpy(x), + y_query=torch.from_numpy(y), + ) def visualize(self, data): raise NotImplementedError() @@ -176,7 +170,8 @@ def to_pyvista(self): grid = pv.StructuredGrid(x, y, z) return grid -def hrrr_conus_grid(ix0 = 0, iy0 = 0, nx = 1799, ny = 1059): + +def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059): # coordinates of point in top-left corner lat0 = 21.138123 lon0 = 237.280472 @@ -185,10 +180,11 @@ def hrrr_conus_grid(ix0 = 0, iy0 = 0, nx = 1799, ny = 1059): # coordinates on projected space x0, y0 = HRRR_CONUS_PROJECTION.project(lat0, lon0) - x = [x0 + i * scale for i in range(ix0, ix0+nx)] - y = [y0 + i * scale for i in range(iy0, iy0+ny)] + x = [x0 + i * scale for i in range(ix0, ix0 + nx)] + y = [y0 + i * scale for i in range(iy0, iy0 + ny)] return LambertConformalConicGrid(HRRR_CONUS_PROJECTION, x, y) + # Grid used by HRRR CONUS (Continental US) data HRRR_CONUS_GRID = hrrr_conus_grid() diff --git a/tests/test_lcc.py b/tests/test_lcc.py index 73dcacb..536068a 100644 --- a/tests/test_lcc.py +++ b/tests/test_lcc.py @@ -13,38 +13,49 @@ # See the License for the specific language governing permissions and # limitations under the License. -#%% -from earth2grid.lcc import HRRR_CONUS_GRID +# %% import numpy as np -import torch import pytest +import torch + +from earth2grid.lcc import HRRR_CONUS_GRID + def test_grid_shape(): - assert HRRR_CONUS_GRID.lat.shape == HRRR_CONUS_GRID.shape + assert HRRR_CONUS_GRID.lat.shape == HRRR_CONUS_GRID.shape assert HRRR_CONUS_GRID.lon.shape == HRRR_CONUS_GRID.shape -lats = np.array([ + +lats = np.array( + [ [21.138123, 21.801926, 22.393631, 22.911015], - [23.636763, 24.328228, 24.944668, 25.48374 ], + [23.636763, 24.328228, 24.944668, 25.48374], [26.155672, 26.875362, 27.517046, 28.078257], - [28.69017 , 29.438608, 30.106009, 30.68978 ]]) + [28.69017, 29.438608, 30.106009, 30.68978], + ] +) -lons = np.array([ - [-122.71953 , -120.03195 , -117.304596, -114.54146 ], - [-123.491356, -120.72898 , -117.92319 , -115.07828 ], - [-124.310524, -121.469505, -118.58098 , -115.649574], - [-125.181404, -122.25762 , -119.28173 , -116.25871 ]]) +lons = np.array( + [ + [-122.71953, -120.03195, -117.304596, -114.54146], + [-123.491356, -120.72898, -117.92319, -115.07828], + [-124.310524, -121.469505, -118.58098, -115.649574], + [-125.181404, -122.25762, -119.28173, -116.25871], + ] +) def test_grid_vals(): - assert HRRR_CONUS_GRID.lat[0:400:100,0:400:100] == pytest.approx(lats) - assert HRRR_CONUS_GRID.lon[0:400:100,0:400:100] == pytest.approx(lons) + assert HRRR_CONUS_GRID.lat[0:400:100, 0:400:100] == pytest.approx(lats) + assert HRRR_CONUS_GRID.lon[0:400:100, 0:400:100] == pytest.approx(lons) + def test_grid_slice(): - slice_grid = HRRR_CONUS_GRID[0:400:100,0:400:100] + slice_grid = HRRR_CONUS_GRID[0:400:100, 0:400:100] assert slice_grid.lat == pytest.approx(lats) assert slice_grid.lon == pytest.approx(lons) + def test_regrid_1d(): src = HRRR_CONUS_GRID dest_lat = np.linspace(25.0, 33.0, 10) @@ -55,6 +66,7 @@ def test_regrid_1d(): assert torch.allclose(out_lat, torch.tensor(dest_lat)) + def test_regrid_2d(): src = HRRR_CONUS_GRID dest_lat, dest_lon = np.meshgrid(np.linspace(25.0, 33.0, 10), np.linspace(-123, -98, 12))