Skip to content

Commit

Permalink
Merge pull request #45 from tbeadle/router-viewsets
Browse files Browse the repository at this point in the history
Add routers that support sync and async viewsets.
  • Loading branch information
em1208 authored Oct 12, 2024
2 parents 226e2f0 + 6a8f6b8 commit 1d8e68f
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 5 deletions.
42 changes: 42 additions & 0 deletions adrf/routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from rest_framework.routers import (
DefaultRouter as DRFDefaultRouter,
)
from rest_framework.routers import (
SimpleRouter as DRFSimpleRouter,
)


class SimpleRouter(DRFSimpleRouter):
sync_to_async_action_map = {
"list": "alist",
"create": "acreate",
"retrieve": "aretrieve",
"update": "aupdate",
"destroy": "adestroy",
"partial_update": "partial_aupdate",
}

def get_method_map(self, viewset, method_map):
"""
Given a viewset, and a mapping of http methods to actions,
return a new mapping which only includes any mappings that
are actually implemented by the viewset.
To allow the use of a single router that registers sync and async
viewsets, the actions defined in the routes' method maps are
updated to be the "a"-prefixed names for async viewsets.
"""
bound_methods = {}
if getattr(viewset, "view_is_async", False):
method_map = {
method: self.sync_to_async_action_map.get(action, action)
for method, action in method_map.items()
}
for method, action in method_map.items():
if hasattr(viewset, action):
bound_methods[method] = action
return bound_methods


class DefaultRouter(SimpleRouter, DRFDefaultRouter):
pass
9 changes: 4 additions & 5 deletions adrf/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class ViewSetMixin(DRFViewSetMixin):
the binding of HTTP methods to actions on the Resource.
For example, to create a concrete view binding the 'GET' and 'POST' methods
to the 'list' and 'create' actions...
to the 'alist' and 'acreate' actions...
view = MyViewSet.as_view({'get': 'list', 'post': 'create'})
view = MyViewSet.as_view({'get': 'alist', 'post': 'acreate'})
"""

@classonlymethod
Expand Down Expand Up @@ -155,14 +155,13 @@ def view_is_async(cls):
"""
Checks whether any viewset methods are coroutines.
"""
result = [
return any(
asyncio.iscoroutinefunction(function)
for name, function in getmembers(
cls, inspect.iscoroutinefunction, exclude_names=["view_is_async"]
)
if not name.startswith("__") and name not in cls._ASYNC_NON_DISPATCH_METHODS
]
return any(result)
)


class GenericViewSet(ViewSet, GenericAPIView):
Expand Down
131 changes: 131 additions & 0 deletions tests/test_routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from asgiref.sync import async_to_sync
from django.contrib.auth.models import User
from django.test import Client, TestCase, override_settings
from rest_framework import status
from rest_framework.response import Response
from rest_framework.test import APIRequestFactory
from rest_framework.viewsets import ModelViewSet as DRFModelViewSet

from adrf.routers import SimpleRouter, DefaultRouter
from adrf.serializers import ModelSerializer
from adrf.viewsets import ModelViewSet as AsyncModelViewSet
from tests.test_views import JSON_ERROR, sanitise_json_error


class SyncViewSet(DRFModelViewSet):
def list(self, request):
return Response({"method": "GET", "async": False})

def create(self, request):
return Response({"method": "POST", "data": request.data, "async": False})

def retrieve(self, request, pk):
return Response({"method": "GET", "data": {"pk": pk}, "async": False})

def update(self, request, pk):
return Response({"method": "PUT", "data": request.data, "async": False})

def partial_update(self, request, pk):
return Response({"method": "PATCH", "data": request.data, "async": False})

def destroy(self, request, pk):
return Response({"method": "DELETE", "async": False})


class AsyncViewSet(AsyncModelViewSet):
async def alist(self, request):
return Response({"method": "GET", "async": True})

async def acreate(self, request):
return Response({"method": "POST", "data": request.data, "async": True})

async def aretrieve(self, request, pk):
return Response({"method": "GET", "data": {"pk": pk}, "async": True})

async def aupdate(self, request, pk):
return Response({"method": "PUT", "data": request.data, "async": True})

async def partial_aupdate(self, request, pk):
return Response({"method": "PATCH", "data": request.data, "async": True})

async def adestroy(self, request, pk):
return Response({"method": "DELETE", "async": True})


router = SimpleRouter()
router.register("sync", SyncViewSet, basename="sync")
router.register("async", AsyncViewSet, basename="async")
urlpatterns = router.urls


@override_settings(ROOT_URLCONF="tests.test_routers")
class _RouterIntegrationTests(TestCase):
use_async = None
__test__ = False

def setUp(self):
self.client = Client()
self.url = "/" + ("async" if self.use_async else "sync") + "/"
self.detail_url = self.url + "1/"

def test_list(self):
resp = self.client.get(self.url)
assert resp.status_code == 200
assert resp.data == {"method": "GET", "async": self.use_async}

def test_create(self):
resp = self.client.post(
self.url, {"foo": "bar"}, content_type="application/json"
)
assert resp.status_code == 200
assert resp.data == {
"method": "POST",
"data": {"foo": "bar"},
"async": self.use_async,
}

def test_retrieve(self):
resp = self.client.get(self.detail_url)
assert resp.status_code == 200
assert resp.data == {
"method": "GET",
"data": {"pk": "1"},
"async": self.use_async,
}

def test_update(self):
resp = self.client.put(
self.detail_url, {"foo": "bar"}, content_type="application/json"
)
assert resp.status_code == 200
assert resp.data == {
"method": "PUT",
"data": {"foo": "bar"},
"async": self.use_async,
}

def test_partial_update(self):
resp = self.client.patch(
self.detail_url, {"foo": "bar"}, content_type="application/json"
)
assert resp.status_code == 200
assert resp.data == {
"method": "PATCH",
"data": {"foo": "bar"},
"async": self.use_async,
}

def test_destroy(self):
resp = self.client.delete(self.detail_url)
assert resp.status_code == 200
assert resp.data == {"method": "DELETE", "async": self.use_async}


class TestSyncRouterIntegrationTests(_RouterIntegrationTests):
use_async = False
__test__ = True


class AsyncRouterIntegrationTests(_RouterIntegrationTests):
use_async = True
__test__ = True

0 comments on commit 1d8e68f

Please sign in to comment.