1"""Utilities shared by tests.""" 2 3import asyncio 4import collections 5import contextlib 6import io 7import logging 8import os 9import re 10import selectors 11import socket 12import socketserver 13import sys 14import tempfile 15import threading 16import time 17import unittest 18import weakref 19 20from unittest import mock 21 22from http.server import HTTPServer 23from wsgiref.simple_server import WSGIRequestHandler, WSGIServer 24 25try: 26 import ssl 27except ImportError: # pragma: no cover 28 ssl = None 29 30from asyncio import base_events 31from asyncio import events 32from asyncio import format_helpers 33from asyncio import futures 34from asyncio import tasks 35from asyncio.log import logger 36from test import support 37 38 39def data_file(filename): 40 if hasattr(support, 'TEST_HOME_DIR'): 41 fullname = os.path.join(support.TEST_HOME_DIR, filename) 42 if os.path.isfile(fullname): 43 return fullname 44 fullname = os.path.join(os.path.dirname(__file__), '..', filename) 45 if os.path.isfile(fullname): 46 return fullname 47 raise FileNotFoundError(filename) 48 49 50ONLYCERT = data_file('ssl_cert.pem') 51ONLYKEY = data_file('ssl_key.pem') 52SIGNED_CERTFILE = data_file('keycert3.pem') 53SIGNING_CA = data_file('pycacert.pem') 54PEERCERT = { 55 'OCSP': ('http://testca.pythontest.net/testca/ocsp/',), 56 'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',), 57 'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',), 58 'issuer': ((('countryName', 'XY'),), 59 (('organizationName', 'Python Software Foundation CA'),), 60 (('commonName', 'our-ca-server'),)), 61 'notAfter': 'Jul 7 14:23:16 2028 GMT', 62 'notBefore': 'Aug 29 14:23:16 2018 GMT', 63 'serialNumber': 'CB2D80995A69525C', 64 'subject': ((('countryName', 'XY'),), 65 (('localityName', 'Castle Anthrax'),), 66 (('organizationName', 'Python Software Foundation'),), 67 (('commonName', 'localhost'),)), 68 'subjectAltName': (('DNS', 'localhost'),), 69 'version': 3 70} 71 72 73def simple_server_sslcontext(): 74 server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 75 server_context.load_cert_chain(ONLYCERT, ONLYKEY) 76 server_context.check_hostname = False 77 server_context.verify_mode = ssl.CERT_NONE 78 return server_context 79 80 81def simple_client_sslcontext(*, disable_verify=True): 82 client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) 83 client_context.check_hostname = False 84 if disable_verify: 85 client_context.verify_mode = ssl.CERT_NONE 86 return client_context 87 88 89def dummy_ssl_context(): 90 if ssl is None: 91 return None 92 else: 93 return ssl.SSLContext(ssl.PROTOCOL_TLS) 94 95 96def run_briefly(loop): 97 async def once(): 98 pass 99 gen = once() 100 t = loop.create_task(gen) 101 # Don't log a warning if the task is not done after run_until_complete(). 102 # It occurs if the loop is stopped or if a task raises a BaseException. 103 t._log_destroy_pending = False 104 try: 105 loop.run_until_complete(t) 106 finally: 107 gen.close() 108 109 110def run_until(loop, pred, timeout=support.SHORT_TIMEOUT): 111 deadline = time.monotonic() + timeout 112 while not pred(): 113 if timeout is not None: 114 timeout = deadline - time.monotonic() 115 if timeout <= 0: 116 raise futures.TimeoutError() 117 loop.run_until_complete(tasks.sleep(0.001)) 118 119 120def run_once(loop): 121 """Legacy API to run once through the event loop. 122 123 This is the recommended pattern for test code. It will poll the 124 selector once and run all callbacks scheduled in response to I/O 125 events. 126 """ 127 loop.call_soon(loop.stop) 128 loop.run_forever() 129 130 131class SilentWSGIRequestHandler(WSGIRequestHandler): 132 133 def get_stderr(self): 134 return io.StringIO() 135 136 def log_message(self, format, *args): 137 pass 138 139 140class SilentWSGIServer(WSGIServer): 141 142 request_timeout = support.LOOPBACK_TIMEOUT 143 144 def get_request(self): 145 request, client_addr = super().get_request() 146 request.settimeout(self.request_timeout) 147 return request, client_addr 148 149 def handle_error(self, request, client_address): 150 pass 151 152 153class SSLWSGIServerMixin: 154 155 def finish_request(self, request, client_address): 156 # The relative location of our test directory (which 157 # contains the ssl key and certificate files) differs 158 # between the stdlib and stand-alone asyncio. 159 # Prefer our own if we can find it. 160 context = ssl.SSLContext() 161 context.load_cert_chain(ONLYCERT, ONLYKEY) 162 163 ssock = context.wrap_socket(request, server_side=True) 164 try: 165 self.RequestHandlerClass(ssock, client_address, self) 166 ssock.close() 167 except OSError: 168 # maybe socket has been closed by peer 169 pass 170 171 172class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): 173 pass 174 175 176def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): 177 178 def loop(environ): 179 size = int(environ['CONTENT_LENGTH']) 180 while size: 181 data = environ['wsgi.input'].read(min(size, 0x10000)) 182 yield data 183 size -= len(data) 184 185 def app(environ, start_response): 186 status = '200 OK' 187 headers = [('Content-type', 'text/plain')] 188 start_response(status, headers) 189 if environ['PATH_INFO'] == '/loop': 190 return loop(environ) 191 else: 192 return [b'Test message'] 193 194 # Run the test WSGI server in a separate thread in order not to 195 # interfere with event handling in the main thread 196 server_class = server_ssl_cls if use_ssl else server_cls 197 httpd = server_class(address, SilentWSGIRequestHandler) 198 httpd.set_app(app) 199 httpd.address = httpd.server_address 200 server_thread = threading.Thread( 201 target=lambda: httpd.serve_forever(poll_interval=0.05)) 202 server_thread.start() 203 try: 204 yield httpd 205 finally: 206 httpd.shutdown() 207 httpd.server_close() 208 server_thread.join() 209 210 211if hasattr(socket, 'AF_UNIX'): 212 213 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): 214 215 def server_bind(self): 216 socketserver.UnixStreamServer.server_bind(self) 217 self.server_name = '127.0.0.1' 218 self.server_port = 80 219 220 221 class UnixWSGIServer(UnixHTTPServer, WSGIServer): 222 223 request_timeout = support.LOOPBACK_TIMEOUT 224 225 def server_bind(self): 226 UnixHTTPServer.server_bind(self) 227 self.setup_environ() 228 229 def get_request(self): 230 request, client_addr = super().get_request() 231 request.settimeout(self.request_timeout) 232 # Code in the stdlib expects that get_request 233 # will return a socket and a tuple (host, port). 234 # However, this isn't true for UNIX sockets, 235 # as the second return value will be a path; 236 # hence we return some fake data sufficient 237 # to get the tests going 238 return request, ('127.0.0.1', '') 239 240 241 class SilentUnixWSGIServer(UnixWSGIServer): 242 243 def handle_error(self, request, client_address): 244 pass 245 246 247 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): 248 pass 249 250 251 def gen_unix_socket_path(): 252 with tempfile.NamedTemporaryFile() as file: 253 return file.name 254 255 256 @contextlib.contextmanager 257 def unix_socket_path(): 258 path = gen_unix_socket_path() 259 try: 260 yield path 261 finally: 262 try: 263 os.unlink(path) 264 except OSError: 265 pass 266 267 268 @contextlib.contextmanager 269 def run_test_unix_server(*, use_ssl=False): 270 with unix_socket_path() as path: 271 yield from _run_test_server(address=path, use_ssl=use_ssl, 272 server_cls=SilentUnixWSGIServer, 273 server_ssl_cls=UnixSSLWSGIServer) 274 275 276@contextlib.contextmanager 277def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): 278 yield from _run_test_server(address=(host, port), use_ssl=use_ssl, 279 server_cls=SilentWSGIServer, 280 server_ssl_cls=SSLWSGIServer) 281 282 283def make_test_protocol(base): 284 dct = {} 285 for name in dir(base): 286 if name.startswith('__') and name.endswith('__'): 287 # skip magic names 288 continue 289 dct[name] = MockCallback(return_value=None) 290 return type('TestProtocol', (base,) + base.__bases__, dct)() 291 292 293class TestSelector(selectors.BaseSelector): 294 295 def __init__(self): 296 self.keys = {} 297 298 def register(self, fileobj, events, data=None): 299 key = selectors.SelectorKey(fileobj, 0, events, data) 300 self.keys[fileobj] = key 301 return key 302 303 def unregister(self, fileobj): 304 return self.keys.pop(fileobj) 305 306 def select(self, timeout): 307 return [] 308 309 def get_map(self): 310 return self.keys 311 312 313class TestLoop(base_events.BaseEventLoop): 314 """Loop for unittests. 315 316 It manages self time directly. 317 If something scheduled to be executed later then 318 on next loop iteration after all ready handlers done 319 generator passed to __init__ is calling. 320 321 Generator should be like this: 322 323 def gen(): 324 ... 325 when = yield ... 326 ... = yield time_advance 327 328 Value returned by yield is absolute time of next scheduled handler. 329 Value passed to yield is time advance to move loop's time forward. 330 """ 331 332 def __init__(self, gen=None): 333 super().__init__() 334 335 if gen is None: 336 def gen(): 337 yield 338 self._check_on_close = False 339 else: 340 self._check_on_close = True 341 342 self._gen = gen() 343 next(self._gen) 344 self._time = 0 345 self._clock_resolution = 1e-9 346 self._timers = [] 347 self._selector = TestSelector() 348 349 self.readers = {} 350 self.writers = {} 351 self.reset_counters() 352 353 self._transports = weakref.WeakValueDictionary() 354 355 def time(self): 356 return self._time 357 358 def advance_time(self, advance): 359 """Move test time forward.""" 360 if advance: 361 self._time += advance 362 363 def close(self): 364 super().close() 365 if self._check_on_close: 366 try: 367 self._gen.send(0) 368 except StopIteration: 369 pass 370 else: # pragma: no cover 371 raise AssertionError("Time generator is not finished") 372 373 def _add_reader(self, fd, callback, *args): 374 self.readers[fd] = events.Handle(callback, args, self, None) 375 376 def _remove_reader(self, fd): 377 self.remove_reader_count[fd] += 1 378 if fd in self.readers: 379 del self.readers[fd] 380 return True 381 else: 382 return False 383 384 def assert_reader(self, fd, callback, *args): 385 if fd not in self.readers: 386 raise AssertionError(f'fd {fd} is not registered') 387 handle = self.readers[fd] 388 if handle._callback != callback: 389 raise AssertionError( 390 f'unexpected callback: {handle._callback} != {callback}') 391 if handle._args != args: 392 raise AssertionError( 393 f'unexpected callback args: {handle._args} != {args}') 394 395 def assert_no_reader(self, fd): 396 if fd in self.readers: 397 raise AssertionError(f'fd {fd} is registered') 398 399 def _add_writer(self, fd, callback, *args): 400 self.writers[fd] = events.Handle(callback, args, self, None) 401 402 def _remove_writer(self, fd): 403 self.remove_writer_count[fd] += 1 404 if fd in self.writers: 405 del self.writers[fd] 406 return True 407 else: 408 return False 409 410 def assert_writer(self, fd, callback, *args): 411 assert fd in self.writers, 'fd {} is not registered'.format(fd) 412 handle = self.writers[fd] 413 assert handle._callback == callback, '{!r} != {!r}'.format( 414 handle._callback, callback) 415 assert handle._args == args, '{!r} != {!r}'.format( 416 handle._args, args) 417 418 def _ensure_fd_no_transport(self, fd): 419 if not isinstance(fd, int): 420 try: 421 fd = int(fd.fileno()) 422 except (AttributeError, TypeError, ValueError): 423 # This code matches selectors._fileobj_to_fd function. 424 raise ValueError("Invalid file object: " 425 "{!r}".format(fd)) from None 426 try: 427 transport = self._transports[fd] 428 except KeyError: 429 pass 430 else: 431 raise RuntimeError( 432 'File descriptor {!r} is used by transport {!r}'.format( 433 fd, transport)) 434 435 def add_reader(self, fd, callback, *args): 436 """Add a reader callback.""" 437 self._ensure_fd_no_transport(fd) 438 return self._add_reader(fd, callback, *args) 439 440 def remove_reader(self, fd): 441 """Remove a reader callback.""" 442 self._ensure_fd_no_transport(fd) 443 return self._remove_reader(fd) 444 445 def add_writer(self, fd, callback, *args): 446 """Add a writer callback..""" 447 self._ensure_fd_no_transport(fd) 448 return self._add_writer(fd, callback, *args) 449 450 def remove_writer(self, fd): 451 """Remove a writer callback.""" 452 self._ensure_fd_no_transport(fd) 453 return self._remove_writer(fd) 454 455 def reset_counters(self): 456 self.remove_reader_count = collections.defaultdict(int) 457 self.remove_writer_count = collections.defaultdict(int) 458 459 def _run_once(self): 460 super()._run_once() 461 for when in self._timers: 462 advance = self._gen.send(when) 463 self.advance_time(advance) 464 self._timers = [] 465 466 def call_at(self, when, callback, *args, context=None): 467 self._timers.append(when) 468 return super().call_at(when, callback, *args, context=context) 469 470 def _process_events(self, event_list): 471 return 472 473 def _write_to_self(self): 474 pass 475 476 477def MockCallback(**kwargs): 478 return mock.Mock(spec=['__call__'], **kwargs) 479 480 481class MockPattern(str): 482 """A regex based str with a fuzzy __eq__. 483 484 Use this helper with 'mock.assert_called_with', or anywhere 485 where a regex comparison between strings is needed. 486 487 For instance: 488 mock_call.assert_called_with(MockPattern('spam.*ham')) 489 """ 490 def __eq__(self, other): 491 return bool(re.search(str(self), other, re.S)) 492 493 494class MockInstanceOf: 495 def __init__(self, type): 496 self._type = type 497 498 def __eq__(self, other): 499 return isinstance(other, self._type) 500 501 502def get_function_source(func): 503 source = format_helpers._get_function_source(func) 504 if source is None: 505 raise ValueError("unable to get the source of %r" % (func,)) 506 return source 507 508 509class TestCase(unittest.TestCase): 510 @staticmethod 511 def close_loop(loop): 512 if loop._default_executor is not None: 513 if not loop.is_closed(): 514 loop.run_until_complete(loop.shutdown_default_executor()) 515 else: 516 loop._default_executor.shutdown(wait=True) 517 loop.close() 518 policy = support.maybe_get_event_loop_policy() 519 if policy is not None: 520 try: 521 watcher = policy.get_child_watcher() 522 except NotImplementedError: 523 # watcher is not implemented by EventLoopPolicy, e.g. Windows 524 pass 525 else: 526 if isinstance(watcher, asyncio.ThreadedChildWatcher): 527 threads = list(watcher._threads.values()) 528 for thread in threads: 529 thread.join() 530 531 def set_event_loop(self, loop, *, cleanup=True): 532 assert loop is not None 533 # ensure that the event loop is passed explicitly in asyncio 534 events.set_event_loop(None) 535 if cleanup: 536 self.addCleanup(self.close_loop, loop) 537 538 def new_test_loop(self, gen=None): 539 loop = TestLoop(gen) 540 self.set_event_loop(loop) 541 return loop 542 543 def unpatch_get_running_loop(self): 544 events._get_running_loop = self._get_running_loop 545 546 def setUp(self): 547 self._get_running_loop = events._get_running_loop 548 events._get_running_loop = lambda: None 549 self._thread_cleanup = support.threading_setup() 550 551 def tearDown(self): 552 self.unpatch_get_running_loop() 553 554 events.set_event_loop(None) 555 556 # Detect CPython bug #23353: ensure that yield/yield-from is not used 557 # in an except block of a generator 558 self.assertEqual(sys.exc_info(), (None, None, None)) 559 560 self.doCleanups() 561 support.threading_cleanup(*self._thread_cleanup) 562 support.reap_children() 563 564 565@contextlib.contextmanager 566def disable_logger(): 567 """Context manager to disable asyncio logger. 568 569 For example, it can be used to ignore warnings in debug mode. 570 """ 571 old_level = logger.level 572 try: 573 logger.setLevel(logging.CRITICAL+1) 574 yield 575 finally: 576 logger.setLevel(old_level) 577 578 579def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM, 580 family=socket.AF_INET): 581 """Create a mock of a non-blocking socket.""" 582 sock = mock.MagicMock(socket.socket) 583 sock.proto = proto 584 sock.type = type 585 sock.family = family 586 sock.gettimeout.return_value = 0.0 587 return sock 588