query.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. # Copyright (C) 2003-2017 Nominum, Inc.
  3. #
  4. # Permission to use, copy, modify, and distribute this software and its
  5. # documentation for any purpose with or without fee is hereby granted,
  6. # provided that the above copyright notice and this permission notice
  7. # appear in all copies.
  8. #
  9. # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
  10. # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  11. # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
  12. # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  13. # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  14. # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
  15. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  16. """Talk to a DNS server."""
  17. import contextlib
  18. import enum
  19. import errno
  20. import os
  21. import selectors
  22. import socket
  23. import struct
  24. import time
  25. import base64
  26. import urllib.parse
  27. import dns.exception
  28. import dns.inet
  29. import dns.name
  30. import dns.message
  31. import dns.rcode
  32. import dns.rdataclass
  33. import dns.rdatatype
  34. import dns.serial
  35. import dns.xfr
  36. try:
  37. import requests
  38. from requests_toolbelt.adapters.source import SourceAddressAdapter
  39. from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter
  40. _have_requests = True
  41. except ImportError: # pragma: no cover
  42. _have_requests = False
  43. _have_httpx = False
  44. _have_http2 = False
  45. try:
  46. import httpx
  47. _have_httpx = True
  48. try:
  49. # See if http2 support is available.
  50. with httpx.Client(http2=True):
  51. _have_http2 = True
  52. except Exception:
  53. pass
  54. except ImportError: # pragma: no cover
  55. pass
  56. have_doh = _have_requests or _have_httpx
  57. try:
  58. import ssl
  59. except ImportError: # pragma: no cover
  60. class ssl: # type: ignore
  61. class WantReadException(Exception):
  62. pass
  63. class WantWriteException(Exception):
  64. pass
  65. class SSLSocket:
  66. pass
  67. def create_default_context(self, *args, **kwargs):
  68. raise Exception('no ssl support')
  69. # Function used to create a socket. Can be overridden if needed in special
  70. # situations.
  71. socket_factory = socket.socket
  72. class UnexpectedSource(dns.exception.DNSException):
  73. """A DNS query response came from an unexpected address or port."""
  74. class BadResponse(dns.exception.FormError):
  75. """A DNS query response does not respond to the question asked."""
  76. class NoDOH(dns.exception.DNSException):
  77. """DNS over HTTPS (DOH) was requested but the requests module is not
  78. available."""
  79. # for backwards compatibility
  80. TransferError = dns.xfr.TransferError
  81. def _compute_times(timeout):
  82. now = time.time()
  83. if timeout is None:
  84. return (now, None)
  85. else:
  86. return (now, now + timeout)
  87. def _wait_for(fd, readable, writable, _, expiration):
  88. # Use the selected selector class to wait for any of the specified
  89. # events. An "expiration" absolute time is converted into a relative
  90. # timeout.
  91. #
  92. # The unused parameter is 'error', which is always set when
  93. # selecting for read or write, and we have no error-only selects.
  94. if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
  95. return True
  96. sel = _selector_class()
  97. events = 0
  98. if readable:
  99. events |= selectors.EVENT_READ
  100. if writable:
  101. events |= selectors.EVENT_WRITE
  102. if events:
  103. sel.register(fd, events)
  104. if expiration is None:
  105. timeout = None
  106. else:
  107. timeout = expiration - time.time()
  108. if timeout <= 0.0:
  109. raise dns.exception.Timeout
  110. if not sel.select(timeout):
  111. raise dns.exception.Timeout
  112. def _set_selector_class(selector_class):
  113. # Internal API. Do not use.
  114. global _selector_class
  115. _selector_class = selector_class
  116. if hasattr(selectors, 'PollSelector'):
  117. # Prefer poll() on platforms that support it because it has no
  118. # limits on the maximum value of a file descriptor (plus it will
  119. # be more efficient for high values).
  120. _selector_class = selectors.PollSelector
  121. else:
  122. _selector_class = selectors.SelectSelector # pragma: no cover
  123. def _wait_for_readable(s, expiration):
  124. _wait_for(s, True, False, True, expiration)
  125. def _wait_for_writable(s, expiration):
  126. _wait_for(s, False, True, True, expiration)
  127. def _addresses_equal(af, a1, a2):
  128. # Convert the first value of the tuple, which is a textual format
  129. # address into binary form, so that we are not confused by different
  130. # textual representations of the same address
  131. try:
  132. n1 = dns.inet.inet_pton(af, a1[0])
  133. n2 = dns.inet.inet_pton(af, a2[0])
  134. except dns.exception.SyntaxError:
  135. return False
  136. return n1 == n2 and a1[1:] == a2[1:]
  137. def _matches_destination(af, from_address, destination, ignore_unexpected):
  138. # Check that from_address is appropriate for a response to a query
  139. # sent to destination.
  140. if not destination:
  141. return True
  142. if _addresses_equal(af, from_address, destination) or \
  143. (dns.inet.is_multicast(destination[0]) and
  144. from_address[1:] == destination[1:]):
  145. return True
  146. elif ignore_unexpected:
  147. return False
  148. raise UnexpectedSource(f'got a response from {from_address} instead of '
  149. f'{destination}')
  150. def _destination_and_source(where, port, source, source_port,
  151. where_must_be_address=True):
  152. # Apply defaults and compute destination and source tuples
  153. # suitable for use in connect(), sendto(), or bind().
  154. af = None
  155. destination = None
  156. try:
  157. af = dns.inet.af_for_address(where)
  158. destination = where
  159. except Exception:
  160. if where_must_be_address:
  161. raise
  162. # URLs are ok so eat the exception
  163. if source:
  164. saf = dns.inet.af_for_address(source)
  165. if af:
  166. # We know the destination af, so source had better agree!
  167. if saf != af:
  168. raise ValueError('different address families for source ' +
  169. 'and destination')
  170. else:
  171. # We didn't know the destination af, but we know the source,
  172. # so that's our af.
  173. af = saf
  174. if source_port and not source:
  175. # Caller has specified a source_port but not an address, so we
  176. # need to return a source, and we need to use the appropriate
  177. # wildcard address as the address.
  178. if af == socket.AF_INET:
  179. source = '0.0.0.0'
  180. elif af == socket.AF_INET6:
  181. source = '::'
  182. else:
  183. raise ValueError('source_port specified but address family is '
  184. 'unknown')
  185. # Convert high-level (address, port) tuples into low-level address
  186. # tuples.
  187. if destination:
  188. destination = dns.inet.low_level_address_tuple((destination, port), af)
  189. if source:
  190. source = dns.inet.low_level_address_tuple((source, source_port), af)
  191. return (af, destination, source)
  192. def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
  193. s = socket_factory(af, type)
  194. try:
  195. s.setblocking(False)
  196. if source is not None:
  197. s.bind(source)
  198. if ssl_context:
  199. return ssl_context.wrap_socket(s, do_handshake_on_connect=False,
  200. server_hostname=server_hostname)
  201. else:
  202. return s
  203. except Exception:
  204. s.close()
  205. raise
  206. def https(q, where, timeout=None, port=443, source=None, source_port=0,
  207. one_rr_per_rrset=False, ignore_trailing=False,
  208. session=None, path='/dns-query', post=True,
  209. bootstrap_address=None, verify=True):
  210. """Return the response obtained after sending a query via DNS-over-HTTPS.
  211. *q*, a ``dns.message.Message``, the query to send.
  212. *where*, a ``str``, the nameserver IP address or the full URL. If an IP
  213. address is given, the URL will be constructed using the following schema:
  214. https://<IP-address>:<port>/<path>.
  215. *timeout*, a ``float`` or ``None``, the number of seconds to
  216. wait before the query times out. If ``None``, the default, wait forever.
  217. *port*, a ``int``, the port to send the query to. The default is 443.
  218. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  219. the source address. The default is the wildcard address.
  220. *source_port*, an ``int``, the port from which to send the message.
  221. The default is 0.
  222. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  223. RRset.
  224. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  225. junk at end of the received message.
  226. *session*, an ``httpx.Client`` or ``requests.session.Session``. If
  227. provided, the client/session to use to send the queries.
  228. *path*, a ``str``. If *where* is an IP address, then *path* will be used to
  229. construct the URL to send the DNS query to.
  230. *post*, a ``bool``. If ``True``, the default, POST method will be used.
  231. *bootstrap_address*, a ``str``, the IP address to use to bypass the
  232. system's DNS resolver.
  233. *verify*, a ``str``, containing a path to a certificate file or directory.
  234. Returns a ``dns.message.Message``.
  235. """
  236. if not have_doh:
  237. raise NoDOH('Neither httpx nor requests is available.') # pragma: no cover
  238. _httpx_ok = _have_httpx
  239. wire = q.to_wire()
  240. (af, _, source) = _destination_and_source(where, port, source, source_port,
  241. False)
  242. transport_adapter = None
  243. transport = None
  244. headers = {
  245. "accept": "application/dns-message"
  246. }
  247. if af is not None:
  248. if af == socket.AF_INET:
  249. url = 'https://{}:{}{}'.format(where, port, path)
  250. elif af == socket.AF_INET6:
  251. url = 'https://[{}]:{}{}'.format(where, port, path)
  252. elif bootstrap_address is not None:
  253. _httpx_ok = False
  254. split_url = urllib.parse.urlsplit(where)
  255. headers['Host'] = split_url.hostname
  256. url = where.replace(split_url.hostname, bootstrap_address)
  257. if _have_requests:
  258. transport_adapter = HostHeaderSSLAdapter()
  259. else:
  260. url = where
  261. if source is not None:
  262. # set source port and source address
  263. if _have_httpx:
  264. if source_port == 0:
  265. transport = httpx.HTTPTransport(local_address=source[0])
  266. else:
  267. _httpx_ok = False
  268. if _have_requests:
  269. transport_adapter = SourceAddressAdapter(source)
  270. if session:
  271. if _have_httpx:
  272. _is_httpx = isinstance(session, httpx.Client)
  273. else:
  274. _is_httpx = False
  275. if _is_httpx and not _httpx_ok:
  276. raise NoDOH('Session is httpx, but httpx cannot be used for '
  277. 'the requested operation.')
  278. else:
  279. _is_httpx = _httpx_ok
  280. if not _httpx_ok and not _have_requests:
  281. raise NoDOH('Cannot use httpx for this operation, and '
  282. 'requests is not available.')
  283. with contextlib.ExitStack() as stack:
  284. if not session:
  285. if _is_httpx:
  286. session = stack.enter_context(httpx.Client(http1=True,
  287. http2=_have_http2,
  288. verify=verify,
  289. transport=transport))
  290. else:
  291. session = stack.enter_context(requests.sessions.Session())
  292. if transport_adapter:
  293. session.mount(url, transport_adapter)
  294. # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
  295. # GET and POST examples
  296. if post:
  297. headers.update({
  298. "content-type": "application/dns-message",
  299. "content-length": str(len(wire))
  300. })
  301. if _is_httpx:
  302. response = session.post(url, headers=headers, content=wire,
  303. timeout=timeout)
  304. else:
  305. response = session.post(url, headers=headers, data=wire,
  306. timeout=timeout, verify=verify)
  307. else:
  308. wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
  309. if _is_httpx:
  310. wire = wire.decode() # httpx does a repr() if we give it bytes
  311. response = session.get(url, headers=headers,
  312. timeout=timeout,
  313. params={"dns": wire})
  314. else:
  315. response = session.get(url, headers=headers,
  316. timeout=timeout, verify=verify,
  317. params={"dns": wire})
  318. # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
  319. # status codes
  320. if response.status_code < 200 or response.status_code > 299:
  321. raise ValueError('{} responded with status code {}'
  322. '\nResponse body: {}'.format(where,
  323. response.status_code,
  324. response.content))
  325. r = dns.message.from_wire(response.content,
  326. keyring=q.keyring,
  327. request_mac=q.request_mac,
  328. one_rr_per_rrset=one_rr_per_rrset,
  329. ignore_trailing=ignore_trailing)
  330. r.time = response.elapsed
  331. if not q.is_response(r):
  332. raise BadResponse
  333. return r
  334. def _udp_recv(sock, max_size, expiration):
  335. """Reads a datagram from the socket.
  336. A Timeout exception will be raised if the operation is not completed
  337. by the expiration time.
  338. """
  339. while True:
  340. try:
  341. return sock.recvfrom(max_size)
  342. except BlockingIOError:
  343. _wait_for_readable(sock, expiration)
  344. def _udp_send(sock, data, destination, expiration):
  345. """Sends the specified datagram to destination over the socket.
  346. A Timeout exception will be raised if the operation is not completed
  347. by the expiration time.
  348. """
  349. while True:
  350. try:
  351. if destination:
  352. return sock.sendto(data, destination)
  353. else:
  354. return sock.send(data)
  355. except BlockingIOError: # pragma: no cover
  356. _wait_for_writable(sock, expiration)
  357. def send_udp(sock, what, destination, expiration=None):
  358. """Send a DNS message to the specified UDP socket.
  359. *sock*, a ``socket``.
  360. *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
  361. *destination*, a destination tuple appropriate for the address family
  362. of the socket, specifying where to send the query.
  363. *expiration*, a ``float`` or ``None``, the absolute time at which
  364. a timeout exception should be raised. If ``None``, no timeout will
  365. occur.
  366. Returns an ``(int, float)`` tuple of bytes sent and the sent time.
  367. """
  368. if isinstance(what, dns.message.Message):
  369. what = what.to_wire()
  370. sent_time = time.time()
  371. n = _udp_send(sock, what, destination, expiration)
  372. return (n, sent_time)
  373. def receive_udp(sock, destination=None, expiration=None,
  374. ignore_unexpected=False, one_rr_per_rrset=False,
  375. keyring=None, request_mac=b'', ignore_trailing=False,
  376. raise_on_truncation=False):
  377. """Read a DNS message from a UDP socket.
  378. *sock*, a ``socket``.
  379. *destination*, a destination tuple appropriate for the address family
  380. of the socket, specifying where the message is expected to arrive from.
  381. When receiving a response, this would be where the associated query was
  382. sent.
  383. *expiration*, a ``float`` or ``None``, the absolute time at which
  384. a timeout exception should be raised. If ``None``, no timeout will
  385. occur.
  386. *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
  387. unexpected sources.
  388. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  389. RRset.
  390. *keyring*, a ``dict``, the keyring to use for TSIG.
  391. *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
  392. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  393. junk at end of the received message.
  394. *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
  395. the TC bit is set.
  396. Raises if the message is malformed, if network errors occur, of if
  397. there is a timeout.
  398. If *destination* is not ``None``, returns a ``(dns.message.Message, float)``
  399. tuple of the received message and the received time.
  400. If *destination* is ``None``, returns a
  401. ``(dns.message.Message, float, tuple)``
  402. tuple of the received message, the received time, and the address where
  403. the message arrived from.
  404. """
  405. wire = b''
  406. while True:
  407. (wire, from_address) = _udp_recv(sock, 65535, expiration)
  408. if _matches_destination(sock.family, from_address, destination,
  409. ignore_unexpected):
  410. break
  411. received_time = time.time()
  412. r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
  413. one_rr_per_rrset=one_rr_per_rrset,
  414. ignore_trailing=ignore_trailing,
  415. raise_on_truncation=raise_on_truncation)
  416. if destination:
  417. return (r, received_time)
  418. else:
  419. return (r, received_time, from_address)
  420. def udp(q, where, timeout=None, port=53, source=None, source_port=0,
  421. ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
  422. raise_on_truncation=False, sock=None):
  423. """Return the response obtained after sending a query via UDP.
  424. *q*, a ``dns.message.Message``, the query to send
  425. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  426. to send the message.
  427. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
  428. query times out. If ``None``, the default, wait forever.
  429. *port*, an ``int``, the port send the message to. The default is 53.
  430. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  431. the source address. The default is the wildcard address.
  432. *source_port*, an ``int``, the port from which to send the message.
  433. The default is 0.
  434. *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
  435. unexpected sources.
  436. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  437. RRset.
  438. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  439. junk at end of the received message.
  440. *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
  441. the TC bit is set.
  442. *sock*, a ``socket.socket``, or ``None``, the socket to use for the
  443. query. If ``None``, the default, a socket is created. Note that
  444. if a socket is provided, it must be a nonblocking datagram socket,
  445. and the *source* and *source_port* are ignored.
  446. Returns a ``dns.message.Message``.
  447. """
  448. wire = q.to_wire()
  449. (af, destination, source) = _destination_and_source(where, port,
  450. source, source_port)
  451. (begin_time, expiration) = _compute_times(timeout)
  452. with contextlib.ExitStack() as stack:
  453. if sock:
  454. s = sock
  455. else:
  456. s = stack.enter_context(_make_socket(af, socket.SOCK_DGRAM, source))
  457. send_udp(s, wire, destination, expiration)
  458. (r, received_time) = receive_udp(s, destination, expiration,
  459. ignore_unexpected, one_rr_per_rrset,
  460. q.keyring, q.mac, ignore_trailing,
  461. raise_on_truncation)
  462. r.time = received_time - begin_time
  463. if not q.is_response(r):
  464. raise BadResponse
  465. return r
  466. def udp_with_fallback(q, where, timeout=None, port=53, source=None,
  467. source_port=0, ignore_unexpected=False,
  468. one_rr_per_rrset=False, ignore_trailing=False,
  469. udp_sock=None, tcp_sock=None):
  470. """Return the response to the query, trying UDP first and falling back
  471. to TCP if UDP results in a truncated response.
  472. *q*, a ``dns.message.Message``, the query to send
  473. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  474. to send the message.
  475. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
  476. query times out. If ``None``, the default, wait forever.
  477. *port*, an ``int``, the port send the message to. The default is 53.
  478. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  479. the source address. The default is the wildcard address.
  480. *source_port*, an ``int``, the port from which to send the message.
  481. The default is 0.
  482. *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
  483. unexpected sources.
  484. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  485. RRset.
  486. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  487. junk at end of the received message.
  488. *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the
  489. UDP query. If ``None``, the default, a socket is created. Note that
  490. if a socket is provided, it must be a nonblocking datagram socket,
  491. and the *source* and *source_port* are ignored for the UDP query.
  492. *tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
  493. TCP query. If ``None``, the default, a socket is created. Note that
  494. if a socket is provided, it must be a nonblocking connected stream
  495. socket, and *where*, *source* and *source_port* are ignored for the TCP
  496. query.
  497. Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
  498. if and only if TCP was used.
  499. """
  500. try:
  501. response = udp(q, where, timeout, port, source, source_port,
  502. ignore_unexpected, one_rr_per_rrset,
  503. ignore_trailing, True, udp_sock)
  504. return (response, False)
  505. except dns.message.Truncated:
  506. response = tcp(q, where, timeout, port, source, source_port,
  507. one_rr_per_rrset, ignore_trailing, tcp_sock)
  508. return (response, True)
  509. def _net_read(sock, count, expiration):
  510. """Read the specified number of bytes from sock. Keep trying until we
  511. either get the desired amount, or we hit EOF.
  512. A Timeout exception will be raised if the operation is not completed
  513. by the expiration time.
  514. """
  515. s = b''
  516. while count > 0:
  517. try:
  518. n = sock.recv(count)
  519. if n == b'':
  520. raise EOFError
  521. count -= len(n)
  522. s += n
  523. except (BlockingIOError, ssl.SSLWantReadError):
  524. _wait_for_readable(sock, expiration)
  525. except ssl.SSLWantWriteError: # pragma: no cover
  526. _wait_for_writable(sock, expiration)
  527. return s
  528. def _net_write(sock, data, expiration):
  529. """Write the specified data to the socket.
  530. A Timeout exception will be raised if the operation is not completed
  531. by the expiration time.
  532. """
  533. current = 0
  534. l = len(data)
  535. while current < l:
  536. try:
  537. current += sock.send(data[current:])
  538. except (BlockingIOError, ssl.SSLWantWriteError):
  539. _wait_for_writable(sock, expiration)
  540. except ssl.SSLWantReadError: # pragma: no cover
  541. _wait_for_readable(sock, expiration)
  542. def send_tcp(sock, what, expiration=None):
  543. """Send a DNS message to the specified TCP socket.
  544. *sock*, a ``socket``.
  545. *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
  546. *expiration*, a ``float`` or ``None``, the absolute time at which
  547. a timeout exception should be raised. If ``None``, no timeout will
  548. occur.
  549. Returns an ``(int, float)`` tuple of bytes sent and the sent time.
  550. """
  551. if isinstance(what, dns.message.Message):
  552. what = what.to_wire()
  553. l = len(what)
  554. # copying the wire into tcpmsg is inefficient, but lets us
  555. # avoid writev() or doing a short write that would get pushed
  556. # onto the net
  557. tcpmsg = struct.pack("!H", l) + what
  558. sent_time = time.time()
  559. _net_write(sock, tcpmsg, expiration)
  560. return (len(tcpmsg), sent_time)
  561. def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
  562. keyring=None, request_mac=b'', ignore_trailing=False):
  563. """Read a DNS message from a TCP socket.
  564. *sock*, a ``socket``.
  565. *expiration*, a ``float`` or ``None``, the absolute time at which
  566. a timeout exception should be raised. If ``None``, no timeout will
  567. occur.
  568. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  569. RRset.
  570. *keyring*, a ``dict``, the keyring to use for TSIG.
  571. *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
  572. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  573. junk at end of the received message.
  574. Raises if the message is malformed, if network errors occur, of if
  575. there is a timeout.
  576. Returns a ``(dns.message.Message, float)`` tuple of the received message
  577. and the received time.
  578. """
  579. ldata = _net_read(sock, 2, expiration)
  580. (l,) = struct.unpack("!H", ldata)
  581. wire = _net_read(sock, l, expiration)
  582. received_time = time.time()
  583. r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
  584. one_rr_per_rrset=one_rr_per_rrset,
  585. ignore_trailing=ignore_trailing)
  586. return (r, received_time)
  587. def _connect(s, address, expiration):
  588. err = s.connect_ex(address)
  589. if err == 0:
  590. return
  591. if err in (errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY):
  592. _wait_for_writable(s, expiration)
  593. err = s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
  594. if err != 0:
  595. raise OSError(err, os.strerror(err))
  596. def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
  597. one_rr_per_rrset=False, ignore_trailing=False, sock=None):
  598. """Return the response obtained after sending a query via TCP.
  599. *q*, a ``dns.message.Message``, the query to send
  600. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  601. to send the message.
  602. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
  603. query times out. If ``None``, the default, wait forever.
  604. *port*, an ``int``, the port send the message to. The default is 53.
  605. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  606. the source address. The default is the wildcard address.
  607. *source_port*, an ``int``, the port from which to send the message.
  608. The default is 0.
  609. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  610. RRset.
  611. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  612. junk at end of the received message.
  613. *sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
  614. query. If ``None``, the default, a socket is created. Note that
  615. if a socket is provided, it must be a nonblocking connected stream
  616. socket, and *where*, *port*, *source* and *source_port* are ignored.
  617. Returns a ``dns.message.Message``.
  618. """
  619. wire = q.to_wire()
  620. (begin_time, expiration) = _compute_times(timeout)
  621. with contextlib.ExitStack() as stack:
  622. if sock:
  623. s = sock
  624. else:
  625. (af, destination, source) = _destination_and_source(where, port,
  626. source,
  627. source_port)
  628. s = stack.enter_context(_make_socket(af, socket.SOCK_STREAM,
  629. source))
  630. _connect(s, destination, expiration)
  631. send_tcp(s, wire, expiration)
  632. (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
  633. q.keyring, q.mac, ignore_trailing)
  634. r.time = received_time - begin_time
  635. if not q.is_response(r):
  636. raise BadResponse
  637. return r
  638. def _tls_handshake(s, expiration):
  639. while True:
  640. try:
  641. s.do_handshake()
  642. return
  643. except ssl.SSLWantReadError:
  644. _wait_for_readable(s, expiration)
  645. except ssl.SSLWantWriteError: # pragma: no cover
  646. _wait_for_writable(s, expiration)
  647. def tls(q, where, timeout=None, port=853, source=None, source_port=0,
  648. one_rr_per_rrset=False, ignore_trailing=False, sock=None,
  649. ssl_context=None, server_hostname=None):
  650. """Return the response obtained after sending a query via TLS.
  651. *q*, a ``dns.message.Message``, the query to send
  652. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  653. to send the message.
  654. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
  655. query times out. If ``None``, the default, wait forever.
  656. *port*, an ``int``, the port send the message to. The default is 853.
  657. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  658. the source address. The default is the wildcard address.
  659. *source_port*, an ``int``, the port from which to send the message.
  660. The default is 0.
  661. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  662. RRset.
  663. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  664. junk at end of the received message.
  665. *sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for
  666. the query. If ``None``, the default, a socket is created. Note
  667. that if a socket is provided, it must be a nonblocking connected
  668. SSL stream socket, and *where*, *port*, *source*, *source_port*,
  669. and *ssl_context* are ignored.
  670. *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
  671. a TLS connection. If ``None``, the default, creates one with the default
  672. configuration.
  673. *server_hostname*, a ``str`` containing the server's hostname. The
  674. default is ``None``, which means that no hostname is known, and if an
  675. SSL context is created, hostname checking will be disabled.
  676. Returns a ``dns.message.Message``.
  677. """
  678. if sock:
  679. #
  680. # If a socket was provided, there's no special TLS handling needed.
  681. #
  682. return tcp(q, where, timeout, port, source, source_port,
  683. one_rr_per_rrset, ignore_trailing, sock)
  684. wire = q.to_wire()
  685. (begin_time, expiration) = _compute_times(timeout)
  686. (af, destination, source) = _destination_and_source(where, port,
  687. source, source_port)
  688. if ssl_context is None and not sock:
  689. ssl_context = ssl.create_default_context()
  690. if server_hostname is None:
  691. ssl_context.check_hostname = False
  692. with _make_socket(af, socket.SOCK_STREAM, source, ssl_context=ssl_context,
  693. server_hostname=server_hostname) as s:
  694. _connect(s, destination, expiration)
  695. _tls_handshake(s, expiration)
  696. send_tcp(s, wire, expiration)
  697. (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
  698. q.keyring, q.mac, ignore_trailing)
  699. r.time = received_time - begin_time
  700. if not q.is_response(r):
  701. raise BadResponse
  702. return r
  703. def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
  704. timeout=None, port=53, keyring=None, keyname=None, relativize=True,
  705. lifetime=None, source=None, source_port=0, serial=0,
  706. use_udp=False, keyalgorithm=dns.tsig.default_algorithm):
  707. """Return a generator for the responses to a zone transfer.
  708. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  709. to send the message.
  710. *zone*, a ``dns.name.Name`` or ``str``, the name of the zone to transfer.
  711. *rdtype*, an ``int`` or ``str``, the type of zone transfer. The
  712. default is ``dns.rdatatype.AXFR``. ``dns.rdatatype.IXFR`` can be
  713. used to do an incremental transfer instead.
  714. *rdclass*, an ``int`` or ``str``, the class of the zone transfer.
  715. The default is ``dns.rdataclass.IN``.
  716. *timeout*, a ``float``, the number of seconds to wait for each
  717. response message. If None, the default, wait forever.
  718. *port*, an ``int``, the port send the message to. The default is 53.
  719. *keyring*, a ``dict``, the keyring to use for TSIG.
  720. *keyname*, a ``dns.name.Name`` or ``str``, the name of the TSIG
  721. key to use.
  722. *relativize*, a ``bool``. If ``True``, all names in the zone will be
  723. relativized to the zone origin. It is essential that the
  724. relativize setting matches the one specified to
  725. ``dns.zone.from_xfr()`` if using this generator to make a zone.
  726. *lifetime*, a ``float``, the total number of seconds to spend
  727. doing the transfer. If ``None``, the default, then there is no
  728. limit on the time the transfer may take.
  729. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  730. the source address. The default is the wildcard address.
  731. *source_port*, an ``int``, the port from which to send the message.
  732. The default is 0.
  733. *serial*, an ``int``, the SOA serial number to use as the base for
  734. an IXFR diff sequence (only meaningful if *rdtype* is
  735. ``dns.rdatatype.IXFR``).
  736. *use_udp*, a ``bool``. If ``True``, use UDP (only meaningful for IXFR).
  737. *keyalgorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use.
  738. Raises on errors, and so does the generator.
  739. Returns a generator of ``dns.message.Message`` objects.
  740. """
  741. if isinstance(zone, str):
  742. zone = dns.name.from_text(zone)
  743. rdtype = dns.rdatatype.RdataType.make(rdtype)
  744. q = dns.message.make_query(zone, rdtype, rdclass)
  745. if rdtype == dns.rdatatype.IXFR:
  746. rrset = dns.rrset.from_text(zone, 0, 'IN', 'SOA',
  747. '. . %u 0 0 0 0' % serial)
  748. q.authority.append(rrset)
  749. if keyring is not None:
  750. q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
  751. wire = q.to_wire()
  752. (af, destination, source) = _destination_and_source(where, port,
  753. source, source_port)
  754. if use_udp and rdtype != dns.rdatatype.IXFR:
  755. raise ValueError('cannot do a UDP AXFR')
  756. sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
  757. with _make_socket(af, sock_type, source) as s:
  758. (_, expiration) = _compute_times(lifetime)
  759. _connect(s, destination, expiration)
  760. l = len(wire)
  761. if use_udp:
  762. _udp_send(s, wire, None, expiration)
  763. else:
  764. tcpmsg = struct.pack("!H", l) + wire
  765. _net_write(s, tcpmsg, expiration)
  766. done = False
  767. delete_mode = True
  768. expecting_SOA = False
  769. soa_rrset = None
  770. if relativize:
  771. origin = zone
  772. oname = dns.name.empty
  773. else:
  774. origin = None
  775. oname = zone
  776. tsig_ctx = None
  777. while not done:
  778. (_, mexpiration) = _compute_times(timeout)
  779. if mexpiration is None or \
  780. (expiration is not None and mexpiration > expiration):
  781. mexpiration = expiration
  782. if use_udp:
  783. (wire, _) = _udp_recv(s, 65535, mexpiration)
  784. else:
  785. ldata = _net_read(s, 2, mexpiration)
  786. (l,) = struct.unpack("!H", ldata)
  787. wire = _net_read(s, l, mexpiration)
  788. is_ixfr = (rdtype == dns.rdatatype.IXFR)
  789. r = dns.message.from_wire(wire, keyring=q.keyring,
  790. request_mac=q.mac, xfr=True,
  791. origin=origin, tsig_ctx=tsig_ctx,
  792. multi=True, one_rr_per_rrset=is_ixfr)
  793. rcode = r.rcode()
  794. if rcode != dns.rcode.NOERROR:
  795. raise TransferError(rcode)
  796. tsig_ctx = r.tsig_ctx
  797. answer_index = 0
  798. if soa_rrset is None:
  799. if not r.answer or r.answer[0].name != oname:
  800. raise dns.exception.FormError(
  801. "No answer or RRset not for qname")
  802. rrset = r.answer[0]
  803. if rrset.rdtype != dns.rdatatype.SOA:
  804. raise dns.exception.FormError("first RRset is not an SOA")
  805. answer_index = 1
  806. soa_rrset = rrset.copy()
  807. if rdtype == dns.rdatatype.IXFR:
  808. if dns.serial.Serial(soa_rrset[0].serial) <= serial:
  809. #
  810. # We're already up-to-date.
  811. #
  812. done = True
  813. else:
  814. expecting_SOA = True
  815. #
  816. # Process SOAs in the answer section (other than the initial
  817. # SOA in the first message).
  818. #
  819. for rrset in r.answer[answer_index:]:
  820. if done:
  821. raise dns.exception.FormError("answers after final SOA")
  822. if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
  823. if expecting_SOA:
  824. if rrset[0].serial != serial:
  825. raise dns.exception.FormError(
  826. "IXFR base serial mismatch")
  827. expecting_SOA = False
  828. elif rdtype == dns.rdatatype.IXFR:
  829. delete_mode = not delete_mode
  830. #
  831. # If this SOA RRset is equal to the first we saw then we're
  832. # finished. If this is an IXFR we also check that we're
  833. # seeing the record in the expected part of the response.
  834. #
  835. if rrset == soa_rrset and \
  836. (rdtype == dns.rdatatype.AXFR or
  837. (rdtype == dns.rdatatype.IXFR and delete_mode)):
  838. done = True
  839. elif expecting_SOA:
  840. #
  841. # We made an IXFR request and are expecting another
  842. # SOA RR, but saw something else, so this must be an
  843. # AXFR response.
  844. #
  845. rdtype = dns.rdatatype.AXFR
  846. expecting_SOA = False
  847. if done and q.keyring and not r.had_tsig:
  848. raise dns.exception.FormError("missing TSIG")
  849. yield r
  850. class UDPMode(enum.IntEnum):
  851. """How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
  852. NEVER means "never use UDP; always use TCP"
  853. TRY_FIRST means "try to use UDP but fall back to TCP if needed"
  854. ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
  855. """
  856. NEVER = 0
  857. TRY_FIRST = 1
  858. ONLY = 2
  859. def inbound_xfr(where, txn_manager, query=None,
  860. port=53, timeout=None, lifetime=None, source=None,
  861. source_port=0, udp_mode=UDPMode.NEVER):
  862. """Conduct an inbound transfer and apply it via a transaction from the
  863. txn_manager.
  864. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  865. to send the message.
  866. *txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager
  867. for this transfer (typically a ``dns.zone.Zone``).
  868. *query*, the query to send. If not supplied, a default query is
  869. constructed using information from the *txn_manager*.
  870. *port*, an ``int``, the port send the message to. The default is 53.
  871. *timeout*, a ``float``, the number of seconds to wait for each
  872. response message. If None, the default, wait forever.
  873. *lifetime*, a ``float``, the total number of seconds to spend
  874. doing the transfer. If ``None``, the default, then there is no
  875. limit on the time the transfer may take.
  876. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  877. the source address. The default is the wildcard address.
  878. *source_port*, an ``int``, the port from which to send the message.
  879. The default is 0.
  880. *udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used
  881. for IXFRs. The default is ``dns.UDPMode.NEVER``, i.e. only use
  882. TCP. Other possibilities are ``dns.UDPMode.TRY_FIRST``, which
  883. means "try UDP but fallback to TCP if needed", and
  884. ``dns.UDPMode.ONLY``, which means "try UDP and raise
  885. ``dns.xfr.UseTCP`` if it does not succeed.
  886. Raises on errors.
  887. """
  888. if query is None:
  889. (query, serial) = dns.xfr.make_query(txn_manager)
  890. else:
  891. serial = dns.xfr.extract_serial_from_query(query)
  892. rdtype = query.question[0].rdtype
  893. is_ixfr = rdtype == dns.rdatatype.IXFR
  894. origin = txn_manager.from_wire_origin()
  895. wire = query.to_wire()
  896. (af, destination, source) = _destination_and_source(where, port,
  897. source, source_port)
  898. (_, expiration) = _compute_times(lifetime)
  899. retry = True
  900. while retry:
  901. retry = False
  902. if is_ixfr and udp_mode != UDPMode.NEVER:
  903. sock_type = socket.SOCK_DGRAM
  904. is_udp = True
  905. else:
  906. sock_type = socket.SOCK_STREAM
  907. is_udp = False
  908. with _make_socket(af, sock_type, source) as s:
  909. _connect(s, destination, expiration)
  910. if is_udp:
  911. _udp_send(s, wire, None, expiration)
  912. else:
  913. tcpmsg = struct.pack("!H", len(wire)) + wire
  914. _net_write(s, tcpmsg, expiration)
  915. with dns.xfr.Inbound(txn_manager, rdtype, serial,
  916. is_udp) as inbound:
  917. done = False
  918. tsig_ctx = None
  919. while not done:
  920. (_, mexpiration) = _compute_times(timeout)
  921. if mexpiration is None or \
  922. (expiration is not None and mexpiration > expiration):
  923. mexpiration = expiration
  924. if is_udp:
  925. (rwire, _) = _udp_recv(s, 65535, mexpiration)
  926. else:
  927. ldata = _net_read(s, 2, mexpiration)
  928. (l,) = struct.unpack("!H", ldata)
  929. rwire = _net_read(s, l, mexpiration)
  930. r = dns.message.from_wire(rwire, keyring=query.keyring,
  931. request_mac=query.mac, xfr=True,
  932. origin=origin, tsig_ctx=tsig_ctx,
  933. multi=(not is_udp),
  934. one_rr_per_rrset=is_ixfr)
  935. try:
  936. done = inbound.process_message(r)
  937. except dns.xfr.UseTCP:
  938. assert is_udp # should not happen if we used TCP!
  939. if udp_mode == UDPMode.ONLY:
  940. raise
  941. done = True
  942. retry = True
  943. udp_mode = UDPMode.NEVER
  944. continue
  945. tsig_ctx = r.tsig_ctx
  946. if not retry and query.keyring and not r.had_tsig:
  947. raise dns.exception.FormError("missing TSIG")