Skip to content

Commit

Permalink
Fix resolve_type for containers of generic types
Browse files Browse the repository at this point in the history
  • Loading branch information
bcalvert-graft committed Jan 21, 2025
1 parent 082e60e commit 775fba7
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 14 deletions.
28 changes: 28 additions & 0 deletions sematic/types/types/tests/test_dataclass_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


T = TypeVar("T")
U = TypeVar("U")


@dataclass
Expand Down Expand Up @@ -47,6 +48,23 @@ class DerivedIntListHasGeneric(HasGeneric[list[int]]):
pass


@dataclass
class HasGenericContainers(Generic[T, U]):
items_list: list[T]
items_tuple: tuple[T]
items_set: set[U]


@dataclass
class DerivedIntFloat(HasGenericContainers[int, float]):
pass


@dataclass
class DerivedIntStr(HasGenericContainers[int, str]):
pass


@pytest.mark.parametrize(
"from_type, to_type, expected_can_cast, expected_error",
(
Expand Down Expand Up @@ -82,6 +100,16 @@ class DerivedIntListHasGeneric(HasGeneric[list[int]]):
r"field 'x' cannot cast.*int.*to list.*int.*"
),
),
(DerivedIntFloat, DerivedIntFloat, True, None),
(
DerivedIntFloat,
DerivedIntStr,
False,
(
r"Cannot cast.*DerivedIntFloat.*DerivedIntStr.*"
r"Can't cast set.*float.*to set.*str.*float.*cannot cast to str"
),
),
),
)
def test_can_cast_type(from_type, to_type, expected_can_cast, expected_error):
Expand Down
33 changes: 30 additions & 3 deletions sematic/utils/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from sematic.utils.types import resolve_type


def test_resolve_type():
"""Test the resolve_type utility."""
T = TypeVar("T")
U = TypeVar("U")


T = TypeVar("T")
def test_resolve_type_for_basic_type():
"""Test the resolve_type utility."""

@dataclass
class A(Generic[T]):
Expand Down Expand Up @@ -47,3 +49,28 @@ class Concrete:
):
with pytest.raises(ValueError, match=match):
resolve_type(cls, attr_name)


def test_resolve_type_for_container_types():
"""Test the resolve_type utility."""

@dataclass
class HasContainers(Generic[T, U]):
items_list: list[T]
items_tuple: tuple[T]
items_set: set[T]
maps: dict[T, U]

@dataclass
class IntFloat(HasContainers[int, float]):
pass

@dataclass
class Nested(HasContainers[int, HasContainers[int, float]]):
pass

assert resolve_type(IntFloat, "items_list") == list[int]
assert resolve_type(IntFloat, "items_tuple") == tuple[int]
assert resolve_type(IntFloat, "items_set") == set[int]
assert resolve_type(IntFloat, "maps") == dict[int, float]
assert resolve_type(Nested, "maps") == dict[int, HasContainers[int, float]]
37 changes: 26 additions & 11 deletions sematic/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,35 @@ def resolve_type(cls: type, attribute: str):
f"The class '{cls.__name__}' does not have the '{attribute}' attribute"
)
raise ValueError(error_msg)
# And if it's a TypeVar....
if isinstance(field_type, TypeVar):
# Iterate through the bases to find the matching original type
return _resolve_generic_type(cls=cls, type_=field_type, attribute=attribute)


def _resolve_generic_type(
cls: type, type_: type | TypeVar, attribute: str
) -> type | TypeVar:
origin = get_origin(type_)
if origin is not None: # It's a generic like list, dict, etc.
args = tuple(
_resolve_generic_type(cls=cls, type_=arg, attribute=attribute)
for arg in get_args(type_)
)
return origin[args] if args else origin
elif isinstance(type_, TypeVar): # Resolve TypeVar
# Resolve the TypeVar from the class's __orig_bases__
for base in get_original_bases(cls):
origin = get_origin(base)
if origin is None:
base_origin = get_origin(base)
if base_origin is None:
raise ValueError(f"Found no origin for the base: {base.__name__}")
elif origin is Generic:
elif base_origin is Generic:
error_msg = f"The annotation for '{attribute}' has not been parametrized"
raise ValueError(error_msg)
else:
type_args = get_args(base)
# Map TypeVars to their actual types
type_var_mapping = dict(zip(origin.__parameters__, type_args))
if field_type in type_var_mapping:
return type_var_mapping[field_type]
return field_type
type_var_mapping = dict(zip(base_origin.__parameters__, type_args))
if type_ in type_var_mapping:
return type_var_mapping[type_]
# Unresolved TypeVar
return type_
else:
# Non-generic type (e.g., int, str)
return type_

0 comments on commit 775fba7

Please sign in to comment.