From 684f423ce3235451ed288e658ea9dc2328d01787 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Sun, 19 Jan 2025 18:32:22 -0500 Subject: [PATCH] Fix serialization, make custom serializers in pydantic work with csp.Structs Signed-off-by: Nijat Khanbabayev --- csp/impl/enum.py | 5 +- csp/impl/struct.py | 11 ++- csp/tests/impl/test_enum.py | 87 +++++++++++++++++++++ csp/tests/impl/test_struct.py | 137 +++++++++++++++++++++++++++++++++- 4 files changed, 232 insertions(+), 8 deletions(-) diff --git a/csp/impl/enum.py b/csp/impl/enum.py index 14d2f21ec..8eb25399e 100644 --- a/csp/impl/enum.py +++ b/csp/impl/enum.py @@ -75,8 +75,9 @@ def _validate(cls, v) -> "Enum": raise ValueError(f"Cannot convert value to enum: {v}") @staticmethod - def _serialize(value: "Enum") -> str: - return value.name + def _serialize(value: typing.Union[str, "Enum"]) -> str: + # We return value if it is str in case it has already been serialized + return value if isinstance(value, str) else value.name @classmethod def __get_pydantic_json_schema__(cls, _core_schema, handler): diff --git a/csp/impl/struct.py b/csp/impl/struct.py index 812b34d5b..5a52dcbd7 100644 --- a/csp/impl/struct.py +++ b/csp/impl/struct.py @@ -86,7 +86,7 @@ def layout(self, num_cols=8): @staticmethod def _get_pydantic_core_schema(cls, _source_type, handler): - """Tell Pydantic how to validate this Struct class.""" + """Tell Pydantic how to validate and serialize this Struct class.""" from pydantic import PydanticSchemaGenerationError from pydantic_core import core_schema @@ -131,12 +131,15 @@ def create_instance(validated_data): data_dict = validated_data[0] if isinstance(validated_data, tuple) else validated_data return cls(**data_dict) + def ser_func(val, handler): + new_val = val.to_dict() if isinstance(val, csp.Struct) else val + return handler(new_val) + return core_schema.no_info_after_validator_function( function=create_instance, schema=schema, - serialization=core_schema.plain_serializer_function_ser_schema( - function=lambda x: x.to_dict(), # Use the built-in to_dict method - return_schema=core_schema.dict_schema(), + serialization=core_schema.wrap_serializer_function_ser_schema( + function=ser_func, schema=fields_schema, when_used="always" ), ) diff --git a/csp/tests/impl/test_enum.py b/csp/tests/impl/test_enum.py index 427cf2365..87908ca51 100644 --- a/csp/tests/impl/test_enum.py +++ b/csp/tests/impl/test_enum.py @@ -1,6 +1,10 @@ import _pickle +import json +import pytest import unittest from datetime import datetime, timedelta +from pydantic import BaseModel, ConfigDict, RootModel +from typing import Dict, List import csp from csp import ts @@ -36,6 +40,22 @@ def s1(): MyDEnum = csp.DynamicEnum("MyDEnum", ["A", "B", "C"]) +class MyEnum3(csp.Enum): + FIELD1 = csp.Enum.auto() + FIELD2 = csp.Enum.auto() + + +class MyModel(BaseModel): + enum: MyEnum3 + enum_default: MyEnum3 = MyEnum3.FIELD1 + + +class MyDictModel(BaseModel): + model_config = ConfigDict(use_enum_values=True) + + enum_dict: Dict[MyEnum3, int] = None + + class TestCspEnum(unittest.TestCase): def test_basic(self): self.assertEqual(MyEnum("A"), MyEnum.A) @@ -152,6 +172,73 @@ class B(A): self.assertEqual("Cannot extend csp.Enum 'A': inheriting from an Enum is prohibited", str(cm.exception)) + def test_pydantic_validation(self): + assert MyModel(enum="FIELD2").enum == MyEnum3.FIELD2 + assert MyModel(enum=0).enum == MyEnum3.FIELD1 + assert MyModel(enum=MyEnum3.FIELD1).enum == MyEnum3.FIELD1 + with pytest.raises(ValueError): + MyModel(enum=3.14) + + def test_pydantic_dict(self): + assert dict(MyModel(enum=MyEnum3.FIELD2)) == {"enum": MyEnum3.FIELD2, "enum_default": MyEnum3.FIELD1} + assert MyModel(enum=MyEnum3.FIELD2).model_dump(mode="python") == { + "enum": MyEnum3.FIELD2, + "enum_default": MyEnum3.FIELD1, + } + assert MyModel(enum=MyEnum3.FIELD2).model_dump(mode="json") == {"enum": "FIELD2", "enum_default": "FIELD1"} + + def test_pydantic_serialization(self): + assert "enum" in MyModel.model_fields + assert "enum_default" in MyModel.model_fields + tm = MyModel(enum=MyEnum3.FIELD2) + assert json.loads(tm.model_dump_json()) == json.loads('{"enum": "FIELD2", "enum_default": "FIELD1"}') + + def test_enum_as_dict_key_json_serialization(self): + class DictWrapper(RootModel[Dict[MyEnum3, int]]): + model_config = ConfigDict(use_enum_values=True) + + def __getitem__(self, item): + return self.root[item] + + class MyDictWrapperModel(BaseModel): + model_config = ConfigDict(use_enum_values=True) + + enum_dict: DictWrapper + + dict_model = MyDictModel(enum_dict={MyEnum3.FIELD1: 8, MyEnum3.FIELD2: 19}) + assert dict_model.enum_dict[MyEnum3.FIELD1] == 8 + assert dict_model.enum_dict[MyEnum3.FIELD2] == 19 + + assert json.loads(dict_model.model_dump_json()) == json.loads('{"enum_dict":{"FIELD1":8,"FIELD2":19}}') + + dict_wrapper_model = MyDictWrapperModel(enum_dict=DictWrapper({MyEnum3.FIELD1: 8, MyEnum3.FIELD2: 19})) + + assert dict_wrapper_model.enum_dict[MyEnum3.FIELD1] == 8 + assert dict_wrapper_model.enum_dict[MyEnum3.FIELD2] == 19 + assert json.loads(dict_wrapper_model.model_dump_json()) == json.loads('{"enum_dict":{"FIELD1":8,"FIELD2":19}}') + + def test_json_schema_csp(self): + assert MyModel.model_json_schema() == { + "properties": { + "enum": { + "description": "An enumeration of MyEnum3", + "enum": ["FIELD1", "FIELD2"], + "title": "MyEnum3", + "type": "string", + }, + "enum_default": { + "default": "FIELD1", + "description": "An enumeration of MyEnum3", + "enum": ["FIELD1", "FIELD2"], + "title": "MyEnum3", + "type": "string", + }, + }, + "required": ["enum"], + "title": "MyModel", + "type": "object", + } + if __name__ == "__main__": unittest.main() diff --git a/csp/tests/impl/test_struct.py b/csp/tests/impl/test_struct.py index 65152486c..b91cb470d 100644 --- a/csp/tests/impl/test_struct.py +++ b/csp/tests/impl/test_struct.py @@ -3567,10 +3567,143 @@ class DataPoint(csp.Struct): self.assertIsInstance(result.history[2], BaseMetric) # Should be base # Test serialization and deserialization preserves specific types - json_data = result.to_json() - restored = TypeAdapter(DataPoint).validate_json(json_data) + json_data_csp = result.to_json() + json_data_pydantic = TypeAdapter(DataPoint).dump_json(result).decode() + self.assertEqual(json_data_csp, json_data_pydantic) + restored = TypeAdapter(DataPoint).validate_json(json_data_csp) self.assertEqual(restored, result) + def test_pydantic_custom_serialization(self): + """Test that CustomStruct correctly serializes integers with comma formatting""" + from pydantic.functional_serializers import PlainSerializer + + # Define the custom integer type with fancy formatting + FancyInt = Annotated[int, PlainSerializer(lambda x: f"{x:,}", return_type=str, when_used="always")] + + # Simple struct with just the FancyInt + class CustomStruct(csp.Struct): + value: FancyInt + + # Test different integer values + test_cases = [ + (1234, "1,234"), + (1000000, "1,000,000"), + (42, "42"), + ] + + for input_value, expected_output in test_cases: + # Create and serialize the struct + s = CustomStruct(value=input_value) + serialized = json.loads(TypeAdapter(CustomStruct).dump_json(s)) + + # Verify the serialization + self.assertEqual( + serialized["value"], + expected_output, + ) + + def test_pydantic_serialization_with_enums(self): + """Test serialization behavior with enums using both native and Pydantic approaches""" + + class Color(csp.Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + class Shape(csp.Enum): + CIRCLE = 1 + SQUARE = 2 + TRIANGLE = 3 + + class DrawingStruct(csp.Struct): + color: Color + shape: Shape + colors: List[Color] + shapes: Dict[str, Shape] + + drawing = DrawingStruct( + color=Color.RED, + shape=Shape.CIRCLE, + colors=[Color.RED, Color.GREEN, Color.BLUE], + shapes={"a": Shape.SQUARE, "b": Shape.TRIANGLE}, + ) + + # Test native serialization + native_json = json.loads(drawing.to_json()) + self.assertEqual(native_json["color"], "RED") + self.assertEqual(native_json["shape"], "CIRCLE") + self.assertEqual(native_json["colors"], ["RED", "GREEN", "BLUE"]) + self.assertEqual(native_json["shapes"], {"a": "SQUARE", "b": "TRIANGLE"}) + + # Test Pydantic serialization + pydantic_json = json.loads(TypeAdapter(DrawingStruct).dump_json(drawing)) + self.assertEqual(pydantic_json, native_json) # Should be identical for enums + + # Test round-trip through both methods + native_restored = DrawingStruct.from_dict(json.loads(drawing.to_json())) + pydantic_restored = TypeAdapter(DrawingStruct).validate_json(TypeAdapter(DrawingStruct).dump_json(drawing)) + + self.assertEqual(native_restored, drawing) + self.assertEqual(pydantic_restored, drawing) + + def test_pydantic_serialization_vs_native(self): + """Test that Pydantic serialization matches CSP native serialization for basic types""" + from pydantic.functional_serializers import PlainSerializer + + class MyEnum(csp.Enum): + OPTION1 = csp.Enum.auto() + OPTION2 = csp.Enum.auto() + + # Define custom datetime serialization + # This is so that pydantic serializes datetime with the same precision as csp natively does + SimpleDatetime = Annotated[ + datetime, + PlainSerializer(lambda dt: dt.strftime("%Y-%m-%dT%H:%M:%S.%f+00:00"), return_type=str, when_used="always"), + ] + + class SimpleStruct(csp.Struct): + i: int = 123 + f: float = 3.14 + s: str = "test" + b: bool = True + # dt: datetime = datetime(2023, 1, 1) + dt: SimpleDatetime = datetime(2023, 1, 1) + l: List[int] = [1, 2, 3] + d: Dict[str, float] = {"a": 1.1, "b": 2.2} + e: MyEnum = MyEnum.OPTION1 + + # Test with default values + s1 = SimpleStruct() + json_native = s1.to_json() + json_pydantic = TypeAdapter(SimpleStruct).dump_json(s1).decode() + self.assertEqual(json_native, json_pydantic) + + # Test with custom values + s2 = SimpleStruct( + i=456, + f=2.718, + s="custom", + b=False, + dt=datetime(2024, 1, 1, tzinfo=pytz.UTC), + l=[4, 5, 6], + d={"x": 9.9, "y": 8.8}, + ) + json_native = s2.to_json() + json_pydantic = TypeAdapter(SimpleStruct).dump_json(s2).decode() + self.assertEqual(json.loads(json_native), json.loads(json_pydantic)) + + # Test with nested structs + class NestedStruct(csp.Struct): + name: str + simple: SimpleStruct + simples: List[SimpleStruct] + + nested = NestedStruct(name="test", simple=s1, simples=[s1, s2]) + + json_native = nested.to_json() + json_pydantic = TypeAdapter(NestedStruct).dump_json(nested).decode() + self.assertEqual(json.loads(json_native), json.loads(json_pydantic)) + if __name__ == "__main__": unittest.main()