Skip to content

Commit

Permalink
Close DB connections after background tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
MWedl committed Jan 2, 2025
1 parent 7dcd318 commit 887d63a
Show file tree
Hide file tree
Showing 8 changed files with 967 additions and 34 deletions.
2 changes: 1 addition & 1 deletion api/src/reportcreator_api/api_utils/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ async def get(self, request, *args, **kwargs):

if res.status_code == 200:
# Run periodic tasks
run_in_background(PeriodicTask.objects.run_all_pending_tasks())
run_in_background(PeriodicTask.objects.run_all_pending_tasks)()

# Memory cleanup of worker process
gc.collect()
Expand Down
4 changes: 2 additions & 2 deletions api/src/reportcreator_api/tasks/querysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ async def run_task(self, task_info):
res = await task_info.spec.func(task_info)
else:
res = await sync_to_async(task_info.spec.func)(task_info)
task_info.model.status = res or TaskStatus.SUCCESS
task_info.model.status = res if isinstance(res, TaskStatus) else TaskStatus.SUCCESS
except Exception:
logging.exception(f'Error while running periodic task "{task_info.id}"')
task_info.model.status = TaskStatus.FAILED

# Set completed time
task_info.model.completed = timezone.now()
if task_info.model.status == TaskStatus.SUCCESS:
task_info.model.completed = task_info.model.last_success
task_info.model.last_success = task_info.model.completed

log.info(f'Completed periodic task "{task_info.id}" with status "{task_info.model.status}"')

Expand Down
4 changes: 2 additions & 2 deletions api/src/reportcreator_api/tasks/rendering/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ async def weasyprint_start_process():

@log_timing(log_start=True)
async def weasyprint_render_to_pdf(proc, **kwargs) -> RenderStageResult:
@sync_to_async
@sync_to_async()
def encode_data():
return json.dumps(kwargs, cls=DjangoJSONEncoder).encode()

@sync_to_async
@sync_to_async()
def decode_data(stdout):
return RenderStageResult.from_dict(json.loads(stdout.decode()))

Expand Down
27 changes: 22 additions & 5 deletions api/src/reportcreator_api/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import logging
import uuid
from datetime import date
from itertools import groupby
from typing import Any, Iterable, OrderedDict

from asgiref.sync import sync_to_async
from django.db import close_old_connections, connections
from django.utils import dateparse, timezone


Expand Down Expand Up @@ -137,8 +140,22 @@ def groupby_to_dict(data: dict, key) -> dict:


_background_tasks = set()
def run_in_background(coro):
task = asyncio.create_task(coro)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)

def run_in_background(func):
def inner(*args, **kwargs):
@sync_to_async()
def task_finished():
if not connections['default'].in_atomic_block:
close_old_connections()

async def wrapper():
try:
await func(*args, **kwargs)
except Exception:
logging.exception(f'Error while running run_in_background({func.__name__})')
finally:
await task_finished()

task = asyncio.create_task(wrapper())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
return inner
949 changes: 927 additions & 22 deletions packages/NOTICE

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion plugins/demoplugin/tests/test_plugin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_delete(self):
class TestDemoPluginWebsocketConsumer:
@pytest_asyncio.fixture(autouse=True)
async def setUp(self):
@sync_to_async
@sync_to_async()
def setup_db():
self.user1 = create_user()
self.user2 = create_user()
Expand Down
11 changes: 11 additions & 0 deletions plugins/webhooks/tests/test_webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from unittest import mock

import pytest
from asgiref.sync import async_to_sync
from django.apps import apps
from reportcreator_api.pentests.models import ArchivedProject
from reportcreator_api.tests.mock import create_finding, create_project, create_user
from reportcreator_api.utils import utils

from ..app import WebhooksPluginConfig
from ..models import WebhookEventType
Expand Down Expand Up @@ -36,6 +38,11 @@ def setUp(self):
self.project = create_project(members=[self.user])
with mock.patch('sysreptor_plugins.webhooks.utils.send_webhook_request') as self.mock:
yield

@async_to_sync()
async def wait_for_background_tasks(self):
for t in utils._background_tasks:
await t

@pytest.mark.parametrize(['event', 'trigger'], [
(WebhookEventType.PROJECT_CREATED, lambda s: create_project()),
Expand All @@ -50,6 +57,7 @@ def test_webhooks_called(self, event, trigger):
with override_webhook_settings(WEBHOOKS=[webhook_config, {'url': 'https://example.com/other', 'events': ['other']}]):
self.mock.assert_not_called()
trigger(self)
self.wait_for_background_tasks()
self.mock.assert_called_once()
call_args = self.mock.call_args[1]
assert call_args['webhook'] == webhook_config
Expand All @@ -60,10 +68,12 @@ def test_event_filter(self):
# Not subscribed to event
self.mock.assert_not_called()
update(self.project, readonly=True)
self.wait_for_background_tasks()
self.mock.assert_not_called()

# Subscribed to event: webhook called
create_project()
self.wait_for_background_tasks()
self.mock.assert_called_once()

@override_webhook_settings(WEBHOOKS=[
Expand All @@ -73,4 +83,5 @@ def test_event_filter(self):
def test_error_handling(self):
self.mock.side_effect = [Exception('Request failed'), mock.DEFAULT]
create_project()
self.wait_for_background_tasks()
assert self.mock.call_count == 2
2 changes: 1 addition & 1 deletion plugins/webhooks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@ async def send_webhooks(event_type: WebhookEventType, data):

webhooks_to_send = list(filter(lambda w: event_type in w.get('events', []), webhook_settings))
if webhooks_to_send:
run_in_background(send_webhook_requests(data | {'event': event_type.value}, webhooks_to_send))
run_in_background(send_webhook_requests)(data | {'event': event_type.value}, webhooks_to_send)

0 comments on commit 887d63a

Please sign in to comment.