1# Copyright 2019 The gRPC Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import asyncio 16import gc 17import logging 18import time 19import unittest 20 21import grpc 22from grpc.experimental import aio 23 24from tests.unit.framework.common import test_constants 25from tests_aio.unit._test_base import AioTestBase 26 27_SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary' 28_BLOCK_FOREVER = '/test/BlockForever' 29_BLOCK_BRIEFLY = '/test/BlockBriefly' 30_UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen' 31_UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter' 32_UNARY_STREAM_EVILLY_MIXED = '/test/UnaryStreamEvillyMixed' 33_STREAM_UNARY_ASYNC_GEN = '/test/StreamUnaryAsyncGen' 34_STREAM_UNARY_READER_WRITER = '/test/StreamUnaryReaderWriter' 35_STREAM_UNARY_EVILLY_MIXED = '/test/StreamUnaryEvillyMixed' 36_STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen' 37_STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter' 38_STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed' 39_UNIMPLEMENTED_METHOD = '/test/UnimplementedMethod' 40_ERROR_IN_STREAM_STREAM = '/test/ErrorInStreamStream' 41_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY = '/test/ErrorWithoutRaiseInUnaryUnary' 42_ERROR_WITHOUT_RAISE_IN_STREAM_STREAM = '/test/ErrorWithoutRaiseInStreamStream' 43 44_REQUEST = b'\x00\x00\x00' 45_RESPONSE = b'\x01\x01\x01' 46_NUM_STREAM_REQUESTS = 3 47_NUM_STREAM_RESPONSES = 5 48 49 50class _GenericHandler(grpc.GenericRpcHandler): 51 52 def __init__(self): 53 self._called = asyncio.get_event_loop().create_future() 54 self._routing_table = { 55 _SIMPLE_UNARY_UNARY: 56 grpc.unary_unary_rpc_method_handler(self._unary_unary), 57 _BLOCK_FOREVER: 58 grpc.unary_unary_rpc_method_handler(self._block_forever), 59 _BLOCK_BRIEFLY: 60 grpc.unary_unary_rpc_method_handler(self._block_briefly), 61 _UNARY_STREAM_ASYNC_GEN: 62 grpc.unary_stream_rpc_method_handler( 63 self._unary_stream_async_gen), 64 _UNARY_STREAM_READER_WRITER: 65 grpc.unary_stream_rpc_method_handler( 66 self._unary_stream_reader_writer), 67 _UNARY_STREAM_EVILLY_MIXED: 68 grpc.unary_stream_rpc_method_handler( 69 self._unary_stream_evilly_mixed), 70 _STREAM_UNARY_ASYNC_GEN: 71 grpc.stream_unary_rpc_method_handler( 72 self._stream_unary_async_gen), 73 _STREAM_UNARY_READER_WRITER: 74 grpc.stream_unary_rpc_method_handler( 75 self._stream_unary_reader_writer), 76 _STREAM_UNARY_EVILLY_MIXED: 77 grpc.stream_unary_rpc_method_handler( 78 self._stream_unary_evilly_mixed), 79 _STREAM_STREAM_ASYNC_GEN: 80 grpc.stream_stream_rpc_method_handler( 81 self._stream_stream_async_gen), 82 _STREAM_STREAM_READER_WRITER: 83 grpc.stream_stream_rpc_method_handler( 84 self._stream_stream_reader_writer), 85 _STREAM_STREAM_EVILLY_MIXED: 86 grpc.stream_stream_rpc_method_handler( 87 self._stream_stream_evilly_mixed), 88 _ERROR_IN_STREAM_STREAM: 89 grpc.stream_stream_rpc_method_handler( 90 self._error_in_stream_stream), 91 _ERROR_WITHOUT_RAISE_IN_UNARY_UNARY: 92 grpc.unary_unary_rpc_method_handler( 93 self._error_without_raise_in_unary_unary), 94 _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM: 95 grpc.stream_stream_rpc_method_handler( 96 self._error_without_raise_in_stream_stream), 97 } 98 99 @staticmethod 100 async def _unary_unary(unused_request, unused_context): 101 return _RESPONSE 102 103 async def _block_forever(self, unused_request, unused_context): 104 await asyncio.get_event_loop().create_future() 105 106 async def _block_briefly(self, unused_request, unused_context): 107 await asyncio.sleep(test_constants.SHORT_TIMEOUT / 2) 108 return _RESPONSE 109 110 async def _unary_stream_async_gen(self, unused_request, unused_context): 111 for _ in range(_NUM_STREAM_RESPONSES): 112 yield _RESPONSE 113 114 async def _unary_stream_reader_writer(self, unused_request, context): 115 for _ in range(_NUM_STREAM_RESPONSES): 116 await context.write(_RESPONSE) 117 118 async def _unary_stream_evilly_mixed(self, unused_request, context): 119 yield _RESPONSE 120 for _ in range(_NUM_STREAM_RESPONSES - 1): 121 await context.write(_RESPONSE) 122 123 async def _stream_unary_async_gen(self, request_iterator, unused_context): 124 request_count = 0 125 async for request in request_iterator: 126 assert _REQUEST == request 127 request_count += 1 128 assert _NUM_STREAM_REQUESTS == request_count 129 return _RESPONSE 130 131 async def _stream_unary_reader_writer(self, unused_request, context): 132 for _ in range(_NUM_STREAM_REQUESTS): 133 assert _REQUEST == await context.read() 134 return _RESPONSE 135 136 async def _stream_unary_evilly_mixed(self, request_iterator, context): 137 assert _REQUEST == await context.read() 138 request_count = 0 139 async for request in request_iterator: 140 assert _REQUEST == request 141 request_count += 1 142 assert _NUM_STREAM_REQUESTS - 1 == request_count 143 return _RESPONSE 144 145 async def _stream_stream_async_gen(self, request_iterator, unused_context): 146 request_count = 0 147 async for request in request_iterator: 148 assert _REQUEST == request 149 request_count += 1 150 assert _NUM_STREAM_REQUESTS == request_count 151 152 for _ in range(_NUM_STREAM_RESPONSES): 153 yield _RESPONSE 154 155 async def _stream_stream_reader_writer(self, unused_request, context): 156 for _ in range(_NUM_STREAM_REQUESTS): 157 assert _REQUEST == await context.read() 158 for _ in range(_NUM_STREAM_RESPONSES): 159 await context.write(_RESPONSE) 160 161 async def _stream_stream_evilly_mixed(self, request_iterator, context): 162 assert _REQUEST == await context.read() 163 request_count = 0 164 async for request in request_iterator: 165 assert _REQUEST == request 166 request_count += 1 167 assert _NUM_STREAM_REQUESTS - 1 == request_count 168 169 yield _RESPONSE 170 for _ in range(_NUM_STREAM_RESPONSES - 1): 171 await context.write(_RESPONSE) 172 173 async def _error_in_stream_stream(self, request_iterator, unused_context): 174 async for request in request_iterator: 175 assert _REQUEST == request 176 raise RuntimeError('A testing RuntimeError!') 177 yield _RESPONSE 178 179 async def _error_without_raise_in_unary_unary(self, request, context): 180 assert _REQUEST == request 181 context.set_code(grpc.StatusCode.INTERNAL) 182 183 async def _error_without_raise_in_stream_stream(self, request_iterator, 184 context): 185 async for request in request_iterator: 186 assert _REQUEST == request 187 context.set_code(grpc.StatusCode.INTERNAL) 188 189 def service(self, handler_details): 190 self._called.set_result(None) 191 return self._routing_table.get(handler_details.method) 192 193 async def wait_for_call(self): 194 await self._called 195 196 197async def _start_test_server(): 198 server = aio.server() 199 port = server.add_insecure_port('[::]:0') 200 generic_handler = _GenericHandler() 201 server.add_generic_rpc_handlers((generic_handler,)) 202 await server.start() 203 return 'localhost:%d' % port, server, generic_handler 204 205 206class TestServer(AioTestBase): 207 208 async def setUp(self): 209 addr, self._server, self._generic_handler = await _start_test_server() 210 self._channel = aio.insecure_channel(addr) 211 212 async def tearDown(self): 213 await self._channel.close() 214 await self._server.stop(None) 215 216 async def test_unary_unary(self): 217 unary_unary_call = self._channel.unary_unary(_SIMPLE_UNARY_UNARY) 218 response = await unary_unary_call(_REQUEST) 219 self.assertEqual(response, _RESPONSE) 220 221 async def test_unary_stream_async_generator(self): 222 unary_stream_call = self._channel.unary_stream(_UNARY_STREAM_ASYNC_GEN) 223 call = unary_stream_call(_REQUEST) 224 225 response_cnt = 0 226 async for response in call: 227 response_cnt += 1 228 self.assertEqual(_RESPONSE, response) 229 230 self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) 231 self.assertEqual(await call.code(), grpc.StatusCode.OK) 232 233 async def test_unary_stream_reader_writer(self): 234 unary_stream_call = self._channel.unary_stream( 235 _UNARY_STREAM_READER_WRITER) 236 call = unary_stream_call(_REQUEST) 237 238 for _ in range(_NUM_STREAM_RESPONSES): 239 response = await call.read() 240 self.assertEqual(_RESPONSE, response) 241 242 self.assertEqual(await call.code(), grpc.StatusCode.OK) 243 244 async def test_unary_stream_evilly_mixed(self): 245 unary_stream_call = self._channel.unary_stream( 246 _UNARY_STREAM_EVILLY_MIXED) 247 call = unary_stream_call(_REQUEST) 248 249 # Uses reader API 250 self.assertEqual(_RESPONSE, await call.read()) 251 252 # Uses async generator API, mixed! 253 with self.assertRaises(aio.UsageError): 254 async for response in call: 255 self.assertEqual(_RESPONSE, response) 256 257 async def test_stream_unary_async_generator(self): 258 stream_unary_call = self._channel.stream_unary(_STREAM_UNARY_ASYNC_GEN) 259 call = stream_unary_call() 260 261 for _ in range(_NUM_STREAM_REQUESTS): 262 await call.write(_REQUEST) 263 await call.done_writing() 264 265 response = await call 266 self.assertEqual(_RESPONSE, response) 267 self.assertEqual(await call.code(), grpc.StatusCode.OK) 268 269 async def test_stream_unary_reader_writer(self): 270 stream_unary_call = self._channel.stream_unary( 271 _STREAM_UNARY_READER_WRITER) 272 call = stream_unary_call() 273 274 for _ in range(_NUM_STREAM_REQUESTS): 275 await call.write(_REQUEST) 276 await call.done_writing() 277 278 response = await call 279 self.assertEqual(_RESPONSE, response) 280 self.assertEqual(await call.code(), grpc.StatusCode.OK) 281 282 async def test_stream_unary_evilly_mixed(self): 283 stream_unary_call = self._channel.stream_unary( 284 _STREAM_UNARY_EVILLY_MIXED) 285 call = stream_unary_call() 286 287 for _ in range(_NUM_STREAM_REQUESTS): 288 await call.write(_REQUEST) 289 await call.done_writing() 290 291 response = await call 292 self.assertEqual(_RESPONSE, response) 293 self.assertEqual(await call.code(), grpc.StatusCode.OK) 294 295 async def test_stream_stream_async_generator(self): 296 stream_stream_call = self._channel.stream_stream( 297 _STREAM_STREAM_ASYNC_GEN) 298 call = stream_stream_call() 299 300 for _ in range(_NUM_STREAM_REQUESTS): 301 await call.write(_REQUEST) 302 await call.done_writing() 303 304 for _ in range(_NUM_STREAM_RESPONSES): 305 response = await call.read() 306 self.assertEqual(_RESPONSE, response) 307 308 self.assertEqual(await call.code(), grpc.StatusCode.OK) 309 310 async def test_stream_stream_reader_writer(self): 311 stream_stream_call = self._channel.stream_stream( 312 _STREAM_STREAM_READER_WRITER) 313 call = stream_stream_call() 314 315 for _ in range(_NUM_STREAM_REQUESTS): 316 await call.write(_REQUEST) 317 await call.done_writing() 318 319 for _ in range(_NUM_STREAM_RESPONSES): 320 response = await call.read() 321 self.assertEqual(_RESPONSE, response) 322 323 self.assertEqual(await call.code(), grpc.StatusCode.OK) 324 325 async def test_stream_stream_evilly_mixed(self): 326 stream_stream_call = self._channel.stream_stream( 327 _STREAM_STREAM_EVILLY_MIXED) 328 call = stream_stream_call() 329 330 for _ in range(_NUM_STREAM_REQUESTS): 331 await call.write(_REQUEST) 332 await call.done_writing() 333 334 for _ in range(_NUM_STREAM_RESPONSES): 335 response = await call.read() 336 self.assertEqual(_RESPONSE, response) 337 338 self.assertEqual(await call.code(), grpc.StatusCode.OK) 339 340 async def test_shutdown(self): 341 await self._server.stop(None) 342 # Ensures no SIGSEGV triggered, and ends within timeout. 343 344 async def test_shutdown_after_call(self): 345 await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) 346 347 await self._server.stop(None) 348 349 async def test_graceful_shutdown_success(self): 350 call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) 351 await self._generic_handler.wait_for_call() 352 353 shutdown_start_time = time.time() 354 await self._server.stop(test_constants.SHORT_TIMEOUT) 355 grace_period_length = time.time() - shutdown_start_time 356 self.assertGreater(grace_period_length, 357 test_constants.SHORT_TIMEOUT / 3) 358 359 # Validates the states. 360 self.assertEqual(_RESPONSE, await call) 361 self.assertTrue(call.done()) 362 363 async def test_graceful_shutdown_failed(self): 364 call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) 365 await self._generic_handler.wait_for_call() 366 367 await self._server.stop(test_constants.SHORT_TIMEOUT) 368 369 with self.assertRaises(aio.AioRpcError) as exception_context: 370 await call 371 self.assertEqual(grpc.StatusCode.UNAVAILABLE, 372 exception_context.exception.code()) 373 374 async def test_concurrent_graceful_shutdown(self): 375 call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) 376 await self._generic_handler.wait_for_call() 377 378 # Expects the shortest grace period to be effective. 379 shutdown_start_time = time.time() 380 await asyncio.gather( 381 self._server.stop(test_constants.LONG_TIMEOUT), 382 self._server.stop(test_constants.SHORT_TIMEOUT), 383 self._server.stop(test_constants.LONG_TIMEOUT), 384 ) 385 grace_period_length = time.time() - shutdown_start_time 386 self.assertGreater(grace_period_length, 387 test_constants.SHORT_TIMEOUT / 3) 388 389 self.assertEqual(_RESPONSE, await call) 390 self.assertTrue(call.done()) 391 392 async def test_concurrent_graceful_shutdown_immediate(self): 393 call = self._channel.unary_unary(_BLOCK_FOREVER)(_REQUEST) 394 await self._generic_handler.wait_for_call() 395 396 # Expects no grace period, due to the "server.stop(None)". 397 await asyncio.gather( 398 self._server.stop(test_constants.LONG_TIMEOUT), 399 self._server.stop(None), 400 self._server.stop(test_constants.SHORT_TIMEOUT), 401 self._server.stop(test_constants.LONG_TIMEOUT), 402 ) 403 404 with self.assertRaises(aio.AioRpcError) as exception_context: 405 await call 406 self.assertEqual(grpc.StatusCode.UNAVAILABLE, 407 exception_context.exception.code()) 408 409 async def test_shutdown_before_call(self): 410 await self._server.stop(None) 411 412 # Ensures the server is cleaned up at this point. 413 # Some proper exception should be raised. 414 with self.assertRaises(aio.AioRpcError): 415 await self._channel.unary_unary(_SIMPLE_UNARY_UNARY)(_REQUEST) 416 417 async def test_unimplemented(self): 418 call = self._channel.unary_unary(_UNIMPLEMENTED_METHOD) 419 with self.assertRaises(aio.AioRpcError) as exception_context: 420 await call(_REQUEST) 421 rpc_error = exception_context.exception 422 self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) 423 424 async def test_shutdown_during_stream_stream(self): 425 stream_stream_call = self._channel.stream_stream( 426 _STREAM_STREAM_ASYNC_GEN) 427 call = stream_stream_call() 428 429 # Don't half close the RPC yet, keep it alive. 430 await call.write(_REQUEST) 431 await self._server.stop(None) 432 433 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) 434 # No segfault 435 436 async def test_error_in_stream_stream(self): 437 stream_stream_call = self._channel.stream_stream( 438 _ERROR_IN_STREAM_STREAM) 439 call = stream_stream_call() 440 441 # Don't half close the RPC yet, keep it alive. 442 await call.write(_REQUEST) 443 444 # Don't segfault here 445 self.assertEqual(grpc.StatusCode.UNKNOWN, await call.code()) 446 447 async def test_error_without_raise_in_unary_unary(self): 448 call = self._channel.unary_unary(_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY)( 449 _REQUEST) 450 451 with self.assertRaises(aio.AioRpcError) as exception_context: 452 await call 453 454 rpc_error = exception_context.exception 455 self.assertEqual(grpc.StatusCode.INTERNAL, rpc_error.code()) 456 457 async def test_error_without_raise_in_stream_stream(self): 458 call = self._channel.stream_stream( 459 _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM)() 460 461 for _ in range(_NUM_STREAM_REQUESTS): 462 await call.write(_REQUEST) 463 await call.done_writing() 464 465 self.assertEqual(grpc.StatusCode.INTERNAL, await call.code()) 466 467 468if __name__ == '__main__': 469 logging.basicConfig(level=logging.DEBUG) 470 unittest.main(verbosity=2) 471