Skip to content

Commit

Permalink
Merge branch 'background-tasks' into 'main'
Browse files Browse the repository at this point in the history
Background tasks

See merge request reportcreator/reportcreator!767
  • Loading branch information
MWedl committed Nov 21, 2024
2 parents ef791ce + e5448f8 commit 7347a62
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 40 deletions.
31 changes: 23 additions & 8 deletions api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ channels-redis = "^4.2.0"

requests = "^2.28.2"
httpx = "^0.27.0"
tenacity = "^9.0.0"
regex = "^2024.5.10"
jsonschema = "^4.17.3"
python-decouple = "^3.8"
Expand Down
9 changes: 7 additions & 2 deletions api/src/reportcreator_api/api_utils/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from reportcreator_api.users.models import AuthIdentity
from reportcreator_api.utils import license
from reportcreator_api.utils.api import StreamingHttpResponseAsync, ViewSetAsync
from reportcreator_api.utils.utils import copy_keys, remove_duplicates
from reportcreator_api.utils.utils import copy_keys, remove_duplicates, run_in_background

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -215,7 +215,7 @@ async def get(self, request, *args, **kwargs):

if res.status_code == 200:
# Run periodic tasks
await PeriodicTask.objects.run_all_pending_tasks()
run_in_background(PeriodicTask.objects.run_all_pending_tasks())

# Memory cleanup of worker process
gc.collect()
Expand All @@ -234,3 +234,8 @@ def get(self, request, *args, **kwargs):
plugin_api_root = request.build_absolute_uri(f'/api/plugins/{p.plugin_id}/api/')
out[p.name.split('.')[-1]] = plugin_api_root
return Response(data=out)


class PublicAPIRootView(routers.APIRootView):
authentication_classes = []
permission_classes = []
9 changes: 8 additions & 1 deletion api/src/reportcreator_api/conf/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from rest_framework.routers import DefaultRouter
from rest_framework_nested.routers import NestedSimpleRouter

from reportcreator_api.api_utils.views import HealthcheckApiView, PluginApiView, PublicUtilsViewSet, UtilsViewSet
from reportcreator_api.api_utils.views import (
HealthcheckApiView,
PluginApiView,
PublicAPIRootView,
PublicUtilsViewSet,
UtilsViewSet,
)
from reportcreator_api.conf import plugins
from reportcreator_api.notifications.views import NotificationViewSet
from reportcreator_api.pentests.collab.fallback import ConsumerHttpFallbackView
Expand Down Expand Up @@ -106,6 +112,7 @@


public_router = DefaultRouter()
public_router.APIRootView = PublicAPIRootView
public_router.trailing_slash = router.trailing_slash
public_router.include_format_suffixes = router.include_format_suffixes

Expand Down
7 changes: 4 additions & 3 deletions api/src/reportcreator_api/pentests/collab/consumer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import cached_property
from typing import Any

import tenacity
from channels.db import database_sync_to_async
from channels.exceptions import DenyConnection, StopConsumer
from channels.generic.websocket import AsyncJsonWebsocketConsumer
Expand All @@ -31,7 +32,6 @@
from reportcreator_api.users.serializers import PentestUserSerializer
from reportcreator_api.utils.elasticapm import elasticapm_capture_websocket_transaction
from reportcreator_api.utils.history import history_context
from reportcreator_api.utils.utils import aretry

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -259,12 +259,13 @@ async def collab_event(self, event):
return

if event.get('id'):
# Retry fetching event from DB: DB transactions can cause the channels event to arrive before event data is commited to the DB
@tenacity.retry(retry=tenacity.retry_if_exception_type(CollabEvent.DoesNotExist), stop=tenacity.stop_after_delay(1), wait=tenacity.wait_fixed(0.1))
@database_sync_to_async
def get_collab_event(id):
return CollabEvent.objects.get(id=id)

# Retry fetching event from DB: DB transactions can cause the channels event to arrive before event data is commited to the DB
collab_event = await aretry(lambda: get_collab_event(event['id']), retry_for=CollabEvent.DoesNotExist)
collab_event = await get_collab_event(event['id'])
event_data = collab_event.to_dict()
elif isinstance(event.get('event'), dict):
event_data = event['event']
Expand Down
6 changes: 3 additions & 3 deletions api/src/reportcreator_api/pentests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@ async def cleanup_unreferenced_images_and_files(task_info):
await cleanup_template_files(task_info)


def reset_stale_archive_restores(task_info):
async def reset_stale_archive_restores(task_info):
"""
Deletes decrypted shamir keys from the database, when archive restore is stale (last decryption more than 3 days ago),
i.e. some users decrypted their key parts, but some are still missing.
Prevent decrypted shamir keys being stored in the DB forever.
"""
from reportcreator_api.pentests.models import ArchivedProjectKeyPart

ArchivedProjectKeyPart.objects \
await ArchivedProjectKeyPart.objects \
.filter(decrypted_at__isnull=False) \
.annotate(last_decrypted=Subquery(
ArchivedProjectKeyPart.objects
Expand All @@ -178,7 +178,7 @@ def reset_stale_archive_restores(task_info):
.values_list('last_decrypted'),
)) \
.filter(last_decrypted__lt=timezone.now() - settings.AUTOMATICALLY_RESET_STALE_ARCHIVE_RESTORES_AFTER) \
.update(decrypted_at=None, key_part=None)
.aupdate(decrypted_at=None, key_part=None)


async def automatically_archive_projects(task_info):
Expand Down
6 changes: 3 additions & 3 deletions api/src/reportcreator_api/tasks/querysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@


class PeriodicTaskQuerySet(models.QuerySet):
def get_pending_tasks(self):
async def get_pending_tasks(self):
from reportcreator_api.tasks.models import TaskStatus
pending_tasks = {t['id']: t.copy() for t in settings.PERIODIC_TASKS}
for t in self.filter(id__in=pending_tasks.keys()):
async for t in self.filter(id__in=pending_tasks.keys()):
pending_tasks[t.id]['model'] = t
# Remove non-pending tasks
if (t.status == TaskStatus.RUNNING and t.started > timezone.now() - timedelta(minutes=10)) or \
Expand Down Expand Up @@ -75,5 +75,5 @@ async def run_task(self, task_info):
await task_info['model'].asave()

async def run_all_pending_tasks(self):
for t in await sync_to_async(self.get_pending_tasks)():
for t in await self.get_pending_tasks():
await self.run_task(t)
11 changes: 4 additions & 7 deletions api/src/reportcreator_api/tests/test_periodic_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import pytest
from asgiref.sync import async_to_sync
from django.test import override_settings
from django.urls import reverse
from django.utils import timezone
from pytest_django.asserts import assertNumQueries
from rest_framework.test import APIClient

from reportcreator_api.pentests.models import ArchivedProject, CollabEvent, CollabEventType, PentestProject
from reportcreator_api.pentests.tasks import (
Expand Down Expand Up @@ -59,8 +57,7 @@ def setUp(self):
yield

def run_tasks(self):
res = APIClient().get(reverse('utils-healthcheck'))
assert res.status_code == 200, res.data
async_to_sync(PeriodicTask.objects.run_all_pending_tasks)()

def test_initial_run(self):
self.run_tasks()
Expand Down Expand Up @@ -348,7 +345,7 @@ def test_reset_stale(self):
keypart.key_part = {'key_id': 'shamir-key-id', 'key': 'dummy-key'}
keypart.save()

reset_stale_archive_restores(None)
async_to_sync(reset_stale_archive_restores)(None)

keypart.refresh_from_db()
assert not keypart.is_decrypted
Expand All @@ -368,7 +365,7 @@ def test_reset_not_stale(self):
keypart2.key_part = {'key_id': 'shamir-key-id-2', 'key': 'dummy-key2'}
keypart2.save()

reset_stale_archive_restores(None)
async_to_sync(reset_stale_archive_restores)(None)

keypart1.refresh_from_db()
assert keypart1.is_decrypted
Expand All @@ -391,7 +388,7 @@ def test_reset_one_but_not_other(self):
keypart2.key_part = {'key_id': 'shamir-key-id', 'key': 'dummy-key'}
keypart2.save()

reset_stale_archive_restores(None)
async_to_sync(reset_stale_archive_restores)(None)

keypart1.refresh_from_db()
assert not keypart1.is_decrypted
Expand Down
20 changes: 7 additions & 13 deletions api/src/reportcreator_api/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import uuid
from datetime import date, timedelta
from datetime import date
from itertools import groupby
from typing import Any, Iterable, OrderedDict

Expand Down Expand Up @@ -136,15 +136,9 @@ def groupby_to_dict(data: dict, key) -> dict:
return dict(map(lambda t: (t[0], list(t[1])), groupby(sorted(data, key=key), key=key)))


async def aretry(func, timeout=timedelta(seconds=1), interval=timedelta(seconds=0.1), retry_for=None):
timeout_abs = timezone.now() + timeout
while True:
try:
return await func()
except Exception as ex:
if retry_for and not isinstance(ex, retry_for):
raise
elif timezone.now() > timeout_abs:
raise
else:
await asyncio.sleep(interval.total_seconds())
_background_tasks = set()
def run_in_background(coro):
task = asyncio.create_task(coro)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)

0 comments on commit 7347a62

Please sign in to comment.