1import asyncio 2from contextlib import ( 3 asynccontextmanager, AbstractAsyncContextManager, 4 AsyncExitStack, nullcontext, aclosing, contextmanager) 5import functools 6from test import support 7import unittest 8 9from test.test_contextlib import TestBaseExitStack 10 11 12def _async_test(func): 13 """Decorator to turn an async function into a test case.""" 14 @functools.wraps(func) 15 def wrapper(*args, **kwargs): 16 coro = func(*args, **kwargs) 17 loop = asyncio.new_event_loop() 18 asyncio.set_event_loop(loop) 19 try: 20 return loop.run_until_complete(coro) 21 finally: 22 loop.close() 23 asyncio.set_event_loop_policy(None) 24 return wrapper 25 26 27class TestAbstractAsyncContextManager(unittest.TestCase): 28 29 @_async_test 30 async def test_enter(self): 31 class DefaultEnter(AbstractAsyncContextManager): 32 async def __aexit__(self, *args): 33 await super().__aexit__(*args) 34 35 manager = DefaultEnter() 36 self.assertIs(await manager.__aenter__(), manager) 37 38 async with manager as context: 39 self.assertIs(manager, context) 40 41 @_async_test 42 async def test_async_gen_propagates_generator_exit(self): 43 # A regression test for https://bugs.python.org/issue33786. 44 45 @asynccontextmanager 46 async def ctx(): 47 yield 48 49 async def gen(): 50 async with ctx(): 51 yield 11 52 53 ret = [] 54 exc = ValueError(22) 55 with self.assertRaises(ValueError): 56 async with ctx(): 57 async for val in gen(): 58 ret.append(val) 59 raise exc 60 61 self.assertEqual(ret, [11]) 62 63 def test_exit_is_abstract(self): 64 class MissingAexit(AbstractAsyncContextManager): 65 pass 66 67 with self.assertRaises(TypeError): 68 MissingAexit() 69 70 def test_structural_subclassing(self): 71 class ManagerFromScratch: 72 async def __aenter__(self): 73 return self 74 async def __aexit__(self, exc_type, exc_value, traceback): 75 return None 76 77 self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager)) 78 79 class DefaultEnter(AbstractAsyncContextManager): 80 async def __aexit__(self, *args): 81 await super().__aexit__(*args) 82 83 self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager)) 84 85 class NoneAenter(ManagerFromScratch): 86 __aenter__ = None 87 88 self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager)) 89 90 class NoneAexit(ManagerFromScratch): 91 __aexit__ = None 92 93 self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager)) 94 95 96class AsyncContextManagerTestCase(unittest.TestCase): 97 98 @_async_test 99 async def test_contextmanager_plain(self): 100 state = [] 101 @asynccontextmanager 102 async def woohoo(): 103 state.append(1) 104 yield 42 105 state.append(999) 106 async with woohoo() as x: 107 self.assertEqual(state, [1]) 108 self.assertEqual(x, 42) 109 state.append(x) 110 self.assertEqual(state, [1, 42, 999]) 111 112 @_async_test 113 async def test_contextmanager_finally(self): 114 state = [] 115 @asynccontextmanager 116 async def woohoo(): 117 state.append(1) 118 try: 119 yield 42 120 finally: 121 state.append(999) 122 with self.assertRaises(ZeroDivisionError): 123 async with woohoo() as x: 124 self.assertEqual(state, [1]) 125 self.assertEqual(x, 42) 126 state.append(x) 127 raise ZeroDivisionError() 128 self.assertEqual(state, [1, 42, 999]) 129 130 @_async_test 131 async def test_contextmanager_no_reraise(self): 132 @asynccontextmanager 133 async def whee(): 134 yield 135 ctx = whee() 136 await ctx.__aenter__() 137 # Calling __aexit__ should not result in an exception 138 self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None)) 139 140 @_async_test 141 async def test_contextmanager_trap_yield_after_throw(self): 142 @asynccontextmanager 143 async def whoo(): 144 try: 145 yield 146 except: 147 yield 148 ctx = whoo() 149 await ctx.__aenter__() 150 with self.assertRaises(RuntimeError): 151 await ctx.__aexit__(TypeError, TypeError('foo'), None) 152 153 @_async_test 154 async def test_contextmanager_trap_no_yield(self): 155 @asynccontextmanager 156 async def whoo(): 157 if False: 158 yield 159 ctx = whoo() 160 with self.assertRaises(RuntimeError): 161 await ctx.__aenter__() 162 163 @_async_test 164 async def test_contextmanager_trap_second_yield(self): 165 @asynccontextmanager 166 async def whoo(): 167 yield 168 yield 169 ctx = whoo() 170 await ctx.__aenter__() 171 with self.assertRaises(RuntimeError): 172 await ctx.__aexit__(None, None, None) 173 174 @_async_test 175 async def test_contextmanager_non_normalised(self): 176 @asynccontextmanager 177 async def whoo(): 178 try: 179 yield 180 except RuntimeError: 181 raise SyntaxError 182 183 ctx = whoo() 184 await ctx.__aenter__() 185 with self.assertRaises(SyntaxError): 186 await ctx.__aexit__(RuntimeError, None, None) 187 188 @_async_test 189 async def test_contextmanager_except(self): 190 state = [] 191 @asynccontextmanager 192 async def woohoo(): 193 state.append(1) 194 try: 195 yield 42 196 except ZeroDivisionError as e: 197 state.append(e.args[0]) 198 self.assertEqual(state, [1, 42, 999]) 199 async with woohoo() as x: 200 self.assertEqual(state, [1]) 201 self.assertEqual(x, 42) 202 state.append(x) 203 raise ZeroDivisionError(999) 204 self.assertEqual(state, [1, 42, 999]) 205 206 @_async_test 207 async def test_contextmanager_except_stopiter(self): 208 @asynccontextmanager 209 async def woohoo(): 210 yield 211 212 class StopIterationSubclass(StopIteration): 213 pass 214 215 class StopAsyncIterationSubclass(StopAsyncIteration): 216 pass 217 218 for stop_exc in ( 219 StopIteration('spam'), 220 StopAsyncIteration('ham'), 221 StopIterationSubclass('spam'), 222 StopAsyncIterationSubclass('spam') 223 ): 224 with self.subTest(type=type(stop_exc)): 225 try: 226 async with woohoo(): 227 raise stop_exc 228 except Exception as ex: 229 self.assertIs(ex, stop_exc) 230 else: 231 self.fail(f'{stop_exc} was suppressed') 232 233 @_async_test 234 async def test_contextmanager_wrap_runtimeerror(self): 235 @asynccontextmanager 236 async def woohoo(): 237 try: 238 yield 239 except Exception as exc: 240 raise RuntimeError(f'caught {exc}') from exc 241 242 with self.assertRaises(RuntimeError): 243 async with woohoo(): 244 1 / 0 245 246 # If the context manager wrapped StopAsyncIteration in a RuntimeError, 247 # we also unwrap it, because we can't tell whether the wrapping was 248 # done by the generator machinery or by the generator itself. 249 with self.assertRaises(StopAsyncIteration): 250 async with woohoo(): 251 raise StopAsyncIteration 252 253 def _create_contextmanager_attribs(self): 254 def attribs(**kw): 255 def decorate(func): 256 for k,v in kw.items(): 257 setattr(func,k,v) 258 return func 259 return decorate 260 @asynccontextmanager 261 @attribs(foo='bar') 262 async def baz(spam): 263 """Whee!""" 264 yield 265 return baz 266 267 def test_contextmanager_attribs(self): 268 baz = self._create_contextmanager_attribs() 269 self.assertEqual(baz.__name__,'baz') 270 self.assertEqual(baz.foo, 'bar') 271 272 @support.requires_docstrings 273 def test_contextmanager_doc_attrib(self): 274 baz = self._create_contextmanager_attribs() 275 self.assertEqual(baz.__doc__, "Whee!") 276 277 @support.requires_docstrings 278 @_async_test 279 async def test_instance_docstring_given_cm_docstring(self): 280 baz = self._create_contextmanager_attribs()(None) 281 self.assertEqual(baz.__doc__, "Whee!") 282 async with baz: 283 pass # suppress warning 284 285 @_async_test 286 async def test_keywords(self): 287 # Ensure no keyword arguments are inhibited 288 @asynccontextmanager 289 async def woohoo(self, func, args, kwds): 290 yield (self, func, args, kwds) 291 async with woohoo(self=11, func=22, args=33, kwds=44) as target: 292 self.assertEqual(target, (11, 22, 33, 44)) 293 294 @_async_test 295 async def test_recursive(self): 296 depth = 0 297 ncols = 0 298 299 @asynccontextmanager 300 async def woohoo(): 301 nonlocal ncols 302 ncols += 1 303 304 nonlocal depth 305 before = depth 306 depth += 1 307 yield 308 depth -= 1 309 self.assertEqual(depth, before) 310 311 @woohoo() 312 async def recursive(): 313 if depth < 10: 314 await recursive() 315 316 await recursive() 317 318 self.assertEqual(ncols, 10) 319 self.assertEqual(depth, 0) 320 321 322class AclosingTestCase(unittest.TestCase): 323 324 @support.requires_docstrings 325 def test_instance_docs(self): 326 cm_docstring = aclosing.__doc__ 327 obj = aclosing(None) 328 self.assertEqual(obj.__doc__, cm_docstring) 329 330 @_async_test 331 async def test_aclosing(self): 332 state = [] 333 class C: 334 async def aclose(self): 335 state.append(1) 336 x = C() 337 self.assertEqual(state, []) 338 async with aclosing(x) as y: 339 self.assertEqual(x, y) 340 self.assertEqual(state, [1]) 341 342 @_async_test 343 async def test_aclosing_error(self): 344 state = [] 345 class C: 346 async def aclose(self): 347 state.append(1) 348 x = C() 349 self.assertEqual(state, []) 350 with self.assertRaises(ZeroDivisionError): 351 async with aclosing(x) as y: 352 self.assertEqual(x, y) 353 1 / 0 354 self.assertEqual(state, [1]) 355 356 @_async_test 357 async def test_aclosing_bpo41229(self): 358 state = [] 359 360 @contextmanager 361 def sync_resource(): 362 try: 363 yield 364 finally: 365 state.append(1) 366 367 async def agenfunc(): 368 with sync_resource(): 369 yield -1 370 yield -2 371 372 x = agenfunc() 373 self.assertEqual(state, []) 374 with self.assertRaises(ZeroDivisionError): 375 async with aclosing(x) as y: 376 self.assertEqual(x, y) 377 self.assertEqual(-1, await x.__anext__()) 378 1 / 0 379 self.assertEqual(state, [1]) 380 381 382class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase): 383 class SyncAsyncExitStack(AsyncExitStack): 384 @staticmethod 385 def run_coroutine(coro): 386 loop = asyncio.get_event_loop() 387 388 f = asyncio.ensure_future(coro) 389 f.add_done_callback(lambda f: loop.stop()) 390 loop.run_forever() 391 392 exc = f.exception() 393 394 if not exc: 395 return f.result() 396 else: 397 context = exc.__context__ 398 399 try: 400 raise exc 401 except: 402 exc.__context__ = context 403 raise exc 404 405 def close(self): 406 return self.run_coroutine(self.aclose()) 407 408 def __enter__(self): 409 return self.run_coroutine(self.__aenter__()) 410 411 def __exit__(self, *exc_details): 412 return self.run_coroutine(self.__aexit__(*exc_details)) 413 414 exit_stack = SyncAsyncExitStack 415 416 def setUp(self): 417 self.loop = asyncio.new_event_loop() 418 asyncio.set_event_loop(self.loop) 419 self.addCleanup(self.loop.close) 420 self.addCleanup(asyncio.set_event_loop_policy, None) 421 422 @_async_test 423 async def test_async_callback(self): 424 expected = [ 425 ((), {}), 426 ((1,), {}), 427 ((1,2), {}), 428 ((), dict(example=1)), 429 ((1,), dict(example=1)), 430 ((1,2), dict(example=1)), 431 ] 432 result = [] 433 async def _exit(*args, **kwds): 434 """Test metadata propagation""" 435 result.append((args, kwds)) 436 437 async with AsyncExitStack() as stack: 438 for args, kwds in reversed(expected): 439 if args and kwds: 440 f = stack.push_async_callback(_exit, *args, **kwds) 441 elif args: 442 f = stack.push_async_callback(_exit, *args) 443 elif kwds: 444 f = stack.push_async_callback(_exit, **kwds) 445 else: 446 f = stack.push_async_callback(_exit) 447 self.assertIs(f, _exit) 448 for wrapper in stack._exit_callbacks: 449 self.assertIs(wrapper[1].__wrapped__, _exit) 450 self.assertNotEqual(wrapper[1].__name__, _exit.__name__) 451 self.assertIsNone(wrapper[1].__doc__, _exit.__doc__) 452 453 self.assertEqual(result, expected) 454 455 result = [] 456 async with AsyncExitStack() as stack: 457 with self.assertRaises(TypeError): 458 stack.push_async_callback(arg=1) 459 with self.assertRaises(TypeError): 460 self.exit_stack.push_async_callback(arg=2) 461 with self.assertRaises(TypeError): 462 stack.push_async_callback(callback=_exit, arg=3) 463 self.assertEqual(result, []) 464 465 @_async_test 466 async def test_async_push(self): 467 exc_raised = ZeroDivisionError 468 async def _expect_exc(exc_type, exc, exc_tb): 469 self.assertIs(exc_type, exc_raised) 470 async def _suppress_exc(*exc_details): 471 return True 472 async def _expect_ok(exc_type, exc, exc_tb): 473 self.assertIsNone(exc_type) 474 self.assertIsNone(exc) 475 self.assertIsNone(exc_tb) 476 class ExitCM(object): 477 def __init__(self, check_exc): 478 self.check_exc = check_exc 479 async def __aenter__(self): 480 self.fail("Should not be called!") 481 async def __aexit__(self, *exc_details): 482 await self.check_exc(*exc_details) 483 484 async with self.exit_stack() as stack: 485 stack.push_async_exit(_expect_ok) 486 self.assertIs(stack._exit_callbacks[-1][1], _expect_ok) 487 cm = ExitCM(_expect_ok) 488 stack.push_async_exit(cm) 489 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 490 stack.push_async_exit(_suppress_exc) 491 self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc) 492 cm = ExitCM(_expect_exc) 493 stack.push_async_exit(cm) 494 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 495 stack.push_async_exit(_expect_exc) 496 self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) 497 stack.push_async_exit(_expect_exc) 498 self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) 499 1/0 500 501 @_async_test 502 async def test_async_enter_context(self): 503 class TestCM(object): 504 async def __aenter__(self): 505 result.append(1) 506 async def __aexit__(self, *exc_details): 507 result.append(3) 508 509 result = [] 510 cm = TestCM() 511 512 async with AsyncExitStack() as stack: 513 @stack.push_async_callback # Registered first => cleaned up last 514 async def _exit(): 515 result.append(4) 516 self.assertIsNotNone(_exit) 517 await stack.enter_async_context(cm) 518 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 519 result.append(2) 520 521 self.assertEqual(result, [1, 2, 3, 4]) 522 523 @_async_test 524 async def test_async_exit_exception_chaining(self): 525 # Ensure exception chaining matches the reference behaviour 526 async def raise_exc(exc): 527 raise exc 528 529 saved_details = None 530 async def suppress_exc(*exc_details): 531 nonlocal saved_details 532 saved_details = exc_details 533 return True 534 535 try: 536 async with self.exit_stack() as stack: 537 stack.push_async_callback(raise_exc, IndexError) 538 stack.push_async_callback(raise_exc, KeyError) 539 stack.push_async_callback(raise_exc, AttributeError) 540 stack.push_async_exit(suppress_exc) 541 stack.push_async_callback(raise_exc, ValueError) 542 1 / 0 543 except IndexError as exc: 544 self.assertIsInstance(exc.__context__, KeyError) 545 self.assertIsInstance(exc.__context__.__context__, AttributeError) 546 # Inner exceptions were suppressed 547 self.assertIsNone(exc.__context__.__context__.__context__) 548 else: 549 self.fail("Expected IndexError, but no exception was raised") 550 # Check the inner exceptions 551 inner_exc = saved_details[1] 552 self.assertIsInstance(inner_exc, ValueError) 553 self.assertIsInstance(inner_exc.__context__, ZeroDivisionError) 554 555 @_async_test 556 async def test_async_exit_exception_explicit_none_context(self): 557 # Ensure AsyncExitStack chaining matches actual nested `with` statements 558 # regarding explicit __context__ = None. 559 560 class MyException(Exception): 561 pass 562 563 @asynccontextmanager 564 async def my_cm(): 565 try: 566 yield 567 except BaseException: 568 exc = MyException() 569 try: 570 raise exc 571 finally: 572 exc.__context__ = None 573 574 @asynccontextmanager 575 async def my_cm_with_exit_stack(): 576 async with self.exit_stack() as stack: 577 await stack.enter_async_context(my_cm()) 578 yield stack 579 580 for cm in (my_cm, my_cm_with_exit_stack): 581 with self.subTest(): 582 try: 583 async with cm(): 584 raise IndexError() 585 except MyException as exc: 586 self.assertIsNone(exc.__context__) 587 else: 588 self.fail("Expected IndexError, but no exception was raised") 589 590 591class TestAsyncNullcontext(unittest.TestCase): 592 @_async_test 593 async def test_async_nullcontext(self): 594 class C: 595 pass 596 c = C() 597 async with nullcontext(c) as c_in: 598 self.assertIs(c_in, c) 599 600 601if __name__ == '__main__': 602 unittest.main() 603