From d71418c3252a1d353d4592dc944cc2b22c037f58 Mon Sep 17 00:00:00 2001 From: Sholto Armstrong Date: Tue, 14 Jan 2025 11:38:00 +0200 Subject: [PATCH 1/2] add tree priority --- .../example_pipeline_light.drawio.svg | 6 +- spockflow/_serializable.py | 144 ++++--- spockflow/components/tree/__init__.py | 5 +- spockflow/components/tree/v1/compiled.py | 109 ++++-- spockflow/components/tree/v1/core.py | 359 ++++++++++++++---- spockflow/nodes.py | 7 +- 6 files changed, 482 insertions(+), 148 deletions(-) diff --git a/docs/_static/getting-started/example_pipeline_light.drawio.svg b/docs/_static/getting-started/example_pipeline_light.drawio.svg index 0cf3440..84feb97 100644 --- a/docs/_static/getting-started/example_pipeline_light.drawio.svg +++ b/docs/_static/getting-started/example_pipeline_light.drawio.svg @@ -1,4 +1,4 @@ - + @@ -53,13 +53,13 @@
- Criminal Record? + Clients Age?
- Criminal Record? + Clients Age? diff --git a/spockflow/_serializable.py b/spockflow/_serializable.py index a857c34..395b879 100644 --- a/spockflow/_serializable.py +++ b/spockflow/_serializable.py @@ -1,82 +1,127 @@ import pandas as pd -from typing import Any, Union +from typing import Any from pydantic_core import core_schema from typing_extensions import Annotated from pydantic import ( - BaseModel, GetCoreSchemaHandler, GetJsonSchemaHandler, - ValidationError, ) from pydantic.json_schema import JsonSchemaValue -def dump_to_dict(instance: Union[pd.Series, pd.DataFrame]) -> dict: - if isinstance(instance, pd.DataFrame): - return {"type": "DataFrame", "values": instance.to_dict(orient="list")} - return {"values": instance.to_list(), "name": instance.name, "type": "Series"} +values_schema = core_schema.dict_schema(keys_schema=core_schema.str_schema()) +dataframe_json_schema = core_schema.model_fields_schema({ + "type": core_schema.model_field(core_schema.literal_schema(["DataFrame"])), + "values": core_schema.model_field(core_schema.union_schema([values_schema, core_schema.list_schema(values_schema)])), + "dtypes": core_schema.model_field( + core_schema.with_default_schema( + core_schema.union_schema([core_schema.none_schema(), core_schema.dict_schema(keys_schema=core_schema.str_schema(), values_schema=core_schema.str_schema())]), + default=None, + ) + ) +}) +series_json_schema = core_schema.model_fields_schema({ + "type": core_schema.model_field(core_schema.literal_schema(["Series"])), + "values": core_schema.model_field(core_schema.list_schema()), + "name": core_schema.model_field( + core_schema.with_default_schema( + core_schema.union_schema([core_schema.none_schema(), core_schema.str_schema()]), + default=None, + ) + ) +}) + + +def validate_to_dataframe_schema(value: dict) -> pd.DataFrame: + value, *_ = value + values = value["values"] + if isinstance(values, dict): + values = [values] + df = pd.DataFrame(values) + dtypes = value.get("dtypes") + if dtypes: + return df.astype(dtypes) + return df + + +def validate_to_series_schema(value: dict) -> pd.Series: + value, *_ = value + return pd.Series(value["values"], name=value.get("name")) + +from_df_dict_schema = core_schema.chain_schema( + [ + dataframe_json_schema, + # core_schema.dict_schema(), + core_schema.no_info_plain_validator_function(validate_to_dataframe_schema), + ] +) +from_series_dict_schema = core_schema.chain_schema( + [ + series_json_schema, + # core_schema.dict_schema(), + core_schema.no_info_plain_validator_function(validate_to_series_schema), + ] +) -# class Series(pd.Series): -# @classmethod -# def __get_pydantic_core_schema__( -# cls, __source: type[Any], __handler: GetCoreSchemaHandler -# ) -> core_schema.CoreSchema: -# return core_schema.no_info_before_validator_function( -# pd.Series, -# core_schema.dict_schema(), -# serialization=core_schema.plain_serializer_function_ser_schema(dump_to_dict) -# ) +def dump_df_to_dict(instance: pd.DataFrame) -> dict: + values = instance.to_dict(orient="records") + if len(values) == 1: + values = values[0] + return {"type": "DataFrame", "values": values, "dtypes": {k: str(v) for k,v in instance.dtypes.items()}} -# class DataFrame(pd.DataFrame): -# @classmethod -# def __get_pydantic_core_schema__( -# cls, __source: type[Any], __handler: GetCoreSchemaHandler -# ) -> core_schema.CoreSchema: -# return core_schema.no_info_before_validator_function( -# pd.DataFrame, -# core_schema.dict_schema(), -# serialization=core_schema.plain_serializer_function_ser_schema(dump_to_dict) -# ) +def dump_series_to_dict(instance: pd.Series) -> dict: + return {"type": "Series", "values": instance.to_list(), "name": instance.name} -# This class allows Pydantic to serialise and deserialise pandas Dataframes and Series items +class _PandasDataFramePydanticAnnotation: + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + return core_schema.json_or_python_schema( + json_schema=from_df_dict_schema, + python_schema=core_schema.union_schema( + [ + # check if it's an instance first before doing any further work + core_schema.is_instance_schema(pd.DataFrame), + from_df_dict_schema, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + dump_df_to_dict + ), + ) -class _PandasPydanticAnnotation: + @classmethod + def __get_pydantic_json_schema__( + cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + return handler(dataframe_json_schema) + +class _PandasSeriesPydanticAnnotation: @classmethod def __get_pydantic_core_schema__( cls, _source_type: Any, _handler: GetCoreSchemaHandler, ) -> core_schema.CoreSchema: - def validate_from_dict(value: dict) -> pd.Series: - data_type = value.get("type") - if data_type is None: - return pd.DataFrame(value) - if value.get("type") == "DataFrame": - return pd.DataFrame(value["values"]) - return pd.Series(value["values"], name=value["name"]) - - from_int_schema = core_schema.chain_schema( - [ - core_schema.dict_schema(), # TODO make this more comprehensive - core_schema.no_info_plain_validator_function(validate_from_dict), - ] - ) return core_schema.json_or_python_schema( - json_schema=from_int_schema, + json_schema=from_series_dict_schema, python_schema=core_schema.union_schema( [ # check if it's an instance first before doing any further work core_schema.is_instance_schema(pd.Series), - from_int_schema, + from_series_dict_schema, ] ), serialization=core_schema.plain_serializer_function_ser_schema( - dump_to_dict + dump_series_to_dict ), ) @@ -84,10 +129,9 @@ def validate_from_dict(value: dict) -> pd.Series: def __get_pydantic_json_schema__( cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler ) -> JsonSchemaValue: - return handler(core_schema.dict_schema()) # TODO make this more comprehensive - + return handler(dataframe_json_schema) -DataFrame = Annotated[pd.DataFrame, _PandasPydanticAnnotation] +DataFrame = Annotated[pd.DataFrame, _PandasDataFramePydanticAnnotation] +Series = Annotated[pd.Series, _PandasSeriesPydanticAnnotation] -Series = Annotated[pd.Series, _PandasPydanticAnnotation] diff --git a/spockflow/components/tree/__init__.py b/spockflow/components/tree/__init__.py index 31e8147..33e98c2 100644 --- a/spockflow/components/tree/__init__.py +++ b/spockflow/components/tree/__init__.py @@ -1,6 +1,9 @@ from typing import TypeVar, Generic import pandas as pd -from spockflow.components.tree.v1.core import Tree as Tree +from spockflow.components.tree.v1.core import ( + Tree as Tree, + TableCondition as TableCondition +) T = TypeVar("T", bound=dict) diff --git a/spockflow/components/tree/v1/compiled.py b/spockflow/components/tree/v1/compiled.py index 2c46b1d..5303990 100644 --- a/spockflow/components/tree/v1/compiled.py +++ b/spockflow/components/tree/v1/compiled.py @@ -4,8 +4,10 @@ from pydantic import BaseModel from dataclasses import dataclass from spockflow.components.tree.settings import settings -from .core import ChildTree, Tree, TOutput, TCond +from .core import ChildTree, Tree, TOutput, TCond, ConditionedNode, TableConditionedNode from spockflow.nodes import creates_node +if typing.TYPE_CHECKING: + from spockflow.components.dtable import DecisionTable # from pandas.core.groupby import DataFrameGroupBy @@ -17,12 +19,14 @@ class ConditionedOutput: output: TOutput conditions: typing.Tuple[TCond] + priority: int @dataclass class SymbolicConditionedOutput: output: str conditions: typing.Tuple[str] + priority: int @dataclass @@ -37,6 +41,13 @@ class SymbolicFlatTree: class CompiledNumpyTree: def __init__(self, tree: Tree) -> None: + self.decision_tables: typing.Dict[str, "DecisionTable"] = {} + for k,v in tree.get_decision_tables().items(): + assert list(v.outputs.keys()) == ["value"], \ + f"Decision tables nested in decision trees can only have one output named \"value\"" + # Decision tables dont need compilation but keeping with convention + self.decision_tables[k] = v.compile() + self.length = len(tree.root) flattened_tree = self._flatten_tree( @@ -47,6 +58,8 @@ def __init__(self, tree: Tree) -> None: ) flattened_tree = self._get_symbolic_tree(flattened_tree) self._flattened_tree = flattened_tree + self._flattened_priority = np.array([n.priority for n in self._flattened_tree.tree], dtype=np.int32) + self._has_priority = any(self._flattened_priority != 1) (predefined_conditions, predefined_condition_names, execution_conditions) = ( self._split_predefined( @@ -222,51 +235,83 @@ def get_unique_name(v): name = get_unique_name(n.output) outputs[name] = n.output new_tree_nodes.append( - SymbolicConditionedOutput(name, tuple(new_conditions)) + SymbolicConditionedOutput(name, tuple(new_conditions), n.priority) ) return SymbolicFlatTree( outputs=outputs, conditions=conditions, tree=new_tree_nodes ) - @classmethod + @staticmethod + def _merge_priority(p1: typing.Optional[int], p2: typing.Optional[int]) -> typing.Optional[int]: + if p1 is None and p2 is None: + return None + return (p1 or 0) + (p2 or 0) # Also caps values to > 0 + def _flatten_tree( - cls, + self, sub_tree: ChildTree, current_conditions: typing.Tuple[TCond], seen: typing.Set[int], conditioned_outputs: typing.List[ConditionedOutput], + priority: typing.Optional[int]=None, ) -> typing.List[ConditionedOutput]: curr_id = id(sub_tree) if curr_id in seen: raise ValueError("Current tree contains loops. Cannot compile tree.") for n in sub_tree.nodes: - if n.value is None: - raise ValueError( - "All nodes must have a value set to be a valid tree.\n" - "Found a leaf with no value." - ) - if n.condition is None: - raise ValueError( - "All nodes must have a condition set to be a valid tree\n" - "Found a leaf with no condition.\n" - "If this is intended to be a default value please use set_default." - ) + if isinstance(n, ConditionedNode): + if n.value is None: + raise ValueError( + "All nodes must have a value set to be a valid tree.\n" + "Found a leaf with no value." + ) + if n.condition is None: + raise ValueError( + "All nodes must have a condition set to be a valid tree\n" + "Found a leaf with no condition.\n" + "If this is intended to be a default value please use set_default." + ) - n_conditions = current_conditions + (n.condition,) - if isinstance(n.value, ChildTree): - cls._flatten_tree( - n.value, - current_conditions=n_conditions, - seen=seen.union([curr_id]), - conditioned_outputs=conditioned_outputs, - ) + n_conditions = current_conditions + (n.condition,) + if isinstance(n.value, ChildTree): + self._flatten_tree( + n.value, + current_conditions=n_conditions, + seen=seen.union([curr_id]), + conditioned_outputs=conditioned_outputs, + priority=self._merge_priority(priority, n.priority) + ) + else: + conditioned_outputs.append(ConditionedOutput(n.value, n_conditions, max(priority or 0,1))) else: - conditioned_outputs.append(ConditionedOutput(n.value, n_conditions)) - + if n.condition_table not in self.decision_tables: + raise ValueError( + f"Found node conditioned on a table ({n.condition_table}) that is not registered within the tree." + ) + n._check_compatible_table(self.decision_tables[n.condition_table]) + for i,v in enumerate(n.values): + if v is None: + raise ValueError( + "All nodes must have a value set to be a valid tree.\n" + "Found a leaf with no value." + ) + n_priority = None if n.priority is None else n.priority[i] + n_conditions = current_conditions + (f"{n.condition_table}_is_{i}",) + if isinstance(v, ChildTree): + self._flatten_tree( + v, + current_conditions=n_conditions, + seen=seen.union([curr_id]), + conditioned_outputs=conditioned_outputs, + priority=self._merge_priority(priority, n_priority) + ) + else: + conditioned_outputs.append(ConditionedOutput(v, n_conditions, max(priority or 0,1))) + if sub_tree.default_value is not None: conditioned_outputs.append( - ConditionedOutput(sub_tree.default_value, current_conditions) + ConditionedOutput(sub_tree.default_value, current_conditions, max(priority or 0,1)) ) return conditioned_outputs @@ -345,7 +390,13 @@ def conditions_met(self, format_inputs: TFormatData) -> np.ndarray: # [O,C]@[C,N] => [O,N] (Matrix multiplication should be the same as performing a count of all true statements) # The thresh will see where they are all true return (conditions @ self.truth_table) >= self.truth_table_thresh - + + @creates_node() + def prioritized_conditions(self, conditions_met: np.ndarray) -> np.ndarray: + if not self._has_priority: return conditions_met + # [O] * [O,N] -> [O,N] (outputs, batch) + return self._flattened_priority * conditions_met + @creates_node() def condition_names(self, format_inputs: TFormatData) -> typing.List[str]: conditions = format_inputs[0] @@ -389,10 +440,10 @@ def all( def get_results( self, format_inputs: TFormatData, - conditions_met: np.ndarray, + prioritized_conditions: np.ndarray, ) -> pd.DataFrame: _, outputs, *_ = format_inputs - condition_output_idx = np.argmax(conditions_met, axis=1) + condition_output_idx = np.argmax(prioritized_conditions, axis=1) return outputs.iloc[ self._lookup_output_idx(format_inputs, condition_output_idx) ].reset_index(drop=True) diff --git a/spockflow/components/tree/v1/core.py b/spockflow/components/tree/v1/core.py index 065fa8b..0297b9b 100644 --- a/spockflow/components/tree/v1/core.py +++ b/spockflow/components/tree/v1/core.py @@ -1,61 +1,172 @@ import typing +import re +import keyword +import functools import pandas as pd import collections.abc from functools import partial from typing_extensions import Self from abc import ABC, abstractmethod -from pydantic import BaseModel, Field, model_validator, ConfigDict +from pydantic import BaseModel, Field, model_validator, ConfigDict, field_serializer, PrivateAttr, AfterValidator, model_serializer from spockflow.nodes import VariableNode +from ...dtable import DecisionTable +from hamilton.node import DependencyType +from spockflow._serializable import DataFrame, Series, dump_df_to_dict, dump_series_to_dict -TOutput = typing.Union[typing.Callable[..., pd.DataFrame], pd.DataFrame, str] -TCond = typing.Union[typing.Callable[..., pd.Series], pd.Series, str] +if typing.TYPE_CHECKING: + from .compiled import CompiledNumpyTree + + + +def _is_valid_function_name(name): + pattern = r'^[a-zA-Z_][a-zA-Z0-9_]*$' + assert re.match(pattern, name) and name not in keyword.kwlist, f"{name} must be a valid python function name" + return name + + +class TableCondition(BaseModel): + name: typing.Annotated[str, AfterValidator(_is_valid_function_name)] + table: DecisionTable + +TOutput = typing.Union[typing.Callable[..., pd.DataFrame], DataFrame, str] +TCond = typing.Union[typing.Callable[..., pd.Series], Series, str] +TCondRaw = typing.Union[typing.Callable[..., pd.Series], pd.Series, TableCondition, str] + +_TABLE_VALUE_KEY = "value" + +def _length_attr(attr): + if attr is None: + return 1 + if isinstance(attr, str): + return 1 + if isinstance(attr, TableCondition): + return 1 + if not isinstance(attr, collections.abc.Sized): + return 1 + return len(attr) + +def _serialize_value(value: typing.Union[TOutput, "ChildTree", None]): + if value is None: return value + if isinstance(value, typing.Callable): + return value.__name__ + if isinstance(value, pd.DataFrame): + return dump_df_to_dict(value) + res = value.to_dict(orient='records') + if len(res) == 1: + return res[0] + return res + return value + +class TableConditionedNode(BaseModel): + condition_type: typing.Literal["table"] = "table" + # model_config = ConfigDict(arbitrary_types_allowed=True) + values: typing.List[typing.Union[TOutput, "ChildTree", None]] = None + condition_table: str + priority: typing.Optional[typing.List[int]] = None + + + @field_serializer('values') + def serialize_values(self, values: typing.List[typing.Union[TOutput, "ChildTree", None]], _info): + return [_serialize_value(v) for v in values] + + + @model_validator(mode="after") + def check_compatible_lengths(self) -> Self: + self.ensure_length() + return self + + def __len__(self): + len_values = (_length_attr(v) for v in self.values) + try: + return next(v for v in len_values if v !=1) + except StopIteration: + return 1 + + def ensure_length(self, tree_length: int=1): + len_values = [_length_attr(v) for v in self.values] + [tree_length] + try: + non_unit_value = next(v for v in len_values if v !=1) + except StopIteration: + non_unit_value = 1 + if not all(v==1 or v==non_unit_value for v in len_values): + raise ValueError("Incompatible value lengths detected") + + def _check_compatible_table(self, table:DecisionTable): + assert len(table.outputs) == 1, "Table must have exactly one output" + assert _TABLE_VALUE_KEY in table.outputs, f"Table must contain \"{_TABLE_VALUE_KEY}\" as an output key" + default_value = set() + if table.default_value is not None: + assert len(table.default_value.keys()) == 1, "Table must have exactly one output in the default values" + assert _TABLE_VALUE_KEY in table.default_value, f"Table must contain \"{_TABLE_VALUE_KEY}\" as an output key in the default values" + assert len(table.default_value) == 1, f"Default value must be a dataframe of length 1." + default_value = {table.default_value[_TABLE_VALUE_KEY].values[0]} + table_output_values = set(table.outputs[_TABLE_VALUE_KEY]) | default_value + last_value = len(table_output_values) + assert set(range(0,last_value)) == table_output_values, "Table output values must be sequential integer indicies" + assert last_value == len(self.values), "There must be one output value for each index in the tree outputs" + if self.priority is not None: + assert last_value == len(self.priority), "There must be one priority item for each index in the tree outputs" class ConditionedNode(BaseModel): # TODO fix for pd.DataFrame - model_config = ConfigDict(arbitrary_types_allowed=True) + condition_type: typing.Literal["base"] = "base" + # model_config = ConfigDict(arbitrary_types_allowed=True) value: typing.Union[TOutput, "ChildTree", None] = None condition: typing.Optional[TCond] = None + priority: typing.Optional[int] = None + + + @field_serializer('condition') + def serialize_condition(self, condition: typing.Optional[TCond], _info): + if condition is None: return condition + if isinstance(condition, typing.Callable): + return condition.__name__ + if isinstance(condition, pd.Series): + return dump_series_to_dict(condition) + values = condition.tolist() + return {condition.name: values if len(values) > 1 else values[0]} + return condition + + @field_serializer('value') + def serialize_value(self, value: typing.Union[TOutput, "ChildTree", None], _info): + return _serialize_value(value) - @staticmethod - def _length_attr(attr): - if attr is None: - return 1 - if isinstance(attr, str): - return 1 - if not isinstance(attr, collections.abc.Sized): - return 1 - return len(attr) def __len__(self): - len_attr = self._length_attr(self.value) + len_attr = _length_attr(self.value) if len_attr == 1: - len_attr = self._length_attr(self.condition) + len_attr = _length_attr(self.condition) return len_attr @model_validator(mode="after") def check_compatible_lengths(self) -> Self: - len_value = self._length_attr(self.value) - if len_value == 1: - return self - len_condition = self._length_attr(self.condition) - if len_condition == 1: - return self - if len_condition != len_value: - raise ValueError("Condition and value lengths incompatible") + self.ensure_length() return self + + def ensure_length(self, tree_length: int=1): + len_value = _length_attr(self.value) + len_condition = _length_attr(self.condition) + count_unit_length = (len_value==1) + (len_condition==1) + (tree_length==1) + if count_unit_length >= 2: return + if (len_value == len_condition == tree_length): return + raise ValueError("Condition and value and tree lengths are incompatible") + +TConditionedNode = typing.Annotated[typing.Union[ConditionedNode, TableConditionedNode], Field(discriminator='condition_type')] + class ChildTree(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - nodes: typing.List[ConditionedNode] = Field(default_factory=list) + # model_config = ConfigDict(arbitrary_types_allowed=True) + nodes: typing.List[TConditionedNode] = Field(default_factory=list) default_value: typing.Optional[TOutput] = None + _decision_tables: typing.Dict[str,DecisionTable] = PrivateAttr(default_factory=dict) def __len__( self, ): # TODO could maybe cache this on bigger trees if the values havent changed - node_len = ConditionedNode._length_attr(self.default_value) + node_len = _length_attr(self.default_value) if node_len != 1: return node_len for node in self.nodes: @@ -65,7 +176,7 @@ def __len__( @model_validator(mode="after") def check_compatible_lengths(self) -> Self: - child_tree_len = ConditionedNode._length_attr(self.default_value) + child_tree_len = _length_attr(self.default_value) for node in self.nodes: len_value = len(node) if len_value == 1: @@ -77,34 +188,59 @@ def check_compatible_lengths(self) -> Self: f"Lengths of values or conditions in the tree is incompatible. Found {child_tree_len} != {len_value}." ) return self - - def add_node(self, value: TOutput, condition: TCond, **kwargs) -> ConditionedNode: - len_value = ConditionedNode._length_attr(value) - len_condition = ConditionedNode._length_attr(condition) - if len_value != 1 and len_condition != 1 and len_value != len_condition: - raise ValueError( - f"Cannot add node as the length of the value ({len_value}) is not compatible with the length of the condition ({len_condition})." + + @staticmethod + def _merge_decision_tables(to_be_updated: typing.Dict[str,DecisionTable], other: typing.Dict[str,DecisionTable]): + for k,v in other.items(): + if k in to_be_updated: + assert to_be_updated[k] is v, f"Decision table {k} added twice with different values." + else: + to_be_updated[k] = v + + # def _get_condition_from_table(self, condition: TCondRaw): + # if not isinstance(condition, TableCondition): return condition + # self._merge_decision_tables( + # self._decision_tables, + # {condition.name: condition.table} + # ) + # return TableConditionReference(table=condition.name) + + + def add_node( + self, + value: typing.Union[TOutput, typing.List[TOutput]], + condition: TCondRaw, + priority: typing.Union[int, typing.List[int], None]=None, + **kwargs + ) -> ConditionedNode: + + if isinstance(condition, TableCondition): + node = TableConditionedNode( + values=value, + condition_table=condition.name, + priority=priority, + **kwargs ) - if len_value != 1 or len_condition != 1: - # TODO adding this allows better validation but requires circular loops so hard for pydantic to serialise - # len_tree = len(self.root_tree.root) - len_tree = len(self) - if len_value != 1 and len_tree != 1 and len_value != len_tree: - raise ValueError( - f"Cannot add node as the length of the value ({len_value}) incompatible with tree {len_tree}." - ) - if len_condition != 1 and len_tree != 1 and len_condition != len_tree: - raise ValueError( - f"Cannot add node as the length of the condition ({len_condition}) incompatible with tree {len_tree}." - ) - node = ConditionedNode(value=value, condition=condition, **kwargs) + self._merge_decision_tables( + self._decision_tables, + {condition.name: condition.table} + ) + node._check_compatible_table(condition.table) + else: + node = ConditionedNode( + value=value, + condition=condition, + priority=priority, + **kwargs + ) + node.ensure_length(len(self)) self.nodes.append(node) return node def set_default(self, value: TOutput): if self.default_value is not None: raise ValueError("Default value already set") - len_value = ConditionedNode._length_attr(value) + len_value = _length_attr(value) if len_value != 1: # TODO adding this allows better validation but requires circular loops so hard for pydantic to serialise # len_tree = len(self.root_tree.root) @@ -127,8 +263,28 @@ def merge_into(self, other: Self): raise ValueError( f"Cannot merge two subtrees both containing default values" ) + + self._merge_decision_tables( + self._decision_tables, + other._decision_tables + ) + self.nodes.extend(other.nodes) + def get_all_decision_tables(self): + tables = self._decision_tables.copy() + for node in self.nodes: + if isinstance(node, TableConditionedNode): + node_values = node.values + else: + node_values = [node.value] + for nv in node_values: + if isinstance(nv, ChildTree): + self._merge_decision_tables( + tables, + nv.get_all_decision_tables() + ) + return tables class WrappedTreeFunction(ABC): @abstractmethod @@ -145,18 +301,76 @@ def include_subtree( ): ... +def _check_table_result_eq(value: int, **kwargs: pd.DataFrame) -> pd.Series: + table_result = kwargs[next(iter(kwargs.keys()))] + return table_result[_TABLE_VALUE_KEY] == value + class Tree(VariableNode): doc: str = "This executes a user defined decision tree" root: ChildTree = Field(default_factory=ChildTree) + decision_tables: typing.Dict[str,DecisionTable] = Field(default_factory=dict) + + def get_decision_tables(self): + decision_tables = self.decision_tables.copy() + self.root._merge_decision_tables( + decision_tables, + self.root.get_all_decision_tables() + ) + return decision_tables + + @model_serializer() + def serialize_model(self): + return { + 'doc': self.doc, + 'root': self.root, + "decision_tables": self.get_decision_tables() + } def compile(self): from .compiled import CompiledNumpyTree return CompiledNumpyTree(self) + + def _generate_nodes( + self, + name: str, + config: "typing.Dict[str, typing.Any]", + include_runtime_nodes: bool = False, + ) -> "typing.List[node.Node]": + from hamilton import node + compiled_node = self.compile() + output_nodes = super()._generate_nodes( + name = name, + config = config, + include_runtime_nodes = include_runtime_nodes, + compiled_node_override=compiled_node + ) + base_node_ = node.Node.from_fn( + _check_table_result_eq, name="temporary_node" + ) + for table_name, compiled_table in compiled_node.decision_tables.items(): + output_nodes.extend(compiled_table._generate_nodes(table_name, config, True)) + unique_values = set(compiled_table.outputs[_TABLE_VALUE_KEY]) + for v in unique_values: + v = int(v) # Just to be safe + + output_nodes.append(base_node_.copy_with( + name=f"{table_name}_is_{v}", + doc_string=f"This is a function used to determine if the result of {table_name} is {v} to be used in a decision tree.", + callabl=functools.partial( + _check_table_result_eq, + value = v + ), + input_types={table_name: (pd.DataFrame, DependencyType.REQUIRED)}, + include_refs=False, + )) + + return output_nodes def _generate_runtime_nodes( self, config: "typing.Dict[str, typing.Any]", compiled_node: "CompiledNumpyTree" ) -> "typing.List[node.Node]": + # This is used to extract the condition functions when running outside of hamilton context from hamilton import node return [ @@ -216,7 +430,14 @@ def _identify_loops(self, *nodes: "ConditionedNode"): if id(el) in seen: raise ValueError("Tree must not contain any loops") seen.add(id(el)) - if isinstance(el.value, ChildTree): + if isinstance(el, TableConditionedNode): + for v in el.values: + if isinstance(v, ChildTree): + q.extend(v.nodes) + if isinstance(v.default_value, ChildTree): + q.extend(v.default_value.nodes) + + elif isinstance(el.value, ChildTree): q.extend(el.value.nodes) if isinstance(el.value.default_value, ChildTree): q.extend(el.value.default_value.nodes) @@ -242,8 +463,9 @@ def copy(self, deep=True): def condition( self, output: TOutput = None, - condition: typing.Union[TCond, None] = None, + condition: typing.Optional[TCondRaw] = None, child_tree: ChildTree = None, + priority: typing.Optional[int] = None, **kwargs, ) -> typing.Callable[..., WrappedTreeFunction]: """ @@ -288,7 +510,7 @@ def wrapper(condition: TCond): nonlocal output if isinstance(output, Tree): output = output.root - node = child_tree.add_node(value=output, condition=condition, **kwargs) + node = child_tree.add_node(value=output, condition=condition, priority=priority, **kwargs) try: self._identify_loops(node) except ValueError as e: @@ -455,21 +677,30 @@ def visualize(self, get_value_name=None, get_condition_name=None): dot.node(curr_name, curr_name) for node in curr.nodes: - node_condition_name = get_condition_name(node.condition) + node_condition_name = get_condition_name( + node.condition_table + if isinstance(node, TableConditionedNode) + else node.condition + ) dot.node(node_condition_name, node_condition_name) dot.edge(curr_name, node_condition_name) - if hasattr(node.value, "nodes"): - to_search.extend([(node.value, node_condition_name)]) - elif node.value is not None: - node_value_name = get_value_name(node.value) - dot.node( - node_value_name, - node_value_name, - style="filled", - fillcolor="#ADDFFF", - shape="rectangle", - ) - dot.edge(node_condition_name, node_value_name) + if isinstance(node, TableConditionedNode): + node_values = node.values + else: + node_values = [node.value] + for val in node_values: + if hasattr(val, "nodes"): + to_search.extend([(val, node_condition_name)]) + elif val is not None: + node_value_name = get_value_name(val) + dot.node( + node_value_name, + node_value_name, + style="filled", + fillcolor="#ADDFFF", + shape="rectangle", + ) + dot.edge(node_condition_name, node_value_name) if curr.default_value is not None: default_name = get_value_name(curr.default_value) diff --git a/spockflow/nodes.py b/spockflow/nodes.py index 75d921b..3eef337 100644 --- a/spockflow/nodes.py +++ b/spockflow/nodes.py @@ -322,6 +322,7 @@ def _generate_nodes( name: str, config: "typing.Dict[str, typing.Any]", include_runtime_nodes: bool = False, + compiled_node_override: typing.Optional[Self] = None ) -> "typing.List[node.Node]": """Generate nodes for this class to be used in a hamilton dag @@ -333,7 +334,11 @@ def _generate_nodes( Returns: List[node.Node]: The resulting Hamilton nodes """ - compiled_variable_node = self.compile() + if compiled_node_override is None: + compiled_variable_node = self.compile() + else: + compiled_variable_node = self.compile() + node_functions = inspect.getmembers( compiled_variable_node, predicate=self._does_define_node ) From ea4b1c631602d9de0c0ced1c74f61230e164be9a Mon Sep 17 00:00:00 2001 From: Sholto Armstrong Date: Tue, 14 Jan 2025 11:59:12 +0200 Subject: [PATCH 2/2] black formatting --- spockflow/_serializable.py | 76 +++++--- spockflow/components/tree/__init__.py | 2 +- spockflow/components/tree/v1/compiled.py | 47 +++-- spockflow/components/tree/v1/core.py | 237 +++++++++++++---------- spockflow/nodes.py | 2 +- 5 files changed, 221 insertions(+), 143 deletions(-) diff --git a/spockflow/_serializable.py b/spockflow/_serializable.py index 395b879..5438be3 100644 --- a/spockflow/_serializable.py +++ b/spockflow/_serializable.py @@ -13,26 +13,44 @@ values_schema = core_schema.dict_schema(keys_schema=core_schema.str_schema()) -dataframe_json_schema = core_schema.model_fields_schema({ - "type": core_schema.model_field(core_schema.literal_schema(["DataFrame"])), - "values": core_schema.model_field(core_schema.union_schema([values_schema, core_schema.list_schema(values_schema)])), - "dtypes": core_schema.model_field( - core_schema.with_default_schema( - core_schema.union_schema([core_schema.none_schema(), core_schema.dict_schema(keys_schema=core_schema.str_schema(), values_schema=core_schema.str_schema())]), - default=None, - ) - ) -}) -series_json_schema = core_schema.model_fields_schema({ - "type": core_schema.model_field(core_schema.literal_schema(["Series"])), - "values": core_schema.model_field(core_schema.list_schema()), - "name": core_schema.model_field( - core_schema.with_default_schema( - core_schema.union_schema([core_schema.none_schema(), core_schema.str_schema()]), - default=None, - ) - ) -}) +dataframe_json_schema = core_schema.model_fields_schema( + { + "type": core_schema.model_field(core_schema.literal_schema(["DataFrame"])), + "values": core_schema.model_field( + core_schema.union_schema( + [values_schema, core_schema.list_schema(values_schema)] + ) + ), + "dtypes": core_schema.model_field( + core_schema.with_default_schema( + core_schema.union_schema( + [ + core_schema.none_schema(), + core_schema.dict_schema( + keys_schema=core_schema.str_schema(), + values_schema=core_schema.str_schema(), + ), + ] + ), + default=None, + ) + ), + } +) +series_json_schema = core_schema.model_fields_schema( + { + "type": core_schema.model_field(core_schema.literal_schema(["Series"])), + "values": core_schema.model_field(core_schema.list_schema()), + "name": core_schema.model_field( + core_schema.with_default_schema( + core_schema.union_schema( + [core_schema.none_schema(), core_schema.str_schema()] + ), + default=None, + ) + ), + } +) def validate_to_dataframe_schema(value: dict) -> pd.DataFrame: @@ -51,6 +69,7 @@ def validate_to_series_schema(value: dict) -> pd.Series: value, *_ = value return pd.Series(value["values"], name=value.get("name")) + from_df_dict_schema = core_schema.chain_schema( [ dataframe_json_schema, @@ -66,15 +85,22 @@ def validate_to_series_schema(value: dict) -> pd.Series: ] ) + def dump_df_to_dict(instance: pd.DataFrame) -> dict: values = instance.to_dict(orient="records") if len(values) == 1: values = values[0] - return {"type": "DataFrame", "values": values, "dtypes": {k: str(v) for k,v in instance.dtypes.items()}} + return { + "type": "DataFrame", + "values": values, + "dtypes": {k: str(v) for k, v in instance.dtypes.items()}, + } + def dump_series_to_dict(instance: pd.Series) -> dict: return {"type": "Series", "values": instance.to_list(), "name": instance.name} + class _PandasDataFramePydanticAnnotation: @classmethod def __get_pydantic_core_schema__( @@ -101,8 +127,9 @@ def __get_pydantic_core_schema__( def __get_pydantic_json_schema__( cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler ) -> JsonSchemaValue: - return handler(dataframe_json_schema) - + return handler(dataframe_json_schema) + + class _PandasSeriesPydanticAnnotation: @classmethod def __get_pydantic_core_schema__( @@ -129,9 +156,8 @@ def __get_pydantic_core_schema__( def __get_pydantic_json_schema__( cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler ) -> JsonSchemaValue: - return handler(dataframe_json_schema) + return handler(dataframe_json_schema) DataFrame = Annotated[pd.DataFrame, _PandasDataFramePydanticAnnotation] Series = Annotated[pd.Series, _PandasSeriesPydanticAnnotation] - diff --git a/spockflow/components/tree/__init__.py b/spockflow/components/tree/__init__.py index 33e98c2..b36c7c3 100644 --- a/spockflow/components/tree/__init__.py +++ b/spockflow/components/tree/__init__.py @@ -2,7 +2,7 @@ import pandas as pd from spockflow.components.tree.v1.core import ( Tree as Tree, - TableCondition as TableCondition + TableCondition as TableCondition, ) T = TypeVar("T", bound=dict) diff --git a/spockflow/components/tree/v1/compiled.py b/spockflow/components/tree/v1/compiled.py index 5303990..7bfb893 100644 --- a/spockflow/components/tree/v1/compiled.py +++ b/spockflow/components/tree/v1/compiled.py @@ -6,6 +6,7 @@ from spockflow.components.tree.settings import settings from .core import ChildTree, Tree, TOutput, TCond, ConditionedNode, TableConditionedNode from spockflow.nodes import creates_node + if typing.TYPE_CHECKING: from spockflow.components.dtable import DecisionTable @@ -42,9 +43,10 @@ class SymbolicFlatTree: class CompiledNumpyTree: def __init__(self, tree: Tree) -> None: self.decision_tables: typing.Dict[str, "DecisionTable"] = {} - for k,v in tree.get_decision_tables().items(): - assert list(v.outputs.keys()) == ["value"], \ - f"Decision tables nested in decision trees can only have one output named \"value\"" + for k, v in tree.get_decision_tables().items(): + assert list(v.outputs.keys()) == [ + "value" + ], f'Decision tables nested in decision trees can only have one output named "value"' # Decision tables dont need compilation but keeping with convention self.decision_tables[k] = v.compile() @@ -58,7 +60,9 @@ def __init__(self, tree: Tree) -> None: ) flattened_tree = self._get_symbolic_tree(flattened_tree) self._flattened_tree = flattened_tree - self._flattened_priority = np.array([n.priority for n in self._flattened_tree.tree], dtype=np.int32) + self._flattened_priority = np.array( + [n.priority for n in self._flattened_tree.tree], dtype=np.int32 + ) self._has_priority = any(self._flattened_priority != 1) (predefined_conditions, predefined_condition_names, execution_conditions) = ( @@ -243,10 +247,12 @@ def get_unique_name(v): ) @staticmethod - def _merge_priority(p1: typing.Optional[int], p2: typing.Optional[int]) -> typing.Optional[int]: + def _merge_priority( + p1: typing.Optional[int], p2: typing.Optional[int] + ) -> typing.Optional[int]: if p1 is None and p2 is None: return None - return (p1 or 0) + (p2 or 0) # Also caps values to > 0 + return (p1 or 0) + (p2 or 0) # Also caps values to > 0 def _flatten_tree( self, @@ -254,7 +260,7 @@ def _flatten_tree( current_conditions: typing.Tuple[TCond], seen: typing.Set[int], conditioned_outputs: typing.List[ConditionedOutput], - priority: typing.Optional[int]=None, + priority: typing.Optional[int] = None, ) -> typing.List[ConditionedOutput]: curr_id = id(sub_tree) if curr_id in seen: @@ -280,17 +286,19 @@ def _flatten_tree( current_conditions=n_conditions, seen=seen.union([curr_id]), conditioned_outputs=conditioned_outputs, - priority=self._merge_priority(priority, n.priority) + priority=self._merge_priority(priority, n.priority), ) else: - conditioned_outputs.append(ConditionedOutput(n.value, n_conditions, max(priority or 0,1))) + conditioned_outputs.append( + ConditionedOutput(n.value, n_conditions, max(priority or 0, 1)) + ) else: if n.condition_table not in self.decision_tables: raise ValueError( f"Found node conditioned on a table ({n.condition_table}) that is not registered within the tree." ) n._check_compatible_table(self.decision_tables[n.condition_table]) - for i,v in enumerate(n.values): + for i, v in enumerate(n.values): if v is None: raise ValueError( "All nodes must have a value set to be a valid tree.\n" @@ -304,14 +312,18 @@ def _flatten_tree( current_conditions=n_conditions, seen=seen.union([curr_id]), conditioned_outputs=conditioned_outputs, - priority=self._merge_priority(priority, n_priority) + priority=self._merge_priority(priority, n_priority), ) else: - conditioned_outputs.append(ConditionedOutput(v, n_conditions, max(priority or 0,1))) - + conditioned_outputs.append( + ConditionedOutput(v, n_conditions, max(priority or 0, 1)) + ) + if sub_tree.default_value is not None: conditioned_outputs.append( - ConditionedOutput(sub_tree.default_value, current_conditions, max(priority or 0,1)) + ConditionedOutput( + sub_tree.default_value, current_conditions, max(priority or 0, 1) + ) ) return conditioned_outputs @@ -390,13 +402,14 @@ def conditions_met(self, format_inputs: TFormatData) -> np.ndarray: # [O,C]@[C,N] => [O,N] (Matrix multiplication should be the same as performing a count of all true statements) # The thresh will see where they are all true return (conditions @ self.truth_table) >= self.truth_table_thresh - + @creates_node() def prioritized_conditions(self, conditions_met: np.ndarray) -> np.ndarray: - if not self._has_priority: return conditions_met + if not self._has_priority: + return conditions_met # [O] * [O,N] -> [O,N] (outputs, batch) return self._flattened_priority * conditions_met - + @creates_node() def condition_names(self, format_inputs: TFormatData) -> typing.List[str]: conditions = format_inputs[0] diff --git a/spockflow/components/tree/v1/core.py b/spockflow/components/tree/v1/core.py index 72602df..1417bc1 100644 --- a/spockflow/components/tree/v1/core.py +++ b/spockflow/components/tree/v1/core.py @@ -7,34 +7,51 @@ from functools import partial from typing_extensions import Self from abc import ABC, abstractmethod -from pydantic import BaseModel, Field, model_validator, ConfigDict, field_serializer, PrivateAttr, AfterValidator, model_serializer +from pydantic import ( + BaseModel, + Field, + model_validator, + ConfigDict, + field_serializer, + PrivateAttr, + AfterValidator, + model_serializer, +) from spockflow.nodes import VariableNode from ...dtable import DecisionTable from hamilton.node import DependencyType -from spockflow._serializable import DataFrame, Series, dump_df_to_dict, dump_series_to_dict +from spockflow._serializable import ( + DataFrame, + Series, + dump_df_to_dict, + dump_series_to_dict, +) if typing.TYPE_CHECKING: from .compiled import CompiledNumpyTree - def _is_valid_function_name(name): - pattern = r'^[a-zA-Z_][a-zA-Z0-9_]*$' - assert re.match(pattern, name) and name not in keyword.kwlist, f"{name} must be a valid python function name" + pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$" + assert ( + re.match(pattern, name) and name not in keyword.kwlist + ), f"{name} must be a valid python function name" return name class TableCondition(BaseModel): - name: typing.Annotated[str, AfterValidator(_is_valid_function_name)] + name: typing.Annotated[str, AfterValidator(_is_valid_function_name)] table: DecisionTable + TOutput = typing.Union[typing.Callable[..., pd.DataFrame], DataFrame, str] TCond = typing.Union[typing.Callable[..., pd.Series], Series, str] TCondRaw = typing.Union[typing.Callable[..., pd.Series], pd.Series, TableCondition, str] _TABLE_VALUE_KEY = "value" + def _length_attr(attr): if attr is None: return 1 @@ -46,18 +63,21 @@ def _length_attr(attr): return 1 return len(attr) + def _serialize_value(value: typing.Union[TOutput, "ChildTree", None]): - if value is None: return value + if value is None: + return value if isinstance(value, typing.Callable): return value.__name__ if isinstance(value, pd.DataFrame): return dump_df_to_dict(value) - res = value.to_dict(orient='records') + res = value.to_dict(orient="records") if len(res) == 1: return res[0] return res return value + class TableConditionedNode(BaseModel): condition_type: typing.Literal["table"] = "table" # model_config = ConfigDict(arbitrary_types_allowed=True) @@ -65,48 +85,62 @@ class TableConditionedNode(BaseModel): condition_table: str priority: typing.Optional[typing.List[int]] = None - - @field_serializer('values') - def serialize_values(self, values: typing.List[typing.Union[TOutput, "ChildTree", None]], _info): + @field_serializer("values") + def serialize_values( + self, values: typing.List[typing.Union[TOutput, "ChildTree", None]], _info + ): return [_serialize_value(v) for v in values] - @model_validator(mode="after") def check_compatible_lengths(self) -> Self: self.ensure_length() return self - + def __len__(self): len_values = (_length_attr(v) for v in self.values) try: - return next(v for v in len_values if v !=1) + return next(v for v in len_values if v != 1) except StopIteration: return 1 - - def ensure_length(self, tree_length: int=1): + + def ensure_length(self, tree_length: int = 1): len_values = [_length_attr(v) for v in self.values] + [tree_length] try: - non_unit_value = next(v for v in len_values if v !=1) + non_unit_value = next(v for v in len_values if v != 1) except StopIteration: non_unit_value = 1 - if not all(v==1 or v==non_unit_value for v in len_values): + if not all(v == 1 or v == non_unit_value for v in len_values): raise ValueError("Incompatible value lengths detected") - def _check_compatible_table(self, table:DecisionTable): + def _check_compatible_table(self, table: DecisionTable): assert len(table.outputs) == 1, "Table must have exactly one output" - assert _TABLE_VALUE_KEY in table.outputs, f"Table must contain \"{_TABLE_VALUE_KEY}\" as an output key" + assert ( + _TABLE_VALUE_KEY in table.outputs + ), f'Table must contain "{_TABLE_VALUE_KEY}" as an output key' default_value = set() if table.default_value is not None: - assert len(table.default_value.keys()) == 1, "Table must have exactly one output in the default values" - assert _TABLE_VALUE_KEY in table.default_value, f"Table must contain \"{_TABLE_VALUE_KEY}\" as an output key in the default values" - assert len(table.default_value) == 1, f"Default value must be a dataframe of length 1." + assert ( + len(table.default_value.keys()) == 1 + ), "Table must have exactly one output in the default values" + assert ( + _TABLE_VALUE_KEY in table.default_value + ), f'Table must contain "{_TABLE_VALUE_KEY}" as an output key in the default values' + assert ( + len(table.default_value) == 1 + ), f"Default value must be a dataframe of length 1." default_value = {table.default_value[_TABLE_VALUE_KEY].values[0]} table_output_values = set(table.outputs[_TABLE_VALUE_KEY]) | default_value last_value = len(table_output_values) - assert set(range(0,last_value)) == table_output_values, "Table output values must be sequential integer indicies" - assert last_value == len(self.values), "There must be one output value for each index in the tree outputs" + assert ( + set(range(0, last_value)) == table_output_values + ), "Table output values must be sequential integer indicies" + assert last_value == len( + self.values + ), "There must be one output value for each index in the tree outputs" if self.priority is not None: - assert last_value == len(self.priority), "There must be one priority item for each index in the tree outputs" + assert last_value == len( + self.priority + ), "There must be one priority item for each index in the tree outputs" class ConditionedNode(BaseModel): @@ -117,10 +151,10 @@ class ConditionedNode(BaseModel): condition: typing.Optional[TCond] = None priority: typing.Optional[int] = None - - @field_serializer('condition') + @field_serializer("condition") def serialize_condition(self, condition: typing.Optional[TCond], _info): - if condition is None: return condition + if condition is None: + return condition if isinstance(condition, typing.Callable): return condition.__name__ if isinstance(condition, pd.Series): @@ -128,12 +162,11 @@ def serialize_condition(self, condition: typing.Optional[TCond], _info): values = condition.tolist() return {condition.name: values if len(values) > 1 else values[0]} return condition - - @field_serializer('value') + + @field_serializer("value") def serialize_value(self, value: typing.Union[TOutput, "ChildTree", None], _info): return _serialize_value(value) - def __len__(self): len_attr = _length_attr(self.value) if len_attr == 1: @@ -144,24 +177,31 @@ def __len__(self): def check_compatible_lengths(self) -> Self: self.ensure_length() return self - - def ensure_length(self, tree_length: int=1): + + def ensure_length(self, tree_length: int = 1): len_value = _length_attr(self.value) len_condition = _length_attr(self.condition) - count_unit_length = (len_value==1) + (len_condition==1) + (tree_length==1) - if count_unit_length >= 2: return - if (len_value == len_condition == tree_length): return + count_unit_length = (len_value == 1) + (len_condition == 1) + (tree_length == 1) + if count_unit_length >= 2: + return + if len_value == len_condition == tree_length: + return raise ValueError("Condition and value and tree lengths are incompatible") +TConditionedNode = typing.Annotated[ + typing.Union[ConditionedNode, TableConditionedNode], + Field(discriminator="condition_type"), +] -TConditionedNode = typing.Annotated[typing.Union[ConditionedNode, TableConditionedNode], Field(discriminator='condition_type')] class ChildTree(BaseModel): # model_config = ConfigDict(arbitrary_types_allowed=True) nodes: typing.List[TConditionedNode] = Field(default_factory=list) default_value: typing.Optional[TOutput] = None - _decision_tables: typing.Dict[str,DecisionTable] = PrivateAttr(default_factory=dict) + _decision_tables: typing.Dict[str, DecisionTable] = PrivateAttr( + default_factory=dict + ) def __len__( self, @@ -188,15 +228,20 @@ def check_compatible_lengths(self) -> Self: f"Lengths of values or conditions in the tree is incompatible. Found {child_tree_len} != {len_value}." ) return self - + @staticmethod - def _merge_decision_tables(to_be_updated: typing.Dict[str,DecisionTable], other: typing.Dict[str,DecisionTable]): - for k,v in other.items(): + def _merge_decision_tables( + to_be_updated: typing.Dict[str, DecisionTable], + other: typing.Dict[str, DecisionTable], + ): + for k, v in other.items(): if k in to_be_updated: - assert to_be_updated[k] is v, f"Decision table {k} added twice with different values." + assert ( + to_be_updated[k] is v + ), f"Decision table {k} added twice with different values." else: to_be_updated[k] = v - + # def _get_condition_from_table(self, condition: TCondRaw): # if not isinstance(condition, TableCondition): return condition # self._merge_decision_tables( @@ -205,33 +250,28 @@ def _merge_decision_tables(to_be_updated: typing.Dict[str,DecisionTable], other: # ) # return TableConditionReference(table=condition.name) - def add_node( - self, - value: typing.Union[TOutput, typing.List[TOutput]], - condition: TCondRaw, - priority: typing.Union[int, typing.List[int], None]=None, - **kwargs + self, + value: typing.Union[TOutput, typing.List[TOutput]], + condition: TCondRaw, + priority: typing.Union[int, typing.List[int], None] = None, + **kwargs, ) -> ConditionedNode: if isinstance(condition, TableCondition): node = TableConditionedNode( - values=value, + values=value, condition_table=condition.name, - priority=priority, - **kwargs + priority=priority, + **kwargs, ) self._merge_decision_tables( - self._decision_tables, - {condition.name: condition.table} + self._decision_tables, {condition.name: condition.table} ) node._check_compatible_table(condition.table) else: node = ConditionedNode( - value=value, - condition=condition, - priority=priority, - **kwargs + value=value, condition=condition, priority=priority, **kwargs ) node.ensure_length(len(self)) self.nodes.append(node) @@ -265,10 +305,7 @@ def merge_into(self, other: Self): f"Cannot merge two subtrees both containing default values" ) - self._merge_decision_tables( - self._decision_tables, - other._decision_tables - ) + self._merge_decision_tables(self._decision_tables, other._decision_tables) if other.default_value is not None: self.set_default(other.default_value) @@ -284,12 +321,10 @@ def get_all_decision_tables(self): node_values = [node.value] for nv in node_values: if isinstance(nv, ChildTree): - self._merge_decision_tables( - tables, - nv.get_all_decision_tables() - ) + self._merge_decision_tables(tables, nv.get_all_decision_tables()) return tables + class WrappedTreeFunction(ABC): @abstractmethod def __call__(self, *args: typing.Any, **kwds: typing.Any) -> typing.Any: ... @@ -309,32 +344,32 @@ def _check_table_result_eq(value: int, **kwargs: pd.DataFrame) -> pd.Series: table_result = kwargs[next(iter(kwargs.keys()))] return table_result[_TABLE_VALUE_KEY] == value + class Tree(VariableNode): doc: str = "This executes a user defined decision tree" root: ChildTree = Field(default_factory=ChildTree) - decision_tables: typing.Dict[str,DecisionTable] = Field(default_factory=dict) + decision_tables: typing.Dict[str, DecisionTable] = Field(default_factory=dict) def get_decision_tables(self): decision_tables = self.decision_tables.copy() self.root._merge_decision_tables( - decision_tables, - self.root.get_all_decision_tables() + decision_tables, self.root.get_all_decision_tables() ) return decision_tables @model_serializer() def serialize_model(self): return { - 'doc': self.doc, - 'root': self.root, - "decision_tables": self.get_decision_tables() + "doc": self.doc, + "root": self.root, + "decision_tables": self.get_decision_tables(), } def compile(self): from .compiled import CompiledNumpyTree return CompiledNumpyTree(self) - + def _generate_nodes( self, name: str, @@ -342,33 +377,35 @@ def _generate_nodes( include_runtime_nodes: bool = False, ) -> "typing.List[node.Node]": from hamilton import node + compiled_node = self.compile() output_nodes = super()._generate_nodes( - name = name, - config = config, - include_runtime_nodes = include_runtime_nodes, - compiled_node_override=compiled_node - ) - base_node_ = node.Node.from_fn( - _check_table_result_eq, name="temporary_node" + name=name, + config=config, + include_runtime_nodes=include_runtime_nodes, + compiled_node_override=compiled_node, ) + base_node_ = node.Node.from_fn(_check_table_result_eq, name="temporary_node") for table_name, compiled_table in compiled_node.decision_tables.items(): - output_nodes.extend(compiled_table._generate_nodes(table_name, config, True)) + output_nodes.extend( + compiled_table._generate_nodes(table_name, config, True) + ) unique_values = set(compiled_table.outputs[_TABLE_VALUE_KEY]) for v in unique_values: - v = int(v) # Just to be safe - - output_nodes.append(base_node_.copy_with( - name=f"{table_name}_is_{v}", - doc_string=f"This is a function used to determine if the result of {table_name} is {v} to be used in a decision tree.", - callabl=functools.partial( - _check_table_result_eq, - value = v - ), - input_types={table_name: (pd.DataFrame, DependencyType.REQUIRED)}, - include_refs=False, - )) - + v = int(v) # Just to be safe + + output_nodes.append( + base_node_.copy_with( + name=f"{table_name}_is_{v}", + doc_string=f"This is a function used to determine if the result of {table_name} is {v} to be used in a decision tree.", + callabl=functools.partial(_check_table_result_eq, value=v), + input_types={ + table_name: (pd.DataFrame, DependencyType.REQUIRED) + }, + include_refs=False, + ) + ) + return output_nodes def _generate_runtime_nodes( @@ -440,7 +477,7 @@ def _identify_loops(self, *nodes: "ConditionedNode"): q.extend(v.nodes) if isinstance(v.default_value, ChildTree): q.extend(v.default_value.nodes) - + elif isinstance(el.value, ChildTree): q.extend(el.value.nodes) if isinstance(el.value.default_value, ChildTree): @@ -515,7 +552,9 @@ def wrapper(condition: TCond): nonlocal output if isinstance(output, Tree): output = output.root - node = child_tree.add_node(value=output, condition=condition, priority=priority, **kwargs) + node = child_tree.add_node( + value=output, condition=condition, priority=priority, **kwargs + ) try: self._identify_loops(node) except ValueError as e: @@ -684,8 +723,8 @@ def visualize(self, get_value_name=None, get_condition_name=None): for node in curr.nodes: node_condition_name = get_condition_name( - node.condition_table - if isinstance(node, TableConditionedNode) + node.condition_table + if isinstance(node, TableConditionedNode) else node.condition ) dot.node(node_condition_name, node_condition_name) diff --git a/spockflow/nodes.py b/spockflow/nodes.py index 3eef337..42d6113 100644 --- a/spockflow/nodes.py +++ b/spockflow/nodes.py @@ -322,7 +322,7 @@ def _generate_nodes( name: str, config: "typing.Dict[str, typing.Any]", include_runtime_nodes: bool = False, - compiled_node_override: typing.Optional[Self] = None + compiled_node_override: typing.Optional[Self] = None, ) -> "typing.List[node.Node]": """Generate nodes for this class to be used in a hamilton dag