From 4e2c73458993e4f93dd0ad787306230f8fd1e0f5 Mon Sep 17 00:00:00 2001 From: "Terence D. Honles" Date: Tue, 14 Aug 2018 16:02:09 -0700 Subject: [PATCH] allow swagger and redoc settings to be overridden When using site middleware there is no way to configure swagger instances to have different login URLs or authentication mechanisms. This change updates any global references to app settings to prefer a local setting if it exists, and for backward compatibility it refers to the global setting by default. The following calls will need to be adjusted to account for the new swagger_settings parameter that is passed: - ``drf_yasg.utils.get_produces(...)`` - ViewInspector class (and user defined subclasses) - Generator class (and user defined subclasses) --- src/drf_yasg/generators.py | 17 +++++--- src/drf_yasg/inspectors/__init__.py | 6 --- src/drf_yasg/inspectors/base.py | 45 ++++++++++++++++---- src/drf_yasg/inspectors/view.py | 6 ++- src/drf_yasg/renderers.py | 66 +++++++++++++++-------------- src/drf_yasg/utils.py | 4 +- src/drf_yasg/views.py | 39 ++++++++++++++--- tests/test_reference_schema.py | 3 +- 8 files changed, 120 insertions(+), 66 deletions(-) diff --git a/src/drf_yasg/generators.py b/src/drf_yasg/generators.py index 80489dee..11e8f9a7 100644 --- a/src/drf_yasg/generators.py +++ b/src/drf_yasg/generators.py @@ -14,7 +14,7 @@ from rest_framework.settings import api_settings from . import openapi -from .app_settings import swagger_settings +from .app_settings import swagger_settings as _swagger_settings from .errors import SwaggerGenerationError from .inspectors.field import get_basic_type_info, get_queryset_field, get_queryset_from_view from .openapi import ReferenceResolver, SwaggerDict @@ -190,7 +190,7 @@ class OpenAPISchemaGenerator: 'delete': 'destroy', } - def __init__(self, info, version='', url=None, patterns=None, urlconf=None): + def __init__(self, info, version='', url=None, patterns=None, urlconf=None, swagger_settings=_swagger_settings): """ :param openapi.Info info: information about the API @@ -206,9 +206,11 @@ def __init__(self, info, version='', url=None, patterns=None, urlconf=None): :param patterns: if given, only these patterns will be enumerated for inclusion in the API spec :param urlconf: if patterns is not given, use this urlconf to enumerate patterns; if not given, the default urlconf is used + :param swagger_settings: if given global swagger_settings are overridden with local settings """ self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf) self.info = info + self.swagger_settings = swagger_settings self.version = version self.consumes = [] self.produces = [] @@ -237,7 +239,7 @@ def get_security_definitions(self): :return: the security schemes usable with this API :rtype: dict[str,dict] or None """ - security_definitions = swagger_settings.SECURITY_DEFINITIONS + security_definitions = self.swagger_settings.SECURITY_DEFINITIONS if security_definitions is not None: security_definitions = SwaggerDict._as_odict(security_definitions, {}) @@ -251,7 +253,7 @@ def get_security_requirements(self, security_definitions): :return: the security schemes accepted by default :rtype: list[dict[str,list[str]]] or None """ - security_requirements = swagger_settings.SECURITY_REQUIREMENTS + security_requirements = self.swagger_settings.SECURITY_REQUIREMENTS if security_requirements is None: security_requirements = [{security_scheme: []} for security_scheme in security_definitions] @@ -272,7 +274,7 @@ def get_schema(self, request=None, public=False): endpoints = self.get_endpoints(request) components = self.reference_resolver_class(openapi.SCHEMA_DEFINITIONS, force_init=True) self.consumes = get_consumes(api_settings.DEFAULT_PARSER_CLASSES) - self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES) + self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES, self.swagger_settings) paths, prefix = self.get_paths(endpoints, components, request, public) security_definitions = self.get_security_definitions() @@ -511,7 +513,7 @@ def get_operation(self, view, path, prefix, method, components, request): # the inspector class can be specified, in decreasing order of priority, # 1. globally via DEFAULT_AUTO_SCHEMA_CLASS - view_inspector_cls = swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS + view_inspector_cls = self.swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS # 2. on the view/viewset class view_inspector_cls = getattr(view, 'swagger_schema', view_inspector_cls) # 3. on the swagger_auto_schema decorator @@ -520,7 +522,8 @@ def get_operation(self, view, path, prefix, method, components, request): if view_inspector_cls is None: return None - view_inspector = view_inspector_cls(view, path, method, components, request, overrides, operation_keys) + view_inspector = view_inspector_cls( + view, path, method, components, request, overrides, operation_keys, self.swagger_settings) operation = view_inspector.get_operation(operation_keys) if operation is None: return None diff --git a/src/drf_yasg/inspectors/__init__.py b/src/drf_yasg/inspectors/__init__.py index 70bde798..7bade5c9 100644 --- a/src/drf_yasg/inspectors/__init__.py +++ b/src/drf_yasg/inspectors/__init__.py @@ -1,4 +1,3 @@ -from ..app_settings import swagger_settings from .base import ( BaseInspector, FieldInspector, FilterInspector, NotHandled, PaginatorInspector, SerializerInspector, ViewInspector ) @@ -10,11 +9,6 @@ from .query import DrfAPICompatInspector, CoreAPICompatInspector, DjangoRestResponsePagination from .view import SwaggerAutoSchema -# these settings must be accessed only after defining/importing all the classes in this module to avoid ImportErrors -ViewInspector.field_inspectors = swagger_settings.DEFAULT_FIELD_INSPECTORS -ViewInspector.filter_inspectors = swagger_settings.DEFAULT_FILTER_INSPECTORS -ViewInspector.paginator_inspectors = swagger_settings.DEFAULT_PAGINATOR_INSPECTORS - __all__ = [ # base inspectors 'BaseInspector', 'FilterInspector', 'PaginatorInspector', 'FieldInspector', 'SerializerInspector', 'ViewInspector', diff --git a/src/drf_yasg/inspectors/base.py b/src/drf_yasg/inspectors/base.py index 2660fb75..d0bdcb42 100644 --- a/src/drf_yasg/inspectors/base.py +++ b/src/drf_yasg/inspectors/base.py @@ -4,6 +4,7 @@ from rest_framework import serializers from .. import openapi +from ..app_settings import swagger_settings as _swagger_settings from ..utils import force_real_str, get_field_default, get_object_classes, is_list_view #: Sentinel value that inspectors must return to signal that they do not know how to handle an object @@ -52,19 +53,21 @@ def call_view_method(view, method_name, fallback_attr=None, default=None): class BaseInspector: - def __init__(self, view, path, method, components, request): + def __init__(self, view, path, method, components, request, swagger_settings=_swagger_settings): """ :param rest_framework.views.APIView view: the view associated with this endpoint :param str path: the path component of the operation URL :param str method: the http method of the operation :param openapi.ReferenceResolver components: referenceable components :param rest_framework.request.Request request: the request made against the schema view; can be None + :param swagger_settings: if given global swagger_settings are overridden with local settings """ self.view = view self.path = path self.method = method self.components = components self.request = request + self.swagger_settings = swagger_settings def process_result(self, result, method_name, obj, **kwargs): """After an inspector handles an object (i.e. returns a value other than :data:`.NotHandled`), all inspectors @@ -185,8 +188,8 @@ def get_filter_parameters(self, filter_backend): class FieldInspector(BaseInspector): """Base inspector for serializers and serializer fields. """ - def __init__(self, view, path, method, components, request, field_inspectors): - super(FieldInspector, self).__init__(view, path, method, components, request) + def __init__(self, view, path, method, components, request, field_inspectors, swagger_settings=_swagger_settings): + super(FieldInspector, self).__init__(view, path, method, components, request, swagger_settings) self.field_inspectors = field_inspectors def add_manual_fields(self, serializer_or_field, schema): @@ -338,18 +341,42 @@ class ViewInspector(BaseInspector): #: methods which are assumed to return a list of objects when present on non-detail endpoints implicit_list_response_methods = ('GET',) - # real values set in __init__ to prevent import errors - field_inspectors = [] #: - filter_inspectors = [] #: - paginator_inspectors = [] #: + _field_inspectors = None + _filter_inspectors = None + _paginator_inspectors = None - def __init__(self, view, path, method, components, request, overrides): + @property + def field_inspectors(self): + return self._field_inspectors or self.swagger_settings.DEFAULT_FIELD_INSPECTORS + + @field_inspectors.setter + def field_inspectors(self, value): + self._field_inspectors = value + + @property + def filter_inspectors(self): + return self._filter_inspectors or self.swagger_settings.DEFAULT_FILTER_INSPECTORS + + @filter_inspectors.setter + def filter_inspectors(self, value): + self._filter_inspectors = value + + @property + def paginator_inspectors(self): + return self._paginator_inspectors or self.swagger_settings.DEFAULT_PAGINATOR_INSPECTORS + + @paginator_inspectors.setter + def paginator_inspectors(self, value): + self._paginator_inspectors = value + + def __init__(self, view, path, method, components, request, overrides, swagger_settings=_swagger_settings): """ Inspector class responsible for providing :class:`.Operation` definitions given a view, path and method. :param dict overrides: manual overrides as passed to :func:`@swagger_auto_schema <.swagger_auto_schema>` + :param swagger_settings: if given global swagger_settings are overridden with local settings """ - super(ViewInspector, self).__init__(view, path, method, components, request) + super(ViewInspector, self).__init__(view, path, method, components, request, swagger_settings) self.overrides = overrides self._prepend_inspector_overrides('field_inspectors') self._prepend_inspector_overrides('filter_inspectors') diff --git a/src/drf_yasg/inspectors/view.py b/src/drf_yasg/inspectors/view.py index 71bfbf35..bbad9197 100644 --- a/src/drf_yasg/inspectors/view.py +++ b/src/drf_yasg/inspectors/view.py @@ -6,6 +6,7 @@ from rest_framework.status import is_success from .. import openapi +from ..app_settings import swagger_settings as _swagger_settings from ..errors import SwaggerGenerationError from ..utils import ( filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status, @@ -17,8 +18,9 @@ class SwaggerAutoSchema(ViewInspector): - def __init__(self, view, path, method, components, request, overrides, operation_keys=None): - super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides) + def __init__(self, view, path, method, components, request, overrides, operation_keys=None, + swagger_settings=_swagger_settings): + super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides, swagger_settings) self._sch = AutoSchema() self._sch.view = view self.operation_keys = operation_keys diff --git a/src/drf_yasg/renderers.py b/src/drf_yasg/renderers.py index 7d79aaf6..96845785 100644 --- a/src/drf_yasg/renderers.py +++ b/src/drf_yasg/renderers.py @@ -6,7 +6,7 @@ from rest_framework.renderers import BaseRenderer, JSONRenderer, TemplateHTMLRenderer from rest_framework.utils import encoders, json -from .app_settings import redoc_settings, swagger_settings +from .app_settings import redoc_settings as _redoc_settings, swagger_settings as _swagger_settings from .codecs import VALIDATORS, OpenAPICodecJson, OpenAPICodecYaml from .openapi import Swagger from .utils import filter_none @@ -62,6 +62,7 @@ class _UIRenderer(BaseRenderer): media_type = 'text/html' charset = 'utf-8' template = '' + swagger_settings = _swagger_settings def render(self, swagger, accepted_media_type=None, renderer_context=None): if not isinstance(swagger, Swagger): # pragma: no cover @@ -82,7 +83,7 @@ def set_context(self, renderer_context, swagger=None): renderer_context['title'] = swagger.info.title or '' if swagger else '' renderer_context['version'] = swagger.info.version or '' if swagger else '' renderer_context['oauth2_config'] = json.dumps(self.get_oauth2_config(), cls=encoders.JSONEncoder) - renderer_context['USE_SESSION_AUTH'] = swagger_settings.USE_SESSION_AUTH + renderer_context['USE_SESSION_AUTH'] = self.swagger_settings.USE_SESSION_AUTH renderer_context.update(self.get_auth_urls()) def resolve_url(self, to): @@ -106,14 +107,14 @@ def resolve_url(self, to): def get_auth_urls(self): urls = { - 'LOGIN_URL': self.resolve_url(swagger_settings.LOGIN_URL), - 'LOGOUT_URL': self.resolve_url(swagger_settings.LOGOUT_URL), + 'LOGIN_URL': self.resolve_url(self.swagger_settings.LOGIN_URL), + 'LOGOUT_URL': self.resolve_url(self.swagger_settings.LOGOUT_URL), } return filter_none(urls) def get_oauth2_config(self): - data = swagger_settings.OAUTH2_CONFIG + data = self.swagger_settings.OAUTH2_CONFIG assert isinstance(data, dict), "OAUTH2_CONFIG must be a dict" return data @@ -136,28 +137,28 @@ def set_context(self, renderer_context, swagger=None): def get_swagger_ui_settings(self): data = { - 'url': self.resolve_url(swagger_settings.SPEC_URL), - 'operationsSorter': swagger_settings.OPERATIONS_SORTER, - 'tagsSorter': swagger_settings.TAGS_SORTER, - 'docExpansion': swagger_settings.DOC_EXPANSION, - 'deepLinking': swagger_settings.DEEP_LINKING, - 'showExtensions': swagger_settings.SHOW_EXTENSIONS, - 'defaultModelRendering': swagger_settings.DEFAULT_MODEL_RENDERING, - 'defaultModelExpandDepth': swagger_settings.DEFAULT_MODEL_DEPTH, - 'defaultModelsExpandDepth': swagger_settings.DEFAULT_MODEL_DEPTH, - 'showCommonExtensions': swagger_settings.SHOW_COMMON_EXTENSIONS, - 'oauth2RedirectUrl': swagger_settings.OAUTH2_REDIRECT_URL, - 'supportedSubmitMethods': swagger_settings.SUPPORTED_SUBMIT_METHODS, - 'displayOperationId': swagger_settings.DISPLAY_OPERATION_ID, - 'persistAuth': swagger_settings.PERSIST_AUTH, - 'refetchWithAuth': swagger_settings.REFETCH_SCHEMA_WITH_AUTH, - 'refetchOnLogout': swagger_settings.REFETCH_SCHEMA_ON_LOGOUT, - 'fetchSchemaWithQuery': swagger_settings.FETCH_SCHEMA_WITH_QUERY, + 'url': self.resolve_url(self.swagger_settings.SPEC_URL), + 'operationsSorter': self.swagger_settings.OPERATIONS_SORTER, + 'tagsSorter': self.swagger_settings.TAGS_SORTER, + 'docExpansion': self.swagger_settings.DOC_EXPANSION, + 'deepLinking': self.swagger_settings.DEEP_LINKING, + 'showExtensions': self.swagger_settings.SHOW_EXTENSIONS, + 'defaultModelRendering': self.swagger_settings.DEFAULT_MODEL_RENDERING, + 'defaultModelExpandDepth': self.swagger_settings.DEFAULT_MODEL_DEPTH, + 'defaultModelsExpandDepth': self.swagger_settings.DEFAULT_MODEL_DEPTH, + 'showCommonExtensions': self.swagger_settings.SHOW_COMMON_EXTENSIONS, + 'oauth2RedirectUrl': self.swagger_settings.OAUTH2_REDIRECT_URL, + 'supportedSubmitMethods': self.swagger_settings.SUPPORTED_SUBMIT_METHODS, + 'displayOperationId': self.swagger_settings.DISPLAY_OPERATION_ID, + 'persistAuth': self.swagger_settings.PERSIST_AUTH, + 'refetchWithAuth': self.swagger_settings.REFETCH_SCHEMA_WITH_AUTH, + 'refetchOnLogout': self.swagger_settings.REFETCH_SCHEMA_ON_LOGOUT, + 'fetchSchemaWithQuery': self.swagger_settings.FETCH_SCHEMA_WITH_QUERY, } data = filter_none(data) - if swagger_settings.VALIDATOR_URL != '': - data['validatorUrl'] = self.resolve_url(swagger_settings.VALIDATOR_URL) + if self.swagger_settings.VALIDATOR_URL != '': + data['validatorUrl'] = self.resolve_url(self.swagger_settings.VALIDATOR_URL) return data @@ -166,6 +167,7 @@ class ReDocRenderer(_UIRenderer): """Renders a ReDoc web interface for schema browsing.""" template = 'drf-yasg/redoc.html' format = 'redoc' + redoc_settings = _redoc_settings def set_context(self, renderer_context, swagger=None): super(ReDocRenderer, self).set_context(renderer_context, swagger) @@ -173,14 +175,14 @@ def set_context(self, renderer_context, swagger=None): def get_redoc_settings(self): data = { - 'url': self.resolve_url(redoc_settings.SPEC_URL), - 'lazyRendering': redoc_settings.LAZY_RENDERING, - 'hideHostname': redoc_settings.HIDE_HOSTNAME, - 'expandResponses': redoc_settings.EXPAND_RESPONSES, - 'pathInMiddlePanel': redoc_settings.PATH_IN_MIDDLE, - 'nativeScrollbars': redoc_settings.NATIVE_SCROLLBARS, - 'requiredPropsFirst': redoc_settings.REQUIRED_PROPS_FIRST, - 'fetchSchemaWithQuery': redoc_settings.FETCH_SCHEMA_WITH_QUERY, + 'url': self.resolve_url(self.redoc_settings.SPEC_URL), + 'lazyRendering': self.redoc_settings.LAZY_RENDERING, + 'hideHostname': self.redoc_settings.HIDE_HOSTNAME, + 'expandResponses': self.redoc_settings.EXPAND_RESPONSES, + 'pathInMiddlePanel': self.redoc_settings.PATH_IN_MIDDLE, + 'nativeScrollbars': self.redoc_settings.NATIVE_SCROLLBARS, + 'requiredPropsFirst': self.redoc_settings.REQUIRED_PROPS_FIRST, + 'fetchSchemaWithQuery': self.redoc_settings.FETCH_SCHEMA_WITH_QUERY, } return filter_none(data) diff --git a/src/drf_yasg/utils.py b/src/drf_yasg/utils.py index 494ce48c..c78a4821 100644 --- a/src/drf_yasg/utils.py +++ b/src/drf_yasg/utils.py @@ -15,7 +15,7 @@ from rest_framework.utils import encoders, json from rest_framework.views import APIView -from .app_settings import swagger_settings +from .app_settings import swagger_settings as _swagger_settings logger = logging.getLogger(__name__) @@ -387,7 +387,7 @@ def get_consumes(parser_classes): return non_form_media_types -def get_produces(renderer_classes): +def get_produces(renderer_classes, swagger_settings=_swagger_settings): """Extract ``produces`` MIME types from a list of renderer classes. :param list renderer_classes: renderer classes diff --git a/src/drf_yasg/views.py b/src/drf_yasg/views.py index 977f2f4a..a4b3015b 100644 --- a/src/drf_yasg/views.py +++ b/src/drf_yasg/views.py @@ -9,7 +9,7 @@ from rest_framework.settings import api_settings from rest_framework.views import APIView -from .app_settings import swagger_settings +from .app_settings import redoc_settings as _redoc_settings, swagger_settings as _swagger_settings from .renderers import ( ReDocOldRenderer, ReDocRenderer, @@ -17,7 +17,6 @@ _SpecRenderer, ) -SPEC_RENDERERS = swagger_settings.DEFAULT_SPEC_RENDERERS UI_RENDERERS = { 'swagger': (SwaggerUIRenderer, ReDocRenderer), 'redoc': (ReDocRenderer, SwaggerUIRenderer), @@ -50,7 +49,8 @@ def callback(response): def get_schema_view(info=None, url=None, patterns=None, urlconf=None, public=False, validators=None, - generator_class=None, authentication_classes=None, permission_classes=None): + generator_class=None, authentication_classes=None, permission_classes=None, + swagger_settings=_swagger_settings, redoc_settings=_redoc_settings): """Create a SchemaView class with default renderers and generators. :param Info info: information about the API; if omitted, defaults to :ref:`DEFAULT_INFO ` @@ -75,7 +75,10 @@ def get_schema_view(info=None, url=None, patterns=None, urlconf=None, public=Fal _perm_classes = api_settings.DEFAULT_PERMISSION_CLASSES info = info or swagger_settings.DEFAULT_INFO validators = validators or [] - _spec_renderers = tuple(renderer.with_validators(validators) for renderer in SPEC_RENDERERS) + _spec_renderers = tuple( + renderer.with_validators(validators) + for renderer in swagger_settings.DEFAULT_SPEC_RENDERERS + ) class SchemaView(APIView): _ignore_model_permissions = True @@ -89,9 +92,9 @@ class SchemaView(APIView): def get(self, request, version='', format=None): version = request.version or version or '' if isinstance(request.accepted_renderer, _SpecRenderer): - generator = self.generator_class(info, version, url, patterns, urlconf) + generator = self.generator_class(info, version, url, patterns, urlconf, swagger_settings) else: - generator = self.generator_class(info, version, url, patterns=[]) + generator = self.generator_class(info, version, url, patterns=[], swagger_settings=swagger_settings) schema = generator.get_schema(request, self.public) if schema is None: @@ -152,8 +155,30 @@ def with_ui(cls, renderer='swagger', cache_timeout=0, cache_kwargs=None): :return: a view instance """ assert renderer in UI_RENDERERS, "supported default renderers are " + ", ".join(UI_RENDERERS) - renderer_classes = UI_RENDERERS[renderer] + _spec_renderers + _local_swagger_settings = swagger_settings + _local_redoc_settings = redoc_settings + renderer_classes = [] + for renderer_class in UI_RENDERERS[renderer]: + if issubclass(renderer_class, SwaggerUIRenderer): + if _local_swagger_settings is _swagger_settings: + renderer_classes.append(renderer_class) + else: + class CustomSettingsSwaggerRenderer(renderer_class): + swagger_settings = _local_swagger_settings + + renderer_classes.append(CustomSettingsSwaggerRenderer) + + elif issubclass(renderer_class, ReDocRenderer): + if _local_redoc_settings is _redoc_settings: + renderer_classes.append(renderer_class) + else: + class CustomSettingsRedDocRenderer(renderer_class): + redoc_settings = _local_redoc_settings + + renderer_classes.append(CustomSettingsRedDocRenderer) + + renderer_classes.extend(_spec_renderers) return cls.as_cached_view(cache_timeout, cache_kwargs, renderer_classes=renderer_classes) return SchemaView diff --git a/tests/test_reference_schema.py b/tests/test_reference_schema.py index b860b582..eb019b47 100644 --- a/tests/test_reference_schema.py +++ b/tests/test_reference_schema.py @@ -36,8 +36,9 @@ def test_noop_inspectors(swagger_settings, mock_schema_request, codec_json, refe from drf_yasg import app_settings def set_inspectors(inspectors, setting_name): + existing = swagger_settings.get(setting_name, app_settings.SWAGGER_DEFAULTS[setting_name]) inspectors = [__name__ + '.' + inspector.__name__ for inspector in inspectors] - swagger_settings[setting_name] = inspectors + app_settings.SWAGGER_DEFAULTS[setting_name] + swagger_settings[setting_name] = inspectors + existing set_inspectors([NoOpFieldInspector, NoOpSerializerInspector], 'DEFAULT_FIELD_INSPECTORS') set_inspectors([NoOpFilterInspector], 'DEFAULT_FILTER_INSPECTORS')