test_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. """Utilities shared by tests."""
  2. import collections
  3. import contextlib
  4. import io
  5. import logging
  6. import os
  7. import re
  8. import socket
  9. import socketserver
  10. import sys
  11. import tempfile
  12. import threading
  13. import time
  14. import unittest
  15. from unittest import mock
  16. from http.server import HTTPServer
  17. from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
  18. try:
  19. import ssl
  20. except ImportError: # pragma: no cover
  21. ssl = None
  22. from . import base_events
  23. from . import events
  24. from . import futures
  25. from . import selectors
  26. from . import tasks
  27. from .coroutines import coroutine
  28. from .log import logger
  29. if sys.platform == 'win32': # pragma: no cover
  30. from .windows_utils import socketpair
  31. else:
  32. from socket import socketpair # pragma: no cover
  33. def dummy_ssl_context():
  34. if ssl is None:
  35. return None
  36. else:
  37. return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
  38. def run_briefly(loop):
  39. @coroutine
  40. def once():
  41. pass
  42. gen = once()
  43. t = loop.create_task(gen)
  44. # Don't log a warning if the task is not done after run_until_complete().
  45. # It occurs if the loop is stopped or if a task raises a BaseException.
  46. t._log_destroy_pending = False
  47. try:
  48. loop.run_until_complete(t)
  49. finally:
  50. gen.close()
  51. def run_until(loop, pred, timeout=30):
  52. deadline = time.time() + timeout
  53. while not pred():
  54. if timeout is not None:
  55. timeout = deadline - time.time()
  56. if timeout <= 0:
  57. raise futures.TimeoutError()
  58. loop.run_until_complete(tasks.sleep(0.001, loop=loop))
  59. def run_once(loop):
  60. """loop.stop() schedules _raise_stop_error()
  61. and run_forever() runs until _raise_stop_error() callback.
  62. this wont work if test waits for some IO events, because
  63. _raise_stop_error() runs before any of io events callbacks.
  64. """
  65. loop.stop()
  66. loop.run_forever()
  67. class SilentWSGIRequestHandler(WSGIRequestHandler):
  68. def get_stderr(self):
  69. return io.StringIO()
  70. def log_message(self, format, *args):
  71. pass
  72. class SilentWSGIServer(WSGIServer):
  73. request_timeout = 2
  74. def get_request(self):
  75. request, client_addr = super().get_request()
  76. request.settimeout(self.request_timeout)
  77. return request, client_addr
  78. def handle_error(self, request, client_address):
  79. pass
  80. class SSLWSGIServerMixin:
  81. def finish_request(self, request, client_address):
  82. # The relative location of our test directory (which
  83. # contains the ssl key and certificate files) differs
  84. # between the stdlib and stand-alone asyncio.
  85. # Prefer our own if we can find it.
  86. here = os.path.join(os.path.dirname(__file__), '..', 'tests')
  87. if not os.path.isdir(here):
  88. here = os.path.join(os.path.dirname(os.__file__),
  89. 'test', 'test_asyncio')
  90. keyfile = os.path.join(here, 'ssl_key.pem')
  91. certfile = os.path.join(here, 'ssl_cert.pem')
  92. ssock = ssl.wrap_socket(request,
  93. keyfile=keyfile,
  94. certfile=certfile,
  95. server_side=True)
  96. try:
  97. self.RequestHandlerClass(ssock, client_address, self)
  98. ssock.close()
  99. except OSError:
  100. # maybe socket has been closed by peer
  101. pass
  102. class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
  103. pass
  104. def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
  105. def app(environ, start_response):
  106. status = '200 OK'
  107. headers = [('Content-type', 'text/plain')]
  108. start_response(status, headers)
  109. return [b'Test message']
  110. # Run the test WSGI server in a separate thread in order not to
  111. # interfere with event handling in the main thread
  112. server_class = server_ssl_cls if use_ssl else server_cls
  113. httpd = server_class(address, SilentWSGIRequestHandler)
  114. httpd.set_app(app)
  115. httpd.address = httpd.server_address
  116. server_thread = threading.Thread(
  117. target=lambda: httpd.serve_forever(poll_interval=0.05))
  118. server_thread.start()
  119. try:
  120. yield httpd
  121. finally:
  122. httpd.shutdown()
  123. httpd.server_close()
  124. server_thread.join()
  125. if hasattr(socket, 'AF_UNIX'):
  126. class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
  127. def server_bind(self):
  128. socketserver.UnixStreamServer.server_bind(self)
  129. self.server_name = '127.0.0.1'
  130. self.server_port = 80
  131. class UnixWSGIServer(UnixHTTPServer, WSGIServer):
  132. request_timeout = 2
  133. def server_bind(self):
  134. UnixHTTPServer.server_bind(self)
  135. self.setup_environ()
  136. def get_request(self):
  137. request, client_addr = super().get_request()
  138. request.settimeout(self.request_timeout)
  139. # Code in the stdlib expects that get_request
  140. # will return a socket and a tuple (host, port).
  141. # However, this isn't true for UNIX sockets,
  142. # as the second return value will be a path;
  143. # hence we return some fake data sufficient
  144. # to get the tests going
  145. return request, ('127.0.0.1', '')
  146. class SilentUnixWSGIServer(UnixWSGIServer):
  147. def handle_error(self, request, client_address):
  148. pass
  149. class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
  150. pass
  151. def gen_unix_socket_path():
  152. with tempfile.NamedTemporaryFile() as file:
  153. return file.name
  154. @contextlib.contextmanager
  155. def unix_socket_path():
  156. path = gen_unix_socket_path()
  157. try:
  158. yield path
  159. finally:
  160. try:
  161. os.unlink(path)
  162. except OSError:
  163. pass
  164. @contextlib.contextmanager
  165. def run_test_unix_server(*, use_ssl=False):
  166. with unix_socket_path() as path:
  167. yield from _run_test_server(address=path, use_ssl=use_ssl,
  168. server_cls=SilentUnixWSGIServer,
  169. server_ssl_cls=UnixSSLWSGIServer)
  170. @contextlib.contextmanager
  171. def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
  172. yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
  173. server_cls=SilentWSGIServer,
  174. server_ssl_cls=SSLWSGIServer)
  175. def make_test_protocol(base):
  176. dct = {}
  177. for name in dir(base):
  178. if name.startswith('__') and name.endswith('__'):
  179. # skip magic names
  180. continue
  181. dct[name] = MockCallback(return_value=None)
  182. return type('TestProtocol', (base,) + base.__bases__, dct)()
  183. class TestSelector(selectors.BaseSelector):
  184. def __init__(self):
  185. self.keys = {}
  186. def register(self, fileobj, events, data=None):
  187. key = selectors.SelectorKey(fileobj, 0, events, data)
  188. self.keys[fileobj] = key
  189. return key
  190. def unregister(self, fileobj):
  191. return self.keys.pop(fileobj)
  192. def select(self, timeout):
  193. return []
  194. def get_map(self):
  195. return self.keys
  196. class TestLoop(base_events.BaseEventLoop):
  197. """Loop for unittests.
  198. It manages self time directly.
  199. If something scheduled to be executed later then
  200. on next loop iteration after all ready handlers done
  201. generator passed to __init__ is calling.
  202. Generator should be like this:
  203. def gen():
  204. ...
  205. when = yield ...
  206. ... = yield time_advance
  207. Value returned by yield is absolute time of next scheduled handler.
  208. Value passed to yield is time advance to move loop's time forward.
  209. """
  210. def __init__(self, gen=None):
  211. super().__init__()
  212. if gen is None:
  213. def gen():
  214. yield
  215. self._check_on_close = False
  216. else:
  217. self._check_on_close = True
  218. self._gen = gen()
  219. next(self._gen)
  220. self._time = 0
  221. self._clock_resolution = 1e-9
  222. self._timers = []
  223. self._selector = TestSelector()
  224. self.readers = {}
  225. self.writers = {}
  226. self.reset_counters()
  227. def time(self):
  228. return self._time
  229. def advance_time(self, advance):
  230. """Move test time forward."""
  231. if advance:
  232. self._time += advance
  233. def close(self):
  234. super().close()
  235. if self._check_on_close:
  236. try:
  237. self._gen.send(0)
  238. except StopIteration:
  239. pass
  240. else: # pragma: no cover
  241. raise AssertionError("Time generator is not finished")
  242. def add_reader(self, fd, callback, *args):
  243. self.readers[fd] = events.Handle(callback, args, self)
  244. def remove_reader(self, fd):
  245. self.remove_reader_count[fd] += 1
  246. if fd in self.readers:
  247. del self.readers[fd]
  248. return True
  249. else:
  250. return False
  251. def assert_reader(self, fd, callback, *args):
  252. assert fd in self.readers, 'fd {} is not registered'.format(fd)
  253. handle = self.readers[fd]
  254. assert handle._callback == callback, '{!r} != {!r}'.format(
  255. handle._callback, callback)
  256. assert handle._args == args, '{!r} != {!r}'.format(
  257. handle._args, args)
  258. def add_writer(self, fd, callback, *args):
  259. self.writers[fd] = events.Handle(callback, args, self)
  260. def remove_writer(self, fd):
  261. self.remove_writer_count[fd] += 1
  262. if fd in self.writers:
  263. del self.writers[fd]
  264. return True
  265. else:
  266. return False
  267. def assert_writer(self, fd, callback, *args):
  268. assert fd in self.writers, 'fd {} is not registered'.format(fd)
  269. handle = self.writers[fd]
  270. assert handle._callback == callback, '{!r} != {!r}'.format(
  271. handle._callback, callback)
  272. assert handle._args == args, '{!r} != {!r}'.format(
  273. handle._args, args)
  274. def reset_counters(self):
  275. self.remove_reader_count = collections.defaultdict(int)
  276. self.remove_writer_count = collections.defaultdict(int)
  277. def _run_once(self):
  278. super()._run_once()
  279. for when in self._timers:
  280. advance = self._gen.send(when)
  281. self.advance_time(advance)
  282. self._timers = []
  283. def call_at(self, when, callback, *args):
  284. self._timers.append(when)
  285. return super().call_at(when, callback, *args)
  286. def _process_events(self, event_list):
  287. return
  288. def _write_to_self(self):
  289. pass
  290. def MockCallback(**kwargs):
  291. return mock.Mock(spec=['__call__'], **kwargs)
  292. class MockPattern(str):
  293. """A regex based str with a fuzzy __eq__.
  294. Use this helper with 'mock.assert_called_with', or anywhere
  295. where a regex comparison between strings is needed.
  296. For instance:
  297. mock_call.assert_called_with(MockPattern('spam.*ham'))
  298. """
  299. def __eq__(self, other):
  300. return bool(re.search(str(self), other, re.S))
  301. def get_function_source(func):
  302. source = events._get_function_source(func)
  303. if source is None:
  304. raise ValueError("unable to get the source of %r" % (func,))
  305. return source
  306. class TestCase(unittest.TestCase):
  307. def set_event_loop(self, loop, *, cleanup=True):
  308. assert loop is not None
  309. # ensure that the event loop is passed explicitly in asyncio
  310. events.set_event_loop(None)
  311. if cleanup:
  312. self.addCleanup(loop.close)
  313. def new_test_loop(self, gen=None):
  314. loop = TestLoop(gen)
  315. self.set_event_loop(loop)
  316. return loop
  317. def tearDown(self):
  318. events.set_event_loop(None)
  319. # Detect CPython bug #23353: ensure that yield/yield-from is not used
  320. # in an except block of a generator
  321. self.assertEqual(sys.exc_info(), (None, None, None))
  322. @contextlib.contextmanager
  323. def disable_logger():
  324. """Context manager to disable asyncio logger.
  325. For example, it can be used to ignore warnings in debug mode.
  326. """
  327. old_level = logger.level
  328. try:
  329. logger.setLevel(logging.CRITICAL+1)
  330. yield
  331. finally:
  332. logger.setLevel(old_level)
  333. def mock_nonblocking_socket():
  334. """Create a mock of a non-blocking socket."""
  335. sock = mock.Mock(socket.socket)
  336. sock.gettimeout.return_value = 0.0
  337. return sock
  338. def force_legacy_ssl_support():
  339. return mock.patch('asyncio.sslproto._is_sslproto_available',
  340. return_value=False)