diff --git a/earth2studio/models/px/sfno.py b/earth2studio/models/px/sfno.py index 70bb4cde..82f11358 100644 --- a/earth2studio/models/px/sfno.py +++ b/earth2studio/models/px/sfno.py @@ -116,13 +116,16 @@ class SFNO(torch.nn.Module, AutoModelMixin, PrognosticMixin): Consists of a single model with a time-step size of 6 hours. FourCastNet operates on 0.25 degree lat-lon grid (south-pole excluding) equirectangular grid with 73 variables. + Note ---- This model and checkpoint are trained using Modulus-Makani. For more information see the following references: + - https://arxiv.org/abs/2306.03838 - https://github.com/NVIDIA/modulus-makani - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/models/sfno_73ch_small + Parameters ---------- core_model : torch.nn.Module diff --git a/earth2studio/models/px/stormcast.py b/earth2studio/models/px/stormcast.py index 3eae07c0..77400154 100644 --- a/earth2studio/models/px/stormcast.py +++ b/earth2studio/models/px/stormcast.py @@ -71,33 +71,51 @@ class StormCast(torch.nn.Module, AutoModelMixin, PrognosticMixin): - """StormCast generative convection-allowing model for regional forecasts - Consists of two core models, a regression and diffusion model - This class implements StormCastV1, the model released in: - https://arxiv.org/abs/2408.10958 - Model time step size is 1 hour, taking as input: + """StormCast generative convection-allowing model for regional forecasts consists of + two core models: a regression and diffusion model. Model time step size is 1 hour, + taking as input: - High-resolution (3km) HRRR state over the central United States (99 vars) - High-resolution land-sea mask and orography invariants - Coarse resolution (25km) global state (26 vars) The high-resolution grid is the HRRR Lambert conformal projection Coarse-resolution inputs are regridded to the HRRR grid internally. + Note + ---- + For more information see the following references: + + - https://arxiv.org/abs/2408.10958 + Parameters ---------- - regression_model (torch.nn.Module): Deterministic model used to make an initial prediction - diffusion_model (torch.nn.Module): Generative model correcting the deterministic prediciton - lat (np.array): Latitude array (2D) of the domain - lon (np.array): Latitude array (2D) of the domain - means (torch.Tensor): Mean value of each input high-resolution variable - stds (torch.Tensor): Standard deviation of each input high-resolution variable - invariants (torch.Tensor): Static invariant quantities - variables (np.array, optional): High-resolution variables Defaults to np.array(VARIABLES). - conditioning_means (torch.Tensor | None, optional): Means to normalize conditioning data. Defaults to None. - conditioning_stds (torch.Tensor | None, optional): Stds to normalize conditioning data Defaults to None. - conditioning_variables (np.array, optional): Global variables for conditioning. Defaults to np.array(CONDITIONING_VARIABLES). - conditioning_data_source (DataSource | None, optional): Data Source to use for global conditoining. Defaults to None. Required for running in iterator mode - sampler_args (dict[str, float | int], optional): Arguments to pass to the diffusion sampler. Defaults to {}. - interp_method (str, optional): Interpolation method to use when regridding coarse conditoining data. Defaults to "linear". + regression_model : torch.nn.Module + Deterministic model used to make an initial prediction + diffusion_model : torch.nn.Module + Generative model correcting the deterministic prediciton + lat : np.array + Latitude array (2D) of the domain + lon : np.array + Longitude array (2D) of the domain + means : torch.Tensor + Mean value of each input high-resolution variable + stds : torch.Tensor + Standard deviation of each input high-resolution variable + invariants : torch.Tensor + Static invariant quantities + variables : np.array, optional + High-resolution variables, by default np.array(VARIABLES) + conditioning_means : torch.Tensor | None, optional + Means to normalize conditioning data, by default None + conditioning_stds : torch.Tensor | None, optional + Standard deviations to normalize conditioning data, by default None + conditioning_variables : np.array, optional + Global variables for conditioning, by default np.array(CONDITIONING_VARIABLES) + conditioning_data_source : DataSource | None, optional + Data Source to use for global conditoining. Required for running in iterator mode, by default None + sampler_args : dict[str, float | int], optional + Arguments to pass to the diffusion sampler, by default {} + interp_method : str, optional + Interpolation method to use when regridding coarse conditoining data, by default "linear" """ def __init__( @@ -231,7 +249,11 @@ def load_model(cls, package: Package) -> DiagnosticModel: f"nvidia-modulus @ git+https://github.com/NVIDIA/modulus.git" ) - OmegaConf.register_new_resolver("eval", eval) + try: + OmegaConf.register_new_resolver("eval", eval) + except ValueError: + # Likely already registered so skip + pass # load model registry: config = OmegaConf.load(package.resolve("model.yaml")) diff --git a/examples/09_stormcast_example.py b/examples/09_stormcast_example.py index 503295f0..5457596a 100644 --- a/examples/09_stormcast_example.py +++ b/examples/09_stormcast_example.py @@ -14,16 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime - -from loguru import logger -from tqdm import tqdm - -logger.remove() -logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True) - - -# sphinx - deterministic start # %% """ Running StormCast Inference @@ -64,6 +54,14 @@ # could also be used with appropriate time stamps. # %% +from datetime import datetime + +from loguru import logger +from tqdm import tqdm + +logger.remove() +logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True) + import os os.makedirs("outputs", exist_ok=True) diff --git a/examples/10_stormcast_ensemble_example.py b/examples/10_stormcast_ensemble_example.py index ac560d78..a4c95038 100644 --- a/examples/10_stormcast_ensemble_example.py +++ b/examples/10_stormcast_ensemble_example.py @@ -14,15 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -from loguru import logger -from tqdm import tqdm - -logger.remove() -logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True) - - -# sphinx - ensemble start # %% """ Running StormCast Ensemble Inference @@ -37,7 +28,6 @@ - https://arxiv.org/abs/2408.10958 """ - # %% # Set Up # ------ @@ -65,6 +55,13 @@ # could also be used with appropriate time stamps. # %% +import numpy as np +from loguru import logger +from tqdm import tqdm + +logger.remove() +logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True) + import os os.makedirs("outputs", exist_ok=True)