subscribe.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from inspect import isawaitable
  2. from typing import (
  3. Any,
  4. AsyncIterable,
  5. AsyncIterator,
  6. Dict,
  7. Optional,
  8. Union,
  9. )
  10. from ..error import GraphQLError, located_error
  11. from ..execution.collect_fields import collect_fields
  12. from ..execution.execute import (
  13. assert_valid_execution_arguments,
  14. execute,
  15. get_field_def,
  16. ExecutionContext,
  17. ExecutionResult,
  18. )
  19. from ..execution.values import get_argument_values
  20. from ..language import DocumentNode
  21. from ..pyutils import Path, inspect
  22. from ..type import GraphQLFieldResolver, GraphQLSchema
  23. from .map_async_iterator import MapAsyncIterator
  24. __all__ = ["subscribe", "create_source_event_stream"]
  25. async def subscribe(
  26. schema: GraphQLSchema,
  27. document: DocumentNode,
  28. root_value: Any = None,
  29. context_value: Any = None,
  30. variable_values: Optional[Dict[str, Any]] = None,
  31. operation_name: Optional[str] = None,
  32. field_resolver: Optional[GraphQLFieldResolver] = None,
  33. subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
  34. ) -> Union[AsyncIterator[ExecutionResult], ExecutionResult]:
  35. """Create a GraphQL subscription.
  36. Implements the "Subscribe" algorithm described in the GraphQL spec.
  37. Returns a coroutine object which yields either an AsyncIterator (if successful) or
  38. an ExecutionResult (client error). The coroutine will raise an exception if a server
  39. error occurs.
  40. If the client-provided arguments to this function do not result in a compliant
  41. subscription, a GraphQL Response (ExecutionResult) with descriptive errors and no
  42. data will be returned.
  43. If the source stream could not be created due to faulty subscription resolver logic
  44. or underlying systems, the coroutine object will yield a single ExecutionResult
  45. containing ``errors`` and no ``data``.
  46. If the operation succeeded, the coroutine will yield an AsyncIterator, which yields
  47. a stream of ExecutionResults representing the response stream.
  48. """
  49. result_or_stream = await create_source_event_stream(
  50. schema,
  51. document,
  52. root_value,
  53. context_value,
  54. variable_values,
  55. operation_name,
  56. subscribe_field_resolver,
  57. )
  58. if isinstance(result_or_stream, ExecutionResult):
  59. return result_or_stream
  60. async def map_source_to_response(payload: Any) -> ExecutionResult:
  61. """Map source to response.
  62. For each payload yielded from a subscription, map it over the normal GraphQL
  63. :func:`~graphql.execute` function, with ``payload`` as the ``root_value``.
  64. This implements the "MapSourceToResponseEvent" algorithm described in the
  65. GraphQL specification. The :func:`~graphql.execute` function provides the
  66. "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
  67. "ExecuteQuery" algorithm, for which :func:`~graphql.execute` is also used.
  68. """
  69. result = execute(
  70. schema,
  71. document,
  72. payload,
  73. context_value,
  74. variable_values,
  75. operation_name,
  76. field_resolver,
  77. )
  78. return await result if isawaitable(result) else result
  79. # Map every source value to a ExecutionResult value as described above.
  80. return MapAsyncIterator(result_or_stream, map_source_to_response)
  81. async def create_source_event_stream(
  82. schema: GraphQLSchema,
  83. document: DocumentNode,
  84. root_value: Any = None,
  85. context_value: Any = None,
  86. variable_values: Optional[Dict[str, Any]] = None,
  87. operation_name: Optional[str] = None,
  88. subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
  89. ) -> Union[AsyncIterable[Any], ExecutionResult]:
  90. """Create source event stream
  91. Implements the "CreateSourceEventStream" algorithm described in the GraphQL
  92. specification, resolving the subscription source event stream.
  93. Returns a coroutine that yields an AsyncIterable.
  94. If the client-provided arguments to this function do not result in a compliant
  95. subscription, a GraphQL Response (ExecutionResult) with descriptive errors and no
  96. data will be returned.
  97. If the source stream could not be created due to faulty subscription resolver logic
  98. or underlying systems, the coroutine object will yield a single ExecutionResult
  99. containing ``errors`` and no ``data``.
  100. A source event stream represents a sequence of events, each of which triggers a
  101. GraphQL execution for that event.
  102. This may be useful when hosting the stateful subscription service in a different
  103. process or machine than the stateless GraphQL execution engine, or otherwise
  104. separating these two steps. For more on this, see the "Supporting Subscriptions
  105. at Scale" information in the GraphQL spec.
  106. """
  107. # If arguments are missing or incorrectly typed, this is an internal developer
  108. # mistake which should throw an early error.
  109. assert_valid_execution_arguments(schema, document, variable_values)
  110. # If a valid context cannot be created due to incorrect arguments,
  111. # a "Response" with only errors is returned.
  112. context = ExecutionContext.build(
  113. schema,
  114. document,
  115. root_value,
  116. context_value,
  117. variable_values,
  118. operation_name,
  119. subscribe_field_resolver=subscribe_field_resolver,
  120. )
  121. # Return early errors if execution context failed.
  122. if isinstance(context, list):
  123. return ExecutionResult(data=None, errors=context)
  124. try:
  125. event_stream = await execute_subscription(context)
  126. # Assert field returned an event stream, otherwise yield an error.
  127. if not isinstance(event_stream, AsyncIterable):
  128. raise TypeError(
  129. "Subscription field must return AsyncIterable."
  130. f" Received: {inspect(event_stream)}."
  131. )
  132. return event_stream
  133. except GraphQLError as error:
  134. # Report it as an ExecutionResult, containing only errors and no data.
  135. return ExecutionResult(data=None, errors=[error])
  136. async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
  137. schema = context.schema
  138. root_type = schema.subscription_type
  139. if root_type is None:
  140. raise GraphQLError(
  141. "Schema is not configured to execute subscription operation.",
  142. context.operation,
  143. )
  144. root_fields = collect_fields(
  145. schema,
  146. context.fragments,
  147. context.variable_values,
  148. root_type,
  149. context.operation.selection_set,
  150. )
  151. response_name, field_nodes = next(iter(root_fields.items()))
  152. field_def = get_field_def(schema, root_type, field_nodes[0])
  153. if not field_def:
  154. field_name = field_nodes[0].name.value
  155. raise GraphQLError(
  156. f"The subscription field '{field_name}' is not defined.", field_nodes
  157. )
  158. path = Path(None, response_name, root_type.name)
  159. info = context.build_resolve_info(field_def, field_nodes, root_type, path)
  160. # Implements the "ResolveFieldEventStream" algorithm from GraphQL specification.
  161. # It differs from "ResolveFieldValue" due to providing a different `resolveFn`.
  162. try:
  163. # Build a dictionary of arguments from the field.arguments AST, using the
  164. # variables scope to fulfill any variable references.
  165. args = get_argument_values(field_def, field_nodes[0], context.variable_values)
  166. # Call the `subscribe()` resolver or the default resolver to produce an
  167. # AsyncIterable yielding raw payloads.
  168. resolve_fn = field_def.subscribe or context.subscribe_field_resolver
  169. event_stream = resolve_fn(context.root_value, info, **args)
  170. if context.is_awaitable(event_stream):
  171. event_stream = await event_stream
  172. if isinstance(event_stream, Exception):
  173. raise event_stream
  174. return event_stream
  175. except Exception as error:
  176. raise located_error(error, field_nodes, path.as_list())