From 684db20d950442cfeaad583afa67950a39e18f63 Mon Sep 17 00:00:00 2001 From: "Terence D. Honles" Date: Tue, 14 Aug 2018 16:02:09 -0700 Subject: [PATCH] allow swagger and redoc settings to be overridden When using site middleware there is no way to configure swagger instances to have different login URLs or authentication mechanisms. This change updates any global references to app settings to prefer a local setting if it exists, and for backward compatibility it refers to the global setting by default. The following calls will need to be adjusted to account for the new swagger_settings parameter that is passed: - ``drf_yasg.utils.get_produces(...)`` - ViewInspector class (and user defined subclasses) - Generator class (and user defined subclasses) --- src/drf_yasg/generators.py | 24 ++++++++------ src/drf_yasg/inspectors/__init__.py | 6 ---- src/drf_yasg/inspectors/base.py | 45 ++++++++++++++++++++------ src/drf_yasg/inspectors/view.py | 5 +-- src/drf_yasg/renderers.py | 50 +++++++++++++++-------------- src/drf_yasg/utils.py | 4 +-- src/drf_yasg/views.py | 37 +++++++++++++++++---- 7 files changed, 111 insertions(+), 60 deletions(-) diff --git a/src/drf_yasg/generators.py b/src/drf_yasg/generators.py index 83646665..c3759b5f 100644 --- a/src/drf_yasg/generators.py +++ b/src/drf_yasg/generators.py @@ -13,7 +13,7 @@ from rest_framework.settings import api_settings as rest_framework_settings from . import openapi -from .app_settings import swagger_settings +from .app_settings import swagger_settings as _swagger_settings from .errors import SwaggerGenerationError from .inspectors.field import get_basic_type_info, get_queryset_field from .openapi import ReferenceResolver @@ -161,7 +161,7 @@ class OpenAPISchemaGenerator(object): """ endpoint_enumerator_class = EndpointEnumerator - def __init__(self, info, version='', url=None, patterns=None, urlconf=None): + def __init__(self, info, version='', url=None, patterns=None, urlconf=None, swagger_settings=_swagger_settings): """ :param .Info info: information about the API @@ -177,9 +177,11 @@ def __init__(self, info, version='', url=None, patterns=None, urlconf=None): :param patterns: if given, only these patterns will be enumerated for inclusion in the API spec :param urlconf: if patterns is not given, use this urlconf to enumerate patterns; if not given, the default urlconf is used + :param swagger_settings: if given global swagger_settings are overridden with local settings """ self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf) self.info = info + self.swagger_settings = swagger_settings self.version = version self.consumes = [] self.produces = [] @@ -211,16 +213,18 @@ def get_schema(self, request=None, public=False): endpoints = self.get_endpoints(request) components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS) self.consumes = get_consumes(rest_framework_settings.DEFAULT_PARSER_CLASSES) - self.produces = get_produces(rest_framework_settings.DEFAULT_RENDERER_CLASSES) + self.produces = get_produces(rest_framework_settings.DEFAULT_RENDERER_CLASSES, self.swagger_settings) paths, prefix = self.get_paths(endpoints, components, request, public) - security_definitions = swagger_settings.SECURITY_DEFINITIONS + security_definitions = self.swagger_settings.SECURITY_DEFINITIONS if security_definitions is not None: - security_definitions = OrderedDict(sorted([(key, OrderedDict(sorted(sd.items()))) - for key, sd in swagger_settings.SECURITY_DEFINITIONS.items()])) - security_requirements = swagger_settings.SECURITY_REQUIREMENTS + security_definitions = OrderedDict(sorted( + (key, OrderedDict(sorted(sd.items()))) + for key, sd in self.swagger_settings.SECURITY_DEFINITIONS.items())) + security_requirements = self.swagger_settings.SECURITY_REQUIREMENTS if security_requirements is None: - security_requirements = [{security_scheme: []} for security_scheme in swagger_settings.SECURITY_DEFINITIONS] + security_requirements = [ + {security_scheme: []} for security_scheme in self.swagger_settings.SECURITY_DEFINITIONS] security_requirements = sorted(security_requirements, key=lambda od: list(sorted(od))) security_requirements = [OrderedDict(sorted(sr.items())) for sr in security_requirements] @@ -366,7 +370,7 @@ def get_operation(self, view, path, prefix, method, components, request): # the inspector class can be specified, in decreasing order of priorty, # 1. globaly via DEFAULT_AUTO_SCHEMA_CLASS - view_inspector_cls = swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS + view_inspector_cls = self.swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS # 2. on the view/viewset class view_inspector_cls = getattr(view, 'swagger_schema', view_inspector_cls) # 3. on the swagger_auto_schema decorator @@ -375,7 +379,7 @@ def get_operation(self, view, path, prefix, method, components, request): if view_inspector_cls is None: return None - view_inspector = view_inspector_cls(view, path, method, components, request, overrides) + view_inspector = view_inspector_cls(view, path, method, components, request, overrides, self.swagger_settings) operation = view_inspector.get_operation(operation_keys) if operation is None: return None diff --git a/src/drf_yasg/inspectors/__init__.py b/src/drf_yasg/inspectors/__init__.py index cb979c86..2f4c5234 100644 --- a/src/drf_yasg/inspectors/__init__.py +++ b/src/drf_yasg/inspectors/__init__.py @@ -1,4 +1,3 @@ -from ..app_settings import swagger_settings from .base import ( BaseInspector, FieldInspector, FilterInspector, NotHandled, PaginatorInspector, SerializerInspector, ViewInspector ) @@ -10,11 +9,6 @@ from .query import CoreAPICompatInspector, DjangoRestResponsePagination from .view import SwaggerAutoSchema -# these settings must be accessed only after defining/importing all the classes in this module to avoid ImportErrors -ViewInspector.field_inspectors = swagger_settings.DEFAULT_FIELD_INSPECTORS -ViewInspector.filter_inspectors = swagger_settings.DEFAULT_FILTER_INSPECTORS -ViewInspector.paginator_inspectors = swagger_settings.DEFAULT_PAGINATOR_INSPECTORS - __all__ = [ # base inspectors 'BaseInspector', 'FilterInspector', 'PaginatorInspector', 'FieldInspector', 'SerializerInspector', 'ViewInspector', diff --git a/src/drf_yasg/inspectors/base.py b/src/drf_yasg/inspectors/base.py index 35cdc75e..0879410e 100644 --- a/src/drf_yasg/inspectors/base.py +++ b/src/drf_yasg/inspectors/base.py @@ -4,6 +4,7 @@ from rest_framework import serializers from .. import openapi +from ..app_settings import swagger_settings as _swagger_settings from ..utils import force_real_str, get_field_default, is_list_view #: Sentinel value that inspectors must return to signal that they do not know how to handle an object @@ -13,19 +14,21 @@ class BaseInspector(object): - def __init__(self, view, path, method, components, request): + def __init__(self, view, path, method, components, request, swagger_settings=_swagger_settings): """ :param view: the view associated with this endpoint :param str path: the path component of the operation URL :param str method: the http method of the operation :param openapi.ReferenceResolver components: referenceable components :param Request request: the request made against the schema view; can be None + :param swagger_settings: if given global swagger_settings are overridden with local settings """ self.view = view self.path = path self.method = method self.components = components self.request = request + self.swagger_settings = swagger_settings def process_result(self, result, method_name, obj, **kwargs): """After an inspector handles an object (i.e. returns a value other than :data:`.NotHandled`), all inspectors @@ -130,8 +133,8 @@ def get_filter_parameters(self, filter_backend): class FieldInspector(BaseInspector): """Base inspector for serializers and serializer fields. """ - def __init__(self, view, path, method, components, request, field_inspectors): - super(FieldInspector, self).__init__(view, path, method, components, request) + def __init__(self, view, path, method, components, request, field_inspectors, swagger_settings=_swagger_settings): + super(FieldInspector, self).__init__(view, path, method, components, request, swagger_settings) self.field_inspectors = field_inspectors def add_manual_fields(self, serializer_or_field, schema): @@ -275,18 +278,42 @@ class ViewInspector(BaseInspector): #: methods that are assumed to require a request body determined by the view's ``serializer_class`` implicit_body_methods = ('PUT', 'PATCH', 'POST') - # real values set in __init__ to prevent import errors - field_inspectors = [] #: - filter_inspectors = [] #: - paginator_inspectors = [] #: + _field_inspectors = None + _filter_inspectors = None + _paginator_inspectors = None - def __init__(self, view, path, method, components, request, overrides): + @property + def field_inspectors(self): + return self._field_inspectors or self.swagger_settings.DEFAULT_FIELD_INSPECTORS + + @field_inspectors.setter + def field_inspectors(self, value): + self._field_inspectors = value + + @property + def filter_inspectors(self): + return self._filter_inspectors or self.swagger_settings.DEFAULT_FILTER_INSPECTORS + + @filter_inspectors.setter + def filter_inspectors(self, value): + self._filter_inspectors = value + + @property + def paginator_inspectors(self): + return self._paginator_inspectors or self.swagger_settings.DEFAULT_PAGINATOR_INSPECTORS + + @paginator_inspectors.setter + def paginator_inspectors(self, value): + self._paginator_inspectors = value + + def __init__(self, view, path, method, components, request, overrides, swagger_settings=_swagger_settings): """ Inspector class responsible for providing :class:`.Operation` definitions given a view, path and method. :param dict overrides: manual overrides as passed to :func:`@swagger_auto_schema <.swagger_auto_schema>` + :param swagger_settings: if given global swagger_settings are overridden with local settings """ - super(ViewInspector, self).__init__(view, path, method, components, request) + super(ViewInspector, self).__init__(view, path, method, components, request, swagger_settings) self.overrides = overrides self._prepend_inspector_overrides('field_inspectors') self._prepend_inspector_overrides('filter_inspectors') diff --git a/src/drf_yasg/inspectors/view.py b/src/drf_yasg/inspectors/view.py index c2d468ed..77ab74e8 100644 --- a/src/drf_yasg/inspectors/view.py +++ b/src/drf_yasg/inspectors/view.py @@ -6,6 +6,7 @@ from rest_framework.status import is_success from .. import openapi +from ..app_settings import swagger_settings as _swagger_settings from ..errors import SwaggerGenerationError from ..utils import ( force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status, is_list_view, no_body, @@ -17,8 +18,8 @@ class SwaggerAutoSchema(ViewInspector): - def __init__(self, view, path, method, components, request, overrides): - super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides) + def __init__(self, view, path, method, components, request, overrides, swagger_settings=_swagger_settings): + super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides, swagger_settings) self._sch = AutoSchema() self._sch.view = view diff --git a/src/drf_yasg/renderers.py b/src/drf_yasg/renderers.py index bf92fed6..14dd2003 100644 --- a/src/drf_yasg/renderers.py +++ b/src/drf_yasg/renderers.py @@ -4,7 +4,7 @@ from rest_framework.renderers import BaseRenderer, JSONRenderer, TemplateHTMLRenderer from rest_framework.utils import json -from .app_settings import redoc_settings, swagger_settings +from .app_settings import redoc_settings as _redoc_settings, swagger_settings as _swagger_settings from .codecs import VALIDATORS, OpenAPICodecJson, OpenAPICodecYaml from .openapi import Swagger from .utils import filter_none @@ -59,6 +59,7 @@ class _UIRenderer(BaseRenderer): media_type = 'text/html' charset = 'utf-8' template = '' + swagger_settings = _swagger_settings def render(self, swagger, accepted_media_type=None, renderer_context=None): if not isinstance(swagger, Swagger): # pragma: no cover @@ -77,7 +78,7 @@ def set_context(self, renderer_context, swagger): renderer_context['title'] = swagger.info.title renderer_context['version'] = swagger.info.version renderer_context['oauth2_config'] = json.dumps(self.get_oauth2_config()) - renderer_context['USE_SESSION_AUTH'] = swagger_settings.USE_SESSION_AUTH + renderer_context['USE_SESSION_AUTH'] = self.swagger_settings.USE_SESSION_AUTH renderer_context.update(self.get_auth_urls()) def resolve_url(self, to): @@ -98,14 +99,14 @@ def resolve_url(self, to): def get_auth_urls(self): urls = { - 'LOGIN_URL': self.resolve_url(swagger_settings.LOGIN_URL), - 'LOGOUT_URL': self.resolve_url(swagger_settings.LOGOUT_URL), + 'LOGIN_URL': self.resolve_url(self.swagger_settings.LOGIN_URL), + 'LOGOUT_URL': self.resolve_url(self.swagger_settings.LOGOUT_URL), } return filter_none(urls) def get_oauth2_config(self): - data = swagger_settings.OAUTH2_CONFIG + data = self.swagger_settings.OAUTH2_CONFIG assert isinstance(data, dict), "OAUTH2_CONFIG must be a dict" return data @@ -121,23 +122,23 @@ def set_context(self, renderer_context, swagger): def get_swagger_ui_settings(self): data = { - 'url': self.resolve_url(swagger_settings.SPEC_URL), - 'operationsSorter': swagger_settings.OPERATIONS_SORTER, - 'tagsSorter': swagger_settings.TAGS_SORTER, - 'docExpansion': swagger_settings.DOC_EXPANSION, - 'deepLinking': swagger_settings.DEEP_LINKING, - 'showExtensions': swagger_settings.SHOW_EXTENSIONS, - 'defaultModelRendering': swagger_settings.DEFAULT_MODEL_RENDERING, - 'defaultModelExpandDepth': swagger_settings.DEFAULT_MODEL_DEPTH, - 'defaultModelsExpandDepth': swagger_settings.DEFAULT_MODEL_DEPTH, - 'showCommonExtensions': swagger_settings.SHOW_COMMON_EXTENSIONS, - 'oauth2RedirectUrl': swagger_settings.OAUTH2_REDIRECT_URL, - 'supportedSubmitMethods': swagger_settings.SUPPORTED_SUBMIT_METHODS, + 'url': self.resolve_url(self.swagger_settings.SPEC_URL), + 'operationsSorter': self.swagger_settings.OPERATIONS_SORTER, + 'tagsSorter': self.swagger_settings.TAGS_SORTER, + 'docExpansion': self.swagger_settings.DOC_EXPANSION, + 'deepLinking': self.swagger_settings.DEEP_LINKING, + 'showExtensions': self.swagger_settings.SHOW_EXTENSIONS, + 'defaultModelRendering': self.swagger_settings.DEFAULT_MODEL_RENDERING, + 'defaultModelExpandDepth': self.swagger_settings.DEFAULT_MODEL_DEPTH, + 'defaultModelsExpandDepth': self.swagger_settings.DEFAULT_MODEL_DEPTH, + 'showCommonExtensions': self.swagger_settings.SHOW_COMMON_EXTENSIONS, + 'oauth2RedirectUrl': self.swagger_settings.OAUTH2_REDIRECT_URL, + 'supportedSubmitMethods': self.swagger_settings.SUPPORTED_SUBMIT_METHODS, } data = filter_none(data) - if swagger_settings.VALIDATOR_URL != '': - data['validatorUrl'] = self.resolve_url(swagger_settings.VALIDATOR_URL) + if self.swagger_settings.VALIDATOR_URL != '': + data['validatorUrl'] = self.resolve_url(self.swagger_settings.VALIDATOR_URL) return data @@ -146,6 +147,7 @@ class ReDocRenderer(_UIRenderer): """Renders a ReDoc web interface for schema browisng.""" template = 'drf-yasg/redoc.html' format = 'redoc' + redoc_settings = _redoc_settings def set_context(self, renderer_context, swagger): super(ReDocRenderer, self).set_context(renderer_context, swagger) @@ -153,11 +155,11 @@ def set_context(self, renderer_context, swagger): def get_redoc_settings(self): data = { - 'url': self.resolve_url(redoc_settings.SPEC_URL), - 'lazyRendering': redoc_settings.LAZY_RENDERING, - 'hideHostname': redoc_settings.HIDE_HOSTNAME, - 'expandResponses': redoc_settings.EXPAND_RESPONSES, - 'pathInMiddle': redoc_settings.PATH_IN_MIDDLE, + 'url': self.resolve_url(self.redoc_settings.SPEC_URL), + 'lazyRendering': self.redoc_settings.LAZY_RENDERING, + 'hideHostname': self.redoc_settings.HIDE_HOSTNAME, + 'expandResponses': self.redoc_settings.EXPAND_RESPONSES, + 'pathInMiddle': self.redoc_settings.PATH_IN_MIDDLE, } return filter_none(data) diff --git a/src/drf_yasg/utils.py b/src/drf_yasg/utils.py index 0177e0b6..ff408d61 100644 --- a/src/drf_yasg/utils.py +++ b/src/drf_yasg/utils.py @@ -11,7 +11,7 @@ from rest_framework.utils import encoders, json from rest_framework.views import APIView -from .app_settings import swagger_settings +from .app_settings import swagger_settings as _swagger_settings logger = logging.getLogger(__name__) @@ -322,7 +322,7 @@ def get_consumes(parser_classes): return media_types -def get_produces(renderer_classes): +def get_produces(renderer_classes, swagger_settings=_swagger_settings): """Extract ``produces`` MIME types from a list of renderer classes. :param list renderer_classes: renderer classes diff --git a/src/drf_yasg/views.py b/src/drf_yasg/views.py index f6d58f42..6dc24bba 100644 --- a/src/drf_yasg/views.py +++ b/src/drf_yasg/views.py @@ -10,7 +10,7 @@ from rest_framework.settings import api_settings from rest_framework.views import APIView -from .app_settings import swagger_settings +from .app_settings import redoc_settings as _redoc_settings, swagger_settings as _swagger_settings from .renderers import ( OpenAPIRenderer, ReDocOldRenderer, ReDocRenderer, SwaggerJSONRenderer, SwaggerUIRenderer, SwaggerYAMLRenderer ) @@ -48,9 +48,9 @@ def callback(response): def get_schema_view(info=None, url=None, patterns=None, urlconf=None, public=False, validators=None, - generator_class=swagger_settings.DEFAULT_GENERATOR_CLASS, - authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, - permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES): + generator_class=None, authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, + permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES, swagger_settings=_swagger_settings, + redoc_settings=_redoc_settings): """Create a SchemaView class with default renderers and generators. :param .Info info: information about the API; if omitted, defaults to :ref:`DEFAULT_INFO ` @@ -66,7 +66,7 @@ def get_schema_view(info=None, url=None, patterns=None, urlconf=None, public=Fal :rtype: type[.SchemaView] """ _public = public - _generator_class = generator_class + _generator_class = generator_class or swagger_settings.DEFAULT_GENERATOR_CLASS _auth_classes = authentication_classes _perm_classes = permission_classes info = info or swagger_settings.DEFAULT_INFO @@ -83,7 +83,8 @@ class SchemaView(APIView): renderer_classes = _spec_renderers def get(self, request, version='', format=None): - generator = self.generator_class(info, request.version or version or '', url, patterns, urlconf) + generator = self.generator_class( + info, request.version or version or '', url, patterns, urlconf, swagger_settings) schema = generator.get_schema(request, self.public) if schema is None: raise exceptions.PermissionDenied() # pragma: no cover @@ -143,8 +144,30 @@ def with_ui(cls, renderer='swagger', cache_timeout=0, cache_kwargs=None): :return: a view instance """ assert renderer in UI_RENDERERS, "supported default renderers are " + ", ".join(UI_RENDERERS) - renderer_classes = UI_RENDERERS[renderer] + _spec_renderers + _local_swagger_settings = swagger_settings + _local_redoc_settings = redoc_settings + renderer_classes = [] + for renderer_class in UI_RENDERERS[renderer]: + if issubclass(renderer_class, SwaggerUIRenderer): + if _local_swagger_settings is _swagger_settings: + renderer_classes.append(renderer_class) + else: + class CustomSettingsSwaggerRenderer(renderer_class): + swagger_settings = _local_swagger_settings + + renderer_classes.append(CustomSettingsSwaggerRenderer) + + elif issubclass(renderer_class, ReDocRenderer): + if _local_redoc_settings is _redoc_settings: + renderer_classes.append(renderer_class) + else: + class CustomSettingsRedDocRenderer(renderer_class): + redoc_settings = _local_redoc_settings + + renderer_classes.append(CustomSettingsRedDocRenderer) + + renderer_classes.extend(_spec_renderers) return cls.as_cached_view(cache_timeout, cache_kwargs, renderer_classes=renderer_classes) return SchemaView