sync.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. import asyncio
  2. import asyncio.coroutines
  3. import contextvars
  4. import functools
  5. import inspect
  6. import os
  7. import sys
  8. import threading
  9. import warnings
  10. import weakref
  11. from concurrent.futures import Future, ThreadPoolExecutor
  12. from typing import Any, Callable, Dict, Optional, overload
  13. from .current_thread_executor import CurrentThreadExecutor
  14. from .local import Local
  15. def _restore_context(context):
  16. # Check for changes in contextvars, and set them to the current
  17. # context for downstream consumers
  18. for cvar in context:
  19. try:
  20. if cvar.get() != context.get(cvar):
  21. cvar.set(context.get(cvar))
  22. except LookupError:
  23. cvar.set(context.get(cvar))
  24. def _iscoroutinefunction_or_partial(func: Any) -> bool:
  25. # Python < 3.8 does not correctly determine partially wrapped
  26. # coroutine functions are coroutine functions, hence the need for
  27. # this to exist. Code taken from CPython.
  28. if sys.version_info >= (3, 8):
  29. return asyncio.iscoroutinefunction(func)
  30. else:
  31. while inspect.ismethod(func):
  32. func = func.__func__
  33. while isinstance(func, functools.partial):
  34. func = func.func
  35. return asyncio.iscoroutinefunction(func)
  36. class ThreadSensitiveContext:
  37. """Async context manager to manage context for thread sensitive mode
  38. This context manager controls which thread pool executor is used when in
  39. thread sensitive mode. By default, a single thread pool executor is shared
  40. within a process.
  41. In Python 3.7+, the ThreadSensitiveContext() context manager may be used to
  42. specify a thread pool per context.
  43. This context manager is re-entrant, so only the outer-most call to
  44. ThreadSensitiveContext will set the context.
  45. Usage:
  46. >>> import time
  47. >>> async with ThreadSensitiveContext():
  48. ... await sync_to_async(time.sleep, 1)()
  49. """
  50. def __init__(self):
  51. self.token = None
  52. async def __aenter__(self):
  53. try:
  54. SyncToAsync.thread_sensitive_context.get()
  55. except LookupError:
  56. self.token = SyncToAsync.thread_sensitive_context.set(self)
  57. return self
  58. async def __aexit__(self, exc, value, tb):
  59. if not self.token:
  60. return
  61. executor = SyncToAsync.context_to_thread_executor.pop(self, None)
  62. if executor:
  63. executor.shutdown()
  64. SyncToAsync.thread_sensitive_context.reset(self.token)
  65. class AsyncToSync:
  66. """
  67. Utility class which turns an awaitable that only works on the thread with
  68. the event loop into a synchronous callable that works in a subthread.
  69. If the call stack contains an async loop, the code runs there.
  70. Otherwise, the code runs in a new loop in a new thread.
  71. Either way, this thread then pauses and waits to run any thread_sensitive
  72. code called from further down the call stack using SyncToAsync, before
  73. finally exiting once the async task returns.
  74. """
  75. # Maps launched Tasks to the threads that launched them (for locals impl)
  76. launch_map: "Dict[asyncio.Task[object], threading.Thread]" = {}
  77. # Keeps track of which CurrentThreadExecutor to use. This uses an asgiref
  78. # Local, not a threadlocal, so that tasks can work out what their parent used.
  79. executors = Local()
  80. # When we can't find a CurrentThreadExecutor from the context, such as
  81. # inside create_task, we'll look it up here from the running event loop.
  82. loop_thread_executors: "Dict[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = {}
  83. def __init__(self, awaitable, force_new_loop=False):
  84. if not callable(awaitable) or (
  85. not _iscoroutinefunction_or_partial(awaitable)
  86. and not _iscoroutinefunction_or_partial(
  87. getattr(awaitable, "__call__", awaitable)
  88. )
  89. ):
  90. # Python does not have very reliable detection of async functions
  91. # (lots of false negatives) so this is just a warning.
  92. warnings.warn(
  93. "async_to_sync was passed a non-async-marked callable", stacklevel=2
  94. )
  95. self.awaitable = awaitable
  96. try:
  97. self.__self__ = self.awaitable.__self__
  98. except AttributeError:
  99. pass
  100. if force_new_loop:
  101. # They have asked that we always run in a new sub-loop.
  102. self.main_event_loop = None
  103. else:
  104. try:
  105. self.main_event_loop = asyncio.get_running_loop()
  106. except RuntimeError:
  107. # There's no event loop in this thread. Look for the threadlocal if
  108. # we're inside SyncToAsync
  109. main_event_loop_pid = getattr(
  110. SyncToAsync.threadlocal, "main_event_loop_pid", None
  111. )
  112. # We make sure the parent loop is from the same process - if
  113. # they've forked, this is not going to be valid any more (#194)
  114. if main_event_loop_pid and main_event_loop_pid == os.getpid():
  115. self.main_event_loop = getattr(
  116. SyncToAsync.threadlocal, "main_event_loop", None
  117. )
  118. else:
  119. self.main_event_loop = None
  120. def __call__(self, *args, **kwargs):
  121. # You can't call AsyncToSync from a thread with a running event loop
  122. try:
  123. event_loop = asyncio.get_running_loop()
  124. except RuntimeError:
  125. pass
  126. else:
  127. if event_loop.is_running():
  128. raise RuntimeError(
  129. "You cannot use AsyncToSync in the same thread as an async event loop - "
  130. "just await the async function directly."
  131. )
  132. # Wrapping context in list so it can be reassigned from within
  133. # `main_wrap`.
  134. context = [contextvars.copy_context()]
  135. # Make a future for the return information
  136. call_result = Future()
  137. # Get the source thread
  138. source_thread = threading.current_thread()
  139. # Make a CurrentThreadExecutor we'll use to idle in this thread - we
  140. # need one for every sync frame, even if there's one above us in the
  141. # same thread.
  142. if hasattr(self.executors, "current"):
  143. old_current_executor = self.executors.current
  144. else:
  145. old_current_executor = None
  146. current_executor = CurrentThreadExecutor()
  147. self.executors.current = current_executor
  148. loop = None
  149. # Use call_soon_threadsafe to schedule a synchronous callback on the
  150. # main event loop's thread if it's there, otherwise make a new loop
  151. # in this thread.
  152. try:
  153. awaitable = self.main_wrap(
  154. args, kwargs, call_result, source_thread, sys.exc_info(), context
  155. )
  156. if not (self.main_event_loop and self.main_event_loop.is_running()):
  157. # Make our own event loop - in a new thread - and run inside that.
  158. loop = asyncio.new_event_loop()
  159. self.loop_thread_executors[loop] = current_executor
  160. loop_executor = ThreadPoolExecutor(max_workers=1)
  161. loop_future = loop_executor.submit(
  162. self._run_event_loop, loop, awaitable
  163. )
  164. if current_executor:
  165. # Run the CurrentThreadExecutor until the future is done
  166. current_executor.run_until_future(loop_future)
  167. # Wait for future and/or allow for exception propagation
  168. loop_future.result()
  169. else:
  170. # Call it inside the existing loop
  171. self.main_event_loop.call_soon_threadsafe(
  172. self.main_event_loop.create_task, awaitable
  173. )
  174. if current_executor:
  175. # Run the CurrentThreadExecutor until the future is done
  176. current_executor.run_until_future(call_result)
  177. finally:
  178. # Clean up any executor we were running
  179. if loop is not None:
  180. del self.loop_thread_executors[loop]
  181. if hasattr(self.executors, "current"):
  182. del self.executors.current
  183. if old_current_executor:
  184. self.executors.current = old_current_executor
  185. _restore_context(context[0])
  186. # Wait for results from the future.
  187. return call_result.result()
  188. def _run_event_loop(self, loop, coro):
  189. """
  190. Runs the given event loop (designed to be called in a thread).
  191. """
  192. asyncio.set_event_loop(loop)
  193. try:
  194. loop.run_until_complete(coro)
  195. finally:
  196. try:
  197. # mimic asyncio.run() behavior
  198. # cancel unexhausted async generators
  199. tasks = asyncio.all_tasks(loop)
  200. for task in tasks:
  201. task.cancel()
  202. async def gather():
  203. await asyncio.gather(*tasks, return_exceptions=True)
  204. loop.run_until_complete(gather())
  205. for task in tasks:
  206. if task.cancelled():
  207. continue
  208. if task.exception() is not None:
  209. loop.call_exception_handler(
  210. {
  211. "message": "unhandled exception during loop shutdown",
  212. "exception": task.exception(),
  213. "task": task,
  214. }
  215. )
  216. if hasattr(loop, "shutdown_asyncgens"):
  217. loop.run_until_complete(loop.shutdown_asyncgens())
  218. finally:
  219. loop.close()
  220. asyncio.set_event_loop(self.main_event_loop)
  221. def __get__(self, parent, objtype):
  222. """
  223. Include self for methods
  224. """
  225. func = functools.partial(self.__call__, parent)
  226. return functools.update_wrapper(func, self.awaitable)
  227. async def main_wrap(
  228. self, args, kwargs, call_result, source_thread, exc_info, context
  229. ):
  230. """
  231. Wraps the awaitable with something that puts the result into the
  232. result/exception future.
  233. """
  234. if context is not None:
  235. _restore_context(context[0])
  236. current_task = SyncToAsync.get_current_task()
  237. self.launch_map[current_task] = source_thread
  238. try:
  239. # If we have an exception, run the function inside the except block
  240. # after raising it so exc_info is correctly populated.
  241. if exc_info[1]:
  242. try:
  243. raise exc_info[1]
  244. except BaseException:
  245. result = await self.awaitable(*args, **kwargs)
  246. else:
  247. result = await self.awaitable(*args, **kwargs)
  248. except BaseException as e:
  249. call_result.set_exception(e)
  250. else:
  251. call_result.set_result(result)
  252. finally:
  253. del self.launch_map[current_task]
  254. context[0] = contextvars.copy_context()
  255. class SyncToAsync:
  256. """
  257. Utility class which turns a synchronous callable into an awaitable that
  258. runs in a threadpool. It also sets a threadlocal inside the thread so
  259. calls to AsyncToSync can escape it.
  260. If thread_sensitive is passed, the code will run in the same thread as any
  261. outer code. This is needed for underlying Python code that is not
  262. threadsafe (for example, code which handles SQLite database connections).
  263. If the outermost program is async (i.e. SyncToAsync is outermost), then
  264. this will be a dedicated single sub-thread that all sync code runs in,
  265. one after the other. If the outermost program is sync (i.e. AsyncToSync is
  266. outermost), this will just be the main thread. This is achieved by idling
  267. with a CurrentThreadExecutor while AsyncToSync is blocking its sync parent,
  268. rather than just blocking.
  269. If executor is passed in, that will be used instead of the loop's default executor.
  270. In order to pass in an executor, thread_sensitive must be set to False, otherwise
  271. a TypeError will be raised.
  272. """
  273. # If they've set ASGI_THREADS, update the default asyncio executor for now
  274. if "ASGI_THREADS" in os.environ:
  275. # We use get_event_loop here - not get_running_loop - as this will
  276. # be run at import time, and we want to update the main thread's loop.
  277. loop = asyncio.get_event_loop()
  278. loop.set_default_executor(
  279. ThreadPoolExecutor(max_workers=int(os.environ["ASGI_THREADS"]))
  280. )
  281. # Maps launched threads to the coroutines that spawned them
  282. launch_map: "Dict[threading.Thread, asyncio.Task[object]]" = {}
  283. # Storage for main event loop references
  284. threadlocal = threading.local()
  285. # Single-thread executor for thread-sensitive code
  286. single_thread_executor = ThreadPoolExecutor(max_workers=1)
  287. # Maintain a contextvar for the current execution context. Optionally used
  288. # for thread sensitive mode.
  289. thread_sensitive_context: "contextvars.ContextVar[str]" = contextvars.ContextVar(
  290. "thread_sensitive_context"
  291. )
  292. # Contextvar that is used to detect if the single thread executor
  293. # would be awaited on while already being used in the same context
  294. deadlock_context: "contextvars.ContextVar[bool]" = contextvars.ContextVar(
  295. "deadlock_context"
  296. )
  297. # Maintaining a weak reference to the context ensures that thread pools are
  298. # erased once the context goes out of scope. This terminates the thread pool.
  299. context_to_thread_executor: "weakref.WeakKeyDictionary[object, ThreadPoolExecutor]" = (
  300. weakref.WeakKeyDictionary()
  301. )
  302. def __init__(
  303. self,
  304. func: Callable[..., Any],
  305. thread_sensitive: bool = True,
  306. executor: Optional["ThreadPoolExecutor"] = None,
  307. ) -> None:
  308. if (
  309. not callable(func)
  310. or _iscoroutinefunction_or_partial(func)
  311. or _iscoroutinefunction_or_partial(getattr(func, "__call__", func))
  312. ):
  313. raise TypeError("sync_to_async can only be applied to sync functions.")
  314. self.func = func
  315. functools.update_wrapper(self, func)
  316. self._thread_sensitive = thread_sensitive
  317. self._is_coroutine = asyncio.coroutines._is_coroutine # type: ignore
  318. if thread_sensitive and executor is not None:
  319. raise TypeError("executor must not be set when thread_sensitive is True")
  320. self._executor = executor
  321. try:
  322. self.__self__ = func.__self__ # type: ignore
  323. except AttributeError:
  324. pass
  325. async def __call__(self, *args, **kwargs):
  326. loop = asyncio.get_running_loop()
  327. # Work out what thread to run the code in
  328. if self._thread_sensitive:
  329. if hasattr(AsyncToSync.executors, "current"):
  330. # If we have a parent sync thread above somewhere, use that
  331. executor = AsyncToSync.executors.current
  332. elif self.thread_sensitive_context and self.thread_sensitive_context.get(
  333. None
  334. ):
  335. # If we have a way of retrieving the current context, attempt
  336. # to use a per-context thread pool executor
  337. thread_sensitive_context = self.thread_sensitive_context.get()
  338. if thread_sensitive_context in self.context_to_thread_executor:
  339. # Re-use thread executor in current context
  340. executor = self.context_to_thread_executor[thread_sensitive_context]
  341. else:
  342. # Create new thread executor in current context
  343. executor = ThreadPoolExecutor(max_workers=1)
  344. self.context_to_thread_executor[thread_sensitive_context] = executor
  345. elif loop in AsyncToSync.loop_thread_executors:
  346. # Re-use thread executor for running loop
  347. executor = AsyncToSync.loop_thread_executors[loop]
  348. elif self.deadlock_context and self.deadlock_context.get(False):
  349. raise RuntimeError(
  350. "Single thread executor already being used, would deadlock"
  351. )
  352. else:
  353. # Otherwise, we run it in a fixed single thread
  354. executor = self.single_thread_executor
  355. if self.deadlock_context:
  356. self.deadlock_context.set(True)
  357. else:
  358. # Use the passed in executor, or the loop's default if it is None
  359. executor = self._executor
  360. context = contextvars.copy_context()
  361. child = functools.partial(self.func, *args, **kwargs)
  362. func = context.run
  363. args = (child,)
  364. kwargs = {}
  365. try:
  366. # Run the code in the right thread
  367. future = loop.run_in_executor(
  368. executor,
  369. functools.partial(
  370. self.thread_handler,
  371. loop,
  372. self.get_current_task(),
  373. sys.exc_info(),
  374. func,
  375. *args,
  376. **kwargs,
  377. ),
  378. )
  379. ret = await asyncio.wait_for(future, timeout=None)
  380. finally:
  381. _restore_context(context)
  382. if self.deadlock_context:
  383. self.deadlock_context.set(False)
  384. return ret
  385. def __get__(self, parent, objtype):
  386. """
  387. Include self for methods
  388. """
  389. return functools.partial(self.__call__, parent)
  390. def thread_handler(self, loop, source_task, exc_info, func, *args, **kwargs):
  391. """
  392. Wraps the sync application with exception handling.
  393. """
  394. # Set the threadlocal for AsyncToSync
  395. self.threadlocal.main_event_loop = loop
  396. self.threadlocal.main_event_loop_pid = os.getpid()
  397. # Set the task mapping (used for the locals module)
  398. current_thread = threading.current_thread()
  399. if AsyncToSync.launch_map.get(source_task) == current_thread:
  400. # Our parent task was launched from this same thread, so don't make
  401. # a launch map entry - let it shortcut over us! (and stop infinite loops)
  402. parent_set = False
  403. else:
  404. self.launch_map[current_thread] = source_task
  405. parent_set = True
  406. # Run the function
  407. try:
  408. # If we have an exception, run the function inside the except block
  409. # after raising it so exc_info is correctly populated.
  410. if exc_info[1]:
  411. try:
  412. raise exc_info[1]
  413. except BaseException:
  414. return func(*args, **kwargs)
  415. else:
  416. return func(*args, **kwargs)
  417. finally:
  418. # Only delete the launch_map parent if we set it, otherwise it is
  419. # from someone else.
  420. if parent_set:
  421. del self.launch_map[current_thread]
  422. @staticmethod
  423. def get_current_task():
  424. """
  425. Implementation of asyncio.current_task()
  426. that returns None if there is no task.
  427. """
  428. try:
  429. return asyncio.current_task()
  430. except RuntimeError:
  431. return None
  432. # Lowercase aliases (and decorator friendliness)
  433. async_to_sync = AsyncToSync
  434. @overload
  435. def sync_to_async(
  436. func: None = None,
  437. thread_sensitive: bool = True,
  438. executor: Optional["ThreadPoolExecutor"] = None,
  439. ) -> Callable[[Callable[..., Any]], SyncToAsync]:
  440. ...
  441. @overload
  442. def sync_to_async(
  443. func: Callable[..., Any],
  444. thread_sensitive: bool = True,
  445. executor: Optional["ThreadPoolExecutor"] = None,
  446. ) -> SyncToAsync:
  447. ...
  448. def sync_to_async(
  449. func=None,
  450. thread_sensitive=True,
  451. executor=None,
  452. ):
  453. if func is None:
  454. return lambda f: SyncToAsync(
  455. f,
  456. thread_sensitive=thread_sensitive,
  457. executor=executor,
  458. )
  459. return SyncToAsync(
  460. func,
  461. thread_sensitive=thread_sensitive,
  462. executor=executor,
  463. )