1"""Utilities shared by tests.""" 2 3import asyncio 4import collections 5import contextlib 6import io 7import logging 8import os 9import re 10import selectors 11import socket 12import socketserver 13import sys 14import tempfile 15import threading 16import time 17import unittest 18import weakref 19 20from unittest import mock 21 22from http.server import HTTPServer 23from wsgiref.simple_server import WSGIRequestHandler, WSGIServer 24 25try: 26 import ssl 27except ImportError: # pragma: no cover 28 ssl = None 29 30from asyncio import base_events 31from asyncio import events 32from asyncio import format_helpers 33from asyncio import futures 34from asyncio import tasks 35from asyncio.log import logger 36from test import support 37from test.support import threading_helper 38 39 40def data_file(filename): 41 if hasattr(support, 'TEST_HOME_DIR'): 42 fullname = os.path.join(support.TEST_HOME_DIR, filename) 43 if os.path.isfile(fullname): 44 return fullname 45 fullname = os.path.join(os.path.dirname(__file__), '..', filename) 46 if os.path.isfile(fullname): 47 return fullname 48 raise FileNotFoundError(filename) 49 50 51ONLYCERT = data_file('ssl_cert.pem') 52ONLYKEY = data_file('ssl_key.pem') 53SIGNED_CERTFILE = data_file('keycert3.pem') 54SIGNING_CA = data_file('pycacert.pem') 55PEERCERT = { 56 'OCSP': ('http://testca.pythontest.net/testca/ocsp/',), 57 'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',), 58 'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',), 59 'issuer': ((('countryName', 'XY'),), 60 (('organizationName', 'Python Software Foundation CA'),), 61 (('commonName', 'our-ca-server'),)), 62 'notAfter': 'Oct 28 14:23:16 2037 GMT', 63 'notBefore': 'Aug 29 14:23:16 2018 GMT', 64 'serialNumber': 'CB2D80995A69525C', 65 'subject': ((('countryName', 'XY'),), 66 (('localityName', 'Castle Anthrax'),), 67 (('organizationName', 'Python Software Foundation'),), 68 (('commonName', 'localhost'),)), 69 'subjectAltName': (('DNS', 'localhost'),), 70 'version': 3 71} 72 73 74def simple_server_sslcontext(): 75 server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 76 server_context.load_cert_chain(ONLYCERT, ONLYKEY) 77 server_context.check_hostname = False 78 server_context.verify_mode = ssl.CERT_NONE 79 return server_context 80 81 82def simple_client_sslcontext(*, disable_verify=True): 83 client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) 84 client_context.check_hostname = False 85 if disable_verify: 86 client_context.verify_mode = ssl.CERT_NONE 87 return client_context 88 89 90def dummy_ssl_context(): 91 if ssl is None: 92 return None 93 else: 94 return simple_client_sslcontext(disable_verify=True) 95 96 97def run_briefly(loop): 98 async def once(): 99 pass 100 gen = once() 101 t = loop.create_task(gen) 102 # Don't log a warning if the task is not done after run_until_complete(). 103 # It occurs if the loop is stopped or if a task raises a BaseException. 104 t._log_destroy_pending = False 105 try: 106 loop.run_until_complete(t) 107 finally: 108 gen.close() 109 110 111def run_until(loop, pred, timeout=support.SHORT_TIMEOUT): 112 deadline = time.monotonic() + timeout 113 while not pred(): 114 if timeout is not None: 115 timeout = deadline - time.monotonic() 116 if timeout <= 0: 117 raise futures.TimeoutError() 118 loop.run_until_complete(tasks.sleep(0.001)) 119 120 121def run_once(loop): 122 """Legacy API to run once through the event loop. 123 124 This is the recommended pattern for test code. It will poll the 125 selector once and run all callbacks scheduled in response to I/O 126 events. 127 """ 128 loop.call_soon(loop.stop) 129 loop.run_forever() 130 131 132class SilentWSGIRequestHandler(WSGIRequestHandler): 133 134 def get_stderr(self): 135 return io.StringIO() 136 137 def log_message(self, format, *args): 138 pass 139 140 141class SilentWSGIServer(WSGIServer): 142 143 request_timeout = support.LOOPBACK_TIMEOUT 144 145 def get_request(self): 146 request, client_addr = super().get_request() 147 request.settimeout(self.request_timeout) 148 return request, client_addr 149 150 def handle_error(self, request, client_address): 151 pass 152 153 154class SSLWSGIServerMixin: 155 156 def finish_request(self, request, client_address): 157 # The relative location of our test directory (which 158 # contains the ssl key and certificate files) differs 159 # between the stdlib and stand-alone asyncio. 160 # Prefer our own if we can find it. 161 context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 162 context.load_cert_chain(ONLYCERT, ONLYKEY) 163 164 ssock = context.wrap_socket(request, server_side=True) 165 try: 166 self.RequestHandlerClass(ssock, client_address, self) 167 ssock.close() 168 except OSError: 169 # maybe socket has been closed by peer 170 pass 171 172 173class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): 174 pass 175 176 177def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): 178 179 def loop(environ): 180 size = int(environ['CONTENT_LENGTH']) 181 while size: 182 data = environ['wsgi.input'].read(min(size, 0x10000)) 183 yield data 184 size -= len(data) 185 186 def app(environ, start_response): 187 status = '200 OK' 188 headers = [('Content-type', 'text/plain')] 189 start_response(status, headers) 190 if environ['PATH_INFO'] == '/loop': 191 return loop(environ) 192 else: 193 return [b'Test message'] 194 195 # Run the test WSGI server in a separate thread in order not to 196 # interfere with event handling in the main thread 197 server_class = server_ssl_cls if use_ssl else server_cls 198 httpd = server_class(address, SilentWSGIRequestHandler) 199 httpd.set_app(app) 200 httpd.address = httpd.server_address 201 server_thread = threading.Thread( 202 target=lambda: httpd.serve_forever(poll_interval=0.05)) 203 server_thread.start() 204 try: 205 yield httpd 206 finally: 207 httpd.shutdown() 208 httpd.server_close() 209 server_thread.join() 210 211 212if hasattr(socket, 'AF_UNIX'): 213 214 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): 215 216 def server_bind(self): 217 socketserver.UnixStreamServer.server_bind(self) 218 self.server_name = '127.0.0.1' 219 self.server_port = 80 220 221 222 class UnixWSGIServer(UnixHTTPServer, WSGIServer): 223 224 request_timeout = support.LOOPBACK_TIMEOUT 225 226 def server_bind(self): 227 UnixHTTPServer.server_bind(self) 228 self.setup_environ() 229 230 def get_request(self): 231 request, client_addr = super().get_request() 232 request.settimeout(self.request_timeout) 233 # Code in the stdlib expects that get_request 234 # will return a socket and a tuple (host, port). 235 # However, this isn't true for UNIX sockets, 236 # as the second return value will be a path; 237 # hence we return some fake data sufficient 238 # to get the tests going 239 return request, ('127.0.0.1', '') 240 241 242 class SilentUnixWSGIServer(UnixWSGIServer): 243 244 def handle_error(self, request, client_address): 245 pass 246 247 248 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): 249 pass 250 251 252 def gen_unix_socket_path(): 253 with tempfile.NamedTemporaryFile() as file: 254 return file.name 255 256 257 @contextlib.contextmanager 258 def unix_socket_path(): 259 path = gen_unix_socket_path() 260 try: 261 yield path 262 finally: 263 try: 264 os.unlink(path) 265 except OSError: 266 pass 267 268 269 @contextlib.contextmanager 270 def run_test_unix_server(*, use_ssl=False): 271 with unix_socket_path() as path: 272 yield from _run_test_server(address=path, use_ssl=use_ssl, 273 server_cls=SilentUnixWSGIServer, 274 server_ssl_cls=UnixSSLWSGIServer) 275 276 277@contextlib.contextmanager 278def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): 279 yield from _run_test_server(address=(host, port), use_ssl=use_ssl, 280 server_cls=SilentWSGIServer, 281 server_ssl_cls=SSLWSGIServer) 282 283 284def make_test_protocol(base): 285 dct = {} 286 for name in dir(base): 287 if name.startswith('__') and name.endswith('__'): 288 # skip magic names 289 continue 290 dct[name] = MockCallback(return_value=None) 291 return type('TestProtocol', (base,) + base.__bases__, dct)() 292 293 294class TestSelector(selectors.BaseSelector): 295 296 def __init__(self): 297 self.keys = {} 298 299 def register(self, fileobj, events, data=None): 300 key = selectors.SelectorKey(fileobj, 0, events, data) 301 self.keys[fileobj] = key 302 return key 303 304 def unregister(self, fileobj): 305 return self.keys.pop(fileobj) 306 307 def select(self, timeout): 308 return [] 309 310 def get_map(self): 311 return self.keys 312 313 314class TestLoop(base_events.BaseEventLoop): 315 """Loop for unittests. 316 317 It manages self time directly. 318 If something scheduled to be executed later then 319 on next loop iteration after all ready handlers done 320 generator passed to __init__ is calling. 321 322 Generator should be like this: 323 324 def gen(): 325 ... 326 when = yield ... 327 ... = yield time_advance 328 329 Value returned by yield is absolute time of next scheduled handler. 330 Value passed to yield is time advance to move loop's time forward. 331 """ 332 333 def __init__(self, gen=None): 334 super().__init__() 335 336 if gen is None: 337 def gen(): 338 yield 339 self._check_on_close = False 340 else: 341 self._check_on_close = True 342 343 self._gen = gen() 344 next(self._gen) 345 self._time = 0 346 self._clock_resolution = 1e-9 347 self._timers = [] 348 self._selector = TestSelector() 349 350 self.readers = {} 351 self.writers = {} 352 self.reset_counters() 353 354 self._transports = weakref.WeakValueDictionary() 355 356 def time(self): 357 return self._time 358 359 def advance_time(self, advance): 360 """Move test time forward.""" 361 if advance: 362 self._time += advance 363 364 def close(self): 365 super().close() 366 if self._check_on_close: 367 try: 368 self._gen.send(0) 369 except StopIteration: 370 pass 371 else: # pragma: no cover 372 raise AssertionError("Time generator is not finished") 373 374 def _add_reader(self, fd, callback, *args): 375 self.readers[fd] = events.Handle(callback, args, self, None) 376 377 def _remove_reader(self, fd): 378 self.remove_reader_count[fd] += 1 379 if fd in self.readers: 380 del self.readers[fd] 381 return True 382 else: 383 return False 384 385 def assert_reader(self, fd, callback, *args): 386 if fd not in self.readers: 387 raise AssertionError(f'fd {fd} is not registered') 388 handle = self.readers[fd] 389 if handle._callback != callback: 390 raise AssertionError( 391 f'unexpected callback: {handle._callback} != {callback}') 392 if handle._args != args: 393 raise AssertionError( 394 f'unexpected callback args: {handle._args} != {args}') 395 396 def assert_no_reader(self, fd): 397 if fd in self.readers: 398 raise AssertionError(f'fd {fd} is registered') 399 400 def _add_writer(self, fd, callback, *args): 401 self.writers[fd] = events.Handle(callback, args, self, None) 402 403 def _remove_writer(self, fd): 404 self.remove_writer_count[fd] += 1 405 if fd in self.writers: 406 del self.writers[fd] 407 return True 408 else: 409 return False 410 411 def assert_writer(self, fd, callback, *args): 412 if fd not in self.writers: 413 raise AssertionError(f'fd {fd} is not registered') 414 handle = self.writers[fd] 415 if handle._callback != callback: 416 raise AssertionError(f'{handle._callback!r} != {callback!r}') 417 if handle._args != args: 418 raise AssertionError(f'{handle._args!r} != {args!r}') 419 420 def _ensure_fd_no_transport(self, fd): 421 if not isinstance(fd, int): 422 try: 423 fd = int(fd.fileno()) 424 except (AttributeError, TypeError, ValueError): 425 # This code matches selectors._fileobj_to_fd function. 426 raise ValueError("Invalid file object: " 427 "{!r}".format(fd)) from None 428 try: 429 transport = self._transports[fd] 430 except KeyError: 431 pass 432 else: 433 raise RuntimeError( 434 'File descriptor {!r} is used by transport {!r}'.format( 435 fd, transport)) 436 437 def add_reader(self, fd, callback, *args): 438 """Add a reader callback.""" 439 self._ensure_fd_no_transport(fd) 440 return self._add_reader(fd, callback, *args) 441 442 def remove_reader(self, fd): 443 """Remove a reader callback.""" 444 self._ensure_fd_no_transport(fd) 445 return self._remove_reader(fd) 446 447 def add_writer(self, fd, callback, *args): 448 """Add a writer callback..""" 449 self._ensure_fd_no_transport(fd) 450 return self._add_writer(fd, callback, *args) 451 452 def remove_writer(self, fd): 453 """Remove a writer callback.""" 454 self._ensure_fd_no_transport(fd) 455 return self._remove_writer(fd) 456 457 def reset_counters(self): 458 self.remove_reader_count = collections.defaultdict(int) 459 self.remove_writer_count = collections.defaultdict(int) 460 461 def _run_once(self): 462 super()._run_once() 463 for when in self._timers: 464 advance = self._gen.send(when) 465 self.advance_time(advance) 466 self._timers = [] 467 468 def call_at(self, when, callback, *args, context=None): 469 self._timers.append(when) 470 return super().call_at(when, callback, *args, context=context) 471 472 def _process_events(self, event_list): 473 return 474 475 def _write_to_self(self): 476 pass 477 478 479def MockCallback(**kwargs): 480 return mock.Mock(spec=['__call__'], **kwargs) 481 482 483class MockPattern(str): 484 """A regex based str with a fuzzy __eq__. 485 486 Use this helper with 'mock.assert_called_with', or anywhere 487 where a regex comparison between strings is needed. 488 489 For instance: 490 mock_call.assert_called_with(MockPattern('spam.*ham')) 491 """ 492 def __eq__(self, other): 493 return bool(re.search(str(self), other, re.S)) 494 495 496class MockInstanceOf: 497 def __init__(self, type): 498 self._type = type 499 500 def __eq__(self, other): 501 return isinstance(other, self._type) 502 503 504def get_function_source(func): 505 source = format_helpers._get_function_source(func) 506 if source is None: 507 raise ValueError("unable to get the source of %r" % (func,)) 508 return source 509 510 511class TestCase(unittest.TestCase): 512 @staticmethod 513 def close_loop(loop): 514 if loop._default_executor is not None: 515 if not loop.is_closed(): 516 loop.run_until_complete(loop.shutdown_default_executor()) 517 else: 518 loop._default_executor.shutdown(wait=True) 519 loop.close() 520 policy = support.maybe_get_event_loop_policy() 521 if policy is not None: 522 try: 523 watcher = policy.get_child_watcher() 524 except NotImplementedError: 525 # watcher is not implemented by EventLoopPolicy, e.g. Windows 526 pass 527 else: 528 if isinstance(watcher, asyncio.ThreadedChildWatcher): 529 threads = list(watcher._threads.values()) 530 for thread in threads: 531 thread.join() 532 533 def set_event_loop(self, loop, *, cleanup=True): 534 if loop is None: 535 raise AssertionError('loop is None') 536 # ensure that the event loop is passed explicitly in asyncio 537 events.set_event_loop(None) 538 if cleanup: 539 self.addCleanup(self.close_loop, loop) 540 541 def new_test_loop(self, gen=None): 542 loop = TestLoop(gen) 543 self.set_event_loop(loop) 544 return loop 545 546 def setUp(self): 547 self._thread_cleanup = threading_helper.threading_setup() 548 549 def tearDown(self): 550 events.set_event_loop(None) 551 552 # Detect CPython bug #23353: ensure that yield/yield-from is not used 553 # in an except block of a generator 554 self.assertEqual(sys.exc_info(), (None, None, None)) 555 556 self.doCleanups() 557 threading_helper.threading_cleanup(*self._thread_cleanup) 558 support.reap_children() 559 560 561@contextlib.contextmanager 562def disable_logger(): 563 """Context manager to disable asyncio logger. 564 565 For example, it can be used to ignore warnings in debug mode. 566 """ 567 old_level = logger.level 568 try: 569 logger.setLevel(logging.CRITICAL+1) 570 yield 571 finally: 572 logger.setLevel(old_level) 573 574 575def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM, 576 family=socket.AF_INET): 577 """Create a mock of a non-blocking socket.""" 578 sock = mock.MagicMock(socket.socket) 579 sock.proto = proto 580 sock.type = type 581 sock.family = family 582 sock.gettimeout.return_value = 0.0 583 return sock 584