diff --git a/api/src/reportcreator_api/conf/plugins.py b/api/src/reportcreator_api/conf/plugins.py index afa1461b..9c84796a 100644 --- a/api/src/reportcreator_api/conf/plugins.py +++ b/api/src/reportcreator_api/conf/plugins.py @@ -6,7 +6,9 @@ from importlib import import_module from pathlib import Path +from decouple import config from django.apps import AppConfig, apps +from django.conf import settings from django.contrib.staticfiles import finders from django.contrib.staticfiles.finders import AppDirectoriesFinder, FileSystemFinder from django.core.exceptions import ImproperlyConfigured @@ -14,6 +16,8 @@ from django.utils.functional import classproperty from django.utils.module_loading import module_has_submodule +from reportcreator_api.utils import license + enabled_plugins = [] @@ -30,6 +34,11 @@ class PluginConfig(AppConfig): The plugin_id is used internally to uniquely identify the plugin and it's resources (e.g. DB tables, API endpoints, etc.). """ + professional_only: bool = False + """ + Indicates whether the plugin is only available in SysReptor professional or also in SysReptor community edition. + """ + frontend_settings = {} def __init__(self, *args, **kwargs) -> None: @@ -83,6 +92,9 @@ def _frontend_entry(self) -> str|None: return None def get_frontend_settings(self, request) -> dict: + """ + Dictionary with settings passed to the plugin's frontend implementation + """ return self.frontend_settings @@ -133,6 +145,12 @@ def remove_entry(path: Path): path.unlink(missing_ok=True) +def can_load_professional_plugins(): + license_text = getattr(settings, 'LICENSE', config('LICENSE', default=None)) + return license.decode_and_validate_license(license=license_text, skip_db_checks=True) \ + .get('type') == license.LicenseType.PROFESSIONAL + + def collect_plugins(dst: Path, srcs: list[Path]): # Collect plugins from all plugin directories all_module_dirs = [] @@ -233,6 +251,10 @@ def load_plugins(plugin_dirs: list[Path], enabled_plugins: list[str]): # Add to installed_apps app_class = plugin_config_class.__module__ + '.' + plugin_config_class.__name__ app_label = plugin_config_class.label + + if plugin_config_class.professional_only and not can_load_professional_plugins(): + logging.warning(f'Plugin "{plugin_name}" requires a professional license. Not enabling plugin.') + continue if app_class not in installed_apps: installed_apps.append(app_class) logging.info(f'Enabling plugin {plugin_name} ({plugin_id=}, {app_label=}, {app_class=})') diff --git a/api/src/reportcreator_api/conf/settings.py b/api/src/reportcreator_api/conf/settings.py index 7e81af39..c3520ca9 100644 --- a/api/src/reportcreator_api/conf/settings.py +++ b/api/src/reportcreator_api/conf/settings.py @@ -683,7 +683,6 @@ def __bool__(self): {'id': 'silver', 'algorithm': 'ed25519', 'key': 'MCowBQYDK2VwAyEAwu/cl0CZSSBFOzFSz/hhUQQjHIKiT4RS3ekPevSKn7w='}, {'id': 'magenta', 'algorithm': 'ed25519', 'key': 'MCowBQYDK2VwAyEAd10mgfTx0fuPO6KwcYU98RLhreCF+BQCeI6CAs0YztA='}, ] -LICENSE_COMMUNITY_MAX_USERS = 3 INSTALLATION_ID_PATH = MEDIA_ROOT / 'installation_id' if not INSTALLATION_ID_PATH.exists(): diff --git a/api/src/reportcreator_api/conf/settings_test.py b/api/src/reportcreator_api/conf/settings_test.py index 4c4219c5..e5515913 100644 --- a/api/src/reportcreator_api/conf/settings_test.py +++ b/api/src/reportcreator_api/conf/settings_test.py @@ -62,13 +62,15 @@ BACKUP_KEY = 'dummy-backup-key-used-in-unit-test' -# Always enable some plugins during tests -ENABLED_PLUGINS += ['demoplugin'] -enable_test_plugins = load_plugins(PLUGIN_DIRS, ENABLED_PLUGINS) -INSTALLED_APPS += [p for p in enable_test_plugins if p not in INSTALLED_APPS] - - # Disable license check +from reportcreator_api.conf import plugins # noqa: E402 from reportcreator_api.utils import license # noqa: E402 license.check_license = lambda **kwargs: {'type': license.LicenseType.PROFESSIONAL, 'users': 1000, 'name': 'Company Name'} +plugins.can_load_professional_plugins = lambda: True + + +# Always enable some plugins during tests +ENABLED_PLUGINS += ['demoplugin'] +enable_test_plugins = load_plugins(PLUGIN_DIRS, ENABLED_PLUGINS) +INSTALLED_APPS += [p for p in enable_test_plugins if p not in INSTALLED_APPS] diff --git a/api/src/reportcreator_api/pentests/collab/consumer_base.py b/api/src/reportcreator_api/pentests/collab/consumer_base.py index 4cef63bc..57fb3528 100644 --- a/api/src/reportcreator_api/pentests/collab/consumer_base.py +++ b/api/src/reportcreator_api/pentests/collab/consumer_base.py @@ -16,6 +16,7 @@ from django.utils.crypto import get_random_string from randomcolor import RandomColor from rest_framework.exceptions import ValidationError as RestFrameworkValidationError +from uvicorn.protocols.utils import ClientDisconnected from reportcreator_api.pentests.collab.text_transformations import ( EditorSelection, @@ -58,7 +59,7 @@ async def dispatch(self, message): with history_context(history_user=self.user): await super().dispatch(message) - except StopConsumer: + except (StopConsumer, ClientDisconnected): raise except Exception as ex: log.exception(ex) diff --git a/api/src/reportcreator_api/pentests/customfields/utils.py b/api/src/reportcreator_api/pentests/customfields/utils.py index a4b3c764..4b666501 100644 --- a/api/src/reportcreator_api/pentests/customfields/utils.py +++ b/api/src/reportcreator_api/pentests/customfields/utils.py @@ -25,7 +25,7 @@ def contains(a: dict, b: dict) -> bool: if not b: return True - if type(a) != type(b): + if type(a) is not type(b): return False for k, v in b.items(): @@ -42,7 +42,7 @@ def contains(a: dict, b: dict) -> bool: def has_field_structure_changed(old: FieldDefinition|ObjectField, new: FieldDefinition|ObjectField) -> bool: - if type(old) != type(new): + if type(old) is not type(new): return True old_fields = old.field_dict diff --git a/api/src/reportcreator_api/tests/test_plugins.py b/api/src/reportcreator_api/tests/test_plugins.py index ab290e97..f5b40a6f 100644 --- a/api/src/reportcreator_api/tests/test_plugins.py +++ b/api/src/reportcreator_api/tests/test_plugins.py @@ -1,6 +1,7 @@ import io from contextlib import contextmanager from pathlib import Path +from unittest import mock import pytest from django.apps import apps @@ -11,6 +12,7 @@ from django.test import override_settings from django.urls import reverse +from reportcreator_api.conf.plugins import load_plugins from reportcreator_api.management.commands import restorebackup from reportcreator_api.tests.mock import api_client, create_user from reportcreator_api.utils.utils import omit_keys @@ -23,7 +25,7 @@ def enable_demoplugin(): # Import config to check if plugin exists try: - from sysreptor_plugins.demoplugin.app import DemoPluginConfig + from sysreptor_plugins.demoplugin.app import DemoPluginConfig # type: ignore except ImportError: pytest.skip('DemoPlugin not found') @@ -49,42 +51,52 @@ def disable_demoplugin(): def create_demopluginmodel(**kwargs): - from sysreptor_plugins.demoplugin.models import DemoPluginModel + from sysreptor_plugins.demoplugin.models import DemoPluginModel # type: ignore return DemoPluginModel.objects.create(**kwargs) @pytest.mark.django_db() -@enable_demoplugin() -def test_plugin_loading(): - # Test django app of plugin is installed - assert apps.is_installed('sysreptor_plugins.demoplugin') - assert apps.get_app_config(DEMOPLUGIN_APPLABEL) is not None - - # Models registered - model = apps.get_model(DEMOPLUGIN_APPLABEL, 'DemoPluginModel') - obj = model.objects.create(name='test') - assert model.objects.filter(pk=obj.pk).exists() - - # Static files - # Create dummy file when the frontend was not built yet - from sysreptor_plugins import demoplugin # noqa: I001 - pluginjs_path = (Path(demoplugin.__path__[0]) / 'static' / 'plugin.js').resolve() - if not pluginjs_path.exists(): - pluginjs_path.parent.mkdir(parents=True, exist_ok=True) - pluginjs_path.touch() - finders.get_finder.cache_clear() - - res = finders.find(f'plugins/{DEMOPLUGIN_ID}/plugin.js') is not None - - # URLs registered - assert api_client().get(reverse(f'{DEMOPLUGIN_APPLABEL}:helloworld')).status_code == 200 - - # Plugin config in api settings - res = api_client().get(reverse('publicutils-settings')) - assert res.status_code == 200 - demoplugin_config = next(filter(lambda p: p['id'] == DEMOPLUGIN_ID, res.data['plugins'])) - assert omit_keys(demoplugin_config, ['frontend_entry']) == {'id': DEMOPLUGIN_ID, 'name': 'demoplugin', 'frontend_settings': {}} - +class TestPluginLoading: + @enable_demoplugin() + def test_plugin_loading(self): + # Test django app of plugin is installed + assert apps.is_installed('sysreptor_plugins.demoplugin') + assert apps.get_app_config(DEMOPLUGIN_APPLABEL) is not None + + # Models registered + model = apps.get_model(DEMOPLUGIN_APPLABEL, 'DemoPluginModel') + obj = model.objects.create(name='test') + assert model.objects.filter(pk=obj.pk).exists() + + # Static files + # Create dummy file when the frontend was not built yet + from sysreptor_plugins import demoplugin # noqa: I001 + pluginjs_path = (Path(demoplugin.__path__[0]) / 'static' / 'plugin.js').resolve() + if not pluginjs_path.exists(): + pluginjs_path.parent.mkdir(parents=True, exist_ok=True) + pluginjs_path.touch() + finders.get_finder.cache_clear() + + res = finders.find(f'plugins/{DEMOPLUGIN_ID}/plugin.js') is not None + + # URLs registered + assert api_client().get(reverse(f'{DEMOPLUGIN_APPLABEL}:helloworld')).status_code == 200 + + # Plugin config in api settings + res = api_client().get(reverse('publicutils-settings')) + assert res.status_code == 200 + demoplugin_config = next(filter(lambda p: p['id'] == DEMOPLUGIN_ID, res.data['plugins'])) + assert omit_keys(demoplugin_config, ['frontend_entry']) == {'id': DEMOPLUGIN_ID, 'name': 'demoplugin', 'frontend_settings': {}} + + def test_load_professional_only(self): + from sysreptor_plugins.demoplugin.app import DemoPluginConfig # type: ignore + + try: + DemoPluginConfig.professional_only = True + with mock.patch('reportcreator_api.conf.plugins.can_load_professional_plugins', return_value=False): + assert load_plugins(plugin_dirs=settings.PLUGIN_DIRS, enabled_plugins=['demoplugin']) == [] + finally: + DemoPluginConfig.professional_only = False @pytest.mark.django_db() class TestPluginBackupRestore: diff --git a/api/src/reportcreator_api/utils/license.py b/api/src/reportcreator_api/utils/license.py index ee3e799d..0ccfcd66 100644 --- a/api/src/reportcreator_api/utils/license.py +++ b/api/src/reportcreator_api/utils/license.py @@ -11,7 +11,6 @@ from django.utils import dateparse, timezone from rest_framework import permissions -from reportcreator_api.conf import plugins from reportcreator_api.utils.decorators import cache @@ -79,8 +78,6 @@ def decode_license(license): def decode_and_validate_license(license, skip_db_checks=False, skip_limit_validation=False): - from reportcreator_api.users.models import PentestUser - try: if not license: raise LicenseError(None) @@ -96,6 +93,7 @@ def decode_and_validate_license(license, skip_db_checks=False, skip_limit_valida # Validate license limits not exceeded if not skip_db_checks: + from reportcreator_api.users.models import PentestUser current_user_count = PentestUser.objects.get_licensed_user_count() if current_user_count > license_data['users']: raise LicenseError(license_data | { @@ -115,7 +113,7 @@ def decode_and_validate_license(license, skip_db_checks=False, skip_limit_valida error_details = ex.detail if isinstance(ex.detail, dict) else {'error': ex.detail} return error_details | { 'type': LicenseType.COMMUNITY, - 'users': settings.LICENSE_COMMUNITY_MAX_USERS, + 'users': 3, } @@ -129,6 +127,7 @@ async def acheck_license(**kwargs): def get_license_info(): + from reportcreator_api.conf import plugins from reportcreator_api.users.models import PentestUser return check_license() | { diff --git a/docs/docs/setup/plugins.md b/docs/docs/setup/plugins.md index 0a977323..785c4e78 100644 --- a/docs/docs/setup/plugins.md +++ b/docs/docs/setup/plugins.md @@ -11,15 +11,15 @@ All plugins are disabled by default. To enable a plugin, add it to the [`ENABLED Official plugins are maintained by the SysReptor team and are included in official docker images. -| Plugin | Description | -| ------ | ----------- | -| [cyberchef](https://github.com/Syslifters/sysreptor/tree/main/plugins/cyberchef) | CyberChef integration | -| [graphqlvoyager](https://github.com/Syslifters/sysreptor/tree/main/plugins/graphqlvoyager) | GraphQL Voyager integration | -| [checkthehash](https://github.com/Syslifters/sysreptor/tree/main/plugins/checkthehash) | Hash identifier | -| [customizetheme](https://github.com/Syslifters/sysreptor/tree/main/plugins/customizetheme) | Customize UI themes per instance | -| [demoplugin](https://github.com/Syslifters/sysreptor/tree/main/plugins/demoplugin) | A demo plugin that demonstrates the plugin system | -| [projectnumber](https://github.com/Syslifters/sysreptor/tree/main/plugins/projectnumber) | Automatically adds an incremental project number to new projects | -| [webhooks](https://github.com/Syslifters/sysreptor/tree/main/plugins/webhooks) | Send webhooks on certain events | +| Plugin | Description | | +| ------ | ----------- | --- | +| [cyberchef](https://github.com/Syslifters/sysreptor/tree/main/plugins/cyberchef) | CyberChef integration | | +| [graphqlvoyager](https://github.com/Syslifters/sysreptor/tree/main/plugins/graphqlvoyager) | GraphQL Voyager integration | | +| [checkthehash](https://github.com/Syslifters/sysreptor/tree/main/plugins/checkthehash) | Hash identifier | | +| [customizetheme](https://github.com/Syslifters/sysreptor/tree/main/plugins/customizetheme) | Customize UI themes per instance | | +| [demoplugin](https://github.com/Syslifters/sysreptor/tree/main/plugins/demoplugin) | A demo plugin that demonstrates the plugin system | | +| [projectnumber](https://github.com/Syslifters/sysreptor/tree/main/plugins/projectnumber) | Automatically adds an incremental project number to new projects | | +| [webhooks](https://github.com/Syslifters/sysreptor/tree/main/plugins/webhooks) | Send webhooks on certain events | :octicons-heart-fill-24: Pro only |