• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import asyncio
2import inspect
3
4from .case import TestCase
5
6
7class IsolatedAsyncioTestCase(TestCase):
8    # Names intentionally have a long prefix
9    # to reduce a chance of clashing with user-defined attributes
10    # from inherited test case
11    #
12    # The class doesn't call loop.run_until_complete(self.setUp()) and family
13    # but uses a different approach:
14    # 1. create a long-running task that reads self.setUp()
15    #    awaitable from queue along with a future
16    # 2. await the awaitable object passing in and set the result
17    #    into the future object
18    # 3. Outer code puts the awaitable and the future object into a queue
19    #    with waiting for the future
20    # The trick is necessary because every run_until_complete() call
21    # creates a new task with embedded ContextVar context.
22    # To share contextvars between setUp(), test and tearDown() we need to execute
23    # them inside the same task.
24
25    # Note: the test case modifies event loop policy if the policy was not instantiated
26    # yet.
27    # asyncio.get_event_loop_policy() creates a default policy on demand but never
28    # returns None
29    # I believe this is not an issue in user level tests but python itself for testing
30    # should reset a policy in every test module
31    # by calling asyncio.set_event_loop_policy(None) in tearDownModule()
32
33    def __init__(self, methodName='runTest'):
34        super().__init__(methodName)
35        self._asyncioTestLoop = None
36        self._asyncioCallsQueue = None
37
38    async def asyncSetUp(self):
39        pass
40
41    async def asyncTearDown(self):
42        pass
43
44    def addAsyncCleanup(self, func, /, *args, **kwargs):
45        # A trivial trampoline to addCleanup()
46        # the function exists because it has a different semantics
47        # and signature:
48        # addCleanup() accepts regular functions
49        # but addAsyncCleanup() accepts coroutines
50        #
51        # We intentionally don't add inspect.iscoroutinefunction() check
52        # for func argument because there is no way
53        # to check for async function reliably:
54        # 1. It can be "async def func()" itself
55        # 2. Class can implement "async def __call__()" method
56        # 3. Regular "def func()" that returns awaitable object
57        self.addCleanup(*(func, *args), **kwargs)
58
59    def _callSetUp(self):
60        self.setUp()
61        self._callAsync(self.asyncSetUp)
62
63    def _callTestMethod(self, method):
64        self._callMaybeAsync(method)
65
66    def _callTearDown(self):
67        self._callAsync(self.asyncTearDown)
68        self.tearDown()
69
70    def _callCleanup(self, function, *args, **kwargs):
71        self._callMaybeAsync(function, *args, **kwargs)
72
73    def _callAsync(self, func, /, *args, **kwargs):
74        assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
75        ret = func(*args, **kwargs)
76        assert inspect.isawaitable(ret), f'{func!r} returned non-awaitable'
77        fut = self._asyncioTestLoop.create_future()
78        self._asyncioCallsQueue.put_nowait((fut, ret))
79        return self._asyncioTestLoop.run_until_complete(fut)
80
81    def _callMaybeAsync(self, func, /, *args, **kwargs):
82        assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
83        ret = func(*args, **kwargs)
84        if inspect.isawaitable(ret):
85            fut = self._asyncioTestLoop.create_future()
86            self._asyncioCallsQueue.put_nowait((fut, ret))
87            return self._asyncioTestLoop.run_until_complete(fut)
88        else:
89            return ret
90
91    async def _asyncioLoopRunner(self, fut):
92        self._asyncioCallsQueue = queue = asyncio.Queue()
93        fut.set_result(None)
94        while True:
95            query = await queue.get()
96            queue.task_done()
97            if query is None:
98                return
99            fut, awaitable = query
100            try:
101                ret = await awaitable
102                if not fut.cancelled():
103                    fut.set_result(ret)
104            except (SystemExit, KeyboardInterrupt):
105                raise
106            except (BaseException, asyncio.CancelledError) as ex:
107                if not fut.cancelled():
108                    fut.set_exception(ex)
109
110    def _setupAsyncioLoop(self):
111        assert self._asyncioTestLoop is None, 'asyncio test loop already initialized'
112        loop = asyncio.new_event_loop()
113        asyncio.set_event_loop(loop)
114        loop.set_debug(True)
115        self._asyncioTestLoop = loop
116        fut = loop.create_future()
117        self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut))
118        loop.run_until_complete(fut)
119
120    def _tearDownAsyncioLoop(self):
121        assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
122        loop = self._asyncioTestLoop
123        self._asyncioTestLoop = None
124        self._asyncioCallsQueue.put_nowait(None)
125        loop.run_until_complete(self._asyncioCallsQueue.join())
126
127        try:
128            # cancel all tasks
129            to_cancel = asyncio.all_tasks(loop)
130            if not to_cancel:
131                return
132
133            for task in to_cancel:
134                task.cancel()
135
136            loop.run_until_complete(
137                asyncio.gather(*to_cancel, return_exceptions=True))
138
139            for task in to_cancel:
140                if task.cancelled():
141                    continue
142                if task.exception() is not None:
143                    loop.call_exception_handler({
144                        'message': 'unhandled exception during test shutdown',
145                        'exception': task.exception(),
146                        'task': task,
147                    })
148            # shutdown asyncgens
149            loop.run_until_complete(loop.shutdown_asyncgens())
150        finally:
151            asyncio.set_event_loop(None)
152            loop.close()
153
154    def run(self, result=None):
155        self._setupAsyncioLoop()
156        try:
157            return super().run(result)
158        finally:
159            self._tearDownAsyncioLoop()
160
161    def debug(self):
162        self._setupAsyncioLoop()
163        super().debug()
164        self._tearDownAsyncioLoop()
165
166    def __del__(self):
167        if self._asyncioTestLoop is not None:
168            self._tearDownAsyncioLoop()
169