• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Utilities shared by tests."""
2
3import asyncio
4import collections
5import contextlib
6import io
7import logging
8import os
9import re
10import selectors
11import socket
12import socketserver
13import sys
14import tempfile
15import threading
16import time
17import unittest
18import weakref
19
20from unittest import mock
21
22from http.server import HTTPServer
23from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
24
25try:
26    import ssl
27except ImportError:  # pragma: no cover
28    ssl = None
29
30from asyncio import base_events
31from asyncio import events
32from asyncio import format_helpers
33from asyncio import futures
34from asyncio import tasks
35from asyncio.log import logger
36from test import support
37from test.support import threading_helper
38
39
40def data_file(filename):
41    if hasattr(support, 'TEST_HOME_DIR'):
42        fullname = os.path.join(support.TEST_HOME_DIR, filename)
43        if os.path.isfile(fullname):
44            return fullname
45    fullname = os.path.join(os.path.dirname(__file__), '..', filename)
46    if os.path.isfile(fullname):
47        return fullname
48    raise FileNotFoundError(filename)
49
50
51ONLYCERT = data_file('ssl_cert.pem')
52ONLYKEY = data_file('ssl_key.pem')
53SIGNED_CERTFILE = data_file('keycert3.pem')
54SIGNING_CA = data_file('pycacert.pem')
55PEERCERT = {
56    'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
57    'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
58    'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
59    'issuer': ((('countryName', 'XY'),),
60            (('organizationName', 'Python Software Foundation CA'),),
61            (('commonName', 'our-ca-server'),)),
62    'notAfter': 'Oct 28 14:23:16 2037 GMT',
63    'notBefore': 'Aug 29 14:23:16 2018 GMT',
64    'serialNumber': 'CB2D80995A69525C',
65    'subject': ((('countryName', 'XY'),),
66             (('localityName', 'Castle Anthrax'),),
67             (('organizationName', 'Python Software Foundation'),),
68             (('commonName', 'localhost'),)),
69    'subjectAltName': (('DNS', 'localhost'),),
70    'version': 3
71}
72
73
74def simple_server_sslcontext():
75    server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
76    server_context.load_cert_chain(ONLYCERT, ONLYKEY)
77    server_context.check_hostname = False
78    server_context.verify_mode = ssl.CERT_NONE
79    return server_context
80
81
82def simple_client_sslcontext(*, disable_verify=True):
83    client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
84    client_context.check_hostname = False
85    if disable_verify:
86        client_context.verify_mode = ssl.CERT_NONE
87    return client_context
88
89
90def dummy_ssl_context():
91    if ssl is None:
92        return None
93    else:
94        return simple_client_sslcontext(disable_verify=True)
95
96
97def run_briefly(loop):
98    async def once():
99        pass
100    gen = once()
101    t = loop.create_task(gen)
102    # Don't log a warning if the task is not done after run_until_complete().
103    # It occurs if the loop is stopped or if a task raises a BaseException.
104    t._log_destroy_pending = False
105    try:
106        loop.run_until_complete(t)
107    finally:
108        gen.close()
109
110
111def run_until(loop, pred, timeout=support.SHORT_TIMEOUT):
112    deadline = time.monotonic() + timeout
113    while not pred():
114        if timeout is not None:
115            timeout = deadline - time.monotonic()
116            if timeout <= 0:
117                raise futures.TimeoutError()
118        loop.run_until_complete(tasks.sleep(0.001))
119
120
121def run_once(loop):
122    """Legacy API to run once through the event loop.
123
124    This is the recommended pattern for test code.  It will poll the
125    selector once and run all callbacks scheduled in response to I/O
126    events.
127    """
128    loop.call_soon(loop.stop)
129    loop.run_forever()
130
131
132class SilentWSGIRequestHandler(WSGIRequestHandler):
133
134    def get_stderr(self):
135        return io.StringIO()
136
137    def log_message(self, format, *args):
138        pass
139
140
141class SilentWSGIServer(WSGIServer):
142
143    request_timeout = support.LOOPBACK_TIMEOUT
144
145    def get_request(self):
146        request, client_addr = super().get_request()
147        request.settimeout(self.request_timeout)
148        return request, client_addr
149
150    def handle_error(self, request, client_address):
151        pass
152
153
154class SSLWSGIServerMixin:
155
156    def finish_request(self, request, client_address):
157        # The relative location of our test directory (which
158        # contains the ssl key and certificate files) differs
159        # between the stdlib and stand-alone asyncio.
160        # Prefer our own if we can find it.
161        context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
162        context.load_cert_chain(ONLYCERT, ONLYKEY)
163
164        ssock = context.wrap_socket(request, server_side=True)
165        try:
166            self.RequestHandlerClass(ssock, client_address, self)
167            ssock.close()
168        except OSError:
169            # maybe socket has been closed by peer
170            pass
171
172
173class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
174    pass
175
176
177def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
178
179    def loop(environ):
180        size = int(environ['CONTENT_LENGTH'])
181        while size:
182            data = environ['wsgi.input'].read(min(size, 0x10000))
183            yield data
184            size -= len(data)
185
186    def app(environ, start_response):
187        status = '200 OK'
188        headers = [('Content-type', 'text/plain')]
189        start_response(status, headers)
190        if environ['PATH_INFO'] == '/loop':
191            return loop(environ)
192        else:
193            return [b'Test message']
194
195    # Run the test WSGI server in a separate thread in order not to
196    # interfere with event handling in the main thread
197    server_class = server_ssl_cls if use_ssl else server_cls
198    httpd = server_class(address, SilentWSGIRequestHandler)
199    httpd.set_app(app)
200    httpd.address = httpd.server_address
201    server_thread = threading.Thread(
202        target=lambda: httpd.serve_forever(poll_interval=0.05))
203    server_thread.start()
204    try:
205        yield httpd
206    finally:
207        httpd.shutdown()
208        httpd.server_close()
209        server_thread.join()
210
211
212if hasattr(socket, 'AF_UNIX'):
213
214    class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
215
216        def server_bind(self):
217            socketserver.UnixStreamServer.server_bind(self)
218            self.server_name = '127.0.0.1'
219            self.server_port = 80
220
221
222    class UnixWSGIServer(UnixHTTPServer, WSGIServer):
223
224        request_timeout = support.LOOPBACK_TIMEOUT
225
226        def server_bind(self):
227            UnixHTTPServer.server_bind(self)
228            self.setup_environ()
229
230        def get_request(self):
231            request, client_addr = super().get_request()
232            request.settimeout(self.request_timeout)
233            # Code in the stdlib expects that get_request
234            # will return a socket and a tuple (host, port).
235            # However, this isn't true for UNIX sockets,
236            # as the second return value will be a path;
237            # hence we return some fake data sufficient
238            # to get the tests going
239            return request, ('127.0.0.1', '')
240
241
242    class SilentUnixWSGIServer(UnixWSGIServer):
243
244        def handle_error(self, request, client_address):
245            pass
246
247
248    class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
249        pass
250
251
252    def gen_unix_socket_path():
253        with tempfile.NamedTemporaryFile() as file:
254            return file.name
255
256
257    @contextlib.contextmanager
258    def unix_socket_path():
259        path = gen_unix_socket_path()
260        try:
261            yield path
262        finally:
263            try:
264                os.unlink(path)
265            except OSError:
266                pass
267
268
269    @contextlib.contextmanager
270    def run_test_unix_server(*, use_ssl=False):
271        with unix_socket_path() as path:
272            yield from _run_test_server(address=path, use_ssl=use_ssl,
273                                        server_cls=SilentUnixWSGIServer,
274                                        server_ssl_cls=UnixSSLWSGIServer)
275
276
277@contextlib.contextmanager
278def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
279    yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
280                                server_cls=SilentWSGIServer,
281                                server_ssl_cls=SSLWSGIServer)
282
283
284def make_test_protocol(base):
285    dct = {}
286    for name in dir(base):
287        if name.startswith('__') and name.endswith('__'):
288            # skip magic names
289            continue
290        dct[name] = MockCallback(return_value=None)
291    return type('TestProtocol', (base,) + base.__bases__, dct)()
292
293
294class TestSelector(selectors.BaseSelector):
295
296    def __init__(self):
297        self.keys = {}
298
299    def register(self, fileobj, events, data=None):
300        key = selectors.SelectorKey(fileobj, 0, events, data)
301        self.keys[fileobj] = key
302        return key
303
304    def unregister(self, fileobj):
305        return self.keys.pop(fileobj)
306
307    def select(self, timeout):
308        return []
309
310    def get_map(self):
311        return self.keys
312
313
314class TestLoop(base_events.BaseEventLoop):
315    """Loop for unittests.
316
317    It manages self time directly.
318    If something scheduled to be executed later then
319    on next loop iteration after all ready handlers done
320    generator passed to __init__ is calling.
321
322    Generator should be like this:
323
324        def gen():
325            ...
326            when = yield ...
327            ... = yield time_advance
328
329    Value returned by yield is absolute time of next scheduled handler.
330    Value passed to yield is time advance to move loop's time forward.
331    """
332
333    def __init__(self, gen=None):
334        super().__init__()
335
336        if gen is None:
337            def gen():
338                yield
339            self._check_on_close = False
340        else:
341            self._check_on_close = True
342
343        self._gen = gen()
344        next(self._gen)
345        self._time = 0
346        self._clock_resolution = 1e-9
347        self._timers = []
348        self._selector = TestSelector()
349
350        self.readers = {}
351        self.writers = {}
352        self.reset_counters()
353
354        self._transports = weakref.WeakValueDictionary()
355
356    def time(self):
357        return self._time
358
359    def advance_time(self, advance):
360        """Move test time forward."""
361        if advance:
362            self._time += advance
363
364    def close(self):
365        super().close()
366        if self._check_on_close:
367            try:
368                self._gen.send(0)
369            except StopIteration:
370                pass
371            else:  # pragma: no cover
372                raise AssertionError("Time generator is not finished")
373
374    def _add_reader(self, fd, callback, *args):
375        self.readers[fd] = events.Handle(callback, args, self, None)
376
377    def _remove_reader(self, fd):
378        self.remove_reader_count[fd] += 1
379        if fd in self.readers:
380            del self.readers[fd]
381            return True
382        else:
383            return False
384
385    def assert_reader(self, fd, callback, *args):
386        if fd not in self.readers:
387            raise AssertionError(f'fd {fd} is not registered')
388        handle = self.readers[fd]
389        if handle._callback != callback:
390            raise AssertionError(
391                f'unexpected callback: {handle._callback} != {callback}')
392        if handle._args != args:
393            raise AssertionError(
394                f'unexpected callback args: {handle._args} != {args}')
395
396    def assert_no_reader(self, fd):
397        if fd in self.readers:
398            raise AssertionError(f'fd {fd} is registered')
399
400    def _add_writer(self, fd, callback, *args):
401        self.writers[fd] = events.Handle(callback, args, self, None)
402
403    def _remove_writer(self, fd):
404        self.remove_writer_count[fd] += 1
405        if fd in self.writers:
406            del self.writers[fd]
407            return True
408        else:
409            return False
410
411    def assert_writer(self, fd, callback, *args):
412        if fd not in self.writers:
413            raise AssertionError(f'fd {fd} is not registered')
414        handle = self.writers[fd]
415        if handle._callback != callback:
416            raise AssertionError(f'{handle._callback!r} != {callback!r}')
417        if handle._args != args:
418            raise AssertionError(f'{handle._args!r} != {args!r}')
419
420    def _ensure_fd_no_transport(self, fd):
421        if not isinstance(fd, int):
422            try:
423                fd = int(fd.fileno())
424            except (AttributeError, TypeError, ValueError):
425                # This code matches selectors._fileobj_to_fd function.
426                raise ValueError("Invalid file object: "
427                                 "{!r}".format(fd)) from None
428        try:
429            transport = self._transports[fd]
430        except KeyError:
431            pass
432        else:
433            raise RuntimeError(
434                'File descriptor {!r} is used by transport {!r}'.format(
435                    fd, transport))
436
437    def add_reader(self, fd, callback, *args):
438        """Add a reader callback."""
439        self._ensure_fd_no_transport(fd)
440        return self._add_reader(fd, callback, *args)
441
442    def remove_reader(self, fd):
443        """Remove a reader callback."""
444        self._ensure_fd_no_transport(fd)
445        return self._remove_reader(fd)
446
447    def add_writer(self, fd, callback, *args):
448        """Add a writer callback.."""
449        self._ensure_fd_no_transport(fd)
450        return self._add_writer(fd, callback, *args)
451
452    def remove_writer(self, fd):
453        """Remove a writer callback."""
454        self._ensure_fd_no_transport(fd)
455        return self._remove_writer(fd)
456
457    def reset_counters(self):
458        self.remove_reader_count = collections.defaultdict(int)
459        self.remove_writer_count = collections.defaultdict(int)
460
461    def _run_once(self):
462        super()._run_once()
463        for when in self._timers:
464            advance = self._gen.send(when)
465            self.advance_time(advance)
466        self._timers = []
467
468    def call_at(self, when, callback, *args, context=None):
469        self._timers.append(when)
470        return super().call_at(when, callback, *args, context=context)
471
472    def _process_events(self, event_list):
473        return
474
475    def _write_to_self(self):
476        pass
477
478
479def MockCallback(**kwargs):
480    return mock.Mock(spec=['__call__'], **kwargs)
481
482
483class MockPattern(str):
484    """A regex based str with a fuzzy __eq__.
485
486    Use this helper with 'mock.assert_called_with', or anywhere
487    where a regex comparison between strings is needed.
488
489    For instance:
490       mock_call.assert_called_with(MockPattern('spam.*ham'))
491    """
492    def __eq__(self, other):
493        return bool(re.search(str(self), other, re.S))
494
495
496class MockInstanceOf:
497    def __init__(self, type):
498        self._type = type
499
500    def __eq__(self, other):
501        return isinstance(other, self._type)
502
503
504def get_function_source(func):
505    source = format_helpers._get_function_source(func)
506    if source is None:
507        raise ValueError("unable to get the source of %r" % (func,))
508    return source
509
510
511class TestCase(unittest.TestCase):
512    @staticmethod
513    def close_loop(loop):
514        if loop._default_executor is not None:
515            if not loop.is_closed():
516                loop.run_until_complete(loop.shutdown_default_executor())
517            else:
518                loop._default_executor.shutdown(wait=True)
519        loop.close()
520        policy = support.maybe_get_event_loop_policy()
521        if policy is not None:
522            try:
523                watcher = policy.get_child_watcher()
524            except NotImplementedError:
525                # watcher is not implemented by EventLoopPolicy, e.g. Windows
526                pass
527            else:
528                if isinstance(watcher, asyncio.ThreadedChildWatcher):
529                    threads = list(watcher._threads.values())
530                    for thread in threads:
531                        thread.join()
532
533    def set_event_loop(self, loop, *, cleanup=True):
534        if loop is None:
535            raise AssertionError('loop is None')
536        # ensure that the event loop is passed explicitly in asyncio
537        events.set_event_loop(None)
538        if cleanup:
539            self.addCleanup(self.close_loop, loop)
540
541    def new_test_loop(self, gen=None):
542        loop = TestLoop(gen)
543        self.set_event_loop(loop)
544        return loop
545
546    def setUp(self):
547        self._thread_cleanup = threading_helper.threading_setup()
548
549    def tearDown(self):
550        events.set_event_loop(None)
551
552        # Detect CPython bug #23353: ensure that yield/yield-from is not used
553        # in an except block of a generator
554        self.assertEqual(sys.exc_info(), (None, None, None))
555
556        self.doCleanups()
557        threading_helper.threading_cleanup(*self._thread_cleanup)
558        support.reap_children()
559
560
561@contextlib.contextmanager
562def disable_logger():
563    """Context manager to disable asyncio logger.
564
565    For example, it can be used to ignore warnings in debug mode.
566    """
567    old_level = logger.level
568    try:
569        logger.setLevel(logging.CRITICAL+1)
570        yield
571    finally:
572        logger.setLevel(old_level)
573
574
575def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
576                            family=socket.AF_INET):
577    """Create a mock of a non-blocking socket."""
578    sock = mock.MagicMock(socket.socket)
579    sock.proto = proto
580    sock.type = type
581    sock.family = family
582    sock.gettimeout.return_value = 0.0
583    return sock
584