From 2766f1fe12d77ad75faf948fdc785fd0fc3450e5 Mon Sep 17 00:00:00 2001 From: Marcos Schroh <2828842+marcosschroh@users.noreply.github.com> Date: Fri, 12 Apr 2024 18:29:32 -0300 Subject: [PATCH] fix: custom types with extra annotation. Closes #598 (#601) --- dataclasses_avroschema/fields/fields.py | 14 ++++++++++---- dataclasses_avroschema/types.py | 20 ++++++++++++++++---- dataclasses_avroschema/utils.py | 6 +++++- tests/fields/consts.py | 20 +++++++++++++++----- 4 files changed, 46 insertions(+), 14 deletions(-) diff --git a/dataclasses_avroschema/fields/fields.py b/dataclasses_avroschema/fields/fields.py index a3bcdada..4239f023 100644 --- a/dataclasses_avroschema/fields/fields.py +++ b/dataclasses_avroschema/fields/fields.py @@ -845,12 +845,18 @@ def field_factory( if native_type is None: native_type = type(None) - if utils.is_annotated(native_type) and native_type not in ALL_TYPES_FIELD_CLASSES: + if utils.is_annotated(native_type): a_type, *extra_args = get_args(native_type) field_info = next((arg for arg in extra_args if isinstance(arg, types.FieldInfo)), None) - # it means that it is a custom type defined by us `Int32`, `Float32`,`TimeMicro` or `DateTimeMicro` - # or a known type Annotated with the end user - native_type = a_type + + if field_info is not None: + # it means that it is a custom type defined by us `Int32`, `Float32`,`TimeMicro`, `DateTimeMicro` + # confixed or condecimal + native_type = utils.rebuild_annotation(a_type, field_info) + + if native_type not in ALL_TYPES_FIELD_CLASSES: + # type Annotated with the end user + native_type = a_type if native_type in IMMUTABLE_FIELDS_CLASSES: klass = IMMUTABLE_FIELDS_CLASSES[native_type] diff --git a/dataclasses_avroschema/types.py b/dataclasses_avroschema/types.py index 7f22bfc5..d8b3c99c 100644 --- a/dataclasses_avroschema/types.py +++ b/dataclasses_avroschema/types.py @@ -51,6 +51,18 @@ def __repr__(self) -> str: return f"FixedFieldInfo(size={self.size}, aliases={self.aliases}, namespace={self.namespace})" +class Int32FieldInfo(FieldInfo): ... + + +class Float32FieldInfo(FieldInfo): ... + + +class TimeMicroFieldInfo(FieldInfo): ... + + +class DateTimeMicro2FieldInfo(FieldInfo): ... + + def confixed( *, size, @@ -67,10 +79,10 @@ def condecimal(*, max_digits: int, decimal_places: int) -> typing.Type[decimal.D ] # type: ignore[return-value] -Int32 = Annotated[int, "Int32"] -Float32 = Annotated[float, "Float32"] -TimeMicro = Annotated[datetime.time, "TimeMicro"] -DateTimeMicro = Annotated[datetime.datetime, "DateTimeMicro"] +Int32 = Annotated[int, Int32FieldInfo()] +Float32 = Annotated[float, Float32FieldInfo()] +TimeMicro = Annotated[datetime.time, TimeMicroFieldInfo()] +DateTimeMicro = Annotated[datetime.datetime, DateTimeMicro2FieldInfo()] CUSTOM_TYPES = ( Int32, diff --git a/dataclasses_avroschema/utils.py b/dataclasses_avroschema/utils.py index 01affc6b..0111db78 100644 --- a/dataclasses_avroschema/utils.py +++ b/dataclasses_avroschema/utils.py @@ -5,7 +5,7 @@ from typing_extensions import Annotated, get_origin -from .types import JsonDict +from .types import FieldInfo, JsonDict try: import pydantic # pragma: no cover @@ -64,6 +64,10 @@ def is_annotated(a_type: typing.Any) -> bool: return origin is not None and isinstance(origin, type) and issubclass(origin, Annotated) # type: ignore[arg-type] +def rebuild_annotation(a_type: typing.Any, field_info: FieldInfo) -> typing.Type: + return Annotated[a_type, field_info] # type: ignore[return-value] + + def standardize_custom_type(value: typing.Any) -> typing.Any: if isinstance(value, dict): return {k: standardize_custom_type(v) for k, v in value.items()} diff --git a/tests/fields/consts.py b/tests/fields/consts.py index 18656567..d481fbe3 100644 --- a/tests/fields/consts.py +++ b/tests/fields/consts.py @@ -7,6 +7,7 @@ import pytest from typing_extensions import Annotated +from dataclasses_avroschema import types from dataclasses_avroschema.fields import field_utils PY_VER = sys.version_info @@ -21,6 +22,10 @@ (bytes, field_utils.BYTES), (None, field_utils.NULL), (type(None), field_utils.NULL), + (types.Int32, field_utils.INT), + (Annotated[types.Int32, "ExtraAnnotation"], field_utils.INT), + (types.Float32, field_utils.FLOAT), + (Annotated[types.Float32, "ExtraAnnotation"], field_utils.FLOAT), (Annotated[str, "string"], field_utils.STRING), (Annotated[int, "integer"], field_utils.LONG), (Annotated[bool, "boolean"], field_utils.BOOLEAN), @@ -83,6 +88,8 @@ (bytes, b"test"), (None, None), (type(None), None), + (types.Int32, 10), + (types.Float32, 10.7), (Annotated[str, "string"], "test"), (Annotated[int, "int"], 1), (Annotated[bool, "boolean"], True), @@ -105,7 +112,9 @@ LOGICAL_TYPES = ( (datetime.date, field_utils.LOGICAL_DATE, now.date()), (datetime.time, field_utils.LOGICAL_TIME_MILIS, now.time()), + (types.TimeMicro, field_utils.LOGICAL_TIME_MICROS, now.time()), (datetime.datetime, field_utils.LOGICAL_DATETIME_MILIS, now), + (types.DateTimeMicro, field_utils.LOGICAL_DATETIME_MICROS, now), (uuid.UUID, field_utils.LOGICAL_UUID, uuid.uuid4()), (Annotated[datetime.date, "date"], field_utils.LOGICAL_DATE, now.date()), (Annotated[datetime.time, "time"], field_utils.LOGICAL_TIME_MILIS, now.time()), @@ -399,14 +408,15 @@ def xfail_annotation(typ): # Represent the logical types # (python_type, avro_type) LOGICAL_TYPES = ( - (datetime.date, {"type": field_utils.INT, "logicalType": field_utils.DATE}), - (datetime.time, {"type": field_utils.INT, "logicalType": field_utils.TIME_MILLIS}), + (datetime.date, field_utils.LOGICAL_DATE), + (datetime.time, field_utils.LOGICAL_TIME_MILIS), + (types.TimeMicro, field_utils.LOGICAL_TIME_MICROS), ( datetime.datetime, - {"type": field_utils.LONG, "logicalType": field_utils.TIMESTAMP_MILLIS}, + field_utils.LOGICAL_DATETIME_MILIS, ), - (uuid.uuid4, {"type": field_utils.STRING, "logicalType": field_utils.UUID}), - (uuid.UUID, {"type": field_utils.STRING, "logicalType": field_utils.UUID}), + (uuid.uuid4, field_utils.LOGICAL_UUID), + (uuid.UUID, field_utils.LOGICAL_UUID), ) LOGICAL_TYPES_AND_INVALID_DEFAULTS = (