Skip to content

Commit

Permalink
allow swagger and redoc settings to be overridden
Browse files Browse the repository at this point in the history
When using site middleware there is no way to configure swagger
instances to have different login URLs or authentication mechanisms.
This change updates any global references to app settings to prefer a
local setting if it exists, and for backward compatibility it refers to
the global setting by default.

The following calls will need to be adjusted to account for the new
swagger_settings parameter that is passed:

- ``drf_yasg.utils.get_produces(...)``
- ViewInspector class (and user defined subclasses)
- Generator class (and user defined subclasses)
  • Loading branch information
terencehonles committed May 11, 2019
1 parent 65a9b35 commit f8471c1
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 63 deletions.
16 changes: 9 additions & 7 deletions src/drf_yasg/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from rest_framework.settings import api_settings

from . import openapi
from .app_settings import swagger_settings
from .app_settings import swagger_settings as _swagger_settings
from .errors import SwaggerGenerationError
from .inspectors.field import get_basic_type_info, get_queryset_field, get_queryset_from_view
from .openapi import ReferenceResolver, SwaggerDict
Expand Down Expand Up @@ -161,7 +161,7 @@ class OpenAPISchemaGenerator(object):
"""
endpoint_enumerator_class = EndpointEnumerator

def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
def __init__(self, info, version='', url=None, patterns=None, urlconf=None, swagger_settings=_swagger_settings):
"""
:param openapi.Info info: information about the API
Expand All @@ -177,9 +177,11 @@ def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
:param patterns: if given, only these patterns will be enumerated for inclusion in the API spec
:param urlconf: if patterns is not given, use this urlconf to enumerate patterns;
if not given, the default urlconf is used
:param swagger_settings: if given global swagger_settings are overridden with local settings
"""
self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf)
self.info = info
self.swagger_settings = swagger_settings
self.version = version
self.consumes = []
self.produces = []
Expand All @@ -205,7 +207,7 @@ def get_security_definitions(self):
:return: the security schemes usable with this API
:rtype: dict[str,dict] or None
"""
security_definitions = swagger_settings.SECURITY_DEFINITIONS
security_definitions = self.swagger_settings.SECURITY_DEFINITIONS
if security_definitions is not None:
security_definitions = SwaggerDict._as_odict(security_definitions, {})

Expand All @@ -219,7 +221,7 @@ def get_security_requirements(self, security_definitions):
:return: the security schemes accepted by default
:rtype: list[dict[str,list[str]]] or None
"""
security_requirements = swagger_settings.SECURITY_REQUIREMENTS
security_requirements = self.swagger_settings.SECURITY_REQUIREMENTS
if security_requirements is None:
security_requirements = [{security_scheme: []} for security_scheme in security_definitions]

Expand All @@ -240,7 +242,7 @@ def get_schema(self, request=None, public=False):
endpoints = self.get_endpoints(request)
components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS, force_init=True)
self.consumes = get_consumes(api_settings.DEFAULT_PARSER_CLASSES)
self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES)
self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES, self.swagger_settings)
paths, prefix = self.get_paths(endpoints, components, request, public)

security_definitions = self.get_security_definitions()
Expand Down Expand Up @@ -431,7 +433,7 @@ def get_operation(self, view, path, prefix, method, components, request):

# the inspector class can be specified, in decreasing order of priorty,
# 1. globaly via DEFAULT_AUTO_SCHEMA_CLASS
view_inspector_cls = swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS
view_inspector_cls = self.swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS
# 2. on the view/viewset class
view_inspector_cls = getattr(view, 'swagger_schema', view_inspector_cls)
# 3. on the swagger_auto_schema decorator
Expand All @@ -440,7 +442,7 @@ def get_operation(self, view, path, prefix, method, components, request):
if view_inspector_cls is None:
return None

view_inspector = view_inspector_cls(view, path, method, components, request, overrides)
view_inspector = view_inspector_cls(view, path, method, components, request, overrides, self.swagger_settings)
operation = view_inspector.get_operation(operation_keys)
if operation is None:
return None
Expand Down
6 changes: 0 additions & 6 deletions src/drf_yasg/inspectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from ..app_settings import swagger_settings
from .base import (
BaseInspector, FieldInspector, FilterInspector, NotHandled, PaginatorInspector, SerializerInspector, ViewInspector
)
Expand All @@ -10,11 +9,6 @@
from .query import CoreAPICompatInspector, DjangoRestResponsePagination
from .view import SwaggerAutoSchema

# these settings must be accessed only after defining/importing all the classes in this module to avoid ImportErrors
ViewInspector.field_inspectors = swagger_settings.DEFAULT_FIELD_INSPECTORS
ViewInspector.filter_inspectors = swagger_settings.DEFAULT_FILTER_INSPECTORS
ViewInspector.paginator_inspectors = swagger_settings.DEFAULT_PAGINATOR_INSPECTORS

__all__ = [
# base inspectors
'BaseInspector', 'FilterInspector', 'PaginatorInspector', 'FieldInspector', 'SerializerInspector', 'ViewInspector',
Expand Down
45 changes: 36 additions & 9 deletions src/drf_yasg/inspectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from rest_framework import serializers

from .. import openapi
from ..app_settings import swagger_settings as _swagger_settings
from ..utils import force_real_str, get_field_default, get_object_classes, is_list_view

#: Sentinel value that inspectors must return to signal that they do not know how to handle an object
Expand Down Expand Up @@ -60,19 +61,21 @@ def call_view_method(view, method_name, fallback_attr=None, default=None):


class BaseInspector(object):
def __init__(self, view, path, method, components, request):
def __init__(self, view, path, method, components, request, swagger_settings=_swagger_settings):
"""
:param rest_framework.views.APIView view: the view associated with this endpoint
:param str path: the path component of the operation URL
:param str method: the http method of the operation
:param openapi.ReferenceResolver components: referenceable components
:param rest_framework.request.Request request: the request made against the schema view; can be None
:param swagger_settings: if given global swagger_settings are overridden with local settings
"""
self.view = view
self.path = path
self.method = method
self.components = components
self.request = request
self.swagger_settings = swagger_settings

def process_result(self, result, method_name, obj, **kwargs):
"""After an inspector handles an object (i.e. returns a value other than :data:`.NotHandled`), all inspectors
Expand Down Expand Up @@ -193,8 +196,8 @@ def get_filter_parameters(self, filter_backend):
class FieldInspector(BaseInspector):
"""Base inspector for serializers and serializer fields. """

def __init__(self, view, path, method, components, request, field_inspectors):
super(FieldInspector, self).__init__(view, path, method, components, request)
def __init__(self, view, path, method, components, request, field_inspectors, swagger_settings=_swagger_settings):
super(FieldInspector, self).__init__(view, path, method, components, request, swagger_settings)
self.field_inspectors = field_inspectors

def add_manual_fields(self, serializer_or_field, schema):
Expand Down Expand Up @@ -345,18 +348,42 @@ class ViewInspector(BaseInspector):
#: methods which are assumed to return a list of objects when present on non-detail endpoints
implicit_list_response_methods = ('GET',)

# real values set in __init__ to prevent import errors
field_inspectors = [] #:
filter_inspectors = [] #:
paginator_inspectors = [] #:
_field_inspectors = None
_filter_inspectors = None
_paginator_inspectors = None

def __init__(self, view, path, method, components, request, overrides):
@property
def field_inspectors(self):
return self._field_inspectors or self.swagger_settings.DEFAULT_FIELD_INSPECTORS

@field_inspectors.setter
def field_inspectors(self, value):
self._field_inspectors = value

@property
def filter_inspectors(self):
return self._filter_inspectors or self.swagger_settings.DEFAULT_FILTER_INSPECTORS

@filter_inspectors.setter
def filter_inspectors(self, value):
self._filter_inspectors = value

@property
def paginator_inspectors(self):
return self._paginator_inspectors or self.swagger_settings.DEFAULT_PAGINATOR_INSPECTORS

@paginator_inspectors.setter
def paginator_inspectors(self, value):
self._paginator_inspectors = value

def __init__(self, view, path, method, components, request, overrides, swagger_settings=_swagger_settings):
"""
Inspector class responsible for providing :class:`.Operation` definitions given a view, path and method.
:param dict overrides: manual overrides as passed to :func:`@swagger_auto_schema <.swagger_auto_schema>`
:param swagger_settings: if given global swagger_settings are overridden with local settings
"""
super(ViewInspector, self).__init__(view, path, method, components, request)
super(ViewInspector, self).__init__(view, path, method, components, request, swagger_settings)
self.overrides = overrides
self._prepend_inspector_overrides('field_inspectors')
self._prepend_inspector_overrides('filter_inspectors')
Expand Down
5 changes: 3 additions & 2 deletions src/drf_yasg/inspectors/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from rest_framework.status import is_success

from .. import openapi
from ..app_settings import swagger_settings as _swagger_settings
from ..errors import SwaggerGenerationError
from ..utils import (
filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status,
Expand All @@ -17,8 +18,8 @@


class SwaggerAutoSchema(ViewInspector):
def __init__(self, view, path, method, components, request, overrides):
super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides)
def __init__(self, view, path, method, components, request, overrides, swagger_settings=_swagger_settings):
super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides, swagger_settings)
self._sch = AutoSchema()
self._sch.view = view

Expand Down
66 changes: 34 additions & 32 deletions src/drf_yasg/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rest_framework.renderers import BaseRenderer, JSONRenderer, TemplateHTMLRenderer
from rest_framework.utils import encoders, json

from .app_settings import redoc_settings, swagger_settings
from .app_settings import redoc_settings as _redoc_settings, swagger_settings as _swagger_settings
from .codecs import VALIDATORS, OpenAPICodecJson, OpenAPICodecYaml
from .openapi import Swagger
from .utils import filter_none
Expand Down Expand Up @@ -63,6 +63,7 @@ class _UIRenderer(BaseRenderer):
media_type = 'text/html'
charset = 'utf-8'
template = ''
swagger_settings = _swagger_settings

def render(self, swagger, accepted_media_type=None, renderer_context=None):
if not isinstance(swagger, Swagger): # pragma: no cover
Expand All @@ -78,7 +79,7 @@ def set_context(self, renderer_context, swagger=None):
renderer_context['title'] = swagger.info.title or '' if swagger else ''
renderer_context['version'] = swagger.info.version or '' if swagger else ''
renderer_context['oauth2_config'] = json.dumps(self.get_oauth2_config(), cls=encoders.JSONEncoder)
renderer_context['USE_SESSION_AUTH'] = swagger_settings.USE_SESSION_AUTH
renderer_context['USE_SESSION_AUTH'] = self.swagger_settings.USE_SESSION_AUTH
renderer_context.update(self.get_auth_urls())

def resolve_url(self, to):
Expand All @@ -102,14 +103,14 @@ def resolve_url(self, to):

def get_auth_urls(self):
urls = {
'LOGIN_URL': self.resolve_url(swagger_settings.LOGIN_URL),
'LOGOUT_URL': self.resolve_url(swagger_settings.LOGOUT_URL),
'LOGIN_URL': self.resolve_url(self.swagger_settings.LOGIN_URL),
'LOGOUT_URL': self.resolve_url(self.swagger_settings.LOGOUT_URL),
}

return filter_none(urls)

def get_oauth2_config(self):
data = swagger_settings.OAUTH2_CONFIG
data = self.swagger_settings.OAUTH2_CONFIG
assert isinstance(data, dict), "OAUTH2_CONFIG must be a dict"
return data

Expand All @@ -132,28 +133,28 @@ def set_context(self, renderer_context, swagger=None):

def get_swagger_ui_settings(self):
data = {
'url': self.resolve_url(swagger_settings.SPEC_URL),
'operationsSorter': swagger_settings.OPERATIONS_SORTER,
'tagsSorter': swagger_settings.TAGS_SORTER,
'docExpansion': swagger_settings.DOC_EXPANSION,
'deepLinking': swagger_settings.DEEP_LINKING,
'showExtensions': swagger_settings.SHOW_EXTENSIONS,
'defaultModelRendering': swagger_settings.DEFAULT_MODEL_RENDERING,
'defaultModelExpandDepth': swagger_settings.DEFAULT_MODEL_DEPTH,
'defaultModelsExpandDepth': swagger_settings.DEFAULT_MODEL_DEPTH,
'showCommonExtensions': swagger_settings.SHOW_COMMON_EXTENSIONS,
'oauth2RedirectUrl': swagger_settings.OAUTH2_REDIRECT_URL,
'supportedSubmitMethods': swagger_settings.SUPPORTED_SUBMIT_METHODS,
'displayOperationId': swagger_settings.DISPLAY_OPERATION_ID,
'persistAuth': swagger_settings.PERSIST_AUTH,
'refetchWithAuth': swagger_settings.REFETCH_SCHEMA_WITH_AUTH,
'refetchOnLogout': swagger_settings.REFETCH_SCHEMA_ON_LOGOUT,
'fetchSchemaWithQuery': swagger_settings.FETCH_SCHEMA_WITH_QUERY,
'url': self.resolve_url(self.swagger_settings.SPEC_URL),
'operationsSorter': self.swagger_settings.OPERATIONS_SORTER,
'tagsSorter': self.swagger_settings.TAGS_SORTER,
'docExpansion': self.swagger_settings.DOC_EXPANSION,
'deepLinking': self.swagger_settings.DEEP_LINKING,
'showExtensions': self.swagger_settings.SHOW_EXTENSIONS,
'defaultModelRendering': self.swagger_settings.DEFAULT_MODEL_RENDERING,
'defaultModelExpandDepth': self.swagger_settings.DEFAULT_MODEL_DEPTH,
'defaultModelsExpandDepth': self.swagger_settings.DEFAULT_MODEL_DEPTH,
'showCommonExtensions': self.swagger_settings.SHOW_COMMON_EXTENSIONS,
'oauth2RedirectUrl': self.swagger_settings.OAUTH2_REDIRECT_URL,
'supportedSubmitMethods': self.swagger_settings.SUPPORTED_SUBMIT_METHODS,
'displayOperationId': self.swagger_settings.DISPLAY_OPERATION_ID,
'persistAuth': self.swagger_settings.PERSIST_AUTH,
'refetchWithAuth': self.swagger_settings.REFETCH_SCHEMA_WITH_AUTH,
'refetchOnLogout': self.swagger_settings.REFETCH_SCHEMA_ON_LOGOUT,
'fetchSchemaWithQuery': self.swagger_settings.FETCH_SCHEMA_WITH_QUERY,
}

data = filter_none(data)
if swagger_settings.VALIDATOR_URL != '':
data['validatorUrl'] = self.resolve_url(swagger_settings.VALIDATOR_URL)
if self.swagger_settings.VALIDATOR_URL != '':
data['validatorUrl'] = self.resolve_url(self.swagger_settings.VALIDATOR_URL)

return data

Expand All @@ -162,21 +163,22 @@ class ReDocRenderer(_UIRenderer):
"""Renders a ReDoc web interface for schema browisng."""
template = 'drf-yasg/redoc.html'
format = 'redoc'
redoc_settings = _redoc_settings

def set_context(self, renderer_context, swagger=None):
super(ReDocRenderer, self).set_context(renderer_context, swagger)
renderer_context['redoc_settings'] = json.dumps(self.get_redoc_settings(), cls=encoders.JSONEncoder)

def get_redoc_settings(self):
data = {
'url': self.resolve_url(redoc_settings.SPEC_URL),
'lazyRendering': redoc_settings.LAZY_RENDERING,
'hideHostname': redoc_settings.HIDE_HOSTNAME,
'expandResponses': redoc_settings.EXPAND_RESPONSES,
'pathInMiddlePanel': redoc_settings.PATH_IN_MIDDLE,
'nativeScrollbars': redoc_settings.NATIVE_SCROLLBARS,
'requiredPropsFirst': redoc_settings.REQUIRED_PROPS_FIRST,
'fetchSchemaWithQuery': redoc_settings.FETCH_SCHEMA_WITH_QUERY,
'url': self.resolve_url(self.redoc_settings.SPEC_URL),
'lazyRendering': self.redoc_settings.LAZY_RENDERING,
'hideHostname': self.redoc_settings.HIDE_HOSTNAME,
'expandResponses': self.redoc_settings.EXPAND_RESPONSES,
'pathInMiddlePanel': self.redoc_settings.PATH_IN_MIDDLE,
'nativeScrollbars': self.redoc_settings.NATIVE_SCROLLBARS,
'requiredPropsFirst': self.redoc_settings.REQUIRED_PROPS_FIRST,
'fetchSchemaWithQuery': self.redoc_settings.FETCH_SCHEMA_WITH_QUERY,
}

return filter_none(data)
Expand Down
4 changes: 2 additions & 2 deletions src/drf_yasg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rest_framework.utils import encoders, json
from rest_framework.views import APIView

from .app_settings import swagger_settings
from .app_settings import swagger_settings as _swagger_settings

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -380,7 +380,7 @@ def get_consumes(parser_classes):
return non_form_media_types


def get_produces(renderer_classes):
def get_produces(renderer_classes, swagger_settings=_swagger_settings):
"""Extract ``produces`` MIME types from a list of renderer classes.
:param list renderer_classes: renderer classes
Expand Down
Loading

0 comments on commit f8471c1

Please sign in to comment.