Skip to content

Commit

Permalink
Fix for PrimaryKeyRelatedField serializer bug
Browse files Browse the repository at this point in the history
  • Loading branch information
em1208 committed Aug 2, 2024
1 parent e518559 commit 5fcc2d8
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 22 deletions.
15 changes: 10 additions & 5 deletions adrf/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList


# NOTE This is the list of fields defined by DRF for which we need to call to_rapresentation.
DRF_FIELDS = list(DRFModelSerializer.serializer_field_mapping.values()) + [
DRFModelSerializer.serializer_related_field,
DRFModelSerializer.serializer_related_to_field,
DRFModelSerializer.serializer_url_field,
DRFModelSerializer.serializer_choice_field,
]


class BaseSerializer(DRFBaseSerializer):
"""
Base serializer class.
Expand Down Expand Up @@ -148,17 +157,13 @@ async def ato_representation(self, instance):
except SkipField:
continue

is_drf_field = type(field) in list(
DRFModelSerializer.serializer_field_mapping.values()
) + [DRFModelSerializer.serializer_choice_field]

check_for_none = (
attribute.pk if isinstance(attribute, models.Model) else attribute
)
if check_for_none is None:
ret[field.field_name] = None
else:
if is_drf_field:
if type(field) in DRF_FIELDS:
repr = field.to_representation(attribute)
else:
repr = await field.ato_representation(attribute)
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def pytest_configure(config):
"django.contrib.staticfiles",
"rest_framework",
"rest_framework.authtoken",
"tests",
),
PASSWORD_HASHERS=("django.contrib.auth.hashers.MD5PasswordHasher",),
)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from django.db import models
from django.contrib.auth.models import User


class Order(models.Model):
name = models.TextField()
user = models.ForeignKey(User, on_delete=models.CASCADE)
51 changes: 34 additions & 17 deletions tests/test_serializers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from collections import ChainMap

from asgiref.sync import sync_to_async
from django.contrib.auth.models import User
from django.test import TestCase

from adrf.serializers import ModelSerializer, Serializer
from rest_framework import serializers
from rest_framework.test import APIRequestFactory
from .test_models import User, Order

factory = APIRequestFactory()

Expand All @@ -20,7 +20,6 @@ def __init__(self, **kwargs):
setattr(self, key, val)


# replace with django test case
class TestSerializer(TestCase):
def setUp(self):
class SimpleSerializer(Serializer):
Expand All @@ -39,14 +38,8 @@ async def acreate(self, validated_data):
async def aupdate(self, instance, validated_data):
return MockObject(**validated_data)

class MyModelSerializer(ModelSerializer):
class Meta:
model = User
fields = ("username",)

self.simple_serializer = SimpleSerializer
self.crud_serializer = CrudSerializer
self.model_serializer = MyModelSerializer

self.default_data = {
"username": "test",
Expand All @@ -66,15 +59,6 @@ async def test_serializer_valid(self):
assert await serializer.adata == data
assert serializer.errors == {}

async def test_modelserializer_valid(self):
data = {
"username": "test",
}
serializer = self.model_serializer(data=data)
assert await sync_to_async(serializer.is_valid)()
assert await serializer.adata == data
assert serializer.errors == {}

async def test_serializer_invalid(self):
data = {
"username": "test",
Expand Down Expand Up @@ -236,3 +220,36 @@ def test_sync_serializer_valid(self):
assert serializer.is_valid()
assert serializer.data == data
assert serializer.errors == {}


class TestModelSerializer(TestCase):
def setUp(self) -> None:
class UserSerializer(ModelSerializer):
class Meta:
model = User
fields = ("username",)

class OrderSerializer(ModelSerializer):
class Meta:
model = Order
fields = ("id", "user", "name")

self.user_serializer = UserSerializer
self.order_serializer = OrderSerializer

async def test_user_serializer_valid(self):
data = {
"username": "test",
}
serializer = self.user_serializer(data=data)
assert await sync_to_async(serializer.is_valid)()
assert await serializer.adata == data
assert serializer.errors == {}

async def test_order_serializer_valid(self):
user = await User.objects.acreate(username="test")
data = {"user": user.id, "name": "Test order"}
serializer = self.order_serializer(data=data)
assert await sync_to_async(serializer.is_valid)()
assert await serializer.adata == data
assert serializer.errors == {}

0 comments on commit 5fcc2d8

Please sign in to comment.