Skip to content

Commit

Permalink
Merge pull request #16 from capitec/feature/tree-priority
Browse files Browse the repository at this point in the history
Add node priority to decision trees
  • Loading branch information
sjnarmstrong authored Jan 14, 2025
2 parents 0bc66cf + ea4b1c6 commit 37a7d68
Show file tree
Hide file tree
Showing 6 changed files with 557 additions and 145 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
170 changes: 120 additions & 50 deletions spockflow/_serializable.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,163 @@
import pandas as pd
from typing import Any, Union
from typing import Any

from pydantic_core import core_schema
from typing_extensions import Annotated

from pydantic import (
BaseModel,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
ValidationError,
)
from pydantic.json_schema import JsonSchemaValue


def dump_to_dict(instance: Union[pd.Series, pd.DataFrame]) -> dict:
if isinstance(instance, pd.DataFrame):
return {"type": "DataFrame", "values": instance.to_dict(orient="list")}
return {"values": instance.to_list(), "name": instance.name, "type": "Series"}
values_schema = core_schema.dict_schema(keys_schema=core_schema.str_schema())

dataframe_json_schema = core_schema.model_fields_schema(
{
"type": core_schema.model_field(core_schema.literal_schema(["DataFrame"])),
"values": core_schema.model_field(
core_schema.union_schema(
[values_schema, core_schema.list_schema(values_schema)]
)
),
"dtypes": core_schema.model_field(
core_schema.with_default_schema(
core_schema.union_schema(
[
core_schema.none_schema(),
core_schema.dict_schema(
keys_schema=core_schema.str_schema(),
values_schema=core_schema.str_schema(),
),
]
),
default=None,
)
),
}
)
series_json_schema = core_schema.model_fields_schema(
{
"type": core_schema.model_field(core_schema.literal_schema(["Series"])),
"values": core_schema.model_field(core_schema.list_schema()),
"name": core_schema.model_field(
core_schema.with_default_schema(
core_schema.union_schema(
[core_schema.none_schema(), core_schema.str_schema()]
),
default=None,
)
),
}
)


def validate_to_dataframe_schema(value: dict) -> pd.DataFrame:
value, *_ = value
values = value["values"]
if isinstance(values, dict):
values = [values]
df = pd.DataFrame(values)
dtypes = value.get("dtypes")
if dtypes:
return df.astype(dtypes)
return df


def validate_to_series_schema(value: dict) -> pd.Series:
value, *_ = value
return pd.Series(value["values"], name=value.get("name"))


from_df_dict_schema = core_schema.chain_schema(
[
dataframe_json_schema,
# core_schema.dict_schema(),
core_schema.no_info_plain_validator_function(validate_to_dataframe_schema),
]
)
from_series_dict_schema = core_schema.chain_schema(
[
series_json_schema,
# core_schema.dict_schema(),
core_schema.no_info_plain_validator_function(validate_to_series_schema),
]
)

# class Series(pd.Series):
# @classmethod
# def __get_pydantic_core_schema__(
# cls, __source: type[Any], __handler: GetCoreSchemaHandler
# ) -> core_schema.CoreSchema:
# return core_schema.no_info_before_validator_function(
# pd.Series,
# core_schema.dict_schema(),
# serialization=core_schema.plain_serializer_function_ser_schema(dump_to_dict)
# )

# class DataFrame(pd.DataFrame):
# @classmethod
# def __get_pydantic_core_schema__(
# cls, __source: type[Any], __handler: GetCoreSchemaHandler
# ) -> core_schema.CoreSchema:
# return core_schema.no_info_before_validator_function(
# pd.DataFrame,
# core_schema.dict_schema(),
# serialization=core_schema.plain_serializer_function_ser_schema(dump_to_dict)
# )
def dump_df_to_dict(instance: pd.DataFrame) -> dict:
values = instance.to_dict(orient="records")
if len(values) == 1:
values = values[0]
return {
"type": "DataFrame",
"values": values,
"dtypes": {k: str(v) for k, v in instance.dtypes.items()},
}

# This class allows Pydantic to serialise and deserialise pandas Dataframes and Series items

def dump_series_to_dict(instance: pd.Series) -> dict:
return {"type": "Series", "values": instance.to_list(), "name": instance.name}

class _PandasPydanticAnnotation:

class _PandasDataFramePydanticAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
def validate_from_dict(value: dict) -> pd.Series:
data_type = value.get("type")
if data_type is None:
return pd.DataFrame(value)
if value.get("type") == "DataFrame":
return pd.DataFrame(value["values"])
return pd.Series(value["values"], name=value["name"])

from_int_schema = core_schema.chain_schema(
[
core_schema.dict_schema(), # TODO make this more comprehensive
core_schema.no_info_plain_validator_function(validate_from_dict),
]
)

return core_schema.json_or_python_schema(
json_schema=from_int_schema,
json_schema=from_df_dict_schema,
python_schema=core_schema.union_schema(
[
# check if it's an instance first before doing any further work
core_schema.is_instance_schema(pd.Series),
from_int_schema,
core_schema.is_instance_schema(pd.DataFrame),
from_df_dict_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
dump_to_dict
dump_df_to_dict
),
)

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.dict_schema()) # TODO make this more comprehensive
return handler(dataframe_json_schema)


class _PandasSeriesPydanticAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:

return core_schema.json_or_python_schema(
json_schema=from_series_dict_schema,
python_schema=core_schema.union_schema(
[
# check if it's an instance first before doing any further work
core_schema.is_instance_schema(pd.Series),
from_series_dict_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
dump_series_to_dict
),
)

DataFrame = Annotated[pd.DataFrame, _PandasPydanticAnnotation]
@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(dataframe_json_schema)


Series = Annotated[pd.Series, _PandasPydanticAnnotation]
DataFrame = Annotated[pd.DataFrame, _PandasDataFramePydanticAnnotation]
Series = Annotated[pd.Series, _PandasSeriesPydanticAnnotation]
5 changes: 4 additions & 1 deletion spockflow/components/tree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import TypeVar, Generic
import pandas as pd
from spockflow.components.tree.v1.core import Tree as Tree
from spockflow.components.tree.v1.core import (
Tree as Tree,
TableCondition as TableCondition,
)

T = TypeVar("T", bound=dict)

Expand Down
Loading

0 comments on commit 37a7d68

Please sign in to comment.