Skip to content

Commit

Permalink
Refactor export/import serializers
Browse files Browse the repository at this point in the history
  • Loading branch information
MWedl committed Dec 17, 2024
1 parent 8de2812 commit b5eaa1a
Show file tree
Hide file tree
Showing 9 changed files with 1,025 additions and 930 deletions.
51 changes: 16 additions & 35 deletions api/src/reportcreator_api/pentests/import_export/import_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@

from reportcreator_api.pentests.consumers import send_collab_event_project, send_collab_event_user
from reportcreator_api.pentests.import_export.serializers import (
FindingTemplateExportImportSerializerV2,
FindingTemplateImportSerializerV1,
FindingTemplateExportImportSerializer,
NotesExportImportSerializer,
PentestProjectExportImportSerializerV1,
PentestProjectExportImportSerializerV2,
ProjectTypeExportImportSerializerV1,
ProjectTypeExportImportSerializerV2,
PentestProjectExportImportSerializer,
ProjectTypeExportImportSerializer,
)
from reportcreator_api.pentests.models import (
CollabEvent,
Expand Down Expand Up @@ -115,16 +112,15 @@ def export_archive_iter(data, serializer_class: Type[serializers.Serializer], co
}
for obj in data:
serializer = serializer_class(instance=obj, context=context)
data = serializer.export()
archive_data = json.dumps(data, cls=DjangoJSONEncoder).encode()
archive_data = json.dumps(serializer.data, cls=DjangoJSONEncoder).encode()
yield from _tarfile_addfile(
buffer=buffer,
archive=archive,
tarinfo=build_tarinfo(name=f'{obj.id}.json', size=len(archive_data)),
file_chunks=[archive_data],
)

for name, file in serializer.export_files():
for name, file in serializer.export_files(instance=obj):
yield from _tarfile_addfile(
buffer=buffer,
archive=archive,
Expand All @@ -151,7 +147,7 @@ def export_archive_iter(data, serializer_class: Type[serializers.Serializer], co
@transaction.atomic()
@history_context(history_change_reason='Imported')
@collab_context(prevent_events=True)
def import_archive(archive_file, serializer_classes: list[Type[serializers.Serializer]], context=None):
def import_archive(archive_file, serializer_class: Type[serializers.Serializer], context=None):
context = (context or {}) | {
'archive': None,
'storage_files': [],
Expand All @@ -177,22 +173,9 @@ def import_archive(archive_file, serializer_classes: list[Type[serializers.Seria
for m in to_import:
data = json.load(archive.extractfile(m))

serializer = None
error = None
for serializer_class in serializer_classes:
try:
serializer = serializer_class(data=data, context=context)
serializer.is_valid(raise_exception=True)
error = None
break
except Exception as ex:
serializer = None
# Use error of the first failing serializer_class
if not error:
error = ex
if error:
raise error
imported_obj = serializer.perform_import()
serializer = serializer_class(data=data, context=context)
serializer.is_valid(raise_exception=True)
imported_obj = serializer.save()
for obj in imported_obj if isinstance(imported_obj, list) else [imported_obj]:
log.info(f'Imported object {obj=} {obj.id}')
if isinstance(imported_obj, list):
Expand All @@ -217,11 +200,11 @@ def import_archive(archive_file, serializer_classes: list[Type[serializers.Seria


def export_templates(data: Iterable[FindingTemplate]):
return export_archive_iter(data, serializer_class=FindingTemplateExportImportSerializerV2)
return export_archive_iter(data, serializer_class=FindingTemplateExportImportSerializer)

def export_project_types(data: Iterable[ProjectType]):
prefetch_related_objects(data, 'assets')
return export_archive_iter(data, serializer_class=ProjectTypeExportImportSerializerV2, context={
return export_archive_iter(data, serializer_class=ProjectTypeExportImportSerializer, context={
'add_design_notice_file': True,
})

Expand All @@ -235,7 +218,7 @@ def export_projects(data: Iterable[PentestProject], export_all=False):
'images',
'project_type__assets',
)
return export_archive_iter(data, serializer_class=PentestProjectExportImportSerializerV2, context={
return export_archive_iter(data, serializer_class=PentestProjectExportImportSerializer, context={
'export_all': export_all,
'add_design_notice_file': True,
})
Expand Down Expand Up @@ -267,21 +250,19 @@ def get_children_recursive(note, all_notes):


def import_templates(archive_file):
return import_archive(archive_file, serializer_classes=[FindingTemplateExportImportSerializerV2, FindingTemplateImportSerializerV1])
return import_archive(archive_file, serializer_class=FindingTemplateExportImportSerializer)

def import_project_types(archive_file):
return import_archive(archive_file, serializer_classes=[
ProjectTypeExportImportSerializerV2,
ProjectTypeExportImportSerializerV1])
return import_archive(archive_file, serializer_class=ProjectTypeExportImportSerializer)

def import_projects(archive_file):
return import_archive(archive_file, serializer_classes=[PentestProjectExportImportSerializerV2, PentestProjectExportImportSerializerV1])
return import_archive(archive_file, serializer_class=PentestProjectExportImportSerializer)

def import_notes(archive_file, context):
if not context.get('project') and not context.get('user'):
raise ValueError('Either project or user must be provided')
# Import notes to DB
notes = import_archive(archive_file, serializer_classes=[NotesExportImportSerializer], context=context)
notes = import_archive(archive_file, serializer_class=NotesExportImportSerializer, context=context)

# Send collab events
sender_options = {
Expand Down
Loading

0 comments on commit b5eaa1a

Please sign in to comment.