diff --git a/dataclasses_avroschema/utils.py b/dataclasses_avroschema/utils.py index 03e94497..d4fd92f3 100644 --- a/dataclasses_avroschema/utils.py +++ b/dataclasses_avroschema/utils.py @@ -4,6 +4,7 @@ from datetime import datetime, timezone from functools import lru_cache +import typing_extensions from typing_extensions import Annotated, get_origin from .protocol import ModelProtocol # pragma: no cover @@ -22,6 +23,24 @@ faust = None # type: ignore # pragma: no cover +@lru_cache(maxsize=None) +def _get_typing_objects_by_name_of(name: str) -> tuple[typing.Any, ...]: + """Get the member named `name` from both `typing` and `typing-extensions` (if it exists).""" + result = tuple(getattr(module, name) for module in (typing, typing_extensions) if hasattr(module, name)) + if not result: + raise ValueError(f'Neither `typing` nor `typing_extensions` has an object called {name!r}') + return result + + +def _is_typing_name(obj: object, name: str) -> bool: + """Return whether `obj` is the member of the typing modules (includes the `typing-extensions` one) named `name`.""" + # Using `any()` is slower: + for thing in _get_typing_objects_by_name_of(name): + if obj is thing: + return True + return False + + @lru_cache(maxsize=None) def is_pydantic_model(klass: typing.Type[ModelProtocol]) -> bool: if pydantic is not None: @@ -82,8 +101,16 @@ class User(...) def is_annotated(a_type: typing.Type) -> bool: - origin = get_origin(a_type) - return origin is not None and isinstance(origin, type) and issubclass(origin, Annotated) # type: ignore[arg-type] + """ + Given a python type, return True if is typing.Annotated, otherwise False + + Arguments: + a_type (typing.Any): python type + + Returns: + bool + """ + return _is_typing_name(get_origin(a_type), name="Annotated") def rebuild_annotation(a_type: typing.Type, field_info: FieldInfo) -> typing.Type: