1# Copyright 2019 The gRPC Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import asyncio 16import gc 17import logging 18import socket 19import time 20import unittest 21 22import grpc 23from grpc.experimental import aio 24 25from tests.unit import resources 26from tests.unit.framework.common import test_constants 27from tests_aio.unit._test_base import AioTestBase 28 29_SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary' 30_BLOCK_FOREVER = '/test/BlockForever' 31_BLOCK_BRIEFLY = '/test/BlockBriefly' 32_UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen' 33_UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter' 34_UNARY_STREAM_EVILLY_MIXED = '/test/UnaryStreamEvillyMixed' 35_STREAM_UNARY_ASYNC_GEN = '/test/StreamUnaryAsyncGen' 36_STREAM_UNARY_READER_WRITER = '/test/StreamUnaryReaderWriter' 37_STREAM_UNARY_EVILLY_MIXED = '/test/StreamUnaryEvillyMixed' 38_STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen' 39_STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter' 40_STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed' 41_UNIMPLEMENTED_METHOD = '/test/UnimplementedMethod' 42_ERROR_IN_STREAM_STREAM = '/test/ErrorInStreamStream' 43_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY = '/test/ErrorWithoutRaiseInUnaryUnary' 44_ERROR_WITHOUT_RAISE_IN_STREAM_STREAM = '/test/ErrorWithoutRaiseInStreamStream' 45 46_REQUEST = b'\x00\x00\x00' 47_RESPONSE = b'\x01\x01\x01' 48_NUM_STREAM_REQUESTS = 3 49_NUM_STREAM_RESPONSES = 5 50_MAXIMUM_CONCURRENT_RPCS = 5 51 52 53class _GenericHandler(grpc.GenericRpcHandler): 54 55 def __init__(self): 56 self._called = asyncio.get_event_loop().create_future() 57 self._routing_table = { 58 _SIMPLE_UNARY_UNARY: 59 grpc.unary_unary_rpc_method_handler(self._unary_unary), 60 _BLOCK_FOREVER: 61 grpc.unary_unary_rpc_method_handler(self._block_forever), 62 _BLOCK_BRIEFLY: 63 grpc.unary_unary_rpc_method_handler(self._block_briefly), 64 _UNARY_STREAM_ASYNC_GEN: 65 grpc.unary_stream_rpc_method_handler( 66 self._unary_stream_async_gen), 67 _UNARY_STREAM_READER_WRITER: 68 grpc.unary_stream_rpc_method_handler( 69 self._unary_stream_reader_writer), 70 _UNARY_STREAM_EVILLY_MIXED: 71 grpc.unary_stream_rpc_method_handler( 72 self._unary_stream_evilly_mixed), 73 _STREAM_UNARY_ASYNC_GEN: 74 grpc.stream_unary_rpc_method_handler( 75 self._stream_unary_async_gen), 76 _STREAM_UNARY_READER_WRITER: 77 grpc.stream_unary_rpc_method_handler( 78 self._stream_unary_reader_writer), 79 _STREAM_UNARY_EVILLY_MIXED: 80 grpc.stream_unary_rpc_method_handler( 81 self._stream_unary_evilly_mixed), 82 _STREAM_STREAM_ASYNC_GEN: 83 grpc.stream_stream_rpc_method_handler( 84 self._stream_stream_async_gen), 85 _STREAM_STREAM_READER_WRITER: 86 grpc.stream_stream_rpc_method_handler( 87 self._stream_stream_reader_writer), 88 _STREAM_STREAM_EVILLY_MIXED: 89 grpc.stream_stream_rpc_method_handler( 90 self._stream_stream_evilly_mixed), 91 _ERROR_IN_STREAM_STREAM: 92 grpc.stream_stream_rpc_method_handler( 93 self._error_in_stream_stream), 94 _ERROR_WITHOUT_RAISE_IN_UNARY_UNARY: 95 grpc.unary_unary_rpc_method_handler( 96 self._error_without_raise_in_unary_unary), 97 _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM: 98 grpc.stream_stream_rpc_method_handler( 99 self._error_without_raise_in_stream_stream), 100 } 101 102 @staticmethod 103 async def _unary_unary(unused_request, unused_context): 104 return _RESPONSE 105 106 async def _block_forever(self, unused_request, unused_context): 107 await asyncio.get_event_loop().create_future() 108 109 async def _block_briefly(self, unused_request, unused_context): 110 await asyncio.sleep(test_constants.SHORT_TIMEOUT / 2) 111 return _RESPONSE 112 113 async def _unary_stream_async_gen(self, unused_request, unused_context): 114 for _ in range(_NUM_STREAM_RESPONSES): 115 yield _RESPONSE 116 117 async def _unary_stream_reader_writer(self, unused_request, context): 118 for _ in range(_NUM_STREAM_RESPONSES): 119 await context.write(_RESPONSE) 120 121 async def _unary_stream_evilly_mixed(self, unused_request, context): 122 yield _RESPONSE 123 for _ in range(_NUM_STREAM_RESPONSES - 1): 124 await context.write(_RESPONSE) 125 126 async def _stream_unary_async_gen(self, request_iterator, unused_context): 127 request_count = 0 128 async for request in request_iterator: 129 assert _REQUEST == request 130 request_count += 1 131 assert _NUM_STREAM_REQUESTS == request_count 132 return _RESPONSE 133 134 async def _stream_unary_reader_writer(self, unused_request, context): 135 for _ in range(_NUM_STREAM_REQUESTS): 136 assert _REQUEST == await context.read() 137 return _RESPONSE 138 139 async def _stream_unary_evilly_mixed(self, request_iterator, context): 140 assert _REQUEST == await context.read() 141 request_count = 0 142 async for request in request_iterator: 143 assert _REQUEST == request 144 request_count += 1 145 assert _NUM_STREAM_REQUESTS - 1 == request_count 146 return _RESPONSE 147 148 async def _stream_stream_async_gen(self, request_iterator, unused_context): 149 request_count = 0 150 async for request in request_iterator: 151 assert _REQUEST == request 152 request_count += 1 153 assert _NUM_STREAM_REQUESTS == request_count 154 155 for _ in range(_NUM_STREAM_RESPONSES): 156 yield _RESPONSE 157 158 async def _stream_stream_reader_writer(self, unused_request, context): 159 for _ in range(_NUM_STREAM_REQUESTS): 160 assert _REQUEST == await context.read() 161 for _ in range(_NUM_STREAM_RESPONSES): 162 await context.write(_RESPONSE) 163 164 async def _stream_stream_evilly_mixed(self, request_iterator, context): 165 assert _REQUEST == await context.read() 166 request_count = 0 167 async for request in request_iterator: 168 assert _REQUEST == request 169 request_count += 1 170 assert _NUM_STREAM_REQUESTS - 1 == request_count 171 172 yield _RESPONSE 173 for _ in range(_NUM_STREAM_RESPONSES - 1): 174 await context.write(_RESPONSE) 175 176 async def _error_in_stream_stream(self, request_iterator, unused_context): 177 async for request in request_iterator: 178 assert _REQUEST == request 179 raise RuntimeError('A testing RuntimeError!') 180 yield _RESPONSE 181 182 async def _error_without_raise_in_unary_unary(self, request, context): 183 assert _REQUEST == request 184 context.set_code(grpc.StatusCode.INTERNAL) 185 186 async def _error_without_raise_in_stream_stream(self, request_iterator, 187 context): 188 async for request in request_iterator: 189 assert _REQUEST == request 190 context.set_code(grpc.StatusCode.INTERNAL) 191 192 def service(self, handler_details): 193 if not self._called.done(): 194 self._called.set_result(None) 195 return self._routing_table.get(handler_details.method) 196 197 async def wait_for_call(self): 198 await self._called 199 200 201async def _start_test_server(): 202 server = aio.server() 203 port = server.add_insecure_port('[::]:0') 204 generic_handler = _GenericHandler() 205 server.add_generic_rpc_handlers((generic_handler,)) 206 await server.start() 207 return 'localhost:%d' % port, server, generic_handler 208 209 210class TestServer(AioTestBase): 211 212 async def setUp(self): 213 addr, self._server, self._generic_handler = await _start_test_server() 214 self._channel = aio.insecure_channel(addr) 215 216 async def tearDown(self): 217 await self._channel.close() 218 await self._server.stop(None) 219 220 async def test_unary_unary(self): 221 unary_unary_call = self._channel.unary_unary(_SIMPLE_UNARY_UNARY) 222 response = await unary_unary_call(_REQUEST) 223 self.assertEqual(response, _RESPONSE) 224 225 async def test_unary_stream_async_generator(self): 226 unary_stream_call = self._channel.unary_stream(_UNARY_STREAM_ASYNC_GEN) 227 call = unary_stream_call(_REQUEST) 228 229 response_cnt = 0 230 async for response in call: 231 response_cnt += 1 232 self.assertEqual(_RESPONSE, response) 233 234 self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) 235 self.assertEqual(await call.code(), grpc.StatusCode.OK) 236 237 async def test_unary_stream_reader_writer(self): 238 unary_stream_call = self._channel.unary_stream( 239 _UNARY_STREAM_READER_WRITER) 240 call = unary_stream_call(_REQUEST) 241 242 for _ in range(_NUM_STREAM_RESPONSES): 243 response = await call.read() 244 self.assertEqual(_RESPONSE, response) 245 246 self.assertEqual(await call.code(), grpc.StatusCode.OK) 247 248 async def test_unary_stream_evilly_mixed(self): 249 unary_stream_call = self._channel.unary_stream( 250 _UNARY_STREAM_EVILLY_MIXED) 251 call = unary_stream_call(_REQUEST) 252 253 # Uses reader API 254 self.assertEqual(_RESPONSE, await call.read()) 255 256 # Uses async generator API, mixed! 257 with self.assertRaises(aio.UsageError): 258 async for response in call: 259 self.assertEqual(_RESPONSE, response) 260 261 async def test_stream_unary_async_generator(self): 262 stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN) 263 call = stream_unary_call() 264 265 for _ in range(_NUM_STREAM_REQUESTS): 266 await call.write(_REQUEST) 267 await call.done_writing() 268 269 response = await call 270 self.assertEqual(_RESPONSE, response) 271 self.assertEqual(await call.code(), grpc.StatusCode.OK) 272 273 async def test_stream_unary_reader_writer(self): 274 stream_unary_call = self._channel.stream_unary( 275 _STREAM_UNARY_READER_WRITER) 276 call = stream_unary_call() 277 278 for _ in range(_NUM_STREAM_REQUESTS): 279 await call.write(_REQUEST) 280 await call.done_writing() 281 282 response = await call 283 self.assertEqual(_RESPONSE, response) 284 self.assertEqual(await call.code(), grpc.StatusCode.OK) 285 286 async def test_stream_unary_evilly_mixed(self): 287 stream_unary_call = self._channel.stream_unary( 288 _STREAM_UNARY_EVILLY_MIXED) 289 call = stream_unary_call() 290 291 for _ in range(_NUM_STREAM_REQUESTS): 292 await call.write(_REQUEST) 293 await call.done_writing() 294 295 response = await call 296 self.assertEqual(_RESPONSE, response) 297 self.assertEqual(await call.code(), grpc.StatusCode.OK) 298 299 async def test_stream_stream_async_generator(self): 300 stream_stream_call = self._channel.stream_stream( 301 _STREAM_STREAM_ASYNC_GEN) 302 call = stream_stream_call() 303 304 for _ in range(_NUM_STREAM_REQUESTS): 305 await call.write(_REQUEST) 306 await call.done_writing() 307 308 for _ in range(_NUM_STREAM_RESPONSES): 309 response = await call.read() 310 self.assertEqual(_RESPONSE, response) 311 312 self.assertEqual(await call.code(), grpc.StatusCode.OK) 313 314 async def test_stream_stream_reader_writer(self): 315 stream_stream_call = self._channel.stream_stream( 316 _STREAM_STREAM_READER_WRITER) 317 call = stream_stream_call() 318 319 for _ in range(_NUM_STREAM_REQUESTS): 320 await call.write(_REQUEST) 321 await call.done_writing() 322 323 for _ in range(_NUM_STREAM_RESPONSES): 324 response = await call.read() 325 self.assertEqual(_RESPONSE, response) 326 327 self.assertEqual(await call.code(), grpc.StatusCode.OK) 328 329 async def test_stream_stream_evilly_mixed(self): 330 stream_stream_call = self._channel.stream_stream( 331 _STREAM_STREAM_EVILLY_MIXED) 332 call = stream_stream_call() 333 334 for _ in range(_NUM_STREAM_REQUESTS): 335 await call.write(_REQUEST) 336 await call.done_writing() 337 338 for _ in range(_NUM_STREAM_RESPONSES): 339 response = await call.read() 340 self.assertEqual(_RESPONSE, response) 341 342 self.assertEqual(await call.code(), grpc.StatusCode.OK) 343 344 async def test_shutdown(self): 345 await self._server.stop(None) 346 # Ensures no SIGSEGV triggered, and ends within timeout. 347 348 async def test_shutdown_after_call(self): 349 await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) 350 351 await self._server.stop(None) 352 353 async def test_graceful_shutdown_success(self): 354 call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) 355 await self._generic_handler.wait_for_call() 356 357 shutdown_start_time = time.time() 358 await self._server.stop(test_constants.SHORT_TIMEOUT) 359 grace_period_length = time.time() - shutdown_start_time 360 self.assertGreater(grace_period_length, 361 test_constants.SHORT_TIMEOUT / 3) 362 363 # Validates the states. 364 self.assertEqual(_RESPONSE, await call) 365 self.assertTrue(call.done()) 366 367 async def test_graceful_shutdown_failed(self): 368 call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) 369 await self._generic_handler.wait_for_call() 370 371 await self._server.stop(test_constants.SHORT_TIMEOUT) 372 373 with self.assertRaises(aio.AioRpcError) as exception_context: 374 await call 375 self.assertEqual(grpc.StatusCode.UNAVAILABLE, 376 exception_context.exception.code()) 377 378 async def test_concurrent_graceful_shutdown(self): 379 call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) 380 await self._generic_handler.wait_for_call() 381 382 # Expects the shortest grace period to be effective. 383 shutdown_start_time = time.time() 384 await asyncio.gather( 385 self._server.stop(test_constants.LONG_TIMEOUT), 386 self._server.stop(test_constants.SHORT_TIMEOUT), 387 self._server.stop(test_constants.LONG_TIMEOUT), 388 ) 389 grace_period_length = time.time() - shutdown_start_time 390 self.assertGreater(grace_period_length, 391 test_constants.SHORT_TIMEOUT / 3) 392 393 self.assertEqual(_RESPONSE, await call) 394 self.assertTrue(call.done()) 395 396 async def test_concurrent_graceful_shutdown_immediate(self): 397 call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) 398 await self._generic_handler.wait_for_call() 399 400 # Expects no grace period, due to the "server.stop(None)". 401 await asyncio.gather( 402 self._server.stop(test_constants.LONG_TIMEOUT), 403 self._server.stop(None), 404 self._server.stop(test_constants.SHORT_TIMEOUT), 405 self._server.stop(test_constants.LONG_TIMEOUT), 406 ) 407 408 with self.assertRaises(aio.AioRpcError) as exception_context: 409 await call 410 self.assertEqual(grpc.StatusCode.UNAVAILABLE, 411 exception_context.exception.code()) 412 413 async def test_shutdown_before_call(self): 414 await self._server.stop(None) 415 416 # Ensures the server is cleaned up at this point. 417 # Some proper exception should be raised. 418 with self.assertRaises(aio.AioRpcError): 419 await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) 420 421 async def test_unimplemented(self): 422 call = self._channel.unary_unary(_UNIMPLEMENTED_METHOD) 423 with self.assertRaises(aio.AioRpcError) as exception_context: 424 await call(_REQUEST) 425 rpc_error = exception_context.exception 426 self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) 427 428 async def test_shutdown_during_stream_stream(self): 429 stream_stream_call = self._channel.stream_stream( 430 _STREAM_STREAM_ASYNC_GEN) 431 call = stream_stream_call() 432 433 # Don't half close the RPC yet, keep it alive. 434 await call.write(_REQUEST) 435 await self._server.stop(None) 436 437 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) 438 # No segfault 439 440 async def test_error_in_stream_stream(self): 441 stream_stream_call = self._channel.stream_stream( 442 _ERROR_IN_STREAM_STREAM) 443 call = stream_stream_call() 444 445 # Don't half close the RPC yet, keep it alive. 446 await call.write(_REQUEST) 447 448 # Don't segfault here 449 self.assertEqual(grpc.StatusCode.UNKNOWN, await call.code()) 450 451 async def test_error_without_raise_in_unary_unary(self): 452 call = self._channel.unary_unary(_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY)( 453 _REQUEST) 454 455 with self.assertRaises(aio.AioRpcError) as exception_context: 456 await call 457 458 rpc_error = exception_context.exception 459 self.assertEqual(grpc.StatusCode.INTERNAL, rpc_error.code()) 460 461 async def test_error_without_raise_in_stream_stream(self): 462 call = self._channel.stream_stream( 463 _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM)() 464 465 for _ in range(_NUM_STREAM_REQUESTS): 466 await call.write(_REQUEST) 467 await call.done_writing() 468 469 self.assertEqual(grpc.StatusCode.INTERNAL, await call.code()) 470 471 async def test_port_binding_exception(self): 472 server = aio.server(options=(('grpc.so_reuseport', 0),)) 473 port = server.add_insecure_port('localhost:0') 474 bind_address = "localhost:%d" % port 475 476 with self.assertRaises(RuntimeError): 477 server.add_insecure_port(bind_address) 478 479 server_credentials = grpc.ssl_server_credentials([ 480 (resources.private_key(), resources.certificate_chain()) 481 ]) 482 with self.assertRaises(RuntimeError): 483 server.add_secure_port(bind_address, server_credentials) 484 485 async def test_maximum_concurrent_rpcs(self): 486 # Build the server with concurrent rpc argument 487 server = aio.server(maximum_concurrent_rpcs=_MAXIMUM_CONCURRENT_RPCS) 488 port = server.add_insecure_port('localhost:0') 489 bind_address = "localhost:%d" % port 490 server.add_generic_rpc_handlers((_GenericHandler(),)) 491 await server.start() 492 # Build the channel 493 channel = aio.insecure_channel(bind_address) 494 # Deplete the concurrent quota with 3 times of max RPCs 495 rpcs = [] 496 for _ in range(3 * _MAXIMUM_CONCURRENT_RPCS): 497 rpcs.append(channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)) 498 task = self.loop.create_task( 499 asyncio.wait(rpcs, return_when=asyncio.FIRST_EXCEPTION)) 500 # Each batch took test_constants.SHORT_TIMEOUT /2 501 start_time = time.time() 502 await task 503 elapsed_time = time.time() - start_time 504 self.assertGreater(elapsed_time, test_constants.SHORT_TIMEOUT * 3 / 2) 505 # Clean-up 506 await channel.close() 507 await server.stop(0) 508 509 510if __name__ == '__main__': 511 logging.basicConfig(level=logging.DEBUG) 512 unittest.main(verbosity=2) 513