• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Utilities shared by tests."""
2
3import asyncio
4import collections
5import contextlib
6import io
7import logging
8import os
9import re
10import selectors
11import socket
12import socketserver
13import sys
14import threading
15import unittest
16import weakref
17import warnings
18from unittest import mock
19
20from http.server import HTTPServer
21from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
22
23try:
24    import ssl
25except ImportError:  # pragma: no cover
26    ssl = None
27
28from asyncio import base_events
29from asyncio import events
30from asyncio import format_helpers
31from asyncio import futures
32from asyncio import tasks
33from asyncio.log import logger
34from test import support
35from test.support import socket_helper
36from test.support import threading_helper
37
38
39# Use the maximum known clock resolution (gh-75191, gh-110088): Windows
40# GetTickCount64() has a resolution of 15.6 ms. Use 50 ms to tolerate rounding
41# issues.
42CLOCK_RES = 0.050
43
44
45def data_file(*filename):
46    fullname = os.path.join(support.TEST_HOME_DIR, *filename)
47    if os.path.isfile(fullname):
48        return fullname
49    fullname = os.path.join(os.path.dirname(__file__), '..', *filename)
50    if os.path.isfile(fullname):
51        return fullname
52    raise FileNotFoundError(os.path.join(filename))
53
54
55ONLYCERT = data_file('certdata', 'ssl_cert.pem')
56ONLYKEY = data_file('certdata', 'ssl_key.pem')
57SIGNED_CERTFILE = data_file('certdata', 'keycert3.pem')
58SIGNING_CA = data_file('certdata', 'pycacert.pem')
59PEERCERT = {
60    'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
61    'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
62    'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
63    'issuer': ((('countryName', 'XY'),),
64            (('organizationName', 'Python Software Foundation CA'),),
65            (('commonName', 'our-ca-server'),)),
66    'notAfter': 'Oct 28 14:23:16 2037 GMT',
67    'notBefore': 'Aug 29 14:23:16 2018 GMT',
68    'serialNumber': 'CB2D80995A69525C',
69    'subject': ((('countryName', 'XY'),),
70             (('localityName', 'Castle Anthrax'),),
71             (('organizationName', 'Python Software Foundation'),),
72             (('commonName', 'localhost'),)),
73    'subjectAltName': (('DNS', 'localhost'),),
74    'version': 3
75}
76
77
78def simple_server_sslcontext():
79    server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
80    server_context.load_cert_chain(ONLYCERT, ONLYKEY)
81    server_context.check_hostname = False
82    server_context.verify_mode = ssl.CERT_NONE
83    return server_context
84
85
86def simple_client_sslcontext(*, disable_verify=True):
87    client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
88    client_context.check_hostname = False
89    if disable_verify:
90        client_context.verify_mode = ssl.CERT_NONE
91    return client_context
92
93
94def dummy_ssl_context():
95    if ssl is None:
96        return None
97    else:
98        return simple_client_sslcontext(disable_verify=True)
99
100
101def run_briefly(loop):
102    async def once():
103        pass
104    gen = once()
105    t = loop.create_task(gen)
106    # Don't log a warning if the task is not done after run_until_complete().
107    # It occurs if the loop is stopped or if a task raises a BaseException.
108    t._log_destroy_pending = False
109    try:
110        loop.run_until_complete(t)
111    finally:
112        gen.close()
113
114
115def run_until(loop, pred, timeout=support.SHORT_TIMEOUT):
116    delay = 0.001
117    for _ in support.busy_retry(timeout, error=False):
118        if pred():
119            break
120        loop.run_until_complete(tasks.sleep(delay))
121        delay = max(delay * 2, 1.0)
122    else:
123        raise futures.TimeoutError()
124
125
126def run_once(loop):
127    """Legacy API to run once through the event loop.
128
129    This is the recommended pattern for test code.  It will poll the
130    selector once and run all callbacks scheduled in response to I/O
131    events.
132    """
133    loop.call_soon(loop.stop)
134    loop.run_forever()
135
136
137class SilentWSGIRequestHandler(WSGIRequestHandler):
138
139    def get_stderr(self):
140        return io.StringIO()
141
142    def log_message(self, format, *args):
143        pass
144
145
146class SilentWSGIServer(WSGIServer):
147
148    request_timeout = support.LOOPBACK_TIMEOUT
149
150    def get_request(self):
151        request, client_addr = super().get_request()
152        request.settimeout(self.request_timeout)
153        return request, client_addr
154
155    def handle_error(self, request, client_address):
156        pass
157
158
159class SSLWSGIServerMixin:
160
161    def finish_request(self, request, client_address):
162        # The relative location of our test directory (which
163        # contains the ssl key and certificate files) differs
164        # between the stdlib and stand-alone asyncio.
165        # Prefer our own if we can find it.
166        context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
167        context.load_cert_chain(ONLYCERT, ONLYKEY)
168
169        ssock = context.wrap_socket(request, server_side=True)
170        try:
171            self.RequestHandlerClass(ssock, client_address, self)
172            ssock.close()
173        except OSError:
174            # maybe socket has been closed by peer
175            pass
176
177
178class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
179    pass
180
181
182def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
183
184    def loop(environ):
185        size = int(environ['CONTENT_LENGTH'])
186        while size:
187            data = environ['wsgi.input'].read(min(size, 0x10000))
188            yield data
189            size -= len(data)
190
191    def app(environ, start_response):
192        status = '200 OK'
193        headers = [('Content-type', 'text/plain')]
194        start_response(status, headers)
195        if environ['PATH_INFO'] == '/loop':
196            return loop(environ)
197        else:
198            return [b'Test message']
199
200    # Run the test WSGI server in a separate thread in order not to
201    # interfere with event handling in the main thread
202    server_class = server_ssl_cls if use_ssl else server_cls
203    httpd = server_class(address, SilentWSGIRequestHandler)
204    httpd.set_app(app)
205    httpd.address = httpd.server_address
206    server_thread = threading.Thread(
207        target=lambda: httpd.serve_forever(poll_interval=0.05))
208    server_thread.start()
209    try:
210        yield httpd
211    finally:
212        httpd.shutdown()
213        httpd.server_close()
214        server_thread.join()
215
216
217if hasattr(socket, 'AF_UNIX'):
218
219    class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
220
221        def server_bind(self):
222            socketserver.UnixStreamServer.server_bind(self)
223            self.server_name = '127.0.0.1'
224            self.server_port = 80
225
226
227    class UnixWSGIServer(UnixHTTPServer, WSGIServer):
228
229        request_timeout = support.LOOPBACK_TIMEOUT
230
231        def server_bind(self):
232            UnixHTTPServer.server_bind(self)
233            self.setup_environ()
234
235        def get_request(self):
236            request, client_addr = super().get_request()
237            request.settimeout(self.request_timeout)
238            # Code in the stdlib expects that get_request
239            # will return a socket and a tuple (host, port).
240            # However, this isn't true for UNIX sockets,
241            # as the second return value will be a path;
242            # hence we return some fake data sufficient
243            # to get the tests going
244            return request, ('127.0.0.1', '')
245
246
247    class SilentUnixWSGIServer(UnixWSGIServer):
248
249        def handle_error(self, request, client_address):
250            pass
251
252
253    class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
254        pass
255
256
257    def gen_unix_socket_path():
258        return socket_helper.create_unix_domain_name()
259
260
261    @contextlib.contextmanager
262    def unix_socket_path():
263        path = gen_unix_socket_path()
264        try:
265            yield path
266        finally:
267            try:
268                os.unlink(path)
269            except OSError:
270                pass
271
272
273    @contextlib.contextmanager
274    def run_test_unix_server(*, use_ssl=False):
275        with unix_socket_path() as path:
276            yield from _run_test_server(address=path, use_ssl=use_ssl,
277                                        server_cls=SilentUnixWSGIServer,
278                                        server_ssl_cls=UnixSSLWSGIServer)
279
280
281@contextlib.contextmanager
282def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
283    yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
284                                server_cls=SilentWSGIServer,
285                                server_ssl_cls=SSLWSGIServer)
286
287
288def echo_datagrams(sock):
289    while True:
290        data, addr = sock.recvfrom(4096)
291        if data == b'STOP':
292            sock.close()
293            break
294        else:
295            sock.sendto(data, addr)
296
297
298@contextlib.contextmanager
299def run_udp_echo_server(*, host='127.0.0.1', port=0):
300    addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
301    family, type, proto, _, sockaddr = addr_info[0]
302    sock = socket.socket(family, type, proto)
303    sock.bind((host, port))
304    sockname = sock.getsockname()
305    thread = threading.Thread(target=lambda: echo_datagrams(sock))
306    thread.start()
307    try:
308        yield sockname
309    finally:
310        # gh-122187: use a separate socket to send the stop message to avoid
311        # TSan reported race on the same socket.
312        sock2 = socket.socket(family, type, proto)
313        sock2.sendto(b'STOP', sockname)
314        sock2.close()
315        thread.join()
316
317
318def make_test_protocol(base):
319    dct = {}
320    for name in dir(base):
321        if name.startswith('__') and name.endswith('__'):
322            # skip magic names
323            continue
324        dct[name] = MockCallback(return_value=None)
325    return type('TestProtocol', (base,) + base.__bases__, dct)()
326
327
328class TestSelector(selectors.BaseSelector):
329
330    def __init__(self):
331        self.keys = {}
332
333    def register(self, fileobj, events, data=None):
334        key = selectors.SelectorKey(fileobj, 0, events, data)
335        self.keys[fileobj] = key
336        return key
337
338    def unregister(self, fileobj):
339        return self.keys.pop(fileobj)
340
341    def select(self, timeout):
342        return []
343
344    def get_map(self):
345        return self.keys
346
347
348class TestLoop(base_events.BaseEventLoop):
349    """Loop for unittests.
350
351    It manages self time directly.
352    If something scheduled to be executed later then
353    on next loop iteration after all ready handlers done
354    generator passed to __init__ is calling.
355
356    Generator should be like this:
357
358        def gen():
359            ...
360            when = yield ...
361            ... = yield time_advance
362
363    Value returned by yield is absolute time of next scheduled handler.
364    Value passed to yield is time advance to move loop's time forward.
365    """
366
367    def __init__(self, gen=None):
368        super().__init__()
369
370        if gen is None:
371            def gen():
372                yield
373            self._check_on_close = False
374        else:
375            self._check_on_close = True
376
377        self._gen = gen()
378        next(self._gen)
379        self._time = 0
380        self._clock_resolution = 1e-9
381        self._timers = []
382        self._selector = TestSelector()
383
384        self.readers = {}
385        self.writers = {}
386        self.reset_counters()
387
388        self._transports = weakref.WeakValueDictionary()
389
390    def time(self):
391        return self._time
392
393    def advance_time(self, advance):
394        """Move test time forward."""
395        if advance:
396            self._time += advance
397
398    def close(self):
399        super().close()
400        if self._check_on_close:
401            try:
402                self._gen.send(0)
403            except StopIteration:
404                pass
405            else:  # pragma: no cover
406                raise AssertionError("Time generator is not finished")
407
408    def _add_reader(self, fd, callback, *args):
409        self.readers[fd] = events.Handle(callback, args, self, None)
410
411    def _remove_reader(self, fd):
412        self.remove_reader_count[fd] += 1
413        if fd in self.readers:
414            del self.readers[fd]
415            return True
416        else:
417            return False
418
419    def assert_reader(self, fd, callback, *args):
420        if fd not in self.readers:
421            raise AssertionError(f'fd {fd} is not registered')
422        handle = self.readers[fd]
423        if handle._callback != callback:
424            raise AssertionError(
425                f'unexpected callback: {handle._callback} != {callback}')
426        if handle._args != args:
427            raise AssertionError(
428                f'unexpected callback args: {handle._args} != {args}')
429
430    def assert_no_reader(self, fd):
431        if fd in self.readers:
432            raise AssertionError(f'fd {fd} is registered')
433
434    def _add_writer(self, fd, callback, *args):
435        self.writers[fd] = events.Handle(callback, args, self, None)
436
437    def _remove_writer(self, fd):
438        self.remove_writer_count[fd] += 1
439        if fd in self.writers:
440            del self.writers[fd]
441            return True
442        else:
443            return False
444
445    def assert_writer(self, fd, callback, *args):
446        if fd not in self.writers:
447            raise AssertionError(f'fd {fd} is not registered')
448        handle = self.writers[fd]
449        if handle._callback != callback:
450            raise AssertionError(f'{handle._callback!r} != {callback!r}')
451        if handle._args != args:
452            raise AssertionError(f'{handle._args!r} != {args!r}')
453
454    def _ensure_fd_no_transport(self, fd):
455        if not isinstance(fd, int):
456            try:
457                fd = int(fd.fileno())
458            except (AttributeError, TypeError, ValueError):
459                # This code matches selectors._fileobj_to_fd function.
460                raise ValueError("Invalid file object: "
461                                 "{!r}".format(fd)) from None
462        try:
463            transport = self._transports[fd]
464        except KeyError:
465            pass
466        else:
467            raise RuntimeError(
468                'File descriptor {!r} is used by transport {!r}'.format(
469                    fd, transport))
470
471    def add_reader(self, fd, callback, *args):
472        """Add a reader callback."""
473        self._ensure_fd_no_transport(fd)
474        return self._add_reader(fd, callback, *args)
475
476    def remove_reader(self, fd):
477        """Remove a reader callback."""
478        self._ensure_fd_no_transport(fd)
479        return self._remove_reader(fd)
480
481    def add_writer(self, fd, callback, *args):
482        """Add a writer callback.."""
483        self._ensure_fd_no_transport(fd)
484        return self._add_writer(fd, callback, *args)
485
486    def remove_writer(self, fd):
487        """Remove a writer callback."""
488        self._ensure_fd_no_transport(fd)
489        return self._remove_writer(fd)
490
491    def reset_counters(self):
492        self.remove_reader_count = collections.defaultdict(int)
493        self.remove_writer_count = collections.defaultdict(int)
494
495    def _run_once(self):
496        super()._run_once()
497        for when in self._timers:
498            advance = self._gen.send(when)
499            self.advance_time(advance)
500        self._timers = []
501
502    def call_at(self, when, callback, *args, context=None):
503        self._timers.append(when)
504        return super().call_at(when, callback, *args, context=context)
505
506    def _process_events(self, event_list):
507        return
508
509    def _write_to_self(self):
510        pass
511
512
513def MockCallback(**kwargs):
514    return mock.Mock(spec=['__call__'], **kwargs)
515
516
517class MockPattern(str):
518    """A regex based str with a fuzzy __eq__.
519
520    Use this helper with 'mock.assert_called_with', or anywhere
521    where a regex comparison between strings is needed.
522
523    For instance:
524       mock_call.assert_called_with(MockPattern('spam.*ham'))
525    """
526    def __eq__(self, other):
527        return bool(re.search(str(self), other, re.S))
528
529
530class MockInstanceOf:
531    def __init__(self, type):
532        self._type = type
533
534    def __eq__(self, other):
535        return isinstance(other, self._type)
536
537
538def get_function_source(func):
539    source = format_helpers._get_function_source(func)
540    if source is None:
541        raise ValueError("unable to get the source of %r" % (func,))
542    return source
543
544
545class TestCase(unittest.TestCase):
546    @staticmethod
547    def close_loop(loop):
548        if loop._default_executor is not None:
549            if not loop.is_closed():
550                loop.run_until_complete(loop.shutdown_default_executor())
551            else:
552                loop._default_executor.shutdown(wait=True)
553        loop.close()
554
555        policy = support.maybe_get_event_loop_policy()
556        if policy is not None:
557            try:
558                with warnings.catch_warnings():
559                    warnings.simplefilter('ignore', DeprecationWarning)
560                    watcher = policy.get_child_watcher()
561            except NotImplementedError:
562                # watcher is not implemented by EventLoopPolicy, e.g. Windows
563                pass
564            else:
565                if isinstance(watcher, asyncio.ThreadedChildWatcher):
566                    # Wait for subprocess to finish, but not forever
567                    for thread in list(watcher._threads.values()):
568                        thread.join(timeout=support.SHORT_TIMEOUT)
569                        if thread.is_alive():
570                            raise RuntimeError(f"thread {thread} still alive: "
571                                               "subprocess still running")
572
573
574    def set_event_loop(self, loop, *, cleanup=True):
575        if loop is None:
576            raise AssertionError('loop is None')
577        # ensure that the event loop is passed explicitly in asyncio
578        events.set_event_loop(None)
579        if cleanup:
580            self.addCleanup(self.close_loop, loop)
581
582    def new_test_loop(self, gen=None):
583        loop = TestLoop(gen)
584        self.set_event_loop(loop)
585        return loop
586
587    def setUp(self):
588        self._thread_cleanup = threading_helper.threading_setup()
589
590    def tearDown(self):
591        events.set_event_loop(None)
592
593        # Detect CPython bug #23353: ensure that yield/yield-from is not used
594        # in an except block of a generator
595        self.assertIsNone(sys.exception())
596
597        self.doCleanups()
598        threading_helper.threading_cleanup(*self._thread_cleanup)
599        support.reap_children()
600
601
602@contextlib.contextmanager
603def disable_logger():
604    """Context manager to disable asyncio logger.
605
606    For example, it can be used to ignore warnings in debug mode.
607    """
608    old_level = logger.level
609    try:
610        logger.setLevel(logging.CRITICAL+1)
611        yield
612    finally:
613        logger.setLevel(old_level)
614
615
616def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
617                            family=socket.AF_INET):
618    """Create a mock of a non-blocking socket."""
619    sock = mock.MagicMock(socket.socket)
620    sock.proto = proto
621    sock.type = type
622    sock.family = family
623    sock.gettimeout.return_value = 0.0
624    return sock
625
626
627async def await_without_task(coro):
628    exc = None
629    def func():
630        try:
631            for _ in coro.__await__():
632                pass
633        except BaseException as err:
634            nonlocal exc
635            exc = err
636    asyncio.get_running_loop().call_soon(func)
637    await asyncio.sleep(0)
638    if exc is not None:
639        raise exc
640