diff --git a/adrf/serializers.py b/adrf/serializers.py index a063ce6..2884db2 100644 --- a/adrf/serializers.py +++ b/adrf/serializers.py @@ -15,6 +15,15 @@ from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList +# NOTE This is the list of fields defined by DRF for which we need to call to_rapresentation. +DRF_FIELDS = list(DRFModelSerializer.serializer_field_mapping.values()) + [ + DRFModelSerializer.serializer_related_field, + DRFModelSerializer.serializer_related_to_field, + DRFModelSerializer.serializer_url_field, + DRFModelSerializer.serializer_choice_field, +] + + class BaseSerializer(DRFBaseSerializer): """ Base serializer class. @@ -148,17 +157,13 @@ async def ato_representation(self, instance): except SkipField: continue - is_drf_field = type(field) in list( - DRFModelSerializer.serializer_field_mapping.values() - ) + [DRFModelSerializer.serializer_choice_field] - check_for_none = ( attribute.pk if isinstance(attribute, models.Model) else attribute ) if check_for_none is None: ret[field.field_name] = None else: - if is_drf_field: + if type(field) in DRF_FIELDS: repr = field.to_representation(attribute) else: repr = await field.ato_representation(attribute) diff --git a/tests/conftest.py b/tests/conftest.py index 9139cce..bfa3699 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,6 +41,7 @@ def pytest_configure(config): "django.contrib.staticfiles", "rest_framework", "rest_framework.authtoken", + "tests", ), PASSWORD_HASHERS=("django.contrib.auth.hashers.MD5PasswordHasher",), ) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..c29ba63 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,7 @@ +from django.db import models +from django.contrib.auth.models import User + + +class Order(models.Model): + name = models.TextField() + user = models.ForeignKey(User, on_delete=models.CASCADE) diff --git a/tests/test_serializers.py b/tests/test_serializers.py index 6ab2dca..6563dc1 100644 --- a/tests/test_serializers.py +++ b/tests/test_serializers.py @@ -1,12 +1,12 @@ from collections import ChainMap from asgiref.sync import sync_to_async -from django.contrib.auth.models import User from django.test import TestCase from adrf.serializers import ModelSerializer, Serializer from rest_framework import serializers from rest_framework.test import APIRequestFactory +from .test_models import User, Order factory = APIRequestFactory() @@ -20,7 +20,6 @@ def __init__(self, **kwargs): setattr(self, key, val) -# replace with django test case class TestSerializer(TestCase): def setUp(self): class SimpleSerializer(Serializer): @@ -39,14 +38,8 @@ async def acreate(self, validated_data): async def aupdate(self, instance, validated_data): return MockObject(**validated_data) - class MyModelSerializer(ModelSerializer): - class Meta: - model = User - fields = ("username",) - self.simple_serializer = SimpleSerializer self.crud_serializer = CrudSerializer - self.model_serializer = MyModelSerializer self.default_data = { "username": "test", @@ -66,15 +59,6 @@ async def test_serializer_valid(self): assert await serializer.adata == data assert serializer.errors == {} - async def test_modelserializer_valid(self): - data = { - "username": "test", - } - serializer = self.model_serializer(data=data) - assert await sync_to_async(serializer.is_valid)() - assert await serializer.adata == data - assert serializer.errors == {} - async def test_serializer_invalid(self): data = { "username": "test", @@ -236,3 +220,36 @@ def test_sync_serializer_valid(self): assert serializer.is_valid() assert serializer.data == data assert serializer.errors == {} + + +class TestModelSerializer(TestCase): + def setUp(self) -> None: + class UserSerializer(ModelSerializer): + class Meta: + model = User + fields = ("username",) + + class OrderSerializer(ModelSerializer): + class Meta: + model = Order + fields = ("id", "user", "name") + + self.user_serializer = UserSerializer + self.order_serializer = OrderSerializer + + async def test_user_serializer_valid(self): + data = { + "username": "test", + } + serializer = self.user_serializer(data=data) + assert await sync_to_async(serializer.is_valid)() + assert await serializer.adata == data + assert serializer.errors == {} + + async def test_order_serializer_valid(self): + user = await User.objects.acreate(username="test") + data = {"user": user.id, "name": "Test order"} + serializer = self.order_serializer(data=data) + assert await sync_to_async(serializer.is_valid)() + assert await serializer.adata == data + assert serializer.errors == {}