-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add routers that support sync and async viewsets.
The SimpleRouter included in DRF does not attempt to look for "alist", "acreate", etc. methods on async viewsets. Using this new router will cause it to use those methods for async viewsets and the non-"a"-prefixed methods for non-async viewsets.
- Loading branch information
Showing
3 changed files
with
177 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from rest_framework.routers import ( | ||
DefaultRouter as DRFDefaultRouter, | ||
) | ||
from rest_framework.routers import ( | ||
SimpleRouter as DRFSimpleRouter, | ||
) | ||
|
||
|
||
class SimpleRouter(DRFSimpleRouter): | ||
sync_to_async_action_map = { | ||
"list": "alist", | ||
"create": "acreate", | ||
"retrieve": "aretrieve", | ||
"update": "aupdate", | ||
"destroy": "adestroy", | ||
"partial_update": "partial_aupdate", | ||
} | ||
|
||
def get_method_map(self, viewset, method_map): | ||
""" | ||
Given a viewset, and a mapping of http methods to actions, | ||
return a new mapping which only includes any mappings that | ||
are actually implemented by the viewset. | ||
To allow the use of a single router that registers sync and async | ||
viewsets, the actions defined in the routes' method maps are | ||
updated to be the "a"-prefixed names for async viewsets. | ||
""" | ||
bound_methods = {} | ||
if getattr(viewset, "view_is_async", False): | ||
method_map = { | ||
method: self.sync_to_async_action_map.get(action, action) | ||
for method, action in method_map.items() | ||
} | ||
for method, action in method_map.items(): | ||
if hasattr(viewset, action): | ||
bound_methods[method] = action | ||
return bound_methods | ||
|
||
|
||
class DefaultRouter(SimpleRouter, DRFDefaultRouter): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from asgiref.sync import async_to_sync | ||
from django.contrib.auth.models import User | ||
from django.test import Client, TestCase, override_settings | ||
from rest_framework import status | ||
from rest_framework.response import Response | ||
from rest_framework.test import APIRequestFactory | ||
from rest_framework.viewsets import ModelViewSet as DRFModelViewSet | ||
|
||
from adrf.routers import SimpleRouter, DefaultRouter | ||
from adrf.serializers import ModelSerializer | ||
from adrf.viewsets import ModelViewSet as AsyncModelViewSet | ||
from tests.test_views import JSON_ERROR, sanitise_json_error | ||
|
||
|
||
class SyncViewSet(DRFModelViewSet): | ||
def list(self, request): | ||
return Response({"method": "GET", "async": False}) | ||
|
||
def create(self, request): | ||
return Response({"method": "POST", "data": request.data, "async": False}) | ||
|
||
def retrieve(self, request, pk): | ||
return Response({"method": "GET", "data": {"pk": pk}, "async": False}) | ||
|
||
def update(self, request, pk): | ||
return Response({"method": "PUT", "data": request.data, "async": False}) | ||
|
||
def partial_update(self, request, pk): | ||
return Response({"method": "PATCH", "data": request.data, "async": False}) | ||
|
||
def destroy(self, request, pk): | ||
return Response({"method": "DELETE", "async": False}) | ||
|
||
|
||
class AsyncViewSet(AsyncModelViewSet): | ||
async def alist(self, request): | ||
return Response({"method": "GET", "async": True}) | ||
|
||
async def acreate(self, request): | ||
return Response({"method": "POST", "data": request.data, "async": True}) | ||
|
||
async def aretrieve(self, request, pk): | ||
return Response({"method": "GET", "data": {"pk": pk}, "async": True}) | ||
|
||
async def aupdate(self, request, pk): | ||
return Response({"method": "PUT", "data": request.data, "async": True}) | ||
|
||
async def partial_aupdate(self, request, pk): | ||
return Response({"method": "PATCH", "data": request.data, "async": True}) | ||
|
||
async def adestroy(self, request, pk): | ||
return Response({"method": "DELETE", "async": True}) | ||
|
||
|
||
router = SimpleRouter() | ||
router.register("sync", SyncViewSet, basename="sync") | ||
router.register("async", AsyncViewSet, basename="async") | ||
urlpatterns = router.urls | ||
|
||
|
||
@override_settings(ROOT_URLCONF="tests.test_routers") | ||
class _RouterIntegrationTests(TestCase): | ||
use_async = None | ||
__test__ = False | ||
|
||
def setUp(self): | ||
self.client = Client() | ||
self.url = "/" + ("async" if self.use_async else "sync") + "/" | ||
self.detail_url = self.url + "1/" | ||
|
||
def test_list(self): | ||
resp = self.client.get(self.url) | ||
assert resp.status_code == 200 | ||
assert resp.data == {"method": "GET", "async": self.use_async} | ||
|
||
def test_create(self): | ||
resp = self.client.post( | ||
self.url, {"foo": "bar"}, content_type="application/json" | ||
) | ||
assert resp.status_code == 200 | ||
assert resp.data == { | ||
"method": "POST", | ||
"data": {"foo": "bar"}, | ||
"async": self.use_async, | ||
} | ||
|
||
def test_retrieve(self): | ||
resp = self.client.get(self.detail_url) | ||
assert resp.status_code == 200 | ||
assert resp.data == { | ||
"method": "GET", | ||
"data": {"pk": "1"}, | ||
"async": self.use_async, | ||
} | ||
|
||
def test_update(self): | ||
resp = self.client.put( | ||
self.detail_url, {"foo": "bar"}, content_type="application/json" | ||
) | ||
assert resp.status_code == 200 | ||
assert resp.data == { | ||
"method": "PUT", | ||
"data": {"foo": "bar"}, | ||
"async": self.use_async, | ||
} | ||
|
||
def test_partial_update(self): | ||
resp = self.client.patch( | ||
self.detail_url, {"foo": "bar"}, content_type="application/json" | ||
) | ||
assert resp.status_code == 200 | ||
assert resp.data == { | ||
"method": "PATCH", | ||
"data": {"foo": "bar"}, | ||
"async": self.use_async, | ||
} | ||
|
||
def test_destroy(self): | ||
resp = self.client.delete(self.detail_url) | ||
assert resp.status_code == 200 | ||
assert resp.data == {"method": "DELETE", "async": self.use_async} | ||
|
||
|
||
class TestSyncRouterIntegrationTests(_RouterIntegrationTests): | ||
use_async = False | ||
__test__ = True | ||
|
||
|
||
class AsyncRouterIntegrationTests(_RouterIntegrationTests): | ||
use_async = True | ||
__test__ = True |