Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update get_serializer_class to consider the request method #77

Merged
merged 10 commits into from
May 27, 2024
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ tests/__init__.py

# Development task artifacts
*.db
Pipfile

# Visual Studio Code config files
.vscode

# pytest
.pytest_cache/
.pytest_cache/
49 changes: 45 additions & 4 deletions drf_rw_serializers/generics.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GenericAPIViewGetSerializerClassTests tests the changes made in this file.

Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

from rest_framework import generics, mixins

from .mixins import CreateModelMixin, ListModelMixin, RetrieveModelMixin, UpdateModelMixin
from .mixins import (
CreateModelMixin,
ListModelMixin,
RetrieveModelMixin,
UpdateModelMixin,
)


class GenericAPIView(generics.GenericAPIView):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def get_serializer_class(self):
def _get_serializer_class(self):
"""
Return the class to use for the serializer.
Defaults to using `self.serializer_class`.
Expand All @@ -25,6 +30,42 @@ def get_serializer_class(self):

return self.serializer_class

def get_serializer_class(self):
"""
Return the class to use for the serializer.
Defaults to using `self.serializer_class`.
If the request method is GET, it tries to use `self.read_serializer_class`.
If the request method is not GET, it tries to use `self.write_serializer_class`.
If the specific serializer class for the request method is not set, it falls back to
`self.serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins get full serialization, others get basic serialization)
"""
if hasattr(self, "request"):
if self.request.method in ["GET", "HEAD", "OPTIONS", "TRACE"]:
assert (
getattr(self, "read_serializer_class", None) is not None
or self.serializer_class is not None
), (
"'%s' should either include a `read_serializer_class` or `serializer_class` "
"attribute, or override the `get_read_serializer_class()` or "
"`get_serializer_class()` method." % self.__class__.__name__
)
return self.get_read_serializer_class()
elif self.request.method in ["POST", "PUT", "PATCH", "DELETE"]:
assert (
getattr(self, "write_serializer_class", None) is not None
or self.serializer_class is not None
), (
"'%s' should either include a `write_serializer_class` or `serializer_class` "
"attribute, or override the `get_write_serializer_class()` or "
"`get_serializer_class()` method." % self.__class__.__name__
)
return self.get_write_serializer_class()

return self._get_serializer_class()
pamella marked this conversation as resolved.
Show resolved Hide resolved

def get_read_serializer(self, *args, **kwargs):
"""
Return the serializer instance that should be used for serializing output.
Expand All @@ -42,7 +83,7 @@ def get_read_serializer_class(self):
(Eg. admins get full serialization, others get basic serialization)
"""
if getattr(self, "read_serializer_class", None) is None:
return self.get_serializer_class()
return self._get_serializer_class()

return self.read_serializer_class

Expand All @@ -64,7 +105,7 @@ def get_write_serializer_class(self):
(Eg. admins can send extra fields, others cannot)
"""
if getattr(self, "write_serializer_class", None) is None:
return self.get_serializer_class()
return self._get_serializer_class()

return self.write_serializer_class

Expand Down
100 changes: 68 additions & 32 deletions test_utils/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def setUp(self):


class TestListRequestSuccess(object):
def test_get_serializer_class(self):
response = self.auth_client.get(self.view_url, format="json")
view = response.renderer_context["view"]
self.assertEqual(view.get_serializer_class(), self.list_serializer_class)

def test_list_request_success(self):
orders = baker.make("example_app.Order", _quantity=3)
for order in orders:
Expand All @@ -36,6 +41,11 @@ def test_list_request_success(self):


class TestRetrieveRequestSuccess(object):
def test_get_serializer_class(self):
response = self.auth_client.get(self.view_url, format="json")
view = response.renderer_context["view"]
self.assertEqual(view.get_serializer_class(), self.retrieve_serializer_class)

def test_list_request_success(self):
order = baker.make("example_app.Order")
baker.make("example_app.OrderedMeal", order=order, _quantity=2)
Expand All @@ -45,54 +55,80 @@ def test_list_request_success(self):


class TestCreateRequestSuccess(object):
data = {
"table_number": 100,
"ordered_meals": [
{
"quantity": 1,
"meal": None,
},
{
"quantity": 2,
"meal": None,
},
],
}

def setUp(self):
self.data["ordered_meals"][0]["meal"] = self.meals[0].id
self.data["ordered_meals"][1]["meal"] = self.meals[1].id

def test_get_serializer_class(self):
response = self.auth_client.post(self.view_url, self.data, format="json")
view = response.renderer_context["view"]
self.assertEqual(view.get_serializer_class(), self.create_in_serializer_class)

def test_create_request_success(self):
data = {
"table_number": 100,
"ordered_meals": [
{
"quantity": 1,
"meal": self.meals[0].id,
},
{
"quantity": 2,
"meal": self.meals[1].id,
},
],
}
response = self.auth_client.post(self.view_url, data, format="json")
response = self.auth_client.post(self.view_url, self.data, format="json")
self.assertEqual(response.status_code, 201)
order = Order.objects.get(id=response.data["id"])
self.assertEqual(response.data, self.create_out_serializer_class(order).data)
self.assertEqual(order.table_number, data["table_number"])
self.assertEqual(order.table_number, self.data["table_number"])

for ordered_meal_dict in data["ordered_meals"]:
for ordered_meal_dict in self.data["ordered_meals"]:
ordered_meal = order.ordered_meals.filter(meal__id=ordered_meal_dict["meal"]).first()
self.assertIsNotNone(ordered_meal)
self.assertEqual(ordered_meal.quantity, ordered_meal_dict["quantity"])


class TestUpdateRequestSuccess(object):
data = {
"table_number": 2,
"ordered_meals": [
{
"quantity": 10,
"meal": None,
},
{
"quantity": 20,
"meal": None,
},
],
}

def setUp(self):
self.data["ordered_meals"][0]["meal"] = self.meals[0].id
self.data["ordered_meals"][1]["meal"] = self.meals[1].id

def test_get_serializer_class(self):
# PUT request
response = self.auth_client.put(self.view_url, self.data, format="json")
view = response.renderer_context["view"]
self.assertEqual(view.get_serializer_class(), self.update_in_serializer_class)

# PATCH request
response = self.auth_client.patch(self.view_url, self.data, format="json")
view = response.renderer_context["view"]
self.assertEqual(view.get_serializer_class(), self.update_in_serializer_class)

def test_update_request_success(self):
data = {
"table_number": 2,
"ordered_meals": [
{
"quantity": 10,
"meal": self.meals[0].id,
},
{
"quantity": 20,
"meal": self.meals[1].id,
},
],
}
response = self.auth_client.put(self.view_url, data, format="json")
response = self.auth_client.put(self.view_url, self.data, format="json")
self.object.refresh_from_db()
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, self.update_out_serializer_class(self.object).data)
self.assertEqual(self.object.table_number, data["table_number"])
self.assertEqual(self.object.table_number, self.data["table_number"])

for ordered_meal_dict in data["ordered_meals"]:
for ordered_meal_dict in self.data["ordered_meals"]:
ordered_meal = self.object.ordered_meals.filter(
meal__id=ordered_meal_dict["meal"]
).first()
Expand Down
Loading