utils.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. import inspect
  2. import logging
  3. import sys
  4. import textwrap
  5. from collections import OrderedDict
  6. from decimal import Decimal
  7. import pytz
  8. from django.db import models
  9. from django.utils.encoding import force_str
  10. from rest_framework import serializers, status
  11. from rest_framework.mixins import DestroyModelMixin, ListModelMixin, RetrieveModelMixin, UpdateModelMixin
  12. from rest_framework.parsers import FileUploadParser
  13. from rest_framework.request import is_form_media_type
  14. from rest_framework.settings import api_settings as rest_framework_settings
  15. from rest_framework.utils import encoders, json
  16. from rest_framework.views import APIView
  17. from .app_settings import swagger_settings
  18. logger = logging.getLogger(__name__)
  19. class no_body(object):
  20. """Used as a sentinel value to forcibly remove the body of a request via :func:`.swagger_auto_schema`."""
  21. pass
  22. class unset(object):
  23. """Used as a sentinel value for function parameters not set by the caller where ``None`` would be a valid value."""
  24. pass
  25. def swagger_auto_schema(method=None, methods=None, auto_schema=unset, request_body=None, query_serializer=None,
  26. manual_parameters=None, operation_id=None, operation_description=None, operation_summary=None,
  27. security=None, deprecated=None, responses=None, field_inspectors=None, filter_inspectors=None,
  28. paginator_inspectors=None, tags=None, **extra_overrides):
  29. """Decorate a view method to customize the :class:`.Operation` object generated from it.
  30. `method` and `methods` are mutually exclusive and must only be present when decorating a view method that accepts
  31. more than one HTTP request method.
  32. The `auto_schema` and `operation_description` arguments take precedence over view- or method-level values.
  33. :param str method: for multi-method views, the http method the options should apply to
  34. :param list[str] methods: for multi-method views, the http methods the options should apply to
  35. :param drf_yasg.inspectors.SwaggerAutoSchema auto_schema: custom class to use for generating the Operation object;
  36. this overrides both the class-level ``swagger_schema`` attribute and the ``DEFAULT_AUTO_SCHEMA_CLASS``
  37. setting, and can be set to ``None`` to prevent this operation from being generated
  38. :param request_body: custom request body which will be used as the ``schema`` property of a
  39. :class:`.Parameter` with ``in: 'body'``.
  40. A Schema or SchemaRef is not valid if this request consumes form-data, because ``form`` and ``body`` parameters
  41. are mutually exclusive in an :class:`.Operation`. If you need to set custom ``form`` parameters, you can use
  42. the `manual_parameters` argument.
  43. If a ``Serializer`` class or instance is given, it will be automatically converted into a :class:`.Schema`
  44. used as a ``body`` :class:`.Parameter`, or into a list of ``form`` :class:`.Parameter`\\ s, as appropriate.
  45. :type request_body: drf_yasg.openapi.Schema or drf_yasg.openapi.SchemaRef or rest_framework.serializers.Serializer
  46. or type[no_body]
  47. :param rest_framework.serializers.Serializer query_serializer: if you use a ``Serializer`` to parse query
  48. parameters, you can pass it here and have :class:`.Parameter` objects be generated automatically from it.
  49. If any ``Field`` on the serializer cannot be represented as a ``query`` :class:`.Parameter`
  50. (e.g. nested Serializers, file fields, ...), the schema generation will fail with an error.
  51. Schema generation will also fail if the name of any Field on the `query_serializer` conflicts with parameters
  52. generated by ``filter_backends`` or ``paginator``.
  53. :param list[drf_yasg.openapi.Parameter] manual_parameters: a list of manual parameters to override the
  54. automatically generated ones
  55. :class:`.Parameter`\\ s are identified by their (``name``, ``in``) combination, and any parameters given
  56. here will fully override automatically generated parameters if they collide.
  57. It is an error to supply ``form`` parameters when the request does not consume form-data.
  58. :param str operation_id: operation ID override; the operation ID must be unique across the whole API
  59. :param str operation_description: operation description override
  60. :param str operation_summary: operation summary string
  61. :param list[dict] security: security requirements override; used to specify which authentication mechanism
  62. is required to call this API; an empty list marks the endpoint as unauthenticated (i.e. removes all accepted
  63. authentication schemes), and ``None`` will inherit the top-level security requirements
  64. :param bool deprecated: deprecation status for operation
  65. :param responses: a dict of documented manual responses
  66. keyed on response status code. If no success (``2xx``) response is given, one will automatically be
  67. generated from the request body and http method. If any ``2xx`` response is given the automatic response is
  68. suppressed.
  69. * if a plain string is given as value, a :class:`.Response` with no body and that string as its description
  70. will be generated
  71. * if ``None`` is given as a value, the response is ignored; this is mainly useful for disabling default
  72. 2xx responses, i.e. ``responses={200: None, 302: 'something'}``
  73. * if a :class:`.Schema`, :class:`.SchemaRef` is given, a :class:`.Response` with the schema as its body and
  74. an empty description will be generated
  75. * a ``Serializer`` class or instance will be converted into a :class:`.Schema` and treated as above
  76. * a :class:`.Response` object will be used as-is; however if its ``schema`` attribute is a ``Serializer``,
  77. it will automatically be converted into a :class:`.Schema`
  78. :type responses: dict[int or str, (drf_yasg.openapi.Schema or drf_yasg.openapi.SchemaRef or
  79. drf_yasg.openapi.Response or str or rest_framework.serializers.Serializer)]
  80. :param list[type[drf_yasg.inspectors.FieldInspector]] field_inspectors: extra serializer and field inspectors; these
  81. will be tried before :attr:`.ViewInspector.field_inspectors` on the :class:`.inspectors.SwaggerAutoSchema`
  82. :param list[type[drf_yasg.inspectors.FilterInspector]] filter_inspectors: extra filter inspectors; these will be
  83. tried before :attr:`.ViewInspector.filter_inspectors` on the :class:`.inspectors.SwaggerAutoSchema`
  84. :param list[type[drf_yasg.inspectors.PaginatorInspector]] paginator_inspectors: extra paginator inspectors; these
  85. will be tried before :attr:`.ViewInspector.paginator_inspectors` on the :class:`.inspectors.SwaggerAutoSchema`
  86. :param list[str] tags: tags override
  87. :param extra_overrides: extra values that will be saved into the ``overrides`` dict; these values will be available
  88. in the handling :class:`.inspectors.SwaggerAutoSchema` instance via ``self.overrides``
  89. """
  90. def decorator(view_method):
  91. assert not any(hm in extra_overrides for hm in APIView.http_method_names), "HTTP method names not allowed here"
  92. data = {
  93. 'request_body': request_body,
  94. 'query_serializer': query_serializer,
  95. 'manual_parameters': manual_parameters,
  96. 'operation_id': operation_id,
  97. 'operation_summary': operation_summary,
  98. 'deprecated': deprecated,
  99. 'operation_description': operation_description,
  100. 'security': security,
  101. 'responses': responses,
  102. 'filter_inspectors': list(filter_inspectors) if filter_inspectors else None,
  103. 'paginator_inspectors': list(paginator_inspectors) if paginator_inspectors else None,
  104. 'field_inspectors': list(field_inspectors) if field_inspectors else None,
  105. 'tags': list(tags) if tags else None,
  106. }
  107. data = filter_none(data)
  108. if auto_schema is not unset:
  109. data['auto_schema'] = auto_schema
  110. data.update(extra_overrides)
  111. if not data: # pragma: no cover
  112. # no overrides to set, no use in doing more work
  113. return view_method
  114. # if the method is an @action, it will have a bind_to_methods attribute, or a mapping attribute for drf>3.8
  115. bind_to_methods = getattr(view_method, 'bind_to_methods', [])
  116. mapping = getattr(view_method, 'mapping', {})
  117. mapping_methods = [mth for mth, name in mapping.items() if name == view_method.__name__]
  118. action_http_methods = bind_to_methods + mapping_methods
  119. # if the method is actually a function based view (@api_view), it will have a 'cls' attribute
  120. view_cls = getattr(view_method, 'cls', None)
  121. api_view_http_methods = [m for m in getattr(view_cls, 'http_method_names', []) if hasattr(view_cls, m)]
  122. available_http_methods = api_view_http_methods + action_http_methods
  123. existing_data = getattr(view_method, '_swagger_auto_schema', {})
  124. _methods = methods
  125. if methods or method:
  126. assert available_http_methods, "`method` or `methods` can only be specified on @action or @api_view views"
  127. assert bool(methods) != bool(method), "specify either method or methods"
  128. assert not isinstance(methods, str), "`methods` expects to receive a list of methods;" \
  129. " use `method` for a single argument"
  130. if method:
  131. _methods = [method.lower()]
  132. else:
  133. _methods = [mth.lower() for mth in methods]
  134. assert all(mth in available_http_methods for mth in _methods), "http method not bound to view"
  135. assert not any(mth in existing_data for mth in _methods), "http method defined multiple times"
  136. if available_http_methods:
  137. # action or api_view
  138. assert bool(api_view_http_methods) != bool(action_http_methods), "this should never happen"
  139. if len(available_http_methods) > 1:
  140. assert _methods, \
  141. "on multi-method api_view or action, you must specify " \
  142. "swagger_auto_schema on a per-method basis using one of the `method` or `methods` arguments"
  143. else:
  144. # for a single-method view we assume that single method as the decorator target
  145. _methods = _methods or available_http_methods
  146. assert not any(hasattr(getattr(view_cls, mth, None), '_swagger_auto_schema') for mth in _methods), \
  147. "swagger_auto_schema applied twice to method"
  148. assert not any(mth in existing_data for mth in _methods), "swagger_auto_schema applied twice to method"
  149. existing_data.update((mth.lower(), data) for mth in _methods)
  150. view_method._swagger_auto_schema = existing_data
  151. else:
  152. assert not _methods, \
  153. "the methods argument should only be specified when decorating an action; " \
  154. "you should also ensure that you put the swagger_auto_schema decorator " \
  155. "AFTER (above) the _route decorator"
  156. assert not existing_data, "swagger_auto_schema applied twice to method"
  157. view_method._swagger_auto_schema = data
  158. return view_method
  159. return decorator
  160. def swagger_serializer_method(serializer_or_field):
  161. """
  162. Decorates the method of a serializers.SerializerMethodField
  163. to hint as to how Swagger should be generated for this field.
  164. :param serializer_or_field: ``Serializer``/``Field`` class or instance
  165. :return:
  166. """
  167. def decorator(serializer_method):
  168. # stash the serializer for SerializerMethodFieldInspector to find
  169. serializer_method._swagger_serializer = serializer_or_field
  170. return serializer_method
  171. return decorator
  172. def is_list_view(path, method, view):
  173. """Check if the given path/method appears to represent a list view (as opposed to a detail/instance view).
  174. :param str path: view path
  175. :param str method: http method
  176. :param APIView view: target view
  177. :rtype: bool
  178. """
  179. # for ViewSets, it could be the default 'list' action, or an @action(detail=False)
  180. action = getattr(view, 'action', '')
  181. method = getattr(view, action, None) or method
  182. detail = getattr(method, 'detail', None)
  183. suffix = getattr(view, 'suffix', None)
  184. if action in ('list', 'create') or detail is False or suffix == 'List':
  185. return True
  186. if action in ('retrieve', 'update', 'partial_update', 'destroy') or detail is True or suffix == 'Instance':
  187. # a detail action is surely not a list route
  188. return False
  189. if isinstance(view, ListModelMixin):
  190. return True
  191. # for GenericAPIView, if it's a detail view it can't also be a list view
  192. if isinstance(view, (RetrieveModelMixin, UpdateModelMixin, DestroyModelMixin)):
  193. return False
  194. # if the last component in the path is parameterized it's probably not a list view
  195. path_components = path.strip('/').split('/')
  196. if path_components and '{' in path_components[-1]:
  197. return False
  198. # otherwise assume it's a list view
  199. return True
  200. def guess_response_status(method):
  201. if method == 'post':
  202. return status.HTTP_201_CREATED
  203. elif method == 'delete':
  204. return status.HTTP_204_NO_CONTENT
  205. else:
  206. return status.HTTP_200_OK
  207. def param_list_to_odict(parameters):
  208. """Transform a list of :class:`.Parameter` objects into an ``OrderedDict`` keyed on the ``(name, in_)`` tuple of
  209. each parameter.
  210. Raises an ``AssertionError`` if `parameters` contains duplicate parameters (by their name + in combination).
  211. :param list[drf_yasg.openapi.Parameter] parameters: the list of parameters
  212. :return: `parameters` keyed by ``(name, in_)``
  213. :rtype: dict[(str,str),drf_yasg.openapi.Parameter]
  214. """
  215. result = OrderedDict(((param.name, param.in_), param) for param in parameters)
  216. assert len(result) == len(parameters), "duplicate Parameters found"
  217. return result
  218. def merge_params(parameters, overrides):
  219. """Merge `overrides` into `parameters`. This is the same as appending `overrides` to `parameters`, but any element
  220. of `parameters` whose ``(name, in_)`` tuple collides with an element in `overrides` is replaced by it.
  221. Raises an ``AssertionError`` if either list contains duplicate parameters.
  222. :param list[drf_yasg.openapi.Parameter] parameters: initial parameters
  223. :param list[drf_yasg.openapi.Parameter] overrides: overriding parameters
  224. :return: merged list
  225. :rtype: list[drf_yasg.openapi.Parameter]
  226. """
  227. parameters = param_list_to_odict(parameters)
  228. parameters.update(param_list_to_odict(overrides))
  229. return list(parameters.values())
  230. def filter_none(obj):
  231. """Remove ``None`` values from tuples, lists or dictionaries. Return other objects as-is.
  232. :param obj: the object
  233. :return: collection with ``None`` values removed
  234. """
  235. if obj is None:
  236. return None
  237. new_obj = None
  238. if isinstance(obj, dict):
  239. new_obj = type(obj)((k, v) for k, v in obj.items() if k is not None and v is not None)
  240. if isinstance(obj, (list, tuple)):
  241. new_obj = type(obj)(v for v in obj if v is not None)
  242. if new_obj is not None and len(new_obj) != len(obj):
  243. return new_obj # pragma: no cover
  244. return obj
  245. def force_serializer_instance(serializer):
  246. """Force `serializer` into a ``Serializer`` instance. If it is not a ``Serializer`` class or instance, raises
  247. an assertion error.
  248. :param serializer: serializer class or instance
  249. :type serializer: serializers.BaseSerializer or type[serializers.BaseSerializer]
  250. :return: serializer instance
  251. :rtype: serializers.BaseSerializer
  252. """
  253. if inspect.isclass(serializer):
  254. assert issubclass(serializer, serializers.BaseSerializer), "Serializer required, not %s" % serializer.__name__
  255. return serializer()
  256. assert isinstance(serializer, serializers.BaseSerializer), \
  257. "Serializer class or instance required, not %s" % type(serializer).__name__
  258. return serializer
  259. def get_serializer_class(serializer):
  260. """Given a ``Serializer`` class or instance, return the ``Serializer`` class. If `serializer` is not a ``Serializer``
  261. class or instance, raises an assertion error.
  262. :param serializer: serializer class or instance, or ``None``
  263. :return: serializer class
  264. :rtype: type[serializers.BaseSerializer]
  265. """
  266. if serializer is None:
  267. return None
  268. if inspect.isclass(serializer):
  269. assert issubclass(serializer, serializers.BaseSerializer), "Serializer required, not %s" % serializer.__name__
  270. return serializer
  271. assert isinstance(serializer, serializers.BaseSerializer), \
  272. "Serializer class or instance required, not %s" % type(serializer).__name__
  273. return type(serializer)
  274. def get_object_classes(classes_or_instances, expected_base_class=None):
  275. """Given a list of instances or class objects, return the list of their classes.
  276. :param classes_or_instances: mixed list to parse
  277. :type classes_or_instances: list[type or object]
  278. :param expected_base_class: if given, only subclasses or instances of this type will be returned
  279. :type expected_base_class: type
  280. :return: list of classes
  281. :rtype: list
  282. """
  283. classes_or_instances = classes_or_instances or []
  284. result = []
  285. for obj in classes_or_instances:
  286. if inspect.isclass(obj):
  287. if not expected_base_class or issubclass(obj, expected_base_class):
  288. result.append(obj)
  289. else:
  290. if not expected_base_class or isinstance(obj, expected_base_class):
  291. result.append(type(obj))
  292. return result
  293. def get_consumes(parser_classes):
  294. """Extract ``consumes`` MIME types from a list of parser classes.
  295. :param list parser_classes: parser classes
  296. :type parser_classes: list[rest_framework.parsers.BaseParser or type[rest_framework.parsers.BaseParser]]
  297. :return: MIME types for ``consumes``
  298. :rtype: list[str]
  299. """
  300. parser_classes = get_object_classes(parser_classes)
  301. parser_classes = [pc for pc in parser_classes if not issubclass(pc, FileUploadParser)]
  302. media_types = [parser.media_type for parser in parser_classes or []]
  303. non_form_media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)]
  304. # Because swagger Parameter objects don't support complex data types (nested objects, arrays),
  305. # we can't use those unless we are sure the view *only* accepts form data
  306. # This means that a view won't support file upload in swagger unless it explicitly
  307. # sets its parser classes to include only form parsers
  308. if len(non_form_media_types) == 0:
  309. return media_types
  310. # If the form accepts both form data and another type, like json (which is the default config),
  311. # we will render its input as a Schema and thus it file parameters will be read-only
  312. return non_form_media_types
  313. def get_produces(renderer_classes):
  314. """Extract ``produces`` MIME types from a list of renderer classes.
  315. :param list renderer_classes: renderer classes
  316. :type renderer_classes: list[rest_framework.renderers.BaseRenderer or type[rest_framework.renderers.BaseRenderer]]
  317. :return: MIME types for ``produces``
  318. :rtype: list[str]
  319. """
  320. renderer_classes = get_object_classes(renderer_classes)
  321. media_types = [renderer.media_type for renderer in renderer_classes or []]
  322. media_types = [encoding for encoding in media_types
  323. if not any(excluded in encoding for excluded in swagger_settings.EXCLUDED_MEDIA_TYPES)]
  324. return media_types
  325. def decimal_as_float(field):
  326. """Returns true if ``field`` is a django-rest-framework DecimalField and its ``coerce_to_string`` attribute or the
  327. ``COERCE_DECIMAL_TO_STRING`` setting is set to ``False``.
  328. :rtype: bool
  329. """
  330. if isinstance(field, serializers.DecimalField) or isinstance(field, models.DecimalField):
  331. return not getattr(field, 'coerce_to_string', rest_framework_settings.COERCE_DECIMAL_TO_STRING)
  332. return False
  333. def get_serializer_ref_name(serializer):
  334. """Get serializer's ref_name (or None for ModelSerializer if it is named 'NestedSerializer')
  335. :param serializer: Serializer instance
  336. :return: Serializer's ``ref_name`` or ``None`` for inline serializer
  337. :rtype: str or None
  338. """
  339. serializer_meta = getattr(serializer, 'Meta', None)
  340. serializer_name = type(serializer).__name__
  341. if hasattr(serializer_meta, 'ref_name'):
  342. ref_name = serializer_meta.ref_name
  343. elif serializer_name == 'NestedSerializer' and isinstance(serializer, serializers.ModelSerializer):
  344. logger.debug("Forcing inline output for ModelSerializer named 'NestedSerializer':\n" + str(serializer))
  345. ref_name = None
  346. else:
  347. ref_name = serializer_name
  348. if ref_name.endswith('Serializer'):
  349. ref_name = ref_name[:-len('Serializer')]
  350. return ref_name
  351. def force_real_str(s, encoding='utf-8', strings_only=False, errors='strict'):
  352. """
  353. Force `s` into a ``str`` instance.
  354. Fix for https://github.com/axnsan12/drf-yasg/issues/159
  355. """
  356. if s is not None:
  357. s = force_str(s, encoding, strings_only, errors)
  358. if type(s) != str:
  359. s = '' + s
  360. # Remove common indentation to get the correct Markdown rendering
  361. s = textwrap.dedent(s)
  362. return s
  363. def field_value_to_representation(field, value):
  364. """Convert a python value related to a field (default, choices, etc.) into its OpenAPI-compatible representation.
  365. :param serializers.Field field: field associated with the value
  366. :param object value: value
  367. :return: the converted value
  368. """
  369. value = field.to_representation(value)
  370. if isinstance(value, Decimal):
  371. if decimal_as_float(field):
  372. value = float(value)
  373. else:
  374. value = str(value)
  375. if isinstance(value, pytz.BaseTzInfo):
  376. value = str(value)
  377. # JSON roundtrip ensures that the value is valid JSON;
  378. # for example, sets and tuples get transformed into lists
  379. return json.loads(json.dumps(value, cls=encoders.JSONEncoder))
  380. def get_field_default(field):
  381. """
  382. Get the default value for a field, converted to a JSON-compatible value while properly handling callables.
  383. :param field: field instance
  384. :return: default value
  385. """
  386. default = getattr(field, 'default', serializers.empty)
  387. if default is not serializers.empty:
  388. if callable(default):
  389. try:
  390. if hasattr(default, 'set_context'):
  391. default.set_context(field)
  392. if getattr(default, 'requires_context', False):
  393. default = default(field)
  394. else:
  395. default = default()
  396. except Exception: # pragma: no cover
  397. logger.warning("default for %s is callable but it raised an exception when "
  398. "called; 'default' will not be set on schema", field, exc_info=True)
  399. default = serializers.empty
  400. if default is not serializers.empty and default is not None:
  401. try:
  402. default = field_value_to_representation(field, default)
  403. except Exception: # pragma: no cover
  404. logger.warning("'default' on schema for %s will not be set because "
  405. "to_representation raised an exception", field, exc_info=True)
  406. default = serializers.empty
  407. return default
  408. def dict_has_ordered_keys(obj):
  409. """Check if a given object is a dict that maintains insertion order.
  410. :param obj: the dict object to check
  411. :rtype: bool
  412. """
  413. if sys.version_info >= (3, 7):
  414. # the Python 3.7 language spec says that dict must maintain insertion order.
  415. return isinstance(obj, dict)
  416. return isinstance(obj, OrderedDict)