from collections import namedtuple try: from collections.abc import Iterable except ImportError: from collections import Iterable from functools import partial from threading import local from .promise import Promise, async_instance, get_default_scheduler if False: from typing import ( Any, List, Sized, Callable, Optional, Tuple, Union, Iterator, Hashable, ) # flake8: noqa def get_chunks(iterable_obj, chunk_size=1): # type: (List[Loader], int) -> Iterator chunk_size = max(1, chunk_size) return ( iterable_obj[i : i + chunk_size] for i in range(0, len(iterable_obj), chunk_size) ) Loader = namedtuple("Loader", "key,resolve,reject") class DataLoader(local): batch = True max_batch_size = None # type: int cache = True def __init__( self, batch_load_fn=None, # type: Callable batch=None, # type: Optional[Any] max_batch_size=None, # type: Optional[int] cache=None, # type: Optional[Any] get_cache_key=None, # type: Optional[Any] cache_map=None, # type: Optional[Any] scheduler=None, # type: Optional[Any] ): # type: (...) -> None if batch_load_fn is not None: self.batch_load_fn = batch_load_fn if not callable(self.batch_load_fn): raise TypeError( ( "DataLoader must be have a batch_load_fn which accepts " "List and returns Promise>, but got: {}." ).format(batch_load_fn) ) if batch is not None: self.batch = batch if max_batch_size is not None: self.max_batch_size = max_batch_size if cache is not None: self.cache = cache self.get_cache_key = get_cache_key or (lambda x: x) self._promise_cache = cache_map or {} self._queue = [] # type: List[Loader] self._scheduler = scheduler def load(self, key=None): # type: (Hashable) -> Promise """ Loads a key, returning a `Promise` for the value represented by that key. """ if key is None: raise TypeError( ( "The loader.load() function must be called with a value," + "but got: {}." ).format(key) ) cache_key = self.get_cache_key(key) # If caching and there is a cache-hit, return cached Promise. if self.cache: cached_promise = self._promise_cache.get(cache_key) if cached_promise: return cached_promise # Otherwise, produce a new Promise for this value. promise = Promise(partial(self.do_resolve_reject, key)) # type: ignore # If caching, cache this promise. if self.cache: self._promise_cache[cache_key] = promise return promise def do_resolve_reject(self, key, resolve, reject): # type: (Hashable, Callable, Callable) -> None # Enqueue this Promise to be dispatched. self._queue.append(Loader(key=key, resolve=resolve, reject=reject)) # Determine if a dispatch of this queue should be scheduled. # A single dispatch should be scheduled per queue at the time when the # queue changes from "empty" to "full". if len(self._queue) == 1: if self.batch: # If batching, schedule a task to dispatch the queue. enqueue_post_promise_job(partial(dispatch_queue, self), self._scheduler) else: # Otherwise dispatch the (queue of one) immediately. dispatch_queue(self) def load_many(self, keys): # type: (Iterable[Hashable]) -> Promise """ Loads multiple keys, promising an array of values >>> a, b = await my_loader.load_many([ 'a', 'b' ]) This is equivalent to the more verbose: >>> a, b = await Promise.all([ >>> my_loader.load('a'), >>> my_loader.load('b') >>> ]) """ if not isinstance(keys, Iterable): raise TypeError( ( "The loader.loadMany() function must be called with Array " + "but got: {}." ).format(keys) ) return Promise.all([self.load(key) for key in keys]) def clear(self, key): # type: (Hashable) -> DataLoader """ Clears the value at `key` from the cache, if it exists. Returns itself for method chaining. """ cache_key = self.get_cache_key(key) self._promise_cache.pop(cache_key, None) return self def clear_all(self): # type: () -> DataLoader """ Clears the entire cache. To be used when some event results in unknown invalidations across this particular `DataLoader`. Returns itself for method chaining. """ self._promise_cache.clear() return self def prime(self, key, value): # type: (Hashable, Any) -> DataLoader """ Adds the provied key and value to the cache. If the key already exists, no change is made. Returns itself for method chaining. """ cache_key = self.get_cache_key(key) # Only add the key if it does not already exist. if cache_key not in self._promise_cache: # Cache a rejected promise if the value is an Error, in order to match # the behavior of load(key). if isinstance(value, Exception): promise = Promise.reject(value) else: promise = Promise.resolve(value) self._promise_cache[cache_key] = promise return self # Private: Enqueue a Job to be executed after all "PromiseJobs" Jobs. # # ES6 JavaScript uses the concepts Job and JobQueue to schedule work to occur # after the current execution context has completed: # http://www.ecma-international.org/ecma-262/6.0/#sec-jobs-and-job-queues # # Node.js uses the `process.nextTick` mechanism to implement the concept of a # Job, maintaining a global FIFO JobQueue for all Jobs, which is flushed after # the current call stack ends. # # When calling `then` on a Promise, it enqueues a Job on a specific # "PromiseJobs" JobQueue which is flushed in Node as a single Job on the # global JobQueue. # # DataLoader batches all loads which occur in a single frame of execution, but # should include in the batch all loads which occur during the flushing of the # "PromiseJobs" JobQueue after that same execution frame. # # In order to avoid the DataLoader dispatch Job occuring before "PromiseJobs", # A Promise Job is created with the sole purpose of enqueuing a global Job, # ensuring that it always occurs after "PromiseJobs" ends. # Private: cached resolved Promise instance cache = local() def enqueue_post_promise_job(fn, scheduler): # type: (Callable, Any) -> None global cache if not hasattr(cache, 'resolved_promise'): cache.resolved_promise = Promise.resolve(None) if not scheduler: scheduler = get_default_scheduler() def on_promise_resolve(v): # type: (Any) -> None async_instance.invoke(fn, scheduler) cache.resolved_promise.then(on_promise_resolve) def dispatch_queue(loader): # type: (DataLoader) -> None """ Given the current state of a Loader instance, perform a batch load from its current queue. """ # Take the current loader queue, replacing it with an empty queue. queue = loader._queue loader._queue = [] # If a maxBatchSize was provided and the queue is longer, then segment the # queue into multiple batches, otherwise treat the queue as a single batch. max_batch_size = loader.max_batch_size if max_batch_size and max_batch_size < len(queue): chunks = get_chunks(queue, max_batch_size) for chunk in chunks: dispatch_queue_batch(loader, chunk) else: dispatch_queue_batch(loader, queue) def dispatch_queue_batch(loader, queue): # type: (DataLoader, List[Loader]) -> None # Collect all keys to be loaded in this dispatch keys = [l.key for l in queue] # Call the provided batch_load_fn for this loader with the loader queue's keys. try: batch_promise = loader.batch_load_fn(keys) except Exception as e: failed_dispatch(loader, queue, e) return None # Assert the expected response from batch_load_fn if not batch_promise or not isinstance(batch_promise, Promise): failed_dispatch( loader, queue, TypeError( ( "DataLoader must be constructed with a function which accepts " "Array and returns Promise>, but the function did " "not return a Promise: {}." ).format(batch_promise) ), ) return None def batch_promise_resolved(values): # type: (Sized) -> None # Assert the expected resolution from batchLoadFn. if not isinstance(values, Iterable): raise TypeError( ( "DataLoader must be constructed with a function which accepts " "Array and returns Promise>, but the function did " "not return a Promise of an Array: {}." ).format(values) ) if len(values) != len(keys): raise TypeError( ( "DataLoader must be constructed with a function which accepts " "Array and returns Promise>, but the function did " "not return a Promise of an Array of the same length as the Array " "of keys." "\n\nKeys:\n{}" "\n\nValues:\n{}" ).format(keys, values) ) # Step through the values, resolving or rejecting each Promise in the # loaded queue. for l, value in zip(queue, values): if isinstance(value, Exception): l.reject(value) else: l.resolve(value) batch_promise.then(batch_promise_resolved).catch( partial(failed_dispatch, loader, queue) ) def failed_dispatch(loader, queue, error): # type: (DataLoader, Iterable[Loader], Exception) -> None """ Do not cache individual loads if the entire batch dispatch fails, but still reject each request so they do not hang. """ for l in queue: loader.clear(l.key) l.reject(error)