Skip to content

Commit

Permalink
Merge pull request #15 from capitec/feature/fix-default-value
Browse files Browse the repository at this point in the history
fix default values
  • Loading branch information
sjnarmstrong authored Jan 14, 2025
2 parents a4ffda5 + 605ddd3 commit 2d352fe
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
11 changes: 8 additions & 3 deletions spockflow/components/tree/v1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def set_default(self, value: TOutput):
raise ValueError(
f"Cannot set default as length of value ({len_value}) incompatible with tree {len_tree}."
)
self.default_value = value

def merge_into(self, other: Self):
len_subtree = len(other)
Expand All @@ -127,6 +128,9 @@ def merge_into(self, other: Self):
raise ValueError(
f"Cannot merge two subtrees both containing default values"
)
if other.default_value is not None:
self.set_default(other.default_value)

self.nodes.extend(other.nodes)


Expand Down Expand Up @@ -316,7 +320,7 @@ def set_default(self, output: TOutput, child_tree: ChildTree = None):
Notes:
- If `child_tree` is not provided, the default is set for `self.root`.
- Checks if a default value is already assigned to `child_tree`. If so, raises an error.
- Converts `output` to the root node of `output` if `output` is an instance of `Tree`.
- if `output` is an instance of `Tree`, add it as subtree.
- Sets `child_tree.default_value` to `output`, establishing it as the default action
when no specific conditions are met in the decision tree.
Expand All @@ -335,8 +339,9 @@ def set_default(self, output: TOutput, child_tree: ChildTree = None):
if child_tree.default_value is not None:
raise ValueError("Default value already set")
if isinstance(output, Tree):
output = output.root
child_tree.default_value = output
self.include_subtree(output, condition=None, child_tree=child_tree)
else:
child_tree.set_default(output)

def include_subtree(
self,
Expand Down
35 changes: 34 additions & 1 deletion tests/unit/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_set_default_twice(tree):
tree.set_default(value(888)) # Default value already set


def test_merge_subtrees_with_defaults(tree):
def test_merge_subtrees_with_only_defaults(tree):
subtree1 = Tree()
subtree1.set_default(value(100))

Expand All @@ -223,6 +223,22 @@ def test_merge_subtrees_with_defaults(tree):
tree.include_subtree(subtree2)


def test_merge_subtrees_with_defaults(tree):
subtree = Tree()
subtree.condition(output=value(100), condition="SubA")
subtree.condition(output=value(200), condition="SubB")
subtree.set_default(output=value(300))

subtree2 = Tree()
subtree2.condition(output=value(1000), condition="SubD")
subtree2.condition(output=value(2000), condition="SubE")
subtree2.set_default(output=value(3000))

tree.include_subtree(subtree)
with pytest.raises(ValueError):
tree.include_subtree(subtree2)


def test_circular_dep_tree(tree):

# TODO this should raise an error
Expand All @@ -237,3 +253,20 @@ def test_circular_dep_tree(tree):

with pytest.raises(ValueError):
subtree_2.condition(condition="A", output=subtree_1)


def test_merge_with_tree_as_default_value(tree):
tree0 = Tree()
tree0.condition(output=value(100), condition="SubA")
tree0.condition(output=value(200), condition="SubB")

tree.condition(output=value(300), condition="SubC")
tree.set_default(tree0)

assert len(tree.root.nodes) == 3
assert tree.root.nodes[0].value.loc[0, "value"] == 300
assert tree.root.nodes[0].condition == "SubC"
assert tree.root.nodes[1].value.loc[0, "value"] == 100
assert tree.root.nodes[1].condition == "SubA"
assert tree.root.nodes[2].value.loc[0, "value"] == 200
assert tree.root.nodes[2].condition == "SubB"

0 comments on commit 2d352fe

Please sign in to comment.