1"""Utilities shared by tests.""" 2 3import asyncio 4import collections 5import contextlib 6import io 7import logging 8import os 9import re 10import selectors 11import socket 12import socketserver 13import sys 14import threading 15import unittest 16import weakref 17import warnings 18from unittest import mock 19 20from http.server import HTTPServer 21from wsgiref.simple_server import WSGIRequestHandler, WSGIServer 22 23try: 24 import ssl 25except ImportError: # pragma: no cover 26 ssl = None 27 28from asyncio import base_events 29from asyncio import events 30from asyncio import format_helpers 31from asyncio import futures 32from asyncio import tasks 33from asyncio.log import logger 34from test import support 35from test.support import socket_helper 36from test.support import threading_helper 37 38 39# Use the maximum known clock resolution (gh-75191, gh-110088): Windows 40# GetTickCount64() has a resolution of 15.6 ms. Use 50 ms to tolerate rounding 41# issues. 42CLOCK_RES = 0.050 43 44 45def data_file(*filename): 46 fullname = os.path.join(support.TEST_HOME_DIR, *filename) 47 if os.path.isfile(fullname): 48 return fullname 49 fullname = os.path.join(os.path.dirname(__file__), '..', *filename) 50 if os.path.isfile(fullname): 51 return fullname 52 raise FileNotFoundError(os.path.join(filename)) 53 54 55ONLYCERT = data_file('certdata', 'ssl_cert.pem') 56ONLYKEY = data_file('certdata', 'ssl_key.pem') 57SIGNED_CERTFILE = data_file('certdata', 'keycert3.pem') 58SIGNING_CA = data_file('certdata', 'pycacert.pem') 59PEERCERT = { 60 'OCSP': ('http://testca.pythontest.net/testca/ocsp/',), 61 'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',), 62 'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',), 63 'issuer': ((('countryName', 'XY'),), 64 (('organizationName', 'Python Software Foundation CA'),), 65 (('commonName', 'our-ca-server'),)), 66 'notAfter': 'Oct 28 14:23:16 2037 GMT', 67 'notBefore': 'Aug 29 14:23:16 2018 GMT', 68 'serialNumber': 'CB2D80995A69525C', 69 'subject': ((('countryName', 'XY'),), 70 (('localityName', 'Castle Anthrax'),), 71 (('organizationName', 'Python Software Foundation'),), 72 (('commonName', 'localhost'),)), 73 'subjectAltName': (('DNS', 'localhost'),), 74 'version': 3 75} 76 77 78def simple_server_sslcontext(): 79 server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 80 server_context.load_cert_chain(ONLYCERT, ONLYKEY) 81 server_context.check_hostname = False 82 server_context.verify_mode = ssl.CERT_NONE 83 return server_context 84 85 86def simple_client_sslcontext(*, disable_verify=True): 87 client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) 88 client_context.check_hostname = False 89 if disable_verify: 90 client_context.verify_mode = ssl.CERT_NONE 91 return client_context 92 93 94def dummy_ssl_context(): 95 if ssl is None: 96 return None 97 else: 98 return simple_client_sslcontext(disable_verify=True) 99 100 101def run_briefly(loop): 102 async def once(): 103 pass 104 gen = once() 105 t = loop.create_task(gen) 106 # Don't log a warning if the task is not done after run_until_complete(). 107 # It occurs if the loop is stopped or if a task raises a BaseException. 108 t._log_destroy_pending = False 109 try: 110 loop.run_until_complete(t) 111 finally: 112 gen.close() 113 114 115def run_until(loop, pred, timeout=support.SHORT_TIMEOUT): 116 delay = 0.001 117 for _ in support.busy_retry(timeout, error=False): 118 if pred(): 119 break 120 loop.run_until_complete(tasks.sleep(delay)) 121 delay = max(delay * 2, 1.0) 122 else: 123 raise futures.TimeoutError() 124 125 126def run_once(loop): 127 """Legacy API to run once through the event loop. 128 129 This is the recommended pattern for test code. It will poll the 130 selector once and run all callbacks scheduled in response to I/O 131 events. 132 """ 133 loop.call_soon(loop.stop) 134 loop.run_forever() 135 136 137class SilentWSGIRequestHandler(WSGIRequestHandler): 138 139 def get_stderr(self): 140 return io.StringIO() 141 142 def log_message(self, format, *args): 143 pass 144 145 146class SilentWSGIServer(WSGIServer): 147 148 request_timeout = support.LOOPBACK_TIMEOUT 149 150 def get_request(self): 151 request, client_addr = super().get_request() 152 request.settimeout(self.request_timeout) 153 return request, client_addr 154 155 def handle_error(self, request, client_address): 156 pass 157 158 159class SSLWSGIServerMixin: 160 161 def finish_request(self, request, client_address): 162 # The relative location of our test directory (which 163 # contains the ssl key and certificate files) differs 164 # between the stdlib and stand-alone asyncio. 165 # Prefer our own if we can find it. 166 context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 167 context.load_cert_chain(ONLYCERT, ONLYKEY) 168 169 ssock = context.wrap_socket(request, server_side=True) 170 try: 171 self.RequestHandlerClass(ssock, client_address, self) 172 ssock.close() 173 except OSError: 174 # maybe socket has been closed by peer 175 pass 176 177 178class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): 179 pass 180 181 182def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): 183 184 def loop(environ): 185 size = int(environ['CONTENT_LENGTH']) 186 while size: 187 data = environ['wsgi.input'].read(min(size, 0x10000)) 188 yield data 189 size -= len(data) 190 191 def app(environ, start_response): 192 status = '200 OK' 193 headers = [('Content-type', 'text/plain')] 194 start_response(status, headers) 195 if environ['PATH_INFO'] == '/loop': 196 return loop(environ) 197 else: 198 return [b'Test message'] 199 200 # Run the test WSGI server in a separate thread in order not to 201 # interfere with event handling in the main thread 202 server_class = server_ssl_cls if use_ssl else server_cls 203 httpd = server_class(address, SilentWSGIRequestHandler) 204 httpd.set_app(app) 205 httpd.address = httpd.server_address 206 server_thread = threading.Thread( 207 target=lambda: httpd.serve_forever(poll_interval=0.05)) 208 server_thread.start() 209 try: 210 yield httpd 211 finally: 212 httpd.shutdown() 213 httpd.server_close() 214 server_thread.join() 215 216 217if hasattr(socket, 'AF_UNIX'): 218 219 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): 220 221 def server_bind(self): 222 socketserver.UnixStreamServer.server_bind(self) 223 self.server_name = '127.0.0.1' 224 self.server_port = 80 225 226 227 class UnixWSGIServer(UnixHTTPServer, WSGIServer): 228 229 request_timeout = support.LOOPBACK_TIMEOUT 230 231 def server_bind(self): 232 UnixHTTPServer.server_bind(self) 233 self.setup_environ() 234 235 def get_request(self): 236 request, client_addr = super().get_request() 237 request.settimeout(self.request_timeout) 238 # Code in the stdlib expects that get_request 239 # will return a socket and a tuple (host, port). 240 # However, this isn't true for UNIX sockets, 241 # as the second return value will be a path; 242 # hence we return some fake data sufficient 243 # to get the tests going 244 return request, ('127.0.0.1', '') 245 246 247 class SilentUnixWSGIServer(UnixWSGIServer): 248 249 def handle_error(self, request, client_address): 250 pass 251 252 253 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): 254 pass 255 256 257 def gen_unix_socket_path(): 258 return socket_helper.create_unix_domain_name() 259 260 261 @contextlib.contextmanager 262 def unix_socket_path(): 263 path = gen_unix_socket_path() 264 try: 265 yield path 266 finally: 267 try: 268 os.unlink(path) 269 except OSError: 270 pass 271 272 273 @contextlib.contextmanager 274 def run_test_unix_server(*, use_ssl=False): 275 with unix_socket_path() as path: 276 yield from _run_test_server(address=path, use_ssl=use_ssl, 277 server_cls=SilentUnixWSGIServer, 278 server_ssl_cls=UnixSSLWSGIServer) 279 280 281@contextlib.contextmanager 282def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): 283 yield from _run_test_server(address=(host, port), use_ssl=use_ssl, 284 server_cls=SilentWSGIServer, 285 server_ssl_cls=SSLWSGIServer) 286 287 288def echo_datagrams(sock): 289 while True: 290 data, addr = sock.recvfrom(4096) 291 if data == b'STOP': 292 sock.close() 293 break 294 else: 295 sock.sendto(data, addr) 296 297 298@contextlib.contextmanager 299def run_udp_echo_server(*, host='127.0.0.1', port=0): 300 addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM) 301 family, type, proto, _, sockaddr = addr_info[0] 302 sock = socket.socket(family, type, proto) 303 sock.bind((host, port)) 304 sockname = sock.getsockname() 305 thread = threading.Thread(target=lambda: echo_datagrams(sock)) 306 thread.start() 307 try: 308 yield sockname 309 finally: 310 # gh-122187: use a separate socket to send the stop message to avoid 311 # TSan reported race on the same socket. 312 sock2 = socket.socket(family, type, proto) 313 sock2.sendto(b'STOP', sockname) 314 sock2.close() 315 thread.join() 316 317 318def make_test_protocol(base): 319 dct = {} 320 for name in dir(base): 321 if name.startswith('__') and name.endswith('__'): 322 # skip magic names 323 continue 324 dct[name] = MockCallback(return_value=None) 325 return type('TestProtocol', (base,) + base.__bases__, dct)() 326 327 328class TestSelector(selectors.BaseSelector): 329 330 def __init__(self): 331 self.keys = {} 332 333 def register(self, fileobj, events, data=None): 334 key = selectors.SelectorKey(fileobj, 0, events, data) 335 self.keys[fileobj] = key 336 return key 337 338 def unregister(self, fileobj): 339 return self.keys.pop(fileobj) 340 341 def select(self, timeout): 342 return [] 343 344 def get_map(self): 345 return self.keys 346 347 348class TestLoop(base_events.BaseEventLoop): 349 """Loop for unittests. 350 351 It manages self time directly. 352 If something scheduled to be executed later then 353 on next loop iteration after all ready handlers done 354 generator passed to __init__ is calling. 355 356 Generator should be like this: 357 358 def gen(): 359 ... 360 when = yield ... 361 ... = yield time_advance 362 363 Value returned by yield is absolute time of next scheduled handler. 364 Value passed to yield is time advance to move loop's time forward. 365 """ 366 367 def __init__(self, gen=None): 368 super().__init__() 369 370 if gen is None: 371 def gen(): 372 yield 373 self._check_on_close = False 374 else: 375 self._check_on_close = True 376 377 self._gen = gen() 378 next(self._gen) 379 self._time = 0 380 self._clock_resolution = 1e-9 381 self._timers = [] 382 self._selector = TestSelector() 383 384 self.readers = {} 385 self.writers = {} 386 self.reset_counters() 387 388 self._transports = weakref.WeakValueDictionary() 389 390 def time(self): 391 return self._time 392 393 def advance_time(self, advance): 394 """Move test time forward.""" 395 if advance: 396 self._time += advance 397 398 def close(self): 399 super().close() 400 if self._check_on_close: 401 try: 402 self._gen.send(0) 403 except StopIteration: 404 pass 405 else: # pragma: no cover 406 raise AssertionError("Time generator is not finished") 407 408 def _add_reader(self, fd, callback, *args): 409 self.readers[fd] = events.Handle(callback, args, self, None) 410 411 def _remove_reader(self, fd): 412 self.remove_reader_count[fd] += 1 413 if fd in self.readers: 414 del self.readers[fd] 415 return True 416 else: 417 return False 418 419 def assert_reader(self, fd, callback, *args): 420 if fd not in self.readers: 421 raise AssertionError(f'fd {fd} is not registered') 422 handle = self.readers[fd] 423 if handle._callback != callback: 424 raise AssertionError( 425 f'unexpected callback: {handle._callback} != {callback}') 426 if handle._args != args: 427 raise AssertionError( 428 f'unexpected callback args: {handle._args} != {args}') 429 430 def assert_no_reader(self, fd): 431 if fd in self.readers: 432 raise AssertionError(f'fd {fd} is registered') 433 434 def _add_writer(self, fd, callback, *args): 435 self.writers[fd] = events.Handle(callback, args, self, None) 436 437 def _remove_writer(self, fd): 438 self.remove_writer_count[fd] += 1 439 if fd in self.writers: 440 del self.writers[fd] 441 return True 442 else: 443 return False 444 445 def assert_writer(self, fd, callback, *args): 446 if fd not in self.writers: 447 raise AssertionError(f'fd {fd} is not registered') 448 handle = self.writers[fd] 449 if handle._callback != callback: 450 raise AssertionError(f'{handle._callback!r} != {callback!r}') 451 if handle._args != args: 452 raise AssertionError(f'{handle._args!r} != {args!r}') 453 454 def _ensure_fd_no_transport(self, fd): 455 if not isinstance(fd, int): 456 try: 457 fd = int(fd.fileno()) 458 except (AttributeError, TypeError, ValueError): 459 # This code matches selectors._fileobj_to_fd function. 460 raise ValueError("Invalid file object: " 461 "{!r}".format(fd)) from None 462 try: 463 transport = self._transports[fd] 464 except KeyError: 465 pass 466 else: 467 raise RuntimeError( 468 'File descriptor {!r} is used by transport {!r}'.format( 469 fd, transport)) 470 471 def add_reader(self, fd, callback, *args): 472 """Add a reader callback.""" 473 self._ensure_fd_no_transport(fd) 474 return self._add_reader(fd, callback, *args) 475 476 def remove_reader(self, fd): 477 """Remove a reader callback.""" 478 self._ensure_fd_no_transport(fd) 479 return self._remove_reader(fd) 480 481 def add_writer(self, fd, callback, *args): 482 """Add a writer callback..""" 483 self._ensure_fd_no_transport(fd) 484 return self._add_writer(fd, callback, *args) 485 486 def remove_writer(self, fd): 487 """Remove a writer callback.""" 488 self._ensure_fd_no_transport(fd) 489 return self._remove_writer(fd) 490 491 def reset_counters(self): 492 self.remove_reader_count = collections.defaultdict(int) 493 self.remove_writer_count = collections.defaultdict(int) 494 495 def _run_once(self): 496 super()._run_once() 497 for when in self._timers: 498 advance = self._gen.send(when) 499 self.advance_time(advance) 500 self._timers = [] 501 502 def call_at(self, when, callback, *args, context=None): 503 self._timers.append(when) 504 return super().call_at(when, callback, *args, context=context) 505 506 def _process_events(self, event_list): 507 return 508 509 def _write_to_self(self): 510 pass 511 512 513def MockCallback(**kwargs): 514 return mock.Mock(spec=['__call__'], **kwargs) 515 516 517class MockPattern(str): 518 """A regex based str with a fuzzy __eq__. 519 520 Use this helper with 'mock.assert_called_with', or anywhere 521 where a regex comparison between strings is needed. 522 523 For instance: 524 mock_call.assert_called_with(MockPattern('spam.*ham')) 525 """ 526 def __eq__(self, other): 527 return bool(re.search(str(self), other, re.S)) 528 529 530class MockInstanceOf: 531 def __init__(self, type): 532 self._type = type 533 534 def __eq__(self, other): 535 return isinstance(other, self._type) 536 537 538def get_function_source(func): 539 source = format_helpers._get_function_source(func) 540 if source is None: 541 raise ValueError("unable to get the source of %r" % (func,)) 542 return source 543 544 545class TestCase(unittest.TestCase): 546 @staticmethod 547 def close_loop(loop): 548 if loop._default_executor is not None: 549 if not loop.is_closed(): 550 loop.run_until_complete(loop.shutdown_default_executor()) 551 else: 552 loop._default_executor.shutdown(wait=True) 553 loop.close() 554 555 policy = support.maybe_get_event_loop_policy() 556 if policy is not None: 557 try: 558 with warnings.catch_warnings(): 559 warnings.simplefilter('ignore', DeprecationWarning) 560 watcher = policy.get_child_watcher() 561 except NotImplementedError: 562 # watcher is not implemented by EventLoopPolicy, e.g. Windows 563 pass 564 else: 565 if isinstance(watcher, asyncio.ThreadedChildWatcher): 566 # Wait for subprocess to finish, but not forever 567 for thread in list(watcher._threads.values()): 568 thread.join(timeout=support.SHORT_TIMEOUT) 569 if thread.is_alive(): 570 raise RuntimeError(f"thread {thread} still alive: " 571 "subprocess still running") 572 573 574 def set_event_loop(self, loop, *, cleanup=True): 575 if loop is None: 576 raise AssertionError('loop is None') 577 # ensure that the event loop is passed explicitly in asyncio 578 events.set_event_loop(None) 579 if cleanup: 580 self.addCleanup(self.close_loop, loop) 581 582 def new_test_loop(self, gen=None): 583 loop = TestLoop(gen) 584 self.set_event_loop(loop) 585 return loop 586 587 def setUp(self): 588 self._thread_cleanup = threading_helper.threading_setup() 589 590 def tearDown(self): 591 events.set_event_loop(None) 592 593 # Detect CPython bug #23353: ensure that yield/yield-from is not used 594 # in an except block of a generator 595 self.assertIsNone(sys.exception()) 596 597 self.doCleanups() 598 threading_helper.threading_cleanup(*self._thread_cleanup) 599 support.reap_children() 600 601 602@contextlib.contextmanager 603def disable_logger(): 604 """Context manager to disable asyncio logger. 605 606 For example, it can be used to ignore warnings in debug mode. 607 """ 608 old_level = logger.level 609 try: 610 logger.setLevel(logging.CRITICAL+1) 611 yield 612 finally: 613 logger.setLevel(old_level) 614 615 616def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM, 617 family=socket.AF_INET): 618 """Create a mock of a non-blocking socket.""" 619 sock = mock.MagicMock(socket.socket) 620 sock.proto = proto 621 sock.type = type 622 sock.family = family 623 sock.gettimeout.return_value = 0.0 624 return sock 625 626 627async def await_without_task(coro): 628 exc = None 629 def func(): 630 try: 631 for _ in coro.__await__(): 632 pass 633 except BaseException as err: 634 nonlocal exc 635 exc = err 636 asyncio.get_running_loop().call_soon(func) 637 await asyncio.sleep(0) 638 if exc is not None: 639 raise exc 640