Skip to content

Commit

Permalink
License activation info
Browse files Browse the repository at this point in the history
MWedl committed Jan 14, 2025
1 parent cd04770 commit 0a87f39
Showing 10 changed files with 195 additions and 35 deletions.
3 changes: 2 additions & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
@@ -82,7 +82,8 @@ extend-select = [
ignore = [
"E741", # Ambiguous variable name: `l`
"DJ001", # Avoid using null=True on string-based fields in django models
"PT004", # Fixture setUp does not return anything, add leading underscore
"S101", # Use of `assert` detected
]

[tool.ruff.lint.flake8-pytest-style]
mark-parentheses = true
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Generated by Django 5.1.4 on 2025-01-09 09:24

import reportcreator_api.utils.models
import uuid
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('tasks', '0001_initial'),
]

operations = [
migrations.CreateModel(
name='LicenseActivationInfo',
fields=[
('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)),
('created', models.DateTimeField(default=reportcreator_api.utils.models.now, editable=False)),
('updated', models.DateTimeField(auto_now=True)),
('license_type', models.CharField(choices=[('community', 'Community'), ('professional', 'Professional')], max_length=255)),
('license_hash', models.CharField(max_length=255, null=True)),
('last_activation_time', models.DateTimeField(blank=True, null=True)),
],
options={
'ordering': ['-created'],
'abstract': False,
},
),
]
9 changes: 9 additions & 0 deletions api/src/reportcreator_api/tasks/models.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from django.utils import timezone

from reportcreator_api.tasks import querysets
from reportcreator_api.utils import license
from reportcreator_api.utils.models import BaseModel


@@ -71,3 +72,11 @@ def inner(func):
))
return func
return inner


class LicenseActivationInfo(BaseModel):
license_type = models.CharField(max_length=255, choices=license.LicenseType.choices)
license_hash = models.CharField(max_length=255, null=True)
last_activation_time = models.DateTimeField(null=True, blank=True)

objects = querysets.LicenseActivationInfoManager()
18 changes: 18 additions & 0 deletions api/src/reportcreator_api/tasks/querysets.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,8 @@
from django.db import IntegrityError, models
from django.utils import timezone

from reportcreator_api.utils import license

log = logging.getLogger(__name__)


@@ -82,3 +84,19 @@ async def run_task(self, task_info):
async def run_all_pending_tasks(self):
for t in await self.get_pending_tasks():
await self.run_task(t)


class LicenseActivationInfoManager(models.Manager.from_queryset(models.QuerySet)):
def current(self):
obj = self.order_by('-created').first()

if not license.is_professional():
if not obj or obj.license_type != license.LicenseType.COMMUNITY:
obj = self.create(license_type=license.LicenseType.COMMUNITY)
else:
if not obj or obj.license_type != license.LicenseType.PROFESSIONAL or obj.license_hash != license.get_license_hash():
obj = self.create(
license_type=license.LicenseType.PROFESSIONAL,
license_hash=license.get_license_hash(),
)
return obj
39 changes: 38 additions & 1 deletion api/src/reportcreator_api/tasks/tasks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,47 @@
import json
import logging
from datetime import timedelta

import httpx
from asgiref.sync import sync_to_async
from django.core.management import call_command
from django.core.serializers.json import DjangoJSONEncoder
from django.utils import dateparse

from reportcreator_api.tasks.models import periodic_task
from reportcreator_api.tasks.models import LicenseActivationInfo, TaskStatus, periodic_task
from reportcreator_api.utils import license


@periodic_task(id='clear_sessions', schedule=timedelta(days=1))
def clear_sessions(task_info):
call_command('clearsessions')


async def activate_license_request(license_info):
async with httpx.AsyncClient(timeout=10) as client:
res = await client.post(
url='https://panel.sysreptor.com/api/v1/licenses/activate/',
headers={'Content-Type': 'application/json'},
data=json.dumps(license_info, cls=DjangoJSONEncoder),
)
res.raise_for_status()
return res.json()


@periodic_task(id='activate_license', schedule=timedelta(days=1))
async def activate_license(task_info):
activation_info = await sync_to_async(LicenseActivationInfo.objects.current)()
if not await license.ais_professional():
return TaskStatus.SUCCESS

try:
res = await activate_license_request(await license.aget_license_info())
if res.get('status') == 'ok':
try:
activation_info.last_activation_time = dateparse.parse_datetime(res.get('license_info', {}).get('last_activation_time'))
except (TypeError, ValueError):
activation_info.last_activation_time = None
await activation_info.asave()
except httpx.TransportError as ex:
logging.warning(f'Failed to activate license: {ex}. Check your internet connection.')
return TaskStatus.FAILED
3 changes: 2 additions & 1 deletion api/src/reportcreator_api/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -508,7 +508,8 @@ async def mock_render_pdf(*args, output=None, **kwargs):
'label': 'Dummy',
},
},
), mock.patch('reportcreator_api.tasks.rendering.render.render_pdf_impl', mock_render_pdf):
), mock.patch('reportcreator_api.tasks.rendering.render.render_pdf_impl', mock_render_pdf), \
mock.patch('reportcreator_api.tasks.tasks.activate_license', return_value=None):
user_map = {
'public': lambda: None,
'guest': lambda: ApiRequestsAndPermissionsTestData.create_user(is_guest=True),
107 changes: 76 additions & 31 deletions api/src/reportcreator_api/tests/test_license.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from uuid import uuid4

import pytest
from asgiref.sync import async_to_sync
from Cryptodome.Hash import SHA512
from Cryptodome.PublicKey import ECC
from Cryptodome.Signature import eddsa
@@ -16,6 +17,8 @@
from rest_framework import status
from rest_framework.test import APIClient

from reportcreator_api.tasks.models import LicenseActivationInfo
from reportcreator_api.tasks.tasks import activate_license
from reportcreator_api.tests.mock import (
api_client,
create_project,
@@ -33,6 +36,42 @@ def assert_api_license_error(res):
assert res.data['code'] == 'license'


def generate_signing_key():
private_key = ECC.generate(curve='ed25519')
public_key = {
'id': str(uuid4()),
'algorithm': 'ed25519',
'key': b64encode(private_key.public_key().export_key(format='DER')).decode(),
}
return private_key, public_key


def sign_license_data(license_data_str: str, public_key: dict, private_key):
signer = eddsa.new(key=private_key, mode='rfc8032')
signature = signer.sign(SHA512.new(license_data_str.encode()))
return {
'key_id': public_key['id'],
'algorithm': public_key['algorithm'],
'signature': b64encode(signature).decode(),
}


def sign_license(license_data, keys):
license_data_str = json.dumps(license_data)
return b64encode(json.dumps({
'data': license_data_str,
'signatures': [sign_license_data(license_data_str, k[0], k[1]) for k in keys],
}).encode()).decode()


def signed_license(keys, **kwargs):
return sign_license({
'users': 10,
'valid_from': (timezone.now() - timedelta(days=30)).date().isoformat(),
'valid_until': (timezone.now() + timedelta(days=30)).date().isoformat(),
} | kwargs, keys)


@pytest.mark.django_db()
class TestCommunityLicenseRestrictions:
@pytest.fixture(autouse=True)
@@ -235,41 +274,12 @@ def test_user_count_limit(self):
class TestLicenseValidation:
@pytest.fixture(autouse=True)
def setUp(self):
self.license_private_key, self.license_public_key = self.generate_signing_key()
self.license_private_key, self.license_public_key = generate_signing_key()
with mock.patch('reportcreator_api.utils.license.LICENSE_VALIDATION_KEYS', new=[self.license_public_key]):
yield

def generate_signing_key(self):
private_key = ECC.generate(curve='ed25519')
public_key = {
'id': str(uuid4()),
'algorithm': 'ed25519',
'key': b64encode(private_key.public_key().export_key(format='DER')).decode(),
}
return private_key, public_key

def sign_license_data(self, license_data_str: str, public_key: dict, private_key):
signer = eddsa.new(key=private_key, mode='rfc8032')
signature = signer.sign(SHA512.new(license_data_str.encode()))
return {
'key_id': public_key['id'],
'algorithm': public_key['algorithm'],
'signature': b64encode(signature).decode(),
}

def sign_license(self, license_data, keys):
license_data_str = json.dumps(license_data)
return b64encode(json.dumps({
'data': license_data_str,
'signatures': [self.sign_license_data(license_data_str, k[0], k[1]) for k in keys],
}).encode()).decode()

def signed_license(self, **kwargs):
return self.sign_license({
'users': 10,
'valid_from': (timezone.now() - timedelta(days=30)).date().isoformat(),
'valid_until': (timezone.now() + timedelta(days=30)).date().isoformat(),
} | kwargs, [(self.license_public_key, self.license_private_key)])
return signed_license(keys=[(self.license_public_key, self.license_private_key)], **kwargs)

@pytest.mark.parametrize(('license_str', 'error'), [
(None, None),
@@ -339,3 +349,38 @@ def test_multiple_signatures_only_1_valid(self):
license_2 = b64encode(json.dumps(license_content).encode())
license_info = license.decode_and_validate_license(license_2)
assert license_info['type'] == license.LicenseType.PROFESSIONAL


@pytest.mark.django_db()
class TestLicenseActivationInfo:
@pytest.fixture(autouse=True)
def setUp(self):
self.license_private_key, self.license_public_key = generate_signing_key()
self.license_community = None
self.license_invalid = 'invalid license string'
self.license_professional = signed_license(keys=[(self.license_public_key, self.license_private_key)])
self.license_professional2 = signed_license(keys=[(self.license_public_key, self.license_private_key)], users=20)
self.license_expired = signed_license(keys=[(self.license_public_key, self.license_private_key)], valid_until=(timezone.now() - timedelta(days=1)).date().isoformat())

with mock.patch('reportcreator_api.utils.license.LICENSE_VALIDATION_KEYS', new=[self.license_public_key]), \
mock.patch('reportcreator_api.utils.license.check_license', new=lambda: license.decode_and_validate_license(settings.LICENSE)), \
mock.patch('reportcreator_api.tasks.tasks.activate_license_request', return_value={'status': 'ok', 'license_info': {'last_activation_time': timezone.now().isoformat()}}):
yield

@pytest.mark.parametrize(('name_old', 'name_new', 'created'), [
('license_community', 'license_community', False),
('license_community', 'license_invalid', False),
('license_community', 'license_professional', True),
('license_professional', 'license_professional', False),
('license_professional', 'license_professional2', True),
('license_professional', 'license_community', True),
('license_professional', 'license_expired', True),
])
def test_license_activation_info_created(self, name_old, name_new, created):
with override_settings(LICENSE=getattr(self, name_old)):
activation_info_old = LicenseActivationInfo.objects.current()
with override_settings(LICENSE=getattr(self, name_new)):
async_to_sync(activate_license)(None)
activation_info_new = LicenseActivationInfo.objects.order_by('-created').first()

assert (activation_info_old != activation_info_new) == created
1 change: 1 addition & 0 deletions api/src/reportcreator_api/users/views.py
Original file line number Diff line number Diff line change
@@ -413,6 +413,7 @@ def perform_login(self, request, user, can_reauth=True):
return Response({
'status': 'success',
'first_login': first_login,
'license': license.get_license_info(),
}, status=status.HTTP_200_OK)

@action(detail=False, url_path='login/oidc/(?P<oidc_provider>[a-zA-Z0-9]+)/begin', methods=['get'], permission_classes=[license.ProfessionalLicenseRequired])
18 changes: 18 additions & 0 deletions api/src/reportcreator_api/utils/license.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import hashlib
import json
import logging

@@ -123,6 +124,15 @@ def decode_and_validate_license(license, skip_db_checks=False, skip_limit_valida
}


def get_license_hash():
if not settings.LICENSE:
return None
try:
return 'sha3_256$$' + hashlib.sha3_256(base64.b64decode(settings.LICENSE)).hexdigest()
except Exception:
return None


@cache('license.license_info', timeout=10 * 60)
def check_license(**kwargs):
return decode_and_validate_license(license=settings.LICENSE, **kwargs)
@@ -134,14 +144,22 @@ async def acheck_license(**kwargs):

def get_license_info():
from reportcreator_api.conf import plugins
from reportcreator_api.tasks.models import LicenseActivationInfo
from reportcreator_api.users.models import PentestUser

activation_info = LicenseActivationInfo.objects.current()
return check_license() | {
'license_hash': get_license_hash(),
'active_users': PentestUser.objects.get_licensed_user_count(),
'total_users': PentestUser.objects.get_total_user_count(),
'installation_id': settings.INSTALLATION_ID,
'software_version': settings.VERSION,
'plugins': [p.name.split('.')[-1] for p in plugins.enabled_plugins],
'activation_info': {
'created': activation_info.created,
'license_hash': activation_info.license_hash,
'last_activation_time': activation_info.last_activation_time,
},
}

async def aget_license_info():
2 changes: 1 addition & 1 deletion plugins/webhooks/tests/test_webhooks.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@ def update(obj, **kwargs):
return obj


@pytest.mark.django_db
@pytest.mark.django_db()
class TestWebhooksCalled:
@pytest.fixture(autouse=True)
def setUp(self):

0 comments on commit 0a87f39

Please sign in to comment.