diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f29b953..23c6032b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add StormCast model to prognostic models - Interpolation between arbitrary lat-lon grids - Added hybrid level support to HRRR data source diff --git a/docs/modules/models.rst b/docs/modules/models.rst index bab4a8fc..638f9e51 100644 --- a/docs/modules/models.rst +++ b/docs/modules/models.rst @@ -34,6 +34,7 @@ Thus are typically used to generate forecast predictions. models.px.Pangu3 models.px.Persistence models.px.SFNO + models.px.StormCast .. _earth2studio.models.dx: diff --git a/earth2studio/data/rand.py b/earth2studio/data/rand.py index 5feacafb..6e3730fd 100644 --- a/earth2studio/data/rand.py +++ b/earth2studio/data/rand.py @@ -39,10 +39,18 @@ def __init__( ): self.domain_coords = domain_coords + # Check for regular vs. curvilinear coordinates + _, value = list(self.domain_coords.items()).pop() + value = np.array(value) + self.curv = len(value.shape) > 1 + if self.curv: + self.domain_coord_shape = value.shape + def __call__( self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, + lead_time: np.array = None, ) -> xr.DataArray: """Retrieve random gaussian data. @@ -63,12 +71,24 @@ def __call__( shape = [len(time), len(variable)] coords = {"time": time, "variable": variable} - for key, value in self.domain_coords.items(): - shape.append(len(value)) - coords[key] = value - da = xr.DataArray( - data=np.random.randn(*shape), dims=list(coords), coords=coords - ) + if self.curv: + shape.extend(self.domain_coord_shape) + dims = ["time", "variable", "y", "x"] + coords = coords | { + "lat": (("y", "x"), self.domain_coords["lat"]), + "lon": (("y", "x"), self.domain_coords["lon"]), + } + da = xr.DataArray(data=np.random.randn(*shape), dims=dims, coords=coords) + + else: + + for key, value in self.domain_coords.items(): + shape.append(len(value)) + coords[key] = value + + da = xr.DataArray( + data=np.random.randn(*shape), dims=list(coords), coords=coords + ) return da diff --git a/earth2studio/data/utils.py b/earth2studio/data/utils.py index da9edcd0..2123552f 100644 --- a/earth2studio/data/utils.py +++ b/earth2studio/data/utils.py @@ -29,6 +29,7 @@ from loguru import logger from earth2studio.data.base import DataSource +from earth2studio.utils.interp import LatLonInterpolation from earth2studio.utils.time import ( leadtimearray_to_timedelta, timearray_to_datetime, @@ -43,8 +44,12 @@ def fetch_data( variable: VariableArray, lead_time: LeadTimeArray = np.array([np.timedelta64(0, "h")]), device: torch.device = "cpu", + interp_to: CoordSystem = None, + interp_method: str = "nearest", ) -> tuple[torch.Tensor, CoordSystem]: """Utility function to fetch data for models and load data on the target device. + If desired, xarray interpolation/regridding in the spatial domain can be used + by passing a target coordinate system via the optional `interp_to` argument. Parameters ---------- @@ -59,6 +64,11 @@ def fetch_data( np.array(np.timedelta64(0, "h")) device : torch.device, optional Torch devive to load data tensor to, by default "cpu" + interp_to : CoordSystem, optional + If provided, the fetched data will be interpolated to the coordinates + specified by lat/lon arrays in this CoordSystem + interp_method : str + Interpolation method to use with xarray (by default 'nearest') Returns ------- @@ -75,22 +85,37 @@ def fetch_data( da0 = da0.assign_coords(time=time) da.append(da0) - return prep_data_array(xr.concat(da, "lead_time"), device=device) + return prep_data_array( + xr.concat(da, "lead_time"), + device=device, + interp_to=interp_to, + interp_method=interp_method, + ) def prep_data_array( da: xr.DataArray, device: torch.device = "cpu", + interp_to: CoordSystem = None, + interp_method: str = "nearest", ) -> tuple[torch.Tensor, CoordSystem]: """Prepares a data array from a data source for inference workflows by converting the data array to a torch tensor and the coordinate system to an OrderedDict. + If desired, xarray interpolation/regridding in the spatial domain can be used + by passing a target coordinate system via the optional `interp_to` argument. + Parameters ---------- da : xr.DataArray Input data array device : torch.device, optional Torch devive to load data tensor to, by default "cpu" + interp_to : CoordSystem, optional + If provided, the fetched data will be interpolated to the coordinates + specified by lat/lon arrays in this CoordSystem + interp_method : str + Interpolation method to use with xarray (by default 'nearest') Returns ------- @@ -98,11 +123,65 @@ def prep_data_array( Tuple containing output tensor and coordinate OrderedDict """ - out = torch.Tensor(da.values).to(device) - + # Initialize the output CoordSystem out_coords = OrderedDict() for dim in da.coords.dims: - out_coords[dim] = np.array(da.coords[dim]) + if dim in ["time", "lead_time", "variable"]: + out_coords[dim] = np.array(da.coords[dim]) + + # Fetch data and regrid if necessary + if interp_to is not None: + if len(interp_to["lat"].shape) != len(interp_to["lon"].shape): + raise ValueError( + "Discrepancy in interpolation coordinates: latitude has different number of dims than longitude" + ) + + if "lat" not in da.dims: + # Data source uses curvilinear coordinates + if interp_method != "linear": + raise ValueError( + "fetch_data does not support interpolation methods other than linear when data source has a curvilinear grid" + ) + interp = LatLonInterpolation( + lat_in=da["lat"].values, + lon_in=da["lon"].values, + lat_out=interp_to["lat"], + lon_out=interp_to["lon"], + ).to(device) + data = torch.Tensor(da.values).to(device) + out = interp(data) + + else: + + if len(interp_to["lat"].shape) > 1 or len(interp_to["lon"].shape) > 1: + # Target grid uses curvilinear coordinates: define internal dims y, x + target_lat = xr.DataArray(interp_to["lat"], dims=["y", "x"]) + target_lon = xr.DataArray(interp_to["lon"], dims=["y", "x"]) + else: + target_lat = xr.DataArray(interp_to["lat"], dims=["lat"]) + target_lon = xr.DataArray(interp_to["lon"], dims=["lon"]) + + da = da.interp( + lat=target_lat, + lon=target_lon, + method=interp_method, + ) + + out = torch.Tensor(da.values).to(device) + + out_coords["lat"] = interp_to["lat"] + out_coords["lon"] = interp_to["lon"] + + else: + out = torch.Tensor(da.values).to(device) + if "lat" in da.coords and "lat" not in da.coords.dims: + # Curvilinear grid case: lat/lon coords are 2D arrays, not in dims + out_coords["lat"] = da.coords["lat"].values + out_coords["lon"] = da.coords["lon"].values + else: + for dim in da.coords.dims: + if dim not in ["time", "lead_time", "variable"]: + out_coords[dim] = np.array(da.coords[dim]) return out, out_coords diff --git a/earth2studio/lexicon/hrrr.py b/earth2studio/lexicon/hrrr.py index fc28a1dd..a27338d6 100644 --- a/earth2studio/lexicon/hrrr.py +++ b/earth2studio/lexicon/hrrr.py @@ -92,16 +92,17 @@ def build_vocab() -> dict[str, str]: 975, 1000, ] - prs_names = ["UGRD", "VGRD", "HGT", "TMP", "RH", "SPFH"] - e2s_id = ["u", "v", "z", "t", "r", "q"] + + prs_names = ["UGRD", "VGRD", "HGT", "TMP", "RH", "SPFH", "HGT"] + e2s_id = ["u", "v", "z", "t", "r", "q", "Z"] prs_variables = {} for (id, variable) in zip(e2s_id, prs_names): for level in prs_levels: prs_variables[f"{id}{level:d}"] = f"prs::anl::{level} mb::{variable}" hybrid_levels = list(range(1, 51)) - hybrid_names = ["UGRD", "VGRD", "HGT", "TMP", "SPFH", "PRES"] - e2s_id = ["u", "v", "z", "t", "q", "p"] + hybrid_names = ["UGRD", "VGRD", "HGT", "TMP", "SPFH", "PRES", "HGT"] + e2s_id = ["u", "v", "z", "t", "q", "p", "Z"] hybrid_variables = {} for (id, variable) in zip(e2s_id, hybrid_names): for level in hybrid_levels: @@ -187,8 +188,8 @@ def build_vocab() -> dict[str, str]: prs_variables[f"{id}{level:d}"] = f"sfc::fcst::{level} mb::{variable}" hybrid_levels = list(range(1, 51)) - hybrid_names = ["UGRD", "VGRD", "HGT", "TMP", "SPFH", "PRES"] - e2s_id = ["u", "v", "z", "t", "q", "p"] + hybrid_names = ["UGRD", "VGRD", "HGT", "TMP", "SPFH", "PRES", "HGT"] + e2s_id = ["u", "v", "z", "t", "q", "p", "Z"] hybrid_variables = {} for (id, variable) in zip(e2s_id, hybrid_names): for level in hybrid_levels: diff --git a/earth2studio/models/batch.py b/earth2studio/models/batch.py index 4ad14ce3..346accf0 100644 --- a/earth2studio/models/batch.py +++ b/earth2studio/models/batch.py @@ -164,7 +164,9 @@ def _batch_wrap(self, func: Callable) -> Callable: # TODO: Better typing for model object @functools.wraps(func) def _wrapper( - model: Any, x: torch.Tensor, coords: CoordSystem + model: Any, + x: torch.Tensor, + coords: CoordSystem, ) -> tuple[torch.Tensor, CoordSystem]: x, flatten_coords, batched_coords, batched_shape = self._compress_batch( diff --git a/earth2studio/models/px/__init__.py b/earth2studio/models/px/__init__.py index 88545f6c..cdca211a 100644 --- a/earth2studio/models/px/__init__.py +++ b/earth2studio/models/px/__init__.py @@ -22,3 +22,4 @@ from earth2studio.models.px.pangu import Pangu3, Pangu6, Pangu24 from earth2studio.models.px.persistence import Persistence from earth2studio.models.px.sfno import SFNO +from earth2studio.models.px.stormcast import StormCast diff --git a/earth2studio/models/px/stormcast.py b/earth2studio/models/px/stormcast.py new file mode 100644 index 00000000..430b009d --- /dev/null +++ b/earth2studio/models/px/stormcast.py @@ -0,0 +1,407 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from collections import OrderedDict +from collections.abc import Generator, Iterator +from itertools import product + +import modulus +import numpy as np +import torch +import xarray as xr +from modulus.models import Module +from modulus.utils.generative import deterministic_sampler +from omegaconf import OmegaConf + +from earth2studio.data import DataSource, fetch_data +from earth2studio.models.auto import AutoModelMixin, Package +from earth2studio.models.batch import batch_coords, batch_func +from earth2studio.models.dx.base import DiagnosticModel +from earth2studio.models.px.utils import PrognosticMixin +from earth2studio.utils import ( + handshake_coords, + handshake_dim, +) +from earth2studio.utils.type import CoordSystem + +# Variables used in StormCastV1 paper +VARIABLES = ( + ["u10m", "v10m", "t2m", "mslp"] + + [ + var + str(level) + for var, level in product( + ["u", "v", "t", "q", "Z", "p"], + map( + lambda x: str(x) + "hl", + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 20, 25, 30], + ), + ) + if not ((var == "p") and (int(level.replace("hl", "")) > 20)) + ] + + [ + "refc", + ] +) + +CONDITIONING_VARIABLES = ["u10m", "v10m", "t2m", "tcwv", "mslp", "sp"] + [ + var + str(level) + for var, level in product(["u", "v", "z", "t", "q"], [1000, 850, 500, 250]) +] + +INVARIANTS = ["lsm", "orography"] + +# Extent of domain in StormCastV1 paper (HRRR Lambert projection indices) +X_START, X_END = 579, 1219 +Y_START, Y_END = 273, 785 + + +class StormCast(torch.nn.Module, AutoModelMixin, PrognosticMixin): + """StormCast generative convection-allowing model for regional forecasts + Consists of two core models, a regression and diffusion model + This class implements StormCastV1, the model released in: + https://arxiv.org/abs/2408.10958 + Model time step size is 1 hour, taking as input: + - High-resolution (3km) HRRR state over the central United States (99 vars) + - High-resolution land-sea mask and orography invariants + - Coarse resolution (25km) global state (26 vars) + The high-resolution grid is the HRRR Lambert conformal projection + Coarse-resolution inputs are regridded to the HRRR grid internally. + + Parameters + ---------- + regression_model (torch.nn.Module): Deterministic model used to make an initial prediction + diffusion_model (torch.nn.Module): Generative model correcting the deterministic prediciton + lat (np.array): Latitude array (2D) of the domain + lon (np.array): Latitude array (2D) of the domain + means (torch.Tensor): Mean value of each input high-resolution variable + stds (torch.Tensor): Standard deviation of each input high-resolution variable + invariants (torch.Tensor): Static invariant quantities + variables (np.array, optional): High-resolution variables Defaults to np.array(VARIABLES). + conditioning_means (torch.Tensor | None, optional): Means to normalize conditioning data. Defaults to None. + conditioning_stds (torch.Tensor | None, optional): Stds to normalize conditioning data Defaults to None. + conditioning_variables (np.array, optional): Global variables for conditioning. Defaults to np.array(CONDITIONING_VARIABLES). + conditioning_data_source (DataSource | None, optional): Data Source to use for global conditoining. Defaults to None. Required for running in iterator mode + sampler_args (dict[str, float | int], optional): Arguments to pass to the diffusion sampler. Defaults to {}. + interp_method (str, optional): Interpolation method to use when regridding coarse conditoining data. Defaults to "linear". + """ + + def __init__( + self, + regression_model: torch.nn.Module, + diffusion_model: torch.nn.Module, + lat: np.array, + lon: np.array, + means: torch.Tensor, + stds: torch.Tensor, + invariants: torch.Tensor, + variables: np.array = np.array(VARIABLES), + conditioning_means: torch.Tensor | None = None, + conditioning_stds: torch.Tensor | None = None, + conditioning_variables: np.array = np.array(CONDITIONING_VARIABLES), + conditioning_data_source: DataSource | None = None, + sampler_args: dict[str, float | int] = {}, + interp_method: str = "linear", + ): + super().__init__() + self.regression_model = regression_model + self.diffusion_model = diffusion_model + + self.lat = lat + self.lon = lon + self.register_buffer("means", means) + self.register_buffer("stds", stds) + self.register_buffer("invariants", invariants) + self.interp_method = interp_method + self.sampler_args = sampler_args + + self.variables = variables + + self.conditioning_variables = conditioning_variables + self.conditioning_data_source = conditioning_data_source + if conditioning_data_source is None: + warnings.warn( + "No conditioning data source was provided to StormCast, " + + "set the conditioning_data_source attribute of the model " + + "before running inference." + ) + + if conditioning_means is not None: + self.register_buffer("conditioning_means", conditioning_means) + + if conditioning_stds is not None: + self.register_buffer("conditioning_stds", conditioning_stds) + + def input_coords(self) -> CoordSystem: + """Input coordinate system""" + return OrderedDict( + { + "batch": np.empty(0), + "time": np.empty(0), + "lead_time": np.array([np.timedelta64(0, "h")]), + "variable": np.array(self.variables), + "lat": self.lat, + "lon": self.lon, + } + ) + + @batch_coords() + def output_coords(self, input_coords: CoordSystem) -> CoordSystem: + """Output coordinate system of diagnostic model + + Parameters + ---------- + input_coords : CoordSystem + Input coordinate system to transform into output_coords + by default None, will use self.input_coords. + + Returns + ------- + CoordSystem + Coordinate system dictionary + """ + + output_coords = OrderedDict( + { + "batch": np.empty(0), + "time": np.empty(0), + "lead_time": np.array([np.timedelta64(1, "h")]), + "variable": np.array(self.variables), + "lat": self.lat, + "lon": self.lon, + } + ) + if input_coords is None: + return output_coords + + target_input_coords = self.input_coords() + + handshake_dim(input_coords, "lon", 5) + handshake_dim(input_coords, "lat", 4) + handshake_dim(input_coords, "variable", 3) + handshake_coords(input_coords, target_input_coords, "lon") + handshake_coords(input_coords, target_input_coords, "lat") + handshake_coords(input_coords, target_input_coords, "variable") + + output_coords["batch"] = input_coords["batch"] + output_coords["time"] = input_coords["time"] + output_coords["lead_time"] = ( + output_coords["lead_time"] + input_coords["lead_time"] + ) + return output_coords + + @classmethod + def load_default_package(cls) -> Package: + """Load prognostic package""" + package = Package( + "ngc://models/nvidia/modulus/stormcast-v1-era5-hrrr@1.0.1", + cache_options={ + "cache_storage": Package.default_cache("stormcast"), + "same_names": True, + }, + ) + return package + + @classmethod + def load_model(cls, package: Package) -> DiagnosticModel: + """Load StormCast model.""" + + # Require appropriate modulus version + installed_version = modulus.__version__ + if installed_version < "0.10.0a0": + raise RuntimeError( + f"modulus version 0.10.0a0 or later is required " + f"to load the StormCast package from NGC, " + f"but version {installed_version} is installed. " + f"Please pip install " + f"nvidia-modulus @ git+https://github.com/NVIDIA/modulus.git" + ) + + OmegaConf.register_new_resolver("eval", eval) + + # load model registry: + config = OmegaConf.load(package.resolve("model.yaml")) + + regression = Module.from_checkpoint(package.resolve("StormCastUNet.0.0.mdlus")) + diffusion = Module.from_checkpoint(package.resolve("EDMPrecond.0.0.mdlus")) + + # Load metadata: means, stds, grid + metadata = xr.open_zarr(package.resolve("metadata.zarr.zip")) + + variables = metadata["variable"].values + lat = metadata.coords["lat"].values + lon = metadata.coords["lon"].values + conditioning_variables = metadata["conditioning_variable"].values + + # Expand dims and tensorify normalization buffers + means = torch.from_numpy(metadata["means"].values[None, :, None, None]) + stds = torch.from_numpy(metadata["stds"].values[None, :, None, None]) + conditioning_means = torch.from_numpy( + metadata["conditioning_means"].values[None, :, None, None] + ) + conditioning_stds = torch.from_numpy( + metadata["conditioning_stds"].values[None, :, None, None] + ) + + # Load invariants + invariants = metadata["invariants"].sel(invariant=config.data.invariants).values + invariants = torch.from_numpy(invariants).repeat(1, 1, 1, 1) + + # EDM sampler arguments + if config.sampler_args is not None: + sampler_args = config.sampler_args + else: + sampler_args = {} + + return cls( + regression, + diffusion, + lat, + lon, + means, + stds, + invariants, + variables=variables, + conditioning_means=conditioning_means, + conditioning_stds=conditioning_stds, + conditioning_variables=conditioning_variables, + sampler_args=sampler_args, + ) + + @torch.inference_mode() + def _forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + + # Scale data + if "conditioning_means" in self._buffers: + conditioning = conditioning - self.conditioning_means + if "conditioning_stds" in self._buffers: + conditioning = conditioning / self.conditioning_stds + + x = (x - self.means) / self.stds + + # Run regression model + invariant_tensor = self.invariants.repeat(x.shape[0], 1, 1, 1) + concats = torch.cat((x, conditioning, invariant_tensor), dim=1) + + out = self.regression_model(concats) + + # Concat for diffusion conditioning + condition = torch.cat((x, out, invariant_tensor), dim=1) + latents = torch.randn_like(x) + + # Run diffusion model + edm_out = deterministic_sampler( + self.diffusion_model, + latents=latents, + img_lr=condition, + **self.sampler_args, + ) + + out += edm_out + + out = out * self.stds + self.means + + return out + + @torch.inference_mode() + @batch_func() + def __call__( + self, + x: torch.Tensor, + coords: CoordSystem, + ) -> tuple[torch.Tensor, CoordSystem]: + """Forward pass of diagnostic""" + + if self.conditioning_data_source is None: + raise RuntimeError( + "StormCast has been called without initializing the model's conditioning_data_source" + ) + + conditioning, conditioning_coords = fetch_data( + self.conditioning_data_source, + time=coords["time"], + variable=self.conditioning_variables, + lead_time=coords["lead_time"], + device=x.device, + interp_to=coords, + interp_method=self.interp_method, + ) + # Add a batch dim + conditioning = conditioning.repeat(x.shape[0], 1, 1, 1, 1, 1) + conditioning_coords.update({"batch": np.empty(0)}) + conditioning_coords.move_to_end("batch", last=False) + + # Handshake conditioning coords + handshake_coords(conditioning_coords, coords, "lon") + handshake_coords(conditioning_coords, coords, "lat") + handshake_coords(conditioning_coords, coords, "lead_time") + handshake_coords(conditioning_coords, coords, "time") + + output_coords = self.output_coords(coords) + + for i, _ in enumerate(coords["batch"]): + for j, _ in enumerate(coords["time"]): + for k, _ in enumerate(coords["lead_time"]): + x[i, j, k : k + 1] = self._forward( + x[i, j, k : k + 1], conditioning[i, j, k : k + 1] + ) + + return x, output_coords + + @batch_func() + def _default_generator( + self, + x: torch.Tensor, + coords: CoordSystem, + ) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]: + + coords = coords.copy() + self.output_coords(coords) + yield x, coords + + if self.conditioning_data_source is None: + raise ValueError( + "A conditioning data source must be available for the iterator to function." + ) + + while True: + # Front hook + x, coords = self.front_hook(x, coords) + # Forward is identity operator + + x, coords = self.__call__(x, coords) + # Rear hook + x, coords = self.rear_hook(x, coords) + yield x, coords.copy() + + def create_iterator( + self, x: torch.Tensor, coords: CoordSystem + ) -> Iterator[tuple[torch.Tensor, CoordSystem]]: + """Creates a iterator which can be used to perform time-integration of the + prognostic model. Will return the initial condition first (0th step). + Parameters + ---------- + x : torch.Tensor + Input tensor + coords : CoordSystem + Input coordinate system + Yields + ------ + Iterator[tuple[torch.Tensor, CoordSystem]] + Iterator that generates time-steps of the prognostic model container the + output data tensor and coordinate system dictionary. + """ + yield from self._default_generator(x, coords) diff --git a/earth2studio/utils/coords.py b/earth2studio/utils/coords.py index df337b3f..682e9402 100644 --- a/earth2studio/utils/coords.py +++ b/earth2studio/utils/coords.py @@ -111,6 +111,11 @@ def handshake_coords( f"Required dimension {required_dim} not found in target coordinates" ) + if input_coords[required_dim].shape != target_coords[required_dim].shape: + raise ValueError( + f"Coordinate systems for required dim {required_dim} are not the same" + ) + if not np.all(input_coords[required_dim] == target_coords[required_dim]): raise ValueError( f"Coordinate systems for required dim {required_dim} are not the same" @@ -170,7 +175,9 @@ def map_coords( """A basic interpolation util to map between coordinate systems with common dimensions. Namely, `output_coords` should consist of keys are present in `input_coords`. Note that `output_coords` do not need have all the dimensions of the - `input_coords`. + `input_coords`. Does not support more advanced interpolation, such as between a regular + and curvilinear grid. For such use-cases, use `fetch_data` or `prep_data_array` from + `data/utils`. Parameters ---------- @@ -194,6 +201,7 @@ def map_coords( If output coordinate has a dimension not in the input coordinate ValueError If value in non-numeric output coordinate is not in input coordinate + If asked to interpolate between 2D lat/lon (curvilinear) coordinates """ mapped_coords = input_coords.copy() @@ -217,6 +225,14 @@ def map_coords( # skip interpolation if input and output coords are identical continue + if key in ["lat", "lon"] and len(inc.shape) > 1 or len(outc.shape) > 1: + # Guard against 2D lat/lon grids (curvilinear case) + raise ValueError( + f"Coordinate dim {key} in input or mapped coords is \ + two-dimensional; please use fetch_data or \ + prep_data_array to regrid/interpolate first." + ) + indx = np.where(inc == outc[0])[0][0] inc_slice = inc[indx : indx + outc.shape[0]] # Slice condition diff --git a/examples/05_ensemble_workflow_extend.py b/examples/05_ensemble_workflow_extend.py index 35fc863c..7b16c264 100644 --- a/examples/05_ensemble_workflow_extend.py +++ b/examples/05_ensemble_workflow_extend.py @@ -44,7 +44,8 @@ # %% # .. literalinclude:: ../../earth2studio/run.py # :language: python -# :lines: 116-156 +# :start-after: # sphinx - ensemble start +# :end-before: # sphinx - ensemble end # %% # We need the following: @@ -84,6 +85,7 @@ # applies the same noise amplitude to every variable. We can create a custom wrapper # that only applies the perturbation method to a particular variable instead. + # %% class ApplyToVariable: """Apply a perturbation to only a particular variable.""" diff --git a/pyproject.toml b/pyproject.toml index 779592fc..90d944d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "netCDF4>=1.6.4", "ngcsdk>=3.48.0,<3.55.0", "numpy>=1.24.0", - "nvidia-modulus>=0.6.0", + "nvidia-modulus@git+https://github.com/NVIDIA/modulus.git", "python-dotenv", "s3fs>=2023.5.0", "setuptools>=67.6.0", diff --git a/test/data/test_data_utils.py b/test/data/test_data_utils.py index 15cdf771..68a6d896 100644 --- a/test/data/test_data_utils.py +++ b/test/data/test_data_utils.py @@ -104,6 +104,135 @@ def test_fetch_data(time, lead_time, device): assert not torch.isnan(x).any() +@pytest.mark.parametrize( + "time", + [ + np.array([np.datetime64("1993-04-05T00:00")]), + np.array( + [ + np.datetime64("1999-10-11T12:00"), + np.datetime64("2001-06-04T00:00"), + ] + ), + ], +) +@pytest.mark.parametrize( + "lead_time", + [ + np.array([np.timedelta64(0, "h")]), + np.array([np.timedelta64(-6, "h"), np.timedelta64(0, "h")]), + ], +) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_fetch_data_interp(time, lead_time, device): + # Original (source) domain + variable = np.array(["a", "b", "c"]) + domain = OrderedDict( + { + "lat": np.linspace(90, -90, 721, endpoint=True), + "lon": np.linspace(0, 360, 1440), + } + ) + r = Random(domain) + + # Target domain, 1d lat/lon coords + lat = np.linspace(60, 20, num=256) + lon = np.linspace(130, 60, num=512) + target_coords = OrderedDict( + { + "lat": lat, + "lon": lon, + } + ) + + # nearest neighbor interp + x, coords = fetch_data( + r, + time, + variable, + lead_time, + device=device, + interp_to=target_coords, + interp_method="nearest", + ) + + assert x.device == torch.device(device) + assert np.all(coords["time"] == time) + assert np.all(coords["lead_time"] == lead_time) + assert np.all(coords["variable"] == variable) + assert coords["lat"].shape == (256,) + assert coords["lon"].shape == (512,) + assert not torch.isnan(x).any() + + # bilinear interp + x, coords = fetch_data( + r, + time, + variable, + lead_time, + device=device, + interp_to=target_coords, + interp_method="linear", + ) + + assert x.device == torch.device(device) + assert np.all(coords["time"] == time) + assert np.all(coords["lead_time"] == lead_time) + assert np.all(coords["variable"] == variable) + assert coords["lat"].shape == (256,) + assert coords["lon"].shape == (512,) + assert not torch.isnan(x).any() + + # Target domain, 2d lat/lon coords + lat = np.linspace(60, 20, num=256) + lon = np.linspace(130, 60, num=512) + lat2d, lon2d = np.meshgrid(lat, lon, indexing="ij") + target_coords = OrderedDict( + { + "lat": lat2d, + "lon": lon2d, + } + ) + + # nearest neighbor interp + x, coords = fetch_data( + r, + time, + variable, + lead_time, + device=device, + interp_to=target_coords, + interp_method="nearest", + ) + + assert x.device == torch.device(device) + assert np.all(coords["time"] == time) + assert np.all(coords["lead_time"] == lead_time) + assert np.all(coords["variable"] == variable) + assert coords["lat"].shape == (256, 512) + assert coords["lon"].shape == (256, 512) + assert not torch.isnan(x).any() + + # bilinear interp + x, coords = fetch_data( + r, + time, + variable, + lead_time, + device=device, + interp_to=target_coords, + interp_method="linear", + ) + + assert x.device == torch.device(device) + assert np.all(coords["time"] == time) + assert np.all(coords["lead_time"] == lead_time) + assert np.all(coords["variable"] == variable) + assert coords["lat"].shape == (256, 512) + assert coords["lon"].shape == (256, 512) + assert not torch.isnan(x).any() + + @pytest.mark.parametrize( "time", [ diff --git a/test/data/test_hrrr.py b/test/data/test_hrrr.py index 50f6d5dc..50b5abbc 100644 --- a/test/data/test_hrrr.py +++ b/test/data/test_hrrr.py @@ -37,7 +37,7 @@ ], ], ) -@pytest.mark.parametrize("variable", ["t2m", ["u10m", "u100"]]) +@pytest.mark.parametrize("variable", ["t2m", ["u10m", "u100"], ["u1hl"]]) def test_hrrr_fetch(time, variable): ds = HRRR(cache=False) @@ -115,7 +115,7 @@ def test_hrrr_fx_fetch(time, lead_time): np.array([np.datetime64("2024-01-01T00:00")]), ], ) -@pytest.mark.parametrize("variable", [["t2m", "sp"]]) +@pytest.mark.parametrize("variable", [["t2m", "sp", "t10hl"]]) @pytest.mark.parametrize("cache", [True, False]) def test_hrrr_cache(time, variable, cache): diff --git a/test/data/test_random.py b/test/data/test_random.py index 15030f46..0d135a40 100644 --- a/test/data/test_random.py +++ b/test/data/test_random.py @@ -55,3 +55,25 @@ def test_random(time, variable, lat, lon): assert shape[2] == len(coords["lat"]) assert shape[3] == len(coords["lon"]) assert not np.isnan(data.values).any() + + # Curvilinear coordinates (2d lat/lon arrays) + lat, lon = np.meshgrid(lat, lon, indexing="ij") + + coords = OrderedDict({"lat": lat, "lon": lon}) + + data_source = Random(coords) + + data = data_source(time, variable) + shape = data.shape + + if isinstance(variable, str): + variable = [variable] + + if isinstance(time, datetime.datetime): + time = [time] + + assert shape[0] == len(time) + assert shape[1] == len(variable) + assert shape[2] == coords["lat"].shape[0] + assert shape[3] == coords["lon"].shape[1] + assert not np.isnan(data.values).any() diff --git a/test/lexicon/test_hrrr_lexicon.py b/test/lexicon/test_hrrr_lexicon.py index b03b6bee..fa10affe 100644 --- a/test/lexicon/test_hrrr_lexicon.py +++ b/test/lexicon/test_hrrr_lexicon.py @@ -21,7 +21,13 @@ @pytest.mark.parametrize( - "variable", [["t2m"], ["u10m", "v200"], ["u80m", "z500", "q700"]] + "variable", + [ + ["t2m"], + ["u10m", "v200"], + ["u80m", "z500", "q700"], + ["u1hl", "v4hl", "t20hl", "p30hl"], + ], ) @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) def test_run_deterministic(variable, device): diff --git a/test/models/px/test_stormcast.py b/test/models/px/test_stormcast.py new file mode 100644 index 00000000..c7c22b96 --- /dev/null +++ b/test/models/px/test_stormcast.py @@ -0,0 +1,322 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from collections.abc import Iterable + +import numpy as np +import pytest +import torch + +from earth2studio.data import Random, fetch_data +from earth2studio.models.px import StormCast +from earth2studio.utils import handshake_dim + + +# Spoof models with same call signature +class PhooStormCastRegressionModel(torch.nn.Module): + def __init__(self, out_vars=3): + super().__init__() + self.out_vars = out_vars + + def forward(self, x): + return x[:, : self.out_vars, :, :] + + +class PhooStormCastDiffusionModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.sigma_min = 0.0 + self.sigma_max = 88.0 + + def forward(self, x, noise, class_labels=None, condition=None): + return x + + def round_sigma(self, sigma): + return torch.as_tensor(sigma) + + +@pytest.mark.parametrize( + "time", + [ + np.array([np.datetime64("2020-04-05T00:00")]), + np.array( + [ + np.datetime64("2020-10-11T12:00"), + np.datetime64("2020-06-04T00:00"), + ] + ), + ], +) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_stormcast_call(time, device): + + # Spoof models + regression = PhooStormCastRegressionModel() + diffusion = PhooStormCastDiffusionModel() + + # Init data sources + nvar, nvar_cond, nlat, nlon = 3, 5, 128, 160 + lat, lon = np.meshgrid( + np.linspace(30, 46, num=nlat), np.linspace(250, 275, num=nlon), indexing="ij" + ) + dc = OrderedDict([("lat", lat), ("lon", lon)]) + r = Random(dc) + r_condition = Random( + OrderedDict( + [ + ("lat", np.linspace(90, -90, num=181, endpoint=True)), + ("lon", np.linspace(0, 360, num=360)), + ] + ) + ) + + # Spoof variable names + variables = np.array(["u%02d" % i for i in range(nvar)]) + + # Init model with explicit conditioning data in constructor + means = torch.zeros(1, nvar, 1, 1) + stds = torch.ones(1, nvar, 1, 1) + invariants = torch.randn(1, 2, nlat, nlon) + conditioning_means = torch.randn(1, nvar_cond, 1, 1, device=device) + conditioning_stds = torch.randn(1, nvar_cond, 1, 1, device=device) + conditioning_variables = np.array(["u%02d" % i for i in range(nvar_cond)]) + p = StormCast( + regression, + diffusion, + lat, + lon, + means, + stds, + invariants, + variables=variables, + conditioning_means=conditioning_means, + conditioning_stds=conditioning_stds, + conditioning_variables=conditioning_variables, + conditioning_data_source=r_condition, + sampler_args={"num_steps": 2}, + ).to(device) + + # Get Data and convert to tensor, coords + lead_time = p.input_coords()["lead_time"] + variable = p.input_coords()["variable"] + x, coords = fetch_data(r, time, variable, lead_time, device=device) + + out, out_coords = p(x, coords) + + if not isinstance(time, Iterable): + time = [time] + + assert out.shape == torch.Size([len(time), 1, nvar, nlat, nlon]) + assert (out_coords["variable"] == p.output_coords(coords)["variable"]).all() + assert np.all(out_coords["time"] == time) + handshake_dim(out_coords, "lon", 4) + handshake_dim(out_coords, "lat", 3) + handshake_dim(out_coords, "variable", 2) + handshake_dim(out_coords, "lead_time", 1) + handshake_dim(out_coords, "time", 0) + + +@pytest.mark.parametrize( + "ensemble", + [1, 2], +) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_stormcast_iter(ensemble, device): + + time = np.array([np.datetime64("2020-04-05T00:00")]) + + # Spoof models + regression = PhooStormCastRegressionModel() + diffusion = PhooStormCastDiffusionModel() + + # Init data sources + nvar, nvar_cond, nlat, nlon = 3, 5, 128, 160 + lat, lon = np.meshgrid( + np.linspace(30, 46, num=nlat), np.linspace(250, 275, num=nlon), indexing="ij" + ) + dc = OrderedDict([("lat", lat), ("lon", lon)]) + r = Random(dc) + r_condition = Random( + OrderedDict( + [ + ("lat", np.linspace(90, -90, num=181, endpoint=True)), + ("lon", np.linspace(0, 360, num=360)), + ] + ) + ) + + # Init model with explicit conditioning data in constructor + variables = np.array(["u%02d" % i for i in range(nvar)]) + means = torch.zeros(1, nvar, 1, 1) + stds = torch.ones(1, nvar, 1, 1) + invariants = torch.randn(1, 2, nlat, nlon) + conditioning_means = torch.randn(1, nvar_cond, 1, 1, device=device) + conditioning_stds = torch.randn(1, nvar_cond, 1, 1, device=device) + conditioning_variables = np.array(["u%02d" % i for i in range(nvar_cond)]) + p = StormCast( + regression, + diffusion, + lat, + lon, + means, + stds, + invariants, + variables, + conditioning_means=conditioning_means, + conditioning_stds=conditioning_stds, + conditioning_variables=conditioning_variables, + conditioning_data_source=r_condition, + sampler_args={"num_steps": 2}, + ).to(device) + + # Get Data and convert to tensor, coords + lead_time = p.input_coords()["lead_time"] + variable = p.input_coords()["variable"] + x, coords = fetch_data(r, time, variable, lead_time, device=device) + + # Add ensemble to front + x = x.unsqueeze(0).repeat(ensemble, 1, 1, 1, 1, 1) + coords.update({"ensemble": np.arange(ensemble)}) + coords.move_to_end("ensemble", last=False) + + p_iter = p.create_iterator(x, coords) + + if not isinstance(time, Iterable): + time = [time] + + # Get generator + next(p_iter) # Skip first which should return the input + for i, (out, out_coords) in enumerate(p_iter): + assert len(out.shape) == 6 + assert out.shape == torch.Size([ensemble, len(time), 1, nvar, nlat, nlon]) + assert ( + out_coords["variable"] == p.output_coords(p.input_coords())["variable"] + ).all() + assert (out_coords["ensemble"] == np.arange(ensemble)).all() + assert out_coords["lead_time"][0] == np.timedelta64(i + 1, "h") + + if i > 5: + break + + +@pytest.mark.parametrize( + "dc", + [ + OrderedDict( + { + "lat": np.random.randn(312, 640), + "lon": np.random.randn(312, 640), + } + ), + ], +) +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_stormcast_exceptions(dc, device): + time = np.array([np.datetime64("2020-04-05T00:00")]) + + regression = PhooStormCastRegressionModel() + diffusion = PhooStormCastDiffusionModel() + + # Build model with correct coords but no conditioning info + lat, lon = np.meshgrid( + np.linspace(30, 46, num=512), np.linspace(250, 275, num=640), indexing="ij" + ) + r = Random(OrderedDict([("lat", lat), ("lon", lon)])) + means = torch.zeros(1, 99, 1, 1) + stds = torch.ones(1, 99, 1, 1) + invariants = torch.randn(1, 2, 512, 640) + p = StormCast( + regression, + diffusion, + lat, + lon, + means, + stds, + invariants, + ).to(device) + + # Get Data and convert to tensor, coords + lead_time = p.input_coords()["lead_time"] + variable = p.input_coords()["variable"] + x, coords = fetch_data(r, time, variable, lead_time, device=device) + + with pytest.raises(RuntimeError): + # Calling with no conditioning info should fail + p(x, coords) + + # Create iterator and consume first batch (initial condition) + p_iter = p.create_iterator(x, coords) + next(p_iter) + with pytest.raises(ValueError): + # Using the generator with no built-in conditioning should fail + next(p_iter) + + +@pytest.fixture(scope="module") +def model(model_cache_context) -> StormCast: + # Test only on cuda device + with model_cache_context(): + package = StormCast.load_default_package() + p = StormCast.load_model(package) + return p + + +# @pytest.mark.ci_cache +@pytest.mark.timeout(360) +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_stormcast_package(device, model): + torch.cuda.empty_cache() + time = np.array([np.datetime64("2020-04-05T00:00")]) + # Test the cached model package StormCast + p = model.to(device) + + # Create random data sources + dc = OrderedDict([("lat", p.lat), ("lon", p.lon)]) + r = Random(dc) + r_condition = Random( + OrderedDict( + [ + ("lat", np.linspace(90, -90, num=721, endpoint=True)), + ("lon", np.linspace(0, 360, num=1440)), + ] + ) + ) + + # Manually set the condition data source (necessary as NGC package doesn't specify) + p.conditioning_data_source = r_condition + + # Decrease the number of edm sampling steps to speed up the test + p.sampler_args = {"num_steps": 2} + + # Get Data and convert to tensor, coords + lead_time = p.input_coords()["lead_time"] + variable = p.input_coords()["variable"] + x, coords = fetch_data(r, time, variable, lead_time, device=device) + + out, out_coords = p(x, coords) + + if not isinstance(time, Iterable): + time = [time] + + assert out.shape == torch.Size([len(time), 1, 99, 512, 640]) + assert (out_coords["variable"] == p.output_coords(coords)["variable"]).all() + assert np.all(out_coords["time"] == time) + handshake_dim(out_coords, "lon", 4) + handshake_dim(out_coords, "lat", 3) + handshake_dim(out_coords, "variable", 2) + handshake_dim(out_coords, "lead_time", 1) + handshake_dim(out_coords, "time", 0) diff --git a/test/statistics/test_acc.py b/test/statistics/test_acc.py index 70cabeb4..9fd5962e 100644 --- a/test/statistics/test_acc.py +++ b/test/statistics/test_acc.py @@ -119,7 +119,7 @@ def test_acc(reduction_weights: tuple[list[str], np.ndarray], device: str) -> No { "ensemble": np.arange(10), "time": np.array([np.datetime64("1993-04-05T00:00")]), - "variable": ["t2m", "tcwv"], + "variable": np.array(["t2m", "tcwv"]), "lat": np.linspace(-90.0, 90.0, 361), "lon": np.linspace(0.0, 360.0, 720, endpoint=False), } @@ -182,7 +182,7 @@ def test_acc_leadtime( "ensemble": np.arange(10), "time": np.array([np.datetime64("1993-04-05T00:00")]), "lead_time": np.array([np.timedelta64(6, "h")]), - "variable": ["t2m", "tcwv"], + "variable": np.array(["t2m", "tcwv"]), "lat": np.linspace(-90.0, 90.0, 361), "lon": np.linspace(0.0, 360.0, 720, endpoint=False), } @@ -250,7 +250,7 @@ def test_acc_failures( "ensemble": np.arange(10), "time": np.array([np.datetime64("1993-04-05T00:00")]), "lead_time": np.array([np.timedelta64(6, "h")]), - "variable": ["t2m", "tcwv"], + "variable": np.array(["t2m", "tcwv"]), "lat": np.linspace(-90.0, 90.0, 361), "lon": np.linspace(0.0, 360.0, 720, endpoint=False), } diff --git a/test/statistics/test_crps.py b/test/statistics/test_crps.py index 5b0d1598..d9929974 100644 --- a/test/statistics/test_crps.py +++ b/test/statistics/test_crps.py @@ -57,7 +57,7 @@ def test_crps( { "ensemble": np.arange(10), "time": np.array([np.datetime64("1993-04-05T00:00")]), - "variable": ["t2m", "tcwv"], + "variable": np.array(["t2m", "tcwv"]), "lat": np.linspace(-90.0, 90.0, 361), "lon": np.linspace(0.0, 360.0, 720, endpoint=False), } @@ -96,7 +96,7 @@ def test_crps_failures(device: str) -> None: { "ensemble": np.arange(10), "time": np.array([np.datetime64("1993-04-05T00:00")]), - "variable": ["t2m", "tcwv"], + "variable": np.array(["t2m", "tcwv"]), "lat": np.linspace(-90.0, 90.0, 361), "lon": np.linspace(0.0, 360.0, 720, endpoint=False), } diff --git a/test/statistics/test_metrics.py b/test/statistics/test_metrics.py index 0664ae1a..6adb9acb 100644 --- a/test/statistics/test_metrics.py +++ b/test/statistics/test_metrics.py @@ -45,7 +45,7 @@ def test_rmse(reduction_weights: tuple[list[str], np.ndarray], device: str) -> N { "ensemble": np.arange(10), "time": np.array([np.datetime64("1993-04-05T00:00")]), - "variable": ["t2m", "tcwv"], + "variable": np.array(["t2m", "tcwv"]), "lat": np.linspace(-90.0, 90.0, 361), "lon": np.linspace(0.0, 360.0, 720, endpoint=False), } @@ -140,7 +140,7 @@ def test_spread_skill( "time": np.array( [np.datetime64("1993-04-05T00:00"), np.datetime64("1993-04-06T00:00")] ), - "variable": ["t2m", "tcwv"], + "variable": np.array(["t2m", "tcwv"]), "lat": np.linspace(-90.0, 90.0, 361), "lon": np.linspace(0.0, 360.0, 720, endpoint=False), } @@ -191,7 +191,7 @@ def test_skill_spread( "time": np.array( [np.datetime64("1993-04-05T00:00"), np.datetime64("1993-04-06T00:00")] ), - "variable": ["t2m", "tcwv"], + "variable": np.array(["t2m", "tcwv"]), "lat": np.linspace(-90.0, 90.0, 361), "lon": np.linspace(0.0, 360.0, 720, endpoint=False), } diff --git a/test/statistics/test_moments.py b/test/statistics/test_moments.py index b6918249..2c632d4c 100644 --- a/test/statistics/test_moments.py +++ b/test/statistics/test_moments.py @@ -42,7 +42,7 @@ def test_mean(reduction_weights: tuple[list[str], np.ndarray], device: str) -> N { "ensemble": np.arange(10), "time": np.array([np.datetime64("1993-04-05T00:00")]), - "variable": ["t2m", "tcwv"], + "variable": np.array(["t2m", "tcwv"]), "lat": np.linspace(-90.0, 90.0, 361), "lon": np.linspace(0.0, 360.0, 720, endpoint=False), } @@ -136,7 +136,7 @@ def test_var(reduction_weights: tuple[list[str], np.ndarray], device: str) -> No { "ensemble": np.arange(10), "time": np.array([np.datetime64("1993-04-05T00:00")]), - "variable": ["t2m", "tcwv"], + "variable": np.array(["t2m", "tcwv"]), "lat": np.linspace(-90.0, 90.0, 361), "lon": np.linspace(0.0, 360.0, 720, endpoint=False), } @@ -176,7 +176,7 @@ def test_std(reduction_weights: tuple[list[str], np.ndarray], device: str) -> No { "ensemble": np.arange(10), "time": np.array([np.datetime64("1993-04-05T00:00")]), - "variable": ["t2m", "tcwv"], + "variable": np.array(["t2m", "tcwv"]), "lat": np.linspace(-90.0, 90.0, 361), "lon": np.linspace(0.0, 360.0, 720, endpoint=False), } diff --git a/test/statistics/test_ranks.py b/test/statistics/test_ranks.py index fae184ff..70154a92 100644 --- a/test/statistics/test_ranks.py +++ b/test/statistics/test_ranks.py @@ -41,7 +41,7 @@ def test_rank_histogram(ensemble_dimension: str, device: str) -> None: { "ensemble": np.arange(10), "time": np.array([np.datetime64("1993-04-05T00:00")]), - "variable": ["t2m", "tcwv"], + "variable": np.array(["t2m", "tcwv"]), "lat": np.linspace(-90.0, 90.0, 361), "lon": np.linspace(0.0, 360.0, 720, endpoint=False), } diff --git a/test/utils/test_coords.py b/test/utils/test_coords.py index 8bea118c..734f261d 100644 --- a/test/utils/test_coords.py +++ b/test/utils/test_coords.py @@ -158,3 +158,12 @@ def test_map_errors(): with pytest.raises(ValueError): map_coords(data, coords, OrderedDict([("variable", np.array(["d"]))])) + + curv_coords = OrderedDict( + [ + ("variable", np.array(["a", "b", "c"])), + ("lat", np.array([[1, 2, 3], [4, 5, 6]])), + ] + ) + with pytest.raises(ValueError): + map_coords(data, coords, curv_coords)