_trio_backend.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. """trio async I/O library query support"""
  3. import socket
  4. import trio
  5. import trio.socket # type: ignore
  6. import dns._asyncbackend
  7. import dns.exception
  8. import dns.inet
  9. def _maybe_timeout(timeout):
  10. if timeout:
  11. return trio.move_on_after(timeout)
  12. else:
  13. return dns._asyncbackend.NullContext()
  14. # for brevity
  15. _lltuple = dns.inet.low_level_address_tuple
  16. # pylint: disable=redefined-outer-name
  17. class DatagramSocket(dns._asyncbackend.DatagramSocket):
  18. def __init__(self, socket):
  19. self.socket = socket
  20. self.family = socket.family
  21. async def sendto(self, what, destination, timeout):
  22. with _maybe_timeout(timeout):
  23. return await self.socket.sendto(what, destination)
  24. raise dns.exception.Timeout(timeout=timeout) # pragma: no cover
  25. async def recvfrom(self, size, timeout):
  26. with _maybe_timeout(timeout):
  27. return await self.socket.recvfrom(size)
  28. raise dns.exception.Timeout(timeout=timeout)
  29. async def close(self):
  30. self.socket.close()
  31. async def getpeername(self):
  32. return self.socket.getpeername()
  33. async def getsockname(self):
  34. return self.socket.getsockname()
  35. class StreamSocket(dns._asyncbackend.StreamSocket):
  36. def __init__(self, family, stream, tls=False):
  37. self.family = family
  38. self.stream = stream
  39. self.tls = tls
  40. async def sendall(self, what, timeout):
  41. with _maybe_timeout(timeout):
  42. return await self.stream.send_all(what)
  43. raise dns.exception.Timeout(timeout=timeout)
  44. async def recv(self, size, timeout):
  45. with _maybe_timeout(timeout):
  46. return await self.stream.receive_some(size)
  47. raise dns.exception.Timeout(timeout=timeout)
  48. async def close(self):
  49. await self.stream.aclose()
  50. async def getpeername(self):
  51. if self.tls:
  52. return self.stream.transport_stream.socket.getpeername()
  53. else:
  54. return self.stream.socket.getpeername()
  55. async def getsockname(self):
  56. if self.tls:
  57. return self.stream.transport_stream.socket.getsockname()
  58. else:
  59. return self.stream.socket.getsockname()
  60. class Backend(dns._asyncbackend.Backend):
  61. def name(self):
  62. return 'trio'
  63. async def make_socket(self, af, socktype, proto=0, source=None,
  64. destination=None, timeout=None,
  65. ssl_context=None, server_hostname=None):
  66. s = trio.socket.socket(af, socktype, proto)
  67. stream = None
  68. try:
  69. if source:
  70. await s.bind(_lltuple(source, af))
  71. if socktype == socket.SOCK_STREAM:
  72. with _maybe_timeout(timeout):
  73. await s.connect(_lltuple(destination, af))
  74. except Exception: # pragma: no cover
  75. s.close()
  76. raise
  77. if socktype == socket.SOCK_DGRAM:
  78. return DatagramSocket(s)
  79. elif socktype == socket.SOCK_STREAM:
  80. stream = trio.SocketStream(s)
  81. s = None
  82. tls = False
  83. if ssl_context:
  84. tls = True
  85. try:
  86. stream = trio.SSLStream(stream, ssl_context,
  87. server_hostname=server_hostname)
  88. except Exception: # pragma: no cover
  89. await stream.aclose()
  90. raise
  91. return StreamSocket(af, stream, tls)
  92. raise NotImplementedError('unsupported socket ' +
  93. f'type {socktype}') # pragma: no cover
  94. async def sleep(self, interval):
  95. await trio.sleep(interval)