• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Testing the compatibility between AsyncIO stack and the old stack."""
15
16import asyncio
17import logging
18import os
19import random
20import threading
21import unittest
22from concurrent.futures import ThreadPoolExecutor
23from typing import Callable, Iterable, Sequence, Tuple
24
25import grpc
26from grpc.experimental import aio
27
28from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
29from tests.unit.framework.common import test_constants
30from tests_aio.unit import _common
31from tests_aio.unit._test_base import AioTestBase
32from tests_aio.unit._test_server import TestServiceServicer, start_test_server
33
34_NUM_STREAM_RESPONSES = 5
35_REQUEST_PAYLOAD_SIZE = 7
36_RESPONSE_PAYLOAD_SIZE = 42
37_REQUEST = b'\x03\x07'
38_ADHOC_METHOD = '/test/AdHoc'
39
40
41def _unique_options() -> Sequence[Tuple[str, float]]:
42    return (('iv', random.random()),)
43
44
45class _AdhocGenericHandler(grpc.GenericRpcHandler):
46    _handler: grpc.RpcMethodHandler
47
48    def __init__(self):
49        self._handler = None
50
51    def set_adhoc_handler(self, handler: grpc.RpcMethodHandler):
52        self._handler = handler
53
54    def service(self, handler_call_details):
55        if handler_call_details.method == _ADHOC_METHOD:
56            return self._handler
57        else:
58            return None
59
60
61@unittest.skipIf(
62    os.environ.get('GRPC_ASYNCIO_ENGINE', '').lower() == 'custom_io_manager',
63    'Compatible mode needs POLLER completion queue.')
64class TestCompatibility(AioTestBase):
65
66    async def setUp(self):
67        self._async_server = aio.server(
68            options=(('grpc.so_reuseport', 0),),
69            migration_thread_pool=ThreadPoolExecutor())
70
71        test_pb2_grpc.add_TestServiceServicer_to_server(TestServiceServicer(),
72                                                        self._async_server)
73        self._adhoc_handlers = _AdhocGenericHandler()
74        self._async_server.add_generic_rpc_handlers((self._adhoc_handlers,))
75
76        port = self._async_server.add_insecure_port('[::]:0')
77        address = 'localhost:%d' % port
78        await self._async_server.start()
79
80        # Create async stub
81        self._async_channel = aio.insecure_channel(address,
82                                                   options=_unique_options())
83        self._async_stub = test_pb2_grpc.TestServiceStub(self._async_channel)
84
85        # Create sync stub
86        self._sync_channel = grpc.insecure_channel(address,
87                                                   options=_unique_options())
88        self._sync_stub = test_pb2_grpc.TestServiceStub(self._sync_channel)
89
90    async def tearDown(self):
91        self._sync_channel.close()
92        await self._async_channel.close()
93        await self._async_server.stop(None)
94
95    async def _run_in_another_thread(self, func: Callable[[], None]):
96        work_done = asyncio.Event(loop=self.loop)
97
98        def thread_work():
99            func()
100            self.loop.call_soon_threadsafe(work_done.set)
101
102        thread = threading.Thread(target=thread_work, daemon=True)
103        thread.start()
104        await work_done.wait()
105        thread.join()
106
107    async def test_unary_unary(self):
108        # Calling async API in this thread
109        await self._async_stub.UnaryCall(messages_pb2.SimpleRequest(),
110                                         timeout=test_constants.LONG_TIMEOUT)
111
112        # Calling sync API in a different thread
113        def sync_work() -> None:
114            response, call = self._sync_stub.UnaryCall.with_call(
115                messages_pb2.SimpleRequest(),
116                timeout=test_constants.LONG_TIMEOUT)
117            self.assertIsInstance(response, messages_pb2.SimpleResponse)
118            self.assertEqual(grpc.StatusCode.OK, call.code())
119
120        await self._run_in_another_thread(sync_work)
121
122    async def test_unary_stream(self):
123        request = messages_pb2.StreamingOutputCallRequest()
124        for _ in range(_NUM_STREAM_RESPONSES):
125            request.response_parameters.append(
126                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
127
128        # Calling async API in this thread
129        call = self._async_stub.StreamingOutputCall(request)
130
131        for _ in range(_NUM_STREAM_RESPONSES):
132            await call.read()
133        self.assertEqual(grpc.StatusCode.OK, await call.code())
134
135        # Calling sync API in a different thread
136        def sync_work() -> None:
137            response_iterator = self._sync_stub.StreamingOutputCall(request)
138            for response in response_iterator:
139                assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body)
140            self.assertEqual(grpc.StatusCode.OK, response_iterator.code())
141
142        await self._run_in_another_thread(sync_work)
143
144    async def test_stream_unary(self):
145        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
146        request = messages_pb2.StreamingInputCallRequest(payload=payload)
147
148        # Calling async API in this thread
149        async def gen():
150            for _ in range(_NUM_STREAM_RESPONSES):
151                yield request
152
153        response = await self._async_stub.StreamingInputCall(gen())
154        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
155                         response.aggregated_payload_size)
156
157        # Calling sync API in a different thread
158        def sync_work() -> None:
159            response = self._sync_stub.StreamingInputCall(
160                iter([request] * _NUM_STREAM_RESPONSES))
161            self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
162                             response.aggregated_payload_size)
163
164        await self._run_in_another_thread(sync_work)
165
166    async def test_stream_stream(self):
167        request = messages_pb2.StreamingOutputCallRequest()
168        request.response_parameters.append(
169            messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
170
171        # Calling async API in this thread
172        call = self._async_stub.FullDuplexCall()
173
174        for _ in range(_NUM_STREAM_RESPONSES):
175            await call.write(request)
176            response = await call.read()
177            assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body)
178
179        await call.done_writing()
180        assert await call.code() == grpc.StatusCode.OK
181
182        # Calling sync API in a different thread
183        def sync_work() -> None:
184            response_iterator = self._sync_stub.FullDuplexCall(iter([request]))
185            for response in response_iterator:
186                assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body)
187            self.assertEqual(grpc.StatusCode.OK, response_iterator.code())
188
189        await self._run_in_another_thread(sync_work)
190
191    async def test_server(self):
192
193        class GenericHandlers(grpc.GenericRpcHandler):
194
195            def service(self, handler_call_details):
196                return grpc.unary_unary_rpc_method_handler(lambda x, _: x)
197
198        # It's fine to instantiate server object in the event loop thread.
199        # The server will spawn its own serving thread.
200        server = grpc.server(ThreadPoolExecutor(),
201                             handlers=(GenericHandlers(),))
202        port = server.add_insecure_port('localhost:0')
203        server.start()
204
205        def sync_work() -> None:
206            for _ in range(100):
207                with grpc.insecure_channel('localhost:%d' % port) as channel:
208                    response = channel.unary_unary('/test/test')(b'\x07\x08')
209                    self.assertEqual(response, b'\x07\x08')
210
211        await self._run_in_another_thread(sync_work)
212
213    async def test_many_loop(self):
214        address, server = await start_test_server()
215
216        # Run another loop in another thread
217        def sync_work():
218
219            async def async_work():
220                # Create async stub
221                async_channel = aio.insecure_channel(address,
222                                                     options=_unique_options())
223                async_stub = test_pb2_grpc.TestServiceStub(async_channel)
224
225                call = async_stub.UnaryCall(messages_pb2.SimpleRequest())
226                response = await call
227                self.assertIsInstance(response, messages_pb2.SimpleResponse)
228                self.assertEqual(grpc.StatusCode.OK, await call.code())
229
230            loop = asyncio.new_event_loop()
231            loop.run_until_complete(async_work())
232
233        await self._run_in_another_thread(sync_work)
234        await server.stop(None)
235
236    async def test_sync_unary_unary_success(self):
237
238        @grpc.unary_unary_rpc_method_handler
239        def echo_unary_unary(request: bytes, unused_context):
240            return request
241
242        self._adhoc_handlers.set_adhoc_handler(echo_unary_unary)
243        response = await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST
244                                                                       )
245        self.assertEqual(_REQUEST, response)
246
247    async def test_sync_unary_unary_metadata(self):
248        metadata = (('unique', 'key-42'),)
249
250        @grpc.unary_unary_rpc_method_handler
251        def metadata_unary_unary(request: bytes, context: grpc.ServicerContext):
252            context.send_initial_metadata(metadata)
253            return request
254
255        self._adhoc_handlers.set_adhoc_handler(metadata_unary_unary)
256        call = self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST)
257        self.assertTrue(
258            _common.seen_metadata(aio.Metadata(*metadata), await
259                                  call.initial_metadata()))
260
261    async def test_sync_unary_unary_abort(self):
262
263        @grpc.unary_unary_rpc_method_handler
264        def abort_unary_unary(request: bytes, context: grpc.ServicerContext):
265            context.abort(grpc.StatusCode.INTERNAL, 'Test')
266
267        self._adhoc_handlers.set_adhoc_handler(abort_unary_unary)
268        with self.assertRaises(aio.AioRpcError) as exception_context:
269            await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST)
270        self.assertEqual(grpc.StatusCode.INTERNAL,
271                         exception_context.exception.code())
272
273    async def test_sync_unary_unary_set_code(self):
274
275        @grpc.unary_unary_rpc_method_handler
276        def set_code_unary_unary(request: bytes, context: grpc.ServicerContext):
277            context.set_code(grpc.StatusCode.INTERNAL)
278
279        self._adhoc_handlers.set_adhoc_handler(set_code_unary_unary)
280        with self.assertRaises(aio.AioRpcError) as exception_context:
281            await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST)
282        self.assertEqual(grpc.StatusCode.INTERNAL,
283                         exception_context.exception.code())
284
285    async def test_sync_unary_stream_success(self):
286
287        @grpc.unary_stream_rpc_method_handler
288        def echo_unary_stream(request: bytes, unused_context):
289            for _ in range(_NUM_STREAM_RESPONSES):
290                yield request
291
292        self._adhoc_handlers.set_adhoc_handler(echo_unary_stream)
293        call = self._async_channel.unary_stream(_ADHOC_METHOD)(_REQUEST)
294        async for response in call:
295            self.assertEqual(_REQUEST, response)
296
297    async def test_sync_unary_stream_error(self):
298
299        @grpc.unary_stream_rpc_method_handler
300        def error_unary_stream(request: bytes, unused_context):
301            for _ in range(_NUM_STREAM_RESPONSES):
302                yield request
303            raise RuntimeError('Test')
304
305        self._adhoc_handlers.set_adhoc_handler(error_unary_stream)
306        call = self._async_channel.unary_stream(_ADHOC_METHOD)(_REQUEST)
307        with self.assertRaises(aio.AioRpcError) as exception_context:
308            async for response in call:
309                self.assertEqual(_REQUEST, response)
310        self.assertEqual(grpc.StatusCode.UNKNOWN,
311                         exception_context.exception.code())
312
313    async def test_sync_stream_unary_success(self):
314
315        @grpc.stream_unary_rpc_method_handler
316        def echo_stream_unary(request_iterator: Iterable[bytes],
317                              unused_context):
318            self.assertEqual(len(list(request_iterator)), _NUM_STREAM_RESPONSES)
319            return _REQUEST
320
321        self._adhoc_handlers.set_adhoc_handler(echo_stream_unary)
322        request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
323        response = await self._async_channel.stream_unary(_ADHOC_METHOD)(
324            request_iterator)
325        self.assertEqual(_REQUEST, response)
326
327    async def test_sync_stream_unary_error(self):
328
329        @grpc.stream_unary_rpc_method_handler
330        def echo_stream_unary(request_iterator: Iterable[bytes],
331                              unused_context):
332            self.assertEqual(len(list(request_iterator)), _NUM_STREAM_RESPONSES)
333            raise RuntimeError('Test')
334
335        self._adhoc_handlers.set_adhoc_handler(echo_stream_unary)
336        request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
337        with self.assertRaises(aio.AioRpcError) as exception_context:
338            response = await self._async_channel.stream_unary(_ADHOC_METHOD)(
339                request_iterator)
340        self.assertEqual(grpc.StatusCode.UNKNOWN,
341                         exception_context.exception.code())
342
343    async def test_sync_stream_stream_success(self):
344
345        @grpc.stream_stream_rpc_method_handler
346        def echo_stream_stream(request_iterator: Iterable[bytes],
347                               unused_context):
348            for request in request_iterator:
349                yield request
350
351        self._adhoc_handlers.set_adhoc_handler(echo_stream_stream)
352        request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
353        call = self._async_channel.stream_stream(_ADHOC_METHOD)(
354            request_iterator)
355        async for response in call:
356            self.assertEqual(_REQUEST, response)
357
358    async def test_sync_stream_stream_error(self):
359
360        @grpc.stream_stream_rpc_method_handler
361        def echo_stream_stream(request_iterator: Iterable[bytes],
362                               unused_context):
363            for request in request_iterator:
364                yield request
365            raise RuntimeError('test')
366
367        self._adhoc_handlers.set_adhoc_handler(echo_stream_stream)
368        request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
369        call = self._async_channel.stream_stream(_ADHOC_METHOD)(
370            request_iterator)
371        with self.assertRaises(aio.AioRpcError) as exception_context:
372            async for response in call:
373                self.assertEqual(_REQUEST, response)
374        self.assertEqual(grpc.StatusCode.UNKNOWN,
375                         exception_context.exception.code())
376
377
378if __name__ == '__main__':
379    logging.basicConfig(level=logging.DEBUG)
380    unittest.main(verbosity=2)
381