Skip to content

Commit

Permalink
Fix parametrized generics for dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
bcalvert-graft committed Jan 19, 2025
1 parent 742f0fe commit 4c2e643
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 22 deletions.
26 changes: 22 additions & 4 deletions sematic/types/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# Third-party
import cloudpickle # type: ignore
from typing_extensions import get_original_bases

# Sematic
from sematic.types.generic_type import GenericType
Expand All @@ -30,6 +31,7 @@
is_sematic_parametrized_generic_type,
is_supported_type_annotation,
)
from sematic.utils.types import resolve_type


# VALUE SERIALIZATION
Expand Down Expand Up @@ -193,7 +195,7 @@ def type_from_json_encodable(json_encodable: typing.Any) -> typing.Any:
def _type_repr(
type_: typing.Any,
) -> typing.Tuple[str, str, typing.Dict[str, typing.Any]]:
return (_get_category(type_), _get_key(type_), _get_parameters(type_))
return _get_category(type_), _get_key(type_), _get_parameters(type_)


_BUILTINS = (float, int, str, bool, type(None), bytes)
Expand Down Expand Up @@ -271,11 +273,11 @@ def _get_parameters(type_: typing.Any) -> typing.Dict[str, typing.Any]:
return {"args": [_parameter_repr(arg) for arg in typing.get_args(type_)]}

if _is_dataclass(type_):
field_names = [field.name for field in dataclasses.fields(type_)]
return {
"import_path": type_.__module__,
"fields": {
name: _parameter_repr(field.type)
for name, field in type_.__dataclass_fields__.items()
name: _parameter_repr(resolve_type(type_, name)) for name in field_names
},
}

Expand Down Expand Up @@ -320,8 +322,21 @@ def _is_scalar(v):

def _populate_registry(type_: typing.Any, registry: typing.Dict[str, typing.Any]) -> None:
def _include_in_registry(t) -> bool:
if _has_unparametrized_type_vars(t):
return False
return t not in (object, abc.ABC, GenericType)

def _has_unparametrized_type_vars(cls: type):
try:
return any(
isinstance(T, typing.TypeVar)
for base in get_original_bases(cls)
for T in typing.get_args(base)
)
except TypeError:
# Then we're dealing with a parametrized GenericAlias, like dict[str, str]
return False

if not _include_in_registry(type_):
return

Expand All @@ -333,7 +348,10 @@ def _include_in_registry(t) -> bool:

if _is_dataclass(type_):
_populate_registry_from_parameters(
{name: field.type for name, field in type_.__dataclass_fields__.items()},
{
name: resolve_type(type_, name)
for name in type_.__dataclass_fields__.keys()
},
registry,
)

Expand Down
36 changes: 19 additions & 17 deletions sematic/types/types/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
value_from_json_encodable,
value_to_json_encodable,
)
from sematic.utils.types import resolve_type


@register_safe_cast(DataclassKey)
Expand All @@ -57,8 +58,8 @@ def _safe_cast_dataclass(value: Any, type_: Any) -> Tuple[Any, Optional[str]]:
# Otherwise we make sure the subclass is conserved, including
# potential additional fields.
cast_value = copy.deepcopy(value)

for name, field in type_.__dataclass_fields__.items():
field_names = [field.name for field in dataclasses.fields(type_)]
for name in field_names:
try:
# First we attempt to access the property
field_value = getattr(value, name)
Expand All @@ -70,8 +71,8 @@ def _safe_cast_dataclass(value: Any, type_: Any) -> Tuple[Any, Optional[str]]:
return None, "Cannot cast {} to {}: Field {} is missing".format(
repr(value), type_, repr(name)
)

cast_field, error = safe_cast(field_value, field.type)
field_type = resolve_type(type_, name)
cast_field, error = safe_cast(field_value, field_type)
if error is not None:
return None, "Cannot cast field '{}' of {} to {}: {}".format(
name, repr(value), type_, error
Expand Down Expand Up @@ -108,7 +109,9 @@ def _can_cast_to_dataclass(from_type: Any, to_type: Any) -> Tuple[bool, Optional
return False, "{}: missing fields: {}".format(prefix, repr(missing_fields))

for name, field in to_fields.items():
can_cast, error = can_cast_type(from_fields[name].type, field.type)
from_attr_type = resolve_type(cls=from_type, attribute=name)
to_attr_type = resolve_type(cls=to_type, attribute=name)
can_cast, error = can_cast_type(from_attr_type, to_attr_type)
if not can_cast:
return False, "{}: field {} cannot cast: {}".format(prefix, repr(name), error)

Expand Down Expand Up @@ -136,11 +139,9 @@ def _dataclass_from_json_encodable(value: Any, type_: Any) -> Any:
)

kwargs = {}

fields: Dict[str, dataclasses.Field] = root_type.__dataclass_fields__

for name, field in fields.items():
field_type = field.type
field_names = [field.name for field in dataclasses.fields(root_type)]
for name in field_names:
field_type = resolve_type(root_type, name)
if name in types:
field_type = type_from_json_encodable(types[name])

Expand Down Expand Up @@ -179,7 +180,7 @@ def _serialize_dataclass(serializer: Callable, value: Any, _) -> SummaryOutput:
# The actual value type can be different from the field type if
# the value is an instance of a subclass
value_type = type(field_value)
field_type = value_serialization_type = field.type
field_type = value_serialization_type = resolve_type(type_, name)

# Only if the value type is different (e.g. subclass) do we persist the type
# serialization
Expand Down Expand Up @@ -245,19 +246,20 @@ def fromdict(dataclass_type: Type[T], as_dict: Dict[str, Any]) -> T:
if not field.init:
continue
dict_value = as_dict[name]
if dataclasses.is_dataclass(field.type):
kwargs[name] = fromdict(field.type, dict_value) # type: ignore
field_type = resolve_type(dataclass_type, name)
if dataclasses.is_dataclass(field_type):
kwargs[name] = fromdict(field_type, dict_value) # type: ignore
continue
if get_origin(field.type) is list:
element_type = get_args(field.type)[0]
if get_origin(field_type) is list:
element_type = get_args(field_type)[0]
if dataclasses.is_dataclass(element_type):
kwargs[name] = [ # type: ignore
fromdict(element_type, element) # type: ignore
for element in dict_value
]
continue
if get_origin(field.type) is dict:
value_type = get_args(field.type)[1]
if get_origin(field_type) is dict:
value_type = get_args(field_type)[1]
if dataclasses.is_dataclass(value_type):
kwargs[name] = { # type: ignore
key: fromdict(value_type, value) # type: ignore
Expand Down
184 changes: 184 additions & 0 deletions sematic/types/types/tests/test_dataclass_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""Module providing tests of Dataclass support for Generic dataclasses."""

# Standard Library
import re
from dataclasses import dataclass
from typing import Any, Generic, TypeVar

# Third-party
import pytest
from typing_extensions import get_original_bases

# Sematic
from sematic.types.casting import can_cast_type, safe_cast
from sematic.types.serialization import (
type_from_json_encodable,
type_to_json_encodable,
)


T = TypeVar("T")


@dataclass
class HasGeneric(Generic[T]):
x: T


@dataclass
class DerivedIntHasGeneric(HasGeneric[int]):
pass


@dataclass
class DerivedFloatHasGeneric(HasGeneric[float]):
pass


@dataclass
class DerivedStrHasGeneric(HasGeneric[str]):
pass


@dataclass
class DerivedIntListHasGeneric(HasGeneric[list[int]]):
pass


@pytest.mark.parametrize(
"from_type, to_type, expected_can_cast, expected_error",
(
# When the underlying types can cast to one another, it's fine
(
DerivedIntHasGeneric,
DerivedFloatHasGeneric,
True,
None,
),
(
DerivedFloatHasGeneric,
DerivedIntHasGeneric,
True,
None,
),
# But if the types can't cast, it breaks
(
DerivedFloatHasGeneric,
DerivedStrHasGeneric,
False,
(
r"Cannot cast.*DerivedFloatHasGeneric.*to.*DerivedStrHasGeneric.*"
r"field 'x' cannot cast.*float.*cannot cast.*str"
),
),
(
DerivedIntHasGeneric,
DerivedIntListHasGeneric,
False,
(
r"Cannot cast.*DerivedIntHasGeneric.*to.*DerivedIntListHasGeneric.*"
r"field 'x' cannot cast.*int.*to list.*int.*"
),
),
),
)
def test_can_cast_type(from_type, to_type, expected_can_cast, expected_error):
can_cast, error = can_cast_type(from_type, to_type)
assert can_cast is expected_can_cast, error
if expected_error is None:
assert error is None
else:
assert re.match(expected_error, error)


@pytest.mark.parametrize(
"value, type_, expected_type, expected_value, expected_error",
(
# When the underlying types can cast to one another, it's fine
(
DerivedIntHasGeneric(x=1),
DerivedFloatHasGeneric,
DerivedFloatHasGeneric,
DerivedFloatHasGeneric(x=1.0),
None,
),
(
DerivedFloatHasGeneric(x=1.0),
DerivedIntHasGeneric,
DerivedIntHasGeneric,
DerivedIntHasGeneric(x=1),
None,
),
# But if the types can't cast, it breaks
(
DerivedFloatHasGeneric(x=1),
DerivedStrHasGeneric,
DerivedStrHasGeneric,
None,
(
r"Cannot cast field 'x' of DerivedFloatHasGeneric.*to.*"
r"DerivedStrHasGeneric.*Cannot cast 1 to.*str"
),
),
(
DerivedIntHasGeneric(x=1),
DerivedIntListHasGeneric,
DerivedIntListHasGeneric,
None,
(
r"Cannot cast field 'x' of DerivedIntHasGeneric.*to.*"
r"DerivedIntListHasGeneric.*1 not an iterable"
),
),
),
)
def test_safe_cast(
value: Any, type_: type, expected_type: type, expected_value: Any, expected_error: str
):
cast_value, error = safe_cast(value, type_)
if expected_error is None:
assert isinstance(cast_value, expected_type), error
assert cast_value == expected_value
assert error is None
else:
assert error is not None
assert re.match(expected_error, error)


def test_type_to_json_encodable():
"""Test casting a dataclass to JSON-encodable."""
result = type_to_json_encodable(DerivedIntHasGeneric)
category, key, parameters = result["type"]
assert category == "dataclass"
assert key == DerivedIntHasGeneric.__name__
expected_parameters = {
"import_path": DerivedIntHasGeneric.__module__,
"fields": {"x": {"type": ("builtin", "int", {})}},
}
assert parameters == expected_parameters
registry = result["registry"]
for base in get_original_bases(DerivedIntHasGeneric):
assert base.__name__ not in registry


@dataclass
class HasDerivedGeneric:
derived_generic: DerivedIntHasGeneric


GENERIC_TYPES_TO_TEST = (
DerivedIntHasGeneric,
DerivedFloatHasGeneric,
DerivedStrHasGeneric,
DerivedIntListHasGeneric,
HasDerivedGeneric,
)


@pytest.mark.parametrize(
"type_",
GENERIC_TYPES_TO_TEST,
)
def test_type_from_json_encodable(type_: type):
json_encodable = type_to_json_encodable(type_)
assert type_from_json_encodable(json_encodable) is type_
49 changes: 49 additions & 0 deletions sematic/utils/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Standard Library
from dataclasses import dataclass
from typing import Generic, TypeVar

# Third-party
import pytest

# Sematic
from sematic.utils.types import resolve_type


def test_resolve_type():
"""Test the resolve_type utility."""

T = TypeVar("T")

@dataclass
class A(Generic[T]):
val: T

def test(self, a: T) -> T:
return self.val

@dataclass
class B(A[int]):
def test(self, a: int) -> int:
return a + self.val

@dataclass
class C(A[A[int]]):
pass

@dataclass
class Concrete:
x: int

assert resolve_type(B, "val") is int
assert resolve_type(C, "val") is A[int]
assert resolve_type(Concrete, "x") is int
for cls, attr_name, match in (
(
A,
"val",
"The annotation for 'val' has not been parametrized",
),
(Concrete, "y", "The class 'Concrete' does not have the 'y' attribute"),
):
with pytest.raises(ValueError, match=match):
resolve_type(cls, attr_name)
Loading

0 comments on commit 4c2e643

Please sign in to comment.