• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
6import logging
7import os
8import re
9import socket
10import socketserver
11import sys
12import tempfile
13import threading
14import time
15import unittest
16import weakref
17
18from unittest import mock
19
20from http.server import HTTPServer
21from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
22
23try:
24    import ssl
25except ImportError:  # pragma: no cover
26    ssl = None
27
28from . import base_events
29from . import compat
30from . import events
31from . import futures
32from . import selectors
33from . import tasks
34from .coroutines import coroutine
35from .log import logger
36
37
38if sys.platform == 'win32':  # pragma: no cover
39    from .windows_utils import socketpair
40else:
41    from socket import socketpair  # pragma: no cover
42
43
44def dummy_ssl_context():
45    if ssl is None:
46        return None
47    else:
48        return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
49
50
51def run_briefly(loop):
52    @coroutine
53    def once():
54        pass
55    gen = once()
56    t = loop.create_task(gen)
57    # Don't log a warning if the task is not done after run_until_complete().
58    # It occurs if the loop is stopped or if a task raises a BaseException.
59    t._log_destroy_pending = False
60    try:
61        loop.run_until_complete(t)
62    finally:
63        gen.close()
64
65
66def run_until(loop, pred, timeout=30):
67    deadline = time.time() + timeout
68    while not pred():
69        if timeout is not None:
70            timeout = deadline - time.time()
71            if timeout <= 0:
72                raise futures.TimeoutError()
73        loop.run_until_complete(tasks.sleep(0.001, loop=loop))
74
75
76def run_once(loop):
77    """Legacy API to run once through the event loop.
78
79    This is the recommended pattern for test code.  It will poll the
80    selector once and run all callbacks scheduled in response to I/O
81    events.
82    """
83    loop.call_soon(loop.stop)
84    loop.run_forever()
85
86
87class SilentWSGIRequestHandler(WSGIRequestHandler):
88
89    def get_stderr(self):
90        return io.StringIO()
91
92    def log_message(self, format, *args):
93        pass
94
95
96class SilentWSGIServer(WSGIServer):
97
98    request_timeout = 2
99
100    def get_request(self):
101        request, client_addr = super().get_request()
102        request.settimeout(self.request_timeout)
103        return request, client_addr
104
105    def handle_error(self, request, client_address):
106        pass
107
108
109class SSLWSGIServerMixin:
110
111    def finish_request(self, request, client_address):
112        # The relative location of our test directory (which
113        # contains the ssl key and certificate files) differs
114        # between the stdlib and stand-alone asyncio.
115        # Prefer our own if we can find it.
116        here = os.path.join(os.path.dirname(__file__), '..', 'tests')
117        if not os.path.isdir(here):
118            here = os.path.join(os.path.dirname(os.__file__),
119                                'test', 'test_asyncio')
120        keyfile = os.path.join(here, 'ssl_key.pem')
121        certfile = os.path.join(here, 'ssl_cert.pem')
122        context = ssl.SSLContext()
123        context.load_cert_chain(certfile, keyfile)
124
125        ssock = context.wrap_socket(request, server_side=True)
126        try:
127            self.RequestHandlerClass(ssock, client_address, self)
128            ssock.close()
129        except OSError:
130            # maybe socket has been closed by peer
131            pass
132
133
134class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
135    pass
136
137
138def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
139
140    def app(environ, start_response):
141        status = '200 OK'
142        headers = [('Content-type', 'text/plain')]
143        start_response(status, headers)
144        return [b'Test message']
145
146    # Run the test WSGI server in a separate thread in order not to
147    # interfere with event handling in the main thread
148    server_class = server_ssl_cls if use_ssl else server_cls
149    httpd = server_class(address, SilentWSGIRequestHandler)
150    httpd.set_app(app)
151    httpd.address = httpd.server_address
152    server_thread = threading.Thread(
153        target=lambda: httpd.serve_forever(poll_interval=0.05))
154    server_thread.start()
155    try:
156        yield httpd
157    finally:
158        httpd.shutdown()
159        httpd.server_close()
160        server_thread.join()
161
162
163if hasattr(socket, 'AF_UNIX'):
164
165    class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
166
167        def server_bind(self):
168            socketserver.UnixStreamServer.server_bind(self)
169            self.server_name = '127.0.0.1'
170            self.server_port = 80
171
172
173    class UnixWSGIServer(UnixHTTPServer, WSGIServer):
174
175        request_timeout = 2
176
177        def server_bind(self):
178            UnixHTTPServer.server_bind(self)
179            self.setup_environ()
180
181        def get_request(self):
182            request, client_addr = super().get_request()
183            request.settimeout(self.request_timeout)
184            # Code in the stdlib expects that get_request
185            # will return a socket and a tuple (host, port).
186            # However, this isn't true for UNIX sockets,
187            # as the second return value will be a path;
188            # hence we return some fake data sufficient
189            # to get the tests going
190            return request, ('127.0.0.1', '')
191
192
193    class SilentUnixWSGIServer(UnixWSGIServer):
194
195        def handle_error(self, request, client_address):
196            pass
197
198
199    class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
200        pass
201
202
203    def gen_unix_socket_path():
204        with tempfile.NamedTemporaryFile() as file:
205            return file.name
206
207
208    @contextlib.contextmanager
209    def unix_socket_path():
210        path = gen_unix_socket_path()
211        try:
212            yield path
213        finally:
214            try:
215                os.unlink(path)
216            except OSError:
217                pass
218
219
220    @contextlib.contextmanager
221    def run_test_unix_server(*, use_ssl=False):
222        with unix_socket_path() as path:
223            yield from _run_test_server(address=path, use_ssl=use_ssl,
224                                        server_cls=SilentUnixWSGIServer,
225                                        server_ssl_cls=UnixSSLWSGIServer)
226
227
228@contextlib.contextmanager
229def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
230    yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
231                                server_cls=SilentWSGIServer,
232                                server_ssl_cls=SSLWSGIServer)
233
234
235def make_test_protocol(base):
236    dct = {}
237    for name in dir(base):
238        if name.startswith('__') and name.endswith('__'):
239            # skip magic names
240            continue
241        dct[name] = MockCallback(return_value=None)
242    return type('TestProtocol', (base,) + base.__bases__, dct)()
243
244
245class TestSelector(selectors.BaseSelector):
246
247    def __init__(self):
248        self.keys = {}
249
250    def register(self, fileobj, events, data=None):
251        key = selectors.SelectorKey(fileobj, 0, events, data)
252        self.keys[fileobj] = key
253        return key
254
255    def unregister(self, fileobj):
256        return self.keys.pop(fileobj)
257
258    def select(self, timeout):
259        return []
260
261    def get_map(self):
262        return self.keys
263
264
265class TestLoop(base_events.BaseEventLoop):
266    """Loop for unittests.
267
268    It manages self time directly.
269    If something scheduled to be executed later then
270    on next loop iteration after all ready handlers done
271    generator passed to __init__ is calling.
272
273    Generator should be like this:
274
275        def gen():
276            ...
277            when = yield ...
278            ... = yield time_advance
279
280    Value returned by yield is absolute time of next scheduled handler.
281    Value passed to yield is time advance to move loop's time forward.
282    """
283
284    def __init__(self, gen=None):
285        super().__init__()
286
287        if gen is None:
288            def gen():
289                yield
290            self._check_on_close = False
291        else:
292            self._check_on_close = True
293
294        self._gen = gen()
295        next(self._gen)
296        self._time = 0
297        self._clock_resolution = 1e-9
298        self._timers = []
299        self._selector = TestSelector()
300
301        self.readers = {}
302        self.writers = {}
303        self.reset_counters()
304
305        self._transports = weakref.WeakValueDictionary()
306
307    def time(self):
308        return self._time
309
310    def advance_time(self, advance):
311        """Move test time forward."""
312        if advance:
313            self._time += advance
314
315    def close(self):
316        super().close()
317        if self._check_on_close:
318            try:
319                self._gen.send(0)
320            except StopIteration:
321                pass
322            else:  # pragma: no cover
323                raise AssertionError("Time generator is not finished")
324
325    def _add_reader(self, fd, callback, *args):
326        self.readers[fd] = events.Handle(callback, args, self)
327
328    def _remove_reader(self, fd):
329        self.remove_reader_count[fd] += 1
330        if fd in self.readers:
331            del self.readers[fd]
332            return True
333        else:
334            return False
335
336    def assert_reader(self, fd, callback, *args):
337        assert fd in self.readers, 'fd {} is not registered'.format(fd)
338        handle = self.readers[fd]
339        assert handle._callback == callback, '{!r} != {!r}'.format(
340            handle._callback, callback)
341        assert handle._args == args, '{!r} != {!r}'.format(
342            handle._args, args)
343
344    def _add_writer(self, fd, callback, *args):
345        self.writers[fd] = events.Handle(callback, args, self)
346
347    def _remove_writer(self, fd):
348        self.remove_writer_count[fd] += 1
349        if fd in self.writers:
350            del self.writers[fd]
351            return True
352        else:
353            return False
354
355    def assert_writer(self, fd, callback, *args):
356        assert fd in self.writers, 'fd {} is not registered'.format(fd)
357        handle = self.writers[fd]
358        assert handle._callback == callback, '{!r} != {!r}'.format(
359            handle._callback, callback)
360        assert handle._args == args, '{!r} != {!r}'.format(
361            handle._args, args)
362
363    def _ensure_fd_no_transport(self, fd):
364        try:
365            transport = self._transports[fd]
366        except KeyError:
367            pass
368        else:
369            raise RuntimeError(
370                'File descriptor {!r} is used by transport {!r}'.format(
371                    fd, transport))
372
373    def add_reader(self, fd, callback, *args):
374        """Add a reader callback."""
375        self._ensure_fd_no_transport(fd)
376        return self._add_reader(fd, callback, *args)
377
378    def remove_reader(self, fd):
379        """Remove a reader callback."""
380        self._ensure_fd_no_transport(fd)
381        return self._remove_reader(fd)
382
383    def add_writer(self, fd, callback, *args):
384        """Add a writer callback.."""
385        self._ensure_fd_no_transport(fd)
386        return self._add_writer(fd, callback, *args)
387
388    def remove_writer(self, fd):
389        """Remove a writer callback."""
390        self._ensure_fd_no_transport(fd)
391        return self._remove_writer(fd)
392
393    def reset_counters(self):
394        self.remove_reader_count = collections.defaultdict(int)
395        self.remove_writer_count = collections.defaultdict(int)
396
397    def _run_once(self):
398        super()._run_once()
399        for when in self._timers:
400            advance = self._gen.send(when)
401            self.advance_time(advance)
402        self._timers = []
403
404    def call_at(self, when, callback, *args):
405        self._timers.append(when)
406        return super().call_at(when, callback, *args)
407
408    def _process_events(self, event_list):
409        return
410
411    def _write_to_self(self):
412        pass
413
414
415def MockCallback(**kwargs):
416    return mock.Mock(spec=['__call__'], **kwargs)
417
418
419class MockPattern(str):
420    """A regex based str with a fuzzy __eq__.
421
422    Use this helper with 'mock.assert_called_with', or anywhere
423    where a regex comparison between strings is needed.
424
425    For instance:
426       mock_call.assert_called_with(MockPattern('spam.*ham'))
427    """
428    def __eq__(self, other):
429        return bool(re.search(str(self), other, re.S))
430
431
432def get_function_source(func):
433    source = events._get_function_source(func)
434    if source is None:
435        raise ValueError("unable to get the source of %r" % (func,))
436    return source
437
438
439class TestCase(unittest.TestCase):
440    def set_event_loop(self, loop, *, cleanup=True):
441        assert loop is not None
442        # ensure that the event loop is passed explicitly in asyncio
443        events.set_event_loop(None)
444        if cleanup:
445            self.addCleanup(loop.close)
446
447    def new_test_loop(self, gen=None):
448        loop = TestLoop(gen)
449        self.set_event_loop(loop)
450        return loop
451
452    def unpatch_get_running_loop(self):
453        events._get_running_loop = self._get_running_loop
454
455    def setUp(self):
456        self._get_running_loop = events._get_running_loop
457        events._get_running_loop = lambda: None
458
459    def tearDown(self):
460        self.unpatch_get_running_loop()
461
462        events.set_event_loop(None)
463
464        # Detect CPython bug #23353: ensure that yield/yield-from is not used
465        # in an except block of a generator
466        self.assertEqual(sys.exc_info(), (None, None, None))
467
468    if not compat.PY34:
469        # Python 3.3 compatibility
470        def subTest(self, *args, **kwargs):
471            class EmptyCM:
472                def __enter__(self):
473                    pass
474                def __exit__(self, *exc):
475                    pass
476            return EmptyCM()
477
478
479@contextlib.contextmanager
480def disable_logger():
481    """Context manager to disable asyncio logger.
482
483    For example, it can be used to ignore warnings in debug mode.
484    """
485    old_level = logger.level
486    try:
487        logger.setLevel(logging.CRITICAL+1)
488        yield
489    finally:
490        logger.setLevel(old_level)
491
492
493def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
494                            family=socket.AF_INET):
495    """Create a mock of a non-blocking socket."""
496    sock = mock.MagicMock(socket.socket)
497    sock.proto = proto
498    sock.type = type
499    sock.family = family
500    sock.gettimeout.return_value = 0.0
501    return sock
502
503
504def force_legacy_ssl_support():
505    return mock.patch('asyncio.sslproto._is_sslproto_available',
506                      return_value=False)
507