validation_context.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, cast
  2. from ..error import GraphQLError
  3. from ..language import (
  4. DocumentNode,
  5. FragmentDefinitionNode,
  6. FragmentSpreadNode,
  7. OperationDefinitionNode,
  8. SelectionSetNode,
  9. VariableNode,
  10. Visitor,
  11. VisitorAction,
  12. visit,
  13. )
  14. from ..type import (
  15. GraphQLArgument,
  16. GraphQLCompositeType,
  17. GraphQLDirective,
  18. GraphQLEnumValue,
  19. GraphQLField,
  20. GraphQLInputType,
  21. GraphQLOutputType,
  22. GraphQLSchema,
  23. )
  24. from ..utilities import TypeInfo, TypeInfoVisitor
  25. __all__ = [
  26. "ASTValidationContext",
  27. "SDLValidationContext",
  28. "ValidationContext",
  29. "VariableUsage",
  30. "VariableUsageVisitor",
  31. ]
  32. NodeWithSelectionSet = Union[OperationDefinitionNode, FragmentDefinitionNode]
  33. class VariableUsage(NamedTuple):
  34. node: VariableNode
  35. type: Optional[GraphQLInputType]
  36. default_value: Any
  37. class VariableUsageVisitor(Visitor):
  38. """Visitor adding all variable usages to a given list."""
  39. usages: List[VariableUsage]
  40. def __init__(self, type_info: TypeInfo):
  41. super().__init__()
  42. self.usages = []
  43. self._append_usage = self.usages.append
  44. self._type_info = type_info
  45. def enter_variable_definition(self, *_args: Any) -> VisitorAction:
  46. return self.SKIP
  47. def enter_variable(self, node: VariableNode, *_args: Any) -> VisitorAction:
  48. type_info = self._type_info
  49. usage = VariableUsage(
  50. node, type_info.get_input_type(), type_info.get_default_value()
  51. )
  52. self._append_usage(usage)
  53. return None
  54. class ASTValidationContext:
  55. """Utility class providing a context for validation of an AST.
  56. An instance of this class is passed as the context attribute to all Validators,
  57. allowing access to commonly useful contextual information from within a validation
  58. rule.
  59. """
  60. document: DocumentNode
  61. _fragments: Optional[Dict[str, FragmentDefinitionNode]]
  62. _fragment_spreads: Dict[SelectionSetNode, List[FragmentSpreadNode]]
  63. _recursively_referenced_fragments: Dict[
  64. OperationDefinitionNode, List[FragmentDefinitionNode]
  65. ]
  66. def __init__(
  67. self, ast: DocumentNode, on_error: Callable[[GraphQLError], None]
  68. ) -> None:
  69. self.document = ast
  70. self.on_error = on_error # type: ignore
  71. self._fragments = None
  72. self._fragment_spreads = {}
  73. self._recursively_referenced_fragments = {}
  74. def on_error(self, error: GraphQLError) -> None:
  75. pass
  76. def report_error(self, error: GraphQLError) -> None:
  77. self.on_error(error)
  78. def get_fragment(self, name: str) -> Optional[FragmentDefinitionNode]:
  79. fragments = self._fragments
  80. if fragments is None:
  81. fragments = {
  82. statement.name.value: statement
  83. for statement in self.document.definitions
  84. if isinstance(statement, FragmentDefinitionNode)
  85. }
  86. self._fragments = fragments
  87. return fragments.get(name)
  88. def get_fragment_spreads(self, node: SelectionSetNode) -> List[FragmentSpreadNode]:
  89. spreads = self._fragment_spreads.get(node)
  90. if spreads is None:
  91. spreads = []
  92. append_spread = spreads.append
  93. sets_to_visit = [node]
  94. append_set = sets_to_visit.append
  95. pop_set = sets_to_visit.pop
  96. while sets_to_visit:
  97. visited_set = pop_set()
  98. for selection in visited_set.selections:
  99. if isinstance(selection, FragmentSpreadNode):
  100. append_spread(selection)
  101. else:
  102. set_to_visit = cast(
  103. NodeWithSelectionSet, selection
  104. ).selection_set
  105. if set_to_visit:
  106. append_set(set_to_visit)
  107. self._fragment_spreads[node] = spreads
  108. return spreads
  109. def get_recursively_referenced_fragments(
  110. self, operation: OperationDefinitionNode
  111. ) -> List[FragmentDefinitionNode]:
  112. fragments = self._recursively_referenced_fragments.get(operation)
  113. if fragments is None:
  114. fragments = []
  115. append_fragment = fragments.append
  116. collected_names: Set[str] = set()
  117. add_name = collected_names.add
  118. nodes_to_visit = [operation.selection_set]
  119. append_node = nodes_to_visit.append
  120. pop_node = nodes_to_visit.pop
  121. get_fragment = self.get_fragment
  122. get_fragment_spreads = self.get_fragment_spreads
  123. while nodes_to_visit:
  124. visited_node = pop_node()
  125. for spread in get_fragment_spreads(visited_node):
  126. frag_name = spread.name.value
  127. if frag_name not in collected_names:
  128. add_name(frag_name)
  129. fragment = get_fragment(frag_name)
  130. if fragment:
  131. append_fragment(fragment)
  132. append_node(fragment.selection_set)
  133. self._recursively_referenced_fragments[operation] = fragments
  134. return fragments
  135. class SDLValidationContext(ASTValidationContext):
  136. """Utility class providing a context for validation of an SDL AST.
  137. An instance of this class is passed as the context attribute to all Validators,
  138. allowing access to commonly useful contextual information from within a validation
  139. rule.
  140. """
  141. schema: Optional[GraphQLSchema]
  142. def __init__(
  143. self,
  144. ast: DocumentNode,
  145. schema: Optional[GraphQLSchema],
  146. on_error: Callable[[GraphQLError], None],
  147. ) -> None:
  148. super().__init__(ast, on_error)
  149. self.schema = schema
  150. class ValidationContext(ASTValidationContext):
  151. """Utility class providing a context for validation using a GraphQL schema.
  152. An instance of this class is passed as the context attribute to all Validators,
  153. allowing access to commonly useful contextual information from within a validation
  154. rule.
  155. """
  156. schema: GraphQLSchema
  157. _type_info: TypeInfo
  158. _variable_usages: Dict[NodeWithSelectionSet, List[VariableUsage]]
  159. _recursive_variable_usages: Dict[OperationDefinitionNode, List[VariableUsage]]
  160. def __init__(
  161. self,
  162. schema: GraphQLSchema,
  163. ast: DocumentNode,
  164. type_info: TypeInfo,
  165. on_error: Callable[[GraphQLError], None],
  166. ) -> None:
  167. super().__init__(ast, on_error)
  168. self.schema = schema
  169. self._type_info = type_info
  170. self._variable_usages = {}
  171. self._recursive_variable_usages = {}
  172. def get_variable_usages(self, node: NodeWithSelectionSet) -> List[VariableUsage]:
  173. usages = self._variable_usages.get(node)
  174. if usages is None:
  175. usage_visitor = VariableUsageVisitor(self._type_info)
  176. visit(node, TypeInfoVisitor(self._type_info, usage_visitor))
  177. usages = usage_visitor.usages
  178. self._variable_usages[node] = usages
  179. return usages
  180. def get_recursive_variable_usages(
  181. self, operation: OperationDefinitionNode
  182. ) -> List[VariableUsage]:
  183. usages = self._recursive_variable_usages.get(operation)
  184. if usages is None:
  185. get_variable_usages = self.get_variable_usages
  186. usages = get_variable_usages(operation)
  187. for fragment in self.get_recursively_referenced_fragments(operation):
  188. usages.extend(get_variable_usages(fragment))
  189. self._recursive_variable_usages[operation] = usages
  190. return usages
  191. def get_type(self) -> Optional[GraphQLOutputType]:
  192. return self._type_info.get_type()
  193. def get_parent_type(self) -> Optional[GraphQLCompositeType]:
  194. return self._type_info.get_parent_type()
  195. def get_input_type(self) -> Optional[GraphQLInputType]:
  196. return self._type_info.get_input_type()
  197. def get_parent_input_type(self) -> Optional[GraphQLInputType]:
  198. return self._type_info.get_parent_input_type()
  199. def get_field_def(self) -> Optional[GraphQLField]:
  200. return self._type_info.get_field_def()
  201. def get_directive(self) -> Optional[GraphQLDirective]:
  202. return self._type_info.get_directive()
  203. def get_argument(self) -> Optional[GraphQLArgument]:
  204. return self._type_info.get_argument()
  205. def get_enum_value(self) -> Optional[GraphQLEnumValue]:
  206. return self._type_info.get_enum_value()