diff --git a/restdoctor/rest_framework/fields.py b/restdoctor/rest_framework/fields.py index b0cc693..14c8b07 100644 --- a/restdoctor/rest_framework/fields.py +++ b/restdoctor/rest_framework/fields.py @@ -1,16 +1,20 @@ from __future__ import annotations import datetime -from typing import Optional, Union, Any +from typing import Optional, Union, Any, TYPE_CHECKING from django.db import models from rest_framework import ISO_8601 -from rest_framework.fields import DateTimeField as BaseDateTimeField +from rest_framework.fields import DateTimeField as BaseDateTimeField, UUIDField from rest_framework.relations import HyperlinkedIdentityField as BaseHyperlinkedIdentityField from rest_framework.request import Request from rest_framework.settings import api_settings from restdoctor.rest_framework.reverse import preserve_resource_params +if TYPE_CHECKING: + import uuid + from django.db.models import Model, QuerySet + class DateTimeField(BaseDateTimeField): def to_representation( @@ -36,3 +40,12 @@ class HyperlinkedIdentityField(BaseHyperlinkedIdentityField): def get_url(self, obj: models.Model, view_name: str, request: Request, *args: Any, **kwargs: Any) -> str: url = super().get_url(obj, view_name, request, *args, **kwargs) return preserve_resource_params(url, request) + + +class ModelFromUUIDField(UUIDField): + def __init__(self, queryset: QuerySet, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.queryset = queryset + + def to_representation(self, uuid: uuid.UUID) -> Optional[Model]: + return self.queryset.filter(uuid=uuid).first() diff --git a/tests/test_unit/test_fields.py b/tests/test_unit/test_fields.py index 2f03883..ea39d0e 100644 --- a/tests/test_unit/test_fields.py +++ b/tests/test_unit/test_fields.py @@ -4,7 +4,8 @@ import pytz from django.utils.timezone import make_aware -from restdoctor.rest_framework.fields import DateTimeField +from restdoctor.rest_framework.fields import DateTimeField, ModelFromUUIDField +from tests.stubs.models import MyModel @pytest.mark.parametrize( @@ -21,3 +22,12 @@ def test_datetime_field_to_representation( datetime_obj, expected_string_representation, ): assert DateTimeField().to_representation(datetime_obj) == expected_string_representation + + +@pytest.mark.django_db +def test_model_from_uuid_field_to_representation( + my_model, +): + queryset = MyModel.objects.all() + + assert ModelFromUUIDField(queryset=queryset).to_representation(my_model.uuid) == my_model