Skip to content

Commit

Permalink
fix: improve pydantic model translation (#582)
Browse files Browse the repository at this point in the history
* Add support for translating pydantic descriptions to record docs

* Render record title as schema doc

* Accept kwargs on AvroModel.avro_schema and pass to json.dumps

* Address PR comments
  - added test case
  - reverted changes to PydanticParser
  - added PydanticParser.get_field_metadata staticmethod called from parse_fields() to obtain description as field doc
  - added PydanticParser.generate_documentation to obtain title as record doc

* Fix doc generation and typing

---------

Co-authored-by: Marcos Schroh <[email protected]>
  • Loading branch information
kevinjacobs-delfi and marcosschroh authored Mar 26, 2024
1 parent b8fcdb8 commit 82f53e1
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 6 deletions.
25 changes: 24 additions & 1 deletion dataclasses_avroschema/pydantic/parser.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
from __future__ import annotations

import dataclasses
import typing

from pydantic.fields import FieldInfo

from dataclasses_avroschema.fields.base import Field
from dataclasses_avroschema.fields.fields import AvroField
from dataclasses_avroschema.parser import Parser


class PydanticParser(Parser):
@staticmethod
def get_field_metadata(field_info: FieldInfo) -> dict[str, typing.Any]:
metadata: dict[str, typing.Any] = (
field_info.json_schema_extra.get("metadata", {}) if field_info.json_schema_extra else {} # type: ignore
)
if field_info.description:
metadata["doc"] = field_info.description # type: ignore
return metadata

def parse_fields(self, exclude: typing.List) -> typing.List[Field]:
return [
AvroField(
Expand All @@ -16,10 +29,20 @@ def parse_fields(self, exclude: typing.List) -> typing.List[Field]:
if field_info.is_required() or field_info.default_factory
else field_info.default,
default_factory=field_info.default_factory,
metadata=field_info.json_schema_extra.get("metadata", {}) if field_info.json_schema_extra else {},
metadata=self.get_field_metadata(field_info),
model_metadata=self.metadata,
parent=self.parent,
)
for field_name, field_info in self.type.model_fields.items()
if field_name not in exclude and field_name != "model_config"
]

def generate_documentation(self) -> typing.Optional[str]:
doc = None
if isinstance(self.metadata.schema_doc, str):
doc = self.metadata.schema_doc
elif self.type.model_config and "title" in self.type.model_config:
doc = self.type.model_config["title"]
else:
doc = super().generate_documentation()
return doc
10 changes: 5 additions & 5 deletions dataclasses_avroschema/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def _generate_parser(cls: Type[CT]) -> Parser:
return Parser(type=cls._klass, metadata=cls._metadata, parent=cls._parent or cls)

@classmethod
def avro_schema(cls: Type[CT], case_type: Optional[str] = None) -> str:
return json.dumps(cls.avro_schema_to_python(case_type=case_type))
def avro_schema(cls: Type[CT], case_type: Optional[str] = None, **kwargs) -> str:
return json.dumps(cls.avro_schema_to_python(case_type=case_type), **kwargs)

@classmethod
def avro_schema_to_python(
Expand Down Expand Up @@ -155,7 +155,7 @@ def deserialize(

@classmethod
def parse_obj(cls: Type[CT], data: Dict) -> CT:
return from_dict(data_class=cls, data=data, config=cls.config())
return from_dict(data_class=cls, data=data, config=cls.dacite_config())

def validate(self) -> bool:
schema = self.avro_schema_to_python()
Expand All @@ -171,7 +171,7 @@ def to_json(self, **kwargs: Any) -> str:
return json.dumps(data, **kwargs)

@classmethod
def config(cls: Type[CT]) -> Config:
def dacite_config(cls: Type[CT]) -> Config:
"""
Get the default config for dacite and always include the self reference
"""
Expand Down Expand Up @@ -211,4 +211,4 @@ def fake(cls: Type[CT], **data: Any) -> CT:
payload = {field.name: field.fake() for field in cls.get_fields() if field.name not in data.keys()}
payload.update(data)

return from_dict(data_class=cls, data=payload, config=cls.config())
return from_dict(data_class=cls, data=payload, config=cls.dacite_config())
16 changes: 16 additions & 0 deletions tests/schemas/pydantic/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
AmqpDsn,
AwareDatetime,
CockroachDsn,
ConfigDict,
EmailStr,
Field,
FutureDate,
Expand Down Expand Up @@ -75,6 +76,21 @@ class Meta:
assert User.avro_schema() == json.dumps(expected_schema)


def test_pydantic_record_schema_with_description():
class User(AvroBaseModel):
model_config = ConfigDict(title="User doc")

name: str = Field(description="bar")

expected_schema = {
"type": "record",
"name": "User",
"fields": [{"doc": "bar", "name": "name", "type": "string"}],
"doc": "User doc",
}
assert User.avro_schema() == json.dumps(expected_schema)


def test_pydantic_record_schema_complex_types(user_advance_avro_json, color_enum):
class UserAdvance(AvroBaseModel):
name: str
Expand Down

0 comments on commit 82f53e1

Please sign in to comment.