diff --git a/docs/_static/getting-started/example_pipeline_light.drawio.svg b/docs/_static/getting-started/example_pipeline_light.drawio.svg index 0cf3440..84feb97 100644 --- a/docs/_static/getting-started/example_pipeline_light.drawio.svg +++ b/docs/_static/getting-started/example_pipeline_light.drawio.svg @@ -1,4 +1,4 @@ - + @@ -53,13 +53,13 @@
- Criminal Record? + Clients Age?
- Criminal Record? + Clients Age? diff --git a/spockflow/_serializable.py b/spockflow/_serializable.py index a857c34..5438be3 100644 --- a/spockflow/_serializable.py +++ b/spockflow/_serializable.py @@ -1,82 +1,125 @@ import pandas as pd -from typing import Any, Union +from typing import Any from pydantic_core import core_schema from typing_extensions import Annotated from pydantic import ( - BaseModel, GetCoreSchemaHandler, GetJsonSchemaHandler, - ValidationError, ) from pydantic.json_schema import JsonSchemaValue -def dump_to_dict(instance: Union[pd.Series, pd.DataFrame]) -> dict: - if isinstance(instance, pd.DataFrame): - return {"type": "DataFrame", "values": instance.to_dict(orient="list")} - return {"values": instance.to_list(), "name": instance.name, "type": "Series"} +values_schema = core_schema.dict_schema(keys_schema=core_schema.str_schema()) + +dataframe_json_schema = core_schema.model_fields_schema( + { + "type": core_schema.model_field(core_schema.literal_schema(["DataFrame"])), + "values": core_schema.model_field( + core_schema.union_schema( + [values_schema, core_schema.list_schema(values_schema)] + ) + ), + "dtypes": core_schema.model_field( + core_schema.with_default_schema( + core_schema.union_schema( + [ + core_schema.none_schema(), + core_schema.dict_schema( + keys_schema=core_schema.str_schema(), + values_schema=core_schema.str_schema(), + ), + ] + ), + default=None, + ) + ), + } +) +series_json_schema = core_schema.model_fields_schema( + { + "type": core_schema.model_field(core_schema.literal_schema(["Series"])), + "values": core_schema.model_field(core_schema.list_schema()), + "name": core_schema.model_field( + core_schema.with_default_schema( + core_schema.union_schema( + [core_schema.none_schema(), core_schema.str_schema()] + ), + default=None, + ) + ), + } +) + + +def validate_to_dataframe_schema(value: dict) -> pd.DataFrame: + value, *_ = value + values = value["values"] + if isinstance(values, dict): + values = [values] + df = pd.DataFrame(values) + dtypes = value.get("dtypes") + if dtypes: + return df.astype(dtypes) + return df + +def validate_to_series_schema(value: dict) -> pd.Series: + value, *_ = value + return pd.Series(value["values"], name=value.get("name")) + + +from_df_dict_schema = core_schema.chain_schema( + [ + dataframe_json_schema, + # core_schema.dict_schema(), + core_schema.no_info_plain_validator_function(validate_to_dataframe_schema), + ] +) +from_series_dict_schema = core_schema.chain_schema( + [ + series_json_schema, + # core_schema.dict_schema(), + core_schema.no_info_plain_validator_function(validate_to_series_schema), + ] +) -# class Series(pd.Series): -# @classmethod -# def __get_pydantic_core_schema__( -# cls, __source: type[Any], __handler: GetCoreSchemaHandler -# ) -> core_schema.CoreSchema: -# return core_schema.no_info_before_validator_function( -# pd.Series, -# core_schema.dict_schema(), -# serialization=core_schema.plain_serializer_function_ser_schema(dump_to_dict) -# ) -# class DataFrame(pd.DataFrame): -# @classmethod -# def __get_pydantic_core_schema__( -# cls, __source: type[Any], __handler: GetCoreSchemaHandler -# ) -> core_schema.CoreSchema: -# return core_schema.no_info_before_validator_function( -# pd.DataFrame, -# core_schema.dict_schema(), -# serialization=core_schema.plain_serializer_function_ser_schema(dump_to_dict) -# ) +def dump_df_to_dict(instance: pd.DataFrame) -> dict: + values = instance.to_dict(orient="records") + if len(values) == 1: + values = values[0] + return { + "type": "DataFrame", + "values": values, + "dtypes": {k: str(v) for k, v in instance.dtypes.items()}, + } -# This class allows Pydantic to serialise and deserialise pandas Dataframes and Series items +def dump_series_to_dict(instance: pd.Series) -> dict: + return {"type": "Series", "values": instance.to_list(), "name": instance.name} -class _PandasPydanticAnnotation: + +class _PandasDataFramePydanticAnnotation: @classmethod def __get_pydantic_core_schema__( cls, _source_type: Any, _handler: GetCoreSchemaHandler, ) -> core_schema.CoreSchema: - def validate_from_dict(value: dict) -> pd.Series: - data_type = value.get("type") - if data_type is None: - return pd.DataFrame(value) - if value.get("type") == "DataFrame": - return pd.DataFrame(value["values"]) - return pd.Series(value["values"], name=value["name"]) - - from_int_schema = core_schema.chain_schema( - [ - core_schema.dict_schema(), # TODO make this more comprehensive - core_schema.no_info_plain_validator_function(validate_from_dict), - ] - ) return core_schema.json_or_python_schema( - json_schema=from_int_schema, + json_schema=from_df_dict_schema, python_schema=core_schema.union_schema( [ # check if it's an instance first before doing any further work - core_schema.is_instance_schema(pd.Series), - from_int_schema, + core_schema.is_instance_schema(pd.DataFrame), + from_df_dict_schema, ] ), serialization=core_schema.plain_serializer_function_ser_schema( - dump_to_dict + dump_df_to_dict ), ) @@ -84,10 +127,37 @@ def validate_from_dict(value: dict) -> pd.Series: def __get_pydantic_json_schema__( cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler ) -> JsonSchemaValue: - return handler(core_schema.dict_schema()) # TODO make this more comprehensive + return handler(dataframe_json_schema) + + +class _PandasSeriesPydanticAnnotation: + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + return core_schema.json_or_python_schema( + json_schema=from_series_dict_schema, + python_schema=core_schema.union_schema( + [ + # check if it's an instance first before doing any further work + core_schema.is_instance_schema(pd.Series), + from_series_dict_schema, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + dump_series_to_dict + ), + ) -DataFrame = Annotated[pd.DataFrame, _PandasPydanticAnnotation] + @classmethod + def __get_pydantic_json_schema__( + cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + return handler(dataframe_json_schema) -Series = Annotated[pd.Series, _PandasPydanticAnnotation] +DataFrame = Annotated[pd.DataFrame, _PandasDataFramePydanticAnnotation] +Series = Annotated[pd.Series, _PandasSeriesPydanticAnnotation] diff --git a/spockflow/components/tree/__init__.py b/spockflow/components/tree/__init__.py index 31e8147..b36c7c3 100644 --- a/spockflow/components/tree/__init__.py +++ b/spockflow/components/tree/__init__.py @@ -1,6 +1,9 @@ from typing import TypeVar, Generic import pandas as pd -from spockflow.components.tree.v1.core import Tree as Tree +from spockflow.components.tree.v1.core import ( + Tree as Tree, + TableCondition as TableCondition, +) T = TypeVar("T", bound=dict) diff --git a/spockflow/components/tree/v1/compiled.py b/spockflow/components/tree/v1/compiled.py index 2c46b1d..7bfb893 100644 --- a/spockflow/components/tree/v1/compiled.py +++ b/spockflow/components/tree/v1/compiled.py @@ -4,9 +4,12 @@ from pydantic import BaseModel from dataclasses import dataclass from spockflow.components.tree.settings import settings -from .core import ChildTree, Tree, TOutput, TCond +from .core import ChildTree, Tree, TOutput, TCond, ConditionedNode, TableConditionedNode from spockflow.nodes import creates_node +if typing.TYPE_CHECKING: + from spockflow.components.dtable import DecisionTable + # from pandas.core.groupby import DataFrameGroupBy if typing.TYPE_CHECKING: @@ -17,12 +20,14 @@ class ConditionedOutput: output: TOutput conditions: typing.Tuple[TCond] + priority: int @dataclass class SymbolicConditionedOutput: output: str conditions: typing.Tuple[str] + priority: int @dataclass @@ -37,6 +42,14 @@ class SymbolicFlatTree: class CompiledNumpyTree: def __init__(self, tree: Tree) -> None: + self.decision_tables: typing.Dict[str, "DecisionTable"] = {} + for k, v in tree.get_decision_tables().items(): + assert list(v.outputs.keys()) == [ + "value" + ], f'Decision tables nested in decision trees can only have one output named "value"' + # Decision tables dont need compilation but keeping with convention + self.decision_tables[k] = v.compile() + self.length = len(tree.root) flattened_tree = self._flatten_tree( @@ -47,6 +60,10 @@ def __init__(self, tree: Tree) -> None: ) flattened_tree = self._get_symbolic_tree(flattened_tree) self._flattened_tree = flattened_tree + self._flattened_priority = np.array( + [n.priority for n in self._flattened_tree.tree], dtype=np.int32 + ) + self._has_priority = any(self._flattened_priority != 1) (predefined_conditions, predefined_condition_names, execution_conditions) = ( self._split_predefined( @@ -222,51 +239,91 @@ def get_unique_name(v): name = get_unique_name(n.output) outputs[name] = n.output new_tree_nodes.append( - SymbolicConditionedOutput(name, tuple(new_conditions)) + SymbolicConditionedOutput(name, tuple(new_conditions), n.priority) ) return SymbolicFlatTree( outputs=outputs, conditions=conditions, tree=new_tree_nodes ) - @classmethod + @staticmethod + def _merge_priority( + p1: typing.Optional[int], p2: typing.Optional[int] + ) -> typing.Optional[int]: + if p1 is None and p2 is None: + return None + return (p1 or 0) + (p2 or 0) # Also caps values to > 0 + def _flatten_tree( - cls, + self, sub_tree: ChildTree, current_conditions: typing.Tuple[TCond], seen: typing.Set[int], conditioned_outputs: typing.List[ConditionedOutput], + priority: typing.Optional[int] = None, ) -> typing.List[ConditionedOutput]: curr_id = id(sub_tree) if curr_id in seen: raise ValueError("Current tree contains loops. Cannot compile tree.") for n in sub_tree.nodes: - if n.value is None: - raise ValueError( - "All nodes must have a value set to be a valid tree.\n" - "Found a leaf with no value." - ) - if n.condition is None: - raise ValueError( - "All nodes must have a condition set to be a valid tree\n" - "Found a leaf with no condition.\n" - "If this is intended to be a default value please use set_default." - ) + if isinstance(n, ConditionedNode): + if n.value is None: + raise ValueError( + "All nodes must have a value set to be a valid tree.\n" + "Found a leaf with no value." + ) + if n.condition is None: + raise ValueError( + "All nodes must have a condition set to be a valid tree\n" + "Found a leaf with no condition.\n" + "If this is intended to be a default value please use set_default." + ) - n_conditions = current_conditions + (n.condition,) - if isinstance(n.value, ChildTree): - cls._flatten_tree( - n.value, - current_conditions=n_conditions, - seen=seen.union([curr_id]), - conditioned_outputs=conditioned_outputs, - ) + n_conditions = current_conditions + (n.condition,) + if isinstance(n.value, ChildTree): + self._flatten_tree( + n.value, + current_conditions=n_conditions, + seen=seen.union([curr_id]), + conditioned_outputs=conditioned_outputs, + priority=self._merge_priority(priority, n.priority), + ) + else: + conditioned_outputs.append( + ConditionedOutput(n.value, n_conditions, max(priority or 0, 1)) + ) else: - conditioned_outputs.append(ConditionedOutput(n.value, n_conditions)) + if n.condition_table not in self.decision_tables: + raise ValueError( + f"Found node conditioned on a table ({n.condition_table}) that is not registered within the tree." + ) + n._check_compatible_table(self.decision_tables[n.condition_table]) + for i, v in enumerate(n.values): + if v is None: + raise ValueError( + "All nodes must have a value set to be a valid tree.\n" + "Found a leaf with no value." + ) + n_priority = None if n.priority is None else n.priority[i] + n_conditions = current_conditions + (f"{n.condition_table}_is_{i}",) + if isinstance(v, ChildTree): + self._flatten_tree( + v, + current_conditions=n_conditions, + seen=seen.union([curr_id]), + conditioned_outputs=conditioned_outputs, + priority=self._merge_priority(priority, n_priority), + ) + else: + conditioned_outputs.append( + ConditionedOutput(v, n_conditions, max(priority or 0, 1)) + ) if sub_tree.default_value is not None: conditioned_outputs.append( - ConditionedOutput(sub_tree.default_value, current_conditions) + ConditionedOutput( + sub_tree.default_value, current_conditions, max(priority or 0, 1) + ) ) return conditioned_outputs @@ -346,6 +403,13 @@ def conditions_met(self, format_inputs: TFormatData) -> np.ndarray: # The thresh will see where they are all true return (conditions @ self.truth_table) >= self.truth_table_thresh + @creates_node() + def prioritized_conditions(self, conditions_met: np.ndarray) -> np.ndarray: + if not self._has_priority: + return conditions_met + # [O] * [O,N] -> [O,N] (outputs, batch) + return self._flattened_priority * conditions_met + @creates_node() def condition_names(self, format_inputs: TFormatData) -> typing.List[str]: conditions = format_inputs[0] @@ -389,10 +453,10 @@ def all( def get_results( self, format_inputs: TFormatData, - conditions_met: np.ndarray, + prioritized_conditions: np.ndarray, ) -> pd.DataFrame: _, outputs, *_ = format_inputs - condition_output_idx = np.argmax(conditions_met, axis=1) + condition_output_idx = np.argmax(prioritized_conditions, axis=1) return outputs.iloc[ self._lookup_output_idx(format_inputs, condition_output_idx) ].reset_index(drop=True) diff --git a/spockflow/components/tree/v1/core.py b/spockflow/components/tree/v1/core.py index 7338cd1..1417bc1 100644 --- a/spockflow/components/tree/v1/core.py +++ b/spockflow/components/tree/v1/core.py @@ -1,61 +1,212 @@ import typing +import re +import keyword +import functools import pandas as pd import collections.abc from functools import partial from typing_extensions import Self from abc import ABC, abstractmethod -from pydantic import BaseModel, Field, model_validator, ConfigDict +from pydantic import ( + BaseModel, + Field, + model_validator, + ConfigDict, + field_serializer, + PrivateAttr, + AfterValidator, + model_serializer, +) from spockflow.nodes import VariableNode +from ...dtable import DecisionTable +from hamilton.node import DependencyType +from spockflow._serializable import ( + DataFrame, + Series, + dump_df_to_dict, + dump_series_to_dict, +) -TOutput = typing.Union[typing.Callable[..., pd.DataFrame], pd.DataFrame, str] -TCond = typing.Union[typing.Callable[..., pd.Series], pd.Series, str] +if typing.TYPE_CHECKING: + from .compiled import CompiledNumpyTree + + +def _is_valid_function_name(name): + pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$" + assert ( + re.match(pattern, name) and name not in keyword.kwlist + ), f"{name} must be a valid python function name" + return name + + +class TableCondition(BaseModel): + name: typing.Annotated[str, AfterValidator(_is_valid_function_name)] + table: DecisionTable + + +TOutput = typing.Union[typing.Callable[..., pd.DataFrame], DataFrame, str] +TCond = typing.Union[typing.Callable[..., pd.Series], Series, str] +TCondRaw = typing.Union[typing.Callable[..., pd.Series], pd.Series, TableCondition, str] + +_TABLE_VALUE_KEY = "value" + + +def _length_attr(attr): + if attr is None: + return 1 + if isinstance(attr, str): + return 1 + if isinstance(attr, TableCondition): + return 1 + if not isinstance(attr, collections.abc.Sized): + return 1 + return len(attr) + + +def _serialize_value(value: typing.Union[TOutput, "ChildTree", None]): + if value is None: + return value + if isinstance(value, typing.Callable): + return value.__name__ + if isinstance(value, pd.DataFrame): + return dump_df_to_dict(value) + res = value.to_dict(orient="records") + if len(res) == 1: + return res[0] + return res + return value + + +class TableConditionedNode(BaseModel): + condition_type: typing.Literal["table"] = "table" + # model_config = ConfigDict(arbitrary_types_allowed=True) + values: typing.List[typing.Union[TOutput, "ChildTree", None]] = None + condition_table: str + priority: typing.Optional[typing.List[int]] = None + + @field_serializer("values") + def serialize_values( + self, values: typing.List[typing.Union[TOutput, "ChildTree", None]], _info + ): + return [_serialize_value(v) for v in values] + + @model_validator(mode="after") + def check_compatible_lengths(self) -> Self: + self.ensure_length() + return self + + def __len__(self): + len_values = (_length_attr(v) for v in self.values) + try: + return next(v for v in len_values if v != 1) + except StopIteration: + return 1 + + def ensure_length(self, tree_length: int = 1): + len_values = [_length_attr(v) for v in self.values] + [tree_length] + try: + non_unit_value = next(v for v in len_values if v != 1) + except StopIteration: + non_unit_value = 1 + if not all(v == 1 or v == non_unit_value for v in len_values): + raise ValueError("Incompatible value lengths detected") + + def _check_compatible_table(self, table: DecisionTable): + assert len(table.outputs) == 1, "Table must have exactly one output" + assert ( + _TABLE_VALUE_KEY in table.outputs + ), f'Table must contain "{_TABLE_VALUE_KEY}" as an output key' + default_value = set() + if table.default_value is not None: + assert ( + len(table.default_value.keys()) == 1 + ), "Table must have exactly one output in the default values" + assert ( + _TABLE_VALUE_KEY in table.default_value + ), f'Table must contain "{_TABLE_VALUE_KEY}" as an output key in the default values' + assert ( + len(table.default_value) == 1 + ), f"Default value must be a dataframe of length 1." + default_value = {table.default_value[_TABLE_VALUE_KEY].values[0]} + table_output_values = set(table.outputs[_TABLE_VALUE_KEY]) | default_value + last_value = len(table_output_values) + assert ( + set(range(0, last_value)) == table_output_values + ), "Table output values must be sequential integer indicies" + assert last_value == len( + self.values + ), "There must be one output value for each index in the tree outputs" + if self.priority is not None: + assert last_value == len( + self.priority + ), "There must be one priority item for each index in the tree outputs" class ConditionedNode(BaseModel): # TODO fix for pd.DataFrame - model_config = ConfigDict(arbitrary_types_allowed=True) + condition_type: typing.Literal["base"] = "base" + # model_config = ConfigDict(arbitrary_types_allowed=True) value: typing.Union[TOutput, "ChildTree", None] = None condition: typing.Optional[TCond] = None + priority: typing.Optional[int] = None - @staticmethod - def _length_attr(attr): - if attr is None: - return 1 - if isinstance(attr, str): - return 1 - if not isinstance(attr, collections.abc.Sized): - return 1 - return len(attr) + @field_serializer("condition") + def serialize_condition(self, condition: typing.Optional[TCond], _info): + if condition is None: + return condition + if isinstance(condition, typing.Callable): + return condition.__name__ + if isinstance(condition, pd.Series): + return dump_series_to_dict(condition) + values = condition.tolist() + return {condition.name: values if len(values) > 1 else values[0]} + return condition + + @field_serializer("value") + def serialize_value(self, value: typing.Union[TOutput, "ChildTree", None], _info): + return _serialize_value(value) def __len__(self): - len_attr = self._length_attr(self.value) + len_attr = _length_attr(self.value) if len_attr == 1: - len_attr = self._length_attr(self.condition) + len_attr = _length_attr(self.condition) return len_attr @model_validator(mode="after") def check_compatible_lengths(self) -> Self: - len_value = self._length_attr(self.value) - if len_value == 1: - return self - len_condition = self._length_attr(self.condition) - if len_condition == 1: - return self - if len_condition != len_value: - raise ValueError("Condition and value lengths incompatible") + self.ensure_length() return self + def ensure_length(self, tree_length: int = 1): + len_value = _length_attr(self.value) + len_condition = _length_attr(self.condition) + count_unit_length = (len_value == 1) + (len_condition == 1) + (tree_length == 1) + if count_unit_length >= 2: + return + if len_value == len_condition == tree_length: + return + raise ValueError("Condition and value and tree lengths are incompatible") + + +TConditionedNode = typing.Annotated[ + typing.Union[ConditionedNode, TableConditionedNode], + Field(discriminator="condition_type"), +] + class ChildTree(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - nodes: typing.List[ConditionedNode] = Field(default_factory=list) + # model_config = ConfigDict(arbitrary_types_allowed=True) + nodes: typing.List[TConditionedNode] = Field(default_factory=list) default_value: typing.Optional[TOutput] = None + _decision_tables: typing.Dict[str, DecisionTable] = PrivateAttr( + default_factory=dict + ) def __len__( self, ): # TODO could maybe cache this on bigger trees if the values havent changed - node_len = ConditionedNode._length_attr(self.default_value) + node_len = _length_attr(self.default_value) if node_len != 1: return node_len for node in self.nodes: @@ -65,7 +216,7 @@ def __len__( @model_validator(mode="after") def check_compatible_lengths(self) -> Self: - child_tree_len = ConditionedNode._length_attr(self.default_value) + child_tree_len = _length_attr(self.default_value) for node in self.nodes: len_value = len(node) if len_value == 1: @@ -78,33 +229,58 @@ def check_compatible_lengths(self) -> Self: ) return self - def add_node(self, value: TOutput, condition: TCond, **kwargs) -> ConditionedNode: - len_value = ConditionedNode._length_attr(value) - len_condition = ConditionedNode._length_attr(condition) - if len_value != 1 and len_condition != 1 and len_value != len_condition: - raise ValueError( - f"Cannot add node as the length of the value ({len_value}) is not compatible with the length of the condition ({len_condition})." + @staticmethod + def _merge_decision_tables( + to_be_updated: typing.Dict[str, DecisionTable], + other: typing.Dict[str, DecisionTable], + ): + for k, v in other.items(): + if k in to_be_updated: + assert ( + to_be_updated[k] is v + ), f"Decision table {k} added twice with different values." + else: + to_be_updated[k] = v + + # def _get_condition_from_table(self, condition: TCondRaw): + # if not isinstance(condition, TableCondition): return condition + # self._merge_decision_tables( + # self._decision_tables, + # {condition.name: condition.table} + # ) + # return TableConditionReference(table=condition.name) + + def add_node( + self, + value: typing.Union[TOutput, typing.List[TOutput]], + condition: TCondRaw, + priority: typing.Union[int, typing.List[int], None] = None, + **kwargs, + ) -> ConditionedNode: + + if isinstance(condition, TableCondition): + node = TableConditionedNode( + values=value, + condition_table=condition.name, + priority=priority, + **kwargs, ) - if len_value != 1 or len_condition != 1: - # TODO adding this allows better validation but requires circular loops so hard for pydantic to serialise - # len_tree = len(self.root_tree.root) - len_tree = len(self) - if len_value != 1 and len_tree != 1 and len_value != len_tree: - raise ValueError( - f"Cannot add node as the length of the value ({len_value}) incompatible with tree {len_tree}." - ) - if len_condition != 1 and len_tree != 1 and len_condition != len_tree: - raise ValueError( - f"Cannot add node as the length of the condition ({len_condition}) incompatible with tree {len_tree}." - ) - node = ConditionedNode(value=value, condition=condition, **kwargs) + self._merge_decision_tables( + self._decision_tables, {condition.name: condition.table} + ) + node._check_compatible_table(condition.table) + else: + node = ConditionedNode( + value=value, condition=condition, priority=priority, **kwargs + ) + node.ensure_length(len(self)) self.nodes.append(node) return node def set_default(self, value: TOutput): if self.default_value is not None: raise ValueError("Default value already set") - len_value = ConditionedNode._length_attr(value) + len_value = _length_attr(value) if len_value != 1: # TODO adding this allows better validation but requires circular loops so hard for pydantic to serialise # len_tree = len(self.root_tree.root) @@ -128,11 +304,26 @@ def merge_into(self, other: Self): raise ValueError( f"Cannot merge two subtrees both containing default values" ) + + self._merge_decision_tables(self._decision_tables, other._decision_tables) + if other.default_value is not None: self.set_default(other.default_value) self.nodes.extend(other.nodes) + def get_all_decision_tables(self): + tables = self._decision_tables.copy() + for node in self.nodes: + if isinstance(node, TableConditionedNode): + node_values = node.values + else: + node_values = [node.value] + for nv in node_values: + if isinstance(nv, ChildTree): + self._merge_decision_tables(tables, nv.get_all_decision_tables()) + return tables + class WrappedTreeFunction(ABC): @abstractmethod @@ -149,18 +340,78 @@ def include_subtree( ): ... +def _check_table_result_eq(value: int, **kwargs: pd.DataFrame) -> pd.Series: + table_result = kwargs[next(iter(kwargs.keys()))] + return table_result[_TABLE_VALUE_KEY] == value + + class Tree(VariableNode): doc: str = "This executes a user defined decision tree" root: ChildTree = Field(default_factory=ChildTree) + decision_tables: typing.Dict[str, DecisionTable] = Field(default_factory=dict) + + def get_decision_tables(self): + decision_tables = self.decision_tables.copy() + self.root._merge_decision_tables( + decision_tables, self.root.get_all_decision_tables() + ) + return decision_tables + + @model_serializer() + def serialize_model(self): + return { + "doc": self.doc, + "root": self.root, + "decision_tables": self.get_decision_tables(), + } def compile(self): from .compiled import CompiledNumpyTree return CompiledNumpyTree(self) + def _generate_nodes( + self, + name: str, + config: "typing.Dict[str, typing.Any]", + include_runtime_nodes: bool = False, + ) -> "typing.List[node.Node]": + from hamilton import node + + compiled_node = self.compile() + output_nodes = super()._generate_nodes( + name=name, + config=config, + include_runtime_nodes=include_runtime_nodes, + compiled_node_override=compiled_node, + ) + base_node_ = node.Node.from_fn(_check_table_result_eq, name="temporary_node") + for table_name, compiled_table in compiled_node.decision_tables.items(): + output_nodes.extend( + compiled_table._generate_nodes(table_name, config, True) + ) + unique_values = set(compiled_table.outputs[_TABLE_VALUE_KEY]) + for v in unique_values: + v = int(v) # Just to be safe + + output_nodes.append( + base_node_.copy_with( + name=f"{table_name}_is_{v}", + doc_string=f"This is a function used to determine if the result of {table_name} is {v} to be used in a decision tree.", + callabl=functools.partial(_check_table_result_eq, value=v), + input_types={ + table_name: (pd.DataFrame, DependencyType.REQUIRED) + }, + include_refs=False, + ) + ) + + return output_nodes + def _generate_runtime_nodes( self, config: "typing.Dict[str, typing.Any]", compiled_node: "CompiledNumpyTree" ) -> "typing.List[node.Node]": + # This is used to extract the condition functions when running outside of hamilton context from hamilton import node return [ @@ -220,7 +471,14 @@ def _identify_loops(self, *nodes: "ConditionedNode"): if id(el) in seen: raise ValueError("Tree must not contain any loops") seen.add(id(el)) - if isinstance(el.value, ChildTree): + if isinstance(el, TableConditionedNode): + for v in el.values: + if isinstance(v, ChildTree): + q.extend(v.nodes) + if isinstance(v.default_value, ChildTree): + q.extend(v.default_value.nodes) + + elif isinstance(el.value, ChildTree): q.extend(el.value.nodes) if isinstance(el.value.default_value, ChildTree): q.extend(el.value.default_value.nodes) @@ -247,8 +505,9 @@ def copy(self, deep=True): def condition( self, output: TOutput = None, - condition: typing.Union[TCond, None] = None, + condition: typing.Optional[TCondRaw] = None, child_tree: ChildTree = None, + priority: typing.Optional[int] = None, **kwargs, ) -> typing.Callable[..., WrappedTreeFunction]: """ @@ -293,7 +552,9 @@ def wrapper(condition: TCond): nonlocal output if isinstance(output, Tree): output = output.root - node = child_tree.add_node(value=output, condition=condition, **kwargs) + node = child_tree.add_node( + value=output, condition=condition, priority=priority, **kwargs + ) try: self._identify_loops(node) except ValueError as e: @@ -461,21 +722,30 @@ def visualize(self, get_value_name=None, get_condition_name=None): dot.node(curr_name, curr_name) for node in curr.nodes: - node_condition_name = get_condition_name(node.condition) + node_condition_name = get_condition_name( + node.condition_table + if isinstance(node, TableConditionedNode) + else node.condition + ) dot.node(node_condition_name, node_condition_name) dot.edge(curr_name, node_condition_name) - if hasattr(node.value, "nodes"): - to_search.extend([(node.value, node_condition_name)]) - elif node.value is not None: - node_value_name = get_value_name(node.value) - dot.node( - node_value_name, - node_value_name, - style="filled", - fillcolor="#ADDFFF", - shape="rectangle", - ) - dot.edge(node_condition_name, node_value_name) + if isinstance(node, TableConditionedNode): + node_values = node.values + else: + node_values = [node.value] + for val in node_values: + if hasattr(val, "nodes"): + to_search.extend([(val, node_condition_name)]) + elif val is not None: + node_value_name = get_value_name(val) + dot.node( + node_value_name, + node_value_name, + style="filled", + fillcolor="#ADDFFF", + shape="rectangle", + ) + dot.edge(node_condition_name, node_value_name) if curr.default_value is not None: default_name = get_value_name(curr.default_value) diff --git a/spockflow/nodes.py b/spockflow/nodes.py index 75d921b..42d6113 100644 --- a/spockflow/nodes.py +++ b/spockflow/nodes.py @@ -322,6 +322,7 @@ def _generate_nodes( name: str, config: "typing.Dict[str, typing.Any]", include_runtime_nodes: bool = False, + compiled_node_override: typing.Optional[Self] = None, ) -> "typing.List[node.Node]": """Generate nodes for this class to be used in a hamilton dag @@ -333,7 +334,11 @@ def _generate_nodes( Returns: List[node.Node]: The resulting Hamilton nodes """ - compiled_variable_node = self.compile() + if compiled_node_override is None: + compiled_variable_node = self.compile() + else: + compiled_variable_node = self.compile() + node_functions = inspect.getmembers( compiled_variable_node, predicate=self._does_define_node )