Skip to content

Commit

Permalink
Add routers that support sync and async viewsets.
Browse files Browse the repository at this point in the history
The SimpleRouter included in DRF does not attempt to look for "alist",
"acreate", etc. methods on async viewsets. Using this new router will
cause it to use those methods for async viewsets and the
non-"a"-prefixed methods for non-async viewsets.
  • Loading branch information
tbeadle committed Aug 19, 2024
1 parent 226e2f0 commit 6a8f6b8
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 5 deletions.
42 changes: 42 additions & 0 deletions adrf/routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from rest_framework.routers import (
DefaultRouter as DRFDefaultRouter,
)
from rest_framework.routers import (
SimpleRouter as DRFSimpleRouter,
)


class SimpleRouter(DRFSimpleRouter):
sync_to_async_action_map = {
"list": "alist",
"create": "acreate",
"retrieve": "aretrieve",
"update": "aupdate",
"destroy": "adestroy",
"partial_update": "partial_aupdate",
}

def get_method_map(self, viewset, method_map):
"""
Given a viewset, and a mapping of http methods to actions,
return a new mapping which only includes any mappings that
are actually implemented by the viewset.
To allow the use of a single router that registers sync and async
viewsets, the actions defined in the routes' method maps are
updated to be the "a"-prefixed names for async viewsets.
"""
bound_methods = {}
if getattr(viewset, "view_is_async", False):
method_map = {
method: self.sync_to_async_action_map.get(action, action)
for method, action in method_map.items()
}
for method, action in method_map.items():
if hasattr(viewset, action):
bound_methods[method] = action
return bound_methods


class DefaultRouter(SimpleRouter, DRFDefaultRouter):
pass
9 changes: 4 additions & 5 deletions adrf/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class ViewSetMixin(DRFViewSetMixin):
the binding of HTTP methods to actions on the Resource.
For example, to create a concrete view binding the 'GET' and 'POST' methods
to the 'list' and 'create' actions...
to the 'alist' and 'acreate' actions...
view = MyViewSet.as_view({'get': 'list', 'post': 'create'})
view = MyViewSet.as_view({'get': 'alist', 'post': 'acreate'})
"""

@classonlymethod
Expand Down Expand Up @@ -155,14 +155,13 @@ def view_is_async(cls):
"""
Checks whether any viewset methods are coroutines.
"""
result = [
return any(
asyncio.iscoroutinefunction(function)
for name, function in getmembers(
cls, inspect.iscoroutinefunction, exclude_names=["view_is_async"]
)
if not name.startswith("__") and name not in cls._ASYNC_NON_DISPATCH_METHODS
]
return any(result)
)


class GenericViewSet(ViewSet, GenericAPIView):
Expand Down
131 changes: 131 additions & 0 deletions tests/test_routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from asgiref.sync import async_to_sync
from django.contrib.auth.models import User
from django.test import Client, TestCase, override_settings
from rest_framework import status
from rest_framework.response import Response
from rest_framework.test import APIRequestFactory
from rest_framework.viewsets import ModelViewSet as DRFModelViewSet

from adrf.routers import SimpleRouter, DefaultRouter
from adrf.serializers import ModelSerializer
from adrf.viewsets import ModelViewSet as AsyncModelViewSet
from tests.test_views import JSON_ERROR, sanitise_json_error


class SyncViewSet(DRFModelViewSet):
def list(self, request):
return Response({"method": "GET", "async": False})

def create(self, request):
return Response({"method": "POST", "data": request.data, "async": False})

def retrieve(self, request, pk):
return Response({"method": "GET", "data": {"pk": pk}, "async": False})

def update(self, request, pk):
return Response({"method": "PUT", "data": request.data, "async": False})

def partial_update(self, request, pk):
return Response({"method": "PATCH", "data": request.data, "async": False})

def destroy(self, request, pk):
return Response({"method": "DELETE", "async": False})


class AsyncViewSet(AsyncModelViewSet):
async def alist(self, request):
return Response({"method": "GET", "async": True})

async def acreate(self, request):
return Response({"method": "POST", "data": request.data, "async": True})

async def aretrieve(self, request, pk):
return Response({"method": "GET", "data": {"pk": pk}, "async": True})

async def aupdate(self, request, pk):
return Response({"method": "PUT", "data": request.data, "async": True})

async def partial_aupdate(self, request, pk):
return Response({"method": "PATCH", "data": request.data, "async": True})

async def adestroy(self, request, pk):
return Response({"method": "DELETE", "async": True})


router = SimpleRouter()
router.register("sync", SyncViewSet, basename="sync")
router.register("async", AsyncViewSet, basename="async")
urlpatterns = router.urls


@override_settings(ROOT_URLCONF="tests.test_routers")
class _RouterIntegrationTests(TestCase):
use_async = None
__test__ = False

def setUp(self):
self.client = Client()
self.url = "/" + ("async" if self.use_async else "sync") + "/"
self.detail_url = self.url + "1/"

def test_list(self):
resp = self.client.get(self.url)
assert resp.status_code == 200
assert resp.data == {"method": "GET", "async": self.use_async}

def test_create(self):
resp = self.client.post(
self.url, {"foo": "bar"}, content_type="application/json"
)
assert resp.status_code == 200
assert resp.data == {
"method": "POST",
"data": {"foo": "bar"},
"async": self.use_async,
}

def test_retrieve(self):
resp = self.client.get(self.detail_url)
assert resp.status_code == 200
assert resp.data == {
"method": "GET",
"data": {"pk": "1"},
"async": self.use_async,
}

def test_update(self):
resp = self.client.put(
self.detail_url, {"foo": "bar"}, content_type="application/json"
)
assert resp.status_code == 200
assert resp.data == {
"method": "PUT",
"data": {"foo": "bar"},
"async": self.use_async,
}

def test_partial_update(self):
resp = self.client.patch(
self.detail_url, {"foo": "bar"}, content_type="application/json"
)
assert resp.status_code == 200
assert resp.data == {
"method": "PATCH",
"data": {"foo": "bar"},
"async": self.use_async,
}

def test_destroy(self):
resp = self.client.delete(self.detail_url)
assert resp.status_code == 200
assert resp.data == {"method": "DELETE", "async": self.use_async}


class TestSyncRouterIntegrationTests(_RouterIntegrationTests):
use_async = False
__test__ = True


class AsyncRouterIntegrationTests(_RouterIntegrationTests):
use_async = True
__test__ = True

0 comments on commit 6a8f6b8

Please sign in to comment.