Skip to content

Commit

Permalink
Add tests for Aggregator
Browse files Browse the repository at this point in the history
  • Loading branch information
schmoelder committed Jan 29, 2025
1 parent 1f90cfc commit 7b70fc7
Showing 1 changed file with 81 additions and 1 deletion.
82 changes: 81 additions & 1 deletion tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from CADETProcess.dataStructure import (
Structure,
Constant, Switch,
Typed, Integer, String, List,
Typed, Integer, Float, String, List,
Callable,
RangedFloat, UnsignedInteger,
SizedList, SizedNdArray,
SizedUnsignedList, SizedUnsignedNdArray,
Polynomial, NdPolynomial,
Vector, Matrix,
DependentlyModulatedUnsignedList,
Aggregator, SizedAggregator, SizedClassDependentAggregator,
)


Expand Down Expand Up @@ -592,5 +593,84 @@ def test_value(self):
self.model.matrix = [1, 2]


class TestAggregator(unittest.TestCase):
def setUp(self):
class DummyInstance(Structure):
float_param = Float(default=1.0)
sized_param = SizedNdArray(size=4)
sized_param_transposed = SizedNdArray(size=2)

class Model(Structure):
aggregator = Aggregator('float_param', 'container')
sized_aggregator = SizedAggregator('sized_param', 'container')
transposed_sized_aggregator = SizedAggregator(
'sized_param_transposed', 'container', transpose=True
)

def __init__(self):
self.container = [
DummyInstance(
float_param=i,
sized_param=[i*j for j in range(4)],
sized_param_transposed=[i*j for j in range(2)]
)
for i in range(3)
]

self.model = Model()

def test_value(self):
# Aggregator
self.assertAlmostEqual(self.model.aggregator, [0.0, 1.0, 2.0])

new_value = [1, 2, 3]
self.model.aggregator = new_value

self.assertAlmostEqual(self.model.aggregator, new_value)
for con, val in zip(self.model.container, new_value):
self.assertAlmostEqual(con.float_param, val)

# SizedAggregator
np.testing.assert_almost_equal(
self.model.sized_aggregator,
[
[0, 0, 0, 0],
[0, 1, 2, 3],
[0, 2, 4, 6],
]
)

new_value = [
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6],
]
self.model.sized_aggregator = new_value

np.testing.assert_almost_equal(self.model.sized_aggregator, new_value)
for con, val in zip(self.model.container, new_value):
np.testing.assert_almost_equal(con.sized_param, val)

# Transposed SizedAggregator
np.testing.assert_almost_equal(
self.model.transposed_sized_aggregator,
[
[0, 0, 0],
[0, 1, 2]
]
)
new_value = [
[1, 2, 3],
[2, 3, 4]
]
self.model.transposed_sized_aggregator = new_value

np.testing.assert_almost_equal(
self.model.transposed_sized_aggregator, new_value
)
for con, val in zip(self.model.container, np.array(new_value).T):
np.testing.assert_almost_equal(con.sized_param_transposed, val)


if __name__ == '__main__':
unittest.main()

0 comments on commit 7b70fc7

Please sign in to comment.