view.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. import logging
  2. from collections import OrderedDict
  3. from rest_framework.request import is_form_media_type
  4. from rest_framework.schemas import AutoSchema
  5. from rest_framework.status import is_success
  6. from .. import openapi
  7. from ..errors import SwaggerGenerationError
  8. from ..utils import (
  9. filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status,
  10. merge_params, no_body, param_list_to_odict
  11. )
  12. from .base import ViewInspector, call_view_method
  13. logger = logging.getLogger(__name__)
  14. class SwaggerAutoSchema(ViewInspector):
  15. def __init__(self, view, path, method, components, request, overrides, operation_keys=None):
  16. super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides)
  17. self._sch = AutoSchema()
  18. self._sch.view = view
  19. self.operation_keys = operation_keys
  20. def get_operation(self, operation_keys=None):
  21. operation_keys = operation_keys or self.operation_keys
  22. consumes = self.get_consumes()
  23. produces = self.get_produces()
  24. body = self.get_request_body_parameters(consumes)
  25. query = self.get_query_parameters()
  26. parameters = body + query
  27. parameters = filter_none(parameters)
  28. parameters = self.add_manual_parameters(parameters)
  29. operation_id = self.get_operation_id(operation_keys)
  30. summary, description = self.get_summary_and_description()
  31. security = self.get_security()
  32. assert security is None or isinstance(security, list), "security must be a list of security requirement objects"
  33. deprecated = self.is_deprecated()
  34. tags = self.get_tags(operation_keys)
  35. responses = self.get_responses()
  36. return openapi.Operation(
  37. operation_id=operation_id,
  38. description=force_real_str(description),
  39. summary=force_real_str(summary),
  40. responses=responses,
  41. parameters=parameters,
  42. consumes=consumes,
  43. produces=produces,
  44. tags=tags,
  45. security=security,
  46. deprecated=deprecated
  47. )
  48. def get_request_body_parameters(self, consumes):
  49. """Return the request body parameters for this view. |br|
  50. This is either:
  51. - a list with a single object Parameter with a :class:`.Schema` derived from the request serializer
  52. - a list of primitive Parameters parsed as form data
  53. :param list[str] consumes: a list of accepted MIME types as returned by :meth:`.get_consumes`
  54. :return: a (potentially empty) list of :class:`.Parameter`\\ s either ``in: body`` or ``in: formData``
  55. :rtype: list[openapi.Parameter]
  56. """
  57. serializer = self.get_request_serializer()
  58. schema = None
  59. if serializer is None:
  60. return []
  61. if isinstance(serializer, openapi.Schema.OR_REF):
  62. schema = serializer
  63. if any(is_form_media_type(encoding) for encoding in consumes):
  64. if schema is not None:
  65. raise SwaggerGenerationError("form request body cannot be a Schema")
  66. return self.get_request_form_parameters(serializer)
  67. else:
  68. if schema is None:
  69. schema = self.get_request_body_schema(serializer)
  70. return [self.make_body_parameter(schema)] if schema is not None else []
  71. def get_view_serializer(self):
  72. """Return the serializer as defined by the view's ``get_serializer()`` method.
  73. :return: the view's ``Serializer``
  74. :rtype: rest_framework.serializers.Serializer
  75. """
  76. return call_view_method(self.view, 'get_serializer')
  77. def _get_request_body_override(self):
  78. """Parse the request_body key in the override dict. This method is not public API."""
  79. body_override = self.overrides.get('request_body', None)
  80. if body_override is not None:
  81. if body_override is no_body:
  82. return no_body
  83. if self.method not in self.body_methods:
  84. raise SwaggerGenerationError("request_body can only be applied to (" + ','.join(self.body_methods) +
  85. "); are you looking for query_serializer or manual_parameters?")
  86. if isinstance(body_override, openapi.Schema.OR_REF):
  87. return body_override
  88. return force_serializer_instance(body_override)
  89. return body_override
  90. def get_request_serializer(self):
  91. """Return the request serializer (used for parsing the request payload) for this endpoint.
  92. :return: the request serializer, or one of :class:`.Schema`, :class:`.SchemaRef`, ``None``
  93. :rtype: rest_framework.serializers.Serializer
  94. """
  95. body_override = self._get_request_body_override()
  96. if body_override is None and self.method in self.implicit_body_methods:
  97. return self.get_view_serializer()
  98. if body_override is no_body:
  99. return None
  100. return body_override
  101. def get_request_form_parameters(self, serializer):
  102. """Given a Serializer, return a list of ``in: formData`` :class:`.Parameter`\\ s.
  103. :param serializer: the view's request serializer as returned by :meth:`.get_request_serializer`
  104. :rtype: list[openapi.Parameter]
  105. """
  106. return self.serializer_to_parameters(serializer, in_=openapi.IN_FORM)
  107. def get_request_body_schema(self, serializer):
  108. """Return the :class:`.Schema` for a given request's body data. Only applies to PUT, PATCH and POST requests.
  109. :param serializer: the view's request serializer as returned by :meth:`.get_request_serializer`
  110. :rtype: openapi.Schema
  111. """
  112. return self.serializer_to_schema(serializer)
  113. def make_body_parameter(self, schema):
  114. """Given a :class:`.Schema` object, create an ``in: body`` :class:`.Parameter`.
  115. :param openapi.Schema schema: the request body schema
  116. :rtype: openapi.Parameter
  117. """
  118. return openapi.Parameter(name='data', in_=openapi.IN_BODY, required=True, schema=schema)
  119. def add_manual_parameters(self, parameters):
  120. """Add/replace parameters from the given list of automatically generated request parameters.
  121. :param list[openapi.Parameter] parameters: generated parameters
  122. :return: modified parameters
  123. :rtype: list[openapi.Parameter]
  124. """
  125. manual_parameters = self.overrides.get('manual_parameters', None) or []
  126. if any(param.in_ == openapi.IN_BODY for param in manual_parameters): # pragma: no cover
  127. raise SwaggerGenerationError("specify the body parameter as a Schema or Serializer in request_body")
  128. if any(param.in_ == openapi.IN_FORM for param in manual_parameters): # pragma: no cover
  129. has_body_parameter = any(param.in_ == openapi.IN_BODY for param in parameters)
  130. if has_body_parameter or not any(is_form_media_type(encoding) for encoding in self.get_consumes()):
  131. raise SwaggerGenerationError("cannot add form parameters when the request has a request body; "
  132. "did you forget to set an appropriate parser class on the view?")
  133. if self.method not in self.body_methods:
  134. raise SwaggerGenerationError("form parameters can only be applied to "
  135. "(" + ','.join(self.body_methods) + ") HTTP methods")
  136. return merge_params(parameters, manual_parameters)
  137. def get_responses(self):
  138. """Get the possible responses for this view as a swagger :class:`.Responses` object.
  139. :return: the documented responses
  140. :rtype: openapi.Responses
  141. """
  142. response_serializers = self.get_response_serializers()
  143. return openapi.Responses(
  144. responses=self.get_response_schemas(response_serializers)
  145. )
  146. def get_default_response_serializer(self):
  147. """Return the default response serializer for this endpoint. This is derived from either the ``request_body``
  148. override or the request serializer (:meth:`.get_view_serializer`).
  149. :return: response serializer, :class:`.Schema`, :class:`.SchemaRef`, ``None``
  150. """
  151. body_override = self._get_request_body_override()
  152. if body_override and body_override is not no_body:
  153. return body_override
  154. return self.get_view_serializer()
  155. def get_default_responses(self):
  156. """Get the default responses determined for this view from the request serializer and request method.
  157. :type: dict[str, openapi.Schema]
  158. """
  159. method = self.method.lower()
  160. default_status = guess_response_status(method)
  161. default_schema = ''
  162. if method in ('get', 'post', 'put', 'patch'):
  163. default_schema = self.get_default_response_serializer()
  164. default_schema = default_schema or ''
  165. if default_schema and not isinstance(default_schema, openapi.Schema):
  166. default_schema = self.serializer_to_schema(default_schema) or ''
  167. if default_schema:
  168. if self.has_list_response():
  169. default_schema = openapi.Schema(type=openapi.TYPE_ARRAY, items=default_schema)
  170. if self.should_page():
  171. default_schema = self.get_paginated_response(default_schema) or default_schema
  172. return OrderedDict({str(default_status): default_schema})
  173. def get_response_serializers(self):
  174. """Return the response codes that this view is expected to return, and the serializer for each response body.
  175. The return value should be a dict where the keys are possible status codes, and values are either strings,
  176. ``Serializer``\\ s, :class:`.Schema`, :class:`.SchemaRef` or :class:`.Response` objects. See
  177. :func:`@swagger_auto_schema <.swagger_auto_schema>` for more details.
  178. :return: the response serializers
  179. :rtype: dict
  180. """
  181. manual_responses = self.overrides.get('responses', None) or {}
  182. manual_responses = OrderedDict((str(sc), resp) for sc, resp in manual_responses.items())
  183. responses = OrderedDict()
  184. if not any(is_success(int(sc)) for sc in manual_responses if sc != 'default'):
  185. responses = self.get_default_responses()
  186. responses.update((str(sc), resp) for sc, resp in manual_responses.items())
  187. return responses
  188. def get_response_schemas(self, response_serializers):
  189. """Return the :class:`.openapi.Response` objects calculated for this view.
  190. :param dict response_serializers: response serializers as returned by :meth:`.get_response_serializers`
  191. :return: a dictionary of status code to :class:`.Response` object
  192. :rtype: dict[str, openapi.Response]
  193. """
  194. responses = OrderedDict()
  195. for sc, serializer in response_serializers.items():
  196. if isinstance(serializer, str):
  197. response = openapi.Response(
  198. description=force_real_str(serializer)
  199. )
  200. elif not serializer:
  201. continue
  202. elif isinstance(serializer, openapi.Response):
  203. response = serializer
  204. if hasattr(response, 'schema') and not isinstance(response.schema, openapi.Schema.OR_REF):
  205. serializer = force_serializer_instance(response.schema)
  206. response.schema = self.serializer_to_schema(serializer)
  207. elif isinstance(serializer, openapi.Schema.OR_REF):
  208. response = openapi.Response(
  209. description='',
  210. schema=serializer,
  211. )
  212. elif isinstance(serializer, openapi._Ref):
  213. response = serializer
  214. else:
  215. serializer = force_serializer_instance(serializer)
  216. response = openapi.Response(
  217. description='',
  218. schema=self.serializer_to_schema(serializer),
  219. )
  220. responses[str(sc)] = response
  221. return responses
  222. def get_query_serializer(self):
  223. """Return the query serializer (used for parsing query parameters) for this endpoint.
  224. :return: the query serializer, or ``None``
  225. """
  226. query_serializer = self.overrides.get('query_serializer', None)
  227. if query_serializer is not None:
  228. query_serializer = force_serializer_instance(query_serializer)
  229. return query_serializer
  230. def get_query_parameters(self):
  231. """Return the query parameters accepted by this view.
  232. :rtype: list[openapi.Parameter]
  233. """
  234. natural_parameters = self.get_filter_parameters() + self.get_pagination_parameters()
  235. query_serializer = self.get_query_serializer()
  236. serializer_parameters = []
  237. if query_serializer is not None:
  238. serializer_parameters = self.serializer_to_parameters(query_serializer, in_=openapi.IN_QUERY)
  239. if len(set(param_list_to_odict(natural_parameters)) & set(param_list_to_odict(serializer_parameters))) != 0:
  240. raise SwaggerGenerationError(
  241. "your query_serializer contains fields that conflict with the "
  242. "filter_backend or paginator_class on the view - %s %s" % (self.method, self.path)
  243. )
  244. return natural_parameters + serializer_parameters
  245. def get_operation_id(self, operation_keys=None):
  246. """Return an unique ID for this operation. The ID must be unique across
  247. all :class:`.Operation` objects in the API.
  248. :param tuple[str] operation_keys: an array of keys derived from the path describing the hierarchical layout
  249. of this view in the API; e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc.
  250. :rtype: str
  251. """
  252. operation_keys = operation_keys or self.operation_keys
  253. operation_id = self.overrides.get('operation_id', '')
  254. if not operation_id:
  255. operation_id = '_'.join(operation_keys)
  256. return operation_id
  257. def split_summary_from_description(self, description):
  258. """Decide if and how to split a summary out of the given description. The default implementation
  259. uses the first paragraph of the description as a summary if it is less than 120 characters long.
  260. :param description: the full description to be analyzed
  261. :return: summary and description
  262. :rtype: (str,str)
  263. """
  264. # https://www.python.org/dev/peps/pep-0257/#multi-line-docstrings
  265. summary = None
  266. summary_max_len = 120 # OpenAPI 2.0 spec says summary should be under 120 characters
  267. sections = description.split('\n\n', 1)
  268. if len(sections) == 2:
  269. sections[0] = sections[0].strip()
  270. if len(sections[0]) < summary_max_len:
  271. summary, description = sections
  272. description = description.strip()
  273. return summary, description
  274. def get_summary_and_description(self):
  275. """Return an operation summary and description determined from the view's docstring.
  276. :return: summary and description
  277. :rtype: (str,str)
  278. """
  279. description = self.overrides.get('operation_description', None)
  280. summary = self.overrides.get('operation_summary', None)
  281. if description is None:
  282. description = self._sch.get_description(self.path, self.method) or ''
  283. description = description.strip().replace('\r', '')
  284. if description and (summary is None):
  285. # description from docstring... do summary magic
  286. summary, description = self.split_summary_from_description(description)
  287. return summary, description
  288. def get_security(self):
  289. """Return a list of security requirements for this operation.
  290. Returning an empty list marks the endpoint as unauthenticated (i.e. removes all accepted
  291. authentication schemes). Returning ``None`` will inherit the top-level security requirements.
  292. :return: security requirements
  293. :rtype: list[dict[str,list[str]]]"""
  294. return self.overrides.get('security', None)
  295. def is_deprecated(self):
  296. """Return ``True`` if this operation is to be marked as deprecated.
  297. :return: deprecation status
  298. :rtype: bool
  299. """
  300. return self.overrides.get('deprecated', None)
  301. def get_tags(self, operation_keys=None):
  302. """Get a list of tags for this operation. Tags determine how operations relate with each other, and in the UI
  303. each tag will show as a group containing the operations that use it. If not provided in overrides,
  304. tags will be inferred from the operation url.
  305. :param tuple[str] operation_keys: an array of keys derived from the path describing the hierarchical layout
  306. of this view in the API; e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc.
  307. :rtype: list[str]
  308. """
  309. operation_keys = operation_keys or self.operation_keys
  310. tags = self.overrides.get('tags')
  311. if not tags:
  312. tags = [operation_keys[0]]
  313. return tags
  314. def get_consumes(self):
  315. """Return the MIME types this endpoint can consume.
  316. :rtype: list[str]
  317. """
  318. return get_consumes(self.get_parser_classes())
  319. def get_produces(self):
  320. """Return the MIME types this endpoint can produce.
  321. :rtype: list[str]
  322. """
  323. return get_produces(self.get_renderer_classes())