collect_fields.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from typing import Any, Dict, List, Set, Union, cast
  2. from ..language import (
  3. FieldNode,
  4. FragmentDefinitionNode,
  5. FragmentSpreadNode,
  6. InlineFragmentNode,
  7. SelectionSetNode,
  8. )
  9. from ..type import (
  10. GraphQLAbstractType,
  11. GraphQLIncludeDirective,
  12. GraphQLObjectType,
  13. GraphQLSchema,
  14. GraphQLSkipDirective,
  15. is_abstract_type,
  16. )
  17. from ..utilities.type_from_ast import type_from_ast
  18. from .values import get_directive_values
  19. __all__ = ["collect_fields", "collect_sub_fields"]
  20. def collect_fields(
  21. schema: GraphQLSchema,
  22. fragments: Dict[str, FragmentDefinitionNode],
  23. variable_values: Dict[str, Any],
  24. runtime_type: GraphQLObjectType,
  25. selection_set: SelectionSetNode,
  26. ) -> Dict[str, List[FieldNode]]:
  27. """Collect fields.
  28. Given a selection_set, collects all the fields and returns them.
  29. collect_fields requires the "runtime type" of an object. For a field that
  30. returns an Interface or Union type, the "runtime type" will be the actual
  31. object type returned by that field.
  32. For internal use only.
  33. """
  34. fields: Dict[str, List[FieldNode]] = {}
  35. collect_fields_impl(
  36. schema, fragments, variable_values, runtime_type, selection_set, fields, set()
  37. )
  38. return fields
  39. def collect_sub_fields(
  40. schema: GraphQLSchema,
  41. fragments: Dict[str, FragmentDefinitionNode],
  42. variable_values: Dict[str, Any],
  43. return_type: GraphQLObjectType,
  44. field_nodes: List[FieldNode],
  45. ) -> Dict[str, List[FieldNode]]:
  46. """Collect sub fields.
  47. Given a list of field nodes, collects all the subfields of the passed in fields,
  48. and returns them at the end.
  49. collect_sub_fields requires the "return type" of an object. For a field that
  50. returns an Interface or Union type, the "return type" will be the actual
  51. object type returned by that field.
  52. For internal use only.
  53. """
  54. sub_field_nodes: Dict[str, List[FieldNode]] = {}
  55. visited_fragment_names: Set[str] = set()
  56. for node in field_nodes:
  57. if node.selection_set:
  58. collect_fields_impl(
  59. schema,
  60. fragments,
  61. variable_values,
  62. return_type,
  63. node.selection_set,
  64. sub_field_nodes,
  65. visited_fragment_names,
  66. )
  67. return sub_field_nodes
  68. def collect_fields_impl(
  69. schema: GraphQLSchema,
  70. fragments: Dict[str, FragmentDefinitionNode],
  71. variable_values: Dict[str, Any],
  72. runtime_type: GraphQLObjectType,
  73. selection_set: SelectionSetNode,
  74. fields: Dict[str, List[FieldNode]],
  75. visited_fragment_names: Set[str],
  76. ) -> None:
  77. """Collect fields (internal implementation)."""
  78. for selection in selection_set.selections:
  79. if isinstance(selection, FieldNode):
  80. if not should_include_node(variable_values, selection):
  81. continue
  82. name = get_field_entry_key(selection)
  83. fields.setdefault(name, []).append(selection)
  84. elif isinstance(selection, InlineFragmentNode):
  85. if not should_include_node(
  86. variable_values, selection
  87. ) or not does_fragment_condition_match(schema, selection, runtime_type):
  88. continue
  89. collect_fields_impl(
  90. schema,
  91. fragments,
  92. variable_values,
  93. runtime_type,
  94. selection.selection_set,
  95. fields,
  96. visited_fragment_names,
  97. )
  98. elif isinstance(selection, FragmentSpreadNode): # pragma: no cover else
  99. frag_name = selection.name.value
  100. if frag_name in visited_fragment_names or not should_include_node(
  101. variable_values, selection
  102. ):
  103. continue
  104. visited_fragment_names.add(frag_name)
  105. fragment = fragments.get(frag_name)
  106. if not fragment or not does_fragment_condition_match(
  107. schema, fragment, runtime_type
  108. ):
  109. continue
  110. collect_fields_impl(
  111. schema,
  112. fragments,
  113. variable_values,
  114. runtime_type,
  115. fragment.selection_set,
  116. fields,
  117. visited_fragment_names,
  118. )
  119. def should_include_node(
  120. variable_values: Dict[str, Any],
  121. node: Union[FragmentSpreadNode, FieldNode, InlineFragmentNode],
  122. ) -> bool:
  123. """Check if node should be included
  124. Determines if a field should be included based on the @include and @skip
  125. directives, where @skip has higher precedence than @include.
  126. """
  127. skip = get_directive_values(GraphQLSkipDirective, node, variable_values)
  128. if skip and skip["if"]:
  129. return False
  130. include = get_directive_values(GraphQLIncludeDirective, node, variable_values)
  131. if include and not include["if"]:
  132. return False
  133. return True
  134. def does_fragment_condition_match(
  135. schema: GraphQLSchema,
  136. fragment: Union[FragmentDefinitionNode, InlineFragmentNode],
  137. type_: GraphQLObjectType,
  138. ) -> bool:
  139. """Determine if a fragment is applicable to the given type."""
  140. type_condition_node = fragment.type_condition
  141. if not type_condition_node:
  142. return True
  143. conditional_type = type_from_ast(schema, type_condition_node)
  144. if conditional_type is type_:
  145. return True
  146. if is_abstract_type(conditional_type):
  147. return schema.is_sub_type(cast(GraphQLAbstractType, conditional_type), type_)
  148. return False
  149. def get_field_entry_key(node: FieldNode) -> str:
  150. """Implements the logic to compute the key of a given field's entry"""
  151. return node.alias.value if node.alias else node.name.value