123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520 |
- import copy
- import logging
- import re
- import urllib.parse as urlparse
- from collections import OrderedDict, defaultdict
- import uritemplate
- from django.urls import URLPattern, URLResolver
- from rest_framework import versioning
- from rest_framework.schemas import SchemaGenerator
- from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator
- from rest_framework.schemas.generators import endpoint_ordering, get_pk_name
- from rest_framework.schemas.utils import get_pk_description
- from rest_framework.settings import api_settings
- from . import openapi
- from .app_settings import swagger_settings
- from .errors import SwaggerGenerationError
- from .inspectors.field import get_basic_type_info, get_queryset_field, get_queryset_from_view
- from .openapi import ReferenceResolver, SwaggerDict
- from .utils import force_real_str, get_consumes, get_produces
- logger = logging.getLogger(__name__)
- PATH_PARAMETER_RE = re.compile(r'{(?P<parameter>\w+)}')
- class EndpointEnumerator(_EndpointEnumerator):
- def __init__(self, patterns=None, urlconf=None, request=None):
- super(EndpointEnumerator, self).__init__(patterns, urlconf)
- self.request = request
- def get_path_from_regex(self, path_regex):
- if path_regex.endswith(')'):
- logger.warning("url pattern does not end in $ ('%s') - unexpected things might happen", path_regex)
- return self.unescape_path(super(EndpointEnumerator, self).get_path_from_regex(path_regex))
- def should_include_endpoint(self, path, callback, app_name='', namespace='', url_name=None):
- if not super(EndpointEnumerator, self).should_include_endpoint(path, callback):
- return False
- version = getattr(self.request, 'version', None)
- versioning_class = getattr(callback.cls, 'versioning_class', None)
- if versioning_class is not None and issubclass(versioning_class, versioning.NamespaceVersioning):
- if version and version not in namespace.split(':'):
- return False
- if getattr(callback.cls, 'swagger_schema', object()) is None:
- return False
- return True
- def replace_version(self, path, callback):
- """If ``request.version`` is not ``None`` and `callback` uses ``URLPathVersioning``, this function replaces
- the ``version`` parameter in `path` with the actual version.
- :param str path: the templated path
- :param callback: the view callback
- :rtype: str
- """
- versioning_class = getattr(callback.cls, 'versioning_class', None)
- if versioning_class is not None and issubclass(versioning_class, versioning.URLPathVersioning):
- version = getattr(self.request, 'version', None)
- if version:
- version_param = getattr(versioning_class, 'version_param', 'version')
- version_param = '{%s}' % version_param
- if version_param not in path:
- logger.info("view %s uses URLPathVersioning but URL %s has no param %s"
- % (callback.cls, path, version_param))
- path = path.replace(version_param, version)
- return path
- def get_api_endpoints(self, patterns=None, prefix='', app_name=None, namespace=None, ignored_endpoints=None):
- """
- Return a list of all available API endpoints by inspecting the URL conf.
- Copied entirely from super.
- """
- if patterns is None:
- patterns = self.patterns
- api_endpoints = []
- if ignored_endpoints is None:
- ignored_endpoints = set()
- for pattern in patterns:
- path_regex = prefix + str(pattern.pattern)
- if isinstance(pattern, URLPattern):
- try:
- path = self.get_path_from_regex(path_regex)
- callback = pattern.callback
- url_name = pattern.name
- if self.should_include_endpoint(path, callback, app_name or '', namespace or '', url_name):
- path = self.replace_version(path, callback)
- # avoid adding endpoints that have already been seen,
- # as Django resolves urls in top-down order
- if path in ignored_endpoints:
- continue
- ignored_endpoints.add(path)
- for method in self.get_allowed_methods(callback):
- endpoint = (path, method, callback)
- api_endpoints.append(endpoint)
- except Exception: # pragma: no cover
- logger.warning('failed to enumerate view', exc_info=True)
- elif isinstance(pattern, URLResolver):
- nested_endpoints = self.get_api_endpoints(
- patterns=pattern.url_patterns,
- prefix=path_regex,
- app_name="%s:%s" % (app_name, pattern.app_name) if app_name else pattern.app_name,
- namespace="%s:%s" % (namespace, pattern.namespace) if namespace else pattern.namespace,
- ignored_endpoints=ignored_endpoints
- )
- api_endpoints.extend(nested_endpoints)
- else:
- logger.warning("unknown pattern type {}".format(type(pattern)))
- api_endpoints = sorted(api_endpoints, key=endpoint_ordering)
- return api_endpoints
- def unescape(self, s):
- """Unescape all backslash escapes from `s`.
- :param str s: string with backslash escapes
- :rtype: str
- """
- # unlike .replace('\\', ''), this correctly transforms a double backslash into a single backslash
- return re.sub(r'\\(.)', r'\1', s)
- def unescape_path(self, path):
- """Remove backslashes escapes from all path components outside {parameters}. This is needed because
- ``simplify_regex`` does not handle this correctly.
- **NOTE:** this might destructively affect some url regex patterns that contain metacharacters (e.g. \\w, \\d)
- outside path parameter groups; if you are in this category, God help you
- :param str path: path possibly containing
- :return: the unescaped path
- :rtype: str
- """
- clean_path = ''
- while path:
- match = PATH_PARAMETER_RE.search(path)
- if not match:
- clean_path += self.unescape(path)
- break
- clean_path += self.unescape(path[:match.start()])
- clean_path += match.group()
- path = path[match.end():]
- return clean_path
- class OpenAPISchemaGenerator(object):
- """
- This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema.
- Method implementations shamelessly stolen and adapted from rest-framework ``SchemaGenerator``.
- """
- endpoint_enumerator_class = EndpointEnumerator
- reference_resolver_class = ReferenceResolver
- def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
- """
- :param openapi.Info info: information about the API
- :param str version: API version string; if omitted, `info.default_version` will be used
- :param str url: API scheme, host and port; if ``None`` is passed and ``DEFAULT_API_URL`` is not set, the url
- will be inferred from the request made against the schema view, so you should generally not need to set
- this parameter explicitly; if the empty string is passed, no host and scheme will be emitted
- If `url` is not ``None`` or the empty string, it must be a scheme-absolute uri (i.e. starting with http://
- or https://), and any path component is ignored;
- See also: :ref:`documentation on base URL construction <custom-spec-base-url>`
- :param patterns: if given, only these patterns will be enumerated for inclusion in the API spec
- :param urlconf: if patterns is not given, use this urlconf to enumerate patterns;
- if not given, the default urlconf is used
- """
- self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf)
- self.info = info
- self.version = version
- self.consumes = []
- self.produces = []
- if url is None and swagger_settings.DEFAULT_API_URL is not None:
- url = swagger_settings.DEFAULT_API_URL
- if url:
- parsed_url = urlparse.urlparse(url)
- if parsed_url.scheme not in ('http', 'https') or not parsed_url.netloc:
- raise SwaggerGenerationError("`url` must be an absolute HTTP(S) url")
- if parsed_url.path:
- logger.warning("path component of api base URL %s is ignored; use FORCE_SCRIPT_NAME instead" % url)
- else:
- self._gen.url = url
- @property
- def url(self):
- return self._gen.url
- def get_security_definitions(self):
- """Get the security schemes for this API. This determines what is usable in security requirements,
- and helps clients configure their authorization credentials.
- :return: the security schemes usable with this API
- :rtype: dict[str,dict] or None
- """
- security_definitions = swagger_settings.SECURITY_DEFINITIONS
- if security_definitions is not None:
- security_definitions = SwaggerDict._as_odict(security_definitions, {})
- return security_definitions
- def get_security_requirements(self, security_definitions):
- """Get the base (global) security requirements of the API. This is never called if
- :meth:`.get_security_definitions` returns `None`.
- :param security_definitions: security definitions as returned by :meth:`.get_security_definitions`
- :return: the security schemes accepted by default
- :rtype: list[dict[str,list[str]]] or None
- """
- security_requirements = swagger_settings.SECURITY_REQUIREMENTS
- if security_requirements is None:
- security_requirements = [{security_scheme: []} for security_scheme in security_definitions]
- security_requirements = [SwaggerDict._as_odict(sr, {}) for sr in security_requirements]
- security_requirements = sorted(security_requirements, key=list)
- return security_requirements
- def get_schema(self, request=None, public=False):
- """Generate a :class:`.Swagger` object representing the API schema.
- :param request: the request used for filtering accessible endpoints and finding the spec URI
- :type request: rest_framework.request.Request or None
- :param bool public: if True, all endpoints are included regardless of access through `request`
- :return: the generated Swagger specification
- :rtype: openapi.Swagger
- """
- endpoints = self.get_endpoints(request)
- components = self.reference_resolver_class(openapi.SCHEMA_DEFINITIONS, force_init=True)
- self.consumes = get_consumes(api_settings.DEFAULT_PARSER_CLASSES)
- self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES)
- paths, prefix = self.get_paths(endpoints, components, request, public)
- security_definitions = self.get_security_definitions()
- if security_definitions:
- security_requirements = self.get_security_requirements(security_definitions)
- else:
- security_requirements = None
- url = self.url
- if url is None and request is not None:
- url = request.build_absolute_uri()
- return openapi.Swagger(
- info=self.info, paths=paths, consumes=self.consumes or None, produces=self.produces or None,
- security_definitions=security_definitions, security=security_requirements,
- _url=url, _prefix=prefix, _version=self.version, **dict(components)
- )
- def create_view(self, callback, method, request=None):
- """Create a view instance from a view callback as registered in urlpatterns.
- :param callback: view callback registered in urlpatterns
- :param str method: HTTP method
- :param request: request to bind to the view
- :type request: rest_framework.request.Request or None
- :return: the view instance
- """
- view = self._gen.create_view(callback, method, request)
- overrides = getattr(callback, '_swagger_auto_schema', None)
- if overrides is not None:
- # decorated function based view must have its decorator information passed on to the re-instantiated view
- for method, _ in overrides.items():
- view_method = getattr(view, method, None)
- if view_method is not None: # pragma: no cover
- setattr(view_method.__func__, '_swagger_auto_schema', overrides)
- setattr(view, 'swagger_fake_view', True)
- return view
- def coerce_path(self, path, view):
- """Coerce {pk} path arguments into the name of the model field, where possible. This is cleaner for an
- external representation (i.e. "this is an identifier", not "this is a database primary key").
- :param str path: the path
- :param rest_framework.views.APIView view: associated view
- :rtype: str
- """
- if '{pk}' not in path:
- return path
- model = getattr(get_queryset_from_view(view), 'model', None)
- if model:
- field_name = get_pk_name(model)
- else:
- field_name = 'id'
- return path.replace('{pk}', '{%s}' % field_name)
- def get_endpoints(self, request):
- """Iterate over all the registered endpoints in the API and return a fake view with the right parameters.
- :param request: request to bind to the endpoint views
- :type request: rest_framework.request.Request or None
- :return: {path: (view_class, list[(http_method, view_instance)])
- :rtype: dict[str,(type,list[(str,rest_framework.views.APIView)])]
- """
- enumerator = self.endpoint_enumerator_class(self._gen.patterns, self._gen.urlconf, request=request)
- endpoints = enumerator.get_api_endpoints()
- view_paths = defaultdict(list)
- view_cls = {}
- for path, method, callback in endpoints:
- view = self.create_view(callback, method, request)
- path = self.coerce_path(path, view)
- view_paths[path].append((method, view))
- view_cls[path] = callback.cls
- return {path: (view_cls[path], methods) for path, methods in view_paths.items()}
- def get_operation_keys(self, subpath, method, view):
- """Return a list of keys that should be used to group an operation within the specification. ::
- /users/ ("users", "list"), ("users", "create")
- /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
- /users/enabled/ ("users", "enabled") # custom viewset list action
- /users/{pk}/star/ ("users", "star") # custom viewset detail action
- /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
- /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update")
- :param str subpath: path to the operation with any common prefix/base path removed
- :param str method: HTTP method
- :param view: the view associated with the operation
- :rtype: list[str]
- """
- return self._gen.get_keys(subpath, method, view)
- def determine_path_prefix(self, paths):
- """
- Given a list of all paths, return the common prefix which should be
- discounted when generating a schema structure.
- This will be the longest common string that does not include that last
- component of the URL, or the last component before a path parameter.
- For example: ::
- /api/v1/users/
- /api/v1/users/{pk}/
- The path prefix is ``/api/v1/``.
- :param list[str] paths: list of paths
- :rtype: str
- """
- return self._gen.determine_path_prefix(paths)
- def should_include_endpoint(self, path, method, view, public):
- """Check if a given endpoint should be included in the resulting schema.
- :param str path: request path
- :param str method: http request method
- :param view: instantiated view callback
- :param bool public: if True, all endpoints are included regardless of access through `request`
- :returns: true if the view should be excluded
- :rtype: bool
- """
- return public or self._gen.has_view_permissions(path, method, view)
- def get_paths_object(self, paths):
- """Construct the Swagger Paths object.
- :param OrderedDict[str,openapi.PathItem] paths: mapping of paths to :class:`.PathItem` objects
- :returns: the :class:`.Paths` object
- :rtype: openapi.Paths
- """
- return openapi.Paths(paths=paths)
- def get_paths(self, endpoints, components, request, public):
- """Generate the Swagger Paths for the API from the given endpoints.
- :param dict endpoints: endpoints as returned by get_endpoints
- :param ReferenceResolver components: resolver/container for Swagger References
- :param Request request: the request made against the schema view; can be None
- :param bool public: if True, all endpoints are included regardless of access through `request`
- :returns: the :class:`.Paths` object and the longest common path prefix, as a 2-tuple
- :rtype: tuple[openapi.Paths,str]
- """
- if not endpoints:
- return openapi.Paths(paths={}), ''
- prefix = self.determine_path_prefix(list(endpoints.keys())) or ''
- assert '{' not in prefix, "base path cannot be templated in swagger 2.0"
- paths = OrderedDict()
- for path, (view_cls, methods) in sorted(endpoints.items()):
- operations = {}
- for method, view in methods:
- if not self.should_include_endpoint(path, method, view, public):
- continue
- operation = self.get_operation(view, path, prefix, method, components, request)
- if operation is not None:
- operations[method.lower()] = operation
- if operations:
- # since the common prefix is used as the API basePath, it must be stripped
- # from individual paths when writing them into the swagger document
- path_suffix = path[len(prefix):]
- if not path_suffix.startswith('/'):
- path_suffix = '/' + path_suffix
- paths[path_suffix] = self.get_path_item(path, view_cls, operations)
- return self.get_paths_object(paths), prefix
- def get_operation(self, view, path, prefix, method, components, request):
- """Get an :class:`.Operation` for the given API endpoint (path, method). This method delegates to
- :meth:`~.inspectors.ViewInspector.get_operation` of a :class:`~.inspectors.ViewInspector` determined
- according to settings and :func:`@swagger_auto_schema <.swagger_auto_schema>` overrides.
- :param view: the view associated with this endpoint
- :param str path: the path component of the operation URL
- :param str prefix: common path prefix among all endpoints
- :param str method: the http method of the operation
- :param openapi.ReferenceResolver components: referenceable components
- :param Request request: the request made against the schema view; can be None
- :rtype: openapi.Operation
- """
- operation_keys = self.get_operation_keys(path[len(prefix):], method, view)
- overrides = self.get_overrides(view, method)
- # the inspector class can be specified, in decreasing order of priority,
- # 1. globally via DEFAULT_AUTO_SCHEMA_CLASS
- view_inspector_cls = swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS
- # 2. on the view/viewset class
- view_inspector_cls = getattr(view, 'swagger_schema', view_inspector_cls)
- # 3. on the swagger_auto_schema decorator
- view_inspector_cls = overrides.get('auto_schema', view_inspector_cls)
- if view_inspector_cls is None:
- return None
- view_inspector = view_inspector_cls(view, path, method, components, request, overrides, operation_keys)
- operation = view_inspector.get_operation(operation_keys)
- if operation is None:
- return None
- if 'consumes' in operation and set(operation.consumes) == set(self.consumes):
- del operation.consumes
- if 'produces' in operation and set(operation.produces) == set(self.produces):
- del operation.produces
- return operation
- def get_path_item(self, path, view_cls, operations):
- """Get a :class:`.PathItem` object that describes the parameters and operations related to a single path in the
- API.
- :param str path: the path
- :param type view_cls: the view that was bound to this path in urlpatterns
- :param dict[str,openapi.Operation] operations: operations defined on this path, keyed by lowercase HTTP method
- :rtype: openapi.PathItem
- """
- path_parameters = self.get_path_parameters(path, view_cls)
- return openapi.PathItem(parameters=path_parameters, **operations)
- def get_overrides(self, view, method):
- """Get overrides specified for a given operation.
- :param view: the view associated with the operation
- :param str method: HTTP method
- :return: a dictionary containing any overrides set by :func:`@swagger_auto_schema <.swagger_auto_schema>`
- :rtype: dict
- """
- method = method.lower()
- action = getattr(view, 'action', method)
- action_method = getattr(view, action, None)
- overrides = getattr(action_method, '_swagger_auto_schema', {})
- if method in overrides:
- overrides = overrides[method]
- return copy.deepcopy(overrides)
- def get_path_parameters(self, path, view_cls):
- """Return a list of Parameter instances corresponding to any templated path variables.
- :param str path: templated request path
- :param type view_cls: the view class associated with the path
- :return: path parameters
- :rtype: list[openapi.Parameter]
- """
- parameters = []
- queryset = get_queryset_from_view(view_cls)
- for variable in sorted(uritemplate.variables(path)):
- model, model_field = get_queryset_field(queryset, variable)
- attrs = get_basic_type_info(model_field) or {'type': openapi.TYPE_STRING}
- if getattr(view_cls, 'lookup_field', None) == variable and attrs['type'] == openapi.TYPE_STRING:
- attrs['pattern'] = getattr(view_cls, 'lookup_value_regex', attrs.get('pattern', None))
- if model_field and getattr(model_field, 'help_text', False):
- description = model_field.help_text
- elif model_field and getattr(model_field, 'primary_key', False):
- description = get_pk_description(model, model_field)
- else:
- description = None
- field = openapi.Parameter(
- name=variable,
- description=force_real_str(description),
- required=True,
- in_=openapi.IN_PATH,
- **attrs
- )
- parameters.append(field)
- return parameters
|