123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446 |
- """Utilities shared by tests."""
- import collections
- import contextlib
- import io
- import logging
- import os
- import re
- import socket
- import socketserver
- import sys
- import tempfile
- import threading
- import time
- import unittest
- from unittest import mock
- from http.server import HTTPServer
- from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
- try:
- import ssl
- except ImportError: # pragma: no cover
- ssl = None
- from . import base_events
- from . import events
- from . import futures
- from . import selectors
- from . import tasks
- from .coroutines import coroutine
- from .log import logger
- if sys.platform == 'win32': # pragma: no cover
- from .windows_utils import socketpair
- else:
- from socket import socketpair # pragma: no cover
- def dummy_ssl_context():
- if ssl is None:
- return None
- else:
- return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- def run_briefly(loop):
- @coroutine
- def once():
- pass
- gen = once()
- t = loop.create_task(gen)
- # Don't log a warning if the task is not done after run_until_complete().
- # It occurs if the loop is stopped or if a task raises a BaseException.
- t._log_destroy_pending = False
- try:
- loop.run_until_complete(t)
- finally:
- gen.close()
- def run_until(loop, pred, timeout=30):
- deadline = time.time() + timeout
- while not pred():
- if timeout is not None:
- timeout = deadline - time.time()
- if timeout <= 0:
- raise futures.TimeoutError()
- loop.run_until_complete(tasks.sleep(0.001, loop=loop))
- def run_once(loop):
- """loop.stop() schedules _raise_stop_error()
- and run_forever() runs until _raise_stop_error() callback.
- this wont work if test waits for some IO events, because
- _raise_stop_error() runs before any of io events callbacks.
- """
- loop.stop()
- loop.run_forever()
- class SilentWSGIRequestHandler(WSGIRequestHandler):
- def get_stderr(self):
- return io.StringIO()
- def log_message(self, format, *args):
- pass
- class SilentWSGIServer(WSGIServer):
- request_timeout = 2
- def get_request(self):
- request, client_addr = super().get_request()
- request.settimeout(self.request_timeout)
- return request, client_addr
- def handle_error(self, request, client_address):
- pass
- class SSLWSGIServerMixin:
- def finish_request(self, request, client_address):
- # The relative location of our test directory (which
- # contains the ssl key and certificate files) differs
- # between the stdlib and stand-alone asyncio.
- # Prefer our own if we can find it.
- here = os.path.join(os.path.dirname(__file__), '..', 'tests')
- if not os.path.isdir(here):
- here = os.path.join(os.path.dirname(os.__file__),
- 'test', 'test_asyncio')
- keyfile = os.path.join(here, 'ssl_key.pem')
- certfile = os.path.join(here, 'ssl_cert.pem')
- ssock = ssl.wrap_socket(request,
- keyfile=keyfile,
- certfile=certfile,
- server_side=True)
- try:
- self.RequestHandlerClass(ssock, client_address, self)
- ssock.close()
- except OSError:
- # maybe socket has been closed by peer
- pass
- class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
- pass
- def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
- def app(environ, start_response):
- status = '200 OK'
- headers = [('Content-type', 'text/plain')]
- start_response(status, headers)
- return [b'Test message']
- # Run the test WSGI server in a separate thread in order not to
- # interfere with event handling in the main thread
- server_class = server_ssl_cls if use_ssl else server_cls
- httpd = server_class(address, SilentWSGIRequestHandler)
- httpd.set_app(app)
- httpd.address = httpd.server_address
- server_thread = threading.Thread(
- target=lambda: httpd.serve_forever(poll_interval=0.05))
- server_thread.start()
- try:
- yield httpd
- finally:
- httpd.shutdown()
- httpd.server_close()
- server_thread.join()
- if hasattr(socket, 'AF_UNIX'):
- class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
- def server_bind(self):
- socketserver.UnixStreamServer.server_bind(self)
- self.server_name = '127.0.0.1'
- self.server_port = 80
- class UnixWSGIServer(UnixHTTPServer, WSGIServer):
- request_timeout = 2
- def server_bind(self):
- UnixHTTPServer.server_bind(self)
- self.setup_environ()
- def get_request(self):
- request, client_addr = super().get_request()
- request.settimeout(self.request_timeout)
- # Code in the stdlib expects that get_request
- # will return a socket and a tuple (host, port).
- # However, this isn't true for UNIX sockets,
- # as the second return value will be a path;
- # hence we return some fake data sufficient
- # to get the tests going
- return request, ('127.0.0.1', '')
- class SilentUnixWSGIServer(UnixWSGIServer):
- def handle_error(self, request, client_address):
- pass
- class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
- pass
- def gen_unix_socket_path():
- with tempfile.NamedTemporaryFile() as file:
- return file.name
- @contextlib.contextmanager
- def unix_socket_path():
- path = gen_unix_socket_path()
- try:
- yield path
- finally:
- try:
- os.unlink(path)
- except OSError:
- pass
- @contextlib.contextmanager
- def run_test_unix_server(*, use_ssl=False):
- with unix_socket_path() as path:
- yield from _run_test_server(address=path, use_ssl=use_ssl,
- server_cls=SilentUnixWSGIServer,
- server_ssl_cls=UnixSSLWSGIServer)
- @contextlib.contextmanager
- def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
- yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
- server_cls=SilentWSGIServer,
- server_ssl_cls=SSLWSGIServer)
- def make_test_protocol(base):
- dct = {}
- for name in dir(base):
- if name.startswith('__') and name.endswith('__'):
- # skip magic names
- continue
- dct[name] = MockCallback(return_value=None)
- return type('TestProtocol', (base,) + base.__bases__, dct)()
- class TestSelector(selectors.BaseSelector):
- def __init__(self):
- self.keys = {}
- def register(self, fileobj, events, data=None):
- key = selectors.SelectorKey(fileobj, 0, events, data)
- self.keys[fileobj] = key
- return key
- def unregister(self, fileobj):
- return self.keys.pop(fileobj)
- def select(self, timeout):
- return []
- def get_map(self):
- return self.keys
- class TestLoop(base_events.BaseEventLoop):
- """Loop for unittests.
- It manages self time directly.
- If something scheduled to be executed later then
- on next loop iteration after all ready handlers done
- generator passed to __init__ is calling.
- Generator should be like this:
- def gen():
- ...
- when = yield ...
- ... = yield time_advance
- Value returned by yield is absolute time of next scheduled handler.
- Value passed to yield is time advance to move loop's time forward.
- """
- def __init__(self, gen=None):
- super().__init__()
- if gen is None:
- def gen():
- yield
- self._check_on_close = False
- else:
- self._check_on_close = True
- self._gen = gen()
- next(self._gen)
- self._time = 0
- self._clock_resolution = 1e-9
- self._timers = []
- self._selector = TestSelector()
- self.readers = {}
- self.writers = {}
- self.reset_counters()
- def time(self):
- return self._time
- def advance_time(self, advance):
- """Move test time forward."""
- if advance:
- self._time += advance
- def close(self):
- super().close()
- if self._check_on_close:
- try:
- self._gen.send(0)
- except StopIteration:
- pass
- else: # pragma: no cover
- raise AssertionError("Time generator is not finished")
- def add_reader(self, fd, callback, *args):
- self.readers[fd] = events.Handle(callback, args, self)
- def remove_reader(self, fd):
- self.remove_reader_count[fd] += 1
- if fd in self.readers:
- del self.readers[fd]
- return True
- else:
- return False
- def assert_reader(self, fd, callback, *args):
- assert fd in self.readers, 'fd {} is not registered'.format(fd)
- handle = self.readers[fd]
- assert handle._callback == callback, '{!r} != {!r}'.format(
- handle._callback, callback)
- assert handle._args == args, '{!r} != {!r}'.format(
- handle._args, args)
- def add_writer(self, fd, callback, *args):
- self.writers[fd] = events.Handle(callback, args, self)
- def remove_writer(self, fd):
- self.remove_writer_count[fd] += 1
- if fd in self.writers:
- del self.writers[fd]
- return True
- else:
- return False
- def assert_writer(self, fd, callback, *args):
- assert fd in self.writers, 'fd {} is not registered'.format(fd)
- handle = self.writers[fd]
- assert handle._callback == callback, '{!r} != {!r}'.format(
- handle._callback, callback)
- assert handle._args == args, '{!r} != {!r}'.format(
- handle._args, args)
- def reset_counters(self):
- self.remove_reader_count = collections.defaultdict(int)
- self.remove_writer_count = collections.defaultdict(int)
- def _run_once(self):
- super()._run_once()
- for when in self._timers:
- advance = self._gen.send(when)
- self.advance_time(advance)
- self._timers = []
- def call_at(self, when, callback, *args):
- self._timers.append(when)
- return super().call_at(when, callback, *args)
- def _process_events(self, event_list):
- return
- def _write_to_self(self):
- pass
- def MockCallback(**kwargs):
- return mock.Mock(spec=['__call__'], **kwargs)
- class MockPattern(str):
- """A regex based str with a fuzzy __eq__.
- Use this helper with 'mock.assert_called_with', or anywhere
- where a regex comparison between strings is needed.
- For instance:
- mock_call.assert_called_with(MockPattern('spam.*ham'))
- """
- def __eq__(self, other):
- return bool(re.search(str(self), other, re.S))
- def get_function_source(func):
- source = events._get_function_source(func)
- if source is None:
- raise ValueError("unable to get the source of %r" % (func,))
- return source
- class TestCase(unittest.TestCase):
- def set_event_loop(self, loop, *, cleanup=True):
- assert loop is not None
- # ensure that the event loop is passed explicitly in asyncio
- events.set_event_loop(None)
- if cleanup:
- self.addCleanup(loop.close)
- def new_test_loop(self, gen=None):
- loop = TestLoop(gen)
- self.set_event_loop(loop)
- return loop
- def tearDown(self):
- events.set_event_loop(None)
- # Detect CPython bug #23353: ensure that yield/yield-from is not used
- # in an except block of a generator
- self.assertEqual(sys.exc_info(), (None, None, None))
- @contextlib.contextmanager
- def disable_logger():
- """Context manager to disable asyncio logger.
- For example, it can be used to ignore warnings in debug mode.
- """
- old_level = logger.level
- try:
- logger.setLevel(logging.CRITICAL+1)
- yield
- finally:
- logger.setLevel(old_level)
- def mock_nonblocking_socket():
- """Create a mock of a non-blocking socket."""
- sock = mock.Mock(socket.socket)
- sock.gettimeout.return_value = 0.0
- return sock
- def force_legacy_ssl_support():
- return mock.patch('asyncio.sslproto._is_sslproto_available',
- return_value=False)
|