Skip to content

Commit

Permalink
add perform* list & retrieve methods, upd generics (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
eveighty authored Mar 12, 2024
1 parent abb3939 commit e341227
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 47 deletions.
22 changes: 13 additions & 9 deletions restdoctor/rest_framework/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ def _simple_get_object(self) -> Model:
# Perform the lookup filtering.
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field

assert lookup_url_kwarg in self.kwargs, (
'Expected view %s to be called with a URL keyword argument '
'named "%s". Fix your URL conf, or set the `.lookup_field` '
'attribute on the view correctly.' % (self.__class__.__name__, lookup_url_kwarg)
)
if lookup_url_kwarg not in self.kwargs:
raise ImproperlyConfigured(
(
f'Expected view {self.__class__.__name__} to be called with a URL keyword argument '
f'named "{lookup_url_kwarg}". Fix your URL conf, or set the `.lookup_field` '
'attribute on the view correctly.'
)
)

filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
obj = get_object_or_404(queryset, **filter_kwargs)
Expand Down Expand Up @@ -82,8 +85,9 @@ def _check_lookup_configuration(self) -> None:
)
)

def _get_queryset_for_object(self) -> QuerySet:
qs = self.get_queryset()
def _get_queryset_for_object(self, queryset: QuerySet | None = None) -> QuerySet:
if queryset is None:
queryset = self.get_queryset()
if settings.API_IGNORE_FILTER_PARAMS_FOR_DETAIL:
return qs
return self.filter_queryset(qs)
return queryset
return self.filter_queryset(queryset)
20 changes: 14 additions & 6 deletions restdoctor/rest_framework/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,26 @@ def list( # noqa: A003
meta = self.get_meta_serializer_data()
page = self.paginate_queryset(queryset)
if page is not None:
self.perform_list(page)
serializer = self.get_serializer(page, many=True)
prepare_page = self.perform_list(page, request_data=request_serializer.validated_data)
serializer = self.get_serializer(prepare_page, many=True)
response = self.get_paginated_response(serializer.data)
response.meta.update(meta)
return response

self.perform_list(queryset)
prepare_data = self.perform_list(queryset, request_data=request_serializer.validated_data)

serializer = self.get_serializer(queryset, many=True)
serializer = self.get_serializer(prepare_data, many=True)
return ResponseWithMeta(data=serializer.data, meta=meta)

def get_collection(
self, request_serializer: BaseSerializer
) -> typing.Union[typing.List, QuerySet]:
return self.filter_queryset(self.get_queryset())

def perform_list(self, data: typing.Union[typing.List, QuerySet]) -> None:
pass
def perform_list(
self, data: typing.Union[typing.List, QuerySet], request_data: dict = None
) -> typing.Union[typing.List, QuerySet]:
return data

def get_meta_data(self) -> typing.Dict[str, typing.Any]:
return {}
Expand All @@ -100,6 +102,7 @@ def retrieve(self, request: Request, *args: typing.Any, **kwargs: typing.Any) ->
request_serializer.is_valid(raise_exception=True)

item = self.get_item(request_serializer)
item = self.perform_retrieve(item)

serializer = self.get_serializer(item)
return Response(serializer.data)
Expand All @@ -112,6 +115,11 @@ def get_item(
def get_serializer(self, *args: typing.Any, **kwargs: typing.Any) -> BaseSerializer:
return self.get_response_serializer(*args, **kwargs)

def perform_retrieve(
self, item: typing.Union[dict, ModelObject]
) -> typing.Union[dict, ModelObject]:
return item


class UpdateModelMixin(BaseUpdateModelMixin):
def update(self, request: Request, *args: typing.Any, **kwargs: typing.Any) -> Response:
Expand Down
63 changes: 31 additions & 32 deletions restdoctor/rest_framework/pagination/cursor_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

import contextlib
import typing
import uuid
from typing import Any

from django.core.exceptions import ObjectDoesNotExist
from rest_framework.pagination import BasePagination
from rest_framework.utils.urls import replace_query_param, remove_query_param
from rest_framework.utils.urls import remove_query_param, replace_query_param

from restdoctor.constants import DEFAULT_MAX_PAGE_SIZE, DEFAULT_PAGE_SIZE
from restdoctor.rest_framework.pagination.mixins import SerializerClassPaginationMixin
from restdoctor.rest_framework.pagination.serializers import (
CursorUUIDRequestSerializer, CursorUUIDResponseSerializer, CursorUUIDUncountedResponseSerializer,
CursorUUIDRequestSerializer,
CursorUUIDResponseSerializer,
CursorUUIDUncountedResponseSerializer,
)
from restdoctor.rest_framework.response import ResponseWithMeta

Expand All @@ -33,21 +35,20 @@ def get_order(default_order: str, serializer_keys: typing.Sequence[str]) -> str:

def get_cursor(
queryset: QuerySet,
after_uuid: uuid.UUID = None,
before_uuid: uuid.UUID = None,
after_value: Any = None,
before_value: Any = None,
default_order: str = None,
cursor_field_name: str = 'uuid',
) -> typing.Tuple[typing.Any, str]:
cursor_obj = None
if after_uuid is not None:
with contextlib.suppress(ObjectDoesNotExist):
cursor_obj = queryset.get(uuid=after_uuid)
order = default_order or ''
with contextlib.suppress(ObjectDoesNotExist):
if after_value is not None:
cursor_obj = queryset.get(**{cursor_field_name: after_value})
order = 'after'
elif before_uuid is not None:
with contextlib.suppress(ObjectDoesNotExist):
cursor_obj = queryset.get(uuid=before_uuid)
elif before_value is not None:
cursor_obj = queryset.get(**{cursor_field_name: before_value})
order = 'before'
if cursor_obj is None:
order = default_order or ''
return cursor_obj, order


Expand All @@ -64,9 +65,7 @@ class CursorUUIDPagination(SerializerClassPaginationMixin, BasePagination):

serializer_class_map = {
'default': CursorUUIDRequestSerializer,
'pagination': {
'response': CursorUUIDResponseSerializer,
},
'pagination': {'response': CursorUUIDResponseSerializer},
}

max_page_size = DEFAULT_MAX_PAGE_SIZE
Expand All @@ -79,13 +78,15 @@ def get_lookup(self) -> OptionalLookup:
return {lookup_keyword: getattr(self.cursor_obj, self.lookup_by_field)}

def paginate_queryset(
self, queryset: QuerySet, request: Request, view: APIView = None,
self, queryset: QuerySet, request: Request, view: APIView = None
) -> OptionalList:
serializer_class = self.get_request_serializer_class()
serializer = serializer_class(data=request.query_params, max_per_page=self.max_page_size)
serializer.is_valid(raise_exception=True)

self.per_page = serializer.validated_data.get(self.page_size_query_param, self.default_page_size)
self.per_page = serializer.validated_data.get(
self.page_size_query_param, self.default_page_size
)

after_uuid = serializer.validated_data.get(self.after_query_param)
before_uuid = serializer.validated_data.get(self.before_query_param)
Expand Down Expand Up @@ -119,10 +120,7 @@ def paginate_queryset(
del paginated[-1]

if paginated:
self.page_boundaries = (
paginated[0].uuid,
paginated[-1].uuid,
)
self.page_boundaries = (paginated[0].uuid, paginated[-1].uuid)
elif self.cursor_obj:
self.page_boundaries = (self.cursor_obj.uuid, self.cursor_obj.uuid)
else:
Expand All @@ -135,7 +133,9 @@ def get_page_link_tmpl(self) -> str:
url_tmpl = replace_query_param(url_tmpl, self.page_size_query_param, self.per_page)
return url_tmpl

def get_page_link(self, after: typing.Any = None, before: typing.Any = None) -> typing.Optional[str]:
def get_page_link(
self, after: typing.Any = None, before: typing.Any = None
) -> typing.Optional[str]:
if after is not None:
base_url = remove_query_param(self.base_url, self.before_query_param)
return replace_query_param(base_url, self.after_query_param, after)
Expand All @@ -144,10 +144,7 @@ def get_page_link(self, after: typing.Any = None, before: typing.Any = None) ->
return replace_query_param(base_url, self.before_query_param, before)

def get_paginated_response(self, data: typing.Sequence[typing.Any]) -> ResponseWithMeta:
meta = {
self.page_size_query_param: self.per_page,
'has_next': self.has_next,
}
meta = {self.page_size_query_param: self.per_page, 'has_next': self.has_next}
cursor_uuid = None
if self.cursor_obj:
cursor_uuid = self.cursor_obj.uuid
Expand All @@ -161,8 +158,12 @@ def get_paginated_response(self, data: typing.Sequence[typing.Any]) -> ResponseW
if self.order == 'after':
after_idx, before_idx = 1, 0

meta['after_url'] = self.get_page_link(**{self.after_query_param: self.page_boundaries[after_idx] or ''})
meta['before_url'] = self.get_page_link(**{self.before_query_param: self.page_boundaries[before_idx] or ''})
meta['after_url'] = self.get_page_link(
**{self.after_query_param: self.page_boundaries[after_idx] or ''}
)
meta['before_url'] = self.get_page_link(
**{self.before_query_param: self.page_boundaries[before_idx] or ''}
)

return ResponseWithMeta(data=data, meta=meta)

Expand All @@ -172,7 +173,5 @@ class CursorUUIDUncountedPagination(CursorUUIDPagination):

serializer_class_map = {
'default': CursorUUIDRequestSerializer,
'pagination': {
'response': CursorUUIDUncountedResponseSerializer,
},
'pagination': {'response': CursorUUIDUncountedResponseSerializer},
}

0 comments on commit e341227

Please sign in to comment.