1"""Utilities shared by tests.""" 2 3import collections 4import contextlib 5import io 6import logging 7import os 8import re 9import socket 10import socketserver 11import sys 12import tempfile 13import threading 14import time 15import unittest 16import weakref 17 18from unittest import mock 19 20from http.server import HTTPServer 21from wsgiref.simple_server import WSGIRequestHandler, WSGIServer 22 23try: 24 import ssl 25except ImportError: # pragma: no cover 26 ssl = None 27 28from . import base_events 29from . import compat 30from . import events 31from . import futures 32from . import selectors 33from . import tasks 34from .coroutines import coroutine 35from .log import logger 36 37 38if sys.platform == 'win32': # pragma: no cover 39 from .windows_utils import socketpair 40else: 41 from socket import socketpair # pragma: no cover 42 43 44def dummy_ssl_context(): 45 if ssl is None: 46 return None 47 else: 48 return ssl.SSLContext(ssl.PROTOCOL_SSLv23) 49 50 51def run_briefly(loop): 52 @coroutine 53 def once(): 54 pass 55 gen = once() 56 t = loop.create_task(gen) 57 # Don't log a warning if the task is not done after run_until_complete(). 58 # It occurs if the loop is stopped or if a task raises a BaseException. 59 t._log_destroy_pending = False 60 try: 61 loop.run_until_complete(t) 62 finally: 63 gen.close() 64 65 66def run_until(loop, pred, timeout=30): 67 deadline = time.time() + timeout 68 while not pred(): 69 if timeout is not None: 70 timeout = deadline - time.time() 71 if timeout <= 0: 72 raise futures.TimeoutError() 73 loop.run_until_complete(tasks.sleep(0.001, loop=loop)) 74 75 76def run_once(loop): 77 """Legacy API to run once through the event loop. 78 79 This is the recommended pattern for test code. It will poll the 80 selector once and run all callbacks scheduled in response to I/O 81 events. 82 """ 83 loop.call_soon(loop.stop) 84 loop.run_forever() 85 86 87class SilentWSGIRequestHandler(WSGIRequestHandler): 88 89 def get_stderr(self): 90 return io.StringIO() 91 92 def log_message(self, format, *args): 93 pass 94 95 96class SilentWSGIServer(WSGIServer): 97 98 request_timeout = 2 99 100 def get_request(self): 101 request, client_addr = super().get_request() 102 request.settimeout(self.request_timeout) 103 return request, client_addr 104 105 def handle_error(self, request, client_address): 106 pass 107 108 109class SSLWSGIServerMixin: 110 111 def finish_request(self, request, client_address): 112 # The relative location of our test directory (which 113 # contains the ssl key and certificate files) differs 114 # between the stdlib and stand-alone asyncio. 115 # Prefer our own if we can find it. 116 here = os.path.join(os.path.dirname(__file__), '..', 'tests') 117 if not os.path.isdir(here): 118 here = os.path.join(os.path.dirname(os.__file__), 119 'test', 'test_asyncio') 120 keyfile = os.path.join(here, 'ssl_key.pem') 121 certfile = os.path.join(here, 'ssl_cert.pem') 122 context = ssl.SSLContext() 123 context.load_cert_chain(certfile, keyfile) 124 125 ssock = context.wrap_socket(request, server_side=True) 126 try: 127 self.RequestHandlerClass(ssock, client_address, self) 128 ssock.close() 129 except OSError: 130 # maybe socket has been closed by peer 131 pass 132 133 134class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): 135 pass 136 137 138def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): 139 140 def app(environ, start_response): 141 status = '200 OK' 142 headers = [('Content-type', 'text/plain')] 143 start_response(status, headers) 144 return [b'Test message'] 145 146 # Run the test WSGI server in a separate thread in order not to 147 # interfere with event handling in the main thread 148 server_class = server_ssl_cls if use_ssl else server_cls 149 httpd = server_class(address, SilentWSGIRequestHandler) 150 httpd.set_app(app) 151 httpd.address = httpd.server_address 152 server_thread = threading.Thread( 153 target=lambda: httpd.serve_forever(poll_interval=0.05)) 154 server_thread.start() 155 try: 156 yield httpd 157 finally: 158 httpd.shutdown() 159 httpd.server_close() 160 server_thread.join() 161 162 163if hasattr(socket, 'AF_UNIX'): 164 165 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): 166 167 def server_bind(self): 168 socketserver.UnixStreamServer.server_bind(self) 169 self.server_name = '127.0.0.1' 170 self.server_port = 80 171 172 173 class UnixWSGIServer(UnixHTTPServer, WSGIServer): 174 175 request_timeout = 2 176 177 def server_bind(self): 178 UnixHTTPServer.server_bind(self) 179 self.setup_environ() 180 181 def get_request(self): 182 request, client_addr = super().get_request() 183 request.settimeout(self.request_timeout) 184 # Code in the stdlib expects that get_request 185 # will return a socket and a tuple (host, port). 186 # However, this isn't true for UNIX sockets, 187 # as the second return value will be a path; 188 # hence we return some fake data sufficient 189 # to get the tests going 190 return request, ('127.0.0.1', '') 191 192 193 class SilentUnixWSGIServer(UnixWSGIServer): 194 195 def handle_error(self, request, client_address): 196 pass 197 198 199 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): 200 pass 201 202 203 def gen_unix_socket_path(): 204 with tempfile.NamedTemporaryFile() as file: 205 return file.name 206 207 208 @contextlib.contextmanager 209 def unix_socket_path(): 210 path = gen_unix_socket_path() 211 try: 212 yield path 213 finally: 214 try: 215 os.unlink(path) 216 except OSError: 217 pass 218 219 220 @contextlib.contextmanager 221 def run_test_unix_server(*, use_ssl=False): 222 with unix_socket_path() as path: 223 yield from _run_test_server(address=path, use_ssl=use_ssl, 224 server_cls=SilentUnixWSGIServer, 225 server_ssl_cls=UnixSSLWSGIServer) 226 227 228@contextlib.contextmanager 229def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): 230 yield from _run_test_server(address=(host, port), use_ssl=use_ssl, 231 server_cls=SilentWSGIServer, 232 server_ssl_cls=SSLWSGIServer) 233 234 235def make_test_protocol(base): 236 dct = {} 237 for name in dir(base): 238 if name.startswith('__') and name.endswith('__'): 239 # skip magic names 240 continue 241 dct[name] = MockCallback(return_value=None) 242 return type('TestProtocol', (base,) + base.__bases__, dct)() 243 244 245class TestSelector(selectors.BaseSelector): 246 247 def __init__(self): 248 self.keys = {} 249 250 def register(self, fileobj, events, data=None): 251 key = selectors.SelectorKey(fileobj, 0, events, data) 252 self.keys[fileobj] = key 253 return key 254 255 def unregister(self, fileobj): 256 return self.keys.pop(fileobj) 257 258 def select(self, timeout): 259 return [] 260 261 def get_map(self): 262 return self.keys 263 264 265class TestLoop(base_events.BaseEventLoop): 266 """Loop for unittests. 267 268 It manages self time directly. 269 If something scheduled to be executed later then 270 on next loop iteration after all ready handlers done 271 generator passed to __init__ is calling. 272 273 Generator should be like this: 274 275 def gen(): 276 ... 277 when = yield ... 278 ... = yield time_advance 279 280 Value returned by yield is absolute time of next scheduled handler. 281 Value passed to yield is time advance to move loop's time forward. 282 """ 283 284 def __init__(self, gen=None): 285 super().__init__() 286 287 if gen is None: 288 def gen(): 289 yield 290 self._check_on_close = False 291 else: 292 self._check_on_close = True 293 294 self._gen = gen() 295 next(self._gen) 296 self._time = 0 297 self._clock_resolution = 1e-9 298 self._timers = [] 299 self._selector = TestSelector() 300 301 self.readers = {} 302 self.writers = {} 303 self.reset_counters() 304 305 self._transports = weakref.WeakValueDictionary() 306 307 def time(self): 308 return self._time 309 310 def advance_time(self, advance): 311 """Move test time forward.""" 312 if advance: 313 self._time += advance 314 315 def close(self): 316 super().close() 317 if self._check_on_close: 318 try: 319 self._gen.send(0) 320 except StopIteration: 321 pass 322 else: # pragma: no cover 323 raise AssertionError("Time generator is not finished") 324 325 def _add_reader(self, fd, callback, *args): 326 self.readers[fd] = events.Handle(callback, args, self) 327 328 def _remove_reader(self, fd): 329 self.remove_reader_count[fd] += 1 330 if fd in self.readers: 331 del self.readers[fd] 332 return True 333 else: 334 return False 335 336 def assert_reader(self, fd, callback, *args): 337 assert fd in self.readers, 'fd {} is not registered'.format(fd) 338 handle = self.readers[fd] 339 assert handle._callback == callback, '{!r} != {!r}'.format( 340 handle._callback, callback) 341 assert handle._args == args, '{!r} != {!r}'.format( 342 handle._args, args) 343 344 def _add_writer(self, fd, callback, *args): 345 self.writers[fd] = events.Handle(callback, args, self) 346 347 def _remove_writer(self, fd): 348 self.remove_writer_count[fd] += 1 349 if fd in self.writers: 350 del self.writers[fd] 351 return True 352 else: 353 return False 354 355 def assert_writer(self, fd, callback, *args): 356 assert fd in self.writers, 'fd {} is not registered'.format(fd) 357 handle = self.writers[fd] 358 assert handle._callback == callback, '{!r} != {!r}'.format( 359 handle._callback, callback) 360 assert handle._args == args, '{!r} != {!r}'.format( 361 handle._args, args) 362 363 def _ensure_fd_no_transport(self, fd): 364 try: 365 transport = self._transports[fd] 366 except KeyError: 367 pass 368 else: 369 raise RuntimeError( 370 'File descriptor {!r} is used by transport {!r}'.format( 371 fd, transport)) 372 373 def add_reader(self, fd, callback, *args): 374 """Add a reader callback.""" 375 self._ensure_fd_no_transport(fd) 376 return self._add_reader(fd, callback, *args) 377 378 def remove_reader(self, fd): 379 """Remove a reader callback.""" 380 self._ensure_fd_no_transport(fd) 381 return self._remove_reader(fd) 382 383 def add_writer(self, fd, callback, *args): 384 """Add a writer callback..""" 385 self._ensure_fd_no_transport(fd) 386 return self._add_writer(fd, callback, *args) 387 388 def remove_writer(self, fd): 389 """Remove a writer callback.""" 390 self._ensure_fd_no_transport(fd) 391 return self._remove_writer(fd) 392 393 def reset_counters(self): 394 self.remove_reader_count = collections.defaultdict(int) 395 self.remove_writer_count = collections.defaultdict(int) 396 397 def _run_once(self): 398 super()._run_once() 399 for when in self._timers: 400 advance = self._gen.send(when) 401 self.advance_time(advance) 402 self._timers = [] 403 404 def call_at(self, when, callback, *args): 405 self._timers.append(when) 406 return super().call_at(when, callback, *args) 407 408 def _process_events(self, event_list): 409 return 410 411 def _write_to_self(self): 412 pass 413 414 415def MockCallback(**kwargs): 416 return mock.Mock(spec=['__call__'], **kwargs) 417 418 419class MockPattern(str): 420 """A regex based str with a fuzzy __eq__. 421 422 Use this helper with 'mock.assert_called_with', or anywhere 423 where a regex comparison between strings is needed. 424 425 For instance: 426 mock_call.assert_called_with(MockPattern('spam.*ham')) 427 """ 428 def __eq__(self, other): 429 return bool(re.search(str(self), other, re.S)) 430 431 432def get_function_source(func): 433 source = events._get_function_source(func) 434 if source is None: 435 raise ValueError("unable to get the source of %r" % (func,)) 436 return source 437 438 439class TestCase(unittest.TestCase): 440 def set_event_loop(self, loop, *, cleanup=True): 441 assert loop is not None 442 # ensure that the event loop is passed explicitly in asyncio 443 events.set_event_loop(None) 444 if cleanup: 445 self.addCleanup(loop.close) 446 447 def new_test_loop(self, gen=None): 448 loop = TestLoop(gen) 449 self.set_event_loop(loop) 450 return loop 451 452 def unpatch_get_running_loop(self): 453 events._get_running_loop = self._get_running_loop 454 455 def setUp(self): 456 self._get_running_loop = events._get_running_loop 457 events._get_running_loop = lambda: None 458 459 def tearDown(self): 460 self.unpatch_get_running_loop() 461 462 events.set_event_loop(None) 463 464 # Detect CPython bug #23353: ensure that yield/yield-from is not used 465 # in an except block of a generator 466 self.assertEqual(sys.exc_info(), (None, None, None)) 467 468 if not compat.PY34: 469 # Python 3.3 compatibility 470 def subTest(self, *args, **kwargs): 471 class EmptyCM: 472 def __enter__(self): 473 pass 474 def __exit__(self, *exc): 475 pass 476 return EmptyCM() 477 478 479@contextlib.contextmanager 480def disable_logger(): 481 """Context manager to disable asyncio logger. 482 483 For example, it can be used to ignore warnings in debug mode. 484 """ 485 old_level = logger.level 486 try: 487 logger.setLevel(logging.CRITICAL+1) 488 yield 489 finally: 490 logger.setLevel(old_level) 491 492 493def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM, 494 family=socket.AF_INET): 495 """Create a mock of a non-blocking socket.""" 496 sock = mock.MagicMock(socket.socket) 497 sock.proto = proto 498 sock.type = type 499 sock.family = family 500 sock.gettimeout.return_value = 0.0 501 return sock 502 503 504def force_legacy_ssl_support(): 505 return mock.patch('asyncio.sslproto._is_sslproto_available', 506 return_value=False) 507