From c752937e23270770fdfe2b2f1a58f6fe7e7a50ec Mon Sep 17 00:00:00 2001 From: Nick Geneva Date: Tue, 9 Apr 2024 17:22:01 -0700 Subject: [PATCH 1/7] Version update --- earth2studio/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/earth2studio/__init__.py b/earth2studio/__init__.py index 0f6e4707..e1db7aba 100644 --- a/earth2studio/__init__.py +++ b/earth2studio/__init__.py @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.0a0" +__version__ = "0.1.0" From 5bf3b9abdf53891d0f5265b7a21ce9e822fb90fa Mon Sep 17 00:00:00 2001 From: Nicholas Geneva <5533524+NickGeneva@users.noreply.github.com> Date: Wed, 10 Apr 2024 13:58:45 -0700 Subject: [PATCH 2/7] Fix GFS time limit (#4) * Updating GFS limit * Change log --- CHANGELOG.md | 12 ++++++++---- earth2studio/data/gfs.py | 4 ++-- test/data/test_gfs.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ccec11a6..7638c424 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ + # Changelog All notable changes to this project will be documented in this file. @@ -5,13 +6,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.1.0] - 2024-04-XX +## [0.2.0] - 2024-xx-xx ### Added -- Initial Release of earth2studio -- SFNO model `sfno_73ch_small`. - ### Changed ### Deprecated @@ -23,3 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Security ### Dependencies + +## [0.1.0] - 2024-04-22 + +### Added + +- Initial Release of earth2studio diff --git a/earth2studio/data/gfs.py b/earth2studio/data/gfs.py index 4faf116a..b9f10a5b 100644 --- a/earth2studio/data/gfs.py +++ b/earth2studio/data/gfs.py @@ -220,9 +220,9 @@ def _validate_time(self, times: list[datetime]) -> None: f"Requested date time {time} needs to be 6 hour interval for GFS" ) # To update search "gfs." at https://noaa-gfs-bdp-pds.s3.amazonaws.com/index.html - if time < datetime(year=2021, month=2, day=20): + if time < datetime(year=2021, month=2, day=17): raise ValueError( - f"Requested date time {time} needs to be after February 20th, 2021 for GFS" + f"Requested date time {time} needs to be after February 17th, 2021 for GFS" ) if not self.available(time): diff --git a/test/data/test_gfs.py b/test/data/test_gfs.py index 0c306242..22b699bc 100644 --- a/test/data/test_gfs.py +++ b/test/data/test_gfs.py @@ -106,7 +106,7 @@ def test_gfs_cache(time, variable, cache): @pytest.mark.parametrize( "time", [ - datetime.datetime(year=2021, month=2, day=19), + datetime.datetime(year=2021, month=2, day=16), datetime.datetime(year=2023, month=1, day=1, hour=13), datetime.datetime.now(), ], From b784a81e0b72f2828ad49ae13f11665103266b28 Mon Sep 17 00:00:00 2001 From: Dallas Foster Date: Wed, 10 Apr 2024 14:46:08 -0700 Subject: [PATCH 3/7] add model perturbation example --- examples/02_ensemble_workflow.py | 2 +- examples/05_model_perturbation_hook.py | 362 +++++++++++++++++++++++++ 2 files changed, 363 insertions(+), 1 deletion(-) create mode 100644 examples/05_model_perturbation_hook.py diff --git a/examples/02_ensemble_workflow.py b/examples/02_ensemble_workflow.py index 8d4d3cd2..09630406 100644 --- a/examples/02_ensemble_workflow.py +++ b/examples/02_ensemble_workflow.py @@ -22,7 +22,7 @@ Simple ensemble inference workflow. This example will demonstrate how to run a simple inference workflow to generate a -simple ensemble forecast using one of the built in models of Earth-2 Inference +ensemble forecast using one of the built in models of Earth-2 Inference Studio. In this example you will learn: diff --git a/examples/05_model_perturbation_hook.py b/examples/05_model_perturbation_hook.py new file mode 100644 index 00000000..73e8978e --- /dev/null +++ b/examples/05_model_perturbation_hook.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +""" +Running Ensemble Inference +========================== + +Extended ensemble inference workflow. + +This example will demonstrate how to run a an ensemble inference workflow to generate a +perturbed ensemble forecast. This perturbation is done by injecting code into the model +front and rear hooks. These hooks are applied to the tensor data before/after the model forward call. + +This example also illustrates how you can subselect data for IO. In this example we will only output +two variables: total column water vapour (tcwv) and 500 hPa geopotential (z500). To run this make +sure that the model selected predicts these variables are change appropriately. + +In this example you will learn: + +- How to instantiate a built in prognostic model +- Creating a data source and IO object +- Changing the model forward/rear hooks +- Choose a subselection of coordinates to save to an IO object. +- Post-processing results +""" + +# %% +# Creating an Ensemble Workflow +# ----------------------------------- +# +# To start lets begin with creating a ensemble workflow to use. We encourage +# users to explore and experiment with their own custom workflows that borrow ideas from +# built in workflows inside :py:obj:`earth2studio.run` or the examples. +# +# Creating our own generalizable ensemble workflow is easy when we rely on the component +# interfaces defined in Earth2Studio (use dependency injection). Here we create a run +# method that accepts the following: +# +# - time: Input list of datetimes / strings to run inference for +# - nsteps: Number of forecast steps to predict +# - nensemble: Number of ensembles to run for +# - prognostic: Our initialized prognostic model +# - data: Initialized data source to fetch initial conditions from +# - io: io store that data is written to. +# - output_coords: CoordSystem of output coordinates that should be saved. Should be +# a proper subset of model output coordinates. + +# %% +from collections import OrderedDict +from datetime import datetime +from dotenv import load_dotenv + +load_dotenv() # TODO: make common example prep function + +import numpy as np +import torch +from loguru import logger +from tqdm import tqdm +import os +from earth2studio.data import DataSource, fetch_data +from earth2studio.io import IOBackend +from earth2studio.models.px import PrognosticModel +from earth2studio.utils.coords import CoordSystem, map_coords, extract_coords +from earth2studio.utils.time import to_time_array + +logger.remove() +logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True) + + +def run_ensemble( + time: list[str] | list[datetime] | list[np.datetime64], + nsteps: int, + nensemble: int, + prognostic: PrognosticModel, + data: DataSource, + io: IOBackend, + output_coords: CoordSystem = OrderedDict({}), +) -> IOBackend: + """Ensemble workflow + + Parameters + ---------- + time : list[str] | list[datetime] | list[np.datetime64] + List of string, datetimes or np.datetime64 + nsteps : int + Number of forecast steps + nensemble : int + Number of ensemble members to run inference for. + prognostic : PrognosticModel + Prognostic models + data : DataSource + Data source + io : IOBackend + IO object + + Returns + ------- + IOBackend + Output IO object + """ + logger.info("Running ensemble inference!") + + # Load model onto the device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Inference device: {device}") + prognostic = prognostic.to(device) + + # Fetch data from data source and load onto device + time = to_time_array(time) + x, coords = fetch_data( + source=data, + time=time, + lead_time=prognostic.input_coords["lead_time"], + variable=prognostic.input_coords["variable"], + device=device, + ) + logger.success(f"Fetched data from {data.__class__.__name__}") + + # Expand x, coords for ensemble + x = x.unsqueeze(0).repeat(nensemble, *([1] * x.ndim)) + coords = {"ensemble": np.arange(nensemble)} | coords + + # Set up IO backend with information from output_coords (if applicable). + total_coords = coords.copy() + total_coords["lead_time"] = np.asarray( + [prognostic.output_coords["lead_time"] * i for i in range(nsteps + 1)] + ).flatten() + for key, value in total_coords.items(): + total_coords[key] = output_coords.get(key, value) + + variables_to_save = total_coords.pop("variable") + io.add_array(total_coords, variables_to_save) + + # Map lat and lon if needed + x, coords = map_coords(x, coords, prognostic.input_coords) + + # Create prognostic iterator + model = prognostic.create_iterator(x, coords) + + logger.info("Inference starting!") + with tqdm(total=nsteps + 1, desc="Running inference") as pbar: + for step, (x, coords) in enumerate(model): + # Subselect domain/variables as indicated in output_coords + x, coords = map_coords(x, coords, output_coords) + io.write(*extract_coords(x, coords)) + pbar.update(1) + if step == nsteps: + break + + logger.success("Inference complete") + return io + + +# %% +# Set Up +# ------ +# With the ensemble workflow defined, we now need to create the indivdual components. +# +# We need the following: +# +# - Prognostic Model: Use the built in FourCastNet model :py:class:`earth2studio.models.px.FCN`. +# - Datasource: Pull data from the GFS data api :py:class:`earth2studio.data.GFS`. +# - IO Backend: Lets save the outputs into a Zarr store :py:class:`earth2studio.io.ZarrBackend`. +# +# We will first run the ensemble workflow using an unmodified function, that is a model that has the +# default (identity) forward and rear hooks. Then we will define new hooks for the model and rerun the +# inference request. +# %% +import numpy as np +import torch + +from earth2studio.models.px import DLWP +from earth2studio.data import GFS +from earth2studio.io import ZarrBackend + +# Load the default model package which downloads the check point from NGC +package = DLWP.load_default_package() +model = DLWP.load_model(package) + +# Create the data source +data = GFS() + +# Create the IO handler, store in memory +chunks = {"ensemble": 1, "time": 1} +io_unperturbed = ZarrBackend(file_name="outputs/ensemble.zarr", chunks=chunks) + + +# %% +# Execute the Workflow +# -------------------- +# With all componments intialized, running the workflow is a single line of Python code. +# Workflow will return the provided IO object back to the user, which can be used to +# then post process. Some have additional APIs that can be handy for post-processing or +# saving to file. Check the API docs for more information. +# +# %% + +nsteps = 4 * 12 +nensemble = 16 +forecast_date = "2024-01-30" +output_coords = { + "lat": np.arange(25.0, 60.0, 0.25), + "lon": np.arange(230.0, 300.0, 0.25), + "variable": np.array(["tcwv", "z500"]), +} + +# Forst run the unperturbed model forcast +# io_unperturbed = run_ensemble([forecast_date], nsteps, nensemble, model, data, io_unperturbed, output_coords=output_coords) + +# Introduce slight model perturbation +# front_hook / rear_hook map (x, coords) -> (x, coords) +model.front_hook = lambda x, coords: ( + x + - 0.05 + * x.var(dim=0) + * (x - model.center.unsqueeze(-1)) + / (model.scale.unsqueeze(-1)) ** 2 + + 0.1 * (x - x.mean(dim=0)), + coords, +) +# Also could use model.rear_hook = ... + +io_perturbed = ZarrBackend( + file_name="outputs/ensemble_model_perturbation.zarr", chunks=chunks +) +# io_perturbed = run_ensemble([forecast_date], nsteps, nensemble, model, data, io_perturbed, output_coords=output_coords) + +# %% +# Post Processing +# --------------- +# The last step is to post process our results. Cartopy is a greate library for plotting +# fields on projects of a sphere. Here we plot and compare the ensemble mean and standard +# deviation from using a unperturbed/perturbed model. +# +# Notice that the Zarr IO function has additional APIs to interact with the stored data. + +#%% +import matplotlib.pyplot as plt +import cartopy.crs as ccrs +import matplotlib.animation as animation + + +plt.close("all") +fig = plt.figure(figsize=(20, 10), tight_layout=True) +ax0 = fig.add_subplot(2, 2, 1, projection=ccrs.PlateCarree()) +ax1 = fig.add_subplot(2, 2, 2, projection=ccrs.PlateCarree()) +ax2 = fig.add_subplot(2, 2, 3, projection=ccrs.PlateCarree()) +ax3 = fig.add_subplot(2, 2, 4, projection=ccrs.PlateCarree()) + +levels_unperturbed = np.linspace(0, io_unperturbed["tcwv"][:].max()) +levels_perturbed = np.linspace(0, io_perturbed["tcwv"][:].max()) + +std_levels_perturbed = np.linspace(0, io_perturbed["tcwv"][:].std(axis=0).max()) + + +def update(frame): + """This function updates the frame with a new lead time for animation.""" + ax0.clear() + ax1.clear() + ax2.clear() + ax3.clear() + + ## Update unperturbed image + im0 = ax0.contourf( + io_unperturbed["lon"][:], + io_unperturbed["lat"][:], + io_unperturbed["tcwv"][:, 0, frame].mean(axis=0), + transform=ccrs.PlateCarree(), + cmap="Blues", + levels=levels_unperturbed, + ) + ax0.coastlines() + ax0.gridlines() + + im1 = ax1.contourf( + io_unperturbed["lon"][:], + io_unperturbed["lat"][:], + io_unperturbed["tcwv"][:, 0, frame].std(axis=0), + transform=ccrs.PlateCarree(), + cmap="RdPu_r", + levels=std_levels_perturbed, + ) + ax1.coastlines() + ax1.gridlines() + + im2 = ax2.contourf( + io_perturbed["lon"][:], + io_perturbed["lat"][:], + io_perturbed["tcwv"][:, 0, frame].mean(axis=0), + transform=ccrs.PlateCarree(), + cmap="Blues", + levels=levels_perturbed, + ) + ax2.coastlines() + ax2.gridlines() + + im3 = ax3.contourf( + io_perturbed["lon"][:], + io_perturbed["lat"][:], + io_perturbed["tcwv"][:, 0, frame].std(axis=0), + transform=ccrs.PlateCarree(), + cmap="RdPu_r", + levels=std_levels_perturbed, + ) + ax3.coastlines() + ax3.gridlines() + + for i in range(16): + ax0.contour( + io_unperturbed["lon"][:], + io_unperturbed["lat"][:], + io_unperturbed["z500"][i, 0, frame] / 100.0, + transform=ccrs.PlateCarree(), + levels=np.arange(485, 580, 15), + colors="black", + linestyle="dashed", + ) + + ax2.contour( + io_perturbed["lon"][:], + io_perturbed["lat"][:], + io_perturbed["z500"][i, 0, frame] / 100.0, + transform=ccrs.PlateCarree(), + levels=np.arange(485, 580, 15), + colors="black", + linestyle="dashed", + ) + plt.suptitle( + f'Forecast Starting on {forecast_date} - Lead Time - {io_perturbed["lead_time"][frame]}' + ) + + if frame == 0: + ax0.set_title("Unperturbed Ensemble Mean - tcwv + z500 countors") + ax1.set_title("Unperturbed Ensemble Std - tcwv") + ax2.set_title("Perturbed Ensemble Mean - tcwv + z500 contours") + ax2.set_title("Perturbed Ensemble Std - tcwv") + + plt.colorbar(im0, ax=ax0, shrink=0.5, label="kg m^-2") + plt.colorbar(im1, ax=ax1, shrink=0.5, label="kg m^-2") + plt.colorbar(im2, ax=ax2, shrink=0.5, label="kg m^-2") + plt.colorbar(im3, ax=ax3, shrink=0.5, label="kg m^-2") + + +update(0) +ani = animation.FuncAnimation( + fig=fig, func=update, frames=range(1, nsteps), cache_frame_data=False +) +ani.save(f"outputs/model_perturbation_{forecast_date}.gif", dpi=300) From c3564fe2f8965e5013b0848674e003e77cae2177 Mon Sep 17 00:00:00 2001 From: Dallas Foster Date: Wed, 10 Apr 2024 14:48:39 -0700 Subject: [PATCH 4/7] update changelog --- CHANGELOG.md | 1 + examples/05_model_perturbation_hook.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b414133..13793c15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Initial Release of earth2studio +- Added model perturbation example. ### Changed diff --git a/examples/05_model_perturbation_hook.py b/examples/05_model_perturbation_hook.py index 73e8978e..0e34d360 100644 --- a/examples/05_model_perturbation_hook.py +++ b/examples/05_model_perturbation_hook.py @@ -70,7 +70,6 @@ import torch from loguru import logger from tqdm import tqdm -import os from earth2studio.data import DataSource, fetch_data from earth2studio.io import IOBackend from earth2studio.models.px import PrognosticModel From caf864494ad0b7ee59a5bc198f2db40470feee46 Mon Sep 17 00:00:00 2001 From: Nicholas Geneva <5533524+NickGeneva@users.noreply.github.com> Date: Thu, 11 Apr 2024 13:55:03 -0700 Subject: [PATCH 5/7] GFS fix (#6) * GFS fix valid date check --- earth2studio/data/gfs.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/earth2studio/data/gfs.py b/earth2studio/data/gfs.py index b9f10a5b..1690ce88 100644 --- a/earth2studio/data/gfs.py +++ b/earth2studio/data/gfs.py @@ -206,8 +206,9 @@ def modifier(x: np.array) -> np.array: return gfsda - def _validate_time(self, times: list[datetime]) -> None: - """Verify if date time is valid for GFS + @classmethod + def _validate_time(cls, times: list[datetime]) -> None: + """Verify if date time is valid for GFS based on offline knowledge Parameters ---------- @@ -215,18 +216,19 @@ def _validate_time(self, times: list[datetime]) -> None: list of date times to fetch data """ for time in times: - if not time.hour % 6 == 0: + if not time.hour % 6 == 0 or not time.minute == 0 or not time.second == 0: raise ValueError( f"Requested date time {time} needs to be 6 hour interval for GFS" ) # To update search "gfs." at https://noaa-gfs-bdp-pds.s3.amazonaws.com/index.html + # They are slowly adding more data if time < datetime(year=2021, month=2, day=17): raise ValueError( f"Requested date time {time} needs to be after February 17th, 2021 for GFS" ) - if not self.available(time): - raise ValueError(f"Requested date time {time} not available in GFS") + # if not self.available(time): + # raise ValueError(f"Requested date time {time} not available in GFS") def _fetch_index(self, time: datetime) -> dict[str, tuple[int, int]]: """Fetch GFS atmospheric index file @@ -331,6 +333,12 @@ def available( _ds = np.timedelta64(1, "s") time = datetime.utcfromtimestamp((time - _unix) / _ds) + # Offline checks + try: + cls._validate_time([time]) + except ValueError: + return False + s3 = boto3.client( "s3", config=botocore.config.Config(signature_version=UNSIGNED) ) From f57c43ca5f18cca1fdc2f8fb4a2dfeb98e3c233a Mon Sep 17 00:00:00 2001 From: Dallas Foster Date: Thu, 11 Apr 2024 14:57:10 -0700 Subject: [PATCH 6/7] fix typos --- examples/05_model_perturbation_hook.py | 62 +++++++++++++++++--------- 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/examples/05_model_perturbation_hook.py b/examples/05_model_perturbation_hook.py index 0e34d360..0ff8d55e 100644 --- a/examples/05_model_perturbation_hook.py +++ b/examples/05_model_perturbation_hook.py @@ -16,10 +16,10 @@ # %% """ -Running Ensemble Inference +Model Hook Injection: Perturbation ========================== -Extended ensemble inference workflow. +Adding model noise by using custom hooks. This example will demonstrate how to run a an ensemble inference workflow to generate a perturbed ensemble forecast. This perturbation is done by injecting code into the model @@ -42,7 +42,7 @@ # Creating an Ensemble Workflow # ----------------------------------- # -# To start lets begin with creating a ensemble workflow to use. We encourage +# To start lets begin with creating an ensemble workflow to use. We encourage # users to explore and experiment with their own custom workflows that borrow ideas from # built in workflows inside :py:obj:`earth2studio.run` or the examples. # @@ -70,6 +70,7 @@ import torch from loguru import logger from tqdm import tqdm + from earth2studio.data import DataSource, fetch_data from earth2studio.io import IOBackend from earth2studio.models.px import PrognosticModel @@ -171,7 +172,7 @@ def run_ensemble( # # We need the following: # -# - Prognostic Model: Use the built in FourCastNet model :py:class:`earth2studio.models.px.FCN`. +# - Prognostic Model: Use the built in FourCastNet model :py:class:`earth2studio.models.px.DLWP`. # - Datasource: Pull data from the GFS data api :py:class:`earth2studio.data.GFS`. # - IO Backend: Lets save the outputs into a Zarr store :py:class:`earth2studio.io.ZarrBackend`. # @@ -251,8 +252,14 @@ def run_ensemble( import matplotlib.pyplot as plt import cartopy.crs as ccrs import matplotlib.animation as animation +from matplotlib.colors import LogNorm +levels_unperturbed = np.linspace(0, io_unperturbed["tcwv"][:].max()) +levels_perturbed = np.linspace(0, io_perturbed["tcwv"][:].max()) + +std_levels_perturbed = np.linspace(0, io_perturbed["tcwv"][:].std(axis=0).max()) + plt.close("all") fig = plt.figure(figsize=(20, 10), tight_layout=True) ax0 = fig.add_subplot(2, 2, 1, projection=ccrs.PlateCarree()) @@ -260,11 +267,6 @@ def run_ensemble( ax2 = fig.add_subplot(2, 2, 3, projection=ccrs.PlateCarree()) ax3 = fig.add_subplot(2, 2, 4, projection=ccrs.PlateCarree()) -levels_unperturbed = np.linspace(0, io_unperturbed["tcwv"][:].max()) -levels_perturbed = np.linspace(0, io_perturbed["tcwv"][:].max()) - -std_levels_perturbed = np.linspace(0, io_perturbed["tcwv"][:].std(axis=0).max()) - def update(frame): """This function updates the frame with a new lead time for animation.""" @@ -290,8 +292,9 @@ def update(frame): io_unperturbed["lat"][:], io_unperturbed["tcwv"][:, 0, frame].std(axis=0), transform=ccrs.PlateCarree(), - cmap="RdPu_r", + cmap="RdPu", levels=std_levels_perturbed, + norm=LogNorm(vmin=1e-1, vmax=std_levels_perturbed[-1]), ) ax1.coastlines() ax1.gridlines() @@ -312,8 +315,9 @@ def update(frame): io_perturbed["lat"][:], io_perturbed["tcwv"][:, 0, frame].std(axis=0), transform=ccrs.PlateCarree(), - cmap="RdPu_r", + cmap="RdPu", levels=std_levels_perturbed, + norm=LogNorm(vmin=1e-1, vmax=std_levels_perturbed[-1]), ) ax3.coastlines() ax3.gridlines() @@ -348,14 +352,32 @@ def update(frame): ax2.set_title("Perturbed Ensemble Mean - tcwv + z500 contours") ax2.set_title("Perturbed Ensemble Std - tcwv") - plt.colorbar(im0, ax=ax0, shrink=0.5, label="kg m^-2") - plt.colorbar(im1, ax=ax1, shrink=0.5, label="kg m^-2") - plt.colorbar(im2, ax=ax2, shrink=0.5, label="kg m^-2") - plt.colorbar(im3, ax=ax3, shrink=0.5, label="kg m^-2") + plt.colorbar( + im0, ax=ax0, shrink=0.75, pad=0.04, label="kg m^-2", format="%2.1f" + ) + plt.colorbar( + im1, ax=ax1, shrink=0.75, pad=0.04, label="kg m^-2", format="%1.2e" + ) + plt.colorbar( + im2, ax=ax2, shrink=0.75, pad=0.04, label="kg m^-2", format="%2.1f" + ) + plt.colorbar( + im3, ax=ax3, shrink=0.75, pad=0.04, label="kg m^-2", format="%1.2e" + ) -update(0) -ani = animation.FuncAnimation( - fig=fig, func=update, frames=range(1, nsteps), cache_frame_data=False -) -ani.save(f"outputs/model_perturbation_{forecast_date}.gif", dpi=300) +# Uncomment this for animation +# update(0) +# ani = animation.FuncAnimation( +# fig=fig, func=update, frames=range(1, nsteps), cache_frame_data=False +# ) +# ani.save(f"outputs/model_perturbation_{forecast_date}.gif", dpi=300) + +# Here we plot a handful of images +for lt in [0, 10, 20, 30, 40]: + update(lt) + plt.savefig( + f"outputs/model_perturbation_{forecast_date}_leadtime_{lt}.png", + dpi=300, + bbox_inches="tight", + ) From bb8a1c2e15d62066156f91254681c6409e779b64 Mon Sep 17 00:00:00 2001 From: Dallas Foster Date: Thu, 11 Apr 2024 15:20:15 -0700 Subject: [PATCH 7/7] fix typos --- examples/05_model_perturbation_hook.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/examples/05_model_perturbation_hook.py b/examples/05_model_perturbation_hook.py index 0ff8d55e..060936cd 100644 --- a/examples/05_model_perturbation_hook.py +++ b/examples/05_model_perturbation_hook.py @@ -219,7 +219,15 @@ def run_ensemble( } # Forst run the unperturbed model forcast -# io_unperturbed = run_ensemble([forecast_date], nsteps, nensemble, model, data, io_unperturbed, output_coords=output_coords) +io_unperturbed = run_ensemble( + [forecast_date], + nsteps, + nensemble, + model, + data, + io_unperturbed, + output_coords=output_coords, +) # Introduce slight model perturbation # front_hook / rear_hook map (x, coords) -> (x, coords) @@ -237,7 +245,15 @@ def run_ensemble( io_perturbed = ZarrBackend( file_name="outputs/ensemble_model_perturbation.zarr", chunks=chunks ) -# io_perturbed = run_ensemble([forecast_date], nsteps, nensemble, model, data, io_perturbed, output_coords=output_coords) +io_perturbed = run_ensemble( + [forecast_date], + nsteps, + nensemble, + model, + data, + io_perturbed, + output_coords=output_coords, +) # %% # Post Processing