diff --git a/api/pyproject.toml b/api/pyproject.toml index 9d5f55e6c..7c8e2d0e3 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -82,7 +82,8 @@ extend-select = [ ignore = [ "E741", # Ambiguous variable name: `l` "DJ001", # Avoid using null=True on string-based fields in django models - "PT004", # Fixture setUp does not return anything, add leading underscore "S101", # Use of `assert` detected ] +[tool.ruff.lint.flake8-pytest-style] +mark-parentheses = true \ No newline at end of file diff --git a/api/src/reportcreator_api/api_utils/views.py b/api/src/reportcreator_api/api_utils/views.py index affde5b73..1af8da653 100644 --- a/api/src/reportcreator_api/api_utils/views.py +++ b/api/src/reportcreator_api/api_utils/views.py @@ -16,7 +16,7 @@ from reportcreator_api.api_utils import backup_utils from reportcreator_api.api_utils.healthchecks import run_healthchecks from reportcreator_api.api_utils.models import BackupLog -from reportcreator_api.api_utils.permissions import IsAdminOrSystem, IsUserManagerOrSuperuserOrSystem +from reportcreator_api.api_utils.permissions import IsAdminOrSystem from reportcreator_api.api_utils.serializers import ( BackupLogSerializer, BackupSerializer, @@ -115,7 +115,7 @@ def backuplogs(self, request, *args, **kwargs): return self.get_paginated_response(serializer.data) @extend_schema(responses=OpenApiTypes.OBJECT) - @action(detail=False, url_name='license', url_path='license', methods=['get'], permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES + [IsUserManagerOrSuperuserOrSystem]) + @action(detail=False, url_name='license', url_path='license', methods=['get']) async def license_info(self, request, *args, **kwargs): return Response(data=await license.aget_license_info()) diff --git a/api/src/reportcreator_api/conf/settings.py b/api/src/reportcreator_api/conf/settings.py index ff4dc8bcd..da15c211b 100644 --- a/api/src/reportcreator_api/conf/settings.py +++ b/api/src/reportcreator_api/conf/settings.py @@ -428,7 +428,7 @@ def remove_empty_items(lst=None): 'img-src': [SELF, 'data:'], 'font-src': [SELF], 'worker-src': [SELF], - 'connect-src': [SELF, 'data:'], + 'connect-src': [SELF, 'data:', 'https://portal.sysreptor.com'], 'frame-src': [SELF], 'frame-ancestors': [SELF], 'form-action': [SELF], diff --git a/api/src/reportcreator_api/tasks/migrations/0002_licenseactivationinfo.py b/api/src/reportcreator_api/tasks/migrations/0002_licenseactivationinfo.py new file mode 100644 index 000000000..db8a78d02 --- /dev/null +++ b/api/src/reportcreator_api/tasks/migrations/0002_licenseactivationinfo.py @@ -0,0 +1,30 @@ +# Generated by Django 5.1.4 on 2025-01-09 09:24 + +import reportcreator_api.utils.models +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('tasks', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='LicenseActivationInfo', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('created', models.DateTimeField(default=reportcreator_api.utils.models.now, editable=False)), + ('updated', models.DateTimeField(auto_now=True)), + ('license_type', models.CharField(choices=[('community', 'Community'), ('professional', 'Professional')], max_length=255)), + ('license_hash', models.CharField(max_length=255, null=True)), + ('last_activation_time', models.DateTimeField(blank=True, null=True)), + ], + options={ + 'ordering': ['-created'], + 'abstract': False, + }, + ), + ] diff --git a/api/src/reportcreator_api/tasks/models.py b/api/src/reportcreator_api/tasks/models.py index 934fa6c45..de8f293ce 100644 --- a/api/src/reportcreator_api/tasks/models.py +++ b/api/src/reportcreator_api/tasks/models.py @@ -5,6 +5,7 @@ from django.utils import timezone from reportcreator_api.tasks import querysets +from reportcreator_api.utils import license from reportcreator_api.utils.models import BaseModel @@ -29,6 +30,8 @@ class PeriodicTaskSpec: id: str schedule: timedelta func: callable + retry: timedelta = timedelta(minutes=10) + max_runtime: timedelta = timedelta(minutes=10) @dataclasses.dataclass() @@ -59,12 +62,21 @@ def unregister(self, task: PeriodicTaskSpec): periodic_task_registry = PeriodicTaskRegistry() -def periodic_task(schedule: timedelta, id: str|None = None): +def periodic_task(schedule: timedelta, id: str|None = None, retry: timedelta|None = None): def inner(func): periodic_task_registry.register(PeriodicTaskSpec( id=id or f'{func.__module__}.{func.__name__}', schedule=schedule, + retry=retry or (schedule / 10), func=func, )) return func return inner + + +class LicenseActivationInfo(BaseModel): + license_type = models.CharField(max_length=255, choices=license.LicenseType.choices) + license_hash = models.CharField(max_length=255, null=True) + last_activation_time = models.DateTimeField(null=True, blank=True) + + objects = querysets.LicenseActivationInfoManager() diff --git a/api/src/reportcreator_api/tasks/querysets.py b/api/src/reportcreator_api/tasks/querysets.py index d9af1d9e9..088b99b0b 100644 --- a/api/src/reportcreator_api/tasks/querysets.py +++ b/api/src/reportcreator_api/tasks/querysets.py @@ -1,11 +1,12 @@ import logging -from datetime import timedelta import elasticapm from asgiref.sync import iscoroutinefunction, sync_to_async from django.db import IntegrityError, models from django.utils import timezone +from reportcreator_api.utils import license + log = logging.getLogger(__name__) @@ -19,8 +20,8 @@ async def get_pending_tasks(self): model = task_models.get(t_id) # Remove non-pending tasks if model and ( - (model.status == TaskStatus.RUNNING and model.started > timezone.now() - timedelta(minutes=10)) or \ - (model.status == TaskStatus.FAILED and model.started > timezone.now() - timedelta(minutes=10)) or \ + (model.status == TaskStatus.RUNNING and model.started > timezone.now() - spec.max_runtime) or \ + (model.status == TaskStatus.FAILED and model.started > timezone.now() - spec.retry) or \ (model.status == TaskStatus.SUCCESS and model.started > timezone.now() - spec.schedule) ): continue @@ -83,3 +84,19 @@ async def run_task(self, task_info): async def run_all_pending_tasks(self): for t in await self.get_pending_tasks(): await self.run_task(t) + + +class LicenseActivationInfoManager(models.Manager.from_queryset(models.QuerySet)): + def current(self): + obj = self.order_by('-created').first() + + if not license.is_professional(): + if not obj or obj.license_type != license.LicenseType.COMMUNITY: + obj = self.create(license_type=license.LicenseType.COMMUNITY) + else: + if not obj or obj.license_type != license.LicenseType.PROFESSIONAL or obj.license_hash != license.get_license_hash(): + obj = self.create( + license_type=license.LicenseType.PROFESSIONAL, + license_hash=license.get_license_hash(), + ) + return obj diff --git a/api/src/reportcreator_api/tasks/tasks.py b/api/src/reportcreator_api/tasks/tasks.py index cc1d1da95..c84ee663a 100644 --- a/api/src/reportcreator_api/tasks/tasks.py +++ b/api/src/reportcreator_api/tasks/tasks.py @@ -1,10 +1,47 @@ +import json +import logging from datetime import timedelta +import httpx +from asgiref.sync import sync_to_async from django.core.management import call_command +from django.core.serializers.json import DjangoJSONEncoder +from django.utils import dateparse -from reportcreator_api.tasks.models import periodic_task +from reportcreator_api.tasks.models import LicenseActivationInfo, TaskStatus, periodic_task +from reportcreator_api.utils import license @periodic_task(id='clear_sessions', schedule=timedelta(days=1)) def clear_sessions(task_info): call_command('clearsessions') + + +async def activate_license_request(license_info): + async with httpx.AsyncClient(timeout=10) as client: + res = await client.post( + url='https://portal.sysreptor.com/api/v1/licenses/activate/', + headers={'Content-Type': 'application/json'}, + data=json.dumps(license_info, cls=DjangoJSONEncoder), + ) + res.raise_for_status() + return res.json() + + +@periodic_task(id='activate_license', schedule=timedelta(days=1)) +async def activate_license(task_info): + activation_info = await sync_to_async(LicenseActivationInfo.objects.current)() + if not await license.ais_professional(): + return TaskStatus.SUCCESS + + try: + res = await activate_license_request(await license.aget_license_info()) + if res.get('status') == 'ok': + try: + activation_info.last_activation_time = dateparse.parse_datetime(res.get('license_info', {}).get('last_activation_time')) + except (TypeError, ValueError): + activation_info.last_activation_time = None + await activation_info.asave() + except httpx.TransportError as ex: + logging.warning(f'Failed to activate license: {ex}. Check your internet connection.') + return TaskStatus.FAILED diff --git a/api/src/reportcreator_api/tests/test_api.py b/api/src/reportcreator_api/tests/test_api.py index 5515d3545..a6d56a6ad 100644 --- a/api/src/reportcreator_api/tests/test_api.py +++ b/api/src/reportcreator_api/tests/test_api.py @@ -218,6 +218,7 @@ def guest_urls(): return [ ('utils list', lambda s, c: c.get(reverse('utils-list'))), ('utils cwes', lambda s, c: c.get(reverse('utils-cwes'))), + ('utils-license', lambda s, c: c.get(reverse('utils-license'))), *viewset_urls('pentestuser', get_kwargs=lambda s, detail: {'pk': 'self'}, retrieve=True, update=True, update_partial=True), *viewset_urls('pentestuser', get_kwargs=lambda s, detail: {}, list=True), @@ -304,7 +305,6 @@ def user_manager_urls(): *viewset_urls('authidentity', get_kwargs=lambda s, detail: {'pentestuser_pk': s.user_other.pk} | ({'pk': s.user_other.auth_identities.first().pk} if detail else {}), list=True, retrieve=True, create=True, create_data={'identifier': 'other.identifier'}, update=True, update_partial=True, destroy=True), *viewset_urls('apitoken', get_kwargs=lambda s, detail: {'pentestuser_pk': s.user_other.pk} | ({'pk': s.user_other.api_tokens.first().pk} if detail else {}), list=True, retrieve=True, destroy=True), *viewset_urls('userpublickey', get_kwargs=lambda s, detail: {'pentestuser_pk': s.user_other.pk} | ({'pk': s.user_other.public_keys.first().pk} if detail else {}), list=True, retrieve=True), - ('utils-license', lambda s, c: c.get(reverse('utils-license'))), ] @@ -508,7 +508,8 @@ async def mock_render_pdf(*args, output=None, **kwargs): 'label': 'Dummy', }, }, - ), mock.patch('reportcreator_api.tasks.rendering.render.render_pdf_impl', mock_render_pdf): + ), mock.patch('reportcreator_api.tasks.rendering.render.render_pdf_impl', mock_render_pdf), \ + mock.patch('reportcreator_api.tasks.tasks.activate_license', return_value=None): user_map = { 'public': lambda: None, 'guest': lambda: ApiRequestsAndPermissionsTestData.create_user(is_guest=True), diff --git a/api/src/reportcreator_api/tests/test_license.py b/api/src/reportcreator_api/tests/test_license.py index c747a9d51..982eb3fa9 100644 --- a/api/src/reportcreator_api/tests/test_license.py +++ b/api/src/reportcreator_api/tests/test_license.py @@ -5,6 +5,7 @@ from uuid import uuid4 import pytest +from asgiref.sync import async_to_sync from Cryptodome.Hash import SHA512 from Cryptodome.PublicKey import ECC from Cryptodome.Signature import eddsa @@ -16,6 +17,8 @@ from rest_framework import status from rest_framework.test import APIClient +from reportcreator_api.tasks.models import LicenseActivationInfo +from reportcreator_api.tasks.tasks import activate_license from reportcreator_api.tests.mock import ( api_client, create_project, @@ -33,6 +36,42 @@ def assert_api_license_error(res): assert res.data['code'] == 'license' +def generate_signing_key(): + private_key = ECC.generate(curve='ed25519') + public_key = { + 'id': str(uuid4()), + 'algorithm': 'ed25519', + 'key': b64encode(private_key.public_key().export_key(format='DER')).decode(), + } + return private_key, public_key + + +def sign_license_data(license_data_str: str, public_key: dict, private_key): + signer = eddsa.new(key=private_key, mode='rfc8032') + signature = signer.sign(SHA512.new(license_data_str.encode())) + return { + 'key_id': public_key['id'], + 'algorithm': public_key['algorithm'], + 'signature': b64encode(signature).decode(), + } + + +def sign_license(license_data, keys): + license_data_str = json.dumps(license_data) + return b64encode(json.dumps({ + 'data': license_data_str, + 'signatures': [sign_license_data(license_data_str, k[0], k[1]) for k in keys], + }).encode()).decode() + + +def signed_license(keys, **kwargs): + return sign_license({ + 'users': 10, + 'valid_from': (timezone.now() - timedelta(days=30)).date().isoformat(), + 'valid_until': (timezone.now() + timedelta(days=30)).date().isoformat(), + } | kwargs, keys) + + @pytest.mark.django_db() class TestCommunityLicenseRestrictions: @pytest.fixture(autouse=True) @@ -235,41 +274,12 @@ def test_user_count_limit(self): class TestLicenseValidation: @pytest.fixture(autouse=True) def setUp(self): - self.license_private_key, self.license_public_key = self.generate_signing_key() + self.license_private_key, self.license_public_key = generate_signing_key() with mock.patch('reportcreator_api.utils.license.LICENSE_VALIDATION_KEYS', new=[self.license_public_key]): yield - def generate_signing_key(self): - private_key = ECC.generate(curve='ed25519') - public_key = { - 'id': str(uuid4()), - 'algorithm': 'ed25519', - 'key': b64encode(private_key.public_key().export_key(format='DER')).decode(), - } - return private_key, public_key - - def sign_license_data(self, license_data_str: str, public_key: dict, private_key): - signer = eddsa.new(key=private_key, mode='rfc8032') - signature = signer.sign(SHA512.new(license_data_str.encode())) - return { - 'key_id': public_key['id'], - 'algorithm': public_key['algorithm'], - 'signature': b64encode(signature).decode(), - } - - def sign_license(self, license_data, keys): - license_data_str = json.dumps(license_data) - return b64encode(json.dumps({ - 'data': license_data_str, - 'signatures': [self.sign_license_data(license_data_str, k[0], k[1]) for k in keys], - }).encode()).decode() - def signed_license(self, **kwargs): - return self.sign_license({ - 'users': 10, - 'valid_from': (timezone.now() - timedelta(days=30)).date().isoformat(), - 'valid_until': (timezone.now() + timedelta(days=30)).date().isoformat(), - } | kwargs, [(self.license_public_key, self.license_private_key)]) + return signed_license(keys=[(self.license_public_key, self.license_private_key)], **kwargs) @pytest.mark.parametrize(('license_str', 'error'), [ (None, None), @@ -339,3 +349,38 @@ def test_multiple_signatures_only_1_valid(self): license_2 = b64encode(json.dumps(license_content).encode()) license_info = license.decode_and_validate_license(license_2) assert license_info['type'] == license.LicenseType.PROFESSIONAL + + +@pytest.mark.django_db() +class TestLicenseActivationInfo: + @pytest.fixture(autouse=True) + def setUp(self): + self.license_private_key, self.license_public_key = generate_signing_key() + self.license_community = None + self.license_invalid = 'invalid license string' + self.license_professional = signed_license(keys=[(self.license_public_key, self.license_private_key)]) + self.license_professional2 = signed_license(keys=[(self.license_public_key, self.license_private_key)], users=20) + self.license_expired = signed_license(keys=[(self.license_public_key, self.license_private_key)], valid_until=(timezone.now() - timedelta(days=1)).date().isoformat()) + + with mock.patch('reportcreator_api.utils.license.LICENSE_VALIDATION_KEYS', new=[self.license_public_key]), \ + mock.patch('reportcreator_api.utils.license.check_license', new=lambda: license.decode_and_validate_license(settings.LICENSE)), \ + mock.patch('reportcreator_api.tasks.tasks.activate_license_request', return_value={'status': 'ok', 'license_info': {'last_activation_time': timezone.now().isoformat()}}): + yield + + @pytest.mark.parametrize(('name_old', 'name_new', 'created'), [ + ('license_community', 'license_community', False), + ('license_community', 'license_invalid', False), + ('license_community', 'license_professional', True), + ('license_professional', 'license_professional', False), + ('license_professional', 'license_professional2', True), + ('license_professional', 'license_community', True), + ('license_professional', 'license_expired', True), + ]) + def test_license_activation_info_created(self, name_old, name_new, created): + with override_settings(LICENSE=getattr(self, name_old)): + activation_info_old = LicenseActivationInfo.objects.current() + with override_settings(LICENSE=getattr(self, name_new)): + async_to_sync(activate_license)(None) + activation_info_new = LicenseActivationInfo.objects.order_by('-created').first() + + assert (activation_info_old != activation_info_new) == created diff --git a/api/src/reportcreator_api/users/views.py b/api/src/reportcreator_api/users/views.py index 065528755..b0d7f8bfe 100644 --- a/api/src/reportcreator_api/users/views.py +++ b/api/src/reportcreator_api/users/views.py @@ -413,6 +413,7 @@ def perform_login(self, request, user, can_reauth=True): return Response({ 'status': 'success', 'first_login': first_login, + 'license': license.get_license_info(), }, status=status.HTTP_200_OK) @action(detail=False, url_path='login/oidc/(?P[a-zA-Z0-9]+)/begin', methods=['get'], permission_classes=[license.ProfessionalLicenseRequired]) diff --git a/api/src/reportcreator_api/utils/license.py b/api/src/reportcreator_api/utils/license.py index 1467712fa..4a95c1de2 100644 --- a/api/src/reportcreator_api/utils/license.py +++ b/api/src/reportcreator_api/utils/license.py @@ -1,4 +1,5 @@ import base64 +import hashlib import json import logging @@ -123,6 +124,15 @@ def decode_and_validate_license(license, skip_db_checks=False, skip_limit_valida } +def get_license_hash(): + if not settings.LICENSE: + return None + try: + return 'sha3_256$$' + hashlib.sha3_256(base64.b64decode(settings.LICENSE)).hexdigest() + except Exception: + return None + + @cache('license.license_info', timeout=10 * 60) def check_license(**kwargs): return decode_and_validate_license(license=settings.LICENSE, **kwargs) @@ -134,14 +144,22 @@ async def acheck_license(**kwargs): def get_license_info(): from reportcreator_api.conf import plugins + from reportcreator_api.tasks.models import LicenseActivationInfo from reportcreator_api.users.models import PentestUser + activation_info = LicenseActivationInfo.objects.current() return check_license() | { + 'license_hash': get_license_hash(), 'active_users': PentestUser.objects.get_licensed_user_count(), 'total_users': PentestUser.objects.get_total_user_count(), 'installation_id': settings.INSTALLATION_ID, 'software_version': settings.VERSION, 'plugins': [p.name.split('.')[-1] for p in plugins.enabled_plugins], + 'activation_info': { + 'created': activation_info.created, + 'license_hash': activation_info.license_hash, + 'last_activation_time': activation_info.last_activation_time, + }, } async def aget_license_info(): diff --git a/packages/frontend/src/components/LicenseInfoMenuItem.vue b/packages/frontend/src/components/LicenseInfoMenuItem.vue index 2fd5ea238..78f0e5262 100644 --- a/packages/frontend/src/components/LicenseInfoMenuItem.vue +++ b/packages/frontend/src/components/LicenseInfoMenuItem.vue @@ -1,5 +1,5 @@ - + + + + + + + diff --git a/packages/frontend/src/pages/login/auto.vue b/packages/frontend/src/pages/login/auto.vue index 3e1fa8882..cd21dff64 100644 --- a/packages/frontend/src/pages/login/auto.vue +++ b/packages/frontend/src/pages/login/auto.vue @@ -55,7 +55,7 @@ const { error } = useAsyncData(async () => { } else if (res.status !== LoginResponseStatus.SUCCESS) { throw new Error(`Login failed: ${res.status}`); } - await auth.fetchUser(); + await auth.finishLogin(res); await auth.redirect(); } catch (error: any) { if (error?.data?.detail) { diff --git a/packages/frontend/src/pages/login/oidc/[authProviderId]/callback.vue b/packages/frontend/src/pages/login/oidc/[authProviderId]/callback.vue index 7af50c722..7c9a275fd 100644 --- a/packages/frontend/src/pages/login/oidc/[authProviderId]/callback.vue +++ b/packages/frontend/src/pages/login/oidc/[authProviderId]/callback.vue @@ -44,11 +44,11 @@ const route = useRoute(); const auth = useAuth(); const { error } = useAsyncData(async () => { try { - await $fetch(`/api/v1/auth/login/oidc/${route.params.authProviderId}/complete/`, { + const res = await $fetch(`/api/v1/auth/login/oidc/${route.params.authProviderId}/complete/`, { method: 'GET', params: route.query, }); - await auth.fetchUser(); + await auth.finishLogin(res); await auth.redirect(); } catch (error: any) { if (error?.data?.detail) { diff --git a/packages/nuxt-base-layer/src/composables/apisettings.ts b/packages/nuxt-base-layer/src/composables/apisettings.ts index 13d4d4333..b40456bc6 100644 --- a/packages/nuxt-base-layer/src/composables/apisettings.ts +++ b/packages/nuxt-base-layer/src/composables/apisettings.ts @@ -1,4 +1,4 @@ -import { type ApiSettings, type AuthProvider, type CWE, AuthProviderType } from '#imports'; +import { type ApiSettings, type AuthProvider, type CWE, type LicenseInfoDetails, AuthProviderType } from '#imports'; export const useApiSettings = defineStore('apisettings', { state: () => ({ @@ -6,6 +6,8 @@ export const useApiSettings = defineStore('apisettings', { getSettingsSync: null as Promise | null, cwes: null as CWE[]|null, getCwesSync: null as Promise|null, + licenseInfo: null as LicenseInfoDetails|null, + getLicenseInfoSync: null as Promise|null, }), actions: { async fetchSettings() : Promise { @@ -47,6 +49,25 @@ export const useApiSettings = defineStore('apisettings', { this.getCwesSync = null; } } + }, + async fetchLicenseInfo(): Promise { + this.licenseInfo = await $fetch('/api/v1/utils/license/', { method: 'GET' }); + return this.licenseInfo!; + }, + async getLicenseInfo(): Promise { + if (this.licenseInfo) { + return this.licenseInfo; + } else if (this.getLicenseInfoSync) { + return await this.getLicenseInfoSync; + } else { + try { + this.getLicenseInfoSync = this.fetchLicenseInfo(); + return await this.getLicenseInfoSync; + } finally { + this.getLicenseInfoSync = null; + } + } + } }, getters: { diff --git a/packages/nuxt-base-layer/src/composables/auth.ts b/packages/nuxt-base-layer/src/composables/auth.ts index 5cc18c1e2..4040179b1 100644 --- a/packages/nuxt-base-layer/src/composables/auth.ts +++ b/packages/nuxt-base-layer/src/composables/auth.ts @@ -1,6 +1,6 @@ import type { NavigateToOptions } from '#app/composables/router' import type { LocationQueryValue } from "#vue-router"; -import { AuthProviderType, type AuthProvider, type User, useApiSettings } from "#imports"; +import { AuthProviderType, type AuthProvider, type User, useApiSettings, type LoginResponse, LoginResponseStatus } from "#imports"; export const useAuthStore = defineStore('auth', { state: () => ({ @@ -24,7 +24,6 @@ export const useAuthStore = defineStore('auth', { edit_projects: (state.user && (!state.user.is_guest || state.user.scope.includes('project_admin') || (state.user.is_guest && apiSettings.settings?.guest_permissions.edit_projects))) || false, share_notes: (apiSettings.settings?.features.sharing && state.user && (!state.user.is_guest || state.user.scope.includes('project_admin') || (state.user.is_guest && apiSettings.settings?.guest_permissions.share_notes))) || false, archive_projects: (apiSettings.settings?.features.archiving && state.user && (!state.user.is_guest || state.user.scope.includes('admin') || (state.user.is_guest && apiSettings.settings?.guest_permissions.update_project_settings))) || false, - view_license: state.user?.is_superuser || state.user?.is_user_manager || state.user?.is_system_user || false, view_backup: (apiSettings.isProfessionalLicense && state.user?.scope.includes('admin')) || false, }; }, @@ -100,16 +99,26 @@ export function useAuth() { store.user = null; } + async function finishLogin(response: LoginResponse) { + if (response.status !== LoginResponseStatus.SUCCESS) { + throw new Error('Login failed'); + } + const apiSettings = useApiSettings(); + apiSettings.licenseInfo = response.license!; + + return await fetchUser(); + } + async function authProviderLoginBegin(authProvider: AuthProvider, options = { reauth: false }) { if (authProvider.type === AuthProviderType.LOCAL) { await navigateTo('/login/local/'); } else if (authProvider.type === AuthProviderType.REMOTEUSER) { try { - await $fetch('/api/v1/auth/login/remoteuser/', { + const res = await $fetch('/api/v1/auth/login/remoteuser/', { method: 'POST', body: {} }); - await fetchUser(); + await finishLogin(res); await redirect(); } catch (error) { requestErrorToast({ error, message: 'Login failed' }); @@ -133,5 +142,6 @@ export function useAuth() { redirectToReAuth, fetchUser, authProviderLoginBegin, + finishLogin, }; } diff --git a/packages/nuxt-base-layer/src/utils/types.ts b/packages/nuxt-base-layer/src/utils/types.ts index 371fc5bea..517468b39 100644 --- a/packages/nuxt-base-layer/src/utils/types.ts +++ b/packages/nuxt-base-layer/src/utils/types.ts @@ -204,6 +204,7 @@ export type LoginResponse = { status: LoginResponseStatus, first_login?: boolean, mfa?: MfaMethod[], + license?: LicenseInfoDetails, } export type UserNotification = BaseModel & { diff --git a/plugins/webhooks/tests/test_webhooks.py b/plugins/webhooks/tests/test_webhooks.py index c2b1b2124..872e0ae99 100644 --- a/plugins/webhooks/tests/test_webhooks.py +++ b/plugins/webhooks/tests/test_webhooks.py @@ -30,7 +30,7 @@ def update(obj, **kwargs): return obj -@pytest.mark.django_db +@pytest.mark.django_db() class TestWebhooksCalled: @pytest.fixture(autouse=True) def setUp(self):