123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
- """trio async I/O library query support"""
- import socket
- import trio
- import trio.socket # type: ignore
- import dns._asyncbackend
- import dns.exception
- import dns.inet
- def _maybe_timeout(timeout):
- if timeout:
- return trio.move_on_after(timeout)
- else:
- return dns._asyncbackend.NullContext()
- # for brevity
- _lltuple = dns.inet.low_level_address_tuple
- # pylint: disable=redefined-outer-name
- class DatagramSocket(dns._asyncbackend.DatagramSocket):
- def __init__(self, socket):
- self.socket = socket
- self.family = socket.family
- async def sendto(self, what, destination, timeout):
- with _maybe_timeout(timeout):
- return await self.socket.sendto(what, destination)
- raise dns.exception.Timeout(timeout=timeout) # pragma: no cover
- async def recvfrom(self, size, timeout):
- with _maybe_timeout(timeout):
- return await self.socket.recvfrom(size)
- raise dns.exception.Timeout(timeout=timeout)
- async def close(self):
- self.socket.close()
- async def getpeername(self):
- return self.socket.getpeername()
- async def getsockname(self):
- return self.socket.getsockname()
- class StreamSocket(dns._asyncbackend.StreamSocket):
- def __init__(self, family, stream, tls=False):
- self.family = family
- self.stream = stream
- self.tls = tls
- async def sendall(self, what, timeout):
- with _maybe_timeout(timeout):
- return await self.stream.send_all(what)
- raise dns.exception.Timeout(timeout=timeout)
- async def recv(self, size, timeout):
- with _maybe_timeout(timeout):
- return await self.stream.receive_some(size)
- raise dns.exception.Timeout(timeout=timeout)
- async def close(self):
- await self.stream.aclose()
- async def getpeername(self):
- if self.tls:
- return self.stream.transport_stream.socket.getpeername()
- else:
- return self.stream.socket.getpeername()
- async def getsockname(self):
- if self.tls:
- return self.stream.transport_stream.socket.getsockname()
- else:
- return self.stream.socket.getsockname()
- class Backend(dns._asyncbackend.Backend):
- def name(self):
- return 'trio'
- async def make_socket(self, af, socktype, proto=0, source=None,
- destination=None, timeout=None,
- ssl_context=None, server_hostname=None):
- s = trio.socket.socket(af, socktype, proto)
- stream = None
- try:
- if source:
- await s.bind(_lltuple(source, af))
- if socktype == socket.SOCK_STREAM:
- with _maybe_timeout(timeout):
- await s.connect(_lltuple(destination, af))
- except Exception: # pragma: no cover
- s.close()
- raise
- if socktype == socket.SOCK_DGRAM:
- return DatagramSocket(s)
- elif socktype == socket.SOCK_STREAM:
- stream = trio.SocketStream(s)
- s = None
- tls = False
- if ssl_context:
- tls = True
- try:
- stream = trio.SSLStream(stream, ssl_context,
- server_hostname=server_hostname)
- except Exception: # pragma: no cover
- await stream.aclose()
- raise
- return StreamSocket(af, stream, tls)
- raise NotImplementedError('unsupported socket ' +
- f'type {socktype}') # pragma: no cover
- async def sleep(self, interval):
- await trio.sleep(interval)
|