1
0

generators.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. import copy
  2. import logging
  3. import re
  4. import urllib.parse as urlparse
  5. from collections import OrderedDict, defaultdict
  6. import uritemplate
  7. from django.urls import URLPattern, URLResolver
  8. from rest_framework import versioning
  9. from rest_framework.schemas import SchemaGenerator
  10. from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator
  11. from rest_framework.schemas.generators import endpoint_ordering, get_pk_name
  12. from rest_framework.schemas.utils import get_pk_description
  13. from rest_framework.settings import api_settings
  14. from . import openapi
  15. from .app_settings import swagger_settings
  16. from .errors import SwaggerGenerationError
  17. from .inspectors.field import get_basic_type_info, get_queryset_field, get_queryset_from_view
  18. from .openapi import ReferenceResolver, SwaggerDict
  19. from .utils import force_real_str, get_consumes, get_produces
  20. logger = logging.getLogger(__name__)
  21. PATH_PARAMETER_RE = re.compile(r'{(?P<parameter>\w+)}')
  22. class EndpointEnumerator(_EndpointEnumerator):
  23. def __init__(self, patterns=None, urlconf=None, request=None):
  24. super(EndpointEnumerator, self).__init__(patterns, urlconf)
  25. self.request = request
  26. def get_path_from_regex(self, path_regex):
  27. if path_regex.endswith(')'):
  28. logger.warning("url pattern does not end in $ ('%s') - unexpected things might happen", path_regex)
  29. return self.unescape_path(super(EndpointEnumerator, self).get_path_from_regex(path_regex))
  30. def should_include_endpoint(self, path, callback, app_name='', namespace='', url_name=None):
  31. if not super(EndpointEnumerator, self).should_include_endpoint(path, callback):
  32. return False
  33. version = getattr(self.request, 'version', None)
  34. versioning_class = getattr(callback.cls, 'versioning_class', None)
  35. if versioning_class is not None and issubclass(versioning_class, versioning.NamespaceVersioning):
  36. if version and version not in namespace.split(':'):
  37. return False
  38. if getattr(callback.cls, 'swagger_schema', object()) is None:
  39. return False
  40. return True
  41. def replace_version(self, path, callback):
  42. """If ``request.version`` is not ``None`` and `callback` uses ``URLPathVersioning``, this function replaces
  43. the ``version`` parameter in `path` with the actual version.
  44. :param str path: the templated path
  45. :param callback: the view callback
  46. :rtype: str
  47. """
  48. versioning_class = getattr(callback.cls, 'versioning_class', None)
  49. if versioning_class is not None and issubclass(versioning_class, versioning.URLPathVersioning):
  50. version = getattr(self.request, 'version', None)
  51. if version:
  52. version_param = getattr(versioning_class, 'version_param', 'version')
  53. version_param = '{%s}' % version_param
  54. if version_param not in path:
  55. logger.info("view %s uses URLPathVersioning but URL %s has no param %s"
  56. % (callback.cls, path, version_param))
  57. path = path.replace(version_param, version)
  58. return path
  59. def get_api_endpoints(self, patterns=None, prefix='', app_name=None, namespace=None, ignored_endpoints=None):
  60. """
  61. Return a list of all available API endpoints by inspecting the URL conf.
  62. Copied entirely from super.
  63. """
  64. if patterns is None:
  65. patterns = self.patterns
  66. api_endpoints = []
  67. if ignored_endpoints is None:
  68. ignored_endpoints = set()
  69. for pattern in patterns:
  70. path_regex = prefix + str(pattern.pattern)
  71. if isinstance(pattern, URLPattern):
  72. try:
  73. path = self.get_path_from_regex(path_regex)
  74. callback = pattern.callback
  75. url_name = pattern.name
  76. if self.should_include_endpoint(path, callback, app_name or '', namespace or '', url_name):
  77. path = self.replace_version(path, callback)
  78. # avoid adding endpoints that have already been seen,
  79. # as Django resolves urls in top-down order
  80. if path in ignored_endpoints:
  81. continue
  82. ignored_endpoints.add(path)
  83. for method in self.get_allowed_methods(callback):
  84. endpoint = (path, method, callback)
  85. api_endpoints.append(endpoint)
  86. except Exception: # pragma: no cover
  87. logger.warning('failed to enumerate view', exc_info=True)
  88. elif isinstance(pattern, URLResolver):
  89. nested_endpoints = self.get_api_endpoints(
  90. patterns=pattern.url_patterns,
  91. prefix=path_regex,
  92. app_name="%s:%s" % (app_name, pattern.app_name) if app_name else pattern.app_name,
  93. namespace="%s:%s" % (namespace, pattern.namespace) if namespace else pattern.namespace,
  94. ignored_endpoints=ignored_endpoints
  95. )
  96. api_endpoints.extend(nested_endpoints)
  97. else:
  98. logger.warning("unknown pattern type {}".format(type(pattern)))
  99. api_endpoints = sorted(api_endpoints, key=endpoint_ordering)
  100. return api_endpoints
  101. def unescape(self, s):
  102. """Unescape all backslash escapes from `s`.
  103. :param str s: string with backslash escapes
  104. :rtype: str
  105. """
  106. # unlike .replace('\\', ''), this correctly transforms a double backslash into a single backslash
  107. return re.sub(r'\\(.)', r'\1', s)
  108. def unescape_path(self, path):
  109. """Remove backslashes escapes from all path components outside {parameters}. This is needed because
  110. ``simplify_regex`` does not handle this correctly.
  111. **NOTE:** this might destructively affect some url regex patterns that contain metacharacters (e.g. \\w, \\d)
  112. outside path parameter groups; if you are in this category, God help you
  113. :param str path: path possibly containing
  114. :return: the unescaped path
  115. :rtype: str
  116. """
  117. clean_path = ''
  118. while path:
  119. match = PATH_PARAMETER_RE.search(path)
  120. if not match:
  121. clean_path += self.unescape(path)
  122. break
  123. clean_path += self.unescape(path[:match.start()])
  124. clean_path += match.group()
  125. path = path[match.end():]
  126. return clean_path
  127. class OpenAPISchemaGenerator(object):
  128. """
  129. This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema.
  130. Method implementations shamelessly stolen and adapted from rest-framework ``SchemaGenerator``.
  131. """
  132. endpoint_enumerator_class = EndpointEnumerator
  133. reference_resolver_class = ReferenceResolver
  134. def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
  135. """
  136. :param openapi.Info info: information about the API
  137. :param str version: API version string; if omitted, `info.default_version` will be used
  138. :param str url: API scheme, host and port; if ``None`` is passed and ``DEFAULT_API_URL`` is not set, the url
  139. will be inferred from the request made against the schema view, so you should generally not need to set
  140. this parameter explicitly; if the empty string is passed, no host and scheme will be emitted
  141. If `url` is not ``None`` or the empty string, it must be a scheme-absolute uri (i.e. starting with http://
  142. or https://), and any path component is ignored;
  143. See also: :ref:`documentation on base URL construction <custom-spec-base-url>`
  144. :param patterns: if given, only these patterns will be enumerated for inclusion in the API spec
  145. :param urlconf: if patterns is not given, use this urlconf to enumerate patterns;
  146. if not given, the default urlconf is used
  147. """
  148. self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf)
  149. self.info = info
  150. self.version = version
  151. self.consumes = []
  152. self.produces = []
  153. if url is None and swagger_settings.DEFAULT_API_URL is not None:
  154. url = swagger_settings.DEFAULT_API_URL
  155. if url:
  156. parsed_url = urlparse.urlparse(url)
  157. if parsed_url.scheme not in ('http', 'https') or not parsed_url.netloc:
  158. raise SwaggerGenerationError("`url` must be an absolute HTTP(S) url")
  159. if parsed_url.path:
  160. logger.warning("path component of api base URL %s is ignored; use FORCE_SCRIPT_NAME instead" % url)
  161. else:
  162. self._gen.url = url
  163. @property
  164. def url(self):
  165. return self._gen.url
  166. def get_security_definitions(self):
  167. """Get the security schemes for this API. This determines what is usable in security requirements,
  168. and helps clients configure their authorization credentials.
  169. :return: the security schemes usable with this API
  170. :rtype: dict[str,dict] or None
  171. """
  172. security_definitions = swagger_settings.SECURITY_DEFINITIONS
  173. if security_definitions is not None:
  174. security_definitions = SwaggerDict._as_odict(security_definitions, {})
  175. return security_definitions
  176. def get_security_requirements(self, security_definitions):
  177. """Get the base (global) security requirements of the API. This is never called if
  178. :meth:`.get_security_definitions` returns `None`.
  179. :param security_definitions: security definitions as returned by :meth:`.get_security_definitions`
  180. :return: the security schemes accepted by default
  181. :rtype: list[dict[str,list[str]]] or None
  182. """
  183. security_requirements = swagger_settings.SECURITY_REQUIREMENTS
  184. if security_requirements is None:
  185. security_requirements = [{security_scheme: []} for security_scheme in security_definitions]
  186. security_requirements = [SwaggerDict._as_odict(sr, {}) for sr in security_requirements]
  187. security_requirements = sorted(security_requirements, key=list)
  188. return security_requirements
  189. def get_schema(self, request=None, public=False):
  190. """Generate a :class:`.Swagger` object representing the API schema.
  191. :param request: the request used for filtering accessible endpoints and finding the spec URI
  192. :type request: rest_framework.request.Request or None
  193. :param bool public: if True, all endpoints are included regardless of access through `request`
  194. :return: the generated Swagger specification
  195. :rtype: openapi.Swagger
  196. """
  197. endpoints = self.get_endpoints(request)
  198. components = self.reference_resolver_class(openapi.SCHEMA_DEFINITIONS, force_init=True)
  199. self.consumes = get_consumes(api_settings.DEFAULT_PARSER_CLASSES)
  200. self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES)
  201. paths, prefix = self.get_paths(endpoints, components, request, public)
  202. security_definitions = self.get_security_definitions()
  203. if security_definitions:
  204. security_requirements = self.get_security_requirements(security_definitions)
  205. else:
  206. security_requirements = None
  207. url = self.url
  208. if url is None and request is not None:
  209. url = request.build_absolute_uri()
  210. return openapi.Swagger(
  211. info=self.info, paths=paths, consumes=self.consumes or None, produces=self.produces or None,
  212. security_definitions=security_definitions, security=security_requirements,
  213. _url=url, _prefix=prefix, _version=self.version, **dict(components)
  214. )
  215. def create_view(self, callback, method, request=None):
  216. """Create a view instance from a view callback as registered in urlpatterns.
  217. :param callback: view callback registered in urlpatterns
  218. :param str method: HTTP method
  219. :param request: request to bind to the view
  220. :type request: rest_framework.request.Request or None
  221. :return: the view instance
  222. """
  223. view = self._gen.create_view(callback, method, request)
  224. overrides = getattr(callback, '_swagger_auto_schema', None)
  225. if overrides is not None:
  226. # decorated function based view must have its decorator information passed on to the re-instantiated view
  227. for method, _ in overrides.items():
  228. view_method = getattr(view, method, None)
  229. if view_method is not None: # pragma: no cover
  230. setattr(view_method.__func__, '_swagger_auto_schema', overrides)
  231. setattr(view, 'swagger_fake_view', True)
  232. return view
  233. def coerce_path(self, path, view):
  234. """Coerce {pk} path arguments into the name of the model field, where possible. This is cleaner for an
  235. external representation (i.e. "this is an identifier", not "this is a database primary key").
  236. :param str path: the path
  237. :param rest_framework.views.APIView view: associated view
  238. :rtype: str
  239. """
  240. if '{pk}' not in path:
  241. return path
  242. model = getattr(get_queryset_from_view(view), 'model', None)
  243. if model:
  244. field_name = get_pk_name(model)
  245. else:
  246. field_name = 'id'
  247. return path.replace('{pk}', '{%s}' % field_name)
  248. def get_endpoints(self, request):
  249. """Iterate over all the registered endpoints in the API and return a fake view with the right parameters.
  250. :param request: request to bind to the endpoint views
  251. :type request: rest_framework.request.Request or None
  252. :return: {path: (view_class, list[(http_method, view_instance)])
  253. :rtype: dict[str,(type,list[(str,rest_framework.views.APIView)])]
  254. """
  255. enumerator = self.endpoint_enumerator_class(self._gen.patterns, self._gen.urlconf, request=request)
  256. endpoints = enumerator.get_api_endpoints()
  257. view_paths = defaultdict(list)
  258. view_cls = {}
  259. for path, method, callback in endpoints:
  260. view = self.create_view(callback, method, request)
  261. path = self.coerce_path(path, view)
  262. view_paths[path].append((method, view))
  263. view_cls[path] = callback.cls
  264. return {path: (view_cls[path], methods) for path, methods in view_paths.items()}
  265. def get_operation_keys(self, subpath, method, view):
  266. """Return a list of keys that should be used to group an operation within the specification. ::
  267. /users/ ("users", "list"), ("users", "create")
  268. /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
  269. /users/enabled/ ("users", "enabled") # custom viewset list action
  270. /users/{pk}/star/ ("users", "star") # custom viewset detail action
  271. /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
  272. /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update")
  273. :param str subpath: path to the operation with any common prefix/base path removed
  274. :param str method: HTTP method
  275. :param view: the view associated with the operation
  276. :rtype: list[str]
  277. """
  278. return self._gen.get_keys(subpath, method, view)
  279. def determine_path_prefix(self, paths):
  280. """
  281. Given a list of all paths, return the common prefix which should be
  282. discounted when generating a schema structure.
  283. This will be the longest common string that does not include that last
  284. component of the URL, or the last component before a path parameter.
  285. For example: ::
  286. /api/v1/users/
  287. /api/v1/users/{pk}/
  288. The path prefix is ``/api/v1/``.
  289. :param list[str] paths: list of paths
  290. :rtype: str
  291. """
  292. return self._gen.determine_path_prefix(paths)
  293. def should_include_endpoint(self, path, method, view, public):
  294. """Check if a given endpoint should be included in the resulting schema.
  295. :param str path: request path
  296. :param str method: http request method
  297. :param view: instantiated view callback
  298. :param bool public: if True, all endpoints are included regardless of access through `request`
  299. :returns: true if the view should be excluded
  300. :rtype: bool
  301. """
  302. return public or self._gen.has_view_permissions(path, method, view)
  303. def get_paths_object(self, paths):
  304. """Construct the Swagger Paths object.
  305. :param OrderedDict[str,openapi.PathItem] paths: mapping of paths to :class:`.PathItem` objects
  306. :returns: the :class:`.Paths` object
  307. :rtype: openapi.Paths
  308. """
  309. return openapi.Paths(paths=paths)
  310. def get_paths(self, endpoints, components, request, public):
  311. """Generate the Swagger Paths for the API from the given endpoints.
  312. :param dict endpoints: endpoints as returned by get_endpoints
  313. :param ReferenceResolver components: resolver/container for Swagger References
  314. :param Request request: the request made against the schema view; can be None
  315. :param bool public: if True, all endpoints are included regardless of access through `request`
  316. :returns: the :class:`.Paths` object and the longest common path prefix, as a 2-tuple
  317. :rtype: tuple[openapi.Paths,str]
  318. """
  319. if not endpoints:
  320. return openapi.Paths(paths={}), ''
  321. prefix = self.determine_path_prefix(list(endpoints.keys())) or ''
  322. assert '{' not in prefix, "base path cannot be templated in swagger 2.0"
  323. paths = OrderedDict()
  324. for path, (view_cls, methods) in sorted(endpoints.items()):
  325. operations = {}
  326. for method, view in methods:
  327. if not self.should_include_endpoint(path, method, view, public):
  328. continue
  329. operation = self.get_operation(view, path, prefix, method, components, request)
  330. if operation is not None:
  331. operations[method.lower()] = operation
  332. if operations:
  333. # since the common prefix is used as the API basePath, it must be stripped
  334. # from individual paths when writing them into the swagger document
  335. path_suffix = path[len(prefix):]
  336. if not path_suffix.startswith('/'):
  337. path_suffix = '/' + path_suffix
  338. paths[path_suffix] = self.get_path_item(path, view_cls, operations)
  339. return self.get_paths_object(paths), prefix
  340. def get_operation(self, view, path, prefix, method, components, request):
  341. """Get an :class:`.Operation` for the given API endpoint (path, method). This method delegates to
  342. :meth:`~.inspectors.ViewInspector.get_operation` of a :class:`~.inspectors.ViewInspector` determined
  343. according to settings and :func:`@swagger_auto_schema <.swagger_auto_schema>` overrides.
  344. :param view: the view associated with this endpoint
  345. :param str path: the path component of the operation URL
  346. :param str prefix: common path prefix among all endpoints
  347. :param str method: the http method of the operation
  348. :param openapi.ReferenceResolver components: referenceable components
  349. :param Request request: the request made against the schema view; can be None
  350. :rtype: openapi.Operation
  351. """
  352. operation_keys = self.get_operation_keys(path[len(prefix):], method, view)
  353. overrides = self.get_overrides(view, method)
  354. # the inspector class can be specified, in decreasing order of priority,
  355. # 1. globally via DEFAULT_AUTO_SCHEMA_CLASS
  356. view_inspector_cls = swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS
  357. # 2. on the view/viewset class
  358. view_inspector_cls = getattr(view, 'swagger_schema', view_inspector_cls)
  359. # 3. on the swagger_auto_schema decorator
  360. view_inspector_cls = overrides.get('auto_schema', view_inspector_cls)
  361. if view_inspector_cls is None:
  362. return None
  363. view_inspector = view_inspector_cls(view, path, method, components, request, overrides, operation_keys)
  364. operation = view_inspector.get_operation(operation_keys)
  365. if operation is None:
  366. return None
  367. if 'consumes' in operation and set(operation.consumes) == set(self.consumes):
  368. del operation.consumes
  369. if 'produces' in operation and set(operation.produces) == set(self.produces):
  370. del operation.produces
  371. return operation
  372. def get_path_item(self, path, view_cls, operations):
  373. """Get a :class:`.PathItem` object that describes the parameters and operations related to a single path in the
  374. API.
  375. :param str path: the path
  376. :param type view_cls: the view that was bound to this path in urlpatterns
  377. :param dict[str,openapi.Operation] operations: operations defined on this path, keyed by lowercase HTTP method
  378. :rtype: openapi.PathItem
  379. """
  380. path_parameters = self.get_path_parameters(path, view_cls)
  381. return openapi.PathItem(parameters=path_parameters, **operations)
  382. def get_overrides(self, view, method):
  383. """Get overrides specified for a given operation.
  384. :param view: the view associated with the operation
  385. :param str method: HTTP method
  386. :return: a dictionary containing any overrides set by :func:`@swagger_auto_schema <.swagger_auto_schema>`
  387. :rtype: dict
  388. """
  389. method = method.lower()
  390. action = getattr(view, 'action', method)
  391. action_method = getattr(view, action, None)
  392. overrides = getattr(action_method, '_swagger_auto_schema', {})
  393. if method in overrides:
  394. overrides = overrides[method]
  395. return copy.deepcopy(overrides)
  396. def get_path_parameters(self, path, view_cls):
  397. """Return a list of Parameter instances corresponding to any templated path variables.
  398. :param str path: templated request path
  399. :param type view_cls: the view class associated with the path
  400. :return: path parameters
  401. :rtype: list[openapi.Parameter]
  402. """
  403. parameters = []
  404. queryset = get_queryset_from_view(view_cls)
  405. for variable in sorted(uritemplate.variables(path)):
  406. model, model_field = get_queryset_field(queryset, variable)
  407. attrs = get_basic_type_info(model_field) or {'type': openapi.TYPE_STRING}
  408. if getattr(view_cls, 'lookup_field', None) == variable and attrs['type'] == openapi.TYPE_STRING:
  409. attrs['pattern'] = getattr(view_cls, 'lookup_value_regex', attrs.get('pattern', None))
  410. if model_field and getattr(model_field, 'help_text', False):
  411. description = model_field.help_text
  412. elif model_field and getattr(model_field, 'primary_key', False):
  413. description = get_pk_description(model, model_field)
  414. else:
  415. description = None
  416. field = openapi.Parameter(
  417. name=variable,
  418. description=force_real_str(description),
  419. required=True,
  420. in_=openapi.IN_PATH,
  421. **attrs
  422. )
  423. parameters.append(field)
  424. return parameters