Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StormCast model #164

Merged
merged 24 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2a2d902
add stormcast
dallasfoster Nov 7, 2024
f71e56c
Merge branch 'main' into dallasf/stormcast
dallasfoster Nov 7, 2024
600a062
coord interpolation and stormcast unit tests
pzharrington Nov 12, 2024
225ebe4
stormcast bugfixes and test case updates
pzharrington Nov 13, 2024
b05eba7
add model package and speed up unit tests
pzharrington Dec 13, 2024
53b7b72
modify checkpoint path and update changelog
pzharrington Dec 13, 2024
2fc6046
Merge branch 'main' into pharring/stormcast-tests
pzharrington Dec 13, 2024
70f59a5
handle 1d coord interpolation in fetch_data
pzharrington Dec 13, 2024
9bf5509
Updates to grib files
NickGeneva Dec 16, 2024
b0d1acf
Updates
NickGeneva Dec 16, 2024
89d66e2
Adding note abot lead times
NickGeneva Dec 16, 2024
9ee220f
Caching index file
NickGeneva Dec 16, 2024
125e00e
Fix type
NickGeneva Dec 16, 2024
9e6652c
Fix type
NickGeneva Dec 16, 2024
badffce
Caching index file
NickGeneva Dec 16, 2024
6e311c1
Fixes
NickGeneva Dec 16, 2024
8b3e0ac
Type Fixes
NickGeneva Dec 16, 2024
9f00edf
Merge branch 'ngeneva/hrrr_extended' of https://github.com/NickGeneva…
pzharrington Dec 16, 2024
86ed976
Public package testing
pzharrington Dec 18, 2024
3d5f561
Merge branch 'main' into pharrington/stormcast-testing
pzharrington Dec 18, 2024
6764d7b
prep_data test fixes
pzharrington Dec 18, 2024
7536fee
Require conditioning data source in stormcast, update deps
pzharrington Dec 19, 2024
fd4f611
Add stormcast to docs rst
pzharrington Dec 19, 2024
4fb841c
Merge branch 'main' into pharring/stormcast-tests
pzharrington Dec 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add StormCast model to prognostic models
- Interpolation between arbitrary lat-lon grids
- Added hybrid level support to HRRR data source

Expand Down
1 change: 1 addition & 0 deletions docs/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Thus are typically used to generate forecast predictions.
models.px.Pangu3
models.px.Persistence
models.px.SFNO
models.px.StormCast

.. _earth2studio.models.dx:

Expand Down
32 changes: 26 additions & 6 deletions earth2studio/data/rand.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,18 @@ def __init__(
):
self.domain_coords = domain_coords

# Check for regular vs. curvilinear coordinates
_, value = list(self.domain_coords.items()).pop()
value = np.array(value)
self.curv = len(value.shape) > 1
if self.curv:
self.domain_coord_shape = value.shape

def __call__(
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
lead_time: np.array = None,
) -> xr.DataArray:
"""Retrieve random gaussian data.

Expand All @@ -63,12 +71,24 @@ def __call__(

shape = [len(time), len(variable)]
coords = {"time": time, "variable": variable}
for key, value in self.domain_coords.items():
shape.append(len(value))
coords[key] = value

da = xr.DataArray(
data=np.random.randn(*shape), dims=list(coords), coords=coords
)
if self.curv:
shape.extend(self.domain_coord_shape)
dims = ["time", "variable", "y", "x"]
coords = coords | {
"lat": (("y", "x"), self.domain_coords["lat"]),
"lon": (("y", "x"), self.domain_coords["lon"]),
}
da = xr.DataArray(data=np.random.randn(*shape), dims=dims, coords=coords)

else:

for key, value in self.domain_coords.items():
shape.append(len(value))
coords[key] = value

da = xr.DataArray(
data=np.random.randn(*shape), dims=list(coords), coords=coords
)

return da
87 changes: 83 additions & 4 deletions earth2studio/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from loguru import logger

from earth2studio.data.base import DataSource
from earth2studio.utils.interp import LatLonInterpolation
from earth2studio.utils.time import (
leadtimearray_to_timedelta,
timearray_to_datetime,
Expand All @@ -43,8 +44,12 @@ def fetch_data(
variable: VariableArray,
lead_time: LeadTimeArray = np.array([np.timedelta64(0, "h")]),
device: torch.device = "cpu",
interp_to: CoordSystem = None,
interp_method: str = "nearest",
) -> tuple[torch.Tensor, CoordSystem]:
"""Utility function to fetch data for models and load data on the target device.
If desired, xarray interpolation/regridding in the spatial domain can be used
by passing a target coordinate system via the optional `interp_to` argument.

Parameters
----------
Expand All @@ -59,6 +64,11 @@ def fetch_data(
np.array(np.timedelta64(0, "h"))
device : torch.device, optional
Torch devive to load data tensor to, by default "cpu"
interp_to : CoordSystem, optional
If provided, the fetched data will be interpolated to the coordinates
specified by lat/lon arrays in this CoordSystem
interp_method : str
Interpolation method to use with xarray (by default 'nearest')

Returns
-------
Expand All @@ -75,34 +85,103 @@ def fetch_data(
da0 = da0.assign_coords(time=time)
da.append(da0)

return prep_data_array(xr.concat(da, "lead_time"), device=device)
return prep_data_array(
xr.concat(da, "lead_time"),
device=device,
interp_to=interp_to,
interp_method=interp_method,
)


def prep_data_array(
da: xr.DataArray,
device: torch.device = "cpu",
interp_to: CoordSystem = None,
interp_method: str = "nearest",
) -> tuple[torch.Tensor, CoordSystem]:
"""Prepares a data array from a data source for inference workflows by converting
the data array to a torch tensor and the coordinate system to an OrderedDict.

If desired, xarray interpolation/regridding in the spatial domain can be used
by passing a target coordinate system via the optional `interp_to` argument.

Parameters
----------
da : xr.DataArray
Input data array
device : torch.device, optional
Torch devive to load data tensor to, by default "cpu"
interp_to : CoordSystem, optional
If provided, the fetched data will be interpolated to the coordinates
specified by lat/lon arrays in this CoordSystem
interp_method : str
Interpolation method to use with xarray (by default 'nearest')

Returns
-------
tuple[torch.Tensor, CoordSystem]
Tuple containing output tensor and coordinate OrderedDict
"""

out = torch.Tensor(da.values).to(device)

# Initialize the output CoordSystem
out_coords = OrderedDict()
for dim in da.coords.dims:
out_coords[dim] = np.array(da.coords[dim])
if dim in ["time", "lead_time", "variable"]:
out_coords[dim] = np.array(da.coords[dim])

# Fetch data and regrid if necessary
if interp_to is not None:
if len(interp_to["lat"].shape) != len(interp_to["lon"].shape):
raise ValueError(
"Discrepancy in interpolation coordinates: latitude has different number of dims than longitude"
)

if "lat" not in da.dims:
# Data source uses curvilinear coordinates
if interp_method != "linear":
raise ValueError(
"fetch_data does not support interpolation methods other than linear when data source has a curvilinear grid"
)
interp = LatLonInterpolation(
lat_in=da["lat"].values,
lon_in=da["lon"].values,
lat_out=interp_to["lat"],
lon_out=interp_to["lon"],
).to(device)
data = torch.Tensor(da.values).to(device)
out = interp(data)

else:

if len(interp_to["lat"].shape) > 1 or len(interp_to["lon"].shape) > 1:
# Target grid uses curvilinear coordinates: define internal dims y, x
target_lat = xr.DataArray(interp_to["lat"], dims=["y", "x"])
target_lon = xr.DataArray(interp_to["lon"], dims=["y", "x"])
else:
target_lat = xr.DataArray(interp_to["lat"], dims=["lat"])
target_lon = xr.DataArray(interp_to["lon"], dims=["lon"])

da = da.interp(
lat=target_lat,
lon=target_lon,
method=interp_method,
)

out = torch.Tensor(da.values).to(device)

out_coords["lat"] = interp_to["lat"]
out_coords["lon"] = interp_to["lon"]

else:
out = torch.Tensor(da.values).to(device)
if "lat" in da.coords and "lat" not in da.coords.dims:
# Curvilinear grid case: lat/lon coords are 2D arrays, not in dims
out_coords["lat"] = da.coords["lat"].values
out_coords["lon"] = da.coords["lon"].values
else:
for dim in da.coords.dims:
if dim not in ["time", "lead_time", "variable"]:
out_coords[dim] = np.array(da.coords[dim])

return out, out_coords

Expand Down
13 changes: 7 additions & 6 deletions earth2studio/lexicon/hrrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,17 @@ def build_vocab() -> dict[str, str]:
975,
1000,
]
prs_names = ["UGRD", "VGRD", "HGT", "TMP", "RH", "SPFH"]
e2s_id = ["u", "v", "z", "t", "r", "q"]

prs_names = ["UGRD", "VGRD", "HGT", "TMP", "RH", "SPFH", "HGT"]
e2s_id = ["u", "v", "z", "t", "r", "q", "Z"]
prs_variables = {}
for (id, variable) in zip(e2s_id, prs_names):
for level in prs_levels:
prs_variables[f"{id}{level:d}"] = f"prs::anl::{level} mb::{variable}"

hybrid_levels = list(range(1, 51))
hybrid_names = ["UGRD", "VGRD", "HGT", "TMP", "SPFH", "PRES"]
e2s_id = ["u", "v", "z", "t", "q", "p"]
hybrid_names = ["UGRD", "VGRD", "HGT", "TMP", "SPFH", "PRES", "HGT"]
e2s_id = ["u", "v", "z", "t", "q", "p", "Z"]
hybrid_variables = {}
for (id, variable) in zip(e2s_id, hybrid_names):
for level in hybrid_levels:
Expand Down Expand Up @@ -187,8 +188,8 @@ def build_vocab() -> dict[str, str]:
prs_variables[f"{id}{level:d}"] = f"sfc::fcst::{level} mb::{variable}"

hybrid_levels = list(range(1, 51))
hybrid_names = ["UGRD", "VGRD", "HGT", "TMP", "SPFH", "PRES"]
e2s_id = ["u", "v", "z", "t", "q", "p"]
hybrid_names = ["UGRD", "VGRD", "HGT", "TMP", "SPFH", "PRES", "HGT"]
e2s_id = ["u", "v", "z", "t", "q", "p", "Z"]
hybrid_variables = {}
for (id, variable) in zip(e2s_id, hybrid_names):
for level in hybrid_levels:
Expand Down
4 changes: 3 additions & 1 deletion earth2studio/models/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def _batch_wrap(self, func: Callable) -> Callable:
# TODO: Better typing for model object
@functools.wraps(func)
def _wrapper(
model: Any, x: torch.Tensor, coords: CoordSystem
model: Any,
x: torch.Tensor,
coords: CoordSystem,
) -> tuple[torch.Tensor, CoordSystem]:

x, flatten_coords, batched_coords, batched_shape = self._compress_batch(
Expand Down
1 change: 1 addition & 0 deletions earth2studio/models/px/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from earth2studio.models.px.pangu import Pangu3, Pangu6, Pangu24
from earth2studio.models.px.persistence import Persistence
from earth2studio.models.px.sfno import SFNO
from earth2studio.models.px.stormcast import StormCast
Loading