Skip to content

Commit

Permalink
Fix serialization, make custom serializers in pydantic work with csp.…
Browse files Browse the repository at this point in the history
…Structs

Signed-off-by: Nijat Khanbabayev <[email protected]>
  • Loading branch information
NeejWeej committed Jan 19, 2025
1 parent c252851 commit 684f423
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 8 deletions.
5 changes: 3 additions & 2 deletions csp/impl/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def _validate(cls, v) -> "Enum":
raise ValueError(f"Cannot convert value to enum: {v}")

@staticmethod
def _serialize(value: "Enum") -> str:
return value.name
def _serialize(value: typing.Union[str, "Enum"]) -> str:
# We return value if it is str in case it has already been serialized
return value if isinstance(value, str) else value.name

@classmethod
def __get_pydantic_json_schema__(cls, _core_schema, handler):
Expand Down
11 changes: 7 additions & 4 deletions csp/impl/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def layout(self, num_cols=8):

@staticmethod
def _get_pydantic_core_schema(cls, _source_type, handler):
"""Tell Pydantic how to validate this Struct class."""
"""Tell Pydantic how to validate and serialize this Struct class."""
from pydantic import PydanticSchemaGenerationError
from pydantic_core import core_schema

Expand Down Expand Up @@ -131,12 +131,15 @@ def create_instance(validated_data):
data_dict = validated_data[0] if isinstance(validated_data, tuple) else validated_data
return cls(**data_dict)

def ser_func(val, handler):
new_val = val.to_dict() if isinstance(val, csp.Struct) else val
return handler(new_val)

return core_schema.no_info_after_validator_function(
function=create_instance,
schema=schema,
serialization=core_schema.plain_serializer_function_ser_schema(
function=lambda x: x.to_dict(), # Use the built-in to_dict method
return_schema=core_schema.dict_schema(),
serialization=core_schema.wrap_serializer_function_ser_schema(
function=ser_func, schema=fields_schema, when_used="always"
),
)

Expand Down
87 changes: 87 additions & 0 deletions csp/tests/impl/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import _pickle
import json
import pytest
import unittest
from datetime import datetime, timedelta
from pydantic import BaseModel, ConfigDict, RootModel
from typing import Dict, List

import csp
from csp import ts
Expand Down Expand Up @@ -36,6 +40,22 @@ def s1():
MyDEnum = csp.DynamicEnum("MyDEnum", ["A", "B", "C"])


class MyEnum3(csp.Enum):
FIELD1 = csp.Enum.auto()
FIELD2 = csp.Enum.auto()


class MyModel(BaseModel):
enum: MyEnum3
enum_default: MyEnum3 = MyEnum3.FIELD1


class MyDictModel(BaseModel):
model_config = ConfigDict(use_enum_values=True)

enum_dict: Dict[MyEnum3, int] = None


class TestCspEnum(unittest.TestCase):
def test_basic(self):
self.assertEqual(MyEnum("A"), MyEnum.A)
Expand Down Expand Up @@ -152,6 +172,73 @@ class B(A):

self.assertEqual("Cannot extend csp.Enum 'A': inheriting from an Enum is prohibited", str(cm.exception))

def test_pydantic_validation(self):
assert MyModel(enum="FIELD2").enum == MyEnum3.FIELD2
assert MyModel(enum=0).enum == MyEnum3.FIELD1
assert MyModel(enum=MyEnum3.FIELD1).enum == MyEnum3.FIELD1
with pytest.raises(ValueError):
MyModel(enum=3.14)

def test_pydantic_dict(self):
assert dict(MyModel(enum=MyEnum3.FIELD2)) == {"enum": MyEnum3.FIELD2, "enum_default": MyEnum3.FIELD1}
assert MyModel(enum=MyEnum3.FIELD2).model_dump(mode="python") == {
"enum": MyEnum3.FIELD2,
"enum_default": MyEnum3.FIELD1,
}
assert MyModel(enum=MyEnum3.FIELD2).model_dump(mode="json") == {"enum": "FIELD2", "enum_default": "FIELD1"}

def test_pydantic_serialization(self):
assert "enum" in MyModel.model_fields
assert "enum_default" in MyModel.model_fields
tm = MyModel(enum=MyEnum3.FIELD2)
assert json.loads(tm.model_dump_json()) == json.loads('{"enum": "FIELD2", "enum_default": "FIELD1"}')

def test_enum_as_dict_key_json_serialization(self):
class DictWrapper(RootModel[Dict[MyEnum3, int]]):
model_config = ConfigDict(use_enum_values=True)

def __getitem__(self, item):
return self.root[item]

class MyDictWrapperModel(BaseModel):
model_config = ConfigDict(use_enum_values=True)

enum_dict: DictWrapper

dict_model = MyDictModel(enum_dict={MyEnum3.FIELD1: 8, MyEnum3.FIELD2: 19})
assert dict_model.enum_dict[MyEnum3.FIELD1] == 8
assert dict_model.enum_dict[MyEnum3.FIELD2] == 19

assert json.loads(dict_model.model_dump_json()) == json.loads('{"enum_dict":{"FIELD1":8,"FIELD2":19}}')

dict_wrapper_model = MyDictWrapperModel(enum_dict=DictWrapper({MyEnum3.FIELD1: 8, MyEnum3.FIELD2: 19}))

assert dict_wrapper_model.enum_dict[MyEnum3.FIELD1] == 8
assert dict_wrapper_model.enum_dict[MyEnum3.FIELD2] == 19
assert json.loads(dict_wrapper_model.model_dump_json()) == json.loads('{"enum_dict":{"FIELD1":8,"FIELD2":19}}')

def test_json_schema_csp(self):
assert MyModel.model_json_schema() == {
"properties": {
"enum": {
"description": "An enumeration of MyEnum3",
"enum": ["FIELD1", "FIELD2"],
"title": "MyEnum3",
"type": "string",
},
"enum_default": {
"default": "FIELD1",
"description": "An enumeration of MyEnum3",
"enum": ["FIELD1", "FIELD2"],
"title": "MyEnum3",
"type": "string",
},
},
"required": ["enum"],
"title": "MyModel",
"type": "object",
}


if __name__ == "__main__":
unittest.main()
137 changes: 135 additions & 2 deletions csp/tests/impl/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3567,10 +3567,143 @@ class DataPoint(csp.Struct):
self.assertIsInstance(result.history[2], BaseMetric) # Should be base

# Test serialization and deserialization preserves specific types
json_data = result.to_json()
restored = TypeAdapter(DataPoint).validate_json(json_data)
json_data_csp = result.to_json()
json_data_pydantic = TypeAdapter(DataPoint).dump_json(result).decode()
self.assertEqual(json_data_csp, json_data_pydantic)
restored = TypeAdapter(DataPoint).validate_json(json_data_csp)
self.assertEqual(restored, result)

def test_pydantic_custom_serialization(self):
"""Test that CustomStruct correctly serializes integers with comma formatting"""
from pydantic.functional_serializers import PlainSerializer

# Define the custom integer type with fancy formatting
FancyInt = Annotated[int, PlainSerializer(lambda x: f"{x:,}", return_type=str, when_used="always")]

# Simple struct with just the FancyInt
class CustomStruct(csp.Struct):
value: FancyInt

# Test different integer values
test_cases = [
(1234, "1,234"),
(1000000, "1,000,000"),
(42, "42"),
]

for input_value, expected_output in test_cases:
# Create and serialize the struct
s = CustomStruct(value=input_value)
serialized = json.loads(TypeAdapter(CustomStruct).dump_json(s))

# Verify the serialization
self.assertEqual(
serialized["value"],
expected_output,
)

def test_pydantic_serialization_with_enums(self):
"""Test serialization behavior with enums using both native and Pydantic approaches"""

class Color(csp.Enum):
RED = 1
GREEN = 2
BLUE = 3

class Shape(csp.Enum):
CIRCLE = 1
SQUARE = 2
TRIANGLE = 3

class DrawingStruct(csp.Struct):
color: Color
shape: Shape
colors: List[Color]
shapes: Dict[str, Shape]

drawing = DrawingStruct(
color=Color.RED,
shape=Shape.CIRCLE,
colors=[Color.RED, Color.GREEN, Color.BLUE],
shapes={"a": Shape.SQUARE, "b": Shape.TRIANGLE},
)

# Test native serialization
native_json = json.loads(drawing.to_json())
self.assertEqual(native_json["color"], "RED")
self.assertEqual(native_json["shape"], "CIRCLE")
self.assertEqual(native_json["colors"], ["RED", "GREEN", "BLUE"])
self.assertEqual(native_json["shapes"], {"a": "SQUARE", "b": "TRIANGLE"})

# Test Pydantic serialization
pydantic_json = json.loads(TypeAdapter(DrawingStruct).dump_json(drawing))
self.assertEqual(pydantic_json, native_json) # Should be identical for enums

# Test round-trip through both methods
native_restored = DrawingStruct.from_dict(json.loads(drawing.to_json()))
pydantic_restored = TypeAdapter(DrawingStruct).validate_json(TypeAdapter(DrawingStruct).dump_json(drawing))

self.assertEqual(native_restored, drawing)
self.assertEqual(pydantic_restored, drawing)

def test_pydantic_serialization_vs_native(self):
"""Test that Pydantic serialization matches CSP native serialization for basic types"""
from pydantic.functional_serializers import PlainSerializer

class MyEnum(csp.Enum):
OPTION1 = csp.Enum.auto()
OPTION2 = csp.Enum.auto()

# Define custom datetime serialization
# This is so that pydantic serializes datetime with the same precision as csp natively does
SimpleDatetime = Annotated[
datetime,
PlainSerializer(lambda dt: dt.strftime("%Y-%m-%dT%H:%M:%S.%f+00:00"), return_type=str, when_used="always"),
]

class SimpleStruct(csp.Struct):
i: int = 123
f: float = 3.14
s: str = "test"
b: bool = True
# dt: datetime = datetime(2023, 1, 1)
dt: SimpleDatetime = datetime(2023, 1, 1)
l: List[int] = [1, 2, 3]
d: Dict[str, float] = {"a": 1.1, "b": 2.2}
e: MyEnum = MyEnum.OPTION1

# Test with default values
s1 = SimpleStruct()
json_native = s1.to_json()
json_pydantic = TypeAdapter(SimpleStruct).dump_json(s1).decode()
self.assertEqual(json_native, json_pydantic)

# Test with custom values
s2 = SimpleStruct(
i=456,
f=2.718,
s="custom",
b=False,
dt=datetime(2024, 1, 1, tzinfo=pytz.UTC),
l=[4, 5, 6],
d={"x": 9.9, "y": 8.8},
)
json_native = s2.to_json()
json_pydantic = TypeAdapter(SimpleStruct).dump_json(s2).decode()
self.assertEqual(json.loads(json_native), json.loads(json_pydantic))

# Test with nested structs
class NestedStruct(csp.Struct):
name: str
simple: SimpleStruct
simples: List[SimpleStruct]

nested = NestedStruct(name="test", simple=s1, simples=[s1, s2])

json_native = nested.to_json()
json_pydantic = TypeAdapter(NestedStruct).dump_json(nested).decode()
self.assertEqual(json.loads(json_native), json.loads(json_pydantic))


if __name__ == "__main__":
unittest.main()

0 comments on commit 684f423

Please sign in to comment.