Skip to content

Commit

Permalink
StormCast training code example (#724)
Browse files Browse the repository at this point in the history
* adding stormcast raw files

* major cleanup, refactor and consolidation

Signed-off-by: Peter Harrington <[email protected]>

* More cleanup and init readme

Signed-off-by: Peter Harrington <[email protected]>

* port command line args to standard argparse

Signed-off-by: Peter Harrington <[email protected]>

* remove unused network and loss wrappers

Signed-off-by: Peter Harrington <[email protected]>

* add torchrun instructions

Signed-off-by: Peter Harrington <[email protected]>

* drop dnnlib utils

Signed-off-by: Peter Harrington <[email protected]>

* use Modulus DistributedManager, streamline cmd args

Signed-off-by: Peter Harrington <[email protected]>

* Use standard torch checkpoints instead of pickles

Signed-off-by: Peter Harrington <[email protected]>

* Standardize model configs and channel selection across training and inference

Signed-off-by: Peter Harrington <[email protected]>

* checkpoint format standardization for train/inference

Signed-off-by: Peter Harrington <[email protected]>

* finalize additional deps

Signed-off-by: Peter Harrington <[email protected]>

* format and linting

Signed-off-by: Peter Harrington <[email protected]>

* drop docker and update changelog

Signed-off-by: Peter Harrington <[email protected]>

* Address feedback

Signed-off-by: Peter Harrington <[email protected]>

* add variables to readme, rename network types

Signed-off-by: Peter Harrington <[email protected]>

---------

Signed-off-by: Peter Harrington <[email protected]>
Co-authored-by: nvssh nssswitch user account <[email protected]>
  • Loading branch information
pzharrington and nvssh nssswitch user account authored Nov 25, 2024
1 parent 9e96ddf commit 7f739f7
Show file tree
Hide file tree
Showing 19 changed files with 4,073 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- The XAeroNet model.
- Incoporated CorrDiff-GEFS-HRRR model into CorrDiff, with lead-time aware SongUNet and
cross entropy loss.
- Added StormCast model training and simple inference to examples

### Changed

Expand Down
Binary file added docs/img/stormcast_rollout.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions examples/generative/stormcast/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Extra paths to avoid tracking here

rundir/
145 changes: 145 additions & 0 deletions examples/generative/stormcast/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
<!-- markdownlint-disable -->
## StormCast: Kilometer-Scale Convection Allowing Model Emulation using Generative Diffusion Modeling

**Note: this example is an initial release of the StormCast code and will be heavily refactored in future releases**

## Problem overview

Convection-allowing models (CAMs) are essential tools for forecasting severe thunderstorms and
mesoscale convective systems, which are responsible for some of the most extreme weather events.
By resolving kilometer-scale convective dynamics, these models provide the precision needed for
accurate hazard prediction. However, modeling the atmosphere at this scale is both challenging
and expensive.

This example demonstrates how to run training and simple inference for [StormCast](https://arxiv.org/abs/2408.10958),
a generative diffusion model designed to emulate NOAA’s High-Resolution Rapid Refresh (HRRR) model, a 3km
operational CAM. StormCast autoregressively predicts multiple atmospheric state variables with remarkable
accuracy, demonstrating ability to replicate storm dynamics, observed radar reflectivity, and realistic
atmospheric structure via deep learning-based CAM emulation. StormCast enables high-resolution ML-driven
regional weather forecasting and climate risk analysis.


<p align="center">
<img src="../../../docs/img/stormcast_rollout.gif"/>
</p>

## Getting started

### Preliminaries
Start by installing Modulus (if not already installed) and copying this folder (`examples/generative/stormcast`) to a system with a GPU available. Also, prepare a combined HRRR/ERA5 dataset in the form specified in `utils/data_loader_hrrr_era5.py` (**Note: subsequent versions of this example will include more detailed dataset preparation instructions**).

### Configuration basics

StormCast training is handled by `train.py` and controlled by a YAML configuration file in `config/config.yaml` and command line arguments. You can choose the configuration file using the `--config_file` option, and a specific configuration within that file with the `--config-name` option. The main configuration file specifies the training dataset, the model configuration and the training options. To change a configuration option, you can either edit the existing configurations directly or make new ones by inheriting from the existing configs and overriding specific options. For example, one could create a new config for training the diffusion model in StormCast by creating a new config that inherits from the existing `diffusion` config in `config/config.yaml`:
```
diffusion_bs64:
<<: *diffusion
batch_size: 1
```

The basic configuration file currently contains configurations for just the `regression` and `diffusion` components of StormCast. Note any diffusion model you train will need a pretrained regression model to use, due to how StormCast is designed (you can refer to the paper for more details), thus there are two config items that must be defined to train a diffusion model:
1. `regression_weights` -- The path to a checkpoint with model weights for the regression model. This file should be a pytorch checkpoint saved by your training script, with the `state_dict` for the regression network saved under the `net` key.
2. `regression_config` -- the config name used to train this regression model

All configuration items related to the dataset are also contained in `config/config.yaml`, most importantly the location on the filesystem of the prepared HRRR/ERA5 Dataset (see [Dataset section](#dataset) for details).

There is also a model registry `config/registry.json` which can be used to index different model versions to be used in inference/evaluation. For simplicity, there is just a single model version specified there currently, which matches the StormCast model used to generate results in the paper.

### Training the regression model
To train the StormCast regression model, we use the default configuration file `config.yaml` and specify the `regression` config, along with the `--outdir` argument to choose where training logs and checkpoints should be saved.
We also can use command line options defined in `train.py` to specify other details, like a unique run ID to use for the experiment (`--run_id`). On a single GPU machine, for example, run:
```bash
python train.py --outdir rundir --config_file ./config/config.yaml --config_name regression --run_id 0
```

This will initialize training experiment and launch the main training loop, which is defined in `utils/diffusions/training_loop.py`.

### Training the diffusion model

The method for launching a diffusion model training looks almost identical, and we just have to change the configuration name appropriately. However, since we need a pre-trained regression model for the diffusion model training, this config must define `regression_pickle` to point to a compatible pickle file with network weights for the regression model. Once that is taken care of, launching diffusion training looks nearly identical as previously:
```bash
python train.py --outdir rundir --config_file ./config/config.yaml --config_name diffusion --run_id 0
```

Note that the full training pipeline for StormCast is fairly lengthy, requiring about 120 hours on 64 NVIDIA H100 GPUs. However, more lightweight trainings can still produce decent models if the diffusion model is not trained for as long.

Both regression and diffusion training can be distributed easily with data parallelism via `torchrun`. One just needs to ensure the configuration being run has a large enough batch size to be distributed over the number of available GPUs/processes. The example `regression` and `diffusion` configs in `config/config.yaml` just use a batch size of 1 for simplicity, but new configs can be easily added [as described above](#configuration-basics). For example, distributed training over 8 GPUs on one node would look something like:
```bash
torchrun --standalone --nnodes=1 --nproc_per_node=8 train.py --outdir rundir --config_file ./config/config.yaml --config_name <your_distributed_training_config> --run_id 0
```

Once the training is completed, you can enter a new model into `config/registry.json` that points to the checkpoints (`.pt` file in your training output directory), and you are ready to run inference.

### Inference

A simple demonstrative inference script is given in `inference.py`, which loads a pretrained model from a local directory named `stormcast_checkpoints`.
Yout should update this path to the checkpoints saved by your training runs that you want to run inference for.
The `inference.py` script will run a 12-hour forecast and save outputs as a `zarr` file along with a few plots saved as `png` files.

To run inference, simply do:

```bash
python inference.py
```
This inference script is configured by the contents of a model registry, which specifies config files and names to use for each of the diffusion and regression networks, along with other inference options which specify architecure types and a short description of the model. The `inference.py` script will automatically use the default file for the model registry (`config/registry.json`) and evaluate the `stormcast` example model, but you can configure it to run your desired inference case(s) with the following command-line options:
```bash
--outdir DIR Where to save the results
--registry_file FILE Path to model registry file
--model_name MODEL Name of model to evaluate from the registry
```

We also recommend bringing your checkpoints to [earth2studio](https://github.com/NVIDIA/earth2studio)
for further anaylysis and visualizations.


## Dataset

In this example, StormCast is trained on the [HRRR dataset](https://rapidrefresh.noaa.gov/hrrr/),
conditioned on the [ERA5 dataset](https://www.ecmwf.int/en/forecasts/dataset/ecmwf-reanalysis-v5).
The datapipe in this example is tailored specifically for the domain and problem setting posed in the
[original StormCast preprint](https://arxiv.org/abs/2408.10958), namely a subset of HRRR and ERA5 variables
in a region over the Central US with spatial extent 1536km x 1920km.


A custom dataset object is defined in `utils/data_loader_hrrr_era5.py`, which loads temporally-aligned samples from HRRR and ERA5, interpolated to the same grid and normalized appropriately. This data pipeline requries the HRRR and ERA5 data to abide by a specific `zarr` format and for other datasets, you will need to create a custom datapipe. The table below lists the variables used to train StormCast -- in total there are 26 ERA5 variables used and 99 HRRR variables (along with 2 static HRRR invariants, the land/water mask and orography).

#### ERA5 Variables
| Parameter | Pressure Levels (hPa) | Height Levels (m) |
|---------------------------------------|---------------------------|--------------------|
| Zonal Wind (u) | 1000, 850, 500, 250 | 10 |
| Meridional Wind (v) | 1000, 850, 500, 250 | 10 |
| Geopotential Height (z) | 1000, 850, 500, 250 | None |
| Temperature (t) | 1000, 850, 500, 250 | 2 |
| Humidity (q) | 1000, 850, 500, 250 | None |
| Total Column of Water Vapour (tcwv) | Integrated | - |
| Mean Sea Level Pressure (mslp) | Surface | - |
| Surface Pressure (sp) | Surface | - |


#### HRRR Variables
| Parameter | Hybrid Model Levels (Index) | Height Levels (m) |
|---------------------------------------|-----------------------------------------------------------|--------------------|
| Zonal Wind (u) | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 20, 25, 30 | 10 |
| Meridional Wind (v) | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 20, 25, 30 | 10 |
| Geopotential Height (z) | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 20, 25, 30 | None |
| Temperature (t) | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 20, 25, 30 | 2 |
| Humidity (q) | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 20, 25, 30 | None |
| Pressure (p) | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 20 | None |
| Max. Composite Radar Reflectivity | - | Integrated |
| Mean Sea Level Pressure (mslp) | - | Surface |
| Orography | - | Surface |
| Land/Water Mask | - | Surface |


## Logging

These scripts use Weights & Biases for experiment tracking, which can be enabled by passing the `--log_to_wandb` argument to `train.py`. Academic accounts are free to create at [wandb.ai](https://wandb.ai/).
Once you have an account set up, you can adjust `entity` and `project` in `train.py` to the appropriate names for your `wandb` workspace.


## References

- [Kilometer-Scale Convection Allowing Model Emulation using Generative Diffusion Modeling](https://arxiv.org/abs/2408.10958)
- [Elucidating the design space of diffusion-based generative models](https://openreview.net/pdf?id=k7FuTOWMOc7)
- [Score-Based Generative Modeling through Stochastic Differential Equations](https://arxiv.org/pdf/2011.13456.pdf)

75 changes: 75 additions & 0 deletions examples/generative/stormcast/config/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

base: &base_config
# data config
num_data_workers: 4 # number of dataloader worker threads per proc
location: 'data' # Path to the dataset
dt: 1
log_to_wandb: !!bool False
hrrr_img_size: [512, 640]
n_hrrr_channels: 127
n_era5_channels: 26
invariants: ["lsm", "orog"]
conus_dataset_name: 'hrrr_v3'
hrrr_stats: 'stats_v3_2019_2021' #stats files changed to reflect dropped samples from 2017 and half of 2018
input_channels: 'all' #'all' or list of channels to condition on
exclude_channels: ['u35', 'u40', 'v35', 'v40', 't35', 't40', 'q35', 'q40', 'w1', 'w2', 'w3', 'w4', 'w5', 'w6', 'w7', 'w8', 'w9', 'w10', 'w11', 'w13', 'w15', 'w20', 'w25', 'w30', 'w35', 'w40', 'p25', 'p30', 'p35', 'p40', 'z35', 'z40', 'tcwv', 'vil']
diffusion_channels: "all"
boundary_padding_pixels: 0 # set this to 0 for no padding, 32 for 32 pixels of padding in each direction etc.
train_years: [2018, 2019, 2020, 2021]
valid_years: [2022]

# hyperparameters
batch_size: 64
lr: 4E-4
total_kimg: 100000
img_per_tick: 1000
clip_grad_norm: None
residual: !!bool True
previous_step_conditioning: !!bool False
pure_diffusion: !!bool False
spatial_pos_embed: !!bool False
P_mean: -1.2 #default edm value
use_regression_net: !!bool True
attn_resolutions: []
ema_freq_kimg: 10


# ----------------------------------------------------------------------
regression: &regression
<<: *base_config
batch_size: 1
use_regression_net: !!bool False
loss: 'regression_v2'
validate_every: 1
total_kimg: 1
img_per_tick: 1
# ----------------------------------------------------------------------


# ----------------------------------------------------------------------
diffusion: &diffusion
<<: *base_config
batch_size: 1
use_regression_net: !!bool True
regression_weights: "stormcast_checkpoints/regression_chkpt.pt"
regression_config: "regression"
previous_step_conditioning: !!bool True
spatial_pos_embed: !!bool True
loss: 'edm'
validate_every: 1
# ----------------------------------------------------------------------
13 changes: 13 additions & 0 deletions examples/generative/stormcast/config/registry.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"models": {
"stormcast":{
"edm_checkpoint_path": "stormcast_checkpoints/diffusion_chkpt.pt",
"edm_config_file": "./config/config.yaml",
"edm_config_name": "diffusion",
"regression_checkpoint_path": "stormcast_checkpoints/regression_chkpt.pt",
"regression_config_file": "./config/config.yaml",
"regression_config_name": "regression",
"description": "Example stormcast inference config"
}
}
}
Loading

0 comments on commit 7f739f7

Please sign in to comment.