_curio_backend.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. """curio async I/O library query support"""
  3. import socket
  4. import curio
  5. import curio.socket # type: ignore
  6. import dns._asyncbackend
  7. import dns.exception
  8. import dns.inet
  9. def _maybe_timeout(timeout):
  10. if timeout:
  11. return curio.ignore_after(timeout)
  12. else:
  13. return dns._asyncbackend.NullContext()
  14. # for brevity
  15. _lltuple = dns.inet.low_level_address_tuple
  16. # pylint: disable=redefined-outer-name
  17. class DatagramSocket(dns._asyncbackend.DatagramSocket):
  18. def __init__(self, socket):
  19. self.socket = socket
  20. self.family = socket.family
  21. async def sendto(self, what, destination, timeout):
  22. async with _maybe_timeout(timeout):
  23. return await self.socket.sendto(what, destination)
  24. raise dns.exception.Timeout(timeout=timeout) # pragma: no cover
  25. async def recvfrom(self, size, timeout):
  26. async with _maybe_timeout(timeout):
  27. return await self.socket.recvfrom(size)
  28. raise dns.exception.Timeout(timeout=timeout)
  29. async def close(self):
  30. await self.socket.close()
  31. async def getpeername(self):
  32. return self.socket.getpeername()
  33. async def getsockname(self):
  34. return self.socket.getsockname()
  35. class StreamSocket(dns._asyncbackend.StreamSocket):
  36. def __init__(self, socket):
  37. self.socket = socket
  38. self.family = socket.family
  39. async def sendall(self, what, timeout):
  40. async with _maybe_timeout(timeout):
  41. return await self.socket.sendall(what)
  42. raise dns.exception.Timeout(timeout=timeout)
  43. async def recv(self, size, timeout):
  44. async with _maybe_timeout(timeout):
  45. return await self.socket.recv(size)
  46. raise dns.exception.Timeout(timeout=timeout)
  47. async def close(self):
  48. await self.socket.close()
  49. async def getpeername(self):
  50. return self.socket.getpeername()
  51. async def getsockname(self):
  52. return self.socket.getsockname()
  53. class Backend(dns._asyncbackend.Backend):
  54. def name(self):
  55. return 'curio'
  56. async def make_socket(self, af, socktype, proto=0,
  57. source=None, destination=None, timeout=None,
  58. ssl_context=None, server_hostname=None):
  59. if socktype == socket.SOCK_DGRAM:
  60. s = curio.socket.socket(af, socktype, proto)
  61. try:
  62. if source:
  63. s.bind(_lltuple(source, af))
  64. except Exception: # pragma: no cover
  65. await s.close()
  66. raise
  67. return DatagramSocket(s)
  68. elif socktype == socket.SOCK_STREAM:
  69. if source:
  70. source_addr = _lltuple(source, af)
  71. else:
  72. source_addr = None
  73. async with _maybe_timeout(timeout):
  74. s = await curio.open_connection(destination[0], destination[1],
  75. ssl=ssl_context,
  76. source_addr=source_addr,
  77. server_hostname=server_hostname)
  78. return StreamSocket(s)
  79. raise NotImplementedError('unsupported socket ' +
  80. f'type {socktype}') # pragma: no cover
  81. async def sleep(self, interval):
  82. await curio.sleep(interval)