Skip to content

Commit

Permalink
Add ProxyList/Array to allow indexed modification of aggregated prope…
Browse files Browse the repository at this point in the history
…rties
  • Loading branch information
schmoelder committed Jan 29, 2025
1 parent 8a25232 commit 3c5ffc7
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 22 deletions.
95 changes: 83 additions & 12 deletions CADETProcess/dataStructure/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,69 @@

from .dataStructure import Aggregator

import numpy as np


class NumpyProxyArray(np.ndarray):
"""A numpy array that dynamically updates attributes of container elements."""

def __new__(cls, aggregator, instance):
values = aggregator._get_values_from_container(instance, transpose=True)

if values is None:
return

obj = values.view(cls)

obj.aggregator = aggregator
obj.instance = instance

return obj

def _get_values_from_aggregator(self):
"""Refresh data from the underlying container."""
return self.aggregator._get_values_from_container(
self.instance, transpose=True, check=True
)

def __getitem__(self, index):
"""Retrieve an item from the aggregated parameter array."""
return self._get_values_from_aggregator()[index]

def __setitem__(self, index, value):
"""
Modify an individual element in the aggregated parameter list.
This ensures changes are propagated back to the objects.
"""
current_value = self._get_values_from_aggregator()
current_value[index] = value
self.aggregator.__set__(self.instance, current_value)

def __array_finalize__(self, obj):
"""Ensure attributes are copied when creating a new view or slice."""
if obj is None:
self.aggregator = None
self.instance = None
return

if not isinstance(obj, NumpyProxyArray):
return np.asarray(obj)

self.aggregator = getattr(obj, 'aggregator', None)
self.instance = getattr(obj, 'instance', None)

def __array_function__(self, func, types, *args, **kwargs):
"""
Ensures that high-level NumPy functions (like np.dot, np.linalg.norm) return a
normal np.ndarray.
"""
result = super().__array_function__(func, types, *args, **kwargs)
return np.asarray(result)

def __repr__(self):
"""Return a fresh representation that reflects live data."""
return f"NumpyProxyArray({self._get_values_from_aggregator().__repr__()})"


class SizedAggregator(Aggregator):
"""Aggregator for sized parameters."""
Expand All @@ -26,7 +89,7 @@ def __init__(self, *args, transpose=False, **kwargs):
super().__init__(*args, **kwargs)

def _parameter_shape(self, instance):
values = self._get_parameter_values_from_container(instance, transpose=False)
values = self._get_values_from_container(instance, transpose=False)

shapes = [el.shape for el in values]

Expand All @@ -44,13 +107,18 @@ def _expected_shape(self, instance):
else:
return (self._n_instances(instance), ) + self._parameter_shape(instance)

def _get_parameter_values_from_container(self, instance, transpose=False):
value = super()._get_parameter_values_from_container(instance)
def _get_values_from_container(self, instance, transpose=False, check=False):
value = super()._get_values_from_container(instance, check=False)

if value is None or len(value) == 0:
return

value = np.array(value, ndmin=2)

if check:
value = self._prepare(instance, value, transpose=False, recursive=True)
self._check(instance, value, transpose=True, recursive=True)

if transpose and self.transpose:
value = value.T

Expand Down Expand Up @@ -92,19 +160,17 @@ def __get__(self, instance, cls):
instance : Any
Instance to retrieve the descriptor value for.
cls : Type[Any], optional
Class to which the descriptor belongs. By default None.
Class to which the descriptor belongs.
Returns
-------
np.array
Descriptor values aggregated in a numpy array.
NumpyProxyArray
A numpy-like array that modifies the underlying objects when changed.
"""
value = super().__get__(instance, cls)
if value is not None and value is not self:
if self.transpose:
value = value.T
if instance is None:
return self

return value
return NumpyProxyArray(self, instance)

def __set__(self, instance, value):
"""
Expand Down Expand Up @@ -142,7 +208,7 @@ def __init__(self, *args, mapping, **kwargs):

super().__init__(*args, **kwargs)

def _get_parameter_values_from_container(self, instance):
def _get_values_from_container(self, instance, check=False):
container = self._container_obj(instance)

values = []
Expand All @@ -160,7 +226,12 @@ def _get_parameter_values_from_container(self, instance):
if len(values) == 0:
return

if check:
values = self._prepare(instance, values, transpose=False, recursive=True)
self._check(instance, values, transpose=True, recursive=True)

return values


class SizedClassDependentAggregator(SizedAggregator, ClassDependentAggregator):
pass
91 changes: 81 additions & 10 deletions CADETProcess/dataStructure/dataStructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,81 @@ def __delete__(self, instance):
del instance.__dict__[self.name]


class ProxyList(list):
"""A proxy list that dynamically updates attributes of container elements."""

def __init__(self, aggregator, instance):
values = aggregator._get_values_from_container(instance)

if values is None:
values = [] # Ensure we have a valid default

self.aggregator = aggregator
self.instance = instance

super().__init__(values)

def _get_values_from_aggregator(self):
"""Fetch the latest values from the aggregator."""
return self.aggregator._get_values_from_container(
self.instance, check=True
)

def __getitem__(self, index):
"""Retrieve an item from the aggregated parameter list (live view)."""
return self._get_values_from_aggregator()[index]

def __setitem__(self, index, value):
"""
Modify an individual element in the aggregated parameter list.
Ensures changes propagate to the underlying objects.
"""
current_value = self._get_values_from_aggregator()
current_value[index] = value
self.aggregator.__set__(self.instance, current_value)
super().__setitem__(index, value) # Update the proxy list as well

def __iter__(self):
"""Iterate over aggregated values."""
return iter(self._get_values_from_aggregator())

def __len__(self):
"""Return the length of the container."""
return len(self._get_values_from_aggregator())

def __repr__(self):
"""String representation for debugging."""
return f"ProxyList({self._get_values_from_aggregator().__repr__()})"

def __eq__(self, other):
"""Equality comparison."""
return list(self._get_values_from_aggregator()) == other

def append(self, value):
"""Prevent appending to the proxy list."""
raise NotImplementedError("Appending elements is not allowed.")

def extend(self, values):
"""Prevent extending the proxy list."""
raise NotImplementedError("Extending elements is not allowed.")

def insert(self, index, value):
"""Prevent inserting into the proxy list."""
raise NotImplementedError("Inserting elements is not allowed.")

def pop(self, index=-1):
"""Prevent removing elements."""
raise NotImplementedError("Popping elements is not allowed.")

def remove(self, value):
"""Prevent removing elements."""
raise NotImplementedError("Removing elements is not allowed.")

def clear(self):
"""Prevent clearing elements."""
raise NotImplementedError("Clearing elements is not allowed.")


class Aggregator():
"""Descriptor aggregating parameters from iterable container of other objects."""

Expand Down Expand Up @@ -100,12 +175,14 @@ def _container_obj(self, instance):
def _n_instances(self, instance):
return len(self._container_obj(instance))

def _get_parameter_values_from_container(self, instance):
def _get_values_from_container(self, instance, check=False):
container = self._container_obj(instance)

value = [getattr(el, self.parameter_name) for el in container]

if len(value) == 0:
return
if check:
value = self._prepare(instance, value, recursive=True)
self._check(instance, value, recursive=True)

return value

Expand All @@ -128,13 +205,7 @@ def __get__(self, instance, cls):
if instance is None:
return self

value = self._get_parameter_values_from_container(instance)

if value is not None:
value = self._prepare(instance, value, recursive=True)
self._check(instance, value, recursive=True)

return value
return ProxyList(self, instance)

def __set__(self, instance, value):
"""
Expand Down

0 comments on commit 3c5ffc7

Please sign in to comment.