Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StormCast model #164

Merged
merged 24 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2a2d902
add stormcast
dallasfoster Nov 7, 2024
f71e56c
Merge branch 'main' into dallasf/stormcast
dallasfoster Nov 7, 2024
600a062
coord interpolation and stormcast unit tests
pzharrington Nov 12, 2024
225ebe4
stormcast bugfixes and test case updates
pzharrington Nov 13, 2024
b05eba7
add model package and speed up unit tests
pzharrington Dec 13, 2024
53b7b72
modify checkpoint path and update changelog
pzharrington Dec 13, 2024
2fc6046
Merge branch 'main' into pharring/stormcast-tests
pzharrington Dec 13, 2024
70f59a5
handle 1d coord interpolation in fetch_data
pzharrington Dec 13, 2024
9bf5509
Updates to grib files
NickGeneva Dec 16, 2024
b0d1acf
Updates
NickGeneva Dec 16, 2024
89d66e2
Adding note abot lead times
NickGeneva Dec 16, 2024
9ee220f
Caching index file
NickGeneva Dec 16, 2024
125e00e
Fix type
NickGeneva Dec 16, 2024
9e6652c
Fix type
NickGeneva Dec 16, 2024
badffce
Caching index file
NickGeneva Dec 16, 2024
6e311c1
Fixes
NickGeneva Dec 16, 2024
8b3e0ac
Type Fixes
NickGeneva Dec 16, 2024
9f00edf
Merge branch 'ngeneva/hrrr_extended' of https://github.com/NickGeneva…
pzharrington Dec 16, 2024
86ed976
Public package testing
pzharrington Dec 18, 2024
3d5f561
Merge branch 'main' into pharrington/stormcast-testing
pzharrington Dec 18, 2024
6764d7b
prep_data test fixes
pzharrington Dec 18, 2024
7536fee
Require conditioning data source in stormcast, update deps
pzharrington Dec 19, 2024
fd4f611
Add stormcast to docs rst
pzharrington Dec 19, 2024
4fb841c
Merge branch 'main' into pharring/stormcast-tests
pzharrington Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add StormCast model to prognostic models

### Changed

- Interpolation between arbitrary lat-lon grids

### Changed
Expand Down
32 changes: 26 additions & 6 deletions earth2studio/data/rand.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,18 @@ def __init__(
):
self.domain_coords = domain_coords

# Check for regular vs. curvilinear coordinates
_, value = list(self.domain_coords.items()).pop()
value = np.array(value)
self.curv = len(value.shape) > 1
if self.curv:
self.domain_coord_shape = value.shape

def __call__(
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
lead_time: np.array = None,
) -> xr.DataArray:
"""Retrieve random gaussian data.

Expand All @@ -63,12 +71,24 @@ def __call__(

shape = [len(time), len(variable)]
coords = {"time": time, "variable": variable}
for key, value in self.domain_coords.items():
shape.append(len(value))
coords[key] = value

da = xr.DataArray(
data=np.random.randn(*shape), dims=list(coords), coords=coords
)
if self.curv:
shape.extend(self.domain_coord_shape)
dims = ["time", "variable", "y", "x"]
coords = coords | {
"lat": (("y", "x"), self.domain_coords["lat"]),
"lon": (("y", "x"), self.domain_coords["lon"]),
}
da = xr.DataArray(data=np.random.randn(*shape), dims=dims, coords=coords)

else:

for key, value in self.domain_coords.items():
shape.append(len(value))
coords[key] = value

da = xr.DataArray(
data=np.random.randn(*shape), dims=list(coords), coords=coords
)

return da
62 changes: 59 additions & 3 deletions earth2studio/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ def fetch_data(
variable: VariableArray,
lead_time: LeadTimeArray = np.array([np.timedelta64(0, "h")]),
device: torch.device = "cpu",
interp_to: CoordSystem = None,
interp_method: str = "nearest",
NickGeneva marked this conversation as resolved.
Show resolved Hide resolved
) -> tuple[torch.Tensor, CoordSystem]:
"""Utility function to fetch data for models and load data on the target device.
If desired, xarray interpolation/regridding in the spatial domain can be used
by passing a target coordinate system via the optional `interp_to` argument.

Parameters
----------
Expand All @@ -59,6 +63,11 @@ def fetch_data(
np.array(np.timedelta64(0, "h"))
device : torch.device, optional
Torch devive to load data tensor to, by default "cpu"
interp_to : CoordSystem, optional
If provided, the fetched data will be interpolated to the coordinates
specified by lat/lon arrays in this CoordSystem
interp_method : str
Interpolation method to use with xarray (by default 'nearest')

Returns
-------
Expand All @@ -75,34 +84,81 @@ def fetch_data(
da0 = da0.assign_coords(time=time)
da.append(da0)

return prep_data_array(xr.concat(da, "lead_time"), device=device)
return prep_data_array(
xr.concat(da, "lead_time"),
device=device,
interp_to=interp_to,
interp_method=interp_method,
)


def prep_data_array(
da: xr.DataArray,
device: torch.device = "cpu",
interp_to: CoordSystem = None,
interp_method: str = "nearest",
) -> tuple[torch.Tensor, CoordSystem]:
"""Prepares a data array from a data source for inference workflows by converting
the data array to a torch tensor and the coordinate system to an OrderedDict.

If desired, xarray interpolation/regridding in the spatial domain can be used
by passing a target coordinate system via the optional `interp_to` argument.

Parameters
----------
da : xr.DataArray
Input data array
device : torch.device, optional
Torch devive to load data tensor to, by default "cpu"
interp_to : CoordSystem, optional
If provided, the fetched data will be interpolated to the coordinates
specified by lat/lon arrays in this CoordSystem
interp_method : str
Interpolation method to use with xarray (by default 'nearest')

Returns
-------
tuple[torch.Tensor, CoordSystem]
Tuple containing output tensor and coordinate OrderedDict
"""

if interp_to is not None:
NickGeneva marked this conversation as resolved.
Show resolved Hide resolved
if len(interp_to["lat"].shape) > 1 or len(interp_to["lon"].shape) > 1:
if len(interp_to["lat"].shape) != len(interp_to["lon"].shape):
raise ValueError(
"Discrepancy in interpolation coordinates: latitude has different shape than longitude"
)

# Curvilinear coordinates: define internal dims y, x
target_lat = xr.DataArray(interp_to["lat"], dims=["y", "x"])
target_lon = xr.DataArray(interp_to["lon"], dims=["y", "x"])
else:

target_lat = xr.DataArray(interp_to["lat"], dims=["lat"])
target_lon = xr.DataArray(interp_to["lon"], dims=["lon"])
target_grid = xr.Dataset({"latitude": target_lat, "longitude": target_lon})

da = da.interp(
lat=target_grid["latitude"],
lon=target_grid["longitude"],
method=interp_method,
)

out = torch.Tensor(da.values).to(device)

out_coords = OrderedDict()
for dim in da.coords.dims:
out_coords[dim] = np.array(da.coords[dim])

curvilinear = "lat" in da.coords and len(da.coords["lat"].shape) > 1
if curvilinear:
# Assume the lat, lon array are defined as DataArray coordinates
for dim in da.coords.dims:
if dim in ["time", "variable", "lead_time"]:
out_coords[dim] = np.array(da.coords[dim])
out_coords["lat"] = np.array(da.coords["lat"])
out_coords["lon"] = np.array(da.coords["lon"])
else:
for dim in da.coords.dims:
out_coords[dim] = np.array(da.coords[dim])

return out, out_coords

Expand Down
17 changes: 15 additions & 2 deletions earth2studio/models/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,28 @@ def _batch_wrap(self, func: Callable) -> Callable:
# TODO: Better typing for model object
@functools.wraps(func)
def _wrapper(
model: Any, x: torch.Tensor, coords: CoordSystem
model: Any,
x: torch.Tensor,
coords: CoordSystem,
conditioning: torch.Tensor | None = None,
conditioning_coords: CoordSystem | None = None,
) -> tuple[torch.Tensor, CoordSystem]:

x, flatten_coords, batched_coords, batched_shape = self._compress_batch(
model, x, coords
)

# Model forward
out, out_coords = func(model, x, flatten_coords)
if conditioning is not None:
out, out_coords = func(
model,
x,
flatten_coords,
conditioning=conditioning,
conditioning_coords=conditioning_coords,
)
else:
out, out_coords = func(model, x, flatten_coords)
dallasfoster marked this conversation as resolved.
Show resolved Hide resolved
out, out_coords = self._decompress_batch(
out, out_coords, batched_coords, batched_shape
)
Expand Down
1 change: 1 addition & 0 deletions earth2studio/models/px/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from earth2studio.models.px.pangu import Pangu3, Pangu6, Pangu24
from earth2studio.models.px.persistence import Persistence
from earth2studio.models.px.sfno import SFNO
from earth2studio.models.px.stormcast import StormCast
dallasfoster marked this conversation as resolved.
Show resolved Hide resolved
Loading