From 57bc785435636438c856239d5f5c1efa05cc0228 Mon Sep 17 00:00:00 2001 From: Henri Perso Date: Mon, 13 Jan 2025 11:40:40 +0100 Subject: [PATCH 1/2] fix default values --- spockflow/components/tree/v1/core.py | 11 +++++++--- tests/unit/test_tree.py | 33 +++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/spockflow/components/tree/v1/core.py b/spockflow/components/tree/v1/core.py index 065fa8b..147cf11 100644 --- a/spockflow/components/tree/v1/core.py +++ b/spockflow/components/tree/v1/core.py @@ -113,6 +113,7 @@ def set_default(self, value: TOutput): raise ValueError( f"Cannot set default as length of value ({len_value}) incompatible with tree {len_tree}." ) + self.default_value = value def merge_into(self, other: Self): len_subtree = len(other) @@ -127,6 +128,9 @@ def merge_into(self, other: Self): raise ValueError( f"Cannot merge two subtrees both containing default values" ) + if other.default_value is not None: + self.set_default(other.default_value) + self.nodes.extend(other.nodes) @@ -316,7 +320,7 @@ def set_default(self, output: TOutput, child_tree: ChildTree = None): Notes: - If `child_tree` is not provided, the default is set for `self.root`. - Checks if a default value is already assigned to `child_tree`. If so, raises an error. - - Converts `output` to the root node of `output` if `output` is an instance of `Tree`. + - if `output` is an instance of `Tree`, add it as subtree. - Sets `child_tree.default_value` to `output`, establishing it as the default action when no specific conditions are met in the decision tree. @@ -335,8 +339,9 @@ def set_default(self, output: TOutput, child_tree: ChildTree = None): if child_tree.default_value is not None: raise ValueError("Default value already set") if isinstance(output, Tree): - output = output.root - child_tree.default_value = output + self.include_subtree(output, condition=None, child_tree=child_tree) + else: + child_tree.set_default(output) def include_subtree( self, diff --git a/tests/unit/test_tree.py b/tests/unit/test_tree.py index aad0dd9..fef0df6 100644 --- a/tests/unit/test_tree.py +++ b/tests/unit/test_tree.py @@ -208,7 +208,7 @@ def test_set_default_twice(tree): tree.set_default(value(888)) # Default value already set -def test_merge_subtrees_with_defaults(tree): +def test_merge_subtrees_with_only_defaults(tree): subtree1 = Tree() subtree1.set_default(value(100)) @@ -222,6 +222,21 @@ def test_merge_subtrees_with_defaults(tree): with pytest.raises(ValueError): tree.include_subtree(subtree2) +def test_merge_subtrees_with_defaults(tree): + subtree = Tree() + subtree.condition(output=value(100), condition="SubA") + subtree.condition(output=value(200), condition="SubB") + subtree.set_default(output=value(300)) + + subtree2 = Tree() + subtree2.condition(output=value(1000), condition="SubD") + subtree2.condition(output=value(2000), condition="SubE") + subtree2.set_default(output=value(3000)) + + tree.include_subtree(subtree) + with pytest.raises(ValueError): + tree.include_subtree(subtree2) + def test_circular_dep_tree(tree): @@ -237,3 +252,19 @@ def test_circular_dep_tree(tree): with pytest.raises(ValueError): subtree_2.condition(condition="A", output=subtree_1) + +def test_merge_with_tree_as_default_value(tree): + tree0 = Tree() + tree0.condition(output=value(100), condition="SubA") + tree0.condition(output=value(200), condition="SubB") + + tree.condition(output=value(300), condition="SubC") + tree.set_default(tree0) + + assert len(tree.root.nodes) == 3 + assert tree.root.nodes[0].value.loc[0, "value"] == 300 + assert tree.root.nodes[0].condition == "SubC" + assert tree.root.nodes[1].value.loc[0, "value"] == 100 + assert tree.root.nodes[1].condition == "SubA" + assert tree.root.nodes[2].value.loc[0, "value"] == 200 + assert tree.root.nodes[2].condition == "SubB" From 605ddd3c4300c098defdd378ca36a7477c82e816 Mon Sep 17 00:00:00 2001 From: Henri Perso Date: Mon, 13 Jan 2025 12:18:53 +0100 Subject: [PATCH 2/2] run black --- spockflow/components/tree/v1/core.py | 2 +- tests/unit/test_tree.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/spockflow/components/tree/v1/core.py b/spockflow/components/tree/v1/core.py index 147cf11..5e8ec37 100644 --- a/spockflow/components/tree/v1/core.py +++ b/spockflow/components/tree/v1/core.py @@ -130,7 +130,7 @@ def merge_into(self, other: Self): ) if other.default_value is not None: self.set_default(other.default_value) - + self.nodes.extend(other.nodes) diff --git a/tests/unit/test_tree.py b/tests/unit/test_tree.py index fef0df6..5ff879c 100644 --- a/tests/unit/test_tree.py +++ b/tests/unit/test_tree.py @@ -222,6 +222,7 @@ def test_merge_subtrees_with_only_defaults(tree): with pytest.raises(ValueError): tree.include_subtree(subtree2) + def test_merge_subtrees_with_defaults(tree): subtree = Tree() subtree.condition(output=value(100), condition="SubA") @@ -253,6 +254,7 @@ def test_circular_dep_tree(tree): with pytest.raises(ValueError): subtree_2.condition(condition="A", output=subtree_1) + def test_merge_with_tree_as_default_value(tree): tree0 = Tree() tree0.condition(output=value(100), condition="SubA")