• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import asyncio
2import unittest
3
4from unittest import mock
5from test.test_asyncio import utils as test_utils
6
7
8def tearDownModule():
9    asyncio.set_event_loop_policy(None)
10
11
12class TestPolicy(asyncio.AbstractEventLoopPolicy):
13
14    def __init__(self, loop_factory):
15        self.loop_factory = loop_factory
16        self.loop = None
17
18    def get_event_loop(self):
19        # shouldn't ever be called by asyncio.run()
20        raise RuntimeError
21
22    def new_event_loop(self):
23        return self.loop_factory()
24
25    def set_event_loop(self, loop):
26        if loop is not None:
27            # we want to check if the loop is closed
28            # in BaseTest.tearDown
29            self.loop = loop
30
31
32class BaseTest(unittest.TestCase):
33
34    def new_loop(self):
35        loop = asyncio.BaseEventLoop()
36        loop._process_events = mock.Mock()
37        loop._selector = mock.Mock()
38        loop._selector.select.return_value = ()
39        loop.shutdown_ag_run = False
40
41        async def shutdown_asyncgens():
42            loop.shutdown_ag_run = True
43        loop.shutdown_asyncgens = shutdown_asyncgens
44
45        return loop
46
47    def setUp(self):
48        super().setUp()
49
50        policy = TestPolicy(self.new_loop)
51        asyncio.set_event_loop_policy(policy)
52
53    def tearDown(self):
54        policy = asyncio.get_event_loop_policy()
55        if policy.loop is not None:
56            self.assertTrue(policy.loop.is_closed())
57            self.assertTrue(policy.loop.shutdown_ag_run)
58
59        asyncio.set_event_loop_policy(None)
60        super().tearDown()
61
62
63class RunTests(BaseTest):
64
65    def test_asyncio_run_return(self):
66        async def main():
67            await asyncio.sleep(0)
68            return 42
69
70        self.assertEqual(asyncio.run(main()), 42)
71
72    def test_asyncio_run_raises(self):
73        async def main():
74            await asyncio.sleep(0)
75            raise ValueError('spam')
76
77        with self.assertRaisesRegex(ValueError, 'spam'):
78            asyncio.run(main())
79
80    def test_asyncio_run_only_coro(self):
81        for o in {1, lambda: None}:
82            with self.subTest(obj=o), \
83                    self.assertRaisesRegex(ValueError,
84                                           'a coroutine was expected'):
85                asyncio.run(o)
86
87    def test_asyncio_run_debug(self):
88        async def main(expected):
89            loop = asyncio.get_event_loop()
90            self.assertIs(loop.get_debug(), expected)
91
92        asyncio.run(main(False))
93        asyncio.run(main(True), debug=True)
94        with mock.patch('asyncio.coroutines._is_debug_mode', lambda: True):
95            asyncio.run(main(True))
96            asyncio.run(main(False), debug=False)
97
98    def test_asyncio_run_from_running_loop(self):
99        async def main():
100            coro = main()
101            try:
102                asyncio.run(coro)
103            finally:
104                coro.close()  # Suppress ResourceWarning
105
106        with self.assertRaisesRegex(RuntimeError,
107                                    'cannot be called from a running'):
108            asyncio.run(main())
109
110    def test_asyncio_run_cancels_hanging_tasks(self):
111        lo_task = None
112
113        async def leftover():
114            await asyncio.sleep(0.1)
115
116        async def main():
117            nonlocal lo_task
118            lo_task = asyncio.create_task(leftover())
119            return 123
120
121        self.assertEqual(asyncio.run(main()), 123)
122        self.assertTrue(lo_task.done())
123
124    def test_asyncio_run_reports_hanging_tasks_errors(self):
125        lo_task = None
126        call_exc_handler_mock = mock.Mock()
127
128        async def leftover():
129            try:
130                await asyncio.sleep(0.1)
131            except asyncio.CancelledError:
132                1 / 0
133
134        async def main():
135            loop = asyncio.get_running_loop()
136            loop.call_exception_handler = call_exc_handler_mock
137
138            nonlocal lo_task
139            lo_task = asyncio.create_task(leftover())
140            return 123
141
142        self.assertEqual(asyncio.run(main()), 123)
143        self.assertTrue(lo_task.done())
144
145        call_exc_handler_mock.assert_called_with({
146            'message': test_utils.MockPattern(r'asyncio.run.*shutdown'),
147            'task': lo_task,
148            'exception': test_utils.MockInstanceOf(ZeroDivisionError)
149        })
150
151    def test_asyncio_run_closes_gens_after_hanging_tasks_errors(self):
152        spinner = None
153        lazyboy = None
154
155        class FancyExit(Exception):
156            pass
157
158        async def fidget():
159            while True:
160                yield 1
161                await asyncio.sleep(1)
162
163        async def spin():
164            nonlocal spinner
165            spinner = fidget()
166            try:
167                async for the_meaning_of_life in spinner:  # NoQA
168                    pass
169            except asyncio.CancelledError:
170                1 / 0
171
172        async def main():
173            loop = asyncio.get_running_loop()
174            loop.call_exception_handler = mock.Mock()
175
176            nonlocal lazyboy
177            lazyboy = asyncio.create_task(spin())
178            raise FancyExit
179
180        with self.assertRaises(FancyExit):
181            asyncio.run(main())
182
183        self.assertTrue(lazyboy.done())
184
185        self.assertIsNone(spinner.ag_frame)
186        self.assertFalse(spinner.ag_running)
187
188
189if __name__ == '__main__':
190    unittest.main()
191