From 6a8f6b889ab425d77eaf03ae97cc01a808559a29 Mon Sep 17 00:00:00 2001 From: Tommy Beadle Date: Mon, 19 Aug 2024 16:11:00 -0400 Subject: [PATCH] Add routers that support sync and async viewsets. The SimpleRouter included in DRF does not attempt to look for "alist", "acreate", etc. methods on async viewsets. Using this new router will cause it to use those methods for async viewsets and the non-"a"-prefixed methods for non-async viewsets. --- adrf/routers.py | 42 ++++++++++++++ adrf/viewsets.py | 9 ++- tests/test_routers.py | 131 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 177 insertions(+), 5 deletions(-) create mode 100644 adrf/routers.py create mode 100644 tests/test_routers.py diff --git a/adrf/routers.py b/adrf/routers.py new file mode 100644 index 0000000..388fad8 --- /dev/null +++ b/adrf/routers.py @@ -0,0 +1,42 @@ +from rest_framework.routers import ( + DefaultRouter as DRFDefaultRouter, +) +from rest_framework.routers import ( + SimpleRouter as DRFSimpleRouter, +) + + +class SimpleRouter(DRFSimpleRouter): + sync_to_async_action_map = { + "list": "alist", + "create": "acreate", + "retrieve": "aretrieve", + "update": "aupdate", + "destroy": "adestroy", + "partial_update": "partial_aupdate", + } + + def get_method_map(self, viewset, method_map): + """ + Given a viewset, and a mapping of http methods to actions, + return a new mapping which only includes any mappings that + are actually implemented by the viewset. + + To allow the use of a single router that registers sync and async + viewsets, the actions defined in the routes' method maps are + updated to be the "a"-prefixed names for async viewsets. + """ + bound_methods = {} + if getattr(viewset, "view_is_async", False): + method_map = { + method: self.sync_to_async_action_map.get(action, action) + for method, action in method_map.items() + } + for method, action in method_map.items(): + if hasattr(viewset, action): + bound_methods[method] = action + return bound_methods + + +class DefaultRouter(SimpleRouter, DRFDefaultRouter): + pass diff --git a/adrf/viewsets.py b/adrf/viewsets.py index be1c4f7..3f5b8db 100644 --- a/adrf/viewsets.py +++ b/adrf/viewsets.py @@ -20,9 +20,9 @@ class ViewSetMixin(DRFViewSetMixin): the binding of HTTP methods to actions on the Resource. For example, to create a concrete view binding the 'GET' and 'POST' methods - to the 'list' and 'create' actions... + to the 'alist' and 'acreate' actions... - view = MyViewSet.as_view({'get': 'list', 'post': 'create'}) + view = MyViewSet.as_view({'get': 'alist', 'post': 'acreate'}) """ @classonlymethod @@ -155,14 +155,13 @@ def view_is_async(cls): """ Checks whether any viewset methods are coroutines. """ - result = [ + return any( asyncio.iscoroutinefunction(function) for name, function in getmembers( cls, inspect.iscoroutinefunction, exclude_names=["view_is_async"] ) if not name.startswith("__") and name not in cls._ASYNC_NON_DISPATCH_METHODS - ] - return any(result) + ) class GenericViewSet(ViewSet, GenericAPIView): diff --git a/tests/test_routers.py b/tests/test_routers.py new file mode 100644 index 0000000..edb7058 --- /dev/null +++ b/tests/test_routers.py @@ -0,0 +1,131 @@ +from asgiref.sync import async_to_sync +from django.contrib.auth.models import User +from django.test import Client, TestCase, override_settings +from rest_framework import status +from rest_framework.response import Response +from rest_framework.test import APIRequestFactory +from rest_framework.viewsets import ModelViewSet as DRFModelViewSet + +from adrf.routers import SimpleRouter, DefaultRouter +from adrf.serializers import ModelSerializer +from adrf.viewsets import ModelViewSet as AsyncModelViewSet +from tests.test_views import JSON_ERROR, sanitise_json_error + + +class SyncViewSet(DRFModelViewSet): + def list(self, request): + return Response({"method": "GET", "async": False}) + + def create(self, request): + return Response({"method": "POST", "data": request.data, "async": False}) + + def retrieve(self, request, pk): + return Response({"method": "GET", "data": {"pk": pk}, "async": False}) + + def update(self, request, pk): + return Response({"method": "PUT", "data": request.data, "async": False}) + + def partial_update(self, request, pk): + return Response({"method": "PATCH", "data": request.data, "async": False}) + + def destroy(self, request, pk): + return Response({"method": "DELETE", "async": False}) + + +class AsyncViewSet(AsyncModelViewSet): + async def alist(self, request): + return Response({"method": "GET", "async": True}) + + async def acreate(self, request): + return Response({"method": "POST", "data": request.data, "async": True}) + + async def aretrieve(self, request, pk): + return Response({"method": "GET", "data": {"pk": pk}, "async": True}) + + async def aupdate(self, request, pk): + return Response({"method": "PUT", "data": request.data, "async": True}) + + async def partial_aupdate(self, request, pk): + return Response({"method": "PATCH", "data": request.data, "async": True}) + + async def adestroy(self, request, pk): + return Response({"method": "DELETE", "async": True}) + + +router = SimpleRouter() +router.register("sync", SyncViewSet, basename="sync") +router.register("async", AsyncViewSet, basename="async") +urlpatterns = router.urls + + +@override_settings(ROOT_URLCONF="tests.test_routers") +class _RouterIntegrationTests(TestCase): + use_async = None + __test__ = False + + def setUp(self): + self.client = Client() + self.url = "/" + ("async" if self.use_async else "sync") + "/" + self.detail_url = self.url + "1/" + + def test_list(self): + resp = self.client.get(self.url) + assert resp.status_code == 200 + assert resp.data == {"method": "GET", "async": self.use_async} + + def test_create(self): + resp = self.client.post( + self.url, {"foo": "bar"}, content_type="application/json" + ) + assert resp.status_code == 200 + assert resp.data == { + "method": "POST", + "data": {"foo": "bar"}, + "async": self.use_async, + } + + def test_retrieve(self): + resp = self.client.get(self.detail_url) + assert resp.status_code == 200 + assert resp.data == { + "method": "GET", + "data": {"pk": "1"}, + "async": self.use_async, + } + + def test_update(self): + resp = self.client.put( + self.detail_url, {"foo": "bar"}, content_type="application/json" + ) + assert resp.status_code == 200 + assert resp.data == { + "method": "PUT", + "data": {"foo": "bar"}, + "async": self.use_async, + } + + def test_partial_update(self): + resp = self.client.patch( + self.detail_url, {"foo": "bar"}, content_type="application/json" + ) + assert resp.status_code == 200 + assert resp.data == { + "method": "PATCH", + "data": {"foo": "bar"}, + "async": self.use_async, + } + + def test_destroy(self): + resp = self.client.delete(self.detail_url) + assert resp.status_code == 200 + assert resp.data == {"method": "DELETE", "async": self.use_async} + + +class TestSyncRouterIntegrationTests(_RouterIntegrationTests): + use_async = False + __test__ = True + + +class AsyncRouterIntegrationTests(_RouterIntegrationTests): + use_async = True + __test__ = True