• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Testing the compatibility between AsyncIO stack and the old stack."""
15
16import asyncio
17from concurrent.futures import ThreadPoolExecutor
18import logging
19import os
20import random
21import threading
22from typing import Callable, Iterable, Sequence, Tuple
23import unittest
24
25import grpc
26from grpc.experimental import aio
27
28from src.proto.grpc.testing import messages_pb2
29from src.proto.grpc.testing import test_pb2_grpc
30from tests.unit.framework.common import test_constants
31from tests_aio.unit import _common
32from tests_aio.unit._test_base import AioTestBase
33from tests_aio.unit._test_server import TestServiceServicer
34from tests_aio.unit._test_server import start_test_server
35
36_NUM_STREAM_RESPONSES = 5
37_REQUEST_PAYLOAD_SIZE = 7
38_RESPONSE_PAYLOAD_SIZE = 42
39_REQUEST = b"\x03\x07"
40
41
42def _unique_options() -> Sequence[Tuple[str, float]]:
43    return (("iv", random.random()),)
44
45
46@unittest.skipIf(
47    os.environ.get("GRPC_ASYNCIO_ENGINE", "").lower() == "custom_io_manager",
48    "Compatible mode needs POLLER completion queue.",
49)
50class TestCompatibility(AioTestBase):
51    async def setUp(self):
52        self._async_server = aio.server(
53            options=(("grpc.so_reuseport", 0),),
54            migration_thread_pool=ThreadPoolExecutor(),
55        )
56
57        test_pb2_grpc.add_TestServiceServicer_to_server(
58            TestServiceServicer(), self._async_server
59        )
60        self._adhoc_handlers = _common.AdhocGenericHandler()
61        self._async_server.add_generic_rpc_handlers((self._adhoc_handlers,))
62
63        port = self._async_server.add_insecure_port("[::]:0")
64        address = "localhost:%d" % port
65        await self._async_server.start()
66
67        # Create async stub
68        self._async_channel = aio.insecure_channel(
69            address, options=_unique_options()
70        )
71        self._async_stub = test_pb2_grpc.TestServiceStub(self._async_channel)
72
73        # Create sync stub
74        self._sync_channel = grpc.insecure_channel(
75            address, options=_unique_options()
76        )
77        self._sync_stub = test_pb2_grpc.TestServiceStub(self._sync_channel)
78
79    async def tearDown(self):
80        self._sync_channel.close()
81        await self._async_channel.close()
82        await self._async_server.stop(None)
83
84    async def _run_in_another_thread(self, func: Callable[[], None]):
85        work_done = asyncio.Event()
86
87        def thread_work():
88            func()
89            self.loop.call_soon_threadsafe(work_done.set)
90
91        thread = threading.Thread(target=thread_work, daemon=True)
92        thread.start()
93        await work_done.wait()
94        thread.join()
95
96    async def test_unary_unary(self):
97        # Calling async API in this thread
98        await self._async_stub.UnaryCall(
99            messages_pb2.SimpleRequest(), timeout=test_constants.LONG_TIMEOUT
100        )
101
102        # Calling sync API in a different thread
103        def sync_work() -> None:
104            response, call = self._sync_stub.UnaryCall.with_call(
105                messages_pb2.SimpleRequest(),
106                timeout=test_constants.LONG_TIMEOUT,
107            )
108            self.assertIsInstance(response, messages_pb2.SimpleResponse)
109            self.assertEqual(grpc.StatusCode.OK, call.code())
110
111        await self._run_in_another_thread(sync_work)
112
113    async def test_unary_stream(self):
114        request = messages_pb2.StreamingOutputCallRequest()
115        for _ in range(_NUM_STREAM_RESPONSES):
116            request.response_parameters.append(
117                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)
118            )
119
120        # Calling async API in this thread
121        call = self._async_stub.StreamingOutputCall(request)
122
123        for _ in range(_NUM_STREAM_RESPONSES):
124            await call.read()
125        self.assertEqual(grpc.StatusCode.OK, await call.code())
126
127        # Calling sync API in a different thread
128        def sync_work() -> None:
129            response_iterator = self._sync_stub.StreamingOutputCall(request)
130            for response in response_iterator:
131                assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body)
132            self.assertEqual(grpc.StatusCode.OK, response_iterator.code())
133
134        await self._run_in_another_thread(sync_work)
135
136    async def test_stream_unary(self):
137        payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE)
138        request = messages_pb2.StreamingInputCallRequest(payload=payload)
139
140        # Calling async API in this thread
141        async def gen():
142            for _ in range(_NUM_STREAM_RESPONSES):
143                yield request
144
145        response = await self._async_stub.StreamingInputCall(gen())
146        self.assertEqual(
147            _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
148            response.aggregated_payload_size,
149        )
150
151        # Calling sync API in a different thread
152        def sync_work() -> None:
153            response = self._sync_stub.StreamingInputCall(
154                iter([request] * _NUM_STREAM_RESPONSES)
155            )
156            self.assertEqual(
157                _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
158                response.aggregated_payload_size,
159            )
160
161        await self._run_in_another_thread(sync_work)
162
163    async def test_stream_stream(self):
164        request = messages_pb2.StreamingOutputCallRequest()
165        request.response_parameters.append(
166            messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)
167        )
168
169        # Calling async API in this thread
170        call = self._async_stub.FullDuplexCall()
171
172        for _ in range(_NUM_STREAM_RESPONSES):
173            await call.write(request)
174            response = await call.read()
175            assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body)
176
177        await call.done_writing()
178        assert await call.code() == grpc.StatusCode.OK
179
180        # Calling sync API in a different thread
181        def sync_work() -> None:
182            response_iterator = self._sync_stub.FullDuplexCall(iter([request]))
183            for response in response_iterator:
184                assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body)
185            self.assertEqual(grpc.StatusCode.OK, response_iterator.code())
186
187        await self._run_in_another_thread(sync_work)
188
189    async def test_server(self):
190        class GenericHandlers(grpc.GenericRpcHandler):
191            def service(self, handler_call_details):
192                return grpc.unary_unary_rpc_method_handler(lambda x, _: x)
193
194        # It's fine to instantiate server object in the event loop thread.
195        # The server will spawn its own serving thread.
196        server = grpc.server(
197            ThreadPoolExecutor(), handlers=(GenericHandlers(),)
198        )
199        port = server.add_insecure_port("localhost:0")
200        server.start()
201
202        def sync_work() -> None:
203            for _ in range(100):
204                with grpc.insecure_channel("localhost:%d" % port) as channel:
205                    response = channel.unary_unary("/test/test")(b"\x07\x08")
206                    self.assertEqual(response, b"\x07\x08")
207
208        await self._run_in_another_thread(sync_work)
209
210    async def test_many_loop(self):
211        address, server = await start_test_server()
212
213        # Run another loop in another thread
214        def sync_work():
215            async def async_work():
216                # Create async stub
217                async_channel = aio.insecure_channel(
218                    address, options=_unique_options()
219                )
220                async_stub = test_pb2_grpc.TestServiceStub(async_channel)
221
222                call = async_stub.UnaryCall(messages_pb2.SimpleRequest())
223                response = await call
224                self.assertIsInstance(response, messages_pb2.SimpleResponse)
225                self.assertEqual(grpc.StatusCode.OK, await call.code())
226
227            loop = asyncio.new_event_loop()
228            loop.run_until_complete(async_work())
229
230        await self._run_in_another_thread(sync_work)
231        await server.stop(None)
232
233    async def test_sync_unary_unary_success(self):
234        @grpc.unary_unary_rpc_method_handler
235        def echo_unary_unary(request: bytes, unused_context):
236            return request
237
238        self._adhoc_handlers.set_adhoc_handler(echo_unary_unary)
239        response = await self._async_channel.unary_unary(_common.ADHOC_METHOD)(
240            _REQUEST
241        )
242        self.assertEqual(_REQUEST, response)
243
244    async def test_sync_unary_unary_metadata(self):
245        metadata = (("unique", "key-42"),)
246
247        @grpc.unary_unary_rpc_method_handler
248        def metadata_unary_unary(request: bytes, context: grpc.ServicerContext):
249            context.send_initial_metadata(metadata)
250            return request
251
252        self._adhoc_handlers.set_adhoc_handler(metadata_unary_unary)
253        call = self._async_channel.unary_unary(_common.ADHOC_METHOD)(_REQUEST)
254        self.assertTrue(
255            _common.seen_metadata(
256                aio.Metadata(*metadata), await call.initial_metadata()
257            )
258        )
259
260    async def test_sync_unary_unary_abort(self):
261        @grpc.unary_unary_rpc_method_handler
262        def abort_unary_unary(request: bytes, context: grpc.ServicerContext):
263            context.abort(grpc.StatusCode.INTERNAL, "Test")
264
265        self._adhoc_handlers.set_adhoc_handler(abort_unary_unary)
266        with self.assertRaises(aio.AioRpcError) as exception_context:
267            await self._async_channel.unary_unary(_common.ADHOC_METHOD)(
268                _REQUEST
269            )
270        self.assertEqual(
271            grpc.StatusCode.INTERNAL, exception_context.exception.code()
272        )
273
274    async def test_sync_unary_unary_set_code(self):
275        @grpc.unary_unary_rpc_method_handler
276        def set_code_unary_unary(request: bytes, context: grpc.ServicerContext):
277            context.set_code(grpc.StatusCode.INTERNAL)
278
279        self._adhoc_handlers.set_adhoc_handler(set_code_unary_unary)
280        with self.assertRaises(aio.AioRpcError) as exception_context:
281            await self._async_channel.unary_unary(_common.ADHOC_METHOD)(
282                _REQUEST
283            )
284        self.assertEqual(
285            grpc.StatusCode.INTERNAL, exception_context.exception.code()
286        )
287
288    async def test_sync_unary_stream_success(self):
289        @grpc.unary_stream_rpc_method_handler
290        def echo_unary_stream(request: bytes, unused_context):
291            for _ in range(_NUM_STREAM_RESPONSES):
292                yield request
293
294        self._adhoc_handlers.set_adhoc_handler(echo_unary_stream)
295        call = self._async_channel.unary_stream(_common.ADHOC_METHOD)(_REQUEST)
296        async for response in call:
297            self.assertEqual(_REQUEST, response)
298
299    async def test_sync_unary_stream_error(self):
300        @grpc.unary_stream_rpc_method_handler
301        def error_unary_stream(request: bytes, unused_context):
302            for _ in range(_NUM_STREAM_RESPONSES):
303                yield request
304            raise RuntimeError("Test")
305
306        self._adhoc_handlers.set_adhoc_handler(error_unary_stream)
307        call = self._async_channel.unary_stream(_common.ADHOC_METHOD)(_REQUEST)
308        with self.assertRaises(aio.AioRpcError) as exception_context:
309            async for response in call:
310                self.assertEqual(_REQUEST, response)
311        self.assertEqual(
312            grpc.StatusCode.UNKNOWN, exception_context.exception.code()
313        )
314
315    async def test_sync_stream_unary_success(self):
316        @grpc.stream_unary_rpc_method_handler
317        def echo_stream_unary(
318            request_iterator: Iterable[bytes], unused_context
319        ):
320            self.assertEqual(len(list(request_iterator)), _NUM_STREAM_RESPONSES)
321            return _REQUEST
322
323        self._adhoc_handlers.set_adhoc_handler(echo_stream_unary)
324        request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
325        response = await self._async_channel.stream_unary(_common.ADHOC_METHOD)(
326            request_iterator
327        )
328        self.assertEqual(_REQUEST, response)
329
330    async def test_sync_stream_unary_error(self):
331        @grpc.stream_unary_rpc_method_handler
332        def echo_stream_unary(
333            request_iterator: Iterable[bytes], unused_context
334        ):
335            self.assertEqual(len(list(request_iterator)), _NUM_STREAM_RESPONSES)
336            raise RuntimeError("Test")
337
338        self._adhoc_handlers.set_adhoc_handler(echo_stream_unary)
339        request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
340        with self.assertRaises(aio.AioRpcError) as exception_context:
341            response = await self._async_channel.stream_unary(
342                _common.ADHOC_METHOD
343            )(request_iterator)
344        self.assertEqual(
345            grpc.StatusCode.UNKNOWN, exception_context.exception.code()
346        )
347
348    async def test_sync_stream_stream_success(self):
349        @grpc.stream_stream_rpc_method_handler
350        def echo_stream_stream(
351            request_iterator: Iterable[bytes], unused_context
352        ):
353            for request in request_iterator:
354                yield request
355
356        self._adhoc_handlers.set_adhoc_handler(echo_stream_stream)
357        request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
358        call = self._async_channel.stream_stream(_common.ADHOC_METHOD)(
359            request_iterator
360        )
361        async for response in call:
362            self.assertEqual(_REQUEST, response)
363
364    async def test_sync_stream_stream_error(self):
365        @grpc.stream_stream_rpc_method_handler
366        def echo_stream_stream(
367            request_iterator: Iterable[bytes], unused_context
368        ):
369            for request in request_iterator:
370                yield request
371            raise RuntimeError("test")
372
373        self._adhoc_handlers.set_adhoc_handler(echo_stream_stream)
374        request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES)
375        call = self._async_channel.stream_stream(_common.ADHOC_METHOD)(
376            request_iterator
377        )
378        with self.assertRaises(aio.AioRpcError) as exception_context:
379            async for response in call:
380                self.assertEqual(_REQUEST, response)
381        self.assertEqual(
382            grpc.StatusCode.UNKNOWN, exception_context.exception.code()
383        )
384
385
386if __name__ == "__main__":
387    logging.basicConfig(level=logging.DEBUG)
388    unittest.main(verbosity=2)
389