# Copyright 2019 The gRPC Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import logging import time import unittest import grpc from grpc.experimental import aio from tests.unit import resources from tests.unit.framework.common import test_constants from tests_aio.unit._test_base import AioTestBase _SIMPLE_UNARY_UNARY = "/test/SimpleUnaryUnary" _BLOCK_FOREVER = "/test/BlockForever" _BLOCK_BRIEFLY = "/test/BlockBriefly" _UNARY_STREAM_ASYNC_GEN = "/test/UnaryStreamAsyncGen" _UNARY_STREAM_READER_WRITER = "/test/UnaryStreamReaderWriter" _UNARY_STREAM_EVILLY_MIXED = "/test/UnaryStreamEvillyMixed" _STREAM_UNARY_ASYNC_GEN = "/test/StreamUnaryAsyncGen" _STREAM_UNARY_READER_WRITER = "/test/StreamUnaryReaderWriter" _STREAM_UNARY_EVILLY_MIXED = "/test/StreamUnaryEvillyMixed" _STREAM_STREAM_ASYNC_GEN = "/test/StreamStreamAsyncGen" _STREAM_STREAM_READER_WRITER = "/test/StreamStreamReaderWriter" _STREAM_STREAM_EVILLY_MIXED = "/test/StreamStreamEvillyMixed" _UNIMPLEMENTED_METHOD = "/test/UnimplementedMethod" _ERROR_IN_STREAM_STREAM = "/test/ErrorInStreamStream" _ERROR_IN_STREAM_UNARY = "/test/ErrorInStreamUnary" _ERROR_WITHOUT_RAISE_IN_UNARY_UNARY = "/test/ErrorWithoutRaiseInUnaryUnary" _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM = "/test/ErrorWithoutRaiseInStreamStream" _INVALID_TRAILING_METADATA = "/test/InvalidTrailingMetadata" _REQUEST = b"\x00\x00\x00" _RESPONSE = b"\x01\x01\x01" _NUM_STREAM_REQUESTS = 3 _NUM_STREAM_RESPONSES = 5 _MAXIMUM_CONCURRENT_RPCS = 5 class _GenericHandler(grpc.GenericRpcHandler): def __init__(self): self._called = asyncio.get_event_loop().create_future() self._routing_table = { _SIMPLE_UNARY_UNARY: grpc.unary_unary_rpc_method_handler( self._unary_unary ), _BLOCK_FOREVER: grpc.unary_unary_rpc_method_handler( self._block_forever ), _BLOCK_BRIEFLY: grpc.unary_unary_rpc_method_handler( self._block_briefly ), _UNARY_STREAM_ASYNC_GEN: grpc.unary_stream_rpc_method_handler( self._unary_stream_async_gen ), _UNARY_STREAM_READER_WRITER: grpc.unary_stream_rpc_method_handler( self._unary_stream_reader_writer ), _UNARY_STREAM_EVILLY_MIXED: grpc.unary_stream_rpc_method_handler( self._unary_stream_evilly_mixed ), _STREAM_UNARY_ASYNC_GEN: grpc.stream_unary_rpc_method_handler( self._stream_unary_async_gen ), _STREAM_UNARY_READER_WRITER: grpc.stream_unary_rpc_method_handler( self._stream_unary_reader_writer ), _STREAM_UNARY_EVILLY_MIXED: grpc.stream_unary_rpc_method_handler( self._stream_unary_evilly_mixed ), _STREAM_STREAM_ASYNC_GEN: grpc.stream_stream_rpc_method_handler( self._stream_stream_async_gen ), _STREAM_STREAM_READER_WRITER: grpc.stream_stream_rpc_method_handler( self._stream_stream_reader_writer ), _STREAM_STREAM_EVILLY_MIXED: grpc.stream_stream_rpc_method_handler( self._stream_stream_evilly_mixed ), _ERROR_IN_STREAM_STREAM: grpc.stream_stream_rpc_method_handler( self._error_in_stream_stream ), _ERROR_IN_STREAM_UNARY: grpc.stream_unary_rpc_method_handler( self._value_error_in_stream_unary ), _ERROR_WITHOUT_RAISE_IN_UNARY_UNARY: grpc.unary_unary_rpc_method_handler( self._error_without_raise_in_unary_unary ), _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM: grpc.stream_stream_rpc_method_handler( self._error_without_raise_in_stream_stream ), _INVALID_TRAILING_METADATA: grpc.unary_unary_rpc_method_handler( self._invalid_trailing_metadata ), } @staticmethod async def _unary_unary(unused_request, unused_context): return _RESPONSE async def _block_forever(self, unused_request, unused_context): await asyncio.get_event_loop().create_future() async def _block_briefly(self, unused_request, unused_context): await asyncio.sleep(test_constants.SHORT_TIMEOUT / 2) return _RESPONSE async def _unary_stream_async_gen(self, unused_request, unused_context): for _ in range(_NUM_STREAM_RESPONSES): yield _RESPONSE async def _unary_stream_reader_writer(self, unused_request, context): for _ in range(_NUM_STREAM_RESPONSES): await context.write(_RESPONSE) async def _unary_stream_evilly_mixed(self, unused_request, context): yield _RESPONSE for _ in range(_NUM_STREAM_RESPONSES - 1): await context.write(_RESPONSE) async def _stream_unary_async_gen(self, request_iterator, unused_context): request_count = 0 async for request in request_iterator: assert _REQUEST == request request_count += 1 assert _NUM_STREAM_REQUESTS == request_count return _RESPONSE async def _stream_unary_reader_writer(self, unused_request, context): for _ in range(_NUM_STREAM_REQUESTS): assert _REQUEST == await context.read() return _RESPONSE async def _stream_unary_evilly_mixed(self, request_iterator, context): assert _REQUEST == await context.read() request_count = 0 async for request in request_iterator: assert _REQUEST == request request_count += 1 assert _NUM_STREAM_REQUESTS - 1 == request_count return _RESPONSE async def _stream_stream_async_gen(self, request_iterator, unused_context): request_count = 0 async for request in request_iterator: assert _REQUEST == request request_count += 1 assert _NUM_STREAM_REQUESTS == request_count for _ in range(_NUM_STREAM_RESPONSES): yield _RESPONSE async def _stream_stream_reader_writer(self, unused_request, context): for _ in range(_NUM_STREAM_REQUESTS): assert _REQUEST == await context.read() for _ in range(_NUM_STREAM_RESPONSES): await context.write(_RESPONSE) async def _stream_stream_evilly_mixed(self, request_iterator, context): assert _REQUEST == await context.read() request_count = 0 async for request in request_iterator: assert _REQUEST == request request_count += 1 assert _NUM_STREAM_REQUESTS - 1 == request_count yield _RESPONSE for _ in range(_NUM_STREAM_RESPONSES - 1): await context.write(_RESPONSE) async def _error_in_stream_stream(self, request_iterator, unused_context): async for request in request_iterator: assert _REQUEST == request raise RuntimeError("A testing RuntimeError!") yield _RESPONSE async def _value_error_in_stream_unary(self, request_iterator, context): request_count = 0 async for request in request_iterator: assert _REQUEST == request request_count += 1 if request_count >= 1: raise ValueError("A testing RuntimeError!") async def _error_without_raise_in_unary_unary(self, request, context): assert _REQUEST == request context.set_code(grpc.StatusCode.INTERNAL) async def _error_without_raise_in_stream_stream( self, request_iterator, context ): async for request in request_iterator: assert _REQUEST == request context.set_code(grpc.StatusCode.INTERNAL) async def _invalid_trailing_metadata(self, request, context): assert _REQUEST == request for invalid_metadata in [ 42, {}, {"error": "error"}, [{"error": "error"}], ]: try: context.set_trailing_metadata(invalid_metadata) except TypeError: pass else: raise ValueError( "No TypeError raised for invalid metadata:" f" {invalid_metadata}" ) await context.abort( grpc.StatusCode.DATA_LOSS, details="invalid abort", trailing_metadata=({"error": ("error1", "error2")}), ) def service(self, handler_details): if not self._called.done(): self._called.set_result(None) return self._routing_table.get(handler_details.method) async def wait_for_call(self): await self._called async def _start_test_server(): server = aio.server() port = server.add_insecure_port("[::]:0") generic_handler = _GenericHandler() server.add_generic_rpc_handlers((generic_handler,)) await server.start() return "localhost:%d" % port, server, generic_handler class TestServer(AioTestBase): async def setUp(self): addr, self._server, self._generic_handler = await _start_test_server() self._channel = aio.insecure_channel(addr) async def tearDown(self): await self._channel.close() await self._server.stop(None) async def test_unary_unary(self): unary_unary_call = self._channel.unary_unary(_SIMPLE_UNARY_UNARY) response = await unary_unary_call(_REQUEST) self.assertEqual(response, _RESPONSE) async def test_unary_stream_async_generator(self): unary_stream_call = self._channel.unary_stream(_UNARY_STREAM_ASYNC_GEN) call = unary_stream_call(_REQUEST) response_cnt = 0 async for response in call: response_cnt += 1 self.assertEqual(_RESPONSE, response) self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_unary_stream_reader_writer(self): unary_stream_call = self._channel.unary_stream( _UNARY_STREAM_READER_WRITER ) call = unary_stream_call(_REQUEST) for _ in range(_NUM_STREAM_RESPONSES): response = await call.read() self.assertEqual(_RESPONSE, response) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_unary_stream_evilly_mixed(self): unary_stream_call = self._channel.unary_stream( _UNARY_STREAM_EVILLY_MIXED ) call = unary_stream_call(_REQUEST) # Uses reader API self.assertEqual(_RESPONSE, await call.read()) # Uses async generator API, mixed! with self.assertRaises(aio.UsageError): async for response in call: self.assertEqual(_RESPONSE, response) async def test_stream_unary_async_generator(self): stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN) call = stream_unary_call() for _ in range(_NUM_STREAM_REQUESTS): await call.write(_REQUEST) await call.done_writing() response = await call self.assertEqual(_RESPONSE, response) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_stream_unary_async_generator_with_request_iter(self): stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN) finished = False def request_gen(): for _ in range(_NUM_STREAM_REQUESTS): yield _REQUEST nonlocal finished finished = True call = stream_unary_call(request_gen()) response = await call self.assertEqual(_RESPONSE, response) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(finished, True) async def test_stream_unary_reader_writer(self): stream_unary_call = self._channel.stream_unary( _STREAM_UNARY_READER_WRITER ) call = stream_unary_call() for _ in range(_NUM_STREAM_REQUESTS): await call.write(_REQUEST) await call.done_writing() response = await call self.assertEqual(_RESPONSE, response) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_stream_unary_evilly_mixed(self): stream_unary_call = self._channel.stream_unary( _STREAM_UNARY_EVILLY_MIXED ) call = stream_unary_call() for _ in range(_NUM_STREAM_REQUESTS): await call.write(_REQUEST) await call.done_writing() response = await call self.assertEqual(_RESPONSE, response) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_stream_stream_async_generator(self): stream_stream_call = self._channel.stream_stream( _STREAM_STREAM_ASYNC_GEN ) call = stream_stream_call() for _ in range(_NUM_STREAM_REQUESTS): await call.write(_REQUEST) await call.done_writing() for _ in range(_NUM_STREAM_RESPONSES): response = await call.read() self.assertEqual(_RESPONSE, response) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_stream_stream_reader_writer(self): stream_stream_call = self._channel.stream_stream( _STREAM_STREAM_READER_WRITER ) call = stream_stream_call() for _ in range(_NUM_STREAM_REQUESTS): await call.write(_REQUEST) await call.done_writing() for _ in range(_NUM_STREAM_RESPONSES): response = await call.read() self.assertEqual(_RESPONSE, response) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_stream_stream_evilly_mixed(self): stream_stream_call = self._channel.stream_stream( _STREAM_STREAM_EVILLY_MIXED ) call = stream_stream_call() for _ in range(_NUM_STREAM_REQUESTS): await call.write(_REQUEST) await call.done_writing() for _ in range(_NUM_STREAM_RESPONSES): response = await call.read() self.assertEqual(_RESPONSE, response) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_shutdown(self): await self._server.stop(None) # Ensures no SIGSEGV triggered, and ends within timeout. async def test_shutdown_after_call(self): await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) await self._server.stop(None) async def test_graceful_shutdown_success(self): call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) await self._generic_handler.wait_for_call() shutdown_start_time = time.time() await self._server.stop(test_constants.SHORT_TIMEOUT) grace_period_length = time.time() - shutdown_start_time self.assertGreater( grace_period_length, test_constants.SHORT_TIMEOUT / 3 ) # Validates the states. self.assertEqual(_RESPONSE, await call) self.assertTrue(call.done()) async def test_graceful_shutdown_failed(self): call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) await self._generic_handler.wait_for_call() await self._server.stop(test_constants.SHORT_TIMEOUT) with self.assertRaises(aio.AioRpcError) as exception_context: await call self.assertEqual( grpc.StatusCode.UNAVAILABLE, exception_context.exception.code() ) async def test_concurrent_graceful_shutdown(self): call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) await self._generic_handler.wait_for_call() # Expects the shortest grace period to be effective. shutdown_start_time = time.time() await asyncio.gather( self._server.stop(test_constants.LONG_TIMEOUT), self._server.stop(test_constants.SHORT_TIMEOUT), self._server.stop(test_constants.LONG_TIMEOUT), ) grace_period_length = time.time() - shutdown_start_time self.assertGreater( grace_period_length, test_constants.SHORT_TIMEOUT / 3 ) self.assertEqual(_RESPONSE, await call) self.assertTrue(call.done()) async def test_concurrent_graceful_shutdown_immediate(self): call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) await self._generic_handler.wait_for_call() # Expects no grace period, due to the "server.stop(None)". await asyncio.gather( self._server.stop(test_constants.LONG_TIMEOUT), self._server.stop(None), self._server.stop(test_constants.SHORT_TIMEOUT), self._server.stop(test_constants.LONG_TIMEOUT), ) with self.assertRaises(aio.AioRpcError) as exception_context: await call self.assertEqual( grpc.StatusCode.UNAVAILABLE, exception_context.exception.code() ) async def test_shutdown_before_call(self): await self._server.stop(None) # Ensures the server is cleaned up at this point. # Some proper exception should be raised. with self.assertRaises(aio.AioRpcError): await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) async def test_unimplemented(self): call = self._channel.unary_unary(_UNIMPLEMENTED_METHOD) with self.assertRaises(aio.AioRpcError) as exception_context: await call(_REQUEST) rpc_error = exception_context.exception self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) async def test_shutdown_during_stream_stream(self): stream_stream_call = self._channel.stream_stream( _STREAM_STREAM_ASYNC_GEN ) call = stream_stream_call() # Don't half close the RPC yet, keep it alive. await call.write(_REQUEST) await self._server.stop(None) self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) # No segfault async def test_error_in_stream_stream(self): stream_stream_call = self._channel.stream_stream( _ERROR_IN_STREAM_STREAM ) call = stream_stream_call() # Don't half close the RPC yet, keep it alive. await call.write(_REQUEST) # Don't segfault here self.assertEqual(grpc.StatusCode.UNKNOWN, await call.code()) async def test_error_without_raise_in_unary_unary(self): call = self._channel.unary_unary(_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY)( _REQUEST ) with self.assertRaises(aio.AioRpcError) as exception_context: await call rpc_error = exception_context.exception self.assertEqual(grpc.StatusCode.INTERNAL, rpc_error.code()) async def test_error_without_raise_in_stream_stream(self): call = self._channel.stream_stream( _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM )() for _ in range(_NUM_STREAM_REQUESTS): await call.write(_REQUEST) await call.done_writing() self.assertEqual(grpc.StatusCode.INTERNAL, await call.code()) async def test_error_in_stream_unary(self): stream_unary_call = self._channel.stream_unary(_ERROR_IN_STREAM_UNARY) async def request_gen(): for _ in range(_NUM_STREAM_REQUESTS): yield _REQUEST call = stream_unary_call(request_gen()) with self.assertRaises(aio.AioRpcError) as exception_context: await call rpc_error = exception_context.exception self.assertEqual(grpc.StatusCode.UNKNOWN, rpc_error.code()) async def test_port_binding_exception(self): server = aio.server(options=(("grpc.so_reuseport", 0),)) port = server.add_insecure_port("localhost:0") bind_address = "localhost:%d" % port with self.assertRaises(RuntimeError): server.add_insecure_port(bind_address) server_credentials = grpc.ssl_server_credentials( [(resources.private_key(), resources.certificate_chain())] ) with self.assertRaises(RuntimeError): server.add_secure_port(bind_address, server_credentials) async def test_maximum_concurrent_rpcs(self): async def coro_wrapper(awaitable): return await awaitable # Build the server with concurrent rpc argument server = aio.server(maximum_concurrent_rpcs=_MAXIMUM_CONCURRENT_RPCS) port = server.add_insecure_port("localhost:0") bind_address = "localhost:%d" % port server.add_generic_rpc_handlers((_GenericHandler(),)) await server.start() # Build the channel channel = aio.insecure_channel(bind_address) # Deplete the concurrent quota with 3 times of max RPCs rpc_tasks = [] for _ in range(3 * _MAXIMUM_CONCURRENT_RPCS): task = asyncio.create_task( coro_wrapper(channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)) ) rpc_tasks.append(task) await_tasks = asyncio.wait(rpc_tasks, return_when=asyncio.ALL_COMPLETED) done, _ = await await_tasks exceptions = [] for task in done: exception = task.exception() if exception: exceptions.append(exception) # Check whether the number of tasks raised RESOURCE_EXHAUSTED is roughly # 2 * _MAXIMUM_CONCURRENT_RPCS. self.assertAlmostEqual( len(exceptions), 2 * _MAXIMUM_CONCURRENT_RPCS, delta=_MAXIMUM_CONCURRENT_RPCS, ) for exception in exceptions: self.assertTrue(isinstance(exception, aio.AioRpcError)) self.assertEqual( grpc.StatusCode.RESOURCE_EXHAUSTED, exception.code() ) self.assertIn("Concurrent RPC limit exceeded", exception.details()) # Clean-up await channel.close() await server.stop(0) async def test_invalid_trailing_metadata(self): call = self._channel.unary_unary(_INVALID_TRAILING_METADATA)(_REQUEST) with self.assertRaises(aio.AioRpcError) as exception_context: await call rpc_error = exception_context.exception self.assertEqual(grpc.StatusCode.UNKNOWN, rpc_error.code()) self.assertIn("trailing", rpc_error.details()) if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2)