asyncbackend.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. import dns.exception
  3. # pylint: disable=unused-import
  4. from dns._asyncbackend import Socket, DatagramSocket, \
  5. StreamSocket, Backend # noqa:
  6. # pylint: enable=unused-import
  7. _default_backend = None
  8. _backends = {}
  9. # Allow sniffio import to be disabled for testing purposes
  10. _no_sniffio = False
  11. class AsyncLibraryNotFoundError(dns.exception.DNSException):
  12. pass
  13. def get_backend(name):
  14. """Get the specified asynchronous backend.
  15. *name*, a ``str``, the name of the backend. Currently the "trio",
  16. "curio", and "asyncio" backends are available.
  17. Raises NotImplementError if an unknown backend name is specified.
  18. """
  19. # pylint: disable=import-outside-toplevel,redefined-outer-name
  20. backend = _backends.get(name)
  21. if backend:
  22. return backend
  23. if name == 'trio':
  24. import dns._trio_backend
  25. backend = dns._trio_backend.Backend()
  26. elif name == 'curio':
  27. import dns._curio_backend
  28. backend = dns._curio_backend.Backend()
  29. elif name == 'asyncio':
  30. import dns._asyncio_backend
  31. backend = dns._asyncio_backend.Backend()
  32. else:
  33. raise NotImplementedError(f'unimplemented async backend {name}')
  34. _backends[name] = backend
  35. return backend
  36. def sniff():
  37. """Attempt to determine the in-use asynchronous I/O library by using
  38. the ``sniffio`` module if it is available.
  39. Returns the name of the library, or raises AsyncLibraryNotFoundError
  40. if the library cannot be determined.
  41. """
  42. # pylint: disable=import-outside-toplevel
  43. try:
  44. if _no_sniffio:
  45. raise ImportError
  46. import sniffio
  47. try:
  48. return sniffio.current_async_library()
  49. except sniffio.AsyncLibraryNotFoundError:
  50. raise AsyncLibraryNotFoundError('sniffio cannot determine ' +
  51. 'async library')
  52. except ImportError:
  53. import asyncio
  54. try:
  55. asyncio.get_running_loop()
  56. return 'asyncio'
  57. except RuntimeError:
  58. raise AsyncLibraryNotFoundError('no async library detected')
  59. except AttributeError: # pragma: no cover
  60. # we have to check current_task on 3.6
  61. if not asyncio.Task.current_task():
  62. raise AsyncLibraryNotFoundError('no async library detected')
  63. return 'asyncio'
  64. def get_default_backend():
  65. """Get the default backend, initializing it if necessary.
  66. """
  67. if _default_backend:
  68. return _default_backend
  69. return set_default_backend(sniff())
  70. def set_default_backend(name):
  71. """Set the default backend.
  72. It's not normally necessary to call this method, as
  73. ``get_default_backend()`` will initialize the backend
  74. appropriately in many cases. If ``sniffio`` is not installed, or
  75. in testing situations, this function allows the backend to be set
  76. explicitly.
  77. """
  78. global _default_backend
  79. _default_backend = get_backend(name)
  80. return _default_backend