Skip to content

Commit

Permalink
Merge branch 'fix-user-field-error' into 'main'
Browse files Browse the repository at this point in the history
Fix error while updating user fields via REST API

See merge request reportcreator/reportcreator!748
  • Loading branch information
MWedl committed Nov 5, 2024
2 parents 4b79ec2 + 2d5a2a2 commit 3ed99fb
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* Disable static file compression
* Allow to cancel PDF rendering requests
* Show PDF render timing information
* Fix error while updating user fields via REST API


## v2024.81 - 2024-10-25
Expand Down
3 changes: 2 additions & 1 deletion api/src/reportcreator_api/pentests/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,8 @@ def perform_update_key(self, obj, path, value, definition, content, **kwargs):
serializer_data = {path[0]: value}

# Update in DB
serializer = (ReportSectionSerializer if isinstance(obj, ReportSection) else PentestFindingSerializer)(instance=obj, data=serializer_data, partial=True)
serializer_class = ReportSectionSerializer if isinstance(obj, ReportSection) else PentestFindingSerializer
serializer = serializer_class(instance=obj, data=serializer_data, partial=True, context={'project': obj.project})
serializer.is_valid(raise_exception=True)
res = serializer.save()

Expand Down
12 changes: 9 additions & 3 deletions api/src/reportcreator_api/pentests/customfields/serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
from uuid import UUID

from django.db.models.query import Prefetch, prefetch_related_objects
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers
Expand All @@ -12,6 +13,7 @@
FieldDefinition,
ObjectField,
)
from reportcreator_api.pentests.models import ProjectMemberInfo
from reportcreator_api.users.models import PentestUser


Expand Down Expand Up @@ -44,9 +46,13 @@ class UserField(serializers.PrimaryKeyRelatedField):

def to_internal_value(self, data):
if isinstance(data, (str, UUID)) and (project := self.context.get('project')):
if project.members and (user := next(filter(lambda u: str(data) == str(u.id), project.members), None)):
return str(user.id)
elif project.imported_members and (imported_user := next(filter(lambda u: data == u.get('id'), project.imported_members), None)):
if not getattr(project, '_prefetched_objects_cache', {}).get('members'):
# Prefetch members to avoid N+1 queries
prefetch_related_objects([project], Prefetch('members', ProjectMemberInfo.objects.select_related('user')))

if member := next(filter(lambda u: str(data) == str(u.user.id), project.members.all()), None):
return str(member.user.id)
elif imported_user := next(filter(lambda u: data == u.get('id'), project.imported_members), None):
return imported_user.get('id')

user = super().to_internal_value(data)
Expand Down
58 changes: 57 additions & 1 deletion api/src/reportcreator_api/tests/test_customfields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
from django.core.exceptions import ValidationError
from django.urls import reverse
from django.utils import timezone

from reportcreator_api.pentests.collab.text_transformations import SelectionRange
Expand Down Expand Up @@ -37,6 +38,7 @@
from reportcreator_api.pentests.models.project import Comment
from reportcreator_api.tasks.rendering.entry import format_template_field_object
from reportcreator_api.tests.mock import (
api_client,
create_comment,
create_finding,
create_project,
Expand Down Expand Up @@ -180,7 +182,61 @@ def test_field_values(valid, definition, value):
@pytest.mark.django_db()
def test_user_field_value():
user = create_user()
FieldValuesValidator(parse_field_definition([{'id': 'field_user', 'type': 'user', 'label': 'User Field'}]))({'field_user': str(user.id)})
definition = parse_field_definition([{'id': 'field_user', 'type': 'user', 'label': 'User Field'}])
FieldValuesValidator(definition)({'field_user': str(user.id)})


@pytest.mark.django_db()
def test_api_serializer():
user = create_user()
project = create_project(members=[user])
client = api_client(user)

field_data = {
'field_string': 'This is a string',
'field_markdown': 'Some **markdown**\n* String\n*List',
'field_cvss': 'CVSS:3.1/AV:N/AC:H/PR:N/UI:R/S:C/C:H/I:H/A:H',
'field_cwe': 'CWE-89',
'field_date': '2024-01-01',
'field_int': 17,
'field_bool': True,
'field_enum': 'enum1',
'field_combobox': 'value2',
'field_user': str(user.id),
'field_list': ['test'],
'field_object': {'nested1': 'val'},
}

res1 = client.patch(reverse('section-detail', kwargs={'project_pk': project.id, 'id': 'other'}), data={
'data': field_data,
})
assert res1.status_code == 200, res1.data

res2 = client.patch(reverse('finding-detail', kwargs={'project_pk': project.id, 'id': project.findings.first().finding_id}), data={
'data': field_data,
})
assert res2.status_code == 200, res2.data


@pytest.mark.django_db()
def test_api_serializer_user():
user = create_user()
user_imported = {
'id': str(uuid4()),
'name': 'Imported User',
}
project = create_project(members=[user], imported_members=[user_imported])
client = api_client(user)

def assert_valid_user_field_value(user_id, expected):
res = client.patch(reverse('section-detail', kwargs={'project_pk': project.id, 'id': 'other'}), data={
'data': {'field_user': user_id},
})
assert (res.status_code == 200) is expected

assert_valid_user_field_value(str(user.id), True) # Project member
assert_valid_user_field_value(user_imported['id'], True) # Imported member
assert_valid_user_field_value(str(uuid4()), False) # Nonexistent user


class CustomFieldsTestModel(CustomFieldsMixin):
Expand Down

0 comments on commit 3ed99fb

Please sign in to comment.