• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The gRPC Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import gc
17import logging
18import socket
19import time
20import unittest
21
22import grpc
23from grpc.experimental import aio
24
25from tests.unit import resources
26from tests.unit.framework.common import test_constants
27from tests_aio.unit._test_base import AioTestBase
28
29_SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary'
30_BLOCK_FOREVER = '/test/BlockForever'
31_BLOCK_BRIEFLY = '/test/BlockBriefly'
32_UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen'
33_UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter'
34_UNARY_STREAM_EVILLY_MIXED = '/test/UnaryStreamEvillyMixed'
35_STREAM_UNARY_ASYNC_GEN = '/test/StreamUnaryAsyncGen'
36_STREAM_UNARY_READER_WRITER = '/test/StreamUnaryReaderWriter'
37_STREAM_UNARY_EVILLY_MIXED = '/test/StreamUnaryEvillyMixed'
38_STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen'
39_STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter'
40_STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed'
41_UNIMPLEMENTED_METHOD = '/test/UnimplementedMethod'
42_ERROR_IN_STREAM_STREAM = '/test/ErrorInStreamStream'
43_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY = '/test/ErrorWithoutRaiseInUnaryUnary'
44_ERROR_WITHOUT_RAISE_IN_STREAM_STREAM = '/test/ErrorWithoutRaiseInStreamStream'
45
46_REQUEST = b'\x00\x00\x00'
47_RESPONSE = b'\x01\x01\x01'
48_NUM_STREAM_REQUESTS = 3
49_NUM_STREAM_RESPONSES = 5
50_MAXIMUM_CONCURRENT_RPCS = 5
51
52
53class _GenericHandler(grpc.GenericRpcHandler):
54
55    def __init__(self):
56        self._called = asyncio.get_event_loop().create_future()
57        self._routing_table = {
58            _SIMPLE_UNARY_UNARY:
59                grpc.unary_unary_rpc_method_handler(self._unary_unary),
60            _BLOCK_FOREVER:
61                grpc.unary_unary_rpc_method_handler(self._block_forever),
62            _BLOCK_BRIEFLY:
63                grpc.unary_unary_rpc_method_handler(self._block_briefly),
64            _UNARY_STREAM_ASYNC_GEN:
65                grpc.unary_stream_rpc_method_handler(
66                    self._unary_stream_async_gen),
67            _UNARY_STREAM_READER_WRITER:
68                grpc.unary_stream_rpc_method_handler(
69                    self._unary_stream_reader_writer),
70            _UNARY_STREAM_EVILLY_MIXED:
71                grpc.unary_stream_rpc_method_handler(
72                    self._unary_stream_evilly_mixed),
73            _STREAM_UNARY_ASYNC_GEN:
74                grpc.stream_unary_rpc_method_handler(
75                    self._stream_unary_async_gen),
76            _STREAM_UNARY_READER_WRITER:
77                grpc.stream_unary_rpc_method_handler(
78                    self._stream_unary_reader_writer),
79            _STREAM_UNARY_EVILLY_MIXED:
80                grpc.stream_unary_rpc_method_handler(
81                    self._stream_unary_evilly_mixed),
82            _STREAM_STREAM_ASYNC_GEN:
83                grpc.stream_stream_rpc_method_handler(
84                    self._stream_stream_async_gen),
85            _STREAM_STREAM_READER_WRITER:
86                grpc.stream_stream_rpc_method_handler(
87                    self._stream_stream_reader_writer),
88            _STREAM_STREAM_EVILLY_MIXED:
89                grpc.stream_stream_rpc_method_handler(
90                    self._stream_stream_evilly_mixed),
91            _ERROR_IN_STREAM_STREAM:
92                grpc.stream_stream_rpc_method_handler(
93                    self._error_in_stream_stream),
94            _ERROR_WITHOUT_RAISE_IN_UNARY_UNARY:
95                grpc.unary_unary_rpc_method_handler(
96                    self._error_without_raise_in_unary_unary),
97            _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM:
98                grpc.stream_stream_rpc_method_handler(
99                    self._error_without_raise_in_stream_stream),
100        }
101
102    @staticmethod
103    async def _unary_unary(unused_request, unused_context):
104        return _RESPONSE
105
106    async def _block_forever(self, unused_request, unused_context):
107        await asyncio.get_event_loop().create_future()
108
109    async def _block_briefly(self, unused_request, unused_context):
110        await asyncio.sleep(test_constants.SHORT_TIMEOUT / 2)
111        return _RESPONSE
112
113    async def _unary_stream_async_gen(self, unused_request, unused_context):
114        for _ in range(_NUM_STREAM_RESPONSES):
115            yield _RESPONSE
116
117    async def _unary_stream_reader_writer(self, unused_request, context):
118        for _ in range(_NUM_STREAM_RESPONSES):
119            await context.write(_RESPONSE)
120
121    async def _unary_stream_evilly_mixed(self, unused_request, context):
122        yield _RESPONSE
123        for _ in range(_NUM_STREAM_RESPONSES - 1):
124            await context.write(_RESPONSE)
125
126    async def _stream_unary_async_gen(self, request_iterator, unused_context):
127        request_count = 0
128        async for request in request_iterator:
129            assert _REQUEST == request
130            request_count += 1
131        assert _NUM_STREAM_REQUESTS == request_count
132        return _RESPONSE
133
134    async def _stream_unary_reader_writer(self, unused_request, context):
135        for _ in range(_NUM_STREAM_REQUESTS):
136            assert _REQUEST == await context.read()
137        return _RESPONSE
138
139    async def _stream_unary_evilly_mixed(self, request_iterator, context):
140        assert _REQUEST == await context.read()
141        request_count = 0
142        async for request in request_iterator:
143            assert _REQUEST == request
144            request_count += 1
145        assert _NUM_STREAM_REQUESTS - 1 == request_count
146        return _RESPONSE
147
148    async def _stream_stream_async_gen(self, request_iterator, unused_context):
149        request_count = 0
150        async for request in request_iterator:
151            assert _REQUEST == request
152            request_count += 1
153        assert _NUM_STREAM_REQUESTS == request_count
154
155        for _ in range(_NUM_STREAM_RESPONSES):
156            yield _RESPONSE
157
158    async def _stream_stream_reader_writer(self, unused_request, context):
159        for _ in range(_NUM_STREAM_REQUESTS):
160            assert _REQUEST == await context.read()
161        for _ in range(_NUM_STREAM_RESPONSES):
162            await context.write(_RESPONSE)
163
164    async def _stream_stream_evilly_mixed(self, request_iterator, context):
165        assert _REQUEST == await context.read()
166        request_count = 0
167        async for request in request_iterator:
168            assert _REQUEST == request
169            request_count += 1
170        assert _NUM_STREAM_REQUESTS - 1 == request_count
171
172        yield _RESPONSE
173        for _ in range(_NUM_STREAM_RESPONSES - 1):
174            await context.write(_RESPONSE)
175
176    async def _error_in_stream_stream(self, request_iterator, unused_context):
177        async for request in request_iterator:
178            assert _REQUEST == request
179            raise RuntimeError('A testing RuntimeError!')
180        yield _RESPONSE
181
182    async def _error_without_raise_in_unary_unary(self, request, context):
183        assert _REQUEST == request
184        context.set_code(grpc.StatusCode.INTERNAL)
185
186    async def _error_without_raise_in_stream_stream(self, request_iterator,
187                                                    context):
188        async for request in request_iterator:
189            assert _REQUEST == request
190        context.set_code(grpc.StatusCode.INTERNAL)
191
192    def service(self, handler_details):
193        if not self._called.done():
194            self._called.set_result(None)
195        return self._routing_table.get(handler_details.method)
196
197    async def wait_for_call(self):
198        await self._called
199
200
201async def _start_test_server():
202    server = aio.server()
203    port = server.add_insecure_port('[::]:0')
204    generic_handler = _GenericHandler()
205    server.add_generic_rpc_handlers((generic_handler,))
206    await server.start()
207    return 'localhost:%d' % port, server, generic_handler
208
209
210class TestServer(AioTestBase):
211
212    async def setUp(self):
213        addr, self._server, self._generic_handler = await _start_test_server()
214        self._channel = aio.insecure_channel(addr)
215
216    async def tearDown(self):
217        await self._channel.close()
218        await self._server.stop(None)
219
220    async def test_unary_unary(self):
221        unary_unary_call = self._channel.unary_unary(_SIMPLE_UNARY_UNARY)
222        response = await unary_unary_call(_REQUEST)
223        self.assertEqual(response, _RESPONSE)
224
225    async def test_unary_stream_async_generator(self):
226        unary_stream_call = self._channel.unary_stream(_UNARY_STREAM_ASYNC_GEN)
227        call = unary_stream_call(_REQUEST)
228
229        response_cnt = 0
230        async for response in call:
231            response_cnt += 1
232            self.assertEqual(_RESPONSE, response)
233
234        self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
235        self.assertEqual(await call.code(), grpc.StatusCode.OK)
236
237    async def test_unary_stream_reader_writer(self):
238        unary_stream_call = self._channel.unary_stream(
239            _UNARY_STREAM_READER_WRITER)
240        call = unary_stream_call(_REQUEST)
241
242        for _ in range(_NUM_STREAM_RESPONSES):
243            response = await call.read()
244            self.assertEqual(_RESPONSE, response)
245
246        self.assertEqual(await call.code(), grpc.StatusCode.OK)
247
248    async def test_unary_stream_evilly_mixed(self):
249        unary_stream_call = self._channel.unary_stream(
250            _UNARY_STREAM_EVILLY_MIXED)
251        call = unary_stream_call(_REQUEST)
252
253        # Uses reader API
254        self.assertEqual(_RESPONSE, await call.read())
255
256        # Uses async generator API, mixed!
257        with self.assertRaises(aio.UsageError):
258            async for response in call:
259                self.assertEqual(_RESPONSE, response)
260
261    async def test_stream_unary_async_generator(self):
262        stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN)
263        call = stream_unary_call()
264
265        for _ in range(_NUM_STREAM_REQUESTS):
266            await call.write(_REQUEST)
267        await call.done_writing()
268
269        response = await call
270        self.assertEqual(_RESPONSE, response)
271        self.assertEqual(await call.code(), grpc.StatusCode.OK)
272
273    async def test_stream_unary_reader_writer(self):
274        stream_unary_call = self._channel.stream_unary(
275            _STREAM_UNARY_READER_WRITER)
276        call = stream_unary_call()
277
278        for _ in range(_NUM_STREAM_REQUESTS):
279            await call.write(_REQUEST)
280        await call.done_writing()
281
282        response = await call
283        self.assertEqual(_RESPONSE, response)
284        self.assertEqual(await call.code(), grpc.StatusCode.OK)
285
286    async def test_stream_unary_evilly_mixed(self):
287        stream_unary_call = self._channel.stream_unary(
288            _STREAM_UNARY_EVILLY_MIXED)
289        call = stream_unary_call()
290
291        for _ in range(_NUM_STREAM_REQUESTS):
292            await call.write(_REQUEST)
293        await call.done_writing()
294
295        response = await call
296        self.assertEqual(_RESPONSE, response)
297        self.assertEqual(await call.code(), grpc.StatusCode.OK)
298
299    async def test_stream_stream_async_generator(self):
300        stream_stream_call = self._channel.stream_stream(
301            _STREAM_STREAM_ASYNC_GEN)
302        call = stream_stream_call()
303
304        for _ in range(_NUM_STREAM_REQUESTS):
305            await call.write(_REQUEST)
306        await call.done_writing()
307
308        for _ in range(_NUM_STREAM_RESPONSES):
309            response = await call.read()
310            self.assertEqual(_RESPONSE, response)
311
312        self.assertEqual(await call.code(), grpc.StatusCode.OK)
313
314    async def test_stream_stream_reader_writer(self):
315        stream_stream_call = self._channel.stream_stream(
316            _STREAM_STREAM_READER_WRITER)
317        call = stream_stream_call()
318
319        for _ in range(_NUM_STREAM_REQUESTS):
320            await call.write(_REQUEST)
321        await call.done_writing()
322
323        for _ in range(_NUM_STREAM_RESPONSES):
324            response = await call.read()
325            self.assertEqual(_RESPONSE, response)
326
327        self.assertEqual(await call.code(), grpc.StatusCode.OK)
328
329    async def test_stream_stream_evilly_mixed(self):
330        stream_stream_call = self._channel.stream_stream(
331            _STREAM_STREAM_EVILLY_MIXED)
332        call = stream_stream_call()
333
334        for _ in range(_NUM_STREAM_REQUESTS):
335            await call.write(_REQUEST)
336        await call.done_writing()
337
338        for _ in range(_NUM_STREAM_RESPONSES):
339            response = await call.read()
340            self.assertEqual(_RESPONSE, response)
341
342        self.assertEqual(await call.code(), grpc.StatusCode.OK)
343
344    async def test_shutdown(self):
345        await self._server.stop(None)
346        # Ensures no SIGSEGV triggered, and ends within timeout.
347
348    async def test_shutdown_after_call(self):
349        await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
350
351        await self._server.stop(None)
352
353    async def test_graceful_shutdown_success(self):
354        call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
355        await self._generic_handler.wait_for_call()
356
357        shutdown_start_time = time.time()
358        await self._server.stop(test_constants.SHORT_TIMEOUT)
359        grace_period_length = time.time() - shutdown_start_time
360        self.assertGreater(grace_period_length,
361                           test_constants.SHORT_TIMEOUT / 3)
362
363        # Validates the states.
364        self.assertEqual(_RESPONSE, await call)
365        self.assertTrue(call.done())
366
367    async def test_graceful_shutdown_failed(self):
368        call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
369        await self._generic_handler.wait_for_call()
370
371        await self._server.stop(test_constants.SHORT_TIMEOUT)
372
373        with self.assertRaises(aio.AioRpcError) as exception_context:
374            await call
375        self.assertEqual(grpc.StatusCode.UNAVAILABLE,
376                         exception_context.exception.code())
377
378    async def test_concurrent_graceful_shutdown(self):
379        call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
380        await self._generic_handler.wait_for_call()
381
382        # Expects the shortest grace period to be effective.
383        shutdown_start_time = time.time()
384        await asyncio.gather(
385            self._server.stop(test_constants.LONG_TIMEOUT),
386            self._server.stop(test_constants.SHORT_TIMEOUT),
387            self._server.stop(test_constants.LONG_TIMEOUT),
388        )
389        grace_period_length = time.time() - shutdown_start_time
390        self.assertGreater(grace_period_length,
391                           test_constants.SHORT_TIMEOUT / 3)
392
393        self.assertEqual(_RESPONSE, await call)
394        self.assertTrue(call.done())
395
396    async def test_concurrent_graceful_shutdown_immediate(self):
397        call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
398        await self._generic_handler.wait_for_call()
399
400        # Expects no grace period, due to the "server.stop(None)".
401        await asyncio.gather(
402            self._server.stop(test_constants.LONG_TIMEOUT),
403            self._server.stop(None),
404            self._server.stop(test_constants.SHORT_TIMEOUT),
405            self._server.stop(test_constants.LONG_TIMEOUT),
406        )
407
408        with self.assertRaises(aio.AioRpcError) as exception_context:
409            await call
410        self.assertEqual(grpc.StatusCode.UNAVAILABLE,
411                         exception_context.exception.code())
412
413    async def test_shutdown_before_call(self):
414        await self._server.stop(None)
415
416        # Ensures the server is cleaned up at this point.
417        # Some proper exception should be raised.
418        with self.assertRaises(aio.AioRpcError):
419            await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
420
421    async def test_unimplemented(self):
422        call = self._channel.unary_unary(_UNIMPLEMENTED_METHOD)
423        with self.assertRaises(aio.AioRpcError) as exception_context:
424            await call(_REQUEST)
425        rpc_error = exception_context.exception
426        self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
427
428    async def test_shutdown_during_stream_stream(self):
429        stream_stream_call = self._channel.stream_stream(
430            _STREAM_STREAM_ASYNC_GEN)
431        call = stream_stream_call()
432
433        # Don't half close the RPC yet, keep it alive.
434        await call.write(_REQUEST)
435        await self._server.stop(None)
436
437        self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
438        # No segfault
439
440    async def test_error_in_stream_stream(self):
441        stream_stream_call = self._channel.stream_stream(
442            _ERROR_IN_STREAM_STREAM)
443        call = stream_stream_call()
444
445        # Don't half close the RPC yet, keep it alive.
446        await call.write(_REQUEST)
447
448        # Don't segfault here
449        self.assertEqual(grpc.StatusCode.UNKNOWN, await call.code())
450
451    async def test_error_without_raise_in_unary_unary(self):
452        call = self._channel.unary_unary(_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY)(
453            _REQUEST)
454
455        with self.assertRaises(aio.AioRpcError) as exception_context:
456            await call
457
458        rpc_error = exception_context.exception
459        self.assertEqual(grpc.StatusCode.INTERNAL, rpc_error.code())
460
461    async def test_error_without_raise_in_stream_stream(self):
462        call = self._channel.stream_stream(
463            _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM)()
464
465        for _ in range(_NUM_STREAM_REQUESTS):
466            await call.write(_REQUEST)
467        await call.done_writing()
468
469        self.assertEqual(grpc.StatusCode.INTERNAL, await call.code())
470
471    async def test_port_binding_exception(self):
472        server = aio.server(options=(('grpc.so_reuseport', 0),))
473        port = server.add_insecure_port('localhost:0')
474        bind_address = "localhost:%d" % port
475
476        with self.assertRaises(RuntimeError):
477            server.add_insecure_port(bind_address)
478
479        server_credentials = grpc.ssl_server_credentials([
480            (resources.private_key(), resources.certificate_chain())
481        ])
482        with self.assertRaises(RuntimeError):
483            server.add_secure_port(bind_address, server_credentials)
484
485    async def test_maximum_concurrent_rpcs(self):
486        # Build the server with concurrent rpc argument
487        server = aio.server(maximum_concurrent_rpcs=_MAXIMUM_CONCURRENT_RPCS)
488        port = server.add_insecure_port('localhost:0')
489        bind_address = "localhost:%d" % port
490        server.add_generic_rpc_handlers((_GenericHandler(),))
491        await server.start()
492        # Build the channel
493        channel = aio.insecure_channel(bind_address)
494        # Deplete the concurrent quota with 3 times of max RPCs
495        rpcs = []
496        for _ in range(3 * _MAXIMUM_CONCURRENT_RPCS):
497            rpcs.append(channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST))
498        task = self.loop.create_task(
499            asyncio.wait(rpcs, return_when=asyncio.FIRST_EXCEPTION))
500        # Each batch took test_constants.SHORT_TIMEOUT /2
501        start_time = time.time()
502        await task
503        elapsed_time = time.time() - start_time
504        self.assertGreater(elapsed_time, test_constants.SHORT_TIMEOUT * 3 / 2)
505        # Clean-up
506        await channel.close()
507        await server.stop(0)
508
509
510if __name__ == '__main__':
511    logging.basicConfig(level=logging.DEBUG)
512    unittest.main(verbosity=2)
513