Skip to content

Commit

Permalink
Fix setter for SizedAggregator
Browse files Browse the repository at this point in the history
  • Loading branch information
schmoelder committed Jan 28, 2025
1 parent 25a0f07 commit f9a5108
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 31 deletions.
96 changes: 86 additions & 10 deletions CADETProcess/dataStructure/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,29 @@
class SizedAggregator(Aggregator):
"""Aggregator for sized parameters."""

def __init__(self, *args, transpose=False, **kwargs):
"""
Initialize a SizedAggregator instance.
Parameters
----------
*args : Any
Variable length argument list.
transpose : bool, options
If False, the parameter shape will be ((n_instances, ) + parameter_shape).
Else, it will be (parameter_shape + (n_instances, ))
The default is False.
**kwargs : Any
Arbitrary keyword arguments.
"""
self.transpose = transpose

super().__init__(*args, **kwargs)

def _parameter_shape(self, instance):
values = self._get_parameter_values_from_container(instance)
shapes = [np.array(el, ndmin=1).shape for el in values]
values = self._get_parameter_values_from_container(instance, transpose=False)

shapes = [el.shape for el in values]

if len(set(shapes)) > 1:
raise ValueError("Inconsistent parameter shapes.")
Expand All @@ -19,35 +39,91 @@ def _parameter_shape(self, instance):
return shapes[0]

def _expected_shape(self, instance):
return (self._n_instances(instance), ) + self._parameter_shape(instance)
if self.transpose:
return self._parameter_shape(instance) + (self._n_instances(instance), )
else:
return (self._n_instances(instance), ) + self._parameter_shape(instance)

def _get_parameter_values_from_container(self, instance):
def _get_parameter_values_from_container(self, instance, transpose=False):
value = super()._get_parameter_values_from_container(instance)

if value is None or len(value) == 0:
return

value = np.array(value, ndmin=2).T
value = np.array(value, ndmin=2)
if transpose and self.transpose:
value = value.T

return value

def _check(self, instance, value, recursive=False):
def _check(self, instance, value, transpose=True, recursive=False):
value_array = np.array(value, ndmin=2)
if transpose and self.transpose:
value_array = value_array.T

value_shape = value_array.shape
expected_shape = self._expected_shape(instance)
if value.shape != expected_shape:

if value_shape != expected_shape:
raise ValueError(
f"Expected a array with shape {expected_shape}, got {value.shape}"
f"Expected a array with shape {expected_shape}, got {value_shape}"
)

if recursive:
super()._check(instance, value, recursive)

def _prepare(self, instance, value, recursive=False):
value = np.array(value)
def _prepare(self, instance, value, transpose=False, recursive=False):
value = np.array(value, ndmin=2)

if transpose and self.transpose:
value = value.T

if recursive:
value = super()._prepare(instance, value, recursive)

return value

def __get__(self, instance, cls):
"""
Retrieve the descriptor value for the given instance.
Parameters
----------
instance : Any
Instance to retrieve the descriptor value for.
cls : Type[Any], optional
Class to which the descriptor belongs. By default None.
Returns
-------
np.array
Descriptor values aggregated in a numpy array.
"""
if instance is None:
return self

value = super().__get__(instance, cls)

if value is not None:
if self.transpose:
value = value.T

return value

def __set__(self, instance, value):
"""
Set the descriptor value for the given instance.
Parameters
----------
instance : Any
Instance to set the descriptor value for.
value : Any
Value to set.
"""
value = self._prepare(instance, value, transpose=True)
super().__set__(instance, value)


class ClassDependentAggregator(Aggregator):
"""Aggregator where parameter name changes depending on instance type."""
Expand Down
27 changes: 19 additions & 8 deletions CADETProcess/dataStructure/dataStructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _container_obj(self, instance):
return container

def _n_instances(self, instance):
return len(self._get_parameter_values_from_container(instance))
return len(self._container_obj(instance))

def _get_parameter_values_from_container(self, instance):
container = self._container_obj(instance)
Expand Down Expand Up @@ -131,6 +131,7 @@ def __get__(self, instance, cls):
value = self._get_parameter_values_from_container(instance)

if value is not None:
value = self._prepare(instance, value, recursive=True)
self._check(instance, value, recursive=True)

return value
Expand All @@ -143,17 +144,18 @@ def __set__(self, instance, value):
----------
instance : Any
Instance to set the descriptor value for.
value : Any
Value to set.
value : Iterable
Value to set. Note, this assumes that each element of the value maps to
each element of the container.
"""
if value is not None:
value = self._prepare(instance, value, recursive=True)
self._check(instance, value, recursive=True)

container = self._container_obj(instance)
containers = self._container_obj(instance)

for i, el in enumerate(container):
setattr(el, self.parameter_name, value[i])
for container_value, container in zip(value, containers):
setattr(container, self.parameter_name, container_value)

def _prepare(self, instance, value, recursive=False):
"""
Expand Down Expand Up @@ -185,12 +187,21 @@ def _check(self, instance, value, recursive=False):
----------
instance : Any
Instance to retrieve the descriptor value for.
value : Any
Value to check.
value : Iterable
Value to set. Note, this assumes that each element of the value maps to
each element of the container.
recursive : bool, optional
If True, perform the check recursively. Defaults to False.
"""
container = self._container_obj(instance)

if len(value) != len(container):
raise ValueError(
"Unexpected length. "
f"Expected {len(container)} entries, got {len(value)}."
)

return


Expand Down
39 changes: 26 additions & 13 deletions CADETProcess/processModel/reaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Reaction(Structure):
The equilibrium constant for the reaction.
"""
is_kinetic = Bool(default=True)
stoich = SizedNdArray(size='n_comp', default=0)
stoich = SizedNdArray(size='n_comp')
k_fwd = UnsignedFloat()
k_bwd = UnsignedFloat()
k_fwd_min = UnsignedFloat(default=100)
Expand Down Expand Up @@ -492,9 +492,9 @@ class MassActionLaw(BulkReactionBase):

k_fwd = Aggregator('k_fwd', 'reactions')
k_bwd = Aggregator('k_bwd', 'reactions')
stoich = SizedAggregator('stoich', 'reactions')
exponents_fwd = SizedAggregator('exponents_fwd', 'reactions')
exponents_bwd = SizedAggregator('exponents_bwd', 'reactions')
stoich = SizedAggregator('stoich', 'reactions', transpose=True)
exponents_fwd = SizedAggregator('exponents_fwd', 'reactions', transpose=True)
exponents_bwd = SizedAggregator('exponents_bwd', 'reactions', transpose=True)

_parameters = ['stoich', 'exponents_fwd', 'exponents_bwd', 'k_fwd', 'k_bwd']

Expand Down Expand Up @@ -549,49 +549,62 @@ class MassActionLawParticle(ParticleReactionBase):
)
k_fwd_liquid = Aggregator('k_fwd', 'liquid_reactions')
k_bwd_liquid = Aggregator('k_bwd', 'liquid_reactions')
exponents_fwd_liquid = SizedAggregator('exponents_fwd', 'liquid_reactions')
exponents_bwd_liquid = SizedAggregator('exponents_bwd', 'liquid_reactions')
exponents_fwd_liquid = SizedAggregator(
'exponents_fwd', 'liquid_reactions', transpose=True
)
exponents_bwd_liquid = SizedAggregator(
'exponents_bwd', 'liquid_reactions', transpose=True
)

stoich_solid = SizedClassDependentAggregator(
'stoich_solid', 'solid_reactions',
mapping={
CrossPhaseReaction: 'stoich_solid',
None: 'stoich'
}
},
transpose=True
)
k_fwd_solid = Aggregator('k_fwd', 'solid_reactions')
k_bwd_solid = Aggregator('k_bwd', 'solid_reactions')
exponents_fwd_solid = SizedAggregator('exponents_fwd', 'solid_reactions')
exponents_bwd_solid = SizedAggregator('exponents_bwd', 'solid_reactions')
exponents_fwd_solid = SizedAggregator(
'exponents_fwd', 'solid_reactions', transpose=True
)
exponents_bwd_solid = SizedAggregator(
'exponents_bwd', 'solid_reactions', transpose=True
)

exponents_fwd_liquid_modsolid = SizedClassDependentAggregator(
'exponents_fwd_liquid_modsolid', 'liquid_reactions',
mapping={
CrossPhaseReaction: 'exponents_fwd_liquid_modsolid',
None: None
}
},
transpose=True,
)
exponents_bwd_liquid_modsolid = SizedClassDependentAggregator(
'exponents_bwd_liquid_modsolid', 'liquid_reactions',
mapping={
CrossPhaseReaction: 'exponents_bwd_liquid_modsolid',
None: None
}
},
transpose=True,
)

exponents_fwd_solid_modliquid = SizedClassDependentAggregator(
'exponents_fwd_solid_modliquid', 'solid_reactions',
mapping={
CrossPhaseReaction: 'exponents_fwd_solid_modliquid',
None: None
}
},
transpose=True,
)
exponents_bwd_solid_modliquid = SizedClassDependentAggregator(
'exponents_bwd_solid_modliquid', 'solid_reactions',
mapping={
CrossPhaseReaction: 'exponents_bwd_solid_modliquid',
None: None
}
},
transpose=True,
)

_parameters = [
Expand Down

0 comments on commit f9a5108

Please sign in to comment.