• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Adapted with permission from the EdgeDB project;
2# license: PSFL.
3
4import gc
5import asyncio
6import contextvars
7import contextlib
8from asyncio import taskgroups
9import unittest
10import warnings
11
12from test.test_asyncio.utils import await_without_task
13
14# To prevent a warning "test altered the execution environment"
15def tearDownModule():
16    asyncio.set_event_loop_policy(None)
17
18
19class MyExc(Exception):
20    pass
21
22
23class MyBaseExc(BaseException):
24    pass
25
26
27def get_error_types(eg):
28    return {type(exc) for exc in eg.exceptions}
29
30
31class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
32
33    async def test_taskgroup_01(self):
34
35        async def foo1():
36            await asyncio.sleep(0.1)
37            return 42
38
39        async def foo2():
40            await asyncio.sleep(0.2)
41            return 11
42
43        async with taskgroups.TaskGroup() as g:
44            t1 = g.create_task(foo1())
45            t2 = g.create_task(foo2())
46
47        self.assertEqual(t1.result(), 42)
48        self.assertEqual(t2.result(), 11)
49
50    async def test_taskgroup_02(self):
51
52        async def foo1():
53            await asyncio.sleep(0.1)
54            return 42
55
56        async def foo2():
57            await asyncio.sleep(0.2)
58            return 11
59
60        async with taskgroups.TaskGroup() as g:
61            t1 = g.create_task(foo1())
62            await asyncio.sleep(0.15)
63            t2 = g.create_task(foo2())
64
65        self.assertEqual(t1.result(), 42)
66        self.assertEqual(t2.result(), 11)
67
68    async def test_taskgroup_03(self):
69
70        async def foo1():
71            await asyncio.sleep(1)
72            return 42
73
74        async def foo2():
75            await asyncio.sleep(0.2)
76            return 11
77
78        async with taskgroups.TaskGroup() as g:
79            t1 = g.create_task(foo1())
80            await asyncio.sleep(0.15)
81            # cancel t1 explicitly, i.e. everything should continue
82            # working as expected.
83            t1.cancel()
84
85            t2 = g.create_task(foo2())
86
87        self.assertTrue(t1.cancelled())
88        self.assertEqual(t2.result(), 11)
89
90    async def test_taskgroup_04(self):
91
92        NUM = 0
93        t2_cancel = False
94        t2 = None
95
96        async def foo1():
97            await asyncio.sleep(0.1)
98            1 / 0
99
100        async def foo2():
101            nonlocal NUM, t2_cancel
102            try:
103                await asyncio.sleep(1)
104            except asyncio.CancelledError:
105                t2_cancel = True
106                raise
107            NUM += 1
108
109        async def runner():
110            nonlocal NUM, t2
111
112            async with taskgroups.TaskGroup() as g:
113                g.create_task(foo1())
114                t2 = g.create_task(foo2())
115
116            NUM += 10
117
118        with self.assertRaises(ExceptionGroup) as cm:
119            await asyncio.create_task(runner())
120
121        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
122
123        self.assertEqual(NUM, 0)
124        self.assertTrue(t2_cancel)
125        self.assertTrue(t2.cancelled())
126
127    async def test_cancel_children_on_child_error(self):
128        # When a child task raises an error, the rest of the children
129        # are cancelled and the errors are gathered into an EG.
130
131        NUM = 0
132        t2_cancel = False
133        runner_cancel = False
134
135        async def foo1():
136            await asyncio.sleep(0.1)
137            1 / 0
138
139        async def foo2():
140            nonlocal NUM, t2_cancel
141            try:
142                await asyncio.sleep(5)
143            except asyncio.CancelledError:
144                t2_cancel = True
145                raise
146            NUM += 1
147
148        async def runner():
149            nonlocal NUM, runner_cancel
150
151            async with taskgroups.TaskGroup() as g:
152                g.create_task(foo1())
153                g.create_task(foo1())
154                g.create_task(foo1())
155                g.create_task(foo2())
156                try:
157                    await asyncio.sleep(10)
158                except asyncio.CancelledError:
159                    runner_cancel = True
160                    raise
161
162            NUM += 10
163
164        # The 3 foo1 sub tasks can be racy when the host is busy - if the
165        # cancellation happens in the middle, we'll see partial sub errors here
166        with self.assertRaises(ExceptionGroup) as cm:
167            await asyncio.create_task(runner())
168
169        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
170        self.assertEqual(NUM, 0)
171        self.assertTrue(t2_cancel)
172        self.assertTrue(runner_cancel)
173
174    async def test_cancellation(self):
175
176        NUM = 0
177
178        async def foo():
179            nonlocal NUM
180            try:
181                await asyncio.sleep(5)
182            except asyncio.CancelledError:
183                NUM += 1
184                raise
185
186        async def runner():
187            async with taskgroups.TaskGroup() as g:
188                for _ in range(5):
189                    g.create_task(foo())
190
191        r = asyncio.create_task(runner())
192        await asyncio.sleep(0.1)
193
194        self.assertFalse(r.done())
195        r.cancel()
196        with self.assertRaises(asyncio.CancelledError) as cm:
197            await r
198
199        self.assertEqual(NUM, 5)
200
201    async def test_taskgroup_07(self):
202
203        NUM = 0
204
205        async def foo():
206            nonlocal NUM
207            try:
208                await asyncio.sleep(5)
209            except asyncio.CancelledError:
210                NUM += 1
211                raise
212
213        async def runner():
214            nonlocal NUM
215            async with taskgroups.TaskGroup() as g:
216                for _ in range(5):
217                    g.create_task(foo())
218
219                try:
220                    await asyncio.sleep(10)
221                except asyncio.CancelledError:
222                    NUM += 10
223                    raise
224
225        r = asyncio.create_task(runner())
226        await asyncio.sleep(0.1)
227
228        self.assertFalse(r.done())
229        r.cancel()
230        with self.assertRaises(asyncio.CancelledError):
231            await r
232
233        self.assertEqual(NUM, 15)
234
235    async def test_taskgroup_08(self):
236
237        async def foo():
238            try:
239                await asyncio.sleep(10)
240            finally:
241                1 / 0
242
243        async def runner():
244            async with taskgroups.TaskGroup() as g:
245                for _ in range(5):
246                    g.create_task(foo())
247
248                await asyncio.sleep(10)
249
250        r = asyncio.create_task(runner())
251        await asyncio.sleep(0.1)
252
253        self.assertFalse(r.done())
254        r.cancel()
255        with self.assertRaises(ExceptionGroup) as cm:
256            await r
257        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
258
259    async def test_taskgroup_09(self):
260
261        t1 = t2 = None
262
263        async def foo1():
264            await asyncio.sleep(1)
265            return 42
266
267        async def foo2():
268            await asyncio.sleep(2)
269            return 11
270
271        async def runner():
272            nonlocal t1, t2
273            async with taskgroups.TaskGroup() as g:
274                t1 = g.create_task(foo1())
275                t2 = g.create_task(foo2())
276                await asyncio.sleep(0.1)
277                1 / 0
278
279        try:
280            await runner()
281        except ExceptionGroup as t:
282            self.assertEqual(get_error_types(t), {ZeroDivisionError})
283        else:
284            self.fail('ExceptionGroup was not raised')
285
286        self.assertTrue(t1.cancelled())
287        self.assertTrue(t2.cancelled())
288
289    async def test_taskgroup_10(self):
290
291        t1 = t2 = None
292
293        async def foo1():
294            await asyncio.sleep(1)
295            return 42
296
297        async def foo2():
298            await asyncio.sleep(2)
299            return 11
300
301        async def runner():
302            nonlocal t1, t2
303            async with taskgroups.TaskGroup() as g:
304                t1 = g.create_task(foo1())
305                t2 = g.create_task(foo2())
306                1 / 0
307
308        try:
309            await runner()
310        except ExceptionGroup as t:
311            self.assertEqual(get_error_types(t), {ZeroDivisionError})
312        else:
313            self.fail('ExceptionGroup was not raised')
314
315        self.assertTrue(t1.cancelled())
316        self.assertTrue(t2.cancelled())
317
318    async def test_taskgroup_11(self):
319
320        async def foo():
321            try:
322                await asyncio.sleep(10)
323            finally:
324                1 / 0
325
326        async def runner():
327            async with taskgroups.TaskGroup():
328                async with taskgroups.TaskGroup() as g2:
329                    for _ in range(5):
330                        g2.create_task(foo())
331
332                    await asyncio.sleep(10)
333
334        r = asyncio.create_task(runner())
335        await asyncio.sleep(0.1)
336
337        self.assertFalse(r.done())
338        r.cancel()
339        with self.assertRaises(ExceptionGroup) as cm:
340            await r
341
342        self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
343        self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
344
345    async def test_taskgroup_12(self):
346
347        async def foo():
348            try:
349                await asyncio.sleep(10)
350            finally:
351                1 / 0
352
353        async def runner():
354            async with taskgroups.TaskGroup() as g1:
355                g1.create_task(asyncio.sleep(10))
356
357                async with taskgroups.TaskGroup() as g2:
358                    for _ in range(5):
359                        g2.create_task(foo())
360
361                    await asyncio.sleep(10)
362
363        r = asyncio.create_task(runner())
364        await asyncio.sleep(0.1)
365
366        self.assertFalse(r.done())
367        r.cancel()
368        with self.assertRaises(ExceptionGroup) as cm:
369            await r
370
371        self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
372        self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
373
374    async def test_taskgroup_13(self):
375
376        async def crash_after(t):
377            await asyncio.sleep(t)
378            raise ValueError(t)
379
380        async def runner():
381            async with taskgroups.TaskGroup() as g1:
382                g1.create_task(crash_after(0.1))
383
384                async with taskgroups.TaskGroup() as g2:
385                    g2.create_task(crash_after(10))
386
387        r = asyncio.create_task(runner())
388        with self.assertRaises(ExceptionGroup) as cm:
389            await r
390
391        self.assertEqual(get_error_types(cm.exception), {ValueError})
392
393    async def test_taskgroup_14(self):
394
395        async def crash_after(t):
396            await asyncio.sleep(t)
397            raise ValueError(t)
398
399        async def runner():
400            async with taskgroups.TaskGroup() as g1:
401                g1.create_task(crash_after(10))
402
403                async with taskgroups.TaskGroup() as g2:
404                    g2.create_task(crash_after(0.1))
405
406        r = asyncio.create_task(runner())
407        with self.assertRaises(ExceptionGroup) as cm:
408            await r
409
410        self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
411        self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError})
412
413    async def test_taskgroup_15(self):
414
415        async def crash_soon():
416            await asyncio.sleep(0.3)
417            1 / 0
418
419        async def runner():
420            async with taskgroups.TaskGroup() as g1:
421                g1.create_task(crash_soon())
422                try:
423                    await asyncio.sleep(10)
424                except asyncio.CancelledError:
425                    await asyncio.sleep(0.5)
426                    raise
427
428        r = asyncio.create_task(runner())
429        await asyncio.sleep(0.1)
430
431        self.assertFalse(r.done())
432        r.cancel()
433        with self.assertRaises(ExceptionGroup) as cm:
434            await r
435        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
436
437    async def test_taskgroup_16(self):
438
439        async def crash_soon():
440            await asyncio.sleep(0.3)
441            1 / 0
442
443        async def nested_runner():
444            async with taskgroups.TaskGroup() as g1:
445                g1.create_task(crash_soon())
446                try:
447                    await asyncio.sleep(10)
448                except asyncio.CancelledError:
449                    await asyncio.sleep(0.5)
450                    raise
451
452        async def runner():
453            t = asyncio.create_task(nested_runner())
454            await t
455
456        r = asyncio.create_task(runner())
457        await asyncio.sleep(0.1)
458
459        self.assertFalse(r.done())
460        r.cancel()
461        with self.assertRaises(ExceptionGroup) as cm:
462            await r
463        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
464
465    async def test_taskgroup_17(self):
466        NUM = 0
467
468        async def runner():
469            nonlocal NUM
470            async with taskgroups.TaskGroup():
471                try:
472                    await asyncio.sleep(10)
473                except asyncio.CancelledError:
474                    NUM += 10
475                    raise
476
477        r = asyncio.create_task(runner())
478        await asyncio.sleep(0.1)
479
480        self.assertFalse(r.done())
481        r.cancel()
482        with self.assertRaises(asyncio.CancelledError):
483            await r
484
485        self.assertEqual(NUM, 10)
486
487    async def test_taskgroup_18(self):
488        NUM = 0
489
490        async def runner():
491            nonlocal NUM
492            async with taskgroups.TaskGroup():
493                try:
494                    await asyncio.sleep(10)
495                except asyncio.CancelledError:
496                    NUM += 10
497                    # This isn't a good idea, but we have to support
498                    # this weird case.
499                    raise MyExc
500
501        r = asyncio.create_task(runner())
502        await asyncio.sleep(0.1)
503
504        self.assertFalse(r.done())
505        r.cancel()
506
507        try:
508            await r
509        except ExceptionGroup as t:
510            self.assertEqual(get_error_types(t),{MyExc})
511        else:
512            self.fail('ExceptionGroup was not raised')
513
514        self.assertEqual(NUM, 10)
515
516    async def test_taskgroup_19(self):
517        async def crash_soon():
518            await asyncio.sleep(0.1)
519            1 / 0
520
521        async def nested():
522            try:
523                await asyncio.sleep(10)
524            finally:
525                raise MyExc
526
527        async def runner():
528            async with taskgroups.TaskGroup() as g:
529                g.create_task(crash_soon())
530                await nested()
531
532        r = asyncio.create_task(runner())
533        try:
534            await r
535        except ExceptionGroup as t:
536            self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError})
537        else:
538            self.fail('TasgGroupError was not raised')
539
540    async def test_taskgroup_20(self):
541        async def crash_soon():
542            await asyncio.sleep(0.1)
543            1 / 0
544
545        async def nested():
546            try:
547                await asyncio.sleep(10)
548            finally:
549                raise KeyboardInterrupt
550
551        async def runner():
552            async with taskgroups.TaskGroup() as g:
553                g.create_task(crash_soon())
554                await nested()
555
556        with self.assertRaises(KeyboardInterrupt):
557            await runner()
558
559    async def test_taskgroup_20a(self):
560        async def crash_soon():
561            await asyncio.sleep(0.1)
562            1 / 0
563
564        async def nested():
565            try:
566                await asyncio.sleep(10)
567            finally:
568                raise MyBaseExc
569
570        async def runner():
571            async with taskgroups.TaskGroup() as g:
572                g.create_task(crash_soon())
573                await nested()
574
575        with self.assertRaises(BaseExceptionGroup) as cm:
576            await runner()
577
578        self.assertEqual(
579            get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError}
580        )
581
582    async def _test_taskgroup_21(self):
583        # This test doesn't work as asyncio, currently, doesn't
584        # correctly propagate KeyboardInterrupt (or SystemExit) --
585        # those cause the event loop itself to crash.
586        # (Compare to the previous (passing) test -- that one raises
587        # a plain exception but raises KeyboardInterrupt in nested();
588        # this test does it the other way around.)
589
590        async def crash_soon():
591            await asyncio.sleep(0.1)
592            raise KeyboardInterrupt
593
594        async def nested():
595            try:
596                await asyncio.sleep(10)
597            finally:
598                raise TypeError
599
600        async def runner():
601            async with taskgroups.TaskGroup() as g:
602                g.create_task(crash_soon())
603                await nested()
604
605        with self.assertRaises(KeyboardInterrupt):
606            await runner()
607
608    async def test_taskgroup_21a(self):
609
610        async def crash_soon():
611            await asyncio.sleep(0.1)
612            raise MyBaseExc
613
614        async def nested():
615            try:
616                await asyncio.sleep(10)
617            finally:
618                raise TypeError
619
620        async def runner():
621            async with taskgroups.TaskGroup() as g:
622                g.create_task(crash_soon())
623                await nested()
624
625        with self.assertRaises(BaseExceptionGroup) as cm:
626            await runner()
627
628        self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError})
629
630    async def test_taskgroup_22(self):
631
632        async def foo1():
633            await asyncio.sleep(1)
634            return 42
635
636        async def foo2():
637            await asyncio.sleep(2)
638            return 11
639
640        async def runner():
641            async with taskgroups.TaskGroup() as g:
642                g.create_task(foo1())
643                g.create_task(foo2())
644
645        r = asyncio.create_task(runner())
646        await asyncio.sleep(0.05)
647        r.cancel()
648
649        with self.assertRaises(asyncio.CancelledError):
650            await r
651
652    async def test_taskgroup_23(self):
653
654        async def do_job(delay):
655            await asyncio.sleep(delay)
656
657        async with taskgroups.TaskGroup() as g:
658            for count in range(10):
659                await asyncio.sleep(0.1)
660                g.create_task(do_job(0.3))
661                if count == 5:
662                    self.assertLess(len(g._tasks), 5)
663            await asyncio.sleep(1.35)
664            self.assertEqual(len(g._tasks), 0)
665
666    async def test_taskgroup_24(self):
667
668        async def root(g):
669            await asyncio.sleep(0.1)
670            g.create_task(coro1(0.1))
671            g.create_task(coro1(0.2))
672
673        async def coro1(delay):
674            await asyncio.sleep(delay)
675
676        async def runner():
677            async with taskgroups.TaskGroup() as g:
678                g.create_task(root(g))
679
680        await runner()
681
682    async def test_taskgroup_25(self):
683        nhydras = 0
684
685        async def hydra(g):
686            nonlocal nhydras
687            nhydras += 1
688            await asyncio.sleep(0.01)
689            g.create_task(hydra(g))
690            g.create_task(hydra(g))
691
692        async def hercules():
693            while nhydras < 10:
694                await asyncio.sleep(0.015)
695            1 / 0
696
697        async def runner():
698            async with taskgroups.TaskGroup() as g:
699                g.create_task(hydra(g))
700                g.create_task(hercules())
701
702        with self.assertRaises(ExceptionGroup) as cm:
703            await runner()
704
705        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
706        self.assertGreaterEqual(nhydras, 10)
707
708    async def test_taskgroup_task_name(self):
709        async def coro():
710            await asyncio.sleep(0)
711        async with taskgroups.TaskGroup() as g:
712            t = g.create_task(coro(), name="yolo")
713            self.assertEqual(t.get_name(), "yolo")
714
715    async def test_taskgroup_task_context(self):
716        cvar = contextvars.ContextVar('cvar')
717
718        async def coro(val):
719            await asyncio.sleep(0)
720            cvar.set(val)
721
722        async with taskgroups.TaskGroup() as g:
723            ctx = contextvars.copy_context()
724            self.assertIsNone(ctx.get(cvar))
725            t1 = g.create_task(coro(1), context=ctx)
726            await t1
727            self.assertEqual(1, ctx.get(cvar))
728            t2 = g.create_task(coro(2), context=ctx)
729            await t2
730            self.assertEqual(2, ctx.get(cvar))
731
732    async def test_taskgroup_no_create_task_after_failure(self):
733        async def coro1():
734            await asyncio.sleep(0.001)
735            1 / 0
736        async def coro2(g):
737            try:
738                await asyncio.sleep(1)
739            except asyncio.CancelledError:
740                with self.assertRaises(RuntimeError):
741                    g.create_task(coro1())
742
743        with self.assertRaises(ExceptionGroup) as cm:
744            async with taskgroups.TaskGroup() as g:
745                g.create_task(coro1())
746                g.create_task(coro2(g))
747
748        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
749
750    async def test_taskgroup_context_manager_exit_raises(self):
751        # See https://github.com/python/cpython/issues/95289
752        class CustomException(Exception):
753            pass
754
755        async def raise_exc():
756            raise CustomException
757
758        @contextlib.asynccontextmanager
759        async def database():
760            try:
761                yield
762            finally:
763                raise CustomException
764
765        async def main():
766            task = asyncio.current_task()
767            try:
768                async with taskgroups.TaskGroup() as tg:
769                    async with database():
770                        tg.create_task(raise_exc())
771                        await asyncio.sleep(1)
772            except* CustomException as err:
773                self.assertEqual(task.cancelling(), 0)
774                self.assertEqual(len(err.exceptions), 2)
775
776            else:
777                self.fail('CustomException not raised')
778
779        await asyncio.create_task(main())
780
781    async def test_taskgroup_already_entered(self):
782        tg = taskgroups.TaskGroup()
783        async with tg:
784            with self.assertRaisesRegex(RuntimeError, "has already been entered"):
785                async with tg:
786                    pass
787
788    async def test_taskgroup_double_enter(self):
789        tg = taskgroups.TaskGroup()
790        async with tg:
791            pass
792        with self.assertRaisesRegex(RuntimeError, "has already been entered"):
793            async with tg:
794                pass
795
796    async def test_taskgroup_finished(self):
797        async def create_task_after_tg_finish():
798            tg = taskgroups.TaskGroup()
799            async with tg:
800                pass
801            coro = asyncio.sleep(0)
802            with self.assertRaisesRegex(RuntimeError, "is finished"):
803                tg.create_task(coro)
804
805        # Make sure the coroutine was closed when submitted to the inactive tg
806        # (if not closed, a RuntimeWarning should have been raised)
807        with warnings.catch_warnings(record=True) as w:
808            await create_task_after_tg_finish()
809        self.assertEqual(len(w), 0)
810
811    async def test_taskgroup_not_entered(self):
812        tg = taskgroups.TaskGroup()
813        coro = asyncio.sleep(0)
814        with self.assertRaisesRegex(RuntimeError, "has not been entered"):
815            tg.create_task(coro)
816
817    async def test_taskgroup_without_parent_task(self):
818        tg = taskgroups.TaskGroup()
819        with self.assertRaisesRegex(RuntimeError, "parent task"):
820            await await_without_task(tg.__aenter__())
821        coro = asyncio.sleep(0)
822        with self.assertRaisesRegex(RuntimeError, "has not been entered"):
823            tg.create_task(coro)
824
825    def test_coro_closed_when_tg_closed(self):
826        async def run_coro_after_tg_closes():
827            async with taskgroups.TaskGroup() as tg:
828                pass
829            coro = asyncio.sleep(0)
830            with self.assertRaisesRegex(RuntimeError, "is finished"):
831                tg.create_task(coro)
832        loop = asyncio.get_event_loop()
833        loop.run_until_complete(run_coro_after_tg_closes())
834
835    async def test_cancelling_level_preserved(self):
836        async def raise_after(t, e):
837            await asyncio.sleep(t)
838            raise e()
839
840        try:
841            async with asyncio.TaskGroup() as tg:
842                tg.create_task(raise_after(0.0, RuntimeError))
843        except* RuntimeError:
844            pass
845        self.assertEqual(asyncio.current_task().cancelling(), 0)
846
847    async def test_nested_groups_both_cancelled(self):
848        async def raise_after(t, e):
849            await asyncio.sleep(t)
850            raise e()
851
852        try:
853            async with asyncio.TaskGroup() as outer_tg:
854                try:
855                    async with asyncio.TaskGroup() as inner_tg:
856                        inner_tg.create_task(raise_after(0, RuntimeError))
857                        outer_tg.create_task(raise_after(0, ValueError))
858                except* RuntimeError:
859                    pass
860                else:
861                    self.fail("RuntimeError not raised")
862            self.assertEqual(asyncio.current_task().cancelling(), 1)
863        except* ValueError:
864            pass
865        else:
866            self.fail("ValueError not raised")
867        self.assertEqual(asyncio.current_task().cancelling(), 0)
868
869    async def test_error_and_cancel(self):
870        event = asyncio.Event()
871
872        async def raise_error():
873            event.set()
874            await asyncio.sleep(0)
875            raise RuntimeError()
876
877        async def inner():
878            try:
879                async with taskgroups.TaskGroup() as tg:
880                    tg.create_task(raise_error())
881                    await asyncio.sleep(1)
882                    self.fail("Sleep in group should have been cancelled")
883            except* RuntimeError:
884                self.assertEqual(asyncio.current_task().cancelling(), 1)
885            self.assertEqual(asyncio.current_task().cancelling(), 1)
886            await asyncio.sleep(1)
887            self.fail("Sleep after group should have been cancelled")
888
889        async def outer():
890            t = asyncio.create_task(inner())
891            await event.wait()
892            self.assertEqual(t.cancelling(), 0)
893            t.cancel()
894            self.assertEqual(t.cancelling(), 1)
895            with self.assertRaises(asyncio.CancelledError):
896                await t
897            self.assertTrue(t.cancelled())
898
899        await outer()
900
901    async def test_exception_refcycles_direct(self):
902        """Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup"""
903        tg = asyncio.TaskGroup()
904        exc = None
905
906        class _Done(Exception):
907            pass
908
909        try:
910            async with tg:
911                raise _Done
912        except ExceptionGroup as e:
913            exc = e
914
915        self.assertIsNotNone(exc)
916        self.assertListEqual(gc.get_referrers(exc), [])
917
918
919    async def test_exception_refcycles_errors(self):
920        """Test that TaskGroup deletes self._errors, and __aexit__ args"""
921        tg = asyncio.TaskGroup()
922        exc = None
923
924        class _Done(Exception):
925            pass
926
927        try:
928            async with tg:
929                raise _Done
930        except* _Done as excs:
931            exc = excs.exceptions[0]
932
933        self.assertIsInstance(exc, _Done)
934        self.assertListEqual(gc.get_referrers(exc), [])
935
936
937    async def test_exception_refcycles_parent_task(self):
938        """Test that TaskGroup deletes self._parent_task"""
939        tg = asyncio.TaskGroup()
940        exc = None
941
942        class _Done(Exception):
943            pass
944
945        async def coro_fn():
946            async with tg:
947                raise _Done
948
949        try:
950            async with asyncio.TaskGroup() as tg2:
951                tg2.create_task(coro_fn())
952        except* _Done as excs:
953            exc = excs.exceptions[0].exceptions[0]
954
955        self.assertIsInstance(exc, _Done)
956        self.assertListEqual(gc.get_referrers(exc), [])
957
958    async def test_exception_refcycles_propagate_cancellation_error(self):
959        """Test that TaskGroup deletes propagate_cancellation_error"""
960        tg = asyncio.TaskGroup()
961        exc = None
962
963        try:
964            async with asyncio.timeout(-1):
965                async with tg:
966                    await asyncio.sleep(0)
967        except TimeoutError as e:
968            exc = e.__cause__
969
970        self.assertIsInstance(exc, asyncio.CancelledError)
971        self.assertListEqual(gc.get_referrers(exc), [])
972
973    async def test_exception_refcycles_base_error(self):
974        """Test that TaskGroup deletes self._base_error"""
975        class MyKeyboardInterrupt(KeyboardInterrupt):
976            pass
977
978        tg = asyncio.TaskGroup()
979        exc = None
980
981        try:
982            async with tg:
983                raise MyKeyboardInterrupt
984        except MyKeyboardInterrupt as e:
985            exc = e
986
987        self.assertIsNotNone(exc)
988        self.assertListEqual(gc.get_referrers(exc), [])
989
990
991if __name__ == "__main__":
992    unittest.main()
993