1
0

dataloader.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. from collections import namedtuple
  2. try:
  3. from collections.abc import Iterable
  4. except ImportError:
  5. from collections import Iterable
  6. from functools import partial
  7. from threading import local
  8. from .promise import Promise, async_instance, get_default_scheduler
  9. if False:
  10. from typing import (
  11. Any,
  12. List,
  13. Sized,
  14. Callable,
  15. Optional,
  16. Tuple,
  17. Union,
  18. Iterator,
  19. Hashable,
  20. ) # flake8: noqa
  21. def get_chunks(iterable_obj, chunk_size=1):
  22. # type: (List[Loader], int) -> Iterator
  23. chunk_size = max(1, chunk_size)
  24. return (
  25. iterable_obj[i : i + chunk_size]
  26. for i in range(0, len(iterable_obj), chunk_size)
  27. )
  28. Loader = namedtuple("Loader", "key,resolve,reject")
  29. class DataLoader(local):
  30. batch = True
  31. max_batch_size = None # type: int
  32. cache = True
  33. def __init__(
  34. self,
  35. batch_load_fn=None, # type: Callable
  36. batch=None, # type: Optional[Any]
  37. max_batch_size=None, # type: Optional[int]
  38. cache=None, # type: Optional[Any]
  39. get_cache_key=None, # type: Optional[Any]
  40. cache_map=None, # type: Optional[Any]
  41. scheduler=None, # type: Optional[Any]
  42. ):
  43. # type: (...) -> None
  44. if batch_load_fn is not None:
  45. self.batch_load_fn = batch_load_fn
  46. if not callable(self.batch_load_fn):
  47. raise TypeError(
  48. (
  49. "DataLoader must be have a batch_load_fn which accepts "
  50. "List<key> and returns Promise<List<value>>, but got: {}."
  51. ).format(batch_load_fn)
  52. )
  53. if batch is not None:
  54. self.batch = batch
  55. if max_batch_size is not None:
  56. self.max_batch_size = max_batch_size
  57. if cache is not None:
  58. self.cache = cache
  59. self.get_cache_key = get_cache_key or (lambda x: x)
  60. self._promise_cache = cache_map or {}
  61. self._queue = [] # type: List[Loader]
  62. self._scheduler = scheduler
  63. def load(self, key=None):
  64. # type: (Hashable) -> Promise
  65. """
  66. Loads a key, returning a `Promise` for the value represented by that key.
  67. """
  68. if key is None:
  69. raise TypeError(
  70. (
  71. "The loader.load() function must be called with a value,"
  72. + "but got: {}."
  73. ).format(key)
  74. )
  75. cache_key = self.get_cache_key(key)
  76. # If caching and there is a cache-hit, return cached Promise.
  77. if self.cache:
  78. cached_promise = self._promise_cache.get(cache_key)
  79. if cached_promise:
  80. return cached_promise
  81. # Otherwise, produce a new Promise for this value.
  82. promise = Promise(partial(self.do_resolve_reject, key)) # type: ignore
  83. # If caching, cache this promise.
  84. if self.cache:
  85. self._promise_cache[cache_key] = promise
  86. return promise
  87. def do_resolve_reject(self, key, resolve, reject):
  88. # type: (Hashable, Callable, Callable) -> None
  89. # Enqueue this Promise to be dispatched.
  90. self._queue.append(Loader(key=key, resolve=resolve, reject=reject))
  91. # Determine if a dispatch of this queue should be scheduled.
  92. # A single dispatch should be scheduled per queue at the time when the
  93. # queue changes from "empty" to "full".
  94. if len(self._queue) == 1:
  95. if self.batch:
  96. # If batching, schedule a task to dispatch the queue.
  97. enqueue_post_promise_job(partial(dispatch_queue, self), self._scheduler)
  98. else:
  99. # Otherwise dispatch the (queue of one) immediately.
  100. dispatch_queue(self)
  101. def load_many(self, keys):
  102. # type: (Iterable[Hashable]) -> Promise
  103. """
  104. Loads multiple keys, promising an array of values
  105. >>> a, b = await my_loader.load_many([ 'a', 'b' ])
  106. This is equivalent to the more verbose:
  107. >>> a, b = await Promise.all([
  108. >>> my_loader.load('a'),
  109. >>> my_loader.load('b')
  110. >>> ])
  111. """
  112. if not isinstance(keys, Iterable):
  113. raise TypeError(
  114. (
  115. "The loader.loadMany() function must be called with Array<key> "
  116. + "but got: {}."
  117. ).format(keys)
  118. )
  119. return Promise.all([self.load(key) for key in keys])
  120. def clear(self, key):
  121. # type: (Hashable) -> DataLoader
  122. """
  123. Clears the value at `key` from the cache, if it exists. Returns itself for
  124. method chaining.
  125. """
  126. cache_key = self.get_cache_key(key)
  127. self._promise_cache.pop(cache_key, None)
  128. return self
  129. def clear_all(self):
  130. # type: () -> DataLoader
  131. """
  132. Clears the entire cache. To be used when some event results in unknown
  133. invalidations across this particular `DataLoader`. Returns itself for
  134. method chaining.
  135. """
  136. self._promise_cache.clear()
  137. return self
  138. def prime(self, key, value):
  139. # type: (Hashable, Any) -> DataLoader
  140. """
  141. Adds the provied key and value to the cache. If the key already exists, no
  142. change is made. Returns itself for method chaining.
  143. """
  144. cache_key = self.get_cache_key(key)
  145. # Only add the key if it does not already exist.
  146. if cache_key not in self._promise_cache:
  147. # Cache a rejected promise if the value is an Error, in order to match
  148. # the behavior of load(key).
  149. if isinstance(value, Exception):
  150. promise = Promise.reject(value)
  151. else:
  152. promise = Promise.resolve(value)
  153. self._promise_cache[cache_key] = promise
  154. return self
  155. # Private: Enqueue a Job to be executed after all "PromiseJobs" Jobs.
  156. #
  157. # ES6 JavaScript uses the concepts Job and JobQueue to schedule work to occur
  158. # after the current execution context has completed:
  159. # http://www.ecma-international.org/ecma-262/6.0/#sec-jobs-and-job-queues
  160. #
  161. # Node.js uses the `process.nextTick` mechanism to implement the concept of a
  162. # Job, maintaining a global FIFO JobQueue for all Jobs, which is flushed after
  163. # the current call stack ends.
  164. #
  165. # When calling `then` on a Promise, it enqueues a Job on a specific
  166. # "PromiseJobs" JobQueue which is flushed in Node as a single Job on the
  167. # global JobQueue.
  168. #
  169. # DataLoader batches all loads which occur in a single frame of execution, but
  170. # should include in the batch all loads which occur during the flushing of the
  171. # "PromiseJobs" JobQueue after that same execution frame.
  172. #
  173. # In order to avoid the DataLoader dispatch Job occuring before "PromiseJobs",
  174. # A Promise Job is created with the sole purpose of enqueuing a global Job,
  175. # ensuring that it always occurs after "PromiseJobs" ends.
  176. # Private: cached resolved Promise instance
  177. cache = local()
  178. def enqueue_post_promise_job(fn, scheduler):
  179. # type: (Callable, Any) -> None
  180. global cache
  181. if not hasattr(cache, 'resolved_promise'):
  182. cache.resolved_promise = Promise.resolve(None)
  183. if not scheduler:
  184. scheduler = get_default_scheduler()
  185. def on_promise_resolve(v):
  186. # type: (Any) -> None
  187. async_instance.invoke(fn, scheduler)
  188. cache.resolved_promise.then(on_promise_resolve)
  189. def dispatch_queue(loader):
  190. # type: (DataLoader) -> None
  191. """
  192. Given the current state of a Loader instance, perform a batch load
  193. from its current queue.
  194. """
  195. # Take the current loader queue, replacing it with an empty queue.
  196. queue = loader._queue
  197. loader._queue = []
  198. # If a maxBatchSize was provided and the queue is longer, then segment the
  199. # queue into multiple batches, otherwise treat the queue as a single batch.
  200. max_batch_size = loader.max_batch_size
  201. if max_batch_size and max_batch_size < len(queue):
  202. chunks = get_chunks(queue, max_batch_size)
  203. for chunk in chunks:
  204. dispatch_queue_batch(loader, chunk)
  205. else:
  206. dispatch_queue_batch(loader, queue)
  207. def dispatch_queue_batch(loader, queue):
  208. # type: (DataLoader, List[Loader]) -> None
  209. # Collect all keys to be loaded in this dispatch
  210. keys = [l.key for l in queue]
  211. # Call the provided batch_load_fn for this loader with the loader queue's keys.
  212. try:
  213. batch_promise = loader.batch_load_fn(keys)
  214. except Exception as e:
  215. failed_dispatch(loader, queue, e)
  216. return None
  217. # Assert the expected response from batch_load_fn
  218. if not batch_promise or not isinstance(batch_promise, Promise):
  219. failed_dispatch(
  220. loader,
  221. queue,
  222. TypeError(
  223. (
  224. "DataLoader must be constructed with a function which accepts "
  225. "Array<key> and returns Promise<Array<value>>, but the function did "
  226. "not return a Promise: {}."
  227. ).format(batch_promise)
  228. ),
  229. )
  230. return None
  231. def batch_promise_resolved(values):
  232. # type: (Sized) -> None
  233. # Assert the expected resolution from batchLoadFn.
  234. if not isinstance(values, Iterable):
  235. raise TypeError(
  236. (
  237. "DataLoader must be constructed with a function which accepts "
  238. "Array<key> and returns Promise<Array<value>>, but the function did "
  239. "not return a Promise of an Array: {}."
  240. ).format(values)
  241. )
  242. if len(values) != len(keys):
  243. raise TypeError(
  244. (
  245. "DataLoader must be constructed with a function which accepts "
  246. "Array<key> and returns Promise<Array<value>>, but the function did "
  247. "not return a Promise of an Array of the same length as the Array "
  248. "of keys."
  249. "\n\nKeys:\n{}"
  250. "\n\nValues:\n{}"
  251. ).format(keys, values)
  252. )
  253. # Step through the values, resolving or rejecting each Promise in the
  254. # loaded queue.
  255. for l, value in zip(queue, values):
  256. if isinstance(value, Exception):
  257. l.reject(value)
  258. else:
  259. l.resolve(value)
  260. batch_promise.then(batch_promise_resolved).catch(
  261. partial(failed_dispatch, loader, queue)
  262. )
  263. def failed_dispatch(loader, queue, error):
  264. # type: (DataLoader, Iterable[Loader], Exception) -> None
  265. """
  266. Do not cache individual loads if the entire batch dispatch fails,
  267. but still reject each request so they do not hang.
  268. """
  269. for l in queue:
  270. loader.clear(l.key)
  271. l.reject(error)