Skip to content

Commit

Permalink
Support arbitrary classes, passthrough if the struct is already the r…
Browse files Browse the repository at this point in the history
…ight type

Signed-off-by: Nijat Khanbabayev <[email protected]>
  • Loading branch information
NeejWeej committed Jan 17, 2025
1 parent 8f83df4 commit c252851
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 10 deletions.
4 changes: 2 additions & 2 deletions csp/impl/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __str__(self):
return f"{type(self).__name__}.{self.name}"

@classmethod
def validate(cls, v) -> "Enum":
def _validate(cls, v) -> "Enum":
if isinstance(v, cls):
return v
elif isinstance(v, str):
Expand Down Expand Up @@ -100,7 +100,7 @@ def __get_pydantic_core_schema__(
from pydantic_core import core_schema

return core_schema.no_info_before_validator_function(
cls.validate,
cls._validate,
core_schema.any_schema(),
serialization=core_schema.plain_serializer_function_ser_schema(
cls._serialize, info_arg=False, return_schema=core_schema.str_schema(), when_used="json"
Expand Down
23 changes: 20 additions & 3 deletions csp/impl/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,16 @@ def layout(self, num_cols=8):
@staticmethod
def _get_pydantic_core_schema(cls, _source_type, handler):
"""Tell Pydantic how to validate this Struct class."""
from pydantic import PydanticSchemaGenerationError
from pydantic_core import core_schema

fields = {}

for field_name, field_type in cls.__full_metadata_typed__.items():
field_schema = handler.generate_schema(field_type)
try:
field_schema = handler.generate_schema(field_type)
except PydanticSchemaGenerationError: # for classes we dont have a schema for
field_schema = core_schema.is_instance_schema(field_type)

if field_name in cls.__defaults__:
field_schema = core_schema.with_default_schema(
Expand All @@ -104,19 +108,32 @@ def _get_pydantic_core_schema(cls, _source_type, handler):
required=False, # Make all fields optional
)

# Use typed_dict_schema instead of model_fields_schema
# Schema for dictionary inputs
fields_schema = core_schema.typed_dict_schema(
fields=fields,
total=False, # Allow missing fields
)
# Schema for direct class instances
instance_schema = core_schema.is_instance_schema(cls)
# Use union schema to handle both cases
schema = core_schema.union_schema(
[
instance_schema,
fields_schema,
]
)

def create_instance(validated_data):
# We choose to not revalidate, this is the default behavior in pydantic
if isinstance(validated_data, cls):
return validated_data

data_dict = validated_data[0] if isinstance(validated_data, tuple) else validated_data
return cls(**data_dict)

return core_schema.no_info_after_validator_function(
function=create_instance,
schema=fields_schema,
schema=schema,
serialization=core_schema.plain_serializer_function_ser_schema(
function=lambda x: x.to_dict(), # Use the built-in to_dict method
return_schema=core_schema.dict_schema(),
Expand Down
44 changes: 39 additions & 5 deletions csp/tests/impl/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3013,6 +3013,12 @@ class SimpleStruct(csp.Struct):
self.assertEqual(result.name, "ya")
self.assertEqual(result.scores, [1.1, 2.2, 3.3])

# Test that we can validate existing structs
existing = SimpleStruct(value=1, scores=[1])
new = TypeAdapter(SimpleStruct).validate_python(existing)
self.assertTrue(existing is new) # we do not revalidate
self.assertEqual(existing.value, 1)

# Test type coercion
coercion_data = {
"value": "42", # string should convert to int
Expand Down Expand Up @@ -3067,6 +3073,17 @@ class EnumStruct(csp.Struct):
self.assertEqual(result.enum_field, MyEnum.A)
self.assertEqual(result.enum_list, [MyEnum.A, MyEnum.B, MyEnum.A])

# 6. test with arbitrary class
class DummyBlankClass: ...

class StructWithDummy(csp.Struct):
x: int
y: DummyBlankClass

val = DummyBlankClass()
new_struct = TypeAdapter(StructWithDummy).validate_python(dict(x=12, y=val))
self.assertTrue(new_struct.y is val)

def test_pydantic_validation_complex(self):
"""Test Pydantic validation with complex nested types and serialization"""

Expand Down Expand Up @@ -3304,11 +3321,14 @@ def test_struct_with_annotated_validation(self):
"""Test CSP Struct with Annotated fields and validators"""
from pydantic import BeforeValidator, WrapValidator

# Simple validator that modifies the value
# Simple validator that modifies the value and enforces value > 0
def value_validator(v: Any) -> int:
if isinstance(v, str):
return int(v) * 2
return v
v = int(v)
v = int(v)
if v <= 0:
raise ValueError("value must be positive")
return v * 2

# Wrap validator that can modify the whole struct
def struct_validator(val, handler) -> Any:
Expand All @@ -3333,6 +3353,20 @@ class OuterStruct(csp.Struct):
self.assertEqual(inner.description, "default")
self.assertFalse(hasattr(inner, "z"))

# test existing instance
inner_new = TypeAdapter(InnerStruct).validate_python(inner)
self.assertTrue(inner is inner_new)
# No revalidation
self.assertEqual(inner_new.value, 42)

# Test validation with invalid value in existing instance
inner.value = -5 # Set invalid value
# No revalidation, no error
self.assertTrue(inner is TypeAdapter(InnerStruct).validate_python(inner))
with self.assertRaises(ValidationError) as cm:
TypeAdapter(InnerStruct).validate_python(inner.to_dict())
self.assertIn("value must be positive", str(cm.exception))

# Test simple value validation
inner = TypeAdapter(InnerStruct).validate_python({"value": "21", "z": 17})
self.assertEqual(inner.value, 42) # "21" -> 21 -> 42
Expand All @@ -3341,15 +3375,15 @@ class OuterStruct(csp.Struct):

# Test struct validation with expansion
outer = TypeAdapter(OuterStruct).validate_python({"name": "test", "inner": {"value": 10, "z": 12}})
self.assertEqual(outer.inner.value, 10) # not a string so not doubled
self.assertEqual(outer.inner.value, 20) # 10 -> 20 (doubled)
self.assertEqual(outer.inner.description, "auto_generated")
self.assertEqual(outer.inner.z, 12)

# Test normal full structure still works
outer = TypeAdapter(OuterStruct).validate_python(
{"name": "test", "inner": {"value": "5", "description": "custom"}}
)
self.assertEqual(outer.inner.value, 10)
self.assertEqual(outer.inner.value, 10) # "5" -> 5 -> 10 (doubled)
self.assertEqual(outer.inner.description, "custom")
self.assertFalse(hasattr(outer.inner, "z")) # make sure z is not set

Expand Down

0 comments on commit c252851

Please sign in to comment.