1# Adapted with permission from the EdgeDB project; 2# license: PSFL. 3 4import gc 5import asyncio 6import contextvars 7import contextlib 8from asyncio import taskgroups 9import unittest 10import warnings 11 12from test.test_asyncio.utils import await_without_task 13 14# To prevent a warning "test altered the execution environment" 15def tearDownModule(): 16 asyncio.set_event_loop_policy(None) 17 18 19class MyExc(Exception): 20 pass 21 22 23class MyBaseExc(BaseException): 24 pass 25 26 27def get_error_types(eg): 28 return {type(exc) for exc in eg.exceptions} 29 30 31class TestTaskGroup(unittest.IsolatedAsyncioTestCase): 32 33 async def test_taskgroup_01(self): 34 35 async def foo1(): 36 await asyncio.sleep(0.1) 37 return 42 38 39 async def foo2(): 40 await asyncio.sleep(0.2) 41 return 11 42 43 async with taskgroups.TaskGroup() as g: 44 t1 = g.create_task(foo1()) 45 t2 = g.create_task(foo2()) 46 47 self.assertEqual(t1.result(), 42) 48 self.assertEqual(t2.result(), 11) 49 50 async def test_taskgroup_02(self): 51 52 async def foo1(): 53 await asyncio.sleep(0.1) 54 return 42 55 56 async def foo2(): 57 await asyncio.sleep(0.2) 58 return 11 59 60 async with taskgroups.TaskGroup() as g: 61 t1 = g.create_task(foo1()) 62 await asyncio.sleep(0.15) 63 t2 = g.create_task(foo2()) 64 65 self.assertEqual(t1.result(), 42) 66 self.assertEqual(t2.result(), 11) 67 68 async def test_taskgroup_03(self): 69 70 async def foo1(): 71 await asyncio.sleep(1) 72 return 42 73 74 async def foo2(): 75 await asyncio.sleep(0.2) 76 return 11 77 78 async with taskgroups.TaskGroup() as g: 79 t1 = g.create_task(foo1()) 80 await asyncio.sleep(0.15) 81 # cancel t1 explicitly, i.e. everything should continue 82 # working as expected. 83 t1.cancel() 84 85 t2 = g.create_task(foo2()) 86 87 self.assertTrue(t1.cancelled()) 88 self.assertEqual(t2.result(), 11) 89 90 async def test_taskgroup_04(self): 91 92 NUM = 0 93 t2_cancel = False 94 t2 = None 95 96 async def foo1(): 97 await asyncio.sleep(0.1) 98 1 / 0 99 100 async def foo2(): 101 nonlocal NUM, t2_cancel 102 try: 103 await asyncio.sleep(1) 104 except asyncio.CancelledError: 105 t2_cancel = True 106 raise 107 NUM += 1 108 109 async def runner(): 110 nonlocal NUM, t2 111 112 async with taskgroups.TaskGroup() as g: 113 g.create_task(foo1()) 114 t2 = g.create_task(foo2()) 115 116 NUM += 10 117 118 with self.assertRaises(ExceptionGroup) as cm: 119 await asyncio.create_task(runner()) 120 121 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) 122 123 self.assertEqual(NUM, 0) 124 self.assertTrue(t2_cancel) 125 self.assertTrue(t2.cancelled()) 126 127 async def test_cancel_children_on_child_error(self): 128 # When a child task raises an error, the rest of the children 129 # are cancelled and the errors are gathered into an EG. 130 131 NUM = 0 132 t2_cancel = False 133 runner_cancel = False 134 135 async def foo1(): 136 await asyncio.sleep(0.1) 137 1 / 0 138 139 async def foo2(): 140 nonlocal NUM, t2_cancel 141 try: 142 await asyncio.sleep(5) 143 except asyncio.CancelledError: 144 t2_cancel = True 145 raise 146 NUM += 1 147 148 async def runner(): 149 nonlocal NUM, runner_cancel 150 151 async with taskgroups.TaskGroup() as g: 152 g.create_task(foo1()) 153 g.create_task(foo1()) 154 g.create_task(foo1()) 155 g.create_task(foo2()) 156 try: 157 await asyncio.sleep(10) 158 except asyncio.CancelledError: 159 runner_cancel = True 160 raise 161 162 NUM += 10 163 164 # The 3 foo1 sub tasks can be racy when the host is busy - if the 165 # cancellation happens in the middle, we'll see partial sub errors here 166 with self.assertRaises(ExceptionGroup) as cm: 167 await asyncio.create_task(runner()) 168 169 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) 170 self.assertEqual(NUM, 0) 171 self.assertTrue(t2_cancel) 172 self.assertTrue(runner_cancel) 173 174 async def test_cancellation(self): 175 176 NUM = 0 177 178 async def foo(): 179 nonlocal NUM 180 try: 181 await asyncio.sleep(5) 182 except asyncio.CancelledError: 183 NUM += 1 184 raise 185 186 async def runner(): 187 async with taskgroups.TaskGroup() as g: 188 for _ in range(5): 189 g.create_task(foo()) 190 191 r = asyncio.create_task(runner()) 192 await asyncio.sleep(0.1) 193 194 self.assertFalse(r.done()) 195 r.cancel() 196 with self.assertRaises(asyncio.CancelledError) as cm: 197 await r 198 199 self.assertEqual(NUM, 5) 200 201 async def test_taskgroup_07(self): 202 203 NUM = 0 204 205 async def foo(): 206 nonlocal NUM 207 try: 208 await asyncio.sleep(5) 209 except asyncio.CancelledError: 210 NUM += 1 211 raise 212 213 async def runner(): 214 nonlocal NUM 215 async with taskgroups.TaskGroup() as g: 216 for _ in range(5): 217 g.create_task(foo()) 218 219 try: 220 await asyncio.sleep(10) 221 except asyncio.CancelledError: 222 NUM += 10 223 raise 224 225 r = asyncio.create_task(runner()) 226 await asyncio.sleep(0.1) 227 228 self.assertFalse(r.done()) 229 r.cancel() 230 with self.assertRaises(asyncio.CancelledError): 231 await r 232 233 self.assertEqual(NUM, 15) 234 235 async def test_taskgroup_08(self): 236 237 async def foo(): 238 try: 239 await asyncio.sleep(10) 240 finally: 241 1 / 0 242 243 async def runner(): 244 async with taskgroups.TaskGroup() as g: 245 for _ in range(5): 246 g.create_task(foo()) 247 248 await asyncio.sleep(10) 249 250 r = asyncio.create_task(runner()) 251 await asyncio.sleep(0.1) 252 253 self.assertFalse(r.done()) 254 r.cancel() 255 with self.assertRaises(ExceptionGroup) as cm: 256 await r 257 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) 258 259 async def test_taskgroup_09(self): 260 261 t1 = t2 = None 262 263 async def foo1(): 264 await asyncio.sleep(1) 265 return 42 266 267 async def foo2(): 268 await asyncio.sleep(2) 269 return 11 270 271 async def runner(): 272 nonlocal t1, t2 273 async with taskgroups.TaskGroup() as g: 274 t1 = g.create_task(foo1()) 275 t2 = g.create_task(foo2()) 276 await asyncio.sleep(0.1) 277 1 / 0 278 279 try: 280 await runner() 281 except ExceptionGroup as t: 282 self.assertEqual(get_error_types(t), {ZeroDivisionError}) 283 else: 284 self.fail('ExceptionGroup was not raised') 285 286 self.assertTrue(t1.cancelled()) 287 self.assertTrue(t2.cancelled()) 288 289 async def test_taskgroup_10(self): 290 291 t1 = t2 = None 292 293 async def foo1(): 294 await asyncio.sleep(1) 295 return 42 296 297 async def foo2(): 298 await asyncio.sleep(2) 299 return 11 300 301 async def runner(): 302 nonlocal t1, t2 303 async with taskgroups.TaskGroup() as g: 304 t1 = g.create_task(foo1()) 305 t2 = g.create_task(foo2()) 306 1 / 0 307 308 try: 309 await runner() 310 except ExceptionGroup as t: 311 self.assertEqual(get_error_types(t), {ZeroDivisionError}) 312 else: 313 self.fail('ExceptionGroup was not raised') 314 315 self.assertTrue(t1.cancelled()) 316 self.assertTrue(t2.cancelled()) 317 318 async def test_taskgroup_11(self): 319 320 async def foo(): 321 try: 322 await asyncio.sleep(10) 323 finally: 324 1 / 0 325 326 async def runner(): 327 async with taskgroups.TaskGroup(): 328 async with taskgroups.TaskGroup() as g2: 329 for _ in range(5): 330 g2.create_task(foo()) 331 332 await asyncio.sleep(10) 333 334 r = asyncio.create_task(runner()) 335 await asyncio.sleep(0.1) 336 337 self.assertFalse(r.done()) 338 r.cancel() 339 with self.assertRaises(ExceptionGroup) as cm: 340 await r 341 342 self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) 343 self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError}) 344 345 async def test_taskgroup_12(self): 346 347 async def foo(): 348 try: 349 await asyncio.sleep(10) 350 finally: 351 1 / 0 352 353 async def runner(): 354 async with taskgroups.TaskGroup() as g1: 355 g1.create_task(asyncio.sleep(10)) 356 357 async with taskgroups.TaskGroup() as g2: 358 for _ in range(5): 359 g2.create_task(foo()) 360 361 await asyncio.sleep(10) 362 363 r = asyncio.create_task(runner()) 364 await asyncio.sleep(0.1) 365 366 self.assertFalse(r.done()) 367 r.cancel() 368 with self.assertRaises(ExceptionGroup) as cm: 369 await r 370 371 self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) 372 self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError}) 373 374 async def test_taskgroup_13(self): 375 376 async def crash_after(t): 377 await asyncio.sleep(t) 378 raise ValueError(t) 379 380 async def runner(): 381 async with taskgroups.TaskGroup() as g1: 382 g1.create_task(crash_after(0.1)) 383 384 async with taskgroups.TaskGroup() as g2: 385 g2.create_task(crash_after(10)) 386 387 r = asyncio.create_task(runner()) 388 with self.assertRaises(ExceptionGroup) as cm: 389 await r 390 391 self.assertEqual(get_error_types(cm.exception), {ValueError}) 392 393 async def test_taskgroup_14(self): 394 395 async def crash_after(t): 396 await asyncio.sleep(t) 397 raise ValueError(t) 398 399 async def runner(): 400 async with taskgroups.TaskGroup() as g1: 401 g1.create_task(crash_after(10)) 402 403 async with taskgroups.TaskGroup() as g2: 404 g2.create_task(crash_after(0.1)) 405 406 r = asyncio.create_task(runner()) 407 with self.assertRaises(ExceptionGroup) as cm: 408 await r 409 410 self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) 411 self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError}) 412 413 async def test_taskgroup_15(self): 414 415 async def crash_soon(): 416 await asyncio.sleep(0.3) 417 1 / 0 418 419 async def runner(): 420 async with taskgroups.TaskGroup() as g1: 421 g1.create_task(crash_soon()) 422 try: 423 await asyncio.sleep(10) 424 except asyncio.CancelledError: 425 await asyncio.sleep(0.5) 426 raise 427 428 r = asyncio.create_task(runner()) 429 await asyncio.sleep(0.1) 430 431 self.assertFalse(r.done()) 432 r.cancel() 433 with self.assertRaises(ExceptionGroup) as cm: 434 await r 435 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) 436 437 async def test_taskgroup_16(self): 438 439 async def crash_soon(): 440 await asyncio.sleep(0.3) 441 1 / 0 442 443 async def nested_runner(): 444 async with taskgroups.TaskGroup() as g1: 445 g1.create_task(crash_soon()) 446 try: 447 await asyncio.sleep(10) 448 except asyncio.CancelledError: 449 await asyncio.sleep(0.5) 450 raise 451 452 async def runner(): 453 t = asyncio.create_task(nested_runner()) 454 await t 455 456 r = asyncio.create_task(runner()) 457 await asyncio.sleep(0.1) 458 459 self.assertFalse(r.done()) 460 r.cancel() 461 with self.assertRaises(ExceptionGroup) as cm: 462 await r 463 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) 464 465 async def test_taskgroup_17(self): 466 NUM = 0 467 468 async def runner(): 469 nonlocal NUM 470 async with taskgroups.TaskGroup(): 471 try: 472 await asyncio.sleep(10) 473 except asyncio.CancelledError: 474 NUM += 10 475 raise 476 477 r = asyncio.create_task(runner()) 478 await asyncio.sleep(0.1) 479 480 self.assertFalse(r.done()) 481 r.cancel() 482 with self.assertRaises(asyncio.CancelledError): 483 await r 484 485 self.assertEqual(NUM, 10) 486 487 async def test_taskgroup_18(self): 488 NUM = 0 489 490 async def runner(): 491 nonlocal NUM 492 async with taskgroups.TaskGroup(): 493 try: 494 await asyncio.sleep(10) 495 except asyncio.CancelledError: 496 NUM += 10 497 # This isn't a good idea, but we have to support 498 # this weird case. 499 raise MyExc 500 501 r = asyncio.create_task(runner()) 502 await asyncio.sleep(0.1) 503 504 self.assertFalse(r.done()) 505 r.cancel() 506 507 try: 508 await r 509 except ExceptionGroup as t: 510 self.assertEqual(get_error_types(t),{MyExc}) 511 else: 512 self.fail('ExceptionGroup was not raised') 513 514 self.assertEqual(NUM, 10) 515 516 async def test_taskgroup_19(self): 517 async def crash_soon(): 518 await asyncio.sleep(0.1) 519 1 / 0 520 521 async def nested(): 522 try: 523 await asyncio.sleep(10) 524 finally: 525 raise MyExc 526 527 async def runner(): 528 async with taskgroups.TaskGroup() as g: 529 g.create_task(crash_soon()) 530 await nested() 531 532 r = asyncio.create_task(runner()) 533 try: 534 await r 535 except ExceptionGroup as t: 536 self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError}) 537 else: 538 self.fail('TasgGroupError was not raised') 539 540 async def test_taskgroup_20(self): 541 async def crash_soon(): 542 await asyncio.sleep(0.1) 543 1 / 0 544 545 async def nested(): 546 try: 547 await asyncio.sleep(10) 548 finally: 549 raise KeyboardInterrupt 550 551 async def runner(): 552 async with taskgroups.TaskGroup() as g: 553 g.create_task(crash_soon()) 554 await nested() 555 556 with self.assertRaises(KeyboardInterrupt): 557 await runner() 558 559 async def test_taskgroup_20a(self): 560 async def crash_soon(): 561 await asyncio.sleep(0.1) 562 1 / 0 563 564 async def nested(): 565 try: 566 await asyncio.sleep(10) 567 finally: 568 raise MyBaseExc 569 570 async def runner(): 571 async with taskgroups.TaskGroup() as g: 572 g.create_task(crash_soon()) 573 await nested() 574 575 with self.assertRaises(BaseExceptionGroup) as cm: 576 await runner() 577 578 self.assertEqual( 579 get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError} 580 ) 581 582 async def _test_taskgroup_21(self): 583 # This test doesn't work as asyncio, currently, doesn't 584 # correctly propagate KeyboardInterrupt (or SystemExit) -- 585 # those cause the event loop itself to crash. 586 # (Compare to the previous (passing) test -- that one raises 587 # a plain exception but raises KeyboardInterrupt in nested(); 588 # this test does it the other way around.) 589 590 async def crash_soon(): 591 await asyncio.sleep(0.1) 592 raise KeyboardInterrupt 593 594 async def nested(): 595 try: 596 await asyncio.sleep(10) 597 finally: 598 raise TypeError 599 600 async def runner(): 601 async with taskgroups.TaskGroup() as g: 602 g.create_task(crash_soon()) 603 await nested() 604 605 with self.assertRaises(KeyboardInterrupt): 606 await runner() 607 608 async def test_taskgroup_21a(self): 609 610 async def crash_soon(): 611 await asyncio.sleep(0.1) 612 raise MyBaseExc 613 614 async def nested(): 615 try: 616 await asyncio.sleep(10) 617 finally: 618 raise TypeError 619 620 async def runner(): 621 async with taskgroups.TaskGroup() as g: 622 g.create_task(crash_soon()) 623 await nested() 624 625 with self.assertRaises(BaseExceptionGroup) as cm: 626 await runner() 627 628 self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError}) 629 630 async def test_taskgroup_22(self): 631 632 async def foo1(): 633 await asyncio.sleep(1) 634 return 42 635 636 async def foo2(): 637 await asyncio.sleep(2) 638 return 11 639 640 async def runner(): 641 async with taskgroups.TaskGroup() as g: 642 g.create_task(foo1()) 643 g.create_task(foo2()) 644 645 r = asyncio.create_task(runner()) 646 await asyncio.sleep(0.05) 647 r.cancel() 648 649 with self.assertRaises(asyncio.CancelledError): 650 await r 651 652 async def test_taskgroup_23(self): 653 654 async def do_job(delay): 655 await asyncio.sleep(delay) 656 657 async with taskgroups.TaskGroup() as g: 658 for count in range(10): 659 await asyncio.sleep(0.1) 660 g.create_task(do_job(0.3)) 661 if count == 5: 662 self.assertLess(len(g._tasks), 5) 663 await asyncio.sleep(1.35) 664 self.assertEqual(len(g._tasks), 0) 665 666 async def test_taskgroup_24(self): 667 668 async def root(g): 669 await asyncio.sleep(0.1) 670 g.create_task(coro1(0.1)) 671 g.create_task(coro1(0.2)) 672 673 async def coro1(delay): 674 await asyncio.sleep(delay) 675 676 async def runner(): 677 async with taskgroups.TaskGroup() as g: 678 g.create_task(root(g)) 679 680 await runner() 681 682 async def test_taskgroup_25(self): 683 nhydras = 0 684 685 async def hydra(g): 686 nonlocal nhydras 687 nhydras += 1 688 await asyncio.sleep(0.01) 689 g.create_task(hydra(g)) 690 g.create_task(hydra(g)) 691 692 async def hercules(): 693 while nhydras < 10: 694 await asyncio.sleep(0.015) 695 1 / 0 696 697 async def runner(): 698 async with taskgroups.TaskGroup() as g: 699 g.create_task(hydra(g)) 700 g.create_task(hercules()) 701 702 with self.assertRaises(ExceptionGroup) as cm: 703 await runner() 704 705 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) 706 self.assertGreaterEqual(nhydras, 10) 707 708 async def test_taskgroup_task_name(self): 709 async def coro(): 710 await asyncio.sleep(0) 711 async with taskgroups.TaskGroup() as g: 712 t = g.create_task(coro(), name="yolo") 713 self.assertEqual(t.get_name(), "yolo") 714 715 async def test_taskgroup_task_context(self): 716 cvar = contextvars.ContextVar('cvar') 717 718 async def coro(val): 719 await asyncio.sleep(0) 720 cvar.set(val) 721 722 async with taskgroups.TaskGroup() as g: 723 ctx = contextvars.copy_context() 724 self.assertIsNone(ctx.get(cvar)) 725 t1 = g.create_task(coro(1), context=ctx) 726 await t1 727 self.assertEqual(1, ctx.get(cvar)) 728 t2 = g.create_task(coro(2), context=ctx) 729 await t2 730 self.assertEqual(2, ctx.get(cvar)) 731 732 async def test_taskgroup_no_create_task_after_failure(self): 733 async def coro1(): 734 await asyncio.sleep(0.001) 735 1 / 0 736 async def coro2(g): 737 try: 738 await asyncio.sleep(1) 739 except asyncio.CancelledError: 740 with self.assertRaises(RuntimeError): 741 g.create_task(coro1()) 742 743 with self.assertRaises(ExceptionGroup) as cm: 744 async with taskgroups.TaskGroup() as g: 745 g.create_task(coro1()) 746 g.create_task(coro2(g)) 747 748 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) 749 750 async def test_taskgroup_context_manager_exit_raises(self): 751 # See https://github.com/python/cpython/issues/95289 752 class CustomException(Exception): 753 pass 754 755 async def raise_exc(): 756 raise CustomException 757 758 @contextlib.asynccontextmanager 759 async def database(): 760 try: 761 yield 762 finally: 763 raise CustomException 764 765 async def main(): 766 task = asyncio.current_task() 767 try: 768 async with taskgroups.TaskGroup() as tg: 769 async with database(): 770 tg.create_task(raise_exc()) 771 await asyncio.sleep(1) 772 except* CustomException as err: 773 self.assertEqual(task.cancelling(), 0) 774 self.assertEqual(len(err.exceptions), 2) 775 776 else: 777 self.fail('CustomException not raised') 778 779 await asyncio.create_task(main()) 780 781 async def test_taskgroup_already_entered(self): 782 tg = taskgroups.TaskGroup() 783 async with tg: 784 with self.assertRaisesRegex(RuntimeError, "has already been entered"): 785 async with tg: 786 pass 787 788 async def test_taskgroup_double_enter(self): 789 tg = taskgroups.TaskGroup() 790 async with tg: 791 pass 792 with self.assertRaisesRegex(RuntimeError, "has already been entered"): 793 async with tg: 794 pass 795 796 async def test_taskgroup_finished(self): 797 async def create_task_after_tg_finish(): 798 tg = taskgroups.TaskGroup() 799 async with tg: 800 pass 801 coro = asyncio.sleep(0) 802 with self.assertRaisesRegex(RuntimeError, "is finished"): 803 tg.create_task(coro) 804 805 # Make sure the coroutine was closed when submitted to the inactive tg 806 # (if not closed, a RuntimeWarning should have been raised) 807 with warnings.catch_warnings(record=True) as w: 808 await create_task_after_tg_finish() 809 self.assertEqual(len(w), 0) 810 811 async def test_taskgroup_not_entered(self): 812 tg = taskgroups.TaskGroup() 813 coro = asyncio.sleep(0) 814 with self.assertRaisesRegex(RuntimeError, "has not been entered"): 815 tg.create_task(coro) 816 817 async def test_taskgroup_without_parent_task(self): 818 tg = taskgroups.TaskGroup() 819 with self.assertRaisesRegex(RuntimeError, "parent task"): 820 await await_without_task(tg.__aenter__()) 821 coro = asyncio.sleep(0) 822 with self.assertRaisesRegex(RuntimeError, "has not been entered"): 823 tg.create_task(coro) 824 825 def test_coro_closed_when_tg_closed(self): 826 async def run_coro_after_tg_closes(): 827 async with taskgroups.TaskGroup() as tg: 828 pass 829 coro = asyncio.sleep(0) 830 with self.assertRaisesRegex(RuntimeError, "is finished"): 831 tg.create_task(coro) 832 loop = asyncio.get_event_loop() 833 loop.run_until_complete(run_coro_after_tg_closes()) 834 835 async def test_cancelling_level_preserved(self): 836 async def raise_after(t, e): 837 await asyncio.sleep(t) 838 raise e() 839 840 try: 841 async with asyncio.TaskGroup() as tg: 842 tg.create_task(raise_after(0.0, RuntimeError)) 843 except* RuntimeError: 844 pass 845 self.assertEqual(asyncio.current_task().cancelling(), 0) 846 847 async def test_nested_groups_both_cancelled(self): 848 async def raise_after(t, e): 849 await asyncio.sleep(t) 850 raise e() 851 852 try: 853 async with asyncio.TaskGroup() as outer_tg: 854 try: 855 async with asyncio.TaskGroup() as inner_tg: 856 inner_tg.create_task(raise_after(0, RuntimeError)) 857 outer_tg.create_task(raise_after(0, ValueError)) 858 except* RuntimeError: 859 pass 860 else: 861 self.fail("RuntimeError not raised") 862 self.assertEqual(asyncio.current_task().cancelling(), 1) 863 except* ValueError: 864 pass 865 else: 866 self.fail("ValueError not raised") 867 self.assertEqual(asyncio.current_task().cancelling(), 0) 868 869 async def test_error_and_cancel(self): 870 event = asyncio.Event() 871 872 async def raise_error(): 873 event.set() 874 await asyncio.sleep(0) 875 raise RuntimeError() 876 877 async def inner(): 878 try: 879 async with taskgroups.TaskGroup() as tg: 880 tg.create_task(raise_error()) 881 await asyncio.sleep(1) 882 self.fail("Sleep in group should have been cancelled") 883 except* RuntimeError: 884 self.assertEqual(asyncio.current_task().cancelling(), 1) 885 self.assertEqual(asyncio.current_task().cancelling(), 1) 886 await asyncio.sleep(1) 887 self.fail("Sleep after group should have been cancelled") 888 889 async def outer(): 890 t = asyncio.create_task(inner()) 891 await event.wait() 892 self.assertEqual(t.cancelling(), 0) 893 t.cancel() 894 self.assertEqual(t.cancelling(), 1) 895 with self.assertRaises(asyncio.CancelledError): 896 await t 897 self.assertTrue(t.cancelled()) 898 899 await outer() 900 901 async def test_exception_refcycles_direct(self): 902 """Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup""" 903 tg = asyncio.TaskGroup() 904 exc = None 905 906 class _Done(Exception): 907 pass 908 909 try: 910 async with tg: 911 raise _Done 912 except ExceptionGroup as e: 913 exc = e 914 915 self.assertIsNotNone(exc) 916 self.assertListEqual(gc.get_referrers(exc), []) 917 918 919 async def test_exception_refcycles_errors(self): 920 """Test that TaskGroup deletes self._errors, and __aexit__ args""" 921 tg = asyncio.TaskGroup() 922 exc = None 923 924 class _Done(Exception): 925 pass 926 927 try: 928 async with tg: 929 raise _Done 930 except* _Done as excs: 931 exc = excs.exceptions[0] 932 933 self.assertIsInstance(exc, _Done) 934 self.assertListEqual(gc.get_referrers(exc), []) 935 936 937 async def test_exception_refcycles_parent_task(self): 938 """Test that TaskGroup deletes self._parent_task""" 939 tg = asyncio.TaskGroup() 940 exc = None 941 942 class _Done(Exception): 943 pass 944 945 async def coro_fn(): 946 async with tg: 947 raise _Done 948 949 try: 950 async with asyncio.TaskGroup() as tg2: 951 tg2.create_task(coro_fn()) 952 except* _Done as excs: 953 exc = excs.exceptions[0].exceptions[0] 954 955 self.assertIsInstance(exc, _Done) 956 self.assertListEqual(gc.get_referrers(exc), []) 957 958 async def test_exception_refcycles_propagate_cancellation_error(self): 959 """Test that TaskGroup deletes propagate_cancellation_error""" 960 tg = asyncio.TaskGroup() 961 exc = None 962 963 try: 964 async with asyncio.timeout(-1): 965 async with tg: 966 await asyncio.sleep(0) 967 except TimeoutError as e: 968 exc = e.__cause__ 969 970 self.assertIsInstance(exc, asyncio.CancelledError) 971 self.assertListEqual(gc.get_referrers(exc), []) 972 973 async def test_exception_refcycles_base_error(self): 974 """Test that TaskGroup deletes self._base_error""" 975 class MyKeyboardInterrupt(KeyboardInterrupt): 976 pass 977 978 tg = asyncio.TaskGroup() 979 exc = None 980 981 try: 982 async with tg: 983 raise MyKeyboardInterrupt 984 except MyKeyboardInterrupt as e: 985 exc = e 986 987 self.assertIsNotNone(exc) 988 self.assertListEqual(gc.get_referrers(exc), []) 989 990 991if __name__ == "__main__": 992 unittest.main() 993