From 9b06e86a9ff2eb41c6cf2d4648d276ecb8401d53 Mon Sep 17 00:00:00 2001 From: Sarah Jordan Date: Tue, 27 Feb 2024 16:17:05 -0600 Subject: [PATCH] update increment timestep #68 todo: deal with timesteps and potentially change typehints update tests --- .../dev_sandbox/performance_profiling_tsm.py | 11 +-- examples/dev_sandbox/prof.py | 19 +++-- src/clearwater_modules/base.py | 84 ++++++++++++------- 3 files changed, 74 insertions(+), 40 deletions(-) diff --git a/examples/dev_sandbox/performance_profiling_tsm.py b/examples/dev_sandbox/performance_profiling_tsm.py index ffe6b58..767f863 100644 --- a/examples/dev_sandbox/performance_profiling_tsm.py +++ b/examples/dev_sandbox/performance_profiling_tsm.py @@ -61,6 +61,7 @@ def run_performance_test( # instantiate the TSM module tsm = clearwater_modules.tsm.EnergyBudget( + time_steps=iters, initial_state_values=state_i, meteo_parameters=meteo_parameters, ) @@ -89,9 +90,9 @@ def run_performance_test( sys.exit(1) log_file = sys.argv[1] - # iterations_list = [1, 10, 100, 1000, 10000, 100000] - # gridsize_list = [1, 1000, 10000] - iterations_list = [10000] - gridsize_list = [10000] - detailed_profile = True + iterations_list = [1, 10, 100, 1000, 10000, 100000] + gridsize_list = [1, 1000, 10000] + # iterations_list = [10000] + # gridsize_list = [10000] + detailed_profile = False run_performance_test(iterations_list, gridsize_list, log_file, detailed_profile) diff --git a/examples/dev_sandbox/prof.py b/examples/dev_sandbox/prof.py index 34a284a..cb9d538 100644 --- a/examples/dev_sandbox/prof.py +++ b/examples/dev_sandbox/prof.py @@ -6,16 +6,16 @@ import numpy as np -def main(iters: int, baseline: bool): +def main(iters: int, type: str): ti = time.time() # define starting state values - if baseline: + if type == 'baseline': state_i = { 'water_temp_c': 40.0, 'surface_area': 1.0, 'volume': 1.0, } - else: + elif type in ['arrays', 'hotstart']: state_i = { 'water_temp_c': xr.DataArray( np.full(10, 40), @@ -30,6 +30,7 @@ def main(iters: int, baseline: bool): dims='cell', coords={'cell': np.arange(10)}), } + # instantiate the TSM module tsm = clearwater_modules.tsm.EnergyBudget( @@ -38,8 +39,16 @@ def main(iters: int, baseline: bool): meteo_parameters={'wind_c': 1.0}, updateable_static_variables=['wind_c'] ) - print(tsm.static_variable_values) + t2 = time.time() + + if type == 'hotstart': + tsm = clearwater_modules.tsm.EnergyBudget( + time_steps=iters, + hotstart_dataset=tsm.dataset, + ) + t2 = time.time() + for _ in range(iters): tsm.increment_timestep() print(f'Increment timestep speed (average of {iters}): {(time.time() - t2) / 100}') @@ -56,4 +65,4 @@ def main(iters: int, baseline: bool): print('No argument given, defaulting to 100 iteration.') iters = 100 - main(iters=iters, baseline=True) + main(iters=iters, type='baseline') diff --git a/src/clearwater_modules/base.py b/src/clearwater_modules/base.py index ef2501b..c63dde4 100644 --- a/src/clearwater_modules/base.py +++ b/src/clearwater_modules/base.py @@ -36,6 +36,7 @@ def __init__( track_dynamic_variables: bool = True, hotstart_dataset: Optional[xr.Dataset] = None, time_dim: Optional[str] = None, + timestep: Optional[int] = 0, ) -> None: """Initialize the model, should be accessed by subclasses. @@ -62,6 +63,8 @@ def __init__( self.static_variable_values = static_variable_values self.hotstart_dataset = hotstart_dataset self.track_dynamic_variables = track_dynamic_variables + self.timestep = timestep + self.time_steps = time_steps + 1 # xarray indexing if not time_dim: time_dim = 'time_step' @@ -78,7 +81,7 @@ def __init__( initial_state_values=self.initial_state_values, static_variable_values=self.static_variable_values, updateable_static_variables=self.updateable_static_variables, - time_steps=time_steps, + time_steps=self.time_steps, ) elif isinstance(hotstart_dataset, xr.Dataset): @@ -374,6 +377,21 @@ def _non_updateable_static_variables(self) -> list[str]: var.name for var in self.static_variables if var.name not in self.updateable_static_variables ] return self.__non_updateable_static_variables + + def _iter_computations(self): + inputs = map( + lambda x: utils._prep_inputs( + self.dataset.isel({self.time_dim: self.timestep}), + x), + self.computation_order + ) + for name, func, arrays in inputs: + array: np.ndarray = func(*arrays) + dims = self.dataset[name].dims + if self.time_dim in dims: + self.dataset[name].loc[{self.time_dim: self.timestep}] = array + else: + self.dataset[name] = (dims, array) def increment_timestep( @@ -381,14 +399,17 @@ def increment_timestep( update_state_values: Optional[dict[str, xr.DataArray]] = None, ) -> xr.Dataset: """Run the process.""" + self.timestep +=1 + if update_state_values is None: update_state_values = {} - - # get the last timestep as a xr.DataArray - last_timestep: int = self.dataset[self.time_dim].values[-1] - timestep_ds: xr.Dataset = self.dataset.isel( - {self.time_dim: -1}, - ).copy(deep=True) + + # by default, set current timestep equal to last timestep + self.dataset[self.state_variables_names + self.updateable_static_variables].loc[ + {self.time_dim: self.timestep} + ] = self.dataset[self.state_variables_names + self.updateable_static_variables].isel( + {self.time_dim: self.timestep - 1} + ) # update the state variables as necessary (i.e. interacting w/ other models) for var_name, value in update_state_values.items(): @@ -396,30 +417,35 @@ def increment_timestep( raise ValueError( f'Variable {var_name} cannot be updated between timesteps, skipping.', ) - utils.validate_arrays(value, timestep_ds[var_name]) - timestep_ds[var_name] = value + utils.validate_arrays( + value, + self.dataset[var_name].isel( + {self.time_dim: self.timestep} + ) + ) + self.dataset[var_name].loc[{self.time_dim: self.timestep}] = value + + # add dynamic variables to ds + for dynamic_variable in self.dynamic_variables_names: + self.dataset[dynamic_variable] = xr.DataArray( + np.full( + tuple( + self.dataset[self.static_variables_names[0]].sizes[dim] + for dim in self.dataset[self.static_variables_names[0]].dims), + np.nan + ), + dims=self.dataset[self.static_variables_names[0]].dims + ) # compute the dynamic variables in order - timestep_ds = utils.iter_computations( - timestep_ds, - self.computation_order, - ) - if not self.track_dynamic_variables: - timestep_ds = timestep_ds.drop_vars(self.dynamic_variables_names) - - timestep_ds = timestep_ds.drop_vars(self._non_updateable_static_variables) - timestep_ds = timestep_ds.expand_dims( - {self.time_dim: [last_timestep + 1]}, - ) + self._iter_computations() - self.dataset = xr.concat( - [ - self.dataset, - timestep_ds, - ], - dim=self.time_dim, - data_vars='minimal', - ) + if not self.track_dynamic_variables: + self.dataset.loc[{self.time_dim: self.timestep}] = self.dataset.isel( + {self.time_dim: self.timestep} + ).drop_vars( + self.dynamic_variables_names + ) # add dynamic variable attributes if self.track_dynamic_variables: @@ -431,8 +457,6 @@ def increment_timestep( 'description': var.description, } - return self.dataset - def register_variable(models: CanRegisterVariable | Iterable[CanRegisterVariable]): """A decorator to register a variable with a model."""