• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The gRPC Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import gc
17import logging
18import time
19import unittest
20
21import grpc
22from grpc.experimental import aio
23
24from tests.unit.framework.common import test_constants
25from tests_aio.unit._test_base import AioTestBase
26
27_SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary'
28_BLOCK_FOREVER = '/test/BlockForever'
29_BLOCK_BRIEFLY = '/test/BlockBriefly'
30_UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen'
31_UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter'
32_UNARY_STREAM_EVILLY_MIXED = '/test/UnaryStreamEvillyMixed'
33_STREAM_UNARY_ASYNC_GEN = '/test/StreamUnaryAsyncGen'
34_STREAM_UNARY_READER_WRITER = '/test/StreamUnaryReaderWriter'
35_STREAM_UNARY_EVILLY_MIXED = '/test/StreamUnaryEvillyMixed'
36_STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen'
37_STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter'
38_STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed'
39_UNIMPLEMENTED_METHOD = '/test/UnimplementedMethod'
40_ERROR_IN_STREAM_STREAM = '/test/ErrorInStreamStream'
41_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY = '/test/ErrorWithoutRaiseInUnaryUnary'
42_ERROR_WITHOUT_RAISE_IN_STREAM_STREAM = '/test/ErrorWithoutRaiseInStreamStream'
43
44_REQUEST = b'\x00\x00\x00'
45_RESPONSE = b'\x01\x01\x01'
46_NUM_STREAM_REQUESTS = 3
47_NUM_STREAM_RESPONSES = 5
48
49
50class _GenericHandler(grpc.GenericRpcHandler):
51
52    def __init__(self):
53        self._called = asyncio.get_event_loop().create_future()
54        self._routing_table = {
55            _SIMPLE_UNARY_UNARY:
56                grpc.unary_unary_rpc_method_handler(self._unary_unary),
57            _BLOCK_FOREVER:
58                grpc.unary_unary_rpc_method_handler(self._block_forever),
59            _BLOCK_BRIEFLY:
60                grpc.unary_unary_rpc_method_handler(self._block_briefly),
61            _UNARY_STREAM_ASYNC_GEN:
62                grpc.unary_stream_rpc_method_handler(
63                    self._unary_stream_async_gen),
64            _UNARY_STREAM_READER_WRITER:
65                grpc.unary_stream_rpc_method_handler(
66                    self._unary_stream_reader_writer),
67            _UNARY_STREAM_EVILLY_MIXED:
68                grpc.unary_stream_rpc_method_handler(
69                    self._unary_stream_evilly_mixed),
70            _STREAM_UNARY_ASYNC_GEN:
71                grpc.stream_unary_rpc_method_handler(
72                    self._stream_unary_async_gen),
73            _STREAM_UNARY_READER_WRITER:
74                grpc.stream_unary_rpc_method_handler(
75                    self._stream_unary_reader_writer),
76            _STREAM_UNARY_EVILLY_MIXED:
77                grpc.stream_unary_rpc_method_handler(
78                    self._stream_unary_evilly_mixed),
79            _STREAM_STREAM_ASYNC_GEN:
80                grpc.stream_stream_rpc_method_handler(
81                    self._stream_stream_async_gen),
82            _STREAM_STREAM_READER_WRITER:
83                grpc.stream_stream_rpc_method_handler(
84                    self._stream_stream_reader_writer),
85            _STREAM_STREAM_EVILLY_MIXED:
86                grpc.stream_stream_rpc_method_handler(
87                    self._stream_stream_evilly_mixed),
88            _ERROR_IN_STREAM_STREAM:
89                grpc.stream_stream_rpc_method_handler(
90                    self._error_in_stream_stream),
91            _ERROR_WITHOUT_RAISE_IN_UNARY_UNARY:
92                grpc.unary_unary_rpc_method_handler(
93                    self._error_without_raise_in_unary_unary),
94            _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM:
95                grpc.stream_stream_rpc_method_handler(
96                    self._error_without_raise_in_stream_stream),
97        }
98
99    @staticmethod
100    async def _unary_unary(unused_request, unused_context):
101        return _RESPONSE
102
103    async def _block_forever(self, unused_request, unused_context):
104        await asyncio.get_event_loop().create_future()
105
106    async def _block_briefly(self, unused_request, unused_context):
107        await asyncio.sleep(test_constants.SHORT_TIMEOUT / 2)
108        return _RESPONSE
109
110    async def _unary_stream_async_gen(self, unused_request, unused_context):
111        for _ in range(_NUM_STREAM_RESPONSES):
112            yield _RESPONSE
113
114    async def _unary_stream_reader_writer(self, unused_request, context):
115        for _ in range(_NUM_STREAM_RESPONSES):
116            await context.write(_RESPONSE)
117
118    async def _unary_stream_evilly_mixed(self, unused_request, context):
119        yield _RESPONSE
120        for _ in range(_NUM_STREAM_RESPONSES - 1):
121            await context.write(_RESPONSE)
122
123    async def _stream_unary_async_gen(self, request_iterator, unused_context):
124        request_count = 0
125        async for request in request_iterator:
126            assert _REQUEST == request
127            request_count += 1
128        assert _NUM_STREAM_REQUESTS == request_count
129        return _RESPONSE
130
131    async def _stream_unary_reader_writer(self, unused_request, context):
132        for _ in range(_NUM_STREAM_REQUESTS):
133            assert _REQUEST == await context.read()
134        return _RESPONSE
135
136    async def _stream_unary_evilly_mixed(self, request_iterator, context):
137        assert _REQUEST == await context.read()
138        request_count = 0
139        async for request in request_iterator:
140            assert _REQUEST == request
141            request_count += 1
142        assert _NUM_STREAM_REQUESTS - 1 == request_count
143        return _RESPONSE
144
145    async def _stream_stream_async_gen(self, request_iterator, unused_context):
146        request_count = 0
147        async for request in request_iterator:
148            assert _REQUEST == request
149            request_count += 1
150        assert _NUM_STREAM_REQUESTS == request_count
151
152        for _ in range(_NUM_STREAM_RESPONSES):
153            yield _RESPONSE
154
155    async def _stream_stream_reader_writer(self, unused_request, context):
156        for _ in range(_NUM_STREAM_REQUESTS):
157            assert _REQUEST == await context.read()
158        for _ in range(_NUM_STREAM_RESPONSES):
159            await context.write(_RESPONSE)
160
161    async def _stream_stream_evilly_mixed(self, request_iterator, context):
162        assert _REQUEST == await context.read()
163        request_count = 0
164        async for request in request_iterator:
165            assert _REQUEST == request
166            request_count += 1
167        assert _NUM_STREAM_REQUESTS - 1 == request_count
168
169        yield _RESPONSE
170        for _ in range(_NUM_STREAM_RESPONSES - 1):
171            await context.write(_RESPONSE)
172
173    async def _error_in_stream_stream(self, request_iterator, unused_context):
174        async for request in request_iterator:
175            assert _REQUEST == request
176            raise RuntimeError('A testing RuntimeError!')
177        yield _RESPONSE
178
179    async def _error_without_raise_in_unary_unary(self, request, context):
180        assert _REQUEST == request
181        context.set_code(grpc.StatusCode.INTERNAL)
182
183    async def _error_without_raise_in_stream_stream(self, request_iterator,
184                                                    context):
185        async for request in request_iterator:
186            assert _REQUEST == request
187        context.set_code(grpc.StatusCode.INTERNAL)
188
189    def service(self, handler_details):
190        self._called.set_result(None)
191        return self._routing_table.get(handler_details.method)
192
193    async def wait_for_call(self):
194        await self._called
195
196
197async def _start_test_server():
198    server = aio.server()
199    port = server.add_insecure_port('[::]:0')
200    generic_handler = _GenericHandler()
201    server.add_generic_rpc_handlers((generic_handler,))
202    await server.start()
203    return 'localhost:%d' % port, server, generic_handler
204
205
206class TestServer(AioTestBase):
207
208    async def setUp(self):
209        addr, self._server, self._generic_handler = await _start_test_server()
210        self._channel = aio.insecure_channel(addr)
211
212    async def tearDown(self):
213        await self._channel.close()
214        await self._server.stop(None)
215
216    async def test_unary_unary(self):
217        unary_unary_call = self._channel.unary_unary(_SIMPLE_UNARY_UNARY)
218        response = await unary_unary_call(_REQUEST)
219        self.assertEqual(response, _RESPONSE)
220
221    async def test_unary_stream_async_generator(self):
222        unary_stream_call = self._channel.unary_stream(_UNARY_STREAM_ASYNC_GEN)
223        call = unary_stream_call(_REQUEST)
224
225        response_cnt = 0
226        async for response in call:
227            response_cnt += 1
228            self.assertEqual(_RESPONSE, response)
229
230        self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
231        self.assertEqual(await call.code(), grpc.StatusCode.OK)
232
233    async def test_unary_stream_reader_writer(self):
234        unary_stream_call = self._channel.unary_stream(
235            _UNARY_STREAM_READER_WRITER)
236        call = unary_stream_call(_REQUEST)
237
238        for _ in range(_NUM_STREAM_RESPONSES):
239            response = await call.read()
240            self.assertEqual(_RESPONSE, response)
241
242        self.assertEqual(await call.code(), grpc.StatusCode.OK)
243
244    async def test_unary_stream_evilly_mixed(self):
245        unary_stream_call = self._channel.unary_stream(
246            _UNARY_STREAM_EVILLY_MIXED)
247        call = unary_stream_call(_REQUEST)
248
249        # Uses reader API
250        self.assertEqual(_RESPONSE, await call.read())
251
252        # Uses async generator API, mixed!
253        with self.assertRaises(aio.UsageError):
254            async for response in call:
255                self.assertEqual(_RESPONSE, response)
256
257    async def test_stream_unary_async_generator(self):
258        stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN)
259        call = stream_unary_call()
260
261        for _ in range(_NUM_STREAM_REQUESTS):
262            await call.write(_REQUEST)
263        await call.done_writing()
264
265        response = await call
266        self.assertEqual(_RESPONSE, response)
267        self.assertEqual(await call.code(), grpc.StatusCode.OK)
268
269    async def test_stream_unary_reader_writer(self):
270        stream_unary_call = self._channel.stream_unary(
271            _STREAM_UNARY_READER_WRITER)
272        call = stream_unary_call()
273
274        for _ in range(_NUM_STREAM_REQUESTS):
275            await call.write(_REQUEST)
276        await call.done_writing()
277
278        response = await call
279        self.assertEqual(_RESPONSE, response)
280        self.assertEqual(await call.code(), grpc.StatusCode.OK)
281
282    async def test_stream_unary_evilly_mixed(self):
283        stream_unary_call = self._channel.stream_unary(
284            _STREAM_UNARY_EVILLY_MIXED)
285        call = stream_unary_call()
286
287        for _ in range(_NUM_STREAM_REQUESTS):
288            await call.write(_REQUEST)
289        await call.done_writing()
290
291        response = await call
292        self.assertEqual(_RESPONSE, response)
293        self.assertEqual(await call.code(), grpc.StatusCode.OK)
294
295    async def test_stream_stream_async_generator(self):
296        stream_stream_call = self._channel.stream_stream(
297            _STREAM_STREAM_ASYNC_GEN)
298        call = stream_stream_call()
299
300        for _ in range(_NUM_STREAM_REQUESTS):
301            await call.write(_REQUEST)
302        await call.done_writing()
303
304        for _ in range(_NUM_STREAM_RESPONSES):
305            response = await call.read()
306            self.assertEqual(_RESPONSE, response)
307
308        self.assertEqual(await call.code(), grpc.StatusCode.OK)
309
310    async def test_stream_stream_reader_writer(self):
311        stream_stream_call = self._channel.stream_stream(
312            _STREAM_STREAM_READER_WRITER)
313        call = stream_stream_call()
314
315        for _ in range(_NUM_STREAM_REQUESTS):
316            await call.write(_REQUEST)
317        await call.done_writing()
318
319        for _ in range(_NUM_STREAM_RESPONSES):
320            response = await call.read()
321            self.assertEqual(_RESPONSE, response)
322
323        self.assertEqual(await call.code(), grpc.StatusCode.OK)
324
325    async def test_stream_stream_evilly_mixed(self):
326        stream_stream_call = self._channel.stream_stream(
327            _STREAM_STREAM_EVILLY_MIXED)
328        call = stream_stream_call()
329
330        for _ in range(_NUM_STREAM_REQUESTS):
331            await call.write(_REQUEST)
332        await call.done_writing()
333
334        for _ in range(_NUM_STREAM_RESPONSES):
335            response = await call.read()
336            self.assertEqual(_RESPONSE, response)
337
338        self.assertEqual(await call.code(), grpc.StatusCode.OK)
339
340    async def test_shutdown(self):
341        await self._server.stop(None)
342        # Ensures no SIGSEGV triggered, and ends within timeout.
343
344    async def test_shutdown_after_call(self):
345        await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
346
347        await self._server.stop(None)
348
349    async def test_graceful_shutdown_success(self):
350        call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
351        await self._generic_handler.wait_for_call()
352
353        shutdown_start_time = time.time()
354        await self._server.stop(test_constants.SHORT_TIMEOUT)
355        grace_period_length = time.time() - shutdown_start_time
356        self.assertGreater(grace_period_length,
357                           test_constants.SHORT_TIMEOUT / 3)
358
359        # Validates the states.
360        self.assertEqual(_RESPONSE, await call)
361        self.assertTrue(call.done())
362
363    async def test_graceful_shutdown_failed(self):
364        call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
365        await self._generic_handler.wait_for_call()
366
367        await self._server.stop(test_constants.SHORT_TIMEOUT)
368
369        with self.assertRaises(aio.AioRpcError) as exception_context:
370            await call
371        self.assertEqual(grpc.StatusCode.UNAVAILABLE,
372                         exception_context.exception.code())
373
374    async def test_concurrent_graceful_shutdown(self):
375        call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)
376        await self._generic_handler.wait_for_call()
377
378        # Expects the shortest grace period to be effective.
379        shutdown_start_time = time.time()
380        await asyncio.gather(
381            self._server.stop(test_constants.LONG_TIMEOUT),
382            self._server.stop(test_constants.SHORT_TIMEOUT),
383            self._server.stop(test_constants.LONG_TIMEOUT),
384        )
385        grace_period_length = time.time() - shutdown_start_time
386        self.assertGreater(grace_period_length,
387                           test_constants.SHORT_TIMEOUT / 3)
388
389        self.assertEqual(_RESPONSE, await call)
390        self.assertTrue(call.done())
391
392    async def test_concurrent_graceful_shutdown_immediate(self):
393        call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST)
394        await self._generic_handler.wait_for_call()
395
396        # Expects no grace period, due to the "server.stop(None)".
397        await asyncio.gather(
398            self._server.stop(test_constants.LONG_TIMEOUT),
399            self._server.stop(None),
400            self._server.stop(test_constants.SHORT_TIMEOUT),
401            self._server.stop(test_constants.LONG_TIMEOUT),
402        )
403
404        with self.assertRaises(aio.AioRpcError) as exception_context:
405            await call
406        self.assertEqual(grpc.StatusCode.UNAVAILABLE,
407                         exception_context.exception.code())
408
409    async def test_shutdown_before_call(self):
410        await self._server.stop(None)
411
412        # Ensures the server is cleaned up at this point.
413        # Some proper exception should be raised.
414        with self.assertRaises(aio.AioRpcError):
415            await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST)
416
417    async def test_unimplemented(self):
418        call = self._channel.unary_unary(_UNIMPLEMENTED_METHOD)
419        with self.assertRaises(aio.AioRpcError) as exception_context:
420            await call(_REQUEST)
421        rpc_error = exception_context.exception
422        self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
423
424    async def test_shutdown_during_stream_stream(self):
425        stream_stream_call = self._channel.stream_stream(
426            _STREAM_STREAM_ASYNC_GEN)
427        call = stream_stream_call()
428
429        # Don't half close the RPC yet, keep it alive.
430        await call.write(_REQUEST)
431        await self._server.stop(None)
432
433        self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
434        # No segfault
435
436    async def test_error_in_stream_stream(self):
437        stream_stream_call = self._channel.stream_stream(
438            _ERROR_IN_STREAM_STREAM)
439        call = stream_stream_call()
440
441        # Don't half close the RPC yet, keep it alive.
442        await call.write(_REQUEST)
443
444        # Don't segfault here
445        self.assertEqual(grpc.StatusCode.UNKNOWN, await call.code())
446
447    async def test_error_without_raise_in_unary_unary(self):
448        call = self._channel.unary_unary(_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY)(
449            _REQUEST)
450
451        with self.assertRaises(aio.AioRpcError) as exception_context:
452            await call
453
454        rpc_error = exception_context.exception
455        self.assertEqual(grpc.StatusCode.INTERNAL, rpc_error.code())
456
457    async def test_error_without_raise_in_stream_stream(self):
458        call = self._channel.stream_stream(
459            _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM)()
460
461        for _ in range(_NUM_STREAM_REQUESTS):
462            await call.write(_REQUEST)
463        await call.done_writing()
464
465        self.assertEqual(grpc.StatusCode.INTERNAL, await call.code())
466
467
468if __name__ == '__main__':
469    logging.basicConfig(level=logging.DEBUG)
470    unittest.main(verbosity=2)
471