Skip to content

Commit

Permalink
BANG-483: PydanticSerializer for query params (#169)
Browse files Browse the repository at this point in the history
* BANG-483: add map_pydantic_query_serializer method to SerializerSchema class

* BANG-483: add map_pydantic_query_serializer method to SerializerSchemaProtocol

* BANG-483: add call of map_pydantic_query_serializer method to RestDoctorSchema.get_request_serializer_filter_parameters

* BANG-483: add test for map_pydantic_query_serializer method

---------

Co-authored-by: Alexander Munz <[email protected]>
  • Loading branch information
Westerling13 and Alexander Munz authored Mar 19, 2024
1 parent 3185b00 commit ed98f4d
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 2 deletions.
9 changes: 8 additions & 1 deletion restdoctor/rest_framework/schema/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from rest_framework.serializers import BaseSerializer
from semver import VersionInfo

from restdoctor.rest_framework.serializers import PydanticSerializer

OpenAPISchema = t.Dict[str, 'OpenAPISchema'] # type: ignore
OpenAPISchemaParameter = t.Dict[str, t.Any]
LocalRefs = t.Dict[t.Tuple[str, ...], t.Any]
Expand Down Expand Up @@ -66,6 +68,11 @@ def map_serializer(
def map_query_serializer(self, serializer: BaseSerializer) -> t.List[OpenAPISchema]:
...

def map_pydantic_query_serializer(
self, serializer: PydanticSerializer
) -> t.List[OpenAPISchema]:
...


class FieldSchemaBase:
view_schema: ViewSchemaBase
Expand All @@ -75,7 +82,7 @@ class SerializerSchemaBase:
view_schema: ViewSchemaBase


class ViewSchemaBase(abc.ABC): # noqa B024
class ViewSchemaBase(abc.ABC): # noqa: B024
generator: t.Optional[SchemaGenerator] = None
serializer_schema: SerializerSchemaProtocol
field_schema: FieldSchemaProtocol
Expand Down
5 changes: 4 additions & 1 deletion restdoctor/rest_framework/schema/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
get_app_prefix,
normalize_action_schema,
)
from restdoctor.rest_framework.serializers import EmptySerializer
from restdoctor.rest_framework.serializers import EmptySerializer, PydanticSerializer
from restdoctor.rest_framework.views import SerializerClassMapApiView

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -138,6 +138,9 @@ def get_request_serializer_filter_parameters(
request_serializer_class = self.view.get_request_serializer_class(use_default=False)
request_serializer = request_serializer_class()
if not isinstance(request_serializer, EmptySerializer):
if isinstance(request_serializer, PydanticSerializer):
return self.serializer_schema.map_pydantic_query_serializer(request_serializer)

for field in request_serializer.fields.values():
field_schema = self.field_schema.get_field_schema(field)
parameters.append(
Expand Down
18 changes: 18 additions & 0 deletions restdoctor/rest_framework/schema/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,21 @@ def map_query_serializer(self, serializer: BaseSerializer) -> typing.List[OpenAP
props.append(prop)

return props

def map_pydantic_query_serializer(
self, serializer: PydanticSerializer
) -> typing.List[OpenAPISchema]:
props = []
schema_dict = fix_pydantic_title(serializer.pydantic_model_class.schema())
required_fields = schema_dict.get('required', [])
for field_name, field_schema in schema_dict['properties'].items():
props.append(
{
'name': field_name,
'in': 'query',
'required': field_name in required_fields,
'schema': field_schema,
}
)

return props
52 changes: 52 additions & 0 deletions tests/test_unit/test_schema/test_pydantic_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import copy
import datetime
import typing
from uuid import UUID

import pytest
from pydantic import BaseModel, Field, StrictInt, StrictStr
Expand All @@ -19,6 +21,13 @@ class PydanticTestModel(BaseModel):
title: str


class PydanticTestQueryModel(BaseModel):
boolean_param: typing.Optional[bool] = Field(description='Boolean filter param')
string_param: str
integer_param: typing.Optional[int]
uuid_list_param: typing.Optional[typing.List[UUID]]


class PydanticNestedTestModel(BaseModel):
created_at: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
nested_field: PydanticTestModel
Expand All @@ -32,6 +41,11 @@ class NestedPydanticTestSerializer(PydanticSerializer):
pydantic_model = PydanticNestedTestModel


class PydanticTestQuerySerializer(PydanticSerializer):
class Meta:
pydantic_model = PydanticTestQueryModel


@pytest.fixture()
def test_model_schema():
return {
Expand Down Expand Up @@ -117,3 +131,41 @@ def test_map_serializer_with_refs_generator_with_nested_serializer_success_case(
== test_nested_model_schema_without_definitions
)
assert schema_generator.local_refs_registry.get_local_ref(nested_ref) == test_model_schema


def test__serializer_schema():
serializer_schema = RestDoctorSchema().serializer_schema
expected_data = [
{
'in': 'query',
'name': 'boolean_param',
'required': False,
'schema': {'description': 'Boolean filter param', 'type': 'boolean'},
},
{
'in': 'query',
'name': 'string_param',
'required': True,
'schema': {'description': 'String Param', 'type': 'string'},
},
{
'in': 'query',
'name': 'integer_param',
'required': False,
'schema': {'description': 'Integer Param', 'type': 'integer'},
},
{
'in': 'query',
'name': 'uuid_list_param',
'required': False,
'schema': {
'description': 'Uuid List Param',
'items': {'format': 'uuid', 'type': 'string'},
'type': 'array',
},
},
]

result = serializer_schema.map_pydantic_query_serializer(PydanticTestQuerySerializer())

assert result == expected_data

0 comments on commit ed98f4d

Please sign in to comment.