Skip to content

Commit

Permalink
allow swagger and redoc settings to be overridden
Browse files Browse the repository at this point in the history
When using site middleware there is no way to configure swagger
instances to have different login URLs or authentication mechanisms.
This change updates any global references to app settings to prefer a
local setting if it exists, and for backward compatibility it refers to
the global setting by default.

The following calls will need to be adjusted to account for the new
swagger_settings parameter that is passed:

- ``drf_yasg.utils.get_produces(...)``
- ViewInspector class (and user defined subclasses)
- Generator class (and user defined subclasses)
  • Loading branch information
terencehonles committed Sep 15, 2018
1 parent 120c4dd commit ff5d574
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 60 deletions.
24 changes: 14 additions & 10 deletions src/drf_yasg/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from rest_framework.settings import api_settings as rest_framework_settings

from . import openapi
from .app_settings import swagger_settings
from .app_settings import swagger_settings as _swagger_settings
from .errors import SwaggerGenerationError
from .inspectors.field import get_basic_type_info, get_queryset_field, get_queryset_from_view
from .openapi import ReferenceResolver
Expand Down Expand Up @@ -161,7 +161,7 @@ class OpenAPISchemaGenerator(object):
"""
endpoint_enumerator_class = EndpointEnumerator

def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
def __init__(self, info, version='', url=None, patterns=None, urlconf=None, swagger_settings=_swagger_settings):
"""
:param .Info info: information about the API
Expand All @@ -177,9 +177,11 @@ def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
:param patterns: if given, only these patterns will be enumerated for inclusion in the API spec
:param urlconf: if patterns is not given, use this urlconf to enumerate patterns;
if not given, the default urlconf is used
:param swagger_settings: if given global swagger_settings are overridden with local settings
"""
self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf)
self.info = info
self.swagger_settings = swagger_settings
self.version = version
self.consumes = []
self.produces = []
Expand Down Expand Up @@ -211,16 +213,18 @@ def get_schema(self, request=None, public=False):
endpoints = self.get_endpoints(request)
components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS)
self.consumes = get_consumes(rest_framework_settings.DEFAULT_PARSER_CLASSES)
self.produces = get_produces(rest_framework_settings.DEFAULT_RENDERER_CLASSES)
self.produces = get_produces(rest_framework_settings.DEFAULT_RENDERER_CLASSES, self.swagger_settings)
paths, prefix = self.get_paths(endpoints, components, request, public)

security_definitions = swagger_settings.SECURITY_DEFINITIONS
security_definitions = self.swagger_settings.SECURITY_DEFINITIONS
if security_definitions is not None:
security_definitions = OrderedDict(sorted([(key, OrderedDict(sorted(sd.items())))
for key, sd in swagger_settings.SECURITY_DEFINITIONS.items()]))
security_requirements = swagger_settings.SECURITY_REQUIREMENTS
security_definitions = OrderedDict(sorted(
(key, OrderedDict(sorted(sd.items())))
for key, sd in self.swagger_settings.SECURITY_DEFINITIONS.items()))
security_requirements = self.swagger_settings.SECURITY_REQUIREMENTS
if security_requirements is None:
security_requirements = [{security_scheme: []} for security_scheme in swagger_settings.SECURITY_DEFINITIONS]
security_requirements = [
{security_scheme: []} for security_scheme in self.swagger_settings.SECURITY_DEFINITIONS]

security_requirements = sorted(security_requirements, key=lambda od: list(sorted(od)))
security_requirements = [OrderedDict(sorted(sr.items())) for sr in security_requirements]
Expand Down Expand Up @@ -366,7 +370,7 @@ def get_operation(self, view, path, prefix, method, components, request):

# the inspector class can be specified, in decreasing order of priorty,
# 1. globaly via DEFAULT_AUTO_SCHEMA_CLASS
view_inspector_cls = swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS
view_inspector_cls = self.swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS
# 2. on the view/viewset class
view_inspector_cls = getattr(view, 'swagger_schema', view_inspector_cls)
# 3. on the swagger_auto_schema decorator
Expand All @@ -375,7 +379,7 @@ def get_operation(self, view, path, prefix, method, components, request):
if view_inspector_cls is None:
return None

view_inspector = view_inspector_cls(view, path, method, components, request, overrides)
view_inspector = view_inspector_cls(view, path, method, components, request, overrides, self.swagger_settings)
operation = view_inspector.get_operation(operation_keys)
if operation is None:
return None
Expand Down
6 changes: 0 additions & 6 deletions src/drf_yasg/inspectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from ..app_settings import swagger_settings
from .base import (
BaseInspector, FieldInspector, FilterInspector, NotHandled, PaginatorInspector, SerializerInspector, ViewInspector
)
Expand All @@ -10,11 +9,6 @@
from .query import CoreAPICompatInspector, DjangoRestResponsePagination
from .view import SwaggerAutoSchema

# these settings must be accessed only after defining/importing all the classes in this module to avoid ImportErrors
ViewInspector.field_inspectors = swagger_settings.DEFAULT_FIELD_INSPECTORS
ViewInspector.filter_inspectors = swagger_settings.DEFAULT_FILTER_INSPECTORS
ViewInspector.paginator_inspectors = swagger_settings.DEFAULT_PAGINATOR_INSPECTORS

__all__ = [
# base inspectors
'BaseInspector', 'FilterInspector', 'PaginatorInspector', 'FieldInspector', 'SerializerInspector', 'ViewInspector',
Expand Down
45 changes: 36 additions & 9 deletions src/drf_yasg/inspectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from rest_framework import serializers

from .. import openapi
from ..app_settings import swagger_settings as _swagger_settings
from ..utils import force_real_str, get_field_default, is_list_view

#: Sentinel value that inspectors must return to signal that they do not know how to handle an object
Expand All @@ -13,19 +14,21 @@


class BaseInspector(object):
def __init__(self, view, path, method, components, request):
def __init__(self, view, path, method, components, request, swagger_settings=_swagger_settings):
"""
:param view: the view associated with this endpoint
:param str path: the path component of the operation URL
:param str method: the http method of the operation
:param openapi.ReferenceResolver components: referenceable components
:param Request request: the request made against the schema view; can be None
:param swagger_settings: if given global swagger_settings are overridden with local settings
"""
self.view = view
self.path = path
self.method = method
self.components = components
self.request = request
self.swagger_settings = swagger_settings

def process_result(self, result, method_name, obj, **kwargs):
"""After an inspector handles an object (i.e. returns a value other than :data:`.NotHandled`), all inspectors
Expand Down Expand Up @@ -130,8 +133,8 @@ def get_filter_parameters(self, filter_backend):
class FieldInspector(BaseInspector):
"""Base inspector for serializers and serializer fields. """

def __init__(self, view, path, method, components, request, field_inspectors):
super(FieldInspector, self).__init__(view, path, method, components, request)
def __init__(self, view, path, method, components, request, field_inspectors, swagger_settings=_swagger_settings):
super(FieldInspector, self).__init__(view, path, method, components, request, swagger_settings)
self.field_inspectors = field_inspectors

def add_manual_fields(self, serializer_or_field, schema):
Expand Down Expand Up @@ -275,18 +278,42 @@ class ViewInspector(BaseInspector):
#: methods that are assumed to require a request body determined by the view's ``serializer_class``
implicit_body_methods = ('PUT', 'PATCH', 'POST')

# real values set in __init__ to prevent import errors
field_inspectors = [] #:
filter_inspectors = [] #:
paginator_inspectors = [] #:
_field_inspectors = None
_filter_inspectors = None
_paginator_inspectors = None

def __init__(self, view, path, method, components, request, overrides):
@property
def field_inspectors(self):
return self._field_inspectors or self.swagger_settings.DEFAULT_FIELD_INSPECTORS

@field_inspectors.setter
def field_inspectors(self, value):
self._field_inspectors = value

@property
def filter_inspectors(self):
return self._filter_inspectors or self.swagger_settings.DEFAULT_FILTER_INSPECTORS

@filter_inspectors.setter
def filter_inspectors(self, value):
self._filter_inspectors = value

@property
def paginator_inspectors(self):
return self._paginator_inspectors or self.swagger_settings.DEFAULT_PAGINATOR_INSPECTORS

@paginator_inspectors.setter
def paginator_inspectors(self, value):
self._paginator_inspectors = value

def __init__(self, view, path, method, components, request, overrides, swagger_settings=_swagger_settings):
"""
Inspector class responsible for providing :class:`.Operation` definitions given a view, path and method.
:param dict overrides: manual overrides as passed to :func:`@swagger_auto_schema <.swagger_auto_schema>`
:param swagger_settings: if given global swagger_settings are overridden with local settings
"""
super(ViewInspector, self).__init__(view, path, method, components, request)
super(ViewInspector, self).__init__(view, path, method, components, request, swagger_settings)
self.overrides = overrides
self._prepend_inspector_overrides('field_inspectors')
self._prepend_inspector_overrides('filter_inspectors')
Expand Down
5 changes: 3 additions & 2 deletions src/drf_yasg/inspectors/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from rest_framework.status import is_success

from .. import openapi
from ..app_settings import swagger_settings as _swagger_settings
from ..errors import SwaggerGenerationError
from ..utils import (
force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status, is_list_view,
Expand All @@ -17,8 +18,8 @@


class SwaggerAutoSchema(ViewInspector):
def __init__(self, view, path, method, components, request, overrides):
super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides)
def __init__(self, view, path, method, components, request, overrides, swagger_settings=_swagger_settings):
super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides, swagger_settings)
self._sch = AutoSchema()
self._sch.view = view

Expand Down
56 changes: 29 additions & 27 deletions src/drf_yasg/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rest_framework.renderers import BaseRenderer, JSONRenderer, TemplateHTMLRenderer
from rest_framework.utils import json

from .app_settings import redoc_settings, swagger_settings
from .app_settings import redoc_settings as _redoc_settings, swagger_settings as _swagger_settings
from .codecs import VALIDATORS, OpenAPICodecJson, OpenAPICodecYaml
from .openapi import Swagger
from .utils import filter_none
Expand Down Expand Up @@ -60,6 +60,7 @@ class _UIRenderer(BaseRenderer):
media_type = 'text/html'
charset = 'utf-8'
template = ''
swagger_settings = _swagger_settings

def render(self, swagger, accepted_media_type=None, renderer_context=None):
if not isinstance(swagger, Swagger): # pragma: no cover
Expand All @@ -78,7 +79,7 @@ def set_context(self, renderer_context, swagger):
renderer_context['title'] = swagger.info.title
renderer_context['version'] = swagger.info.version
renderer_context['oauth2_config'] = json.dumps(self.get_oauth2_config())
renderer_context['USE_SESSION_AUTH'] = swagger_settings.USE_SESSION_AUTH
renderer_context['USE_SESSION_AUTH'] = self.swagger_settings.USE_SESSION_AUTH
renderer_context.update(self.get_auth_urls())

def resolve_url(self, to):
Expand All @@ -102,14 +103,14 @@ def resolve_url(self, to):

def get_auth_urls(self):
urls = {
'LOGIN_URL': self.resolve_url(swagger_settings.LOGIN_URL),
'LOGOUT_URL': self.resolve_url(swagger_settings.LOGOUT_URL),
'LOGIN_URL': self.resolve_url(self.swagger_settings.LOGIN_URL),
'LOGOUT_URL': self.resolve_url(self.swagger_settings.LOGOUT_URL),
}

return filter_none(urls)

def get_oauth2_config(self):
data = swagger_settings.OAUTH2_CONFIG
data = self.swagger_settings.OAUTH2_CONFIG
assert isinstance(data, dict), "OAUTH2_CONFIG must be a dict"
return data

Expand All @@ -125,24 +126,24 @@ def set_context(self, renderer_context, swagger):

def get_swagger_ui_settings(self):
data = {
'url': self.resolve_url(swagger_settings.SPEC_URL),
'operationsSorter': swagger_settings.OPERATIONS_SORTER,
'tagsSorter': swagger_settings.TAGS_SORTER,
'docExpansion': swagger_settings.DOC_EXPANSION,
'deepLinking': swagger_settings.DEEP_LINKING,
'showExtensions': swagger_settings.SHOW_EXTENSIONS,
'defaultModelRendering': swagger_settings.DEFAULT_MODEL_RENDERING,
'defaultModelExpandDepth': swagger_settings.DEFAULT_MODEL_DEPTH,
'defaultModelsExpandDepth': swagger_settings.DEFAULT_MODEL_DEPTH,
'showCommonExtensions': swagger_settings.SHOW_COMMON_EXTENSIONS,
'oauth2RedirectUrl': swagger_settings.OAUTH2_REDIRECT_URL,
'supportedSubmitMethods': swagger_settings.SUPPORTED_SUBMIT_METHODS,
'displayOperationId': swagger_settings.DISPLAY_OPERATION_ID,
'url': self.resolve_url(self.swagger_settings.SPEC_URL),
'operationsSorter': self.swagger_settings.OPERATIONS_SORTER,
'tagsSorter': self.swagger_settings.TAGS_SORTER,
'docExpansion': self.swagger_settings.DOC_EXPANSION,
'deepLinking': self.swagger_settings.DEEP_LINKING,
'showExtensions': self.swagger_settings.SHOW_EXTENSIONS,
'defaultModelRendering': self.swagger_settings.DEFAULT_MODEL_RENDERING,
'defaultModelExpandDepth': self.swagger_settings.DEFAULT_MODEL_DEPTH,
'defaultModelsExpandDepth': self.swagger_settings.DEFAULT_MODEL_DEPTH,
'showCommonExtensions': self.swagger_settings.SHOW_COMMON_EXTENSIONS,
'oauth2RedirectUrl': self.swagger_settings.OAUTH2_REDIRECT_URL,
'supportedSubmitMethods': self.swagger_settings.SUPPORTED_SUBMIT_METHODS,
'displayOperationId': self.swagger_settings.DISPLAY_OPERATION_ID,
}

data = filter_none(data)
if swagger_settings.VALIDATOR_URL != '':
data['validatorUrl'] = self.resolve_url(swagger_settings.VALIDATOR_URL)
if self.swagger_settings.VALIDATOR_URL != '':
data['validatorUrl'] = self.resolve_url(self.swagger_settings.VALIDATOR_URL)

return data

Expand All @@ -151,20 +152,21 @@ class ReDocRenderer(_UIRenderer):
"""Renders a ReDoc web interface for schema browisng."""
template = 'drf-yasg/redoc.html'
format = 'redoc'
redoc_settings = _redoc_settings

def set_context(self, renderer_context, swagger):
super(ReDocRenderer, self).set_context(renderer_context, swagger)
renderer_context['redoc_settings'] = json.dumps(self.get_redoc_settings())

def get_redoc_settings(self):
data = {
'url': self.resolve_url(redoc_settings.SPEC_URL),
'lazyRendering': redoc_settings.LAZY_RENDERING,
'hideHostname': redoc_settings.HIDE_HOSTNAME,
'expandResponses': redoc_settings.EXPAND_RESPONSES,
'pathInMiddlePanel': redoc_settings.PATH_IN_MIDDLE,
'nativeScrollbars': redoc_settings.NATIVE_SCROLLBARS,
'requiredPropsFirst': redoc_settings.REQUIRED_PROPS_FIRST,
'url': self.resolve_url(self.redoc_settings.SPEC_URL),
'lazyRendering': self.redoc_settings.LAZY_RENDERING,
'hideHostname': self.redoc_settings.HIDE_HOSTNAME,
'expandResponses': self.redoc_settings.EXPAND_RESPONSES,
'pathInMiddlePanel': self.redoc_settings.PATH_IN_MIDDLE,
'nativeScrollbars': self.redoc_settings.NATIVE_SCROLLBARS,
'requiredPropsFirst': self.redoc_settings.REQUIRED_PROPS_FIRST,
}

return filter_none(data)
Expand Down
4 changes: 2 additions & 2 deletions src/drf_yasg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rest_framework.utils import encoders, json
from rest_framework.views import APIView

from .app_settings import swagger_settings
from .app_settings import swagger_settings as _swagger_settings

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -340,7 +340,7 @@ def get_consumes(parser_classes):
return media_types


def get_produces(renderer_classes):
def get_produces(renderer_classes, swagger_settings=_swagger_settings):
"""Extract ``produces`` MIME types from a list of renderer classes.
:param list renderer_classes: renderer classes
Expand Down
32 changes: 28 additions & 4 deletions src/drf_yasg/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rest_framework.settings import api_settings
from rest_framework.views import APIView

from .app_settings import swagger_settings
from .app_settings import redoc_settings as _redoc_settings, swagger_settings as _swagger_settings
from .renderers import (
OpenAPIRenderer, ReDocOldRenderer, ReDocRenderer, SwaggerJSONRenderer, SwaggerUIRenderer, SwaggerYAMLRenderer
)
Expand Down Expand Up @@ -48,7 +48,8 @@ def callback(response):


def get_schema_view(info=None, url=None, patterns=None, urlconf=None, public=False, validators=None,
generator_class=None, authentication_classes=None, permission_classes=None):
generator_class=None, authentication_classes=None, permission_classes=None,
swagger_settings=_swagger_settings, redoc_settings=_redoc_settings):
"""Create a SchemaView class with default renderers and generators.
:param .Info info: information about the API; if omitted, defaults to :ref:`DEFAULT_INFO <default-swagger-settings>`
Expand Down Expand Up @@ -85,7 +86,8 @@ class SchemaView(APIView):
renderer_classes = _spec_renderers

def get(self, request, version='', format=None):
generator = self.generator_class(info, request.version or version or '', url, patterns, urlconf)
generator = self.generator_class(
info, request.version or version or '', url, patterns, urlconf, swagger_settings)
schema = generator.get_schema(request, self.public)
if schema is None:
raise exceptions.PermissionDenied() # pragma: no cover
Expand Down Expand Up @@ -145,8 +147,30 @@ def with_ui(cls, renderer='swagger', cache_timeout=0, cache_kwargs=None):
:return: a view instance
"""
assert renderer in UI_RENDERERS, "supported default renderers are " + ", ".join(UI_RENDERERS)
renderer_classes = UI_RENDERERS[renderer] + _spec_renderers

_local_swagger_settings = swagger_settings
_local_redoc_settings = redoc_settings
renderer_classes = []
for renderer_class in UI_RENDERERS[renderer]:
if issubclass(renderer_class, SwaggerUIRenderer):
if _local_swagger_settings is _swagger_settings:
renderer_classes.append(renderer_class)
else:
class CustomSettingsSwaggerRenderer(renderer_class):
swagger_settings = _local_swagger_settings

renderer_classes.append(CustomSettingsSwaggerRenderer)

elif issubclass(renderer_class, ReDocRenderer):
if _local_redoc_settings is _redoc_settings:
renderer_classes.append(renderer_class)
else:
class CustomSettingsRedDocRenderer(renderer_class):
redoc_settings = _local_redoc_settings

renderer_classes.append(CustomSettingsRedDocRenderer)

renderer_classes.extend(_spec_renderers)
return cls.as_cached_view(cache_timeout, cache_kwargs, renderer_classes=renderer_classes)

return SchemaView

0 comments on commit ff5d574

Please sign in to comment.