Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stormcast docs and example touch up #184

Merged
merged 3 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions earth2studio/models/px/sfno.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,16 @@ class SFNO(torch.nn.Module, AutoModelMixin, PrognosticMixin):
Consists of a single model with a time-step size of 6 hours.
FourCastNet operates on 0.25 degree lat-lon grid (south-pole excluding)
equirectangular grid with 73 variables.

Note
----
This model and checkpoint are trained using Modulus-Makani. For more information
see the following references:

- https://arxiv.org/abs/2306.03838
- https://github.com/NVIDIA/modulus-makani
- https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/models/sfno_73ch_small

Parameters
----------
core_model : torch.nn.Module
Expand Down
62 changes: 42 additions & 20 deletions earth2studio/models/px/stormcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,33 +71,51 @@


class StormCast(torch.nn.Module, AutoModelMixin, PrognosticMixin):
"""StormCast generative convection-allowing model for regional forecasts
Consists of two core models, a regression and diffusion model
This class implements StormCastV1, the model released in:
https://arxiv.org/abs/2408.10958
Model time step size is 1 hour, taking as input:
"""StormCast generative convection-allowing model for regional forecasts consists of
two core models: a regression and diffusion model. Model time step size is 1 hour,
taking as input:
- High-resolution (3km) HRRR state over the central United States (99 vars)
- High-resolution land-sea mask and orography invariants
- Coarse resolution (25km) global state (26 vars)
The high-resolution grid is the HRRR Lambert conformal projection
Coarse-resolution inputs are regridded to the HRRR grid internally.

Note
----
For more information see the following references:

- https://arxiv.org/abs/2408.10958

Parameters
----------
regression_model (torch.nn.Module): Deterministic model used to make an initial prediction
diffusion_model (torch.nn.Module): Generative model correcting the deterministic prediciton
lat (np.array): Latitude array (2D) of the domain
lon (np.array): Latitude array (2D) of the domain
means (torch.Tensor): Mean value of each input high-resolution variable
stds (torch.Tensor): Standard deviation of each input high-resolution variable
invariants (torch.Tensor): Static invariant quantities
variables (np.array, optional): High-resolution variables Defaults to np.array(VARIABLES).
conditioning_means (torch.Tensor | None, optional): Means to normalize conditioning data. Defaults to None.
conditioning_stds (torch.Tensor | None, optional): Stds to normalize conditioning data Defaults to None.
conditioning_variables (np.array, optional): Global variables for conditioning. Defaults to np.array(CONDITIONING_VARIABLES).
conditioning_data_source (DataSource | None, optional): Data Source to use for global conditoining. Defaults to None. Required for running in iterator mode
sampler_args (dict[str, float | int], optional): Arguments to pass to the diffusion sampler. Defaults to {}.
interp_method (str, optional): Interpolation method to use when regridding coarse conditoining data. Defaults to "linear".
regression_model : torch.nn.Module
Deterministic model used to make an initial prediction
diffusion_model : torch.nn.Module
Generative model correcting the deterministic prediciton
lat : np.array
Latitude array (2D) of the domain
lon : np.array
Longitude array (2D) of the domain
means : torch.Tensor
Mean value of each input high-resolution variable
stds : torch.Tensor
Standard deviation of each input high-resolution variable
invariants : torch.Tensor
Static invariant quantities
variables : np.array, optional
High-resolution variables, by default np.array(VARIABLES)
conditioning_means : torch.Tensor | None, optional
Means to normalize conditioning data, by default None
conditioning_stds : torch.Tensor | None, optional
Standard deviations to normalize conditioning data, by default None
conditioning_variables : np.array, optional
Global variables for conditioning, by default np.array(CONDITIONING_VARIABLES)
conditioning_data_source : DataSource | None, optional
Data Source to use for global conditoining. Required for running in iterator mode, by default None
sampler_args : dict[str, float | int], optional
Arguments to pass to the diffusion sampler, by default {}
interp_method : str, optional
Interpolation method to use when regridding coarse conditoining data, by default "linear"
"""

def __init__(
Expand Down Expand Up @@ -231,7 +249,11 @@ def load_model(cls, package: Package) -> DiagnosticModel:
f"nvidia-modulus @ git+https://github.com/NVIDIA/modulus.git"
)

OmegaConf.register_new_resolver("eval", eval)
try:
OmegaConf.register_new_resolver("eval", eval)
except ValueError:
# Likely already registered so skip
pass

# load model registry:
config = OmegaConf.load(package.resolve("model.yaml"))
Expand Down
18 changes: 8 additions & 10 deletions examples/09_stormcast_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime

from loguru import logger
from tqdm import tqdm

logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)


# sphinx - deterministic start
# %%
"""
Running StormCast Inference
Expand Down Expand Up @@ -64,6 +54,14 @@
# could also be used with appropriate time stamps.

# %%
from datetime import datetime

from loguru import logger
from tqdm import tqdm

logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)

import os

os.makedirs("outputs", exist_ok=True)
Expand Down
17 changes: 7 additions & 10 deletions examples/10_stormcast_ensemble_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from loguru import logger
from tqdm import tqdm

logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)


# sphinx - ensemble start
# %%
"""
Running StormCast Ensemble Inference
Expand All @@ -37,7 +28,6 @@
- https://arxiv.org/abs/2408.10958

"""

# %%
# Set Up
# ------
Expand Down Expand Up @@ -65,6 +55,13 @@
# could also be used with appropriate time stamps.

# %%
import numpy as np
from loguru import logger
from tqdm import tqdm

logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)

import os

os.makedirs("outputs", exist_ok=True)
Expand Down