diff --git a/examples/dev_sandbox/prof.py b/examples/dev_sandbox/prof.py index f4cb2a5..34a284a 100644 --- a/examples/dev_sandbox/prof.py +++ b/examples/dev_sandbox/prof.py @@ -2,20 +2,41 @@ import clearwater_modules import time import sys +import xarray as xr +import numpy as np -def main(iters: int): + +def main(iters: int, baseline: bool): ti = time.time() # define starting state values - state_i = { - 'water_temp_c': 40.0, - 'surface_area': 1.0, - 'volume': 1.0, - } + if baseline: + state_i = { + 'water_temp_c': 40.0, + 'surface_area': 1.0, + 'volume': 1.0, + } + else: + state_i = { + 'water_temp_c': xr.DataArray( + np.full(10, 40), + dims='cell', + coords={'cell': np.arange(10)}), + 'surface_area': xr.DataArray( + np.full(10, 1.0), + dims='cell', + coords={'cell': np.arange(10)}), + 'volume': xr.DataArray( + np.full(10, 1.0), + dims='cell', + coords={'cell': np.arange(10)}), + } # instantiate the TSM module tsm = clearwater_modules.tsm.EnergyBudget( + time_steps=iters, initial_state_values=state_i, meteo_parameters={'wind_c': 1.0}, + updateable_static_variables=['wind_c'] ) print(tsm.static_variable_values) t2 = time.time() @@ -35,4 +56,4 @@ def main(iters: int): print('No argument given, defaulting to 100 iteration.') iters = 100 - main(iters=iters) + main(iters=iters, baseline=True) diff --git a/src/clearwater_modules/base.py b/src/clearwater_modules/base.py index bc5081d..ef2501b 100644 --- a/src/clearwater_modules/base.py +++ b/src/clearwater_modules/base.py @@ -1,6 +1,7 @@ """Stored base types shared by all sub-modules.""" import warnings import xarray as xr +import numpy as np import clearwater_modules.utils as utils import clearwater_modules.sorter as sorter from clearwater_modules.shared.types import ( @@ -28,6 +29,7 @@ class Model(CanRegisterVariable): def __init__( self, + time_steps: int, initial_state_values: Optional[InitialVariablesDict] = None, static_variable_values: Optional[InitialVariablesDict] = None, updateable_static_variables: Optional[list[str]] = None, @@ -38,6 +40,7 @@ def __init__( """Initialize the model, should be accessed by subclasses. Args: + time_steps: An integer to indicate the number of timesteps to run. initial_state_values: A dict with variable names as keys, and initial state variables as values. static_variable_values: A dict with variable names as keys, and static @@ -75,6 +78,7 @@ def __init__( initial_state_values=self.initial_state_values, static_variable_values=self.static_variable_values, updateable_static_variables=self.updateable_static_variables, + time_steps=time_steps, ) elif isinstance(hotstart_dataset, xr.Dataset): @@ -96,6 +100,7 @@ def _init_dataset_from_dicts( initial_state_values: InitialVariablesDict, static_variable_values: InitialVariablesDict, updateable_static_variables: list[str], + time_steps: int, ) -> xr.Dataset: """Initialize Model.dataset from dicts.""" if not isinstance(initial_state_values, dict): @@ -125,10 +130,14 @@ def _init_dataset_from_dicts( initial_state_values[static] = static_variable_values.pop(static) # initialize the main model dataset - dataset: xr.Dataset = self._init_state_arrays(initial_state_values) + dataset: xr.Dataset = self._init_state_arrays( + initial_state_values, + time_steps, + ) dataset: xr.Dataset = self._init_static_arrays( dataset, static_variable_values, + time_steps, ) print('Model initialized from input dicts successfully!.') @@ -145,10 +154,13 @@ def _init_from_dataset(self, hotstart_dataset: xr.Dataset) -> xr.Dataset: def _init_state_arrays( self, initial_state_values: InitialVariablesDict, + time_steps: int, ) -> xr.Dataset: """Initializes the state arrays.""" match_dims: list[str] = [] data_arrays: dict[str, xr.DataArray] = {} + coords: dict = {} + add_data: list[str] = [] for k, v in initial_state_values.items(): if k not in (self.state_variables_names + self.updateable_static_variables): @@ -161,37 +173,72 @@ def _init_state_arrays( else: utils.validate_arrays(v, *list(data_arrays.values())) data_arrays[k] = v + coords = coords | dict(data_arrays[k].coords.items()) + add_data.append(k) if len(data_arrays) > 0: - array_i = list(data_arrays.values())[0] + ds = xr.Dataset( + data_vars={ + k: ( + data_arrays[k].dims + (self.time_dim,), + np.full( + tuple(data_arrays[k].sizes[dim] for dim in data_arrays[k].dims) + (time_steps,), + np.nan + ) + ) + for k in data_arrays.keys() + }, + coords={ + **coords, + self.time_dim: np.arange(time_steps), + } + ) else: - array_i = xr.DataArray( - [[1.0]], - dims=['x', 'y'], - coords=[[1.0], [1.0]], + ds = xr.Dataset( + data_vars={ + k: ( + ('x', 'y', self.time_dim), + np.full((1, 1, time_steps), np.nan) + ) + for k in match_dims + }, + coords={'x': [1.0], 'y': [1.0], self.time_dim: np.arange(time_steps)} ) - for var_name in match_dims: + + for var_name in match_dims + add_data: variable = self.get_variable(var_name) attrs = { 'long_name': variable.long_name, 'units': variable.units, 'description': variable.description, } - data_arrays[var_name] = xr.full_like( - array_i, - initial_state_values[var_name], - dtype=type(initial_state_values[var_name]), - ) - data_arrays[var_name].attrs = attrs - ds = xr.Dataset( - data_vars=data_arrays, - coords=array_i.coords, - ) - return ds.expand_dims({self.time_dim: [0]}) + + if var_name not in data_arrays.keys(): + ds[var_name] = xr.DataArray( + np.full( + tuple(ds.sizes[dim] for dim in ds.dims), + np.nan + ), + dims=ds.dims + ) + + ds[var_name].loc[{self.time_dim: 0}] = xr.full_like( + ds[var_name].isel({self.time_dim: 0}), + initial_state_values[var_name], + dtype=type(initial_state_values[var_name]), + ) + + else: + ds[var_name].loc[{self.time_dim: 0}] = initial_state_values[var_name] + + ds[var_name].attrs = attrs + + return ds # ds.expand_dims({self.time_dim: np.arange(time_steps)}) def _init_static_arrays( self, dataset: xr.Dataset, static_variable_values: InitialVariablesDict, + time_steps: int, ) -> xr.Dataset: """Broadcasts static variables to an existing dataset. diff --git a/src/clearwater_modules/tsm/model.py b/src/clearwater_modules/tsm/model.py index ce9f319..f1580e7 100644 --- a/src/clearwater_modules/tsm/model.py +++ b/src/clearwater_modules/tsm/model.py @@ -18,6 +18,7 @@ class EnergyBudget(base.Model): def __init__( self, + time_steps: int, initial_state_values: Optional[base.InitialVariablesDict] = None, updateable_static_variables: Optional[list[str]] = None, meteo_parameters: Optional[dict[str, float]] = None, @@ -58,6 +59,7 @@ def __init__( #static_variable_values['use_sed_temp'] = use_sed_temp super().__init__( + time_steps=time_steps, initial_state_values=initial_state_values, static_variable_values=static_variable_values, updateable_static_variables=updateable_static_variables,