1"""Tests for asyncio/threads.py""" 2 3import asyncio 4import unittest 5 6from contextvars import ContextVar 7from unittest import mock 8from test.test_asyncio import utils as test_utils 9 10 11def tearDownModule(): 12 asyncio.set_event_loop_policy(None) 13 14 15class ToThreadTests(test_utils.TestCase): 16 def setUp(self): 17 super().setUp() 18 self.loop = asyncio.new_event_loop() 19 asyncio.set_event_loop(self.loop) 20 21 def tearDown(self): 22 self.loop.run_until_complete( 23 self.loop.shutdown_default_executor()) 24 self.loop.close() 25 asyncio.set_event_loop(None) 26 self.loop = None 27 super().tearDown() 28 29 def test_to_thread(self): 30 async def main(): 31 return await asyncio.to_thread(sum, [40, 2]) 32 33 result = self.loop.run_until_complete(main()) 34 self.assertEqual(result, 42) 35 36 def test_to_thread_exception(self): 37 def raise_runtime(): 38 raise RuntimeError("test") 39 40 async def main(): 41 await asyncio.to_thread(raise_runtime) 42 43 with self.assertRaisesRegex(RuntimeError, "test"): 44 self.loop.run_until_complete(main()) 45 46 def test_to_thread_once(self): 47 func = mock.Mock() 48 49 async def main(): 50 await asyncio.to_thread(func) 51 52 self.loop.run_until_complete(main()) 53 func.assert_called_once() 54 55 def test_to_thread_concurrent(self): 56 func = mock.Mock() 57 58 async def main(): 59 futs = [] 60 for _ in range(10): 61 fut = asyncio.to_thread(func) 62 futs.append(fut) 63 await asyncio.gather(*futs) 64 65 self.loop.run_until_complete(main()) 66 self.assertEqual(func.call_count, 10) 67 68 def test_to_thread_args_kwargs(self): 69 # Unlike run_in_executor(), to_thread() should directly accept kwargs. 70 func = mock.Mock() 71 72 async def main(): 73 await asyncio.to_thread(func, 'test', something=True) 74 75 self.loop.run_until_complete(main()) 76 func.assert_called_once_with('test', something=True) 77 78 def test_to_thread_contextvars(self): 79 test_ctx = ContextVar('test_ctx') 80 81 def get_ctx(): 82 return test_ctx.get() 83 84 async def main(): 85 test_ctx.set('parrot') 86 return await asyncio.to_thread(get_ctx) 87 88 result = self.loop.run_until_complete(main()) 89 self.assertEqual(result, 'parrot') 90 91 92if __name__ == "__main__": 93 unittest.main() 94