diff --git a/src/drf_yasg/generators.py b/src/drf_yasg/generators.py index 80489dee..11e8f9a7 100644 --- a/src/drf_yasg/generators.py +++ b/src/drf_yasg/generators.py @@ -14,7 +14,7 @@ from rest_framework.settings import api_settings from . import openapi -from .app_settings import swagger_settings +from .app_settings import swagger_settings as _swagger_settings from .errors import SwaggerGenerationError from .inspectors.field import get_basic_type_info, get_queryset_field, get_queryset_from_view from .openapi import ReferenceResolver, SwaggerDict @@ -190,7 +190,7 @@ class OpenAPISchemaGenerator: 'delete': 'destroy', } - def __init__(self, info, version='', url=None, patterns=None, urlconf=None): + def __init__(self, info, version='', url=None, patterns=None, urlconf=None, swagger_settings=_swagger_settings): """ :param openapi.Info info: information about the API @@ -206,9 +206,11 @@ def __init__(self, info, version='', url=None, patterns=None, urlconf=None): :param patterns: if given, only these patterns will be enumerated for inclusion in the API spec :param urlconf: if patterns is not given, use this urlconf to enumerate patterns; if not given, the default urlconf is used + :param swagger_settings: if given global swagger_settings are overridden with local settings """ self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf) self.info = info + self.swagger_settings = swagger_settings self.version = version self.consumes = [] self.produces = [] @@ -237,7 +239,7 @@ def get_security_definitions(self): :return: the security schemes usable with this API :rtype: dict[str,dict] or None """ - security_definitions = swagger_settings.SECURITY_DEFINITIONS + security_definitions = self.swagger_settings.SECURITY_DEFINITIONS if security_definitions is not None: security_definitions = SwaggerDict._as_odict(security_definitions, {}) @@ -251,7 +253,7 @@ def get_security_requirements(self, security_definitions): :return: the security schemes accepted by default :rtype: list[dict[str,list[str]]] or None """ - security_requirements = swagger_settings.SECURITY_REQUIREMENTS + security_requirements = self.swagger_settings.SECURITY_REQUIREMENTS if security_requirements is None: security_requirements = [{security_scheme: []} for security_scheme in security_definitions] @@ -272,7 +274,7 @@ def get_schema(self, request=None, public=False): endpoints = self.get_endpoints(request) components = self.reference_resolver_class(openapi.SCHEMA_DEFINITIONS, force_init=True) self.consumes = get_consumes(api_settings.DEFAULT_PARSER_CLASSES) - self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES) + self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES, self.swagger_settings) paths, prefix = self.get_paths(endpoints, components, request, public) security_definitions = self.get_security_definitions() @@ -511,7 +513,7 @@ def get_operation(self, view, path, prefix, method, components, request): # the inspector class can be specified, in decreasing order of priority, # 1. globally via DEFAULT_AUTO_SCHEMA_CLASS - view_inspector_cls = swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS + view_inspector_cls = self.swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS # 2. on the view/viewset class view_inspector_cls = getattr(view, 'swagger_schema', view_inspector_cls) # 3. on the swagger_auto_schema decorator @@ -520,7 +522,8 @@ def get_operation(self, view, path, prefix, method, components, request): if view_inspector_cls is None: return None - view_inspector = view_inspector_cls(view, path, method, components, request, overrides, operation_keys) + view_inspector = view_inspector_cls( + view, path, method, components, request, overrides, operation_keys, self.swagger_settings) operation = view_inspector.get_operation(operation_keys) if operation is None: return None diff --git a/src/drf_yasg/inspectors/__init__.py b/src/drf_yasg/inspectors/__init__.py index 70bde798..7bade5c9 100644 --- a/src/drf_yasg/inspectors/__init__.py +++ b/src/drf_yasg/inspectors/__init__.py @@ -1,4 +1,3 @@ -from ..app_settings import swagger_settings from .base import ( BaseInspector, FieldInspector, FilterInspector, NotHandled, PaginatorInspector, SerializerInspector, ViewInspector ) @@ -10,11 +9,6 @@ from .query import DrfAPICompatInspector, CoreAPICompatInspector, DjangoRestResponsePagination from .view import SwaggerAutoSchema -# these settings must be accessed only after defining/importing all the classes in this module to avoid ImportErrors -ViewInspector.field_inspectors = swagger_settings.DEFAULT_FIELD_INSPECTORS -ViewInspector.filter_inspectors = swagger_settings.DEFAULT_FILTER_INSPECTORS -ViewInspector.paginator_inspectors = swagger_settings.DEFAULT_PAGINATOR_INSPECTORS - __all__ = [ # base inspectors 'BaseInspector', 'FilterInspector', 'PaginatorInspector', 'FieldInspector', 'SerializerInspector', 'ViewInspector', diff --git a/src/drf_yasg/inspectors/base.py b/src/drf_yasg/inspectors/base.py index 2660fb75..d0bdcb42 100644 --- a/src/drf_yasg/inspectors/base.py +++ b/src/drf_yasg/inspectors/base.py @@ -4,6 +4,7 @@ from rest_framework import serializers from .. import openapi +from ..app_settings import swagger_settings as _swagger_settings from ..utils import force_real_str, get_field_default, get_object_classes, is_list_view #: Sentinel value that inspectors must return to signal that they do not know how to handle an object @@ -52,19 +53,21 @@ def call_view_method(view, method_name, fallback_attr=None, default=None): class BaseInspector: - def __init__(self, view, path, method, components, request): + def __init__(self, view, path, method, components, request, swagger_settings=_swagger_settings): """ :param rest_framework.views.APIView view: the view associated with this endpoint :param str path: the path component of the operation URL :param str method: the http method of the operation :param openapi.ReferenceResolver components: referenceable components :param rest_framework.request.Request request: the request made against the schema view; can be None + :param swagger_settings: if given global swagger_settings are overridden with local settings """ self.view = view self.path = path self.method = method self.components = components self.request = request + self.swagger_settings = swagger_settings def process_result(self, result, method_name, obj, **kwargs): """After an inspector handles an object (i.e. returns a value other than :data:`.NotHandled`), all inspectors @@ -185,8 +188,8 @@ def get_filter_parameters(self, filter_backend): class FieldInspector(BaseInspector): """Base inspector for serializers and serializer fields. """ - def __init__(self, view, path, method, components, request, field_inspectors): - super(FieldInspector, self).__init__(view, path, method, components, request) + def __init__(self, view, path, method, components, request, field_inspectors, swagger_settings=_swagger_settings): + super(FieldInspector, self).__init__(view, path, method, components, request, swagger_settings) self.field_inspectors = field_inspectors def add_manual_fields(self, serializer_or_field, schema): @@ -338,18 +341,42 @@ class ViewInspector(BaseInspector): #: methods which are assumed to return a list of objects when present on non-detail endpoints implicit_list_response_methods = ('GET',) - # real values set in __init__ to prevent import errors - field_inspectors = [] #: - filter_inspectors = [] #: - paginator_inspectors = [] #: + _field_inspectors = None + _filter_inspectors = None + _paginator_inspectors = None - def __init__(self, view, path, method, components, request, overrides): + @property + def field_inspectors(self): + return self._field_inspectors or self.swagger_settings.DEFAULT_FIELD_INSPECTORS + + @field_inspectors.setter + def field_inspectors(self, value): + self._field_inspectors = value + + @property + def filter_inspectors(self): + return self._filter_inspectors or self.swagger_settings.DEFAULT_FILTER_INSPECTORS + + @filter_inspectors.setter + def filter_inspectors(self, value): + self._filter_inspectors = value + + @property + def paginator_inspectors(self): + return self._paginator_inspectors or self.swagger_settings.DEFAULT_PAGINATOR_INSPECTORS + + @paginator_inspectors.setter + def paginator_inspectors(self, value): + self._paginator_inspectors = value + + def __init__(self, view, path, method, components, request, overrides, swagger_settings=_swagger_settings): """ Inspector class responsible for providing :class:`.Operation` definitions given a view, path and method. :param dict overrides: manual overrides as passed to :func:`@swagger_auto_schema <.swagger_auto_schema>` + :param swagger_settings: if given global swagger_settings are overridden with local settings """ - super(ViewInspector, self).__init__(view, path, method, components, request) + super(ViewInspector, self).__init__(view, path, method, components, request, swagger_settings) self.overrides = overrides self._prepend_inspector_overrides('field_inspectors') self._prepend_inspector_overrides('filter_inspectors') diff --git a/src/drf_yasg/inspectors/view.py b/src/drf_yasg/inspectors/view.py index 71bfbf35..2936ea84 100644 --- a/src/drf_yasg/inspectors/view.py +++ b/src/drf_yasg/inspectors/view.py @@ -6,6 +6,7 @@ from rest_framework.status import is_success from .. import openapi +from ..app_settings import swagger_settings as _swagger_settings from ..errors import SwaggerGenerationError from ..utils import ( filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status, @@ -17,8 +18,9 @@ class SwaggerAutoSchema(ViewInspector): - def __init__(self, view, path, method, components, request, overrides, operation_keys=None): - super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides) + def __init__(self, view, path, method, components, request, overrides, operation_keys=None, + swagger_settings=_swagger_settings): + super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides, swagger_settings) self._sch = AutoSchema() self._sch.view = view self.operation_keys = operation_keys @@ -405,4 +407,4 @@ def get_produces(self): :rtype: list[str] """ - return get_produces(self.get_renderer_classes()) + return get_produces(self.get_renderer_classes(), swagger_settings=self.swagger_settings) diff --git a/src/drf_yasg/renderers.py b/src/drf_yasg/renderers.py index cf32cdf7..64a52635 100644 --- a/src/drf_yasg/renderers.py +++ b/src/drf_yasg/renderers.py @@ -6,7 +6,7 @@ from rest_framework.renderers import BaseRenderer, JSONRenderer, TemplateHTMLRenderer from rest_framework.utils import encoders, json -from .app_settings import redoc_settings, swagger_settings +from .app_settings import redoc_settings as _redoc_settings, swagger_settings as _swagger_settings from .codecs import VALIDATORS, OpenAPICodecJson, OpenAPICodecYaml from .openapi import Swagger from .utils import filter_none @@ -62,6 +62,7 @@ class _UIRenderer(BaseRenderer): media_type = 'text/html' charset = 'utf-8' template = '' + swagger_settings = _swagger_settings def render(self, swagger, accepted_media_type=None, renderer_context=None): if not isinstance(swagger, Swagger): # pragma: no cover @@ -82,7 +83,7 @@ def set_context(self, renderer_context, swagger=None): renderer_context['title'] = swagger.info.title or '' if swagger else '' renderer_context['version'] = swagger.info.version or '' if swagger else '' renderer_context['oauth2_config'] = json.dumps(self.get_oauth2_config(), cls=encoders.JSONEncoder) - renderer_context['USE_SESSION_AUTH'] = swagger_settings.USE_SESSION_AUTH + renderer_context['USE_SESSION_AUTH'] = self.swagger_settings.USE_SESSION_AUTH renderer_context.update(self.get_auth_urls()) def resolve_url(self, to): @@ -106,14 +107,14 @@ def resolve_url(self, to): def get_auth_urls(self): urls = { - 'LOGIN_URL': self.resolve_url(swagger_settings.LOGIN_URL), - 'LOGOUT_URL': self.resolve_url(swagger_settings.LOGOUT_URL), + 'LOGIN_URL': self.resolve_url(self.swagger_settings.LOGIN_URL), + 'LOGOUT_URL': self.resolve_url(self.swagger_settings.LOGOUT_URL), } return filter_none(urls) def get_oauth2_config(self): - data = swagger_settings.OAUTH2_CONFIG + data = self.swagger_settings.OAUTH2_CONFIG assert isinstance(data, dict), "OAUTH2_CONFIG must be a dict" return data @@ -136,31 +137,31 @@ def set_context(self, renderer_context, swagger=None): def get_swagger_ui_settings(self): data = { - 'url': self.resolve_url(swagger_settings.SPEC_URL), - 'operationsSorter': swagger_settings.OPERATIONS_SORTER, - 'tagsSorter': swagger_settings.TAGS_SORTER, - 'docExpansion': swagger_settings.DOC_EXPANSION, - 'deepLinking': swagger_settings.DEEP_LINKING, - 'showExtensions': swagger_settings.SHOW_EXTENSIONS, - 'defaultModelRendering': swagger_settings.DEFAULT_MODEL_RENDERING, - 'defaultModelExpandDepth': swagger_settings.DEFAULT_MODEL_DEPTH, - 'defaultModelsExpandDepth': swagger_settings.DEFAULT_MODEL_DEPTH, - 'showCommonExtensions': swagger_settings.SHOW_COMMON_EXTENSIONS, - 'oauth2RedirectUrl': swagger_settings.OAUTH2_REDIRECT_URL, - 'supportedSubmitMethods': swagger_settings.SUPPORTED_SUBMIT_METHODS, - 'displayOperationId': swagger_settings.DISPLAY_OPERATION_ID, - 'persistAuth': swagger_settings.PERSIST_AUTH, - 'refetchWithAuth': swagger_settings.REFETCH_SCHEMA_WITH_AUTH, - 'refetchOnLogout': swagger_settings.REFETCH_SCHEMA_ON_LOGOUT, - 'fetchSchemaWithQuery': swagger_settings.FETCH_SCHEMA_WITH_QUERY, - 'csrfCookie': swagger_settings.CSRF_COOKIE_NAME, + 'url': self.resolve_url(self.swagger_settings.SPEC_URL), + 'operationsSorter': self.swagger_settings.OPERATIONS_SORTER, + 'tagsSorter': self.swagger_settings.TAGS_SORTER, + 'docExpansion': self.swagger_settings.DOC_EXPANSION, + 'deepLinking': self.swagger_settings.DEEP_LINKING, + 'showExtensions': self.swagger_settings.SHOW_EXTENSIONS, + 'defaultModelRendering': self.swagger_settings.DEFAULT_MODEL_RENDERING, + 'defaultModelExpandDepth': self.swagger_settings.DEFAULT_MODEL_DEPTH, + 'defaultModelsExpandDepth': self.swagger_settings.DEFAULT_MODEL_DEPTH, + 'showCommonExtensions': self.swagger_settings.SHOW_COMMON_EXTENSIONS, + 'oauth2RedirectUrl': self.swagger_settings.OAUTH2_REDIRECT_URL, + 'supportedSubmitMethods': self.swagger_settings.SUPPORTED_SUBMIT_METHODS, + 'displayOperationId': self.swagger_settings.DISPLAY_OPERATION_ID, + 'persistAuth': self.swagger_settings.PERSIST_AUTH, + 'refetchWithAuth': self.swagger_settings.REFETCH_SCHEMA_WITH_AUTH, + 'refetchOnLogout': self.swagger_settings.REFETCH_SCHEMA_ON_LOGOUT, + 'fetchSchemaWithQuery': self.swagger_settings.FETCH_SCHEMA_WITH_QUERY, + 'csrfCookie': self.swagger_settings.CSRF_COOKIE_NAME, # remove HTTP_ and convert underscores to dashes - 'csrfHeader': swagger_settings.CSRF_HEADER_NAME[5:].replace('_', '-'), + 'csrfHeader': self.swagger_settings.CSRF_HEADER_NAME[5:].replace('_', '-'), } data = filter_none(data) - if swagger_settings.VALIDATOR_URL != '': - data['validatorUrl'] = self.resolve_url(swagger_settings.VALIDATOR_URL) + if self.swagger_settings.VALIDATOR_URL != '': + data['validatorUrl'] = self.resolve_url(self.swagger_settings.VALIDATOR_URL) return data @@ -169,6 +170,7 @@ class ReDocRenderer(_UIRenderer): """Renders a ReDoc web interface for schema browsing.""" template = 'drf-yasg/redoc.html' format = 'redoc' + redoc_settings = _redoc_settings def set_context(self, renderer_context, swagger=None): super(ReDocRenderer, self).set_context(renderer_context, swagger) @@ -176,14 +178,14 @@ def set_context(self, renderer_context, swagger=None): def get_redoc_settings(self): data = { - 'url': self.resolve_url(redoc_settings.SPEC_URL), - 'lazyRendering': redoc_settings.LAZY_RENDERING, - 'hideHostname': redoc_settings.HIDE_HOSTNAME, - 'expandResponses': redoc_settings.EXPAND_RESPONSES, - 'pathInMiddlePanel': redoc_settings.PATH_IN_MIDDLE, - 'nativeScrollbars': redoc_settings.NATIVE_SCROLLBARS, - 'requiredPropsFirst': redoc_settings.REQUIRED_PROPS_FIRST, - 'fetchSchemaWithQuery': redoc_settings.FETCH_SCHEMA_WITH_QUERY, + 'url': self.resolve_url(self.redoc_settings.SPEC_URL), + 'lazyRendering': self.redoc_settings.LAZY_RENDERING, + 'hideHostname': self.redoc_settings.HIDE_HOSTNAME, + 'expandResponses': self.redoc_settings.EXPAND_RESPONSES, + 'pathInMiddlePanel': self.redoc_settings.PATH_IN_MIDDLE, + 'nativeScrollbars': self.redoc_settings.NATIVE_SCROLLBARS, + 'requiredPropsFirst': self.redoc_settings.REQUIRED_PROPS_FIRST, + 'fetchSchemaWithQuery': self.redoc_settings.FETCH_SCHEMA_WITH_QUERY, } return filter_none(data) diff --git a/src/drf_yasg/utils.py b/src/drf_yasg/utils.py index 494ce48c..c78a4821 100644 --- a/src/drf_yasg/utils.py +++ b/src/drf_yasg/utils.py @@ -15,7 +15,7 @@ from rest_framework.utils import encoders, json from rest_framework.views import APIView -from .app_settings import swagger_settings +from .app_settings import swagger_settings as _swagger_settings logger = logging.getLogger(__name__) @@ -387,7 +387,7 @@ def get_consumes(parser_classes): return non_form_media_types -def get_produces(renderer_classes): +def get_produces(renderer_classes, swagger_settings=_swagger_settings): """Extract ``produces`` MIME types from a list of renderer classes. :param list renderer_classes: renderer classes diff --git a/src/drf_yasg/views.py b/src/drf_yasg/views.py index 977f2f4a..a4b3015b 100644 --- a/src/drf_yasg/views.py +++ b/src/drf_yasg/views.py @@ -9,7 +9,7 @@ from rest_framework.settings import api_settings from rest_framework.views import APIView -from .app_settings import swagger_settings +from .app_settings import redoc_settings as _redoc_settings, swagger_settings as _swagger_settings from .renderers import ( ReDocOldRenderer, ReDocRenderer, @@ -17,7 +17,6 @@ _SpecRenderer, ) -SPEC_RENDERERS = swagger_settings.DEFAULT_SPEC_RENDERERS UI_RENDERERS = { 'swagger': (SwaggerUIRenderer, ReDocRenderer), 'redoc': (ReDocRenderer, SwaggerUIRenderer), @@ -50,7 +49,8 @@ def callback(response): def get_schema_view(info=None, url=None, patterns=None, urlconf=None, public=False, validators=None, - generator_class=None, authentication_classes=None, permission_classes=None): + generator_class=None, authentication_classes=None, permission_classes=None, + swagger_settings=_swagger_settings, redoc_settings=_redoc_settings): """Create a SchemaView class with default renderers and generators. :param Info info: information about the API; if omitted, defaults to :ref:`DEFAULT_INFO ` @@ -75,7 +75,10 @@ def get_schema_view(info=None, url=None, patterns=None, urlconf=None, public=Fal _perm_classes = api_settings.DEFAULT_PERMISSION_CLASSES info = info or swagger_settings.DEFAULT_INFO validators = validators or [] - _spec_renderers = tuple(renderer.with_validators(validators) for renderer in SPEC_RENDERERS) + _spec_renderers = tuple( + renderer.with_validators(validators) + for renderer in swagger_settings.DEFAULT_SPEC_RENDERERS + ) class SchemaView(APIView): _ignore_model_permissions = True @@ -89,9 +92,9 @@ class SchemaView(APIView): def get(self, request, version='', format=None): version = request.version or version or '' if isinstance(request.accepted_renderer, _SpecRenderer): - generator = self.generator_class(info, version, url, patterns, urlconf) + generator = self.generator_class(info, version, url, patterns, urlconf, swagger_settings) else: - generator = self.generator_class(info, version, url, patterns=[]) + generator = self.generator_class(info, version, url, patterns=[], swagger_settings=swagger_settings) schema = generator.get_schema(request, self.public) if schema is None: @@ -152,8 +155,30 @@ def with_ui(cls, renderer='swagger', cache_timeout=0, cache_kwargs=None): :return: a view instance """ assert renderer in UI_RENDERERS, "supported default renderers are " + ", ".join(UI_RENDERERS) - renderer_classes = UI_RENDERERS[renderer] + _spec_renderers + _local_swagger_settings = swagger_settings + _local_redoc_settings = redoc_settings + renderer_classes = [] + for renderer_class in UI_RENDERERS[renderer]: + if issubclass(renderer_class, SwaggerUIRenderer): + if _local_swagger_settings is _swagger_settings: + renderer_classes.append(renderer_class) + else: + class CustomSettingsSwaggerRenderer(renderer_class): + swagger_settings = _local_swagger_settings + + renderer_classes.append(CustomSettingsSwaggerRenderer) + + elif issubclass(renderer_class, ReDocRenderer): + if _local_redoc_settings is _redoc_settings: + renderer_classes.append(renderer_class) + else: + class CustomSettingsRedDocRenderer(renderer_class): + redoc_settings = _local_redoc_settings + + renderer_classes.append(CustomSettingsRedDocRenderer) + + renderer_classes.extend(_spec_renderers) return cls.as_cached_view(cache_timeout, cache_kwargs, renderer_classes=renderer_classes) return SchemaView diff --git a/tests/test_reference_schema.py b/tests/test_reference_schema.py index b860b582..eb019b47 100644 --- a/tests/test_reference_schema.py +++ b/tests/test_reference_schema.py @@ -36,8 +36,9 @@ def test_noop_inspectors(swagger_settings, mock_schema_request, codec_json, refe from drf_yasg import app_settings def set_inspectors(inspectors, setting_name): + existing = swagger_settings.get(setting_name, app_settings.SWAGGER_DEFAULTS[setting_name]) inspectors = [__name__ + '.' + inspector.__name__ for inspector in inspectors] - swagger_settings[setting_name] = inspectors + app_settings.SWAGGER_DEFAULTS[setting_name] + swagger_settings[setting_name] = inspectors + existing set_inspectors([NoOpFieldInspector, NoOpSerializerInspector], 'DEFAULT_FIELD_INSPECTORS') set_inspectors([NoOpFilterInspector], 'DEFAULT_FILTER_INSPECTORS') diff --git a/tests/test_schema_views.py b/tests/test_schema_views.py index 2e467c5a..2e4814e7 100644 --- a/tests/test_schema_views.py +++ b/tests/test_schema_views.py @@ -87,6 +87,7 @@ def test_paginator_schema(client, swagger_settings): swagger_settings['DEFAULT_PAGINATOR_INSPECTORS'] = [ 'drf_yasg.inspectors.CoreAPICompatInspector', 'drf_yasg.inspectors.DrfAPICompatInspector', + 'drf_yasg.inspectors.DjangoRestResponsePagination', ] response = client.get('/versioned/url/v1.0/swagger.yaml')