• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Tests for asyncio/threads.py"""
2
3import asyncio
4import unittest
5
6from contextvars import ContextVar
7from unittest import mock
8from test.test_asyncio import utils as test_utils
9
10
11def tearDownModule():
12    asyncio.set_event_loop_policy(None)
13
14
15class ToThreadTests(test_utils.TestCase):
16    def setUp(self):
17        super().setUp()
18        self.loop = asyncio.new_event_loop()
19        asyncio.set_event_loop(self.loop)
20
21    def tearDown(self):
22        self.loop.run_until_complete(
23            self.loop.shutdown_default_executor())
24        self.loop.close()
25        asyncio.set_event_loop(None)
26        self.loop = None
27        super().tearDown()
28
29    def test_to_thread(self):
30        async def main():
31            return await asyncio.to_thread(sum, [40, 2])
32
33        result = self.loop.run_until_complete(main())
34        self.assertEqual(result, 42)
35
36    def test_to_thread_exception(self):
37        def raise_runtime():
38            raise RuntimeError("test")
39
40        async def main():
41            await asyncio.to_thread(raise_runtime)
42
43        with self.assertRaisesRegex(RuntimeError, "test"):
44            self.loop.run_until_complete(main())
45
46    def test_to_thread_once(self):
47        func = mock.Mock()
48
49        async def main():
50            await asyncio.to_thread(func)
51
52        self.loop.run_until_complete(main())
53        func.assert_called_once()
54
55    def test_to_thread_concurrent(self):
56        func = mock.Mock()
57
58        async def main():
59            futs = []
60            for _ in range(10):
61                fut = asyncio.to_thread(func)
62                futs.append(fut)
63            await asyncio.gather(*futs)
64
65        self.loop.run_until_complete(main())
66        self.assertEqual(func.call_count, 10)
67
68    def test_to_thread_args_kwargs(self):
69        # Unlike run_in_executor(), to_thread() should directly accept kwargs.
70        func = mock.Mock()
71
72        async def main():
73            await asyncio.to_thread(func, 'test', something=True)
74
75        self.loop.run_until_complete(main())
76        func.assert_called_once_with('test', something=True)
77
78    def test_to_thread_contextvars(self):
79        test_ctx = ContextVar('test_ctx')
80
81        def get_ctx():
82            return test_ctx.get()
83
84        async def main():
85            test_ctx.set('parrot')
86            return await asyncio.to_thread(get_ctx)
87
88        result = self.loop.run_until_complete(main())
89        self.assertEqual(result, 'parrot')
90
91
92if __name__ == "__main__":
93    unittest.main()
94