query.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from collections import OrderedDict
  2. import coreschema
  3. from rest_framework.pagination import CursorPagination, LimitOffsetPagination, PageNumberPagination
  4. from .. import openapi
  5. from ..utils import force_real_str
  6. from .base import FilterInspector, PaginatorInspector
  7. class CoreAPICompatInspector(PaginatorInspector, FilterInspector):
  8. """Converts ``coreapi.Field``\\ s to :class:`.openapi.Parameter`\\ s for filters and paginators that implement a
  9. ``get_schema_fields`` method.
  10. """
  11. def get_paginator_parameters(self, paginator):
  12. fields = []
  13. if hasattr(paginator, 'get_schema_fields'):
  14. fields = paginator.get_schema_fields(self.view)
  15. return [self.coreapi_field_to_parameter(field) for field in fields]
  16. def get_filter_parameters(self, filter_backend):
  17. fields = []
  18. if hasattr(filter_backend, 'get_schema_fields'):
  19. fields = filter_backend.get_schema_fields(self.view)
  20. return [self.coreapi_field_to_parameter(field) for field in fields]
  21. def coreapi_field_to_parameter(self, field):
  22. """Convert an instance of `coreapi.Field` to a swagger :class:`.Parameter` object.
  23. :param coreapi.Field field:
  24. :rtype: openapi.Parameter
  25. """
  26. location_to_in = {
  27. 'query': openapi.IN_QUERY,
  28. 'path': openapi.IN_PATH,
  29. 'form': openapi.IN_FORM,
  30. 'body': openapi.IN_FORM,
  31. }
  32. coreapi_types = {
  33. coreschema.Integer: openapi.TYPE_INTEGER,
  34. coreschema.Number: openapi.TYPE_NUMBER,
  35. coreschema.String: openapi.TYPE_STRING,
  36. coreschema.Boolean: openapi.TYPE_BOOLEAN,
  37. }
  38. coreschema_attrs = ['format', 'pattern', 'enum', 'min_length', 'max_length']
  39. schema = field.schema
  40. return openapi.Parameter(
  41. name=field.name,
  42. in_=location_to_in[field.location],
  43. required=field.required,
  44. description=force_real_str(schema.description) if schema else None,
  45. type=coreapi_types.get(type(schema), openapi.TYPE_STRING),
  46. **OrderedDict((attr, getattr(schema, attr, None)) for attr in coreschema_attrs)
  47. )
  48. class DjangoRestResponsePagination(PaginatorInspector):
  49. """Provides response schema pagination wrapping for django-rest-framework's LimitOffsetPagination,
  50. PageNumberPagination and CursorPagination
  51. """
  52. def get_paginated_response(self, paginator, response_schema):
  53. assert response_schema.type == openapi.TYPE_ARRAY, "array return expected for paged response"
  54. paged_schema = None
  55. if isinstance(paginator, (LimitOffsetPagination, PageNumberPagination, CursorPagination)):
  56. has_count = not isinstance(paginator, CursorPagination)
  57. paged_schema = openapi.Schema(
  58. type=openapi.TYPE_OBJECT,
  59. properties=OrderedDict((
  60. ('count', openapi.Schema(type=openapi.TYPE_INTEGER) if has_count else None),
  61. ('next', openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI, x_nullable=True)),
  62. ('previous', openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI, x_nullable=True)),
  63. ('results', response_schema),
  64. )),
  65. required=['results']
  66. )
  67. if has_count:
  68. paged_schema.required.insert(0, 'count')
  69. return paged_schema