Skip to content

Commit

Permalink
fix: Model and Parser protocols added
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosschroh committed Nov 28, 2024
1 parent 5e65d6c commit 7a68b3f
Show file tree
Hide file tree
Showing 16 changed files with 257 additions and 111 deletions.
6 changes: 3 additions & 3 deletions dataclasses_avroschema/dacite_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,16 @@ def generate_dacite_config(model: typing.Type["AvroModel"]) -> Config:
"""
Get the default config for dacite and always include the self reference
"""
# We need to make sure that the `avro schemas` has been generated, otherwise cls._klass is empty
# We need to make sure that the `avro schemas` has been generated, otherwise cls._dataclass is empty
# It won't affect the performance because the rendered schema will be store in model._rendered_schema
model.generate_schema()
dacite_user_config = model._metadata.dacite_config # type: ignore
dacite_user_config = model._parser.metadata.dacite_config # type: ignore

dacite_config = {
"check_types": False,
"cast": [],
"forward_references": {
model._klass.__name__: model._klass, # type: ignore
model._parser.dataclass.__name__: model._parser.dataclass, # type: ignore
},
"type_hooks": {
datetime: parse_datetime,
Expand Down
7 changes: 2 additions & 5 deletions dataclasses_avroschema/faust/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
raise Exception("faust-streaming must be installed in order to use AvroRecord") from ex # pragma: no cover


CT = typing.TypeVar("CT", bound="AvroRecord")


class AvroRecord(Record, AvroModel): # type: ignore
def validate_avro(self) -> bool:
"""
Expand Down Expand Up @@ -56,5 +53,5 @@ def to_dict(self) -> JsonDict:
return self.standardize_type(include_type=False)

@classmethod
def _generate_parser(cls: typing.Type[CT]) -> FaustParser:
return FaustParser(type=cls._klass, metadata=cls.get_metadata(), parent=cls._parent or cls)
def _generate_parser(cls: typing.Type["AvroRecord"]) -> FaustParser:
return FaustParser(type=cls, parent=cls._parent or cls)
12 changes: 12 additions & 0 deletions dataclasses_avroschema/faust/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,20 @@
from dataclasses_avroschema.fields.fields import AvroField
from dataclasses_avroschema.parser import Parser

if typing.TYPE_CHECKING:
from .main import AvroRecord # pragma: no cover


class FaustParser(Parser):
def __init__(
self,
type,
parent,
):
super().__init__(type, parent)
self.type: typing.Type["AvroRecord"]
self.parent: typing.Type["AvroRecord"]

def parse_fields(self, exclude: typing.List) -> typing.List[Field]:
schema_fields = []

Expand Down
5 changes: 4 additions & 1 deletion dataclasses_avroschema/fields/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,10 @@ def get_avro_type(self) -> typing.Union[str, typing.List, typing.Dict]:
meta = getattr(self.type, "Meta", type)
metadata = utils.SchemaMetadata.create(meta)

alias = self.parent._metadata.get_alias_nested_items(self.name) or metadata.get_alias_nested_items(self.name) # type: ignore # noqa E501
if self.parent is not None and self.parent._parser is not None:
alias = self.parent._parser.metadata.get_alias_nested_items(self.name)
else:
alias = metadata.get_alias_nested_items(self.name) # type: ignore # noqa E501

# The priority for the schema name
# 1. Check if exists an alias_nested_items in parent llass or Meta class of own model
Expand Down
70 changes: 24 additions & 46 deletions dataclasses_avroschema/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import json
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union
from typing import Any, Dict, List, Optional, Set, Type, Union

from dacite import Config, from_dict
from fastavro.validation import validate
Expand All @@ -12,41 +12,20 @@
from .fields.base import Field
from .parser import Parser
from .types import JsonDict
from .utils import SchemaMetadata, UserDefinedType, standardize_custom_type

CT = TypeVar("CT", bound="AvroModel")

from .utils import UserDefinedType, standardize_custom_type

_schemas_cache: Dict["Type[AvroModel]", dict] = {}
_dacite_config_cache: Dict["Type[AvroModel]", Config] = {}


class AvroModel:
_parser: Optional[Parser] = None
_klass: Optional[Type] = None
_metadata: Optional[SchemaMetadata] = None
_parent: Any = None
_parent: Optional[Type["AvroModel"]] = None
_user_defined_types: Set[UserDefinedType] = set()
_rendered_schema: OrderedDict = dataclasses.field(default_factory=OrderedDict)

@classmethod
def generate_dataclass(cls: "Type[CT]") -> "Type[CT]":
if cls is AvroModel:
raise AttributeError("Schema generation must be called on a subclass of AvroModel, not AvroModel itself.")

if dataclasses.is_dataclass(cls):
return cls # type: ignore
return dataclasses.dataclass(cls)

@classmethod
def get_metadata(cls: "Type[CT]") -> SchemaMetadata:
if cls._metadata is None:
meta = getattr(cls._klass, "Meta", type)
cls._metadata = SchemaMetadata.create(meta)
return cls._metadata

@classmethod
def get_fullname(cls) -> str:
def get_fullname(cls: Type["AvroModel"]) -> str:
"""
Fullname is composed of two parts: a name and a namespace
separated by a dot. A namespace is a dot-separated sequence of such names.
Expand All @@ -56,26 +35,25 @@ def get_fullname(cls) -> str:
"""
# we need to make sure that the schema has been generated
cls.generate_schema()
metadata = cls.get_metadata()
assert cls._parser
metadata = cls._parser.metadata

if metadata.namespace:
# if the current record has a namespace we use it
return f"{metadata.namespace}.{cls.__name__}"
elif cls._parent is not None:
# if the record has a parent then we try to use the parent namespace
parent_metadata = cls._parent.get_metadata()
assert cls._parent._parser
parent_metadata = cls._parent._parser.metadata
if parent_metadata.namespace:
return f"{parent_metadata.namespace}.{cls.__name__}"
return cls.__name__

@classmethod
def generate_schema(
cls: "Type[CT]", schema_type: serialization.SerializationType = "avro"
cls: Type["AvroModel"], schema_type: serialization.SerializationType = "avro"
) -> Optional[OrderedDict]:
if cls._parser is None:
# Generate dataclass and metadata
cls._klass = cls.generate_dataclass()

# let's live open the possibility to define different
# schema definitions like json
if schema_type == "avro":
Expand All @@ -98,17 +76,17 @@ def _get_serialization_context(cls) -> JsonDict:
return {user_type.model.__name__: user_type.model for user_type in cls._user_defined_types}

@classmethod
def _generate_parser(cls: "Type[CT]") -> Parser:
return Parser(type=cls._klass, metadata=cls.get_metadata(), parent=cls._parent or cls)
def _generate_parser(cls: Type["AvroModel"]) -> Parser:
return Parser(type=cls, parent=cls._parent or cls)

@classmethod
def avro_schema(cls: "Type[CT]", case_type: Optional[str] = None, **kwargs) -> str:
def avro_schema(cls: Type["AvroModel"], case_type: Optional[str] = None, **kwargs) -> str:
return json.dumps(cls.avro_schema_to_python(case_type=case_type), **kwargs)

@classmethod
def avro_schema_to_python(
cls: "Type[CT]",
parent: Optional["CT"] = None,
cls: Type["AvroModel"],
parent: Optional[Type["AvroModel"]] = None,
case_type: Optional[str] = None,
) -> Dict[str, Any]:
if parent is not None:
Expand All @@ -135,13 +113,13 @@ def avro_schema_to_python(
return json.loads(json.dumps(avro_schema))

@classmethod
def get_fields(cls: "Type[CT]") -> List[Field]:
def get_fields(cls: Type["AvroModel"]) -> List[Field]:
if cls._parser is None:
cls.generate_schema()
return cls._parser.fields # type: ignore

@classmethod
def _reset_parser(cls: "Type[CT]") -> None:
def _reset_parser(cls: Type["AvroModel"]) -> None:
"""
Reset all the values to original state.
"""
Expand All @@ -151,12 +129,12 @@ def _reset_parser(cls: "Type[CT]") -> None:

@classmethod
def deserialize(
cls: "Type[CT]",
cls: Type["AvroModel"],
data: bytes,
serialization_type: serialization.SerializationType = "avro",
create_instance: bool = True,
writer_schema: Optional[Union[JsonDict, "Type[CT]"]] = None,
) -> Union[JsonDict, CT]:
writer_schema: Optional[Union[JsonDict, Type["AvroModel"]]] = None,
) -> Union[JsonDict, "AvroModel"]:
payload = cls.deserialize_to_python(data, serialization_type, writer_schema)
obj = cls.parse_obj(payload)

Expand All @@ -166,10 +144,10 @@ def deserialize(

@classmethod
def deserialize_to_python( # This can be used straight with a pydantic dataclass to bypass dacite
cls: "Type[CT]",
cls: Type["AvroModel"],
data: bytes,
serialization_type: serialization.SerializationType = "avro",
writer_schema: Union[JsonDict, "Type[CT]", None] = None,
writer_schema: Union[JsonDict, Type["AvroModel"], None] = None,
) -> dict:
if inspect.isclass(writer_schema) and issubclass(writer_schema, AvroModel):
# mypy does not understand redefinitions
Expand All @@ -188,15 +166,15 @@ def deserialize_to_python( # This can be used straight with a pydantic dataclas
)

@classmethod
def parse_obj(cls: "Type[CT]", data: Dict) -> CT:
def parse_obj(cls: Type["AvroModel"], data: Dict) -> "AvroModel":
config = _dacite_config_cache.get(cls)
if config is None:
config = generate_dacite_config(cls)
_dacite_config_cache[cls] = config
return from_dict(data_class=cls, data=data, config=config)

@classmethod
def fake(cls: "Type[CT]", **data: Any) -> CT:
def fake(cls: Type["AvroModel"], **data: Any) -> "AvroModel":
"""
Creates a fake instance of the model.
Expand All @@ -215,7 +193,7 @@ def asdict(self) -> JsonDict:
field.name: standardize_custom_type(
field_name=field.name, value=getattr(self, field.name), model=self, base_class=AvroModel
)
for field in dataclasses.fields(self) # type: ignore
for field in dataclasses.fields(self) # type: ignore[arg-type]
}

def serialize(self, serialization_type: serialization.SerializationType = "avro") -> bytes:
Expand Down
44 changes: 30 additions & 14 deletions dataclasses_avroschema/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import typing
from collections import OrderedDict

from . import utils
from .fields.base import Field
from .version import PY_VERSION
from .utils import SchemaMetadata

KWARGS = {"slots": True} if PY_VERSION >= (3, 10) else {}
if typing.TYPE_CHECKING:
from .main import AvroModel

Check warning on line 10 in dataclasses_avroschema/parser.py

View check run for this annotation

Codecov / codecov/patch

dataclasses_avroschema/parser.py#L10

Added line #L10 was not covered by tests
# from .protocol import ModelProtocol


@dataclasses.dataclass(**KWARGS) # type: ignore
class Parser:
"""
Parse python dataclasses to represent it as an avro schema.
Expand All @@ -19,18 +19,34 @@ class Parser:
be represented as an avro type.
"""

type: typing.Any
parent: typing.Any
metadata: utils.SchemaMetadata
fields: typing.List[Field] = dataclasses.field(default_factory=list)
# mapping of field_name: Field
fields_map: typing.Dict[str, Field] = dataclasses.field(default_factory=dict)
def __init__(
self,
type: typing.Type["AvroModel"],
parent: typing.Type["AvroModel"],
):
self.type = type
self.parent = parent

def __post_init__(self) -> None:
# generate the dataclass for thr given type
self.dataclass = self.generate_dataclass()

meta = getattr(type, "Meta", type)
self.metadata = SchemaMetadata.create(meta)
exclude = self.metadata.exclude

self.fields = self.parse_fields(exclude=exclude)
self.fields_map = {field.name: field for field in self.fields}

def generate_dataclass(self) -> typing.Type:
from .main import AvroModel

if self.type is AvroModel:
raise AttributeError("Schema generation must be called on a subclass of AvroModel, not AvroModel itself.")

if dataclasses.is_dataclass(self.type):
return self.type
return dataclasses.dataclass(self.type)

def parse_fields(self, exclude: typing.List) -> typing.List[Field]:
from .fields.fields import AvroField

Expand All @@ -44,19 +60,19 @@ def parse_fields(self, exclude: typing.List) -> typing.List[Field]:
model_metadata=self.metadata,
parent=self.parent,
)
for dataclass_field in dataclasses.fields(self.type)
for dataclass_field in dataclasses.fields(self.dataclass)
if dataclass_field.name not in exclude
]

def get_fields_map(self) -> typing.Dict[str, Field]:
return self.fields_map

def get_schema_name(self) -> str:
return self.type._metadata.schema_name or self.type.__name__
return self.metadata.schema_name or self.type.__name__

def generate_documentation(self) -> typing.Optional[str]:
if isinstance(self.metadata.schema_doc, str):
doc = self.metadata.schema_doc
doc: typing.Optional[str] = self.metadata.schema_doc
else:
doc = self.type.__doc__
# dataclasses create a (in avro context) useless docstring by default,
Expand Down
Loading

0 comments on commit 7a68b3f

Please sign in to comment.