From f9a510895cbe4075ce4fcae49da4ac45cb962d5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Schm=C3=B6lder?= Date: Mon, 27 Jan 2025 16:41:34 +0100 Subject: [PATCH] Fix setter for SizedAggregator --- CADETProcess/dataStructure/aggregator.py | 96 ++++++++++++++++++--- CADETProcess/dataStructure/dataStructure.py | 27 ++++-- CADETProcess/processModel/reaction.py | 39 ++++++--- 3 files changed, 131 insertions(+), 31 deletions(-) diff --git a/CADETProcess/dataStructure/aggregator.py b/CADETProcess/dataStructure/aggregator.py index 03a86b1e..b8c865f1 100644 --- a/CADETProcess/dataStructure/aggregator.py +++ b/CADETProcess/dataStructure/aggregator.py @@ -6,9 +6,29 @@ class SizedAggregator(Aggregator): """Aggregator for sized parameters.""" + def __init__(self, *args, transpose=False, **kwargs): + """ + Initialize a SizedAggregator instance. + + Parameters + ---------- + *args : Any + Variable length argument list. + transpose : bool, options + If False, the parameter shape will be ((n_instances, ) + parameter_shape). + Else, it will be (parameter_shape + (n_instances, )) + The default is False. + **kwargs : Any + Arbitrary keyword arguments. + """ + self.transpose = transpose + + super().__init__(*args, **kwargs) + def _parameter_shape(self, instance): - values = self._get_parameter_values_from_container(instance) - shapes = [np.array(el, ndmin=1).shape for el in values] + values = self._get_parameter_values_from_container(instance, transpose=False) + + shapes = [el.shape for el in values] if len(set(shapes)) > 1: raise ValueError("Inconsistent parameter shapes.") @@ -19,35 +39,91 @@ def _parameter_shape(self, instance): return shapes[0] def _expected_shape(self, instance): - return (self._n_instances(instance), ) + self._parameter_shape(instance) + if self.transpose: + return self._parameter_shape(instance) + (self._n_instances(instance), ) + else: + return (self._n_instances(instance), ) + self._parameter_shape(instance) - def _get_parameter_values_from_container(self, instance): + def _get_parameter_values_from_container(self, instance, transpose=False): value = super()._get_parameter_values_from_container(instance) if value is None or len(value) == 0: return - value = np.array(value, ndmin=2).T + value = np.array(value, ndmin=2) + if transpose and self.transpose: + value = value.T + return value - def _check(self, instance, value, recursive=False): + def _check(self, instance, value, transpose=True, recursive=False): + value_array = np.array(value, ndmin=2) + if transpose and self.transpose: + value_array = value_array.T + + value_shape = value_array.shape expected_shape = self._expected_shape(instance) - if value.shape != expected_shape: + + if value_shape != expected_shape: raise ValueError( - f"Expected a array with shape {expected_shape}, got {value.shape}" + f"Expected a array with shape {expected_shape}, got {value_shape}" ) if recursive: super()._check(instance, value, recursive) - def _prepare(self, instance, value, recursive=False): - value = np.array(value) + def _prepare(self, instance, value, transpose=False, recursive=False): + value = np.array(value, ndmin=2) + + if transpose and self.transpose: + value = value.T if recursive: value = super()._prepare(instance, value, recursive) return value + def __get__(self, instance, cls): + """ + Retrieve the descriptor value for the given instance. + + Parameters + ---------- + instance : Any + Instance to retrieve the descriptor value for. + cls : Type[Any], optional + Class to which the descriptor belongs. By default None. + + Returns + ------- + np.array + Descriptor values aggregated in a numpy array. + """ + if instance is None: + return self + + value = super().__get__(instance, cls) + + if value is not None: + if self.transpose: + value = value.T + + return value + + def __set__(self, instance, value): + """ + Set the descriptor value for the given instance. + + Parameters + ---------- + instance : Any + Instance to set the descriptor value for. + value : Any + Value to set. + """ + value = self._prepare(instance, value, transpose=True) + super().__set__(instance, value) + class ClassDependentAggregator(Aggregator): """Aggregator where parameter name changes depending on instance type.""" diff --git a/CADETProcess/dataStructure/dataStructure.py b/CADETProcess/dataStructure/dataStructure.py index cb4d577f..9c2a0ddb 100644 --- a/CADETProcess/dataStructure/dataStructure.py +++ b/CADETProcess/dataStructure/dataStructure.py @@ -98,7 +98,7 @@ def _container_obj(self, instance): return container def _n_instances(self, instance): - return len(self._get_parameter_values_from_container(instance)) + return len(self._container_obj(instance)) def _get_parameter_values_from_container(self, instance): container = self._container_obj(instance) @@ -131,6 +131,7 @@ def __get__(self, instance, cls): value = self._get_parameter_values_from_container(instance) if value is not None: + value = self._prepare(instance, value, recursive=True) self._check(instance, value, recursive=True) return value @@ -143,17 +144,18 @@ def __set__(self, instance, value): ---------- instance : Any Instance to set the descriptor value for. - value : Any - Value to set. + value : Iterable + Value to set. Note, this assumes that each element of the value maps to + each element of the container. """ if value is not None: value = self._prepare(instance, value, recursive=True) self._check(instance, value, recursive=True) - container = self._container_obj(instance) + containers = self._container_obj(instance) - for i, el in enumerate(container): - setattr(el, self.parameter_name, value[i]) + for container_value, container in zip(value, containers): + setattr(container, self.parameter_name, container_value) def _prepare(self, instance, value, recursive=False): """ @@ -185,12 +187,21 @@ def _check(self, instance, value, recursive=False): ---------- instance : Any Instance to retrieve the descriptor value for. - value : Any - Value to check. + value : Iterable + Value to set. Note, this assumes that each element of the value maps to + each element of the container. recursive : bool, optional If True, perform the check recursively. Defaults to False. """ + container = self._container_obj(instance) + + if len(value) != len(container): + raise ValueError( + "Unexpected length. " + f"Expected {len(container)} entries, got {len(value)}." + ) + return diff --git a/CADETProcess/processModel/reaction.py b/CADETProcess/processModel/reaction.py index 33d4aec4..80166f2c 100755 --- a/CADETProcess/processModel/reaction.py +++ b/CADETProcess/processModel/reaction.py @@ -50,7 +50,7 @@ class Reaction(Structure): The equilibrium constant for the reaction. """ is_kinetic = Bool(default=True) - stoich = SizedNdArray(size='n_comp', default=0) + stoich = SizedNdArray(size='n_comp') k_fwd = UnsignedFloat() k_bwd = UnsignedFloat() k_fwd_min = UnsignedFloat(default=100) @@ -492,9 +492,9 @@ class MassActionLaw(BulkReactionBase): k_fwd = Aggregator('k_fwd', 'reactions') k_bwd = Aggregator('k_bwd', 'reactions') - stoich = SizedAggregator('stoich', 'reactions') - exponents_fwd = SizedAggregator('exponents_fwd', 'reactions') - exponents_bwd = SizedAggregator('exponents_bwd', 'reactions') + stoich = SizedAggregator('stoich', 'reactions', transpose=True) + exponents_fwd = SizedAggregator('exponents_fwd', 'reactions', transpose=True) + exponents_bwd = SizedAggregator('exponents_bwd', 'reactions', transpose=True) _parameters = ['stoich', 'exponents_fwd', 'exponents_bwd', 'k_fwd', 'k_bwd'] @@ -549,34 +549,45 @@ class MassActionLawParticle(ParticleReactionBase): ) k_fwd_liquid = Aggregator('k_fwd', 'liquid_reactions') k_bwd_liquid = Aggregator('k_bwd', 'liquid_reactions') - exponents_fwd_liquid = SizedAggregator('exponents_fwd', 'liquid_reactions') - exponents_bwd_liquid = SizedAggregator('exponents_bwd', 'liquid_reactions') + exponents_fwd_liquid = SizedAggregator( + 'exponents_fwd', 'liquid_reactions', transpose=True + ) + exponents_bwd_liquid = SizedAggregator( + 'exponents_bwd', 'liquid_reactions', transpose=True + ) stoich_solid = SizedClassDependentAggregator( 'stoich_solid', 'solid_reactions', mapping={ CrossPhaseReaction: 'stoich_solid', None: 'stoich' - } + }, + transpose=True ) k_fwd_solid = Aggregator('k_fwd', 'solid_reactions') k_bwd_solid = Aggregator('k_bwd', 'solid_reactions') - exponents_fwd_solid = SizedAggregator('exponents_fwd', 'solid_reactions') - exponents_bwd_solid = SizedAggregator('exponents_bwd', 'solid_reactions') + exponents_fwd_solid = SizedAggregator( + 'exponents_fwd', 'solid_reactions', transpose=True + ) + exponents_bwd_solid = SizedAggregator( + 'exponents_bwd', 'solid_reactions', transpose=True + ) exponents_fwd_liquid_modsolid = SizedClassDependentAggregator( 'exponents_fwd_liquid_modsolid', 'liquid_reactions', mapping={ CrossPhaseReaction: 'exponents_fwd_liquid_modsolid', None: None - } + }, + transpose=True, ) exponents_bwd_liquid_modsolid = SizedClassDependentAggregator( 'exponents_bwd_liquid_modsolid', 'liquid_reactions', mapping={ CrossPhaseReaction: 'exponents_bwd_liquid_modsolid', None: None - } + }, + transpose=True, ) exponents_fwd_solid_modliquid = SizedClassDependentAggregator( @@ -584,14 +595,16 @@ class MassActionLawParticle(ParticleReactionBase): mapping={ CrossPhaseReaction: 'exponents_fwd_solid_modliquid', None: None - } + }, + transpose=True, ) exponents_bwd_solid_modliquid = SizedClassDependentAggregator( 'exponents_bwd_solid_modliquid', 'solid_reactions', mapping={ CrossPhaseReaction: 'exponents_bwd_solid_modliquid', None: None - } + }, + transpose=True, ) _parameters = [