• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import logging
16import functools
17import asyncio
18from typing import Callable
19import unittest
20from grpc.experimental import aio
21
22__all__ = 'AioTestBase'
23
24_COROUTINE_FUNCTION_ALLOWLIST = ['setUp', 'tearDown']
25
26
27def _async_to_sync_decorator(f: Callable, loop: asyncio.AbstractEventLoop):
28
29    @functools.wraps(f)
30    def wrapper(*args, **kwargs):
31        return loop.run_until_complete(f(*args, **kwargs))
32
33    return wrapper
34
35
36def _get_default_loop(debug=True):
37    try:
38        loop = asyncio.get_event_loop()
39    except:
40        loop = asyncio.new_event_loop()
41        asyncio.set_event_loop(loop)
42    finally:
43        loop.set_debug(debug)
44        return loop
45
46
47# NOTE(gnossen) this test class can also be implemented with metaclass.
48class AioTestBase(unittest.TestCase):
49    # NOTE(lidi) We need to pick a loop for entire testing phase, otherwise it
50    # will trigger create new loops in new threads, leads to deadlock.
51    _TEST_LOOP = _get_default_loop()
52
53    @property
54    def loop(self):
55        return self._TEST_LOOP
56
57    def __getattribute__(self, name):
58        """Overrides the loading logic to support coroutine functions."""
59        attr = super().__getattribute__(name)
60
61        # If possible, converts the coroutine into a sync function.
62        if name.startswith('test_') or name in _COROUTINE_FUNCTION_ALLOWLIST:
63            if asyncio.iscoroutinefunction(attr):
64                return _async_to_sync_decorator(attr, self._TEST_LOOP)
65        # For other attributes, let them pass.
66        return attr
67