123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- from functools import partial, reduce
- from inspect import isfunction
- from typing import Callable, Iterator, Dict, List, Tuple, Any, Optional
- __all__ = ["MiddlewareManager"]
- GraphQLFieldResolver = Callable[..., Any]
- class MiddlewareManager:
- """Manager for the middleware chain.
- This class helps to wrap resolver functions with the provided middleware functions
- and/or objects. The functions take the next middleware function as first argument.
- If middleware is provided as an object, it must provide a method ``resolve`` that is
- used as the middleware function.
- Note that since resolvers return "AwaitableOrValue"s, all middleware functions
- must be aware of this and check whether values are awaitable before awaiting them.
- """
- # allow custom attributes (not used internally)
- __slots__ = "__dict__", "middlewares", "_middleware_resolvers", "_cached_resolvers"
- _cached_resolvers: Dict[GraphQLFieldResolver, GraphQLFieldResolver]
- _middleware_resolvers: Optional[List[Callable]]
- def __init__(self, *middlewares: Any):
- self.middlewares = middlewares
- self._middleware_resolvers = (
- list(get_middleware_resolvers(middlewares)) if middlewares else None
- )
- self._cached_resolvers = {}
- def get_field_resolver(
- self, field_resolver: GraphQLFieldResolver
- ) -> GraphQLFieldResolver:
- """Wrap the provided resolver with the middleware.
- Returns a function that chains the middleware functions with the provided
- resolver function.
- """
- if self._middleware_resolvers is None:
- return field_resolver
- if field_resolver not in self._cached_resolvers:
- self._cached_resolvers[field_resolver] = reduce(
- lambda chained_fns, next_fn: partial(next_fn, chained_fns),
- self._middleware_resolvers,
- field_resolver,
- )
- return self._cached_resolvers[field_resolver]
- def get_middleware_resolvers(middlewares: Tuple[Any, ...]) -> Iterator[Callable]:
- """Get a list of resolver functions from a list of classes or functions."""
- for middleware in middlewares:
- if isfunction(middleware):
- yield middleware
- else: # middleware provided as object with 'resolve' method
- resolver_func = getattr(middleware, "resolve", None)
- if resolver_func is not None:
- yield resolver_func
|