Skip to content

Commit

Permalink
Merge branch 'license-activation' into 'main'
Browse files Browse the repository at this point in the history
License activation

See merge request reportcreator/reportcreator!819
  • Loading branch information
MWedl committed Jan 14, 2025
2 parents 6969a0a + 3185811 commit cca3fba
Show file tree
Hide file tree
Showing 20 changed files with 294 additions and 107 deletions.
3 changes: 2 additions & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ extend-select = [
ignore = [
"E741", # Ambiguous variable name: `l`
"DJ001", # Avoid using null=True on string-based fields in django models
"PT004", # Fixture setUp does not return anything, add leading underscore
"S101", # Use of `assert` detected
]

[tool.ruff.lint.flake8-pytest-style]
mark-parentheses = true
4 changes: 2 additions & 2 deletions api/src/reportcreator_api/api_utils/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from reportcreator_api.api_utils import backup_utils
from reportcreator_api.api_utils.healthchecks import run_healthchecks
from reportcreator_api.api_utils.models import BackupLog
from reportcreator_api.api_utils.permissions import IsAdminOrSystem, IsUserManagerOrSuperuserOrSystem
from reportcreator_api.api_utils.permissions import IsAdminOrSystem
from reportcreator_api.api_utils.serializers import (
BackupLogSerializer,
BackupSerializer,
Expand Down Expand Up @@ -115,7 +115,7 @@ def backuplogs(self, request, *args, **kwargs):
return self.get_paginated_response(serializer.data)

@extend_schema(responses=OpenApiTypes.OBJECT)
@action(detail=False, url_name='license', url_path='license', methods=['get'], permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES + [IsUserManagerOrSuperuserOrSystem])
@action(detail=False, url_name='license', url_path='license', methods=['get'])
async def license_info(self, request, *args, **kwargs):
return Response(data=await license.aget_license_info())

Expand Down
2 changes: 1 addition & 1 deletion api/src/reportcreator_api/conf/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def remove_empty_items(lst=None):
'img-src': [SELF, 'data:'],
'font-src': [SELF],
'worker-src': [SELF],
'connect-src': [SELF, 'data:'],
'connect-src': [SELF, 'data:', 'https://portal.sysreptor.com'],
'frame-src': [SELF],
'frame-ancestors': [SELF],
'form-action': [SELF],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Generated by Django 5.1.4 on 2025-01-09 09:24

import reportcreator_api.utils.models
import uuid
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('tasks', '0001_initial'),
]

operations = [
migrations.CreateModel(
name='LicenseActivationInfo',
fields=[
('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)),
('created', models.DateTimeField(default=reportcreator_api.utils.models.now, editable=False)),
('updated', models.DateTimeField(auto_now=True)),
('license_type', models.CharField(choices=[('community', 'Community'), ('professional', 'Professional')], max_length=255)),
('license_hash', models.CharField(max_length=255, null=True)),
('last_activation_time', models.DateTimeField(blank=True, null=True)),
],
options={
'ordering': ['-created'],
'abstract': False,
},
),
]
14 changes: 13 additions & 1 deletion api/src/reportcreator_api/tasks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.utils import timezone

from reportcreator_api.tasks import querysets
from reportcreator_api.utils import license
from reportcreator_api.utils.models import BaseModel


Expand All @@ -29,6 +30,8 @@ class PeriodicTaskSpec:
id: str
schedule: timedelta
func: callable
retry: timedelta = timedelta(minutes=10)
max_runtime: timedelta = timedelta(minutes=10)


@dataclasses.dataclass()
Expand Down Expand Up @@ -59,12 +62,21 @@ def unregister(self, task: PeriodicTaskSpec):
periodic_task_registry = PeriodicTaskRegistry()


def periodic_task(schedule: timedelta, id: str|None = None):
def periodic_task(schedule: timedelta, id: str|None = None, retry: timedelta|None = None):
def inner(func):
periodic_task_registry.register(PeriodicTaskSpec(
id=id or f'{func.__module__}.{func.__name__}',
schedule=schedule,
retry=retry or (schedule / 10),
func=func,
))
return func
return inner


class LicenseActivationInfo(BaseModel):
license_type = models.CharField(max_length=255, choices=license.LicenseType.choices)
license_hash = models.CharField(max_length=255, null=True)
last_activation_time = models.DateTimeField(null=True, blank=True)

objects = querysets.LicenseActivationInfoManager()
23 changes: 20 additions & 3 deletions api/src/reportcreator_api/tasks/querysets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from datetime import timedelta

import elasticapm
from asgiref.sync import iscoroutinefunction, sync_to_async
from django.db import IntegrityError, models
from django.utils import timezone

from reportcreator_api.utils import license

log = logging.getLogger(__name__)


Expand All @@ -19,8 +20,8 @@ async def get_pending_tasks(self):
model = task_models.get(t_id)
# Remove non-pending tasks
if model and (
(model.status == TaskStatus.RUNNING and model.started > timezone.now() - timedelta(minutes=10)) or \
(model.status == TaskStatus.FAILED and model.started > timezone.now() - timedelta(minutes=10)) or \
(model.status == TaskStatus.RUNNING and model.started > timezone.now() - spec.max_runtime) or \
(model.status == TaskStatus.FAILED and model.started > timezone.now() - spec.retry) or \
(model.status == TaskStatus.SUCCESS and model.started > timezone.now() - spec.schedule)
):
continue
Expand Down Expand Up @@ -83,3 +84,19 @@ async def run_task(self, task_info):
async def run_all_pending_tasks(self):
for t in await self.get_pending_tasks():
await self.run_task(t)


class LicenseActivationInfoManager(models.Manager.from_queryset(models.QuerySet)):
def current(self):
obj = self.order_by('-created').first()

if not license.is_professional():
if not obj or obj.license_type != license.LicenseType.COMMUNITY:
obj = self.create(license_type=license.LicenseType.COMMUNITY)
else:
if not obj or obj.license_type != license.LicenseType.PROFESSIONAL or obj.license_hash != license.get_license_hash():
obj = self.create(
license_type=license.LicenseType.PROFESSIONAL,
license_hash=license.get_license_hash(),
)
return obj
39 changes: 38 additions & 1 deletion api/src/reportcreator_api/tasks/tasks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,47 @@
import json
import logging
from datetime import timedelta

import httpx
from asgiref.sync import sync_to_async
from django.core.management import call_command
from django.core.serializers.json import DjangoJSONEncoder
from django.utils import dateparse

from reportcreator_api.tasks.models import periodic_task
from reportcreator_api.tasks.models import LicenseActivationInfo, TaskStatus, periodic_task
from reportcreator_api.utils import license


@periodic_task(id='clear_sessions', schedule=timedelta(days=1))
def clear_sessions(task_info):
call_command('clearsessions')


async def activate_license_request(license_info):
async with httpx.AsyncClient(timeout=10) as client:
res = await client.post(
url='https://portal.sysreptor.com/api/v1/licenses/activate/',
headers={'Content-Type': 'application/json'},
data=json.dumps(license_info, cls=DjangoJSONEncoder),
)
res.raise_for_status()
return res.json()


@periodic_task(id='activate_license', schedule=timedelta(days=1))
async def activate_license(task_info):
activation_info = await sync_to_async(LicenseActivationInfo.objects.current)()
if not await license.ais_professional():
return TaskStatus.SUCCESS

try:
res = await activate_license_request(await license.aget_license_info())
if res.get('status') == 'ok':
try:
activation_info.last_activation_time = dateparse.parse_datetime(res.get('license_info', {}).get('last_activation_time'))
except (TypeError, ValueError):
activation_info.last_activation_time = None
await activation_info.asave()
except httpx.TransportError as ex:
logging.warning(f'Failed to activate license: {ex}. Check your internet connection.')
return TaskStatus.FAILED
5 changes: 3 additions & 2 deletions api/src/reportcreator_api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def guest_urls():
return [
('utils list', lambda s, c: c.get(reverse('utils-list'))),
('utils cwes', lambda s, c: c.get(reverse('utils-cwes'))),
('utils-license', lambda s, c: c.get(reverse('utils-license'))),

*viewset_urls('pentestuser', get_kwargs=lambda s, detail: {'pk': 'self'}, retrieve=True, update=True, update_partial=True),
*viewset_urls('pentestuser', get_kwargs=lambda s, detail: {}, list=True),
Expand Down Expand Up @@ -304,7 +305,6 @@ def user_manager_urls():
*viewset_urls('authidentity', get_kwargs=lambda s, detail: {'pentestuser_pk': s.user_other.pk} | ({'pk': s.user_other.auth_identities.first().pk} if detail else {}), list=True, retrieve=True, create=True, create_data={'identifier': 'other.identifier'}, update=True, update_partial=True, destroy=True),
*viewset_urls('apitoken', get_kwargs=lambda s, detail: {'pentestuser_pk': s.user_other.pk} | ({'pk': s.user_other.api_tokens.first().pk} if detail else {}), list=True, retrieve=True, destroy=True),
*viewset_urls('userpublickey', get_kwargs=lambda s, detail: {'pentestuser_pk': s.user_other.pk} | ({'pk': s.user_other.public_keys.first().pk} if detail else {}), list=True, retrieve=True),
('utils-license', lambda s, c: c.get(reverse('utils-license'))),
]


Expand Down Expand Up @@ -508,7 +508,8 @@ async def mock_render_pdf(*args, output=None, **kwargs):
'label': 'Dummy',
},
},
), mock.patch('reportcreator_api.tasks.rendering.render.render_pdf_impl', mock_render_pdf):
), mock.patch('reportcreator_api.tasks.rendering.render.render_pdf_impl', mock_render_pdf), \
mock.patch('reportcreator_api.tasks.tasks.activate_license', return_value=None):
user_map = {
'public': lambda: None,
'guest': lambda: ApiRequestsAndPermissionsTestData.create_user(is_guest=True),
Expand Down
107 changes: 76 additions & 31 deletions api/src/reportcreator_api/tests/test_license.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from uuid import uuid4

import pytest
from asgiref.sync import async_to_sync
from Cryptodome.Hash import SHA512
from Cryptodome.PublicKey import ECC
from Cryptodome.Signature import eddsa
Expand All @@ -16,6 +17,8 @@
from rest_framework import status
from rest_framework.test import APIClient

from reportcreator_api.tasks.models import LicenseActivationInfo
from reportcreator_api.tasks.tasks import activate_license
from reportcreator_api.tests.mock import (
api_client,
create_project,
Expand All @@ -33,6 +36,42 @@ def assert_api_license_error(res):
assert res.data['code'] == 'license'


def generate_signing_key():
private_key = ECC.generate(curve='ed25519')
public_key = {
'id': str(uuid4()),
'algorithm': 'ed25519',
'key': b64encode(private_key.public_key().export_key(format='DER')).decode(),
}
return private_key, public_key


def sign_license_data(license_data_str: str, public_key: dict, private_key):
signer = eddsa.new(key=private_key, mode='rfc8032')
signature = signer.sign(SHA512.new(license_data_str.encode()))
return {
'key_id': public_key['id'],
'algorithm': public_key['algorithm'],
'signature': b64encode(signature).decode(),
}


def sign_license(license_data, keys):
license_data_str = json.dumps(license_data)
return b64encode(json.dumps({
'data': license_data_str,
'signatures': [sign_license_data(license_data_str, k[0], k[1]) for k in keys],
}).encode()).decode()


def signed_license(keys, **kwargs):
return sign_license({
'users': 10,
'valid_from': (timezone.now() - timedelta(days=30)).date().isoformat(),
'valid_until': (timezone.now() + timedelta(days=30)).date().isoformat(),
} | kwargs, keys)


@pytest.mark.django_db()
class TestCommunityLicenseRestrictions:
@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -235,41 +274,12 @@ def test_user_count_limit(self):
class TestLicenseValidation:
@pytest.fixture(autouse=True)
def setUp(self):
self.license_private_key, self.license_public_key = self.generate_signing_key()
self.license_private_key, self.license_public_key = generate_signing_key()
with mock.patch('reportcreator_api.utils.license.LICENSE_VALIDATION_KEYS', new=[self.license_public_key]):
yield

def generate_signing_key(self):
private_key = ECC.generate(curve='ed25519')
public_key = {
'id': str(uuid4()),
'algorithm': 'ed25519',
'key': b64encode(private_key.public_key().export_key(format='DER')).decode(),
}
return private_key, public_key

def sign_license_data(self, license_data_str: str, public_key: dict, private_key):
signer = eddsa.new(key=private_key, mode='rfc8032')
signature = signer.sign(SHA512.new(license_data_str.encode()))
return {
'key_id': public_key['id'],
'algorithm': public_key['algorithm'],
'signature': b64encode(signature).decode(),
}

def sign_license(self, license_data, keys):
license_data_str = json.dumps(license_data)
return b64encode(json.dumps({
'data': license_data_str,
'signatures': [self.sign_license_data(license_data_str, k[0], k[1]) for k in keys],
}).encode()).decode()

def signed_license(self, **kwargs):
return self.sign_license({
'users': 10,
'valid_from': (timezone.now() - timedelta(days=30)).date().isoformat(),
'valid_until': (timezone.now() + timedelta(days=30)).date().isoformat(),
} | kwargs, [(self.license_public_key, self.license_private_key)])
return signed_license(keys=[(self.license_public_key, self.license_private_key)], **kwargs)

@pytest.mark.parametrize(('license_str', 'error'), [
(None, None),
Expand Down Expand Up @@ -339,3 +349,38 @@ def test_multiple_signatures_only_1_valid(self):
license_2 = b64encode(json.dumps(license_content).encode())
license_info = license.decode_and_validate_license(license_2)
assert license_info['type'] == license.LicenseType.PROFESSIONAL


@pytest.mark.django_db()
class TestLicenseActivationInfo:
@pytest.fixture(autouse=True)
def setUp(self):
self.license_private_key, self.license_public_key = generate_signing_key()
self.license_community = None
self.license_invalid = 'invalid license string'
self.license_professional = signed_license(keys=[(self.license_public_key, self.license_private_key)])
self.license_professional2 = signed_license(keys=[(self.license_public_key, self.license_private_key)], users=20)
self.license_expired = signed_license(keys=[(self.license_public_key, self.license_private_key)], valid_until=(timezone.now() - timedelta(days=1)).date().isoformat())

with mock.patch('reportcreator_api.utils.license.LICENSE_VALIDATION_KEYS', new=[self.license_public_key]), \
mock.patch('reportcreator_api.utils.license.check_license', new=lambda: license.decode_and_validate_license(settings.LICENSE)), \
mock.patch('reportcreator_api.tasks.tasks.activate_license_request', return_value={'status': 'ok', 'license_info': {'last_activation_time': timezone.now().isoformat()}}):
yield

@pytest.mark.parametrize(('name_old', 'name_new', 'created'), [
('license_community', 'license_community', False),
('license_community', 'license_invalid', False),
('license_community', 'license_professional', True),
('license_professional', 'license_professional', False),
('license_professional', 'license_professional2', True),
('license_professional', 'license_community', True),
('license_professional', 'license_expired', True),
])
def test_license_activation_info_created(self, name_old, name_new, created):
with override_settings(LICENSE=getattr(self, name_old)):
activation_info_old = LicenseActivationInfo.objects.current()
with override_settings(LICENSE=getattr(self, name_new)):
async_to_sync(activate_license)(None)
activation_info_new = LicenseActivationInfo.objects.order_by('-created').first()

assert (activation_info_old != activation_info_new) == created
1 change: 1 addition & 0 deletions api/src/reportcreator_api/users/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def perform_login(self, request, user, can_reauth=True):
return Response({
'status': 'success',
'first_login': first_login,
'license': license.get_license_info(),
}, status=status.HTTP_200_OK)

@action(detail=False, url_path='login/oidc/(?P<oidc_provider>[a-zA-Z0-9]+)/begin', methods=['get'], permission_classes=[license.ProfessionalLicenseRequired])
Expand Down
Loading

0 comments on commit cca3fba

Please sign in to comment.