diff --git a/api/pyproject.toml b/api/pyproject.toml index 9d5f55e6c..7c8e2d0e3 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -82,7 +82,8 @@ extend-select = [ ignore = [ "E741", # Ambiguous variable name: `l` "DJ001", # Avoid using null=True on string-based fields in django models - "PT004", # Fixture setUp does not return anything, add leading underscore "S101", # Use of `assert` detected ] +[tool.ruff.lint.flake8-pytest-style] +mark-parentheses = true \ No newline at end of file diff --git a/api/src/reportcreator_api/tasks/migrations/0002_licenseactivationinfo.py b/api/src/reportcreator_api/tasks/migrations/0002_licenseactivationinfo.py new file mode 100644 index 000000000..db8a78d02 --- /dev/null +++ b/api/src/reportcreator_api/tasks/migrations/0002_licenseactivationinfo.py @@ -0,0 +1,30 @@ +# Generated by Django 5.1.4 on 2025-01-09 09:24 + +import reportcreator_api.utils.models +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('tasks', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='LicenseActivationInfo', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('created', models.DateTimeField(default=reportcreator_api.utils.models.now, editable=False)), + ('updated', models.DateTimeField(auto_now=True)), + ('license_type', models.CharField(choices=[('community', 'Community'), ('professional', 'Professional')], max_length=255)), + ('license_hash', models.CharField(max_length=255, null=True)), + ('last_activation_time', models.DateTimeField(blank=True, null=True)), + ], + options={ + 'ordering': ['-created'], + 'abstract': False, + }, + ), + ] diff --git a/api/src/reportcreator_api/tasks/models.py b/api/src/reportcreator_api/tasks/models.py index bdc4d6f1a..de8f293ce 100644 --- a/api/src/reportcreator_api/tasks/models.py +++ b/api/src/reportcreator_api/tasks/models.py @@ -5,6 +5,7 @@ from django.utils import timezone from reportcreator_api.tasks import querysets +from reportcreator_api.utils import license from reportcreator_api.utils.models import BaseModel @@ -71,3 +72,11 @@ def inner(func): )) return func return inner + + +class LicenseActivationInfo(BaseModel): + license_type = models.CharField(max_length=255, choices=license.LicenseType.choices) + license_hash = models.CharField(max_length=255, null=True) + last_activation_time = models.DateTimeField(null=True, blank=True) + + objects = querysets.LicenseActivationInfoManager() diff --git a/api/src/reportcreator_api/tasks/querysets.py b/api/src/reportcreator_api/tasks/querysets.py index dd6c125f8..088b99b0b 100644 --- a/api/src/reportcreator_api/tasks/querysets.py +++ b/api/src/reportcreator_api/tasks/querysets.py @@ -5,6 +5,8 @@ from django.db import IntegrityError, models from django.utils import timezone +from reportcreator_api.utils import license + log = logging.getLogger(__name__) @@ -82,3 +84,19 @@ async def run_task(self, task_info): async def run_all_pending_tasks(self): for t in await self.get_pending_tasks(): await self.run_task(t) + + +class LicenseActivationInfoManager(models.Manager.from_queryset(models.QuerySet)): + def current(self): + obj = self.order_by('-created').first() + + if not license.is_professional(): + if not obj or obj.license_type != license.LicenseType.COMMUNITY: + obj = self.create(license_type=license.LicenseType.COMMUNITY) + else: + if not obj or obj.license_type != license.LicenseType.PROFESSIONAL or obj.license_hash != license.get_license_hash(): + obj = self.create( + license_type=license.LicenseType.PROFESSIONAL, + license_hash=license.get_license_hash(), + ) + return obj diff --git a/api/src/reportcreator_api/tasks/tasks.py b/api/src/reportcreator_api/tasks/tasks.py index cc1d1da95..28a0c19c6 100644 --- a/api/src/reportcreator_api/tasks/tasks.py +++ b/api/src/reportcreator_api/tasks/tasks.py @@ -1,10 +1,47 @@ +import json +import logging from datetime import timedelta +import httpx +from asgiref.sync import sync_to_async from django.core.management import call_command +from django.core.serializers.json import DjangoJSONEncoder +from django.utils import dateparse -from reportcreator_api.tasks.models import periodic_task +from reportcreator_api.tasks.models import LicenseActivationInfo, TaskStatus, periodic_task +from reportcreator_api.utils import license @periodic_task(id='clear_sessions', schedule=timedelta(days=1)) def clear_sessions(task_info): call_command('clearsessions') + + +async def activate_license_request(license_info): + async with httpx.AsyncClient(timeout=10) as client: + res = await client.post( + url='https://panel.sysreptor.com/api/v1/licenses/activate/', + headers={'Content-Type': 'application/json'}, + data=json.dumps(license_info, cls=DjangoJSONEncoder), + ) + res.raise_for_status() + return res.json() + + +@periodic_task(id='activate_license', schedule=timedelta(days=1)) +async def activate_license(task_info): + activation_info = await sync_to_async(LicenseActivationInfo.objects.current)() + if not await license.ais_professional(): + return TaskStatus.SUCCESS + + try: + res = await activate_license_request(await license.aget_license_info()) + if res.get('status') == 'ok': + try: + activation_info.last_activation_time = dateparse.parse_datetime(res.get('license_info', {}).get('last_activation_time')) + except (TypeError, ValueError): + activation_info.last_activation_time = None + await activation_info.asave() + except httpx.TransportError as ex: + logging.warning(f'Failed to activate license: {ex}. Check your internet connection.') + return TaskStatus.FAILED diff --git a/api/src/reportcreator_api/tests/test_api.py b/api/src/reportcreator_api/tests/test_api.py index 5515d3545..7d2b07237 100644 --- a/api/src/reportcreator_api/tests/test_api.py +++ b/api/src/reportcreator_api/tests/test_api.py @@ -508,7 +508,8 @@ async def mock_render_pdf(*args, output=None, **kwargs): 'label': 'Dummy', }, }, - ), mock.patch('reportcreator_api.tasks.rendering.render.render_pdf_impl', mock_render_pdf): + ), mock.patch('reportcreator_api.tasks.rendering.render.render_pdf_impl', mock_render_pdf), \ + mock.patch('reportcreator_api.tasks.tasks.activate_license', return_value=None): user_map = { 'public': lambda: None, 'guest': lambda: ApiRequestsAndPermissionsTestData.create_user(is_guest=True), diff --git a/api/src/reportcreator_api/tests/test_license.py b/api/src/reportcreator_api/tests/test_license.py index c747a9d51..982eb3fa9 100644 --- a/api/src/reportcreator_api/tests/test_license.py +++ b/api/src/reportcreator_api/tests/test_license.py @@ -5,6 +5,7 @@ from uuid import uuid4 import pytest +from asgiref.sync import async_to_sync from Cryptodome.Hash import SHA512 from Cryptodome.PublicKey import ECC from Cryptodome.Signature import eddsa @@ -16,6 +17,8 @@ from rest_framework import status from rest_framework.test import APIClient +from reportcreator_api.tasks.models import LicenseActivationInfo +from reportcreator_api.tasks.tasks import activate_license from reportcreator_api.tests.mock import ( api_client, create_project, @@ -33,6 +36,42 @@ def assert_api_license_error(res): assert res.data['code'] == 'license' +def generate_signing_key(): + private_key = ECC.generate(curve='ed25519') + public_key = { + 'id': str(uuid4()), + 'algorithm': 'ed25519', + 'key': b64encode(private_key.public_key().export_key(format='DER')).decode(), + } + return private_key, public_key + + +def sign_license_data(license_data_str: str, public_key: dict, private_key): + signer = eddsa.new(key=private_key, mode='rfc8032') + signature = signer.sign(SHA512.new(license_data_str.encode())) + return { + 'key_id': public_key['id'], + 'algorithm': public_key['algorithm'], + 'signature': b64encode(signature).decode(), + } + + +def sign_license(license_data, keys): + license_data_str = json.dumps(license_data) + return b64encode(json.dumps({ + 'data': license_data_str, + 'signatures': [sign_license_data(license_data_str, k[0], k[1]) for k in keys], + }).encode()).decode() + + +def signed_license(keys, **kwargs): + return sign_license({ + 'users': 10, + 'valid_from': (timezone.now() - timedelta(days=30)).date().isoformat(), + 'valid_until': (timezone.now() + timedelta(days=30)).date().isoformat(), + } | kwargs, keys) + + @pytest.mark.django_db() class TestCommunityLicenseRestrictions: @pytest.fixture(autouse=True) @@ -235,41 +274,12 @@ def test_user_count_limit(self): class TestLicenseValidation: @pytest.fixture(autouse=True) def setUp(self): - self.license_private_key, self.license_public_key = self.generate_signing_key() + self.license_private_key, self.license_public_key = generate_signing_key() with mock.patch('reportcreator_api.utils.license.LICENSE_VALIDATION_KEYS', new=[self.license_public_key]): yield - def generate_signing_key(self): - private_key = ECC.generate(curve='ed25519') - public_key = { - 'id': str(uuid4()), - 'algorithm': 'ed25519', - 'key': b64encode(private_key.public_key().export_key(format='DER')).decode(), - } - return private_key, public_key - - def sign_license_data(self, license_data_str: str, public_key: dict, private_key): - signer = eddsa.new(key=private_key, mode='rfc8032') - signature = signer.sign(SHA512.new(license_data_str.encode())) - return { - 'key_id': public_key['id'], - 'algorithm': public_key['algorithm'], - 'signature': b64encode(signature).decode(), - } - - def sign_license(self, license_data, keys): - license_data_str = json.dumps(license_data) - return b64encode(json.dumps({ - 'data': license_data_str, - 'signatures': [self.sign_license_data(license_data_str, k[0], k[1]) for k in keys], - }).encode()).decode() - def signed_license(self, **kwargs): - return self.sign_license({ - 'users': 10, - 'valid_from': (timezone.now() - timedelta(days=30)).date().isoformat(), - 'valid_until': (timezone.now() + timedelta(days=30)).date().isoformat(), - } | kwargs, [(self.license_public_key, self.license_private_key)]) + return signed_license(keys=[(self.license_public_key, self.license_private_key)], **kwargs) @pytest.mark.parametrize(('license_str', 'error'), [ (None, None), @@ -339,3 +349,38 @@ def test_multiple_signatures_only_1_valid(self): license_2 = b64encode(json.dumps(license_content).encode()) license_info = license.decode_and_validate_license(license_2) assert license_info['type'] == license.LicenseType.PROFESSIONAL + + +@pytest.mark.django_db() +class TestLicenseActivationInfo: + @pytest.fixture(autouse=True) + def setUp(self): + self.license_private_key, self.license_public_key = generate_signing_key() + self.license_community = None + self.license_invalid = 'invalid license string' + self.license_professional = signed_license(keys=[(self.license_public_key, self.license_private_key)]) + self.license_professional2 = signed_license(keys=[(self.license_public_key, self.license_private_key)], users=20) + self.license_expired = signed_license(keys=[(self.license_public_key, self.license_private_key)], valid_until=(timezone.now() - timedelta(days=1)).date().isoformat()) + + with mock.patch('reportcreator_api.utils.license.LICENSE_VALIDATION_KEYS', new=[self.license_public_key]), \ + mock.patch('reportcreator_api.utils.license.check_license', new=lambda: license.decode_and_validate_license(settings.LICENSE)), \ + mock.patch('reportcreator_api.tasks.tasks.activate_license_request', return_value={'status': 'ok', 'license_info': {'last_activation_time': timezone.now().isoformat()}}): + yield + + @pytest.mark.parametrize(('name_old', 'name_new', 'created'), [ + ('license_community', 'license_community', False), + ('license_community', 'license_invalid', False), + ('license_community', 'license_professional', True), + ('license_professional', 'license_professional', False), + ('license_professional', 'license_professional2', True), + ('license_professional', 'license_community', True), + ('license_professional', 'license_expired', True), + ]) + def test_license_activation_info_created(self, name_old, name_new, created): + with override_settings(LICENSE=getattr(self, name_old)): + activation_info_old = LicenseActivationInfo.objects.current() + with override_settings(LICENSE=getattr(self, name_new)): + async_to_sync(activate_license)(None) + activation_info_new = LicenseActivationInfo.objects.order_by('-created').first() + + assert (activation_info_old != activation_info_new) == created diff --git a/api/src/reportcreator_api/users/views.py b/api/src/reportcreator_api/users/views.py index 065528755..b0d7f8bfe 100644 --- a/api/src/reportcreator_api/users/views.py +++ b/api/src/reportcreator_api/users/views.py @@ -413,6 +413,7 @@ def perform_login(self, request, user, can_reauth=True): return Response({ 'status': 'success', 'first_login': first_login, + 'license': license.get_license_info(), }, status=status.HTTP_200_OK) @action(detail=False, url_path='login/oidc/(?P[a-zA-Z0-9]+)/begin', methods=['get'], permission_classes=[license.ProfessionalLicenseRequired]) diff --git a/api/src/reportcreator_api/utils/license.py b/api/src/reportcreator_api/utils/license.py index 1467712fa..4a95c1de2 100644 --- a/api/src/reportcreator_api/utils/license.py +++ b/api/src/reportcreator_api/utils/license.py @@ -1,4 +1,5 @@ import base64 +import hashlib import json import logging @@ -123,6 +124,15 @@ def decode_and_validate_license(license, skip_db_checks=False, skip_limit_valida } +def get_license_hash(): + if not settings.LICENSE: + return None + try: + return 'sha3_256$$' + hashlib.sha3_256(base64.b64decode(settings.LICENSE)).hexdigest() + except Exception: + return None + + @cache('license.license_info', timeout=10 * 60) def check_license(**kwargs): return decode_and_validate_license(license=settings.LICENSE, **kwargs) @@ -134,14 +144,22 @@ async def acheck_license(**kwargs): def get_license_info(): from reportcreator_api.conf import plugins + from reportcreator_api.tasks.models import LicenseActivationInfo from reportcreator_api.users.models import PentestUser + activation_info = LicenseActivationInfo.objects.current() return check_license() | { + 'license_hash': get_license_hash(), 'active_users': PentestUser.objects.get_licensed_user_count(), 'total_users': PentestUser.objects.get_total_user_count(), 'installation_id': settings.INSTALLATION_ID, 'software_version': settings.VERSION, 'plugins': [p.name.split('.')[-1] for p in plugins.enabled_plugins], + 'activation_info': { + 'created': activation_info.created, + 'license_hash': activation_info.license_hash, + 'last_activation_time': activation_info.last_activation_time, + }, } async def aget_license_info(): diff --git a/plugins/webhooks/tests/test_webhooks.py b/plugins/webhooks/tests/test_webhooks.py index c2b1b2124..872e0ae99 100644 --- a/plugins/webhooks/tests/test_webhooks.py +++ b/plugins/webhooks/tests/test_webhooks.py @@ -30,7 +30,7 @@ def update(obj, **kwargs): return obj -@pytest.mark.django_db +@pytest.mark.django_db() class TestWebhooksCalled: @pytest.fixture(autouse=True) def setUp(self):