field.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855
  1. import datetime
  2. import inspect
  3. import logging
  4. import operator
  5. import typing
  6. import uuid
  7. import pkg_resources
  8. from packaging import version
  9. from collections import OrderedDict
  10. from decimal import Decimal
  11. from inspect import signature as inspect_signature
  12. from django.core import validators
  13. from django.db import models
  14. from rest_framework import serializers
  15. from rest_framework.settings import api_settings as rest_framework_settings
  16. from .. import openapi
  17. from ..errors import SwaggerGenerationError
  18. from ..utils import (
  19. decimal_as_float, field_value_to_representation, filter_none, get_serializer_class, get_serializer_ref_name
  20. )
  21. from .base import FieldInspector, NotHandled, SerializerInspector, call_view_method
  22. drf_version = pkg_resources.get_distribution("djangorestframework").version
  23. logger = logging.getLogger(__name__)
  24. class InlineSerializerInspector(SerializerInspector):
  25. """Provides serializer conversions using :meth:`.FieldInspector.field_to_swagger_object`."""
  26. #: whether to output :class:`.Schema` definitions inline or into the ``definitions`` section
  27. use_definitions = False
  28. def get_schema(self, serializer):
  29. return self.probe_field_inspectors(serializer, openapi.Schema, self.use_definitions)
  30. def add_manual_parameters(self, serializer, parameters):
  31. """Add/replace parameters from the given list of automatically generated request parameters. This method
  32. is called only when the serializer is converted into a list of parameters for use in a form data request.
  33. :param serializer: serializer instance
  34. :param list[openapi.Parameter] parameters: generated parameters
  35. :return: modified parameters
  36. :rtype: list[openapi.Parameter]
  37. """
  38. return parameters
  39. def get_request_parameters(self, serializer, in_):
  40. fields = getattr(serializer, 'fields', {})
  41. parameters = [
  42. self.probe_field_inspectors(
  43. value, openapi.Parameter, self.use_definitions,
  44. name=self.get_parameter_name(key), in_=in_
  45. )
  46. for key, value
  47. in fields.items()
  48. if not getattr(value, 'read_only', False)
  49. ]
  50. return self.add_manual_parameters(serializer, parameters)
  51. def get_property_name(self, field_name):
  52. return field_name
  53. def get_parameter_name(self, field_name):
  54. return field_name
  55. def get_serializer_ref_name(self, serializer):
  56. return get_serializer_ref_name(serializer)
  57. def _has_ref_name(self, serializer):
  58. serializer_meta = getattr(serializer, 'Meta', None)
  59. return hasattr(serializer_meta, 'ref_name')
  60. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  61. SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
  62. if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
  63. child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references)
  64. limits = find_limits(field) or {}
  65. return SwaggerType(
  66. type=openapi.TYPE_ARRAY,
  67. items=child_schema,
  68. **limits
  69. )
  70. elif isinstance(field, serializers.Serializer):
  71. if swagger_object_type != openapi.Schema:
  72. raise SwaggerGenerationError("cannot instantiate nested serializer as " + swagger_object_type.__name__)
  73. ref_name = self.get_serializer_ref_name(field)
  74. def make_schema_definition(serializer=field):
  75. properties = OrderedDict()
  76. required = []
  77. for property_name, child in serializer.fields.items():
  78. property_name = self.get_property_name(property_name)
  79. prop_kwargs = {
  80. 'read_only': bool(child.read_only) or None
  81. }
  82. prop_kwargs = filter_none(prop_kwargs)
  83. child_schema = self.probe_field_inspectors(
  84. child, ChildSwaggerType, use_references, **prop_kwargs
  85. )
  86. properties[property_name] = child_schema
  87. if child.required and not getattr(child_schema, 'read_only', False):
  88. required.append(property_name)
  89. result = SwaggerType(
  90. # the title is derived from the field name and is better to
  91. # be omitted from models
  92. use_field_title=False,
  93. type=openapi.TYPE_OBJECT,
  94. properties=properties,
  95. required=required or None,
  96. )
  97. setattr(result, '_NP_serializer', get_serializer_class(serializer))
  98. return result
  99. if not ref_name or not use_references:
  100. return make_schema_definition()
  101. definitions = self.components.with_scope(openapi.SCHEMA_DEFINITIONS)
  102. actual_schema = definitions.setdefault(ref_name, make_schema_definition)
  103. actual_schema._remove_read_only()
  104. actual_serializer = getattr(actual_schema, '_NP_serializer', None)
  105. this_serializer = get_serializer_class(field)
  106. if actual_serializer and actual_serializer != this_serializer: # pragma: no cover
  107. explicit_refs = self._has_ref_name(actual_serializer) and self._has_ref_name(this_serializer)
  108. if not explicit_refs:
  109. raise SwaggerGenerationError(
  110. "Schema for %s would override distinct serializer %s because they implicitly share the same "
  111. "ref_name; explicitly set the ref_name attribute on both serializers' Meta classes"
  112. % (actual_serializer, this_serializer))
  113. return openapi.SchemaRef(definitions, ref_name)
  114. return NotHandled
  115. class ReferencingSerializerInspector(InlineSerializerInspector):
  116. use_definitions = True
  117. def get_queryset_field(queryset, field_name):
  118. """Try to get information about a model and model field from a queryset.
  119. :param queryset: the queryset
  120. :param field_name: target field name
  121. :returns: the model and target field from the queryset as a 2-tuple; both elements can be ``None``
  122. :rtype: tuple
  123. """
  124. model = getattr(queryset, 'model', None)
  125. model_field = get_model_field(model, field_name)
  126. return model, model_field
  127. def get_model_field(model, field_name):
  128. """Try to get the given field from a django db model.
  129. :param model: the model
  130. :param field_name: target field name
  131. :return: model field or ``None``
  132. """
  133. try:
  134. if field_name == 'pk':
  135. return model._meta.pk
  136. else:
  137. return model._meta.get_field(field_name)
  138. except Exception: # pragma: no cover
  139. return None
  140. def get_queryset_from_view(view, serializer=None):
  141. """Try to get the queryset of the given view
  142. :param view: the view instance or class
  143. :param serializer: if given, will check that the view's get_serializer_class return matches this serializer
  144. :return: queryset or ``None``
  145. """
  146. try:
  147. queryset = call_view_method(view, 'get_queryset', 'queryset')
  148. if queryset is not None and serializer is not None:
  149. # make sure the view is actually using *this* serializer
  150. assert type(serializer) == call_view_method(view, 'get_serializer_class', 'serializer_class')
  151. return queryset
  152. except Exception: # pragma: no cover
  153. return None
  154. def get_parent_serializer(field):
  155. """Get the nearest parent ``Serializer`` instance for the given field.
  156. :return: ``Serializer`` or ``None``
  157. """
  158. while field is not None:
  159. if isinstance(field, serializers.Serializer):
  160. return field
  161. field = field.parent
  162. return None # pragma: no cover
  163. def get_related_model(model, source):
  164. """Try to find the other side of a model relationship given the name of a related field.
  165. :param model: one side of the relationship
  166. :param str source: related field name
  167. :return: related model or ``None``
  168. """
  169. try:
  170. descriptor = getattr(model, source)
  171. try:
  172. return descriptor.rel.related_model
  173. except Exception:
  174. return descriptor.field.remote_field.model
  175. except Exception: # pragma: no cover
  176. return None
  177. class RelatedFieldInspector(FieldInspector):
  178. """Provides conversions for ``RelatedField``\\ s."""
  179. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  180. SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
  181. if isinstance(field, serializers.ManyRelatedField):
  182. child_schema = self.probe_field_inspectors(field.child_relation, ChildSwaggerType, use_references)
  183. return SwaggerType(
  184. type=openapi.TYPE_ARRAY,
  185. items=child_schema,
  186. unique_items=True,
  187. )
  188. if not isinstance(field, serializers.RelatedField):
  189. return NotHandled
  190. field_queryset = getattr(field, 'queryset', None)
  191. if isinstance(field, (serializers.PrimaryKeyRelatedField, serializers.SlugRelatedField)):
  192. if getattr(field, 'pk_field', ''):
  193. # a PrimaryKeyRelatedField can have a `pk_field` attribute which is a
  194. # serializer field that will convert the PK value
  195. result = self.probe_field_inspectors(field.pk_field, swagger_object_type, use_references, **kwargs)
  196. # take the type, format, etc from `pk_field`, and the field-level information
  197. # like title, description, default from the PrimaryKeyRelatedField
  198. return SwaggerType(existing_object=result)
  199. target_field = getattr(field, 'slug_field', 'pk')
  200. if field_queryset is not None:
  201. # if the RelatedField has a queryset, try to get the related model field from there
  202. model, model_field = get_queryset_field(field_queryset, target_field)
  203. else:
  204. # if the RelatedField has no queryset (e.g. read only), try to find the target model
  205. # from the view queryset or ModelSerializer model, if present
  206. parent_serializer = get_parent_serializer(field)
  207. serializer_meta = getattr(parent_serializer, 'Meta', None)
  208. this_model = getattr(serializer_meta, 'model', None)
  209. if not this_model:
  210. view_queryset = get_queryset_from_view(self.view, parent_serializer)
  211. this_model = getattr(view_queryset, 'model', None)
  212. source = getattr(field, 'source', '') or field.field_name
  213. if not source and isinstance(field.parent, serializers.ManyRelatedField):
  214. source = field.parent.field_name
  215. model = get_related_model(this_model, source)
  216. model_field = get_model_field(model, target_field)
  217. attrs = get_basic_type_info(model_field) or {'type': openapi.TYPE_STRING}
  218. return SwaggerType(**attrs)
  219. elif isinstance(field, serializers.HyperlinkedRelatedField):
  220. return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)
  221. return SwaggerType(type=openapi.TYPE_STRING)
  222. def find_regex(regex_field):
  223. """Given a ``Field``, look for a ``RegexValidator`` and try to extract its pattern and return it as a string.
  224. :param serializers.Field regex_field: the field instance
  225. :return: the extracted pattern, or ``None``
  226. :rtype: str
  227. """
  228. regex_validator = None
  229. for validator in regex_field.validators:
  230. if isinstance(validator, validators.RegexValidator):
  231. if isinstance(validator, validators.URLValidator) or validator == validators.validate_ipv4_address:
  232. # skip the default url and IP regexes because they are complex and unhelpful
  233. # validate_ipv4_address is a RegexValidator instance in Django 1.11
  234. continue
  235. if regex_validator is not None:
  236. # bail if multiple validators are found - no obvious way to choose
  237. return None # pragma: no cover
  238. regex_validator = validator
  239. # regex_validator.regex should be a compiled re object...
  240. try:
  241. pattern = getattr(getattr(regex_validator, 'regex', None), 'pattern', None)
  242. except Exception: # pragma: no cover
  243. logger.warning('failed to compile regex validator of ' + str(regex_field), exc_info=True)
  244. return None
  245. if pattern:
  246. # attempt some basic cleanup to remove regex constructs not supported by JavaScript
  247. # -- swagger uses javascript-style regexes - see https://github.com/swagger-api/swagger-editor/issues/1601
  248. if pattern.endswith('\\Z') or pattern.endswith('\\z'):
  249. pattern = pattern[:-2] + '$'
  250. return pattern
  251. numeric_fields = (serializers.IntegerField, serializers.FloatField, serializers.DecimalField)
  252. limit_validators = [
  253. # minimum and maximum apply to numbers
  254. (validators.MinValueValidator, numeric_fields, 'minimum', operator.__gt__),
  255. (validators.MaxValueValidator, numeric_fields, 'maximum', operator.__lt__),
  256. # minLength and maxLength apply to strings
  257. (validators.MinLengthValidator, serializers.CharField, 'min_length', operator.__gt__),
  258. (validators.MaxLengthValidator, serializers.CharField, 'max_length', operator.__lt__),
  259. # minItems and maxItems apply to lists
  260. (validators.MinLengthValidator, (serializers.ListField, serializers.ListSerializer), 'min_items', operator.__gt__),
  261. (validators.MaxLengthValidator, (serializers.ListField, serializers.ListSerializer), 'max_items', operator.__lt__),
  262. ]
  263. def find_limits(field):
  264. """Given a ``Field``, look for min/max value/length validators and return appropriate limit validation attributes.
  265. :param serializers.Field field: the field instance
  266. :return: the extracted limits
  267. :rtype: OrderedDict
  268. """
  269. limits = {}
  270. applicable_limits = [
  271. (validator, attr, improves)
  272. for validator, field_class, attr, improves in limit_validators
  273. if isinstance(field, field_class)
  274. ]
  275. if isinstance(field, serializers.DecimalField) and not decimal_as_float(field):
  276. return limits
  277. for validator in field.validators:
  278. if not hasattr(validator, 'limit_value'):
  279. continue
  280. limit_value = validator.limit_value
  281. if isinstance(limit_value, Decimal) and decimal_as_float(field):
  282. limit_value = float(limit_value)
  283. for validator_class, attr, improves in applicable_limits:
  284. if isinstance(validator, validator_class):
  285. if attr not in limits or improves(limit_value, limits[attr]):
  286. limits[attr] = limit_value
  287. if hasattr(field, "allow_blank") and not field.allow_blank:
  288. if limits.get('min_length', 0) < 1:
  289. limits['min_length'] = 1
  290. return OrderedDict(sorted(limits.items()))
  291. def decimal_field_type(field):
  292. return openapi.TYPE_NUMBER if decimal_as_float(field) else openapi.TYPE_STRING
  293. model_field_to_basic_type = [
  294. (models.AutoField, (openapi.TYPE_INTEGER, None)),
  295. (models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)),
  296. (models.BooleanField, (openapi.TYPE_BOOLEAN, None)),
  297. (models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
  298. (models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
  299. (models.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)),
  300. (models.DurationField, (openapi.TYPE_STRING, None)),
  301. (models.FloatField, (openapi.TYPE_NUMBER, None)),
  302. (models.IntegerField, (openapi.TYPE_INTEGER, None)),
  303. (models.IPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV4)),
  304. (models.GenericIPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV6)),
  305. (models.SlugField, (openapi.TYPE_STRING, openapi.FORMAT_SLUG)),
  306. (models.TextField, (openapi.TYPE_STRING, None)),
  307. (models.TimeField, (openapi.TYPE_STRING, None)),
  308. (models.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
  309. (models.CharField, (openapi.TYPE_STRING, None)),
  310. ]
  311. ip_format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6}
  312. serializer_field_to_basic_type = [
  313. (serializers.EmailField, (openapi.TYPE_STRING, openapi.FORMAT_EMAIL)),
  314. (serializers.SlugField, (openapi.TYPE_STRING, openapi.FORMAT_SLUG)),
  315. (serializers.URLField, (openapi.TYPE_STRING, openapi.FORMAT_URI)),
  316. (serializers.IPAddressField, (openapi.TYPE_STRING, lambda field: ip_format.get(field.protocol, None))),
  317. (serializers.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
  318. (serializers.RegexField, (openapi.TYPE_STRING, None)),
  319. (serializers.CharField, (openapi.TYPE_STRING, None)),
  320. (serializers.BooleanField, (openapi.TYPE_BOOLEAN, None)),
  321. (serializers.IntegerField, (openapi.TYPE_INTEGER, None)),
  322. (serializers.FloatField, (openapi.TYPE_NUMBER, None)),
  323. (serializers.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)),
  324. (serializers.DurationField, (openapi.TYPE_STRING, None)),
  325. (serializers.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
  326. (serializers.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
  327. (serializers.ModelField, (openapi.TYPE_STRING, None)),
  328. ]
  329. if version.parse(drf_version) < version.parse("3.14.0"):
  330. model_field_to_basic_type.append(
  331. (models.NullBooleanField, (openapi.TYPE_BOOLEAN, None))
  332. )
  333. serializer_field_to_basic_type.append(
  334. (serializers.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
  335. )
  336. basic_type_info = serializer_field_to_basic_type + model_field_to_basic_type
  337. def get_basic_type_info(field):
  338. """Given a serializer or model ``Field``, return its basic type information - ``type``, ``format``, ``pattern``,
  339. and any applicable min/max limit values.
  340. :param field: the field instance
  341. :return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
  342. :rtype: OrderedDict
  343. """
  344. if field is None:
  345. return None
  346. for field_class, type_format in basic_type_info:
  347. if isinstance(field, field_class):
  348. swagger_type, format = type_format
  349. if callable(swagger_type):
  350. swagger_type = swagger_type(field)
  351. if callable(format):
  352. format = format(field)
  353. break
  354. else: # pragma: no cover
  355. return None
  356. pattern = None
  357. if swagger_type == openapi.TYPE_STRING:
  358. pattern = find_regex(field)
  359. limits = find_limits(field)
  360. result = OrderedDict([
  361. ('type', swagger_type),
  362. ('format', format),
  363. ('pattern', pattern)
  364. ])
  365. result.update(limits)
  366. result = filter_none(result)
  367. return result
  368. def decimal_return_type():
  369. return openapi.TYPE_STRING if rest_framework_settings.COERCE_DECIMAL_TO_STRING else openapi.TYPE_NUMBER
  370. def get_origin_type(hint_class):
  371. return getattr(hint_class, '__origin__', None) or hint_class
  372. def hint_class_issubclass(hint_class, check_class):
  373. origin_type = get_origin_type(hint_class)
  374. return inspect.isclass(origin_type) and issubclass(origin_type, check_class)
  375. hinting_type_info = [
  376. (bool, (openapi.TYPE_BOOLEAN, None)),
  377. (int, (openapi.TYPE_INTEGER, None)),
  378. (str, (openapi.TYPE_STRING, None)),
  379. (float, (openapi.TYPE_NUMBER, None)),
  380. (dict, (openapi.TYPE_OBJECT, None)),
  381. (Decimal, (decimal_return_type, openapi.FORMAT_DECIMAL)),
  382. (uuid.UUID, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
  383. (datetime.datetime, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
  384. (datetime.date, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
  385. ]
  386. if hasattr(typing, 'get_args'):
  387. # python >=3.8
  388. typing_get_args = typing.get_args
  389. else:
  390. # python <3.8
  391. def typing_get_args(tp):
  392. return getattr(tp, '__args__', ())
  393. def inspect_collection_hint_class(hint_class):
  394. args = typing_get_args(hint_class)
  395. child_class = args[0] if args else str
  396. child_type_info = get_basic_type_info_from_hint(child_class) or {'type': openapi.TYPE_STRING}
  397. return OrderedDict([
  398. ('type', openapi.TYPE_ARRAY),
  399. ('items', openapi.Items(**child_type_info)),
  400. ])
  401. hinting_type_info.append(((typing.Sequence, typing.AbstractSet), inspect_collection_hint_class))
  402. def _get_union_types(hint_class):
  403. origin_type = get_origin_type(hint_class)
  404. if origin_type is typing.Union:
  405. return hint_class.__args__
  406. def get_basic_type_info_from_hint(hint_class):
  407. """Given a class (eg from a SerializerMethodField's return type hint,
  408. return its basic type information - ``type``, ``format``, ``pattern``,
  409. and any applicable min/max limit values.
  410. :param hint_class: the class
  411. :return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
  412. :rtype: OrderedDict
  413. """
  414. union_types = _get_union_types(hint_class)
  415. if union_types:
  416. # Optional is implemented as Union[T, None]
  417. if len(union_types) == 2 and isinstance(None, union_types[1]):
  418. result = get_basic_type_info_from_hint(union_types[0])
  419. if result:
  420. result['x-nullable'] = True
  421. return result
  422. return None
  423. for check_class, info in hinting_type_info:
  424. if hint_class_issubclass(hint_class, check_class):
  425. if callable(info):
  426. return info(hint_class)
  427. swagger_type, format = info
  428. if callable(swagger_type):
  429. swagger_type = swagger_type()
  430. return OrderedDict([
  431. ('type', swagger_type),
  432. ('format', format),
  433. ])
  434. return None
  435. class SerializerMethodFieldInspector(FieldInspector):
  436. """Provides conversion for SerializerMethodField, optionally using information from the swagger_serializer_method
  437. decorator.
  438. """
  439. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  440. if not isinstance(field, serializers.SerializerMethodField):
  441. return NotHandled
  442. method = getattr(field.parent, field.method_name, None)
  443. if method is None:
  444. return NotHandled
  445. # attribute added by the swagger_serializer_method decorator
  446. serializer = getattr(method, "_swagger_serializer", None)
  447. if serializer:
  448. # in order of preference for description, use:
  449. # 1) field.help_text from SerializerMethodField(help_text)
  450. # 2) serializer.help_text from swagger_serializer_method(serializer)
  451. # 3) method's docstring
  452. description = field.help_text
  453. if description is None:
  454. description = getattr(serializer, 'help_text', None)
  455. if description is None:
  456. description = method.__doc__
  457. label = field.label
  458. if label is None:
  459. label = getattr(serializer, 'label', None)
  460. if inspect.isclass(serializer):
  461. serializer_kwargs = {
  462. "help_text": description,
  463. "label": label,
  464. "read_only": True,
  465. }
  466. serializer = method._swagger_serializer(**serializer_kwargs)
  467. else:
  468. serializer.help_text = description
  469. serializer.label = label
  470. serializer.read_only = True
  471. return self.probe_field_inspectors(serializer, swagger_object_type, use_references, read_only=True)
  472. else:
  473. # look for Python 3.5+ style type hinting of the return value
  474. hint_class = inspect_signature(method).return_annotation
  475. if not inspect.isclass(hint_class) and hasattr(hint_class, '__args__'):
  476. hint_class = hint_class.__args__[0]
  477. if inspect.isclass(hint_class) and not issubclass(hint_class, inspect._empty):
  478. type_info = get_basic_type_info_from_hint(hint_class)
  479. if type_info is not None:
  480. SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type,
  481. use_references, **kwargs)
  482. return SwaggerType(**type_info)
  483. return NotHandled
  484. class SimpleFieldInspector(FieldInspector):
  485. """Provides conversions for fields which can be described using just ``type``, ``format``, ``pattern``
  486. and min/max validators.
  487. """
  488. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  489. type_info = get_basic_type_info(field)
  490. if type_info is None:
  491. return NotHandled
  492. SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
  493. return SwaggerType(**type_info)
  494. class ChoiceFieldInspector(FieldInspector):
  495. """Provides conversions for ``ChoiceField`` and ``MultipleChoiceField``."""
  496. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  497. SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
  498. if isinstance(field, serializers.ChoiceField):
  499. enum_type = openapi.TYPE_STRING
  500. enum_values = []
  501. for choice in field.choices.keys():
  502. if isinstance(field, serializers.MultipleChoiceField):
  503. choice = field_value_to_representation(field, [choice])[0]
  504. else:
  505. choice = field_value_to_representation(field, choice)
  506. enum_values.append(choice)
  507. # for ModelSerializer, try to infer the type from the associated model field
  508. serializer = get_parent_serializer(field)
  509. if isinstance(serializer, serializers.ModelSerializer):
  510. model = getattr(getattr(serializer, 'Meta'), 'model')
  511. # Use the parent source for nested fields
  512. model_field = get_model_field(model, field.source or field.parent.source)
  513. # If the field has a base_field its type must be used
  514. if getattr(model_field, "base_field", None):
  515. model_field = model_field.base_field
  516. if model_field:
  517. model_type = get_basic_type_info(model_field)
  518. if model_type:
  519. enum_type = model_type.get('type', enum_type)
  520. else:
  521. # Try to infer field type based on enum values
  522. enum_value_types = {type(v) for v in enum_values}
  523. if len(enum_value_types) == 1:
  524. values_type = get_basic_type_info_from_hint(next(iter(enum_value_types)))
  525. if values_type:
  526. enum_type = values_type.get('type', enum_type)
  527. if isinstance(field, serializers.MultipleChoiceField):
  528. result = SwaggerType(
  529. type=openapi.TYPE_ARRAY,
  530. items=ChildSwaggerType(
  531. type=enum_type,
  532. enum=enum_values
  533. )
  534. )
  535. if swagger_object_type == openapi.Parameter:
  536. if result['in'] in (openapi.IN_FORM, openapi.IN_QUERY):
  537. result.collection_format = 'multi'
  538. else:
  539. result = SwaggerType(type=enum_type, enum=enum_values)
  540. return result
  541. return NotHandled
  542. class FileFieldInspector(FieldInspector):
  543. """Provides conversions for ``FileField``\\ s."""
  544. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  545. SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
  546. if isinstance(field, serializers.FileField):
  547. # swagger 2.0 does not support specifics about file fields, so ImageFile gets no special treatment
  548. # OpenAPI 3.0 does support it, so a future implementation could handle this better
  549. err = SwaggerGenerationError("FileField is supported only in a formData Parameter or response Schema")
  550. if swagger_object_type == openapi.Schema:
  551. # FileField.to_representation returns URL or file name
  552. result = SwaggerType(type=openapi.TYPE_STRING, read_only=True)
  553. if getattr(field, 'use_url', rest_framework_settings.UPLOADED_FILES_USE_URL):
  554. result.format = openapi.FORMAT_URI
  555. return result
  556. elif swagger_object_type == openapi.Parameter:
  557. param = SwaggerType(type=openapi.TYPE_FILE)
  558. if param['in'] != openapi.IN_FORM:
  559. raise err # pragma: no cover
  560. return param
  561. else:
  562. raise err # pragma: no cover
  563. return NotHandled
  564. class DictFieldInspector(FieldInspector):
  565. """Provides conversion for ``DictField``."""
  566. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  567. SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
  568. if isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema:
  569. child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references)
  570. return SwaggerType(
  571. type=openapi.TYPE_OBJECT,
  572. additional_properties=child_schema
  573. )
  574. return NotHandled
  575. class HiddenFieldInspector(FieldInspector):
  576. """Hide ``HiddenField``."""
  577. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  578. if isinstance(field, serializers.HiddenField):
  579. return None
  580. return NotHandled
  581. class JSONFieldInspector(FieldInspector):
  582. """Provides conversion for ``JSONField``."""
  583. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  584. SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
  585. if isinstance(field, serializers.JSONField) and swagger_object_type == openapi.Schema:
  586. return SwaggerType(type=openapi.TYPE_OBJECT)
  587. return NotHandled
  588. class StringDefaultFieldInspector(FieldInspector):
  589. """For otherwise unhandled fields, return them as plain :data:`.TYPE_STRING` objects."""
  590. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): # pragma: no cover
  591. # TODO unhandled fields: TimeField
  592. SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
  593. return SwaggerType(type=openapi.TYPE_STRING)
  594. try:
  595. from djangorestframework_camel_case.parser import CamelCaseJSONParser
  596. from djangorestframework_camel_case.render import CamelCaseJSONRenderer, camelize
  597. except ImportError: # pragma: no cover
  598. CamelCaseJSONParser = CamelCaseJSONRenderer = None
  599. def camelize(data):
  600. return data
  601. class CamelCaseJSONFilter(FieldInspector):
  602. """Converts property names to camelCase if ``djangorestframework_camel_case`` is used."""
  603. def camelize_string(self, s):
  604. """Hack to force ``djangorestframework_camel_case`` to camelize a plain string.
  605. :param str s: the string
  606. :return: camelized string
  607. :rtype: str
  608. """
  609. return next(iter(camelize({s: ''})))
  610. def camelize_schema(self, schema):
  611. """Recursively camelize property names for the given schema using ``djangorestframework_camel_case``.
  612. The target schema object must be modified in-place.
  613. :param openapi.Schema schema: the :class:`.Schema` object
  614. """
  615. if getattr(schema, 'properties', {}):
  616. schema.properties = OrderedDict(
  617. (self.camelize_string(key), self.camelize_schema(openapi.resolve_ref(val, self.components)) or val)
  618. for key, val in schema.properties.items()
  619. )
  620. if getattr(schema, 'required', []):
  621. schema.required = [self.camelize_string(p) for p in schema.required]
  622. def process_result(self, result, method_name, obj, **kwargs):
  623. if isinstance(result, openapi.Schema.OR_REF) and self.is_camel_case():
  624. schema = openapi.resolve_ref(result, self.components)
  625. self.camelize_schema(schema)
  626. return result
  627. if CamelCaseJSONParser and CamelCaseJSONRenderer:
  628. def is_camel_case(self):
  629. return (
  630. any(issubclass(parser, CamelCaseJSONParser) for parser in self.get_parser_classes()) or
  631. any(issubclass(renderer, CamelCaseJSONRenderer) for renderer in self.get_renderer_classes())
  632. )
  633. else:
  634. def is_camel_case(self):
  635. return False
  636. try:
  637. from rest_framework_recursive.fields import RecursiveField
  638. except ImportError: # pragma: no cover
  639. class RecursiveFieldInspector(FieldInspector):
  640. """Provides conversion for RecursiveField (https://github.com/heywbj/django-rest-framework-recursive)"""
  641. pass
  642. else:
  643. class RecursiveFieldInspector(FieldInspector):
  644. """Provides conversion for RecursiveField (https://github.com/heywbj/django-rest-framework-recursive)"""
  645. def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
  646. if isinstance(field, RecursiveField) and swagger_object_type == openapi.Schema:
  647. assert use_references is True, "Can not create schema for RecursiveField when use_references is False"
  648. proxied = field.proxied
  649. if isinstance(field.proxied, serializers.ListSerializer):
  650. proxied = proxied.child
  651. ref_name = get_serializer_ref_name(proxied)
  652. assert ref_name is not None, "Can't create RecursiveField schema for inline " + str(type(proxied))
  653. definitions = self.components.with_scope(openapi.SCHEMA_DEFINITIONS)
  654. ref = openapi.SchemaRef(definitions, ref_name, ignore_unresolved=True)
  655. if isinstance(field.proxied, serializers.ListSerializer):
  656. ref = openapi.Items(type=openapi.TYPE_ARRAY, items=ref)
  657. return ref
  658. return NotHandled