From f54b73d13134b791e12630ed8233721d7d007d94 Mon Sep 17 00:00:00 2001 From: Michael Wedl Date: Tue, 19 Nov 2024 14:38:37 +0100 Subject: [PATCH 1/2] Run periodic tasks in background --- api/src/reportcreator_api/api_utils/views.py | 9 +++++++-- api/src/reportcreator_api/conf/urls.py | 9 ++++++++- api/src/reportcreator_api/pentests/tasks.py | 6 +++--- api/src/reportcreator_api/tasks/querysets.py | 6 +++--- .../reportcreator_api/tests/test_periodic_tasks.py | 11 ++++------- api/src/reportcreator_api/utils/utils.py | 8 ++++++++ 6 files changed, 33 insertions(+), 16 deletions(-) diff --git a/api/src/reportcreator_api/api_utils/views.py b/api/src/reportcreator_api/api_utils/views.py index 73c6843b7..4485c152b 100644 --- a/api/src/reportcreator_api/api_utils/views.py +++ b/api/src/reportcreator_api/api_utils/views.py @@ -31,7 +31,7 @@ from reportcreator_api.users.models import AuthIdentity from reportcreator_api.utils import license from reportcreator_api.utils.api import StreamingHttpResponseAsync, ViewSetAsync -from reportcreator_api.utils.utils import copy_keys, remove_duplicates +from reportcreator_api.utils.utils import copy_keys, remove_duplicates, run_in_background log = logging.getLogger(__name__) @@ -215,7 +215,7 @@ async def get(self, request, *args, **kwargs): if res.status_code == 200: # Run periodic tasks - await PeriodicTask.objects.run_all_pending_tasks() + run_in_background(PeriodicTask.objects.run_all_pending_tasks()) # Memory cleanup of worker process gc.collect() @@ -234,3 +234,8 @@ def get(self, request, *args, **kwargs): plugin_api_root = request.build_absolute_uri(f'/api/plugins/{p.plugin_id}/api/') out[p.name.split('.')[-1]] = plugin_api_root return Response(data=out) + + +class PublicAPIRootView(routers.APIRootView): + authentication_classes = [] + permission_classes = [] diff --git a/api/src/reportcreator_api/conf/urls.py b/api/src/reportcreator_api/conf/urls.py index 26ec61a74..1a47eba41 100644 --- a/api/src/reportcreator_api/conf/urls.py +++ b/api/src/reportcreator_api/conf/urls.py @@ -9,7 +9,13 @@ from rest_framework.routers import DefaultRouter from rest_framework_nested.routers import NestedSimpleRouter -from reportcreator_api.api_utils.views import HealthcheckApiView, PluginApiView, PublicUtilsViewSet, UtilsViewSet +from reportcreator_api.api_utils.views import ( + HealthcheckApiView, + PluginApiView, + PublicAPIRootView, + PublicUtilsViewSet, + UtilsViewSet, +) from reportcreator_api.conf import plugins from reportcreator_api.notifications.views import NotificationViewSet from reportcreator_api.pentests.collab.fallback import ConsumerHttpFallbackView @@ -106,6 +112,7 @@ public_router = DefaultRouter() +public_router.APIRootView = PublicAPIRootView public_router.trailing_slash = router.trailing_slash public_router.include_format_suffixes = router.include_format_suffixes diff --git a/api/src/reportcreator_api/pentests/tasks.py b/api/src/reportcreator_api/pentests/tasks.py index 39d8a8976..bf4e00991 100644 --- a/api/src/reportcreator_api/pentests/tasks.py +++ b/api/src/reportcreator_api/pentests/tasks.py @@ -160,7 +160,7 @@ async def cleanup_unreferenced_images_and_files(task_info): await cleanup_template_files(task_info) -def reset_stale_archive_restores(task_info): +async def reset_stale_archive_restores(task_info): """ Deletes decrypted shamir keys from the database, when archive restore is stale (last decryption more than 3 days ago), i.e. some users decrypted their key parts, but some are still missing. @@ -168,7 +168,7 @@ def reset_stale_archive_restores(task_info): """ from reportcreator_api.pentests.models import ArchivedProjectKeyPart - ArchivedProjectKeyPart.objects \ + await ArchivedProjectKeyPart.objects \ .filter(decrypted_at__isnull=False) \ .annotate(last_decrypted=Subquery( ArchivedProjectKeyPart.objects @@ -178,7 +178,7 @@ def reset_stale_archive_restores(task_info): .values_list('last_decrypted'), )) \ .filter(last_decrypted__lt=timezone.now() - settings.AUTOMATICALLY_RESET_STALE_ARCHIVE_RESTORES_AFTER) \ - .update(decrypted_at=None, key_part=None) + .aupdate(decrypted_at=None, key_part=None) async def automatically_archive_projects(task_info): diff --git a/api/src/reportcreator_api/tasks/querysets.py b/api/src/reportcreator_api/tasks/querysets.py index fc003837e..1054ea795 100644 --- a/api/src/reportcreator_api/tasks/querysets.py +++ b/api/src/reportcreator_api/tasks/querysets.py @@ -12,10 +12,10 @@ class PeriodicTaskQuerySet(models.QuerySet): - def get_pending_tasks(self): + async def get_pending_tasks(self): from reportcreator_api.tasks.models import TaskStatus pending_tasks = {t['id']: t.copy() for t in settings.PERIODIC_TASKS} - for t in self.filter(id__in=pending_tasks.keys()): + async for t in self.filter(id__in=pending_tasks.keys()): pending_tasks[t.id]['model'] = t # Remove non-pending tasks if (t.status == TaskStatus.RUNNING and t.started > timezone.now() - timedelta(minutes=10)) or \ @@ -75,5 +75,5 @@ async def run_task(self, task_info): await task_info['model'].asave() async def run_all_pending_tasks(self): - for t in await sync_to_async(self.get_pending_tasks)(): + for t in await self.get_pending_tasks(): await self.run_task(t) diff --git a/api/src/reportcreator_api/tests/test_periodic_tasks.py b/api/src/reportcreator_api/tests/test_periodic_tasks.py index 41919719a..0d9e545e4 100644 --- a/api/src/reportcreator_api/tests/test_periodic_tasks.py +++ b/api/src/reportcreator_api/tests/test_periodic_tasks.py @@ -5,10 +5,8 @@ import pytest from asgiref.sync import async_to_sync from django.test import override_settings -from django.urls import reverse from django.utils import timezone from pytest_django.asserts import assertNumQueries -from rest_framework.test import APIClient from reportcreator_api.pentests.models import ArchivedProject, CollabEvent, CollabEventType, PentestProject from reportcreator_api.pentests.tasks import ( @@ -59,8 +57,7 @@ def setUp(self): yield def run_tasks(self): - res = APIClient().get(reverse('utils-healthcheck')) - assert res.status_code == 200, res.data + async_to_sync(PeriodicTask.objects.run_all_pending_tasks)() def test_initial_run(self): self.run_tasks() @@ -346,7 +343,7 @@ def test_reset_stale(self): keypart.key_part = {'key_id': 'shamir-key-id', 'key': 'dummy-key'} keypart.save() - reset_stale_archive_restores(None) + async_to_sync(reset_stale_archive_restores)(None) keypart.refresh_from_db() assert not keypart.is_decrypted @@ -366,7 +363,7 @@ def test_reset_not_stale(self): keypart2.key_part = {'key_id': 'shamir-key-id-2', 'key': 'dummy-key2'} keypart2.save() - reset_stale_archive_restores(None) + async_to_sync(reset_stale_archive_restores)(None) keypart1.refresh_from_db() assert keypart1.is_decrypted @@ -389,7 +386,7 @@ def test_reset_one_but_not_other(self): keypart2.key_part = {'key_id': 'shamir-key-id', 'key': 'dummy-key'} keypart2.save() - reset_stale_archive_restores(None) + async_to_sync(reset_stale_archive_restores)(None) keypart1.refresh_from_db() assert not keypart1.is_decrypted diff --git a/api/src/reportcreator_api/utils/utils.py b/api/src/reportcreator_api/utils/utils.py index be12186ad..a163d2b7d 100644 --- a/api/src/reportcreator_api/utils/utils.py +++ b/api/src/reportcreator_api/utils/utils.py @@ -148,3 +148,11 @@ async def aretry(func, timeout=timedelta(seconds=1), interval=timedelta(seconds= raise else: await asyncio.sleep(interval.total_seconds()) + + +_background_tasks = set() +def run_in_background(coro): + task = asyncio.create_task(coro) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + From e5448f8fdd3b7b1cd9b731b27fd8c65f0bb900c5 Mon Sep 17 00:00:00 2001 From: Michael Wedl Date: Tue, 19 Nov 2024 14:40:09 +0100 Subject: [PATCH 2/2] Use library for retrying --- api/poetry.lock | 31 ++++++++++++++----- api/pyproject.toml | 1 + .../pentests/collab/consumer_base.py | 7 +++-- api/src/reportcreator_api/utils/utils.py | 16 +--------- 4 files changed, 29 insertions(+), 26 deletions(-) diff --git a/api/poetry.lock b/api/poetry.lock index 30f3dc1f2..9d49d8a63 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -221,17 +221,17 @@ files = [ [[package]] name = "boto3" -version = "1.35.63" +version = "1.35.64" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.35.63-py3-none-any.whl", hash = "sha256:d0f938d4f6f392b6ffc5e75fff14a42e5bbb5228675a0367c8af55398abadbec"}, - {file = "boto3-1.35.63.tar.gz", hash = "sha256:deb593d9a0fb240deb4c43e4da8e6626d7c36be7b2fd2fe28f49d44d395b7de0"}, + {file = "boto3-1.35.64-py3-none-any.whl", hash = "sha256:cdacf03fc750caa3aa0dbf6158166def9922c9d67b4160999ff8fc350662facc"}, + {file = "boto3-1.35.64.tar.gz", hash = "sha256:bc3fc12b41fa2c91e51ab140f74fb1544408a2b1e00f88a4c2369a66d18ddf20"}, ] [package.dependencies] -botocore = ">=1.35.63,<1.36.0" +botocore = ">=1.35.64,<1.36.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -240,13 +240,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.35.63" +version = "1.35.64" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.35.63-py3-none-any.whl", hash = "sha256:0ca1200694a4c0a3fa846795d8e8a08404c214e21195eb9e010c4b8a4ca78a4a"}, - {file = "botocore-1.35.63.tar.gz", hash = "sha256:2b8196bab0a997d206c3d490b52e779ef47dffb68c57c685443f77293aca1589"}, + {file = "botocore-1.35.64-py3-none-any.whl", hash = "sha256:bbd96bf7f442b1d5e35b36f501076e4a588c83d8d84a1952e9ee1d767e5efb3e"}, + {file = "botocore-1.35.64.tar.gz", hash = "sha256:2f95c83f31c9e38a66995c88810fc638c829790e125032ba00ab081a2cf48cb9"}, ] [package.dependencies] @@ -3088,6 +3088,21 @@ files = [ dev = ["build", "hatch"] doc = ["sphinx"] +[[package]] +name = "tenacity" +version = "9.0.0" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, + {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "tinycss2" version = "1.4.0" @@ -3806,4 +3821,4 @@ test = ["pytest"] [metadata] lock-version = "2.0" python-versions = "~3.12" -content-hash = "6824ed33100c7a9b9e65da3f70dd58a60a1583407b1043c209fc116a8eced673" +content-hash = "8d511a6cf36534218de2456fcecc45fae00999ef4a3a620de79efc600ed004e5" diff --git a/api/pyproject.toml b/api/pyproject.toml index 8c027b489..f21cb68bd 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -31,6 +31,7 @@ channels-redis = "^4.2.0" requests = "^2.28.2" httpx = "^0.27.0" +tenacity = "^9.0.0" regex = "^2024.5.10" jsonschema = "^4.17.3" python-decouple = "^3.8" diff --git a/api/src/reportcreator_api/pentests/collab/consumer_base.py b/api/src/reportcreator_api/pentests/collab/consumer_base.py index 1862ec5f7..b72eff05f 100644 --- a/api/src/reportcreator_api/pentests/collab/consumer_base.py +++ b/api/src/reportcreator_api/pentests/collab/consumer_base.py @@ -5,6 +5,7 @@ from functools import cached_property from typing import Any +import tenacity from channels.db import database_sync_to_async from channels.exceptions import DenyConnection, StopConsumer from channels.generic.websocket import AsyncJsonWebsocketConsumer @@ -31,7 +32,6 @@ from reportcreator_api.users.serializers import PentestUserSerializer from reportcreator_api.utils.elasticapm import elasticapm_capture_websocket_transaction from reportcreator_api.utils.history import history_context -from reportcreator_api.utils.utils import aretry log = logging.getLogger(__name__) @@ -259,12 +259,13 @@ async def collab_event(self, event): return if event.get('id'): + # Retry fetching event from DB: DB transactions can cause the channels event to arrive before event data is commited to the DB + @tenacity.retry(retry=tenacity.retry_if_exception_type(CollabEvent.DoesNotExist), stop=tenacity.stop_after_delay(1), wait=tenacity.wait_fixed(0.1)) @database_sync_to_async def get_collab_event(id): return CollabEvent.objects.get(id=id) - # Retry fetching event from DB: DB transactions can cause the channels event to arrive before event data is commited to the DB - collab_event = await aretry(lambda: get_collab_event(event['id']), retry_for=CollabEvent.DoesNotExist) + collab_event = await get_collab_event(event['id']) event_data = collab_event.to_dict() elif isinstance(event.get('event'), dict): event_data = event['event'] diff --git a/api/src/reportcreator_api/utils/utils.py b/api/src/reportcreator_api/utils/utils.py index a163d2b7d..85bd19d85 100644 --- a/api/src/reportcreator_api/utils/utils.py +++ b/api/src/reportcreator_api/utils/utils.py @@ -1,6 +1,6 @@ import asyncio import uuid -from datetime import date, timedelta +from datetime import date from itertools import groupby from typing import Any, Iterable, OrderedDict @@ -136,20 +136,6 @@ def groupby_to_dict(data: dict, key) -> dict: return dict(map(lambda t: (t[0], list(t[1])), groupby(sorted(data, key=key), key=key))) -async def aretry(func, timeout=timedelta(seconds=1), interval=timedelta(seconds=0.1), retry_for=None): - timeout_abs = timezone.now() + timeout - while True: - try: - return await func() - except Exception as ex: - if retry_for and not isinstance(ex, retry_for): - raise - elif timezone.now() > timeout_abs: - raise - else: - await asyncio.sleep(interval.total_seconds()) - - _background_tasks = set() def run_in_background(coro): task = asyncio.create_task(coro)