• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Unit tests for contextlib.py, and other context managers."""
2
3import io
4import os
5import sys
6import tempfile
7import threading
8import traceback
9import unittest
10from contextlib import *  # Tests __all__
11from test import support
12from test.support import os_helper
13from test.support.testcase import ExceptionIsLikeMixin
14import weakref
15
16
17class TestAbstractContextManager(unittest.TestCase):
18
19    def test_enter(self):
20        class DefaultEnter(AbstractContextManager):
21            def __exit__(self, *args):
22                super().__exit__(*args)
23
24        manager = DefaultEnter()
25        self.assertIs(manager.__enter__(), manager)
26
27    def test_slots(self):
28        class DefaultContextManager(AbstractContextManager):
29            __slots__ = ()
30
31            def __exit__(self, *args):
32                super().__exit__(*args)
33
34        with self.assertRaises(AttributeError):
35            DefaultContextManager().var = 42
36
37    def test_exit_is_abstract(self):
38        class MissingExit(AbstractContextManager):
39            pass
40
41        with self.assertRaises(TypeError):
42            MissingExit()
43
44    def test_structural_subclassing(self):
45        class ManagerFromScratch:
46            def __enter__(self):
47                return self
48            def __exit__(self, exc_type, exc_value, traceback):
49                return None
50
51        self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
52
53        class DefaultEnter(AbstractContextManager):
54            def __exit__(self, *args):
55                super().__exit__(*args)
56
57        self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
58
59        class NoEnter(ManagerFromScratch):
60            __enter__ = None
61
62        self.assertFalse(issubclass(NoEnter, AbstractContextManager))
63
64        class NoExit(ManagerFromScratch):
65            __exit__ = None
66
67        self.assertFalse(issubclass(NoExit, AbstractContextManager))
68
69
70class ContextManagerTestCase(unittest.TestCase):
71
72    def test_contextmanager_plain(self):
73        state = []
74        @contextmanager
75        def woohoo():
76            state.append(1)
77            yield 42
78            state.append(999)
79        with woohoo() as x:
80            self.assertEqual(state, [1])
81            self.assertEqual(x, 42)
82            state.append(x)
83        self.assertEqual(state, [1, 42, 999])
84
85    def test_contextmanager_finally(self):
86        state = []
87        @contextmanager
88        def woohoo():
89            state.append(1)
90            try:
91                yield 42
92            finally:
93                state.append(999)
94        with self.assertRaises(ZeroDivisionError):
95            with woohoo() as x:
96                self.assertEqual(state, [1])
97                self.assertEqual(x, 42)
98                state.append(x)
99                raise ZeroDivisionError()
100        self.assertEqual(state, [1, 42, 999])
101
102    def test_contextmanager_traceback(self):
103        @contextmanager
104        def f():
105            yield
106
107        try:
108            with f():
109                1/0
110        except ZeroDivisionError as e:
111            frames = traceback.extract_tb(e.__traceback__)
112
113        self.assertEqual(len(frames), 1)
114        self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
115        self.assertEqual(frames[0].line, '1/0')
116
117        # Repeat with RuntimeError (which goes through a different code path)
118        class RuntimeErrorSubclass(RuntimeError):
119            pass
120
121        try:
122            with f():
123                raise RuntimeErrorSubclass(42)
124        except RuntimeErrorSubclass as e:
125            frames = traceback.extract_tb(e.__traceback__)
126
127        self.assertEqual(len(frames), 1)
128        self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
129        self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
130
131        class StopIterationSubclass(StopIteration):
132            pass
133
134        for stop_exc in (
135            StopIteration('spam'),
136            StopIterationSubclass('spam'),
137        ):
138            with self.subTest(type=type(stop_exc)):
139                try:
140                    with f():
141                        raise stop_exc
142                except type(stop_exc) as e:
143                    self.assertIs(e, stop_exc)
144                    frames = traceback.extract_tb(e.__traceback__)
145                else:
146                    self.fail(f'{stop_exc} was suppressed')
147
148                self.assertEqual(len(frames), 1)
149                self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
150                self.assertEqual(frames[0].line, 'raise stop_exc')
151
152    def test_contextmanager_no_reraise(self):
153        @contextmanager
154        def whee():
155            yield
156        ctx = whee()
157        ctx.__enter__()
158        # Calling __exit__ should not result in an exception
159        self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
160
161    def test_contextmanager_trap_yield_after_throw(self):
162        @contextmanager
163        def whoo():
164            try:
165                yield
166            except:
167                yield
168        ctx = whoo()
169        ctx.__enter__()
170        with self.assertRaises(RuntimeError):
171            ctx.__exit__(TypeError, TypeError("foo"), None)
172        if support.check_impl_detail(cpython=True):
173            # The "gen" attribute is an implementation detail.
174            self.assertFalse(ctx.gen.gi_suspended)
175
176    def test_contextmanager_trap_no_yield(self):
177        @contextmanager
178        def whoo():
179            if False:
180                yield
181        ctx = whoo()
182        with self.assertRaises(RuntimeError):
183            ctx.__enter__()
184
185    def test_contextmanager_trap_second_yield(self):
186        @contextmanager
187        def whoo():
188            yield
189            yield
190        ctx = whoo()
191        ctx.__enter__()
192        with self.assertRaises(RuntimeError):
193            ctx.__exit__(None, None, None)
194        if support.check_impl_detail(cpython=True):
195            # The "gen" attribute is an implementation detail.
196            self.assertFalse(ctx.gen.gi_suspended)
197
198    def test_contextmanager_non_normalised(self):
199        @contextmanager
200        def whoo():
201            try:
202                yield
203            except RuntimeError:
204                raise SyntaxError
205
206        ctx = whoo()
207        ctx.__enter__()
208        with self.assertRaises(SyntaxError):
209            ctx.__exit__(RuntimeError, None, None)
210
211    def test_contextmanager_except(self):
212        state = []
213        @contextmanager
214        def woohoo():
215            state.append(1)
216            try:
217                yield 42
218            except ZeroDivisionError as e:
219                state.append(e.args[0])
220                self.assertEqual(state, [1, 42, 999])
221        with woohoo() as x:
222            self.assertEqual(state, [1])
223            self.assertEqual(x, 42)
224            state.append(x)
225            raise ZeroDivisionError(999)
226        self.assertEqual(state, [1, 42, 999])
227
228    def test_contextmanager_except_stopiter(self):
229        @contextmanager
230        def woohoo():
231            yield
232
233        class StopIterationSubclass(StopIteration):
234            pass
235
236        for stop_exc in (StopIteration('spam'), StopIterationSubclass('spam')):
237            with self.subTest(type=type(stop_exc)):
238                try:
239                    with woohoo():
240                        raise stop_exc
241                except Exception as ex:
242                    self.assertIs(ex, stop_exc)
243                else:
244                    self.fail(f'{stop_exc} was suppressed')
245
246    def test_contextmanager_except_pep479(self):
247        code = """\
248from __future__ import generator_stop
249from contextlib import contextmanager
250@contextmanager
251def woohoo():
252    yield
253"""
254        locals = {}
255        exec(code, locals, locals)
256        woohoo = locals['woohoo']
257
258        stop_exc = StopIteration('spam')
259        try:
260            with woohoo():
261                raise stop_exc
262        except Exception as ex:
263            self.assertIs(ex, stop_exc)
264        else:
265            self.fail('StopIteration was suppressed')
266
267    def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self):
268        @contextmanager
269        def test_issue29692():
270            try:
271                yield
272            except Exception as exc:
273                raise RuntimeError('issue29692:Chained') from exc
274        try:
275            with test_issue29692():
276                raise ZeroDivisionError
277        except Exception as ex:
278            self.assertIs(type(ex), RuntimeError)
279            self.assertEqual(ex.args[0], 'issue29692:Chained')
280            self.assertIsInstance(ex.__cause__, ZeroDivisionError)
281
282        try:
283            with test_issue29692():
284                raise StopIteration('issue29692:Unchained')
285        except Exception as ex:
286            self.assertIs(type(ex), StopIteration)
287            self.assertEqual(ex.args[0], 'issue29692:Unchained')
288            self.assertIsNone(ex.__cause__)
289
290    def test_contextmanager_wrap_runtimeerror(self):
291        @contextmanager
292        def woohoo():
293            try:
294                yield
295            except Exception as exc:
296                raise RuntimeError(f'caught {exc}') from exc
297
298        with self.assertRaises(RuntimeError):
299            with woohoo():
300                1 / 0
301
302        # If the context manager wrapped StopIteration in a RuntimeError,
303        # we also unwrap it, because we can't tell whether the wrapping was
304        # done by the generator machinery or by the generator itself.
305        with self.assertRaises(StopIteration):
306            with woohoo():
307                raise StopIteration
308
309    def _create_contextmanager_attribs(self):
310        def attribs(**kw):
311            def decorate(func):
312                for k,v in kw.items():
313                    setattr(func,k,v)
314                return func
315            return decorate
316        @contextmanager
317        @attribs(foo='bar')
318        def baz(spam):
319            """Whee!"""
320            yield
321        return baz
322
323    def test_contextmanager_attribs(self):
324        baz = self._create_contextmanager_attribs()
325        self.assertEqual(baz.__name__,'baz')
326        self.assertEqual(baz.foo, 'bar')
327
328    @support.requires_docstrings
329    def test_contextmanager_doc_attrib(self):
330        baz = self._create_contextmanager_attribs()
331        self.assertEqual(baz.__doc__, "Whee!")
332
333    @support.requires_docstrings
334    def test_instance_docstring_given_cm_docstring(self):
335        baz = self._create_contextmanager_attribs()(None)
336        self.assertEqual(baz.__doc__, "Whee!")
337
338    def test_keywords(self):
339        # Ensure no keyword arguments are inhibited
340        @contextmanager
341        def woohoo(self, func, args, kwds):
342            yield (self, func, args, kwds)
343        with woohoo(self=11, func=22, args=33, kwds=44) as target:
344            self.assertEqual(target, (11, 22, 33, 44))
345
346    def test_nokeepref(self):
347        class A:
348            pass
349
350        @contextmanager
351        def woohoo(a, b):
352            a = weakref.ref(a)
353            b = weakref.ref(b)
354            # Allow test to work with a non-refcounted GC
355            support.gc_collect()
356            self.assertIsNone(a())
357            self.assertIsNone(b())
358            yield
359
360        with woohoo(A(), b=A()):
361            pass
362
363    def test_param_errors(self):
364        @contextmanager
365        def woohoo(a, *, b):
366            yield
367
368        with self.assertRaises(TypeError):
369            woohoo()
370        with self.assertRaises(TypeError):
371            woohoo(3, 5)
372        with self.assertRaises(TypeError):
373            woohoo(b=3)
374
375    def test_recursive(self):
376        depth = 0
377        ncols = 0
378        @contextmanager
379        def woohoo():
380            nonlocal ncols
381            ncols += 1
382            nonlocal depth
383            before = depth
384            depth += 1
385            yield
386            depth -= 1
387            self.assertEqual(depth, before)
388
389        @woohoo()
390        def recursive():
391            if depth < 10:
392                recursive()
393
394        recursive()
395        self.assertEqual(ncols, 10)
396        self.assertEqual(depth, 0)
397
398
399class ClosingTestCase(unittest.TestCase):
400
401    @support.requires_docstrings
402    def test_instance_docs(self):
403        # Issue 19330: ensure context manager instances have good docstrings
404        cm_docstring = closing.__doc__
405        obj = closing(None)
406        self.assertEqual(obj.__doc__, cm_docstring)
407
408    def test_closing(self):
409        state = []
410        class C:
411            def close(self):
412                state.append(1)
413        x = C()
414        self.assertEqual(state, [])
415        with closing(x) as y:
416            self.assertEqual(x, y)
417        self.assertEqual(state, [1])
418
419    def test_closing_error(self):
420        state = []
421        class C:
422            def close(self):
423                state.append(1)
424        x = C()
425        self.assertEqual(state, [])
426        with self.assertRaises(ZeroDivisionError):
427            with closing(x) as y:
428                self.assertEqual(x, y)
429                1 / 0
430        self.assertEqual(state, [1])
431
432
433class NullcontextTestCase(unittest.TestCase):
434    def test_nullcontext(self):
435        class C:
436            pass
437        c = C()
438        with nullcontext(c) as c_in:
439            self.assertIs(c_in, c)
440
441
442class FileContextTestCase(unittest.TestCase):
443
444    def testWithOpen(self):
445        tfn = tempfile.mktemp()
446        try:
447            with open(tfn, "w", encoding="utf-8") as f:
448                self.assertFalse(f.closed)
449                f.write("Booh\n")
450            self.assertTrue(f.closed)
451            with self.assertRaises(ZeroDivisionError):
452                with open(tfn, "r", encoding="utf-8") as f:
453                    self.assertFalse(f.closed)
454                    self.assertEqual(f.read(), "Booh\n")
455                    1 / 0
456            self.assertTrue(f.closed)
457        finally:
458            os_helper.unlink(tfn)
459
460class LockContextTestCase(unittest.TestCase):
461
462    def boilerPlate(self, lock, locked):
463        self.assertFalse(locked())
464        with lock:
465            self.assertTrue(locked())
466        self.assertFalse(locked())
467        with self.assertRaises(ZeroDivisionError):
468            with lock:
469                self.assertTrue(locked())
470                1 / 0
471        self.assertFalse(locked())
472
473    def testWithLock(self):
474        lock = threading.Lock()
475        self.boilerPlate(lock, lock.locked)
476
477    def testWithRLock(self):
478        lock = threading.RLock()
479        self.boilerPlate(lock, lock._is_owned)
480
481    def testWithCondition(self):
482        lock = threading.Condition()
483        def locked():
484            return lock._is_owned()
485        self.boilerPlate(lock, locked)
486
487    def testWithSemaphore(self):
488        lock = threading.Semaphore()
489        def locked():
490            if lock.acquire(False):
491                lock.release()
492                return False
493            else:
494                return True
495        self.boilerPlate(lock, locked)
496
497    def testWithBoundedSemaphore(self):
498        lock = threading.BoundedSemaphore()
499        def locked():
500            if lock.acquire(False):
501                lock.release()
502                return False
503            else:
504                return True
505        self.boilerPlate(lock, locked)
506
507
508class mycontext(ContextDecorator):
509    """Example decoration-compatible context manager for testing"""
510    started = False
511    exc = None
512    catch = False
513
514    def __enter__(self):
515        self.started = True
516        return self
517
518    def __exit__(self, *exc):
519        self.exc = exc
520        return self.catch
521
522
523class TestContextDecorator(unittest.TestCase):
524
525    @support.requires_docstrings
526    def test_instance_docs(self):
527        # Issue 19330: ensure context manager instances have good docstrings
528        cm_docstring = mycontext.__doc__
529        obj = mycontext()
530        self.assertEqual(obj.__doc__, cm_docstring)
531
532    def test_contextdecorator(self):
533        context = mycontext()
534        with context as result:
535            self.assertIs(result, context)
536            self.assertTrue(context.started)
537
538        self.assertEqual(context.exc, (None, None, None))
539
540
541    def test_contextdecorator_with_exception(self):
542        context = mycontext()
543
544        with self.assertRaisesRegex(NameError, 'foo'):
545            with context:
546                raise NameError('foo')
547        self.assertIsNotNone(context.exc)
548        self.assertIs(context.exc[0], NameError)
549
550        context = mycontext()
551        context.catch = True
552        with context:
553            raise NameError('foo')
554        self.assertIsNotNone(context.exc)
555        self.assertIs(context.exc[0], NameError)
556
557
558    def test_decorator(self):
559        context = mycontext()
560
561        @context
562        def test():
563            self.assertIsNone(context.exc)
564            self.assertTrue(context.started)
565        test()
566        self.assertEqual(context.exc, (None, None, None))
567
568
569    def test_decorator_with_exception(self):
570        context = mycontext()
571
572        @context
573        def test():
574            self.assertIsNone(context.exc)
575            self.assertTrue(context.started)
576            raise NameError('foo')
577
578        with self.assertRaisesRegex(NameError, 'foo'):
579            test()
580        self.assertIsNotNone(context.exc)
581        self.assertIs(context.exc[0], NameError)
582
583
584    def test_decorating_method(self):
585        context = mycontext()
586
587        class Test(object):
588
589            @context
590            def method(self, a, b, c=None):
591                self.a = a
592                self.b = b
593                self.c = c
594
595        # these tests are for argument passing when used as a decorator
596        test = Test()
597        test.method(1, 2)
598        self.assertEqual(test.a, 1)
599        self.assertEqual(test.b, 2)
600        self.assertEqual(test.c, None)
601
602        test = Test()
603        test.method('a', 'b', 'c')
604        self.assertEqual(test.a, 'a')
605        self.assertEqual(test.b, 'b')
606        self.assertEqual(test.c, 'c')
607
608        test = Test()
609        test.method(a=1, b=2)
610        self.assertEqual(test.a, 1)
611        self.assertEqual(test.b, 2)
612
613
614    def test_typo_enter(self):
615        class mycontext(ContextDecorator):
616            def __unter__(self):
617                pass
618            def __exit__(self, *exc):
619                pass
620
621        with self.assertRaisesRegex(TypeError, 'the context manager'):
622            with mycontext():
623                pass
624
625
626    def test_typo_exit(self):
627        class mycontext(ContextDecorator):
628            def __enter__(self):
629                pass
630            def __uxit__(self, *exc):
631                pass
632
633        with self.assertRaisesRegex(TypeError, 'the context manager.*__exit__'):
634            with mycontext():
635                pass
636
637
638    def test_contextdecorator_as_mixin(self):
639        class somecontext(object):
640            started = False
641            exc = None
642
643            def __enter__(self):
644                self.started = True
645                return self
646
647            def __exit__(self, *exc):
648                self.exc = exc
649
650        class mycontext(somecontext, ContextDecorator):
651            pass
652
653        context = mycontext()
654        @context
655        def test():
656            self.assertIsNone(context.exc)
657            self.assertTrue(context.started)
658        test()
659        self.assertEqual(context.exc, (None, None, None))
660
661
662    def test_contextmanager_as_decorator(self):
663        @contextmanager
664        def woohoo(y):
665            state.append(y)
666            yield
667            state.append(999)
668
669        state = []
670        @woohoo(1)
671        def test(x):
672            self.assertEqual(state, [1])
673            state.append(x)
674        test('something')
675        self.assertEqual(state, [1, 'something', 999])
676
677        # Issue #11647: Ensure the decorated function is 'reusable'
678        state = []
679        test('something else')
680        self.assertEqual(state, [1, 'something else', 999])
681
682
683class TestBaseExitStack:
684    exit_stack = None
685
686    @support.requires_docstrings
687    def test_instance_docs(self):
688        # Issue 19330: ensure context manager instances have good docstrings
689        cm_docstring = self.exit_stack.__doc__
690        obj = self.exit_stack()
691        self.assertEqual(obj.__doc__, cm_docstring)
692
693    def test_no_resources(self):
694        with self.exit_stack():
695            pass
696
697    def test_callback(self):
698        expected = [
699            ((), {}),
700            ((1,), {}),
701            ((1,2), {}),
702            ((), dict(example=1)),
703            ((1,), dict(example=1)),
704            ((1,2), dict(example=1)),
705            ((1,2), dict(self=3, callback=4)),
706        ]
707        result = []
708        def _exit(*args, **kwds):
709            """Test metadata propagation"""
710            result.append((args, kwds))
711        with self.exit_stack() as stack:
712            for args, kwds in reversed(expected):
713                if args and kwds:
714                    f = stack.callback(_exit, *args, **kwds)
715                elif args:
716                    f = stack.callback(_exit, *args)
717                elif kwds:
718                    f = stack.callback(_exit, **kwds)
719                else:
720                    f = stack.callback(_exit)
721                self.assertIs(f, _exit)
722            for wrapper in stack._exit_callbacks:
723                self.assertIs(wrapper[1].__wrapped__, _exit)
724                self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
725                self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
726        self.assertEqual(result, expected)
727
728        result = []
729        with self.exit_stack() as stack:
730            with self.assertRaises(TypeError):
731                stack.callback(arg=1)
732            with self.assertRaises(TypeError):
733                self.exit_stack.callback(arg=2)
734            with self.assertRaises(TypeError):
735                stack.callback(callback=_exit, arg=3)
736        self.assertEqual(result, [])
737
738    def test_push(self):
739        exc_raised = ZeroDivisionError
740        def _expect_exc(exc_type, exc, exc_tb):
741            self.assertIs(exc_type, exc_raised)
742        def _suppress_exc(*exc_details):
743            return True
744        def _expect_ok(exc_type, exc, exc_tb):
745            self.assertIsNone(exc_type)
746            self.assertIsNone(exc)
747            self.assertIsNone(exc_tb)
748        class ExitCM(object):
749            def __init__(self, check_exc):
750                self.check_exc = check_exc
751            def __enter__(self):
752                self.fail("Should not be called!")
753            def __exit__(self, *exc_details):
754                self.check_exc(*exc_details)
755        with self.exit_stack() as stack:
756            stack.push(_expect_ok)
757            self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
758            cm = ExitCM(_expect_ok)
759            stack.push(cm)
760            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
761            stack.push(_suppress_exc)
762            self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
763            cm = ExitCM(_expect_exc)
764            stack.push(cm)
765            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
766            stack.push(_expect_exc)
767            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
768            stack.push(_expect_exc)
769            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
770            1/0
771
772    def test_enter_context(self):
773        class TestCM(object):
774            def __enter__(self):
775                result.append(1)
776            def __exit__(self, *exc_details):
777                result.append(3)
778
779        result = []
780        cm = TestCM()
781        with self.exit_stack() as stack:
782            @stack.callback  # Registered first => cleaned up last
783            def _exit():
784                result.append(4)
785            self.assertIsNotNone(_exit)
786            stack.enter_context(cm)
787            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
788            result.append(2)
789        self.assertEqual(result, [1, 2, 3, 4])
790
791    def test_enter_context_errors(self):
792        class LacksEnterAndExit:
793            pass
794        class LacksEnter:
795            def __exit__(self, *exc_info):
796                pass
797        class LacksExit:
798            def __enter__(self):
799                pass
800
801        with self.exit_stack() as stack:
802            with self.assertRaisesRegex(TypeError, 'the context manager'):
803                stack.enter_context(LacksEnterAndExit())
804            with self.assertRaisesRegex(TypeError, 'the context manager'):
805                stack.enter_context(LacksEnter())
806            with self.assertRaisesRegex(TypeError, 'the context manager'):
807                stack.enter_context(LacksExit())
808            self.assertFalse(stack._exit_callbacks)
809
810    def test_close(self):
811        result = []
812        with self.exit_stack() as stack:
813            @stack.callback
814            def _exit():
815                result.append(1)
816            self.assertIsNotNone(_exit)
817            stack.close()
818            result.append(2)
819        self.assertEqual(result, [1, 2])
820
821    def test_pop_all(self):
822        result = []
823        with self.exit_stack() as stack:
824            @stack.callback
825            def _exit():
826                result.append(3)
827            self.assertIsNotNone(_exit)
828            new_stack = stack.pop_all()
829            result.append(1)
830        result.append(2)
831        new_stack.close()
832        self.assertEqual(result, [1, 2, 3])
833
834    def test_exit_raise(self):
835        with self.assertRaises(ZeroDivisionError):
836            with self.exit_stack() as stack:
837                stack.push(lambda *exc: False)
838                1/0
839
840    def test_exit_suppress(self):
841        with self.exit_stack() as stack:
842            stack.push(lambda *exc: True)
843            1/0
844
845    def test_exit_exception_traceback(self):
846        # This test captures the current behavior of ExitStack so that we know
847        # if we ever unintendedly change it. It is not a statement of what the
848        # desired behavior is (for instance, we may want to remove some of the
849        # internal contextlib frames).
850
851        def raise_exc(exc):
852            raise exc
853
854        try:
855            with self.exit_stack() as stack:
856                stack.callback(raise_exc, ValueError)
857                1/0
858        except ValueError as e:
859            exc = e
860
861        self.assertIsInstance(exc, ValueError)
862        ve_frames = traceback.extract_tb(exc.__traceback__)
863        expected = \
864            [('test_exit_exception_traceback', 'with self.exit_stack() as stack:')] + \
865            self.callback_error_internal_frames + \
866            [('_exit_wrapper', 'callback(*args, **kwds)'),
867             ('raise_exc', 'raise exc')]
868
869        self.assertEqual(
870            [(f.name, f.line) for f in ve_frames], expected)
871
872        self.assertIsInstance(exc.__context__, ZeroDivisionError)
873        zde_frames = traceback.extract_tb(exc.__context__.__traceback__)
874        self.assertEqual([(f.name, f.line) for f in zde_frames],
875                         [('test_exit_exception_traceback', '1/0')])
876
877    def test_exit_exception_chaining_reference(self):
878        # Sanity check to make sure that ExitStack chaining matches
879        # actual nested with statements
880        class RaiseExc:
881            def __init__(self, exc):
882                self.exc = exc
883            def __enter__(self):
884                return self
885            def __exit__(self, *exc_details):
886                raise self.exc
887
888        class RaiseExcWithContext:
889            def __init__(self, outer, inner):
890                self.outer = outer
891                self.inner = inner
892            def __enter__(self):
893                return self
894            def __exit__(self, *exc_details):
895                try:
896                    raise self.inner
897                except:
898                    raise self.outer
899
900        class SuppressExc:
901            def __enter__(self):
902                return self
903            def __exit__(self, *exc_details):
904                type(self).saved_details = exc_details
905                return True
906
907        try:
908            with RaiseExc(IndexError):
909                with RaiseExcWithContext(KeyError, AttributeError):
910                    with SuppressExc():
911                        with RaiseExc(ValueError):
912                            1 / 0
913        except IndexError as exc:
914            self.assertIsInstance(exc.__context__, KeyError)
915            self.assertIsInstance(exc.__context__.__context__, AttributeError)
916            # Inner exceptions were suppressed
917            self.assertIsNone(exc.__context__.__context__.__context__)
918        else:
919            self.fail("Expected IndexError, but no exception was raised")
920        # Check the inner exceptions
921        inner_exc = SuppressExc.saved_details[1]
922        self.assertIsInstance(inner_exc, ValueError)
923        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
924
925    def test_exit_exception_chaining(self):
926        # Ensure exception chaining matches the reference behaviour
927        def raise_exc(exc):
928            raise exc
929
930        saved_details = None
931        def suppress_exc(*exc_details):
932            nonlocal saved_details
933            saved_details = exc_details
934            return True
935
936        try:
937            with self.exit_stack() as stack:
938                stack.callback(raise_exc, IndexError)
939                stack.callback(raise_exc, KeyError)
940                stack.callback(raise_exc, AttributeError)
941                stack.push(suppress_exc)
942                stack.callback(raise_exc, ValueError)
943                1 / 0
944        except IndexError as exc:
945            self.assertIsInstance(exc.__context__, KeyError)
946            self.assertIsInstance(exc.__context__.__context__, AttributeError)
947            # Inner exceptions were suppressed
948            self.assertIsNone(exc.__context__.__context__.__context__)
949        else:
950            self.fail("Expected IndexError, but no exception was raised")
951        # Check the inner exceptions
952        inner_exc = saved_details[1]
953        self.assertIsInstance(inner_exc, ValueError)
954        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
955
956    def test_exit_exception_explicit_none_context(self):
957        # Ensure ExitStack chaining matches actual nested `with` statements
958        # regarding explicit __context__ = None.
959
960        class MyException(Exception):
961            pass
962
963        @contextmanager
964        def my_cm():
965            try:
966                yield
967            except BaseException:
968                exc = MyException()
969                try:
970                    raise exc
971                finally:
972                    exc.__context__ = None
973
974        @contextmanager
975        def my_cm_with_exit_stack():
976            with self.exit_stack() as stack:
977                stack.enter_context(my_cm())
978                yield stack
979
980        for cm in (my_cm, my_cm_with_exit_stack):
981            with self.subTest():
982                try:
983                    with cm():
984                        raise IndexError()
985                except MyException as exc:
986                    self.assertIsNone(exc.__context__)
987                else:
988                    self.fail("Expected IndexError, but no exception was raised")
989
990    def test_exit_exception_non_suppressing(self):
991        # http://bugs.python.org/issue19092
992        def raise_exc(exc):
993            raise exc
994
995        def suppress_exc(*exc_details):
996            return True
997
998        try:
999            with self.exit_stack() as stack:
1000                stack.callback(lambda: None)
1001                stack.callback(raise_exc, IndexError)
1002        except Exception as exc:
1003            self.assertIsInstance(exc, IndexError)
1004        else:
1005            self.fail("Expected IndexError, but no exception was raised")
1006
1007        try:
1008            with self.exit_stack() as stack:
1009                stack.callback(raise_exc, KeyError)
1010                stack.push(suppress_exc)
1011                stack.callback(raise_exc, IndexError)
1012        except Exception as exc:
1013            self.assertIsInstance(exc, KeyError)
1014        else:
1015            self.fail("Expected KeyError, but no exception was raised")
1016
1017    def test_exit_exception_with_correct_context(self):
1018        # http://bugs.python.org/issue20317
1019        @contextmanager
1020        def gets_the_context_right(exc):
1021            try:
1022                yield
1023            finally:
1024                raise exc
1025
1026        exc1 = Exception(1)
1027        exc2 = Exception(2)
1028        exc3 = Exception(3)
1029        exc4 = Exception(4)
1030
1031        # The contextmanager already fixes the context, so prior to the
1032        # fix, ExitStack would try to fix it *again* and get into an
1033        # infinite self-referential loop
1034        try:
1035            with self.exit_stack() as stack:
1036                stack.enter_context(gets_the_context_right(exc4))
1037                stack.enter_context(gets_the_context_right(exc3))
1038                stack.enter_context(gets_the_context_right(exc2))
1039                raise exc1
1040        except Exception as exc:
1041            self.assertIs(exc, exc4)
1042            self.assertIs(exc.__context__, exc3)
1043            self.assertIs(exc.__context__.__context__, exc2)
1044            self.assertIs(exc.__context__.__context__.__context__, exc1)
1045            self.assertIsNone(
1046                       exc.__context__.__context__.__context__.__context__)
1047
1048    def test_exit_exception_with_existing_context(self):
1049        # Addresses a lack of test coverage discovered after checking in a
1050        # fix for issue 20317 that still contained debugging code.
1051        def raise_nested(inner_exc, outer_exc):
1052            try:
1053                raise inner_exc
1054            finally:
1055                raise outer_exc
1056        exc1 = Exception(1)
1057        exc2 = Exception(2)
1058        exc3 = Exception(3)
1059        exc4 = Exception(4)
1060        exc5 = Exception(5)
1061        try:
1062            with self.exit_stack() as stack:
1063                stack.callback(raise_nested, exc4, exc5)
1064                stack.callback(raise_nested, exc2, exc3)
1065                raise exc1
1066        except Exception as exc:
1067            self.assertIs(exc, exc5)
1068            self.assertIs(exc.__context__, exc4)
1069            self.assertIs(exc.__context__.__context__, exc3)
1070            self.assertIs(exc.__context__.__context__.__context__, exc2)
1071            self.assertIs(
1072                 exc.__context__.__context__.__context__.__context__, exc1)
1073            self.assertIsNone(
1074                exc.__context__.__context__.__context__.__context__.__context__)
1075
1076    def test_body_exception_suppress(self):
1077        def suppress_exc(*exc_details):
1078            return True
1079        try:
1080            with self.exit_stack() as stack:
1081                stack.push(suppress_exc)
1082                1/0
1083        except IndexError as exc:
1084            self.fail("Expected no exception, got IndexError")
1085
1086    def test_exit_exception_chaining_suppress(self):
1087        with self.exit_stack() as stack:
1088            stack.push(lambda *exc: True)
1089            stack.push(lambda *exc: 1/0)
1090            stack.push(lambda *exc: {}[1])
1091
1092    def test_excessive_nesting(self):
1093        # The original implementation would die with RecursionError here
1094        with self.exit_stack() as stack:
1095            for i in range(10000):
1096                stack.callback(int)
1097
1098    def test_instance_bypass(self):
1099        class Example(object): pass
1100        cm = Example()
1101        cm.__enter__ = object()
1102        cm.__exit__ = object()
1103        stack = self.exit_stack()
1104        with self.assertRaisesRegex(TypeError, 'the context manager'):
1105            stack.enter_context(cm)
1106        stack.push(cm)
1107        self.assertIs(stack._exit_callbacks[-1][1], cm)
1108
1109    def test_dont_reraise_RuntimeError(self):
1110        # https://bugs.python.org/issue27122
1111        class UniqueException(Exception): pass
1112        class UniqueRuntimeError(RuntimeError): pass
1113
1114        @contextmanager
1115        def second():
1116            try:
1117                yield 1
1118            except Exception as exc:
1119                raise UniqueException("new exception") from exc
1120
1121        @contextmanager
1122        def first():
1123            try:
1124                yield 1
1125            except Exception as exc:
1126                raise exc
1127
1128        # The UniqueRuntimeError should be caught by second()'s exception
1129        # handler which chain raised a new UniqueException.
1130        with self.assertRaises(UniqueException) as err_ctx:
1131            with self.exit_stack() as es_ctx:
1132                es_ctx.enter_context(second())
1133                es_ctx.enter_context(first())
1134                raise UniqueRuntimeError("please no infinite loop.")
1135
1136        exc = err_ctx.exception
1137        self.assertIsInstance(exc, UniqueException)
1138        self.assertIsInstance(exc.__context__, UniqueRuntimeError)
1139        self.assertIsNone(exc.__context__.__context__)
1140        self.assertIsNone(exc.__context__.__cause__)
1141        self.assertIs(exc.__cause__, exc.__context__)
1142
1143
1144class TestExitStack(TestBaseExitStack, unittest.TestCase):
1145    exit_stack = ExitStack
1146    callback_error_internal_frames = [
1147        ('__exit__', 'raise exc'),
1148        ('__exit__', 'if cb(*exc_details):'),
1149    ]
1150
1151
1152class TestRedirectStream:
1153
1154    redirect_stream = None
1155    orig_stream = None
1156
1157    @support.requires_docstrings
1158    def test_instance_docs(self):
1159        # Issue 19330: ensure context manager instances have good docstrings
1160        cm_docstring = self.redirect_stream.__doc__
1161        obj = self.redirect_stream(None)
1162        self.assertEqual(obj.__doc__, cm_docstring)
1163
1164    def test_no_redirect_in_init(self):
1165        orig_stdout = getattr(sys, self.orig_stream)
1166        self.redirect_stream(None)
1167        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1168
1169    def test_redirect_to_string_io(self):
1170        f = io.StringIO()
1171        msg = "Consider an API like help(), which prints directly to stdout"
1172        orig_stdout = getattr(sys, self.orig_stream)
1173        with self.redirect_stream(f):
1174            print(msg, file=getattr(sys, self.orig_stream))
1175        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1176        s = f.getvalue().strip()
1177        self.assertEqual(s, msg)
1178
1179    def test_enter_result_is_target(self):
1180        f = io.StringIO()
1181        with self.redirect_stream(f) as enter_result:
1182            self.assertIs(enter_result, f)
1183
1184    def test_cm_is_reusable(self):
1185        f = io.StringIO()
1186        write_to_f = self.redirect_stream(f)
1187        orig_stdout = getattr(sys, self.orig_stream)
1188        with write_to_f:
1189            print("Hello", end=" ", file=getattr(sys, self.orig_stream))
1190        with write_to_f:
1191            print("World!", file=getattr(sys, self.orig_stream))
1192        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1193        s = f.getvalue()
1194        self.assertEqual(s, "Hello World!\n")
1195
1196    def test_cm_is_reentrant(self):
1197        f = io.StringIO()
1198        write_to_f = self.redirect_stream(f)
1199        orig_stdout = getattr(sys, self.orig_stream)
1200        with write_to_f:
1201            print("Hello", end=" ", file=getattr(sys, self.orig_stream))
1202            with write_to_f:
1203                print("World!", file=getattr(sys, self.orig_stream))
1204        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1205        s = f.getvalue()
1206        self.assertEqual(s, "Hello World!\n")
1207
1208
1209class TestRedirectStdout(TestRedirectStream, unittest.TestCase):
1210
1211    redirect_stream = redirect_stdout
1212    orig_stream = "stdout"
1213
1214
1215class TestRedirectStderr(TestRedirectStream, unittest.TestCase):
1216
1217    redirect_stream = redirect_stderr
1218    orig_stream = "stderr"
1219
1220
1221class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase):
1222
1223    @support.requires_docstrings
1224    def test_instance_docs(self):
1225        # Issue 19330: ensure context manager instances have good docstrings
1226        cm_docstring = suppress.__doc__
1227        obj = suppress()
1228        self.assertEqual(obj.__doc__, cm_docstring)
1229
1230    def test_no_result_from_enter(self):
1231        with suppress(ValueError) as enter_result:
1232            self.assertIsNone(enter_result)
1233
1234    def test_no_exception(self):
1235        with suppress(ValueError):
1236            self.assertEqual(pow(2, 5), 32)
1237
1238    def test_exact_exception(self):
1239        with suppress(TypeError):
1240            len(5)
1241
1242    def test_exception_hierarchy(self):
1243        with suppress(LookupError):
1244            'Hello'[50]
1245
1246    def test_other_exception(self):
1247        with self.assertRaises(ZeroDivisionError):
1248            with suppress(TypeError):
1249                1/0
1250
1251    def test_no_args(self):
1252        with self.assertRaises(ZeroDivisionError):
1253            with suppress():
1254                1/0
1255
1256    def test_multiple_exception_args(self):
1257        with suppress(ZeroDivisionError, TypeError):
1258            1/0
1259        with suppress(ZeroDivisionError, TypeError):
1260            len(5)
1261
1262    def test_cm_is_reentrant(self):
1263        ignore_exceptions = suppress(Exception)
1264        with ignore_exceptions:
1265            pass
1266        with ignore_exceptions:
1267            len(5)
1268        with ignore_exceptions:
1269            with ignore_exceptions: # Check nested usage
1270                len(5)
1271            outer_continued = True
1272            1/0
1273        self.assertTrue(outer_continued)
1274
1275    def test_exception_groups(self):
1276        eg_ve = lambda: ExceptionGroup(
1277            "EG with ValueErrors only",
1278            [ValueError("ve1"), ValueError("ve2"), ValueError("ve3")],
1279        )
1280        eg_all = lambda: ExceptionGroup(
1281            "EG with many types of exceptions",
1282            [ValueError("ve1"), KeyError("ke1"), ValueError("ve2"), KeyError("ke2")],
1283        )
1284        with suppress(ValueError):
1285            raise eg_ve()
1286        with suppress(ValueError, KeyError):
1287            raise eg_all()
1288        with self.assertRaises(ExceptionGroup) as eg1:
1289            with suppress(ValueError):
1290                raise eg_all()
1291        self.assertExceptionIsLike(
1292            eg1.exception,
1293            ExceptionGroup(
1294                "EG with many types of exceptions",
1295                [KeyError("ke1"), KeyError("ke2")],
1296            ),
1297        )
1298        # Check handling of BaseExceptionGroup, using GeneratorExit so that
1299        # we don't accidentally discard a ctrl-c with KeyboardInterrupt.
1300        with suppress(GeneratorExit):
1301            raise BaseExceptionGroup("message", [GeneratorExit()])
1302        # If we raise a BaseException group, we can still suppress parts
1303        with self.assertRaises(BaseExceptionGroup) as eg1:
1304            with suppress(KeyError):
1305                raise BaseExceptionGroup("message", [GeneratorExit("g"), KeyError("k")])
1306        self.assertExceptionIsLike(
1307            eg1.exception, BaseExceptionGroup("message", [GeneratorExit("g")]),
1308        )
1309        # If we suppress all the leaf BaseExceptions, we get a non-base ExceptionGroup
1310        with self.assertRaises(ExceptionGroup) as eg1:
1311            with suppress(GeneratorExit):
1312                raise BaseExceptionGroup("message", [GeneratorExit("g"), KeyError("k")])
1313        self.assertExceptionIsLike(
1314            eg1.exception, ExceptionGroup("message", [KeyError("k")]),
1315        )
1316
1317
1318class TestChdir(unittest.TestCase):
1319    def make_relative_path(self, *parts):
1320        return os.path.join(
1321            os.path.dirname(os.path.realpath(__file__)),
1322            *parts,
1323        )
1324
1325    def test_simple(self):
1326        old_cwd = os.getcwd()
1327        target = self.make_relative_path('data')
1328        self.assertNotEqual(old_cwd, target)
1329
1330        with chdir(target):
1331            self.assertEqual(os.getcwd(), target)
1332        self.assertEqual(os.getcwd(), old_cwd)
1333
1334    def test_reentrant(self):
1335        old_cwd = os.getcwd()
1336        target1 = self.make_relative_path('data')
1337        target2 = self.make_relative_path('archivetestdata')
1338        self.assertNotIn(old_cwd, (target1, target2))
1339        chdir1, chdir2 = chdir(target1), chdir(target2)
1340
1341        with chdir1:
1342            self.assertEqual(os.getcwd(), target1)
1343            with chdir2:
1344                self.assertEqual(os.getcwd(), target2)
1345                with chdir1:
1346                    self.assertEqual(os.getcwd(), target1)
1347                self.assertEqual(os.getcwd(), target2)
1348            self.assertEqual(os.getcwd(), target1)
1349        self.assertEqual(os.getcwd(), old_cwd)
1350
1351    def test_exception(self):
1352        old_cwd = os.getcwd()
1353        target = self.make_relative_path('data')
1354        self.assertNotEqual(old_cwd, target)
1355
1356        try:
1357            with chdir(target):
1358                self.assertEqual(os.getcwd(), target)
1359                raise RuntimeError("boom")
1360        except RuntimeError as re:
1361            self.assertEqual(str(re), "boom")
1362        self.assertEqual(os.getcwd(), old_cwd)
1363
1364
1365if __name__ == "__main__":
1366    unittest.main()
1367