1"""Utilities shared by tests.""" 2 3import collections 4import contextlib 5import io 6import logging 7import os 8import re 9import selectors 10import socket 11import socketserver 12import sys 13import tempfile 14import threading 15import time 16import unittest 17import weakref 18 19from unittest import mock 20 21from http.server import HTTPServer 22from wsgiref.simple_server import WSGIRequestHandler, WSGIServer 23 24try: 25 import ssl 26except ImportError: # pragma: no cover 27 ssl = None 28 29from asyncio import base_events 30from asyncio import events 31from asyncio import format_helpers 32from asyncio import futures 33from asyncio import tasks 34from asyncio.log import logger 35from test import support 36 37 38def data_file(filename): 39 if hasattr(support, 'TEST_HOME_DIR'): 40 fullname = os.path.join(support.TEST_HOME_DIR, filename) 41 if os.path.isfile(fullname): 42 return fullname 43 fullname = os.path.join(os.path.dirname(__file__), '..', filename) 44 if os.path.isfile(fullname): 45 return fullname 46 raise FileNotFoundError(filename) 47 48 49ONLYCERT = data_file('ssl_cert.pem') 50ONLYKEY = data_file('ssl_key.pem') 51SIGNED_CERTFILE = data_file('keycert3.pem') 52SIGNING_CA = data_file('pycacert.pem') 53PEERCERT = { 54 'OCSP': ('http://testca.pythontest.net/testca/ocsp/',), 55 'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',), 56 'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',), 57 'issuer': ((('countryName', 'XY'),), 58 (('organizationName', 'Python Software Foundation CA'),), 59 (('commonName', 'our-ca-server'),)), 60 'notAfter': 'Jul 7 14:23:16 2028 GMT', 61 'notBefore': 'Aug 29 14:23:16 2018 GMT', 62 'serialNumber': 'CB2D80995A69525C', 63 'subject': ((('countryName', 'XY'),), 64 (('localityName', 'Castle Anthrax'),), 65 (('organizationName', 'Python Software Foundation'),), 66 (('commonName', 'localhost'),)), 67 'subjectAltName': (('DNS', 'localhost'),), 68 'version': 3 69} 70 71 72def simple_server_sslcontext(): 73 server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 74 server_context.load_cert_chain(ONLYCERT, ONLYKEY) 75 server_context.check_hostname = False 76 server_context.verify_mode = ssl.CERT_NONE 77 return server_context 78 79 80def simple_client_sslcontext(*, disable_verify=True): 81 client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) 82 client_context.check_hostname = False 83 if disable_verify: 84 client_context.verify_mode = ssl.CERT_NONE 85 return client_context 86 87 88def dummy_ssl_context(): 89 if ssl is None: 90 return None 91 else: 92 return ssl.SSLContext(ssl.PROTOCOL_TLS) 93 94 95def run_briefly(loop): 96 async def once(): 97 pass 98 gen = once() 99 t = loop.create_task(gen) 100 # Don't log a warning if the task is not done after run_until_complete(). 101 # It occurs if the loop is stopped or if a task raises a BaseException. 102 t._log_destroy_pending = False 103 try: 104 loop.run_until_complete(t) 105 finally: 106 gen.close() 107 108 109def run_until(loop, pred, timeout=30): 110 deadline = time.monotonic() + timeout 111 while not pred(): 112 if timeout is not None: 113 timeout = deadline - time.monotonic() 114 if timeout <= 0: 115 raise futures.TimeoutError() 116 loop.run_until_complete(tasks.sleep(0.001, loop=loop)) 117 118 119def run_once(loop): 120 """Legacy API to run once through the event loop. 121 122 This is the recommended pattern for test code. It will poll the 123 selector once and run all callbacks scheduled in response to I/O 124 events. 125 """ 126 loop.call_soon(loop.stop) 127 loop.run_forever() 128 129 130class SilentWSGIRequestHandler(WSGIRequestHandler): 131 132 def get_stderr(self): 133 return io.StringIO() 134 135 def log_message(self, format, *args): 136 pass 137 138 139class SilentWSGIServer(WSGIServer): 140 141 request_timeout = 2 142 143 def get_request(self): 144 request, client_addr = super().get_request() 145 request.settimeout(self.request_timeout) 146 return request, client_addr 147 148 def handle_error(self, request, client_address): 149 pass 150 151 152class SSLWSGIServerMixin: 153 154 def finish_request(self, request, client_address): 155 # The relative location of our test directory (which 156 # contains the ssl key and certificate files) differs 157 # between the stdlib and stand-alone asyncio. 158 # Prefer our own if we can find it. 159 context = ssl.SSLContext() 160 context.load_cert_chain(ONLYCERT, ONLYKEY) 161 162 ssock = context.wrap_socket(request, server_side=True) 163 try: 164 self.RequestHandlerClass(ssock, client_address, self) 165 ssock.close() 166 except OSError: 167 # maybe socket has been closed by peer 168 pass 169 170 171class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): 172 pass 173 174 175def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): 176 177 def app(environ, start_response): 178 status = '200 OK' 179 headers = [('Content-type', 'text/plain')] 180 start_response(status, headers) 181 return [b'Test message'] 182 183 # Run the test WSGI server in a separate thread in order not to 184 # interfere with event handling in the main thread 185 server_class = server_ssl_cls if use_ssl else server_cls 186 httpd = server_class(address, SilentWSGIRequestHandler) 187 httpd.set_app(app) 188 httpd.address = httpd.server_address 189 server_thread = threading.Thread( 190 target=lambda: httpd.serve_forever(poll_interval=0.05)) 191 server_thread.start() 192 try: 193 yield httpd 194 finally: 195 httpd.shutdown() 196 httpd.server_close() 197 server_thread.join() 198 199 200if hasattr(socket, 'AF_UNIX'): 201 202 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): 203 204 def server_bind(self): 205 socketserver.UnixStreamServer.server_bind(self) 206 self.server_name = '127.0.0.1' 207 self.server_port = 80 208 209 210 class UnixWSGIServer(UnixHTTPServer, WSGIServer): 211 212 request_timeout = 2 213 214 def server_bind(self): 215 UnixHTTPServer.server_bind(self) 216 self.setup_environ() 217 218 def get_request(self): 219 request, client_addr = super().get_request() 220 request.settimeout(self.request_timeout) 221 # Code in the stdlib expects that get_request 222 # will return a socket and a tuple (host, port). 223 # However, this isn't true for UNIX sockets, 224 # as the second return value will be a path; 225 # hence we return some fake data sufficient 226 # to get the tests going 227 return request, ('127.0.0.1', '') 228 229 230 class SilentUnixWSGIServer(UnixWSGIServer): 231 232 def handle_error(self, request, client_address): 233 pass 234 235 236 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): 237 pass 238 239 240 def gen_unix_socket_path(): 241 with tempfile.NamedTemporaryFile() as file: 242 return file.name 243 244 245 @contextlib.contextmanager 246 def unix_socket_path(): 247 path = gen_unix_socket_path() 248 try: 249 yield path 250 finally: 251 try: 252 os.unlink(path) 253 except OSError: 254 pass 255 256 257 @contextlib.contextmanager 258 def run_test_unix_server(*, use_ssl=False): 259 with unix_socket_path() as path: 260 yield from _run_test_server(address=path, use_ssl=use_ssl, 261 server_cls=SilentUnixWSGIServer, 262 server_ssl_cls=UnixSSLWSGIServer) 263 264 265@contextlib.contextmanager 266def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): 267 yield from _run_test_server(address=(host, port), use_ssl=use_ssl, 268 server_cls=SilentWSGIServer, 269 server_ssl_cls=SSLWSGIServer) 270 271 272def make_test_protocol(base): 273 dct = {} 274 for name in dir(base): 275 if name.startswith('__') and name.endswith('__'): 276 # skip magic names 277 continue 278 dct[name] = MockCallback(return_value=None) 279 return type('TestProtocol', (base,) + base.__bases__, dct)() 280 281 282class TestSelector(selectors.BaseSelector): 283 284 def __init__(self): 285 self.keys = {} 286 287 def register(self, fileobj, events, data=None): 288 key = selectors.SelectorKey(fileobj, 0, events, data) 289 self.keys[fileobj] = key 290 return key 291 292 def unregister(self, fileobj): 293 return self.keys.pop(fileobj) 294 295 def select(self, timeout): 296 return [] 297 298 def get_map(self): 299 return self.keys 300 301 302class TestLoop(base_events.BaseEventLoop): 303 """Loop for unittests. 304 305 It manages self time directly. 306 If something scheduled to be executed later then 307 on next loop iteration after all ready handlers done 308 generator passed to __init__ is calling. 309 310 Generator should be like this: 311 312 def gen(): 313 ... 314 when = yield ... 315 ... = yield time_advance 316 317 Value returned by yield is absolute time of next scheduled handler. 318 Value passed to yield is time advance to move loop's time forward. 319 """ 320 321 def __init__(self, gen=None): 322 super().__init__() 323 324 if gen is None: 325 def gen(): 326 yield 327 self._check_on_close = False 328 else: 329 self._check_on_close = True 330 331 self._gen = gen() 332 next(self._gen) 333 self._time = 0 334 self._clock_resolution = 1e-9 335 self._timers = [] 336 self._selector = TestSelector() 337 338 self.readers = {} 339 self.writers = {} 340 self.reset_counters() 341 342 self._transports = weakref.WeakValueDictionary() 343 344 def time(self): 345 return self._time 346 347 def advance_time(self, advance): 348 """Move test time forward.""" 349 if advance: 350 self._time += advance 351 352 def close(self): 353 super().close() 354 if self._check_on_close: 355 try: 356 self._gen.send(0) 357 except StopIteration: 358 pass 359 else: # pragma: no cover 360 raise AssertionError("Time generator is not finished") 361 362 def _add_reader(self, fd, callback, *args): 363 self.readers[fd] = events.Handle(callback, args, self, None) 364 365 def _remove_reader(self, fd): 366 self.remove_reader_count[fd] += 1 367 if fd in self.readers: 368 del self.readers[fd] 369 return True 370 else: 371 return False 372 373 def assert_reader(self, fd, callback, *args): 374 if fd not in self.readers: 375 raise AssertionError(f'fd {fd} is not registered') 376 handle = self.readers[fd] 377 if handle._callback != callback: 378 raise AssertionError( 379 f'unexpected callback: {handle._callback} != {callback}') 380 if handle._args != args: 381 raise AssertionError( 382 f'unexpected callback args: {handle._args} != {args}') 383 384 def assert_no_reader(self, fd): 385 if fd in self.readers: 386 raise AssertionError(f'fd {fd} is registered') 387 388 def _add_writer(self, fd, callback, *args): 389 self.writers[fd] = events.Handle(callback, args, self, None) 390 391 def _remove_writer(self, fd): 392 self.remove_writer_count[fd] += 1 393 if fd in self.writers: 394 del self.writers[fd] 395 return True 396 else: 397 return False 398 399 def assert_writer(self, fd, callback, *args): 400 assert fd in self.writers, 'fd {} is not registered'.format(fd) 401 handle = self.writers[fd] 402 assert handle._callback == callback, '{!r} != {!r}'.format( 403 handle._callback, callback) 404 assert handle._args == args, '{!r} != {!r}'.format( 405 handle._args, args) 406 407 def _ensure_fd_no_transport(self, fd): 408 if not isinstance(fd, int): 409 try: 410 fd = int(fd.fileno()) 411 except (AttributeError, TypeError, ValueError): 412 # This code matches selectors._fileobj_to_fd function. 413 raise ValueError("Invalid file object: " 414 "{!r}".format(fd)) from None 415 try: 416 transport = self._transports[fd] 417 except KeyError: 418 pass 419 else: 420 raise RuntimeError( 421 'File descriptor {!r} is used by transport {!r}'.format( 422 fd, transport)) 423 424 def add_reader(self, fd, callback, *args): 425 """Add a reader callback.""" 426 self._ensure_fd_no_transport(fd) 427 return self._add_reader(fd, callback, *args) 428 429 def remove_reader(self, fd): 430 """Remove a reader callback.""" 431 self._ensure_fd_no_transport(fd) 432 return self._remove_reader(fd) 433 434 def add_writer(self, fd, callback, *args): 435 """Add a writer callback..""" 436 self._ensure_fd_no_transport(fd) 437 return self._add_writer(fd, callback, *args) 438 439 def remove_writer(self, fd): 440 """Remove a writer callback.""" 441 self._ensure_fd_no_transport(fd) 442 return self._remove_writer(fd) 443 444 def reset_counters(self): 445 self.remove_reader_count = collections.defaultdict(int) 446 self.remove_writer_count = collections.defaultdict(int) 447 448 def _run_once(self): 449 super()._run_once() 450 for when in self._timers: 451 advance = self._gen.send(when) 452 self.advance_time(advance) 453 self._timers = [] 454 455 def call_at(self, when, callback, *args, context=None): 456 self._timers.append(when) 457 return super().call_at(when, callback, *args, context=context) 458 459 def _process_events(self, event_list): 460 return 461 462 def _write_to_self(self): 463 pass 464 465 466def MockCallback(**kwargs): 467 return mock.Mock(spec=['__call__'], **kwargs) 468 469 470class MockPattern(str): 471 """A regex based str with a fuzzy __eq__. 472 473 Use this helper with 'mock.assert_called_with', or anywhere 474 where a regex comparison between strings is needed. 475 476 For instance: 477 mock_call.assert_called_with(MockPattern('spam.*ham')) 478 """ 479 def __eq__(self, other): 480 return bool(re.search(str(self), other, re.S)) 481 482 483class MockInstanceOf: 484 def __init__(self, type): 485 self._type = type 486 487 def __eq__(self, other): 488 return isinstance(other, self._type) 489 490 491def get_function_source(func): 492 source = format_helpers._get_function_source(func) 493 if source is None: 494 raise ValueError("unable to get the source of %r" % (func,)) 495 return source 496 497 498class TestCase(unittest.TestCase): 499 @staticmethod 500 def close_loop(loop): 501 executor = loop._default_executor 502 if executor is not None: 503 executor.shutdown(wait=True) 504 loop.close() 505 506 def set_event_loop(self, loop, *, cleanup=True): 507 assert loop is not None 508 # ensure that the event loop is passed explicitly in asyncio 509 events.set_event_loop(None) 510 if cleanup: 511 self.addCleanup(self.close_loop, loop) 512 513 def new_test_loop(self, gen=None): 514 loop = TestLoop(gen) 515 self.set_event_loop(loop) 516 return loop 517 518 def unpatch_get_running_loop(self): 519 events._get_running_loop = self._get_running_loop 520 521 def setUp(self): 522 self._get_running_loop = events._get_running_loop 523 events._get_running_loop = lambda: None 524 self._thread_cleanup = support.threading_setup() 525 526 def tearDown(self): 527 self.unpatch_get_running_loop() 528 529 events.set_event_loop(None) 530 531 # Detect CPython bug #23353: ensure that yield/yield-from is not used 532 # in an except block of a generator 533 self.assertEqual(sys.exc_info(), (None, None, None)) 534 535 self.doCleanups() 536 support.threading_cleanup(*self._thread_cleanup) 537 support.reap_children() 538 539 540@contextlib.contextmanager 541def disable_logger(): 542 """Context manager to disable asyncio logger. 543 544 For example, it can be used to ignore warnings in debug mode. 545 """ 546 old_level = logger.level 547 try: 548 logger.setLevel(logging.CRITICAL+1) 549 yield 550 finally: 551 logger.setLevel(old_level) 552 553 554def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM, 555 family=socket.AF_INET): 556 """Create a mock of a non-blocking socket.""" 557 sock = mock.MagicMock(socket.socket) 558 sock.proto = proto 559 sock.type = type 560 sock.family = family 561 sock.gettimeout.return_value = 0.0 562 return sock 563