1# Copyright 2019 The gRPC Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests behavior of the Call classes.""" 15 16import asyncio 17import datetime 18import logging 19import random 20import unittest 21 22import grpc 23from grpc.experimental import aio 24 25from src.proto.grpc.testing import messages_pb2 26from src.proto.grpc.testing import test_pb2_grpc 27from tests_aio.unit._constants import UNREACHABLE_TARGET 28from tests_aio.unit._test_base import AioTestBase 29from tests_aio.unit._test_server import start_test_server 30 31_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds() 32 33_NUM_STREAM_RESPONSES = 5 34_RESPONSE_PAYLOAD_SIZE = 42 35_REQUEST_PAYLOAD_SIZE = 7 36_LOCAL_CANCEL_DETAILS_EXPECTATION = "Locally cancelled by application!" 37_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) 38_INFINITE_INTERVAL_US = 2**31 - 1 39 40_NONDETERMINISTIC_ITERATIONS = 50 41_NONDETERMINISTIC_SERVER_SLEEP_MAX_US = 1000 42 43 44class _MulticallableTestMixin: 45 async def setUp(self): 46 address, self._server = await start_test_server() 47 self._channel = aio.insecure_channel(address) 48 self._stub = test_pb2_grpc.TestServiceStub(self._channel) 49 50 async def tearDown(self): 51 await self._channel.close() 52 await self._server.stop(None) 53 54 55class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): 56 async def test_call_to_string(self): 57 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 58 59 self.assertTrue(str(call) is not None) 60 self.assertTrue(repr(call) is not None) 61 62 await call 63 64 self.assertTrue(str(call) is not None) 65 self.assertTrue(repr(call) is not None) 66 67 async def test_call_ok(self): 68 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 69 70 self.assertFalse(call.done()) 71 72 response = await call 73 74 self.assertTrue(call.done()) 75 self.assertIsInstance(response, messages_pb2.SimpleResponse) 76 self.assertEqual(await call.code(), grpc.StatusCode.OK) 77 78 # Response is cached at call object level, reentrance 79 # returns again the same response 80 response_retry = await call 81 self.assertIs(response, response_retry) 82 83 async def test_call_rpc_error(self): 84 async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: 85 stub = test_pb2_grpc.TestServiceStub(channel) 86 87 call = stub.UnaryCall(messages_pb2.SimpleRequest()) 88 89 with self.assertRaises(aio.AioRpcError) as exception_context: 90 await call 91 92 self.assertEqual( 93 grpc.StatusCode.UNAVAILABLE, exception_context.exception.code() 94 ) 95 96 self.assertTrue(call.done()) 97 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) 98 99 async def test_call_code_awaitable(self): 100 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 101 self.assertEqual(await call.code(), grpc.StatusCode.OK) 102 103 async def test_call_details_awaitable(self): 104 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 105 self.assertEqual("", await call.details()) 106 107 async def test_call_initial_metadata_awaitable(self): 108 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 109 self.assertEqual(aio.Metadata(), await call.initial_metadata()) 110 111 async def test_call_trailing_metadata_awaitable(self): 112 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 113 self.assertEqual(aio.Metadata(), await call.trailing_metadata()) 114 115 async def test_call_initial_metadata_cancelable(self): 116 coro_started = asyncio.Event() 117 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 118 119 async def coro(): 120 coro_started.set() 121 await call.initial_metadata() 122 123 task = self.loop.create_task(coro()) 124 await coro_started.wait() 125 task.cancel() 126 127 # Test that initial metadata can still be asked thought 128 # a cancellation happened with the previous task 129 self.assertEqual(aio.Metadata(), await call.initial_metadata()) 130 131 async def test_call_initial_metadata_multiple_waiters(self): 132 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 133 134 async def coro(): 135 return await call.initial_metadata() 136 137 task1 = self.loop.create_task(coro()) 138 task2 = self.loop.create_task(coro()) 139 140 await call 141 expected = [aio.Metadata() for _ in range(2)] 142 self.assertEqual(expected, await asyncio.gather(*[task1, task2])) 143 144 async def test_call_code_cancelable(self): 145 coro_started = asyncio.Event() 146 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 147 148 async def coro(): 149 coro_started.set() 150 await call.code() 151 152 task = self.loop.create_task(coro()) 153 await coro_started.wait() 154 task.cancel() 155 156 # Test that code can still be asked thought 157 # a cancellation happened with the previous task 158 self.assertEqual(grpc.StatusCode.OK, await call.code()) 159 160 async def test_call_code_multiple_waiters(self): 161 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 162 163 async def coro(): 164 return await call.code() 165 166 task1 = self.loop.create_task(coro()) 167 task2 = self.loop.create_task(coro()) 168 169 await call 170 171 self.assertEqual( 172 [grpc.StatusCode.OK, grpc.StatusCode.OK], 173 await asyncio.gather(task1, task2), 174 ) 175 176 async def test_cancel_unary_unary(self): 177 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 178 179 self.assertFalse(call.cancelled()) 180 181 self.assertTrue(call.cancel()) 182 self.assertFalse(call.cancel()) 183 184 with self.assertRaises(asyncio.CancelledError): 185 await call 186 187 # The info in the RpcError should match the info in Call object. 188 self.assertTrue(call.cancelled()) 189 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) 190 self.assertEqual( 191 await call.details(), "Locally cancelled by application!" 192 ) 193 194 async def test_cancel_unary_unary_in_task(self): 195 coro_started = asyncio.Event() 196 call = self._stub.EmptyCall(messages_pb2.SimpleRequest()) 197 198 async def another_coro(): 199 coro_started.set() 200 await call 201 202 task = self.loop.create_task(another_coro()) 203 await coro_started.wait() 204 205 self.assertFalse(task.done()) 206 task.cancel() 207 208 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 209 210 with self.assertRaises(asyncio.CancelledError): 211 await task 212 213 async def test_passing_credentials_fails_over_insecure_channel(self): 214 call_credentials = grpc.composite_call_credentials( 215 grpc.access_token_call_credentials("abc"), 216 grpc.access_token_call_credentials("def"), 217 ) 218 with self.assertRaisesRegex( 219 aio.UsageError, "Call credentials are only valid on secure channels" 220 ): 221 self._stub.UnaryCall( 222 messages_pb2.SimpleRequest(), credentials=call_credentials 223 ) 224 225 226class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase): 227 async def test_call_rpc_error(self): 228 channel = aio.insecure_channel(UNREACHABLE_TARGET) 229 request = messages_pb2.StreamingOutputCallRequest() 230 stub = test_pb2_grpc.TestServiceStub(channel) 231 call = stub.StreamingOutputCall(request) 232 233 with self.assertRaises(aio.AioRpcError) as exception_context: 234 async for response in call: 235 pass 236 237 self.assertEqual( 238 grpc.StatusCode.UNAVAILABLE, exception_context.exception.code() 239 ) 240 241 self.assertTrue(call.done()) 242 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) 243 await channel.close() 244 245 async def test_cancel_unary_stream(self): 246 # Prepares the request 247 request = messages_pb2.StreamingOutputCallRequest() 248 for _ in range(_NUM_STREAM_RESPONSES): 249 request.response_parameters.append( 250 messages_pb2.ResponseParameters( 251 size=_RESPONSE_PAYLOAD_SIZE, 252 interval_us=_RESPONSE_INTERVAL_US, 253 ) 254 ) 255 256 # Invokes the actual RPC 257 call = self._stub.StreamingOutputCall(request) 258 self.assertFalse(call.cancelled()) 259 260 response = await call.read() 261 self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse) 262 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 263 264 self.assertTrue(call.cancel()) 265 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 266 self.assertEqual( 267 _LOCAL_CANCEL_DETAILS_EXPECTATION, await call.details() 268 ) 269 self.assertFalse(call.cancel()) 270 271 with self.assertRaises(asyncio.CancelledError): 272 await call.read() 273 self.assertTrue(call.cancelled()) 274 275 async def test_multiple_cancel_unary_stream(self): 276 # Prepares the request 277 request = messages_pb2.StreamingOutputCallRequest() 278 for _ in range(_NUM_STREAM_RESPONSES): 279 request.response_parameters.append( 280 messages_pb2.ResponseParameters( 281 size=_RESPONSE_PAYLOAD_SIZE, 282 interval_us=_RESPONSE_INTERVAL_US, 283 ) 284 ) 285 286 # Invokes the actual RPC 287 call = self._stub.StreamingOutputCall(request) 288 self.assertFalse(call.cancelled()) 289 290 response = await call.read() 291 self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse) 292 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 293 294 self.assertTrue(call.cancel()) 295 self.assertFalse(call.cancel()) 296 self.assertFalse(call.cancel()) 297 self.assertFalse(call.cancel()) 298 299 with self.assertRaises(asyncio.CancelledError): 300 await call.read() 301 302 async def test_early_cancel_unary_stream(self): 303 """Test cancellation before receiving messages.""" 304 # Prepares the request 305 request = messages_pb2.StreamingOutputCallRequest() 306 for _ in range(_NUM_STREAM_RESPONSES): 307 request.response_parameters.append( 308 messages_pb2.ResponseParameters( 309 size=_RESPONSE_PAYLOAD_SIZE, 310 interval_us=_RESPONSE_INTERVAL_US, 311 ) 312 ) 313 314 # Invokes the actual RPC 315 call = self._stub.StreamingOutputCall(request) 316 317 self.assertFalse(call.cancelled()) 318 self.assertTrue(call.cancel()) 319 self.assertFalse(call.cancel()) 320 321 with self.assertRaises(asyncio.CancelledError): 322 await call.read() 323 324 self.assertTrue(call.cancelled()) 325 326 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 327 self.assertEqual( 328 _LOCAL_CANCEL_DETAILS_EXPECTATION, await call.details() 329 ) 330 331 async def test_late_cancel_unary_stream(self): 332 """Test cancellation after received all messages.""" 333 # Prepares the request 334 request = messages_pb2.StreamingOutputCallRequest() 335 for _ in range(_NUM_STREAM_RESPONSES): 336 request.response_parameters.append( 337 messages_pb2.ResponseParameters( 338 size=_RESPONSE_PAYLOAD_SIZE, 339 ) 340 ) 341 342 # Invokes the actual RPC 343 call = self._stub.StreamingOutputCall(request) 344 345 for _ in range(_NUM_STREAM_RESPONSES): 346 response = await call.read() 347 self.assertIs( 348 type(response), messages_pb2.StreamingOutputCallResponse 349 ) 350 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 351 352 # After all messages received, it is possible that the final state 353 # is received or on its way. It's basically a data race, so our 354 # expectation here is do not crash :) 355 call.cancel() 356 self.assertIn( 357 await call.code(), [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED] 358 ) 359 360 async def test_too_many_reads_unary_stream(self): 361 """Test calling read after received all messages fails.""" 362 # Prepares the request 363 request = messages_pb2.StreamingOutputCallRequest() 364 for _ in range(_NUM_STREAM_RESPONSES): 365 request.response_parameters.append( 366 messages_pb2.ResponseParameters( 367 size=_RESPONSE_PAYLOAD_SIZE, 368 ) 369 ) 370 371 # Invokes the actual RPC 372 call = self._stub.StreamingOutputCall(request) 373 374 for _ in range(_NUM_STREAM_RESPONSES): 375 response = await call.read() 376 self.assertIs( 377 type(response), messages_pb2.StreamingOutputCallResponse 378 ) 379 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 380 self.assertIs(await call.read(), aio.EOF) 381 382 # After the RPC is finished, further reads will lead to exception. 383 self.assertEqual(await call.code(), grpc.StatusCode.OK) 384 self.assertIs(await call.read(), aio.EOF) 385 386 async def test_unary_stream_async_generator(self): 387 """Sunny day test case for unary_stream.""" 388 # Prepares the request 389 request = messages_pb2.StreamingOutputCallRequest() 390 for _ in range(_NUM_STREAM_RESPONSES): 391 request.response_parameters.append( 392 messages_pb2.ResponseParameters( 393 size=_RESPONSE_PAYLOAD_SIZE, 394 ) 395 ) 396 397 # Invokes the actual RPC 398 call = self._stub.StreamingOutputCall(request) 399 self.assertFalse(call.cancelled()) 400 401 async for response in call: 402 self.assertIs( 403 type(response), messages_pb2.StreamingOutputCallResponse 404 ) 405 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 406 407 self.assertEqual(await call.code(), grpc.StatusCode.OK) 408 409 async def test_cancel_unary_stream_with_many_interleavings(self): 410 """A cheap alternative to a structured fuzzer. 411 412 Certain classes of error only appear for very specific interleavings of 413 coroutines. Rather than inserting semi-private asyncio.Events throughout 414 the implementation on which to coordinate and explicitly waiting on those 415 in tests, we instead search for bugs over the space of interleavings by 416 stochastically varying the durations of certain events within the test. 417 """ 418 419 # We range over several orders of magnitude to ensure that switching platforms 420 # (i.e. to slow CI machines) does not result in this test becoming a no-op. 421 sleep_ranges = (10.0**-i for i in range(1, 4)) 422 for sleep_range in sleep_ranges: 423 for _ in range(_NONDETERMINISTIC_ITERATIONS): 424 interval_us = random.randrange( 425 _NONDETERMINISTIC_SERVER_SLEEP_MAX_US 426 ) 427 sleep_secs = sleep_range * random.random() 428 429 coro_started = asyncio.Event() 430 431 # Configs the server method to block forever 432 request = messages_pb2.StreamingOutputCallRequest() 433 request.response_parameters.append( 434 messages_pb2.ResponseParameters( 435 size=1, 436 interval_us=interval_us, 437 ) 438 ) 439 440 # Invokes the actual RPC 441 call = self._stub.StreamingOutputCall(request) 442 443 unhandled_error = False 444 445 async def another_coro(): 446 nonlocal unhandled_error 447 coro_started.set() 448 try: 449 await call.read() 450 except asyncio.CancelledError: 451 pass 452 except Exception as e: 453 unhandled_error = True 454 raise 455 456 task = self.loop.create_task(another_coro()) 457 await coro_started.wait() 458 await asyncio.sleep(sleep_secs) 459 460 task.cancel() 461 462 try: 463 await task 464 except asyncio.CancelledError: 465 pass 466 467 self.assertFalse(unhandled_error) 468 469 async def test_cancel_unary_stream_in_task_using_read(self): 470 coro_started = asyncio.Event() 471 472 # Configs the server method to block forever 473 request = messages_pb2.StreamingOutputCallRequest() 474 request.response_parameters.append( 475 messages_pb2.ResponseParameters( 476 size=_RESPONSE_PAYLOAD_SIZE, 477 interval_us=_INFINITE_INTERVAL_US, 478 ) 479 ) 480 481 # Invokes the actual RPC 482 call = self._stub.StreamingOutputCall(request) 483 484 async def another_coro(): 485 coro_started.set() 486 await call.read() 487 488 task = self.loop.create_task(another_coro()) 489 await coro_started.wait() 490 491 self.assertFalse(task.done()) 492 task.cancel() 493 494 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 495 496 with self.assertRaises(asyncio.CancelledError): 497 await task 498 499 async def test_cancel_unary_stream_in_task_using_async_for(self): 500 coro_started = asyncio.Event() 501 502 # Configs the server method to block forever 503 request = messages_pb2.StreamingOutputCallRequest() 504 request.response_parameters.append( 505 messages_pb2.ResponseParameters( 506 size=_RESPONSE_PAYLOAD_SIZE, 507 interval_us=_INFINITE_INTERVAL_US, 508 ) 509 ) 510 511 # Invokes the actual RPC 512 call = self._stub.StreamingOutputCall(request) 513 514 async def another_coro(): 515 coro_started.set() 516 async for _ in call: 517 pass 518 519 task = self.loop.create_task(another_coro()) 520 await coro_started.wait() 521 522 self.assertFalse(task.done()) 523 task.cancel() 524 525 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 526 527 with self.assertRaises(asyncio.CancelledError): 528 await task 529 530 async def test_time_remaining(self): 531 request = messages_pb2.StreamingOutputCallRequest() 532 # First message comes back immediately 533 request.response_parameters.append( 534 messages_pb2.ResponseParameters( 535 size=_RESPONSE_PAYLOAD_SIZE, 536 ) 537 ) 538 # Second message comes back after a unit of wait time 539 request.response_parameters.append( 540 messages_pb2.ResponseParameters( 541 size=_RESPONSE_PAYLOAD_SIZE, 542 interval_us=_RESPONSE_INTERVAL_US, 543 ) 544 ) 545 546 call = self._stub.StreamingOutputCall( 547 request, timeout=_SHORT_TIMEOUT_S * 2 548 ) 549 550 response = await call.read() 551 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 552 553 # Should be around the same as the timeout 554 remained_time = call.time_remaining() 555 self.assertGreater(remained_time, _SHORT_TIMEOUT_S * 3 / 2) 556 self.assertLess(remained_time, _SHORT_TIMEOUT_S * 5 / 2) 557 558 response = await call.read() 559 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 560 561 # Should be around the timeout minus a unit of wait time 562 remained_time = call.time_remaining() 563 self.assertGreater(remained_time, _SHORT_TIMEOUT_S / 2) 564 self.assertLess(remained_time, _SHORT_TIMEOUT_S * 3 / 2) 565 566 self.assertEqual(grpc.StatusCode.OK, await call.code()) 567 568 async def test_empty_responses(self): 569 # Prepares the request 570 request = messages_pb2.StreamingOutputCallRequest() 571 for _ in range(_NUM_STREAM_RESPONSES): 572 request.response_parameters.append( 573 messages_pb2.ResponseParameters() 574 ) 575 576 # Invokes the actual RPC 577 call = self._stub.StreamingOutputCall(request) 578 579 for _ in range(_NUM_STREAM_RESPONSES): 580 response = await call.read() 581 self.assertIs( 582 type(response), messages_pb2.StreamingOutputCallResponse 583 ) 584 self.assertEqual(b"", response.SerializeToString()) 585 586 self.assertEqual(grpc.StatusCode.OK, await call.code()) 587 588 589class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase): 590 async def test_cancel_stream_unary(self): 591 call = self._stub.StreamingInputCall() 592 593 # Prepares the request 594 payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) 595 request = messages_pb2.StreamingInputCallRequest(payload=payload) 596 597 # Sends out requests 598 for _ in range(_NUM_STREAM_RESPONSES): 599 await call.write(request) 600 601 # Cancels the RPC 602 self.assertFalse(call.done()) 603 self.assertFalse(call.cancelled()) 604 self.assertTrue(call.cancel()) 605 self.assertTrue(call.cancelled()) 606 607 await call.done_writing() 608 609 with self.assertRaises(asyncio.CancelledError): 610 await call 611 612 async def test_early_cancel_stream_unary(self): 613 call = self._stub.StreamingInputCall() 614 615 # Cancels the RPC 616 self.assertFalse(call.done()) 617 self.assertFalse(call.cancelled()) 618 self.assertTrue(call.cancel()) 619 self.assertTrue(call.cancelled()) 620 621 with self.assertRaises(asyncio.InvalidStateError): 622 await call.write(messages_pb2.StreamingInputCallRequest()) 623 624 # Should be no-op 625 await call.done_writing() 626 627 with self.assertRaises(asyncio.CancelledError): 628 await call 629 630 async def test_write_after_done_writing(self): 631 call = self._stub.StreamingInputCall() 632 633 # Prepares the request 634 payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) 635 request = messages_pb2.StreamingInputCallRequest(payload=payload) 636 637 # Sends out requests 638 for _ in range(_NUM_STREAM_RESPONSES): 639 await call.write(request) 640 641 # Should be no-op 642 await call.done_writing() 643 644 with self.assertRaises(asyncio.InvalidStateError): 645 await call.write(messages_pb2.StreamingInputCallRequest()) 646 647 response = await call 648 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) 649 self.assertEqual( 650 _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 651 response.aggregated_payload_size, 652 ) 653 654 self.assertEqual(await call.code(), grpc.StatusCode.OK) 655 656 async def test_error_in_async_generator(self): 657 # Server will pause between responses 658 request = messages_pb2.StreamingOutputCallRequest() 659 request.response_parameters.append( 660 messages_pb2.ResponseParameters( 661 size=_RESPONSE_PAYLOAD_SIZE, 662 interval_us=_RESPONSE_INTERVAL_US, 663 ) 664 ) 665 666 # We expect the request iterator to receive the exception 667 request_iterator_received_the_exception = asyncio.Event() 668 669 async def request_iterator(): 670 with self.assertRaises(asyncio.CancelledError): 671 for _ in range(_NUM_STREAM_RESPONSES): 672 yield request 673 await asyncio.sleep(_SHORT_TIMEOUT_S) 674 request_iterator_received_the_exception.set() 675 676 call = self._stub.StreamingInputCall(request_iterator()) 677 678 # Cancel the RPC after at least one response 679 async def cancel_later(): 680 await asyncio.sleep(_SHORT_TIMEOUT_S * 2) 681 call.cancel() 682 683 cancel_later_task = self.loop.create_task(cancel_later()) 684 685 with self.assertRaises(asyncio.CancelledError): 686 await call 687 688 await request_iterator_received_the_exception.wait() 689 690 # No failures in the cancel later task! 691 await cancel_later_task 692 693 async def test_normal_iterable_requests(self): 694 # Prepares the request 695 payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) 696 request = messages_pb2.StreamingInputCallRequest(payload=payload) 697 requests = [request] * _NUM_STREAM_RESPONSES 698 699 # Sends out requests 700 call = self._stub.StreamingInputCall(requests) 701 702 # RPC should succeed 703 response = await call 704 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) 705 self.assertEqual( 706 _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 707 response.aggregated_payload_size, 708 ) 709 710 self.assertEqual(await call.code(), grpc.StatusCode.OK) 711 712 async def test_call_rpc_error(self): 713 async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: 714 stub = test_pb2_grpc.TestServiceStub(channel) 715 716 # The error should be raised automatically without any traffic. 717 call = stub.StreamingInputCall() 718 with self.assertRaises(aio.AioRpcError) as exception_context: 719 await call 720 721 self.assertEqual( 722 grpc.StatusCode.UNAVAILABLE, exception_context.exception.code() 723 ) 724 725 self.assertTrue(call.done()) 726 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) 727 728 async def test_timeout(self): 729 call = self._stub.StreamingInputCall(timeout=_SHORT_TIMEOUT_S) 730 731 # The error should be raised automatically without any traffic. 732 with self.assertRaises(aio.AioRpcError) as exception_context: 733 await call 734 735 rpc_error = exception_context.exception 736 self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code()) 737 self.assertTrue(call.done()) 738 self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await call.code()) 739 740 741# Prepares the request that stream in a ping-pong manner. 742_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest() 743_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append( 744 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) 745) 746_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE = ( 747 messages_pb2.StreamingOutputCallRequest() 748) 749_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE.response_parameters.append( 750 messages_pb2.ResponseParameters() 751) 752 753 754class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase): 755 async def test_cancel(self): 756 # Invokes the actual RPC 757 call = self._stub.FullDuplexCall() 758 759 for _ in range(_NUM_STREAM_RESPONSES): 760 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 761 response = await call.read() 762 self.assertIsInstance( 763 response, messages_pb2.StreamingOutputCallResponse 764 ) 765 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 766 767 # Cancels the RPC 768 self.assertFalse(call.done()) 769 self.assertFalse(call.cancelled()) 770 self.assertTrue(call.cancel()) 771 self.assertTrue(call.cancelled()) 772 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 773 774 async def test_cancel_with_pending_read(self): 775 call = self._stub.FullDuplexCall() 776 777 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 778 779 # Cancels the RPC 780 self.assertFalse(call.done()) 781 self.assertFalse(call.cancelled()) 782 self.assertTrue(call.cancel()) 783 self.assertTrue(call.cancelled()) 784 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 785 786 async def test_cancel_with_ongoing_read(self): 787 call = self._stub.FullDuplexCall() 788 coro_started = asyncio.Event() 789 790 async def read_coro(): 791 coro_started.set() 792 await call.read() 793 794 read_task = self.loop.create_task(read_coro()) 795 await coro_started.wait() 796 self.assertFalse(read_task.done()) 797 798 # Cancels the RPC 799 self.assertFalse(call.done()) 800 self.assertFalse(call.cancelled()) 801 self.assertTrue(call.cancel()) 802 self.assertTrue(call.cancelled()) 803 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 804 805 async def test_early_cancel(self): 806 call = self._stub.FullDuplexCall() 807 808 # Cancels the RPC 809 self.assertFalse(call.done()) 810 self.assertFalse(call.cancelled()) 811 self.assertTrue(call.cancel()) 812 self.assertTrue(call.cancelled()) 813 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 814 815 async def test_cancel_after_done_writing(self): 816 call = self._stub.FullDuplexCall() 817 request_with_delay = messages_pb2.StreamingOutputCallRequest() 818 request_with_delay.response_parameters.append( 819 messages_pb2.ResponseParameters(interval_us=10000) 820 ) 821 await call.write(request_with_delay) 822 await call.write(request_with_delay) 823 await call.done_writing() 824 825 # Cancels the RPC 826 self.assertFalse(call.cancelled()) 827 self.assertTrue(call.cancel()) 828 self.assertTrue(call.cancelled()) 829 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 830 831 async def test_late_cancel(self): 832 call = self._stub.FullDuplexCall() 833 await call.done_writing() 834 self.assertEqual(grpc.StatusCode.OK, await call.code()) 835 836 # Cancels the RPC 837 self.assertTrue(call.done()) 838 self.assertFalse(call.cancelled()) 839 self.assertFalse(call.cancel()) 840 self.assertFalse(call.cancelled()) 841 842 # Status is still OK 843 self.assertEqual(grpc.StatusCode.OK, await call.code()) 844 845 async def test_async_generator(self): 846 async def request_generator(): 847 yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE 848 yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE 849 850 call = self._stub.FullDuplexCall(request_generator()) 851 async for response in call: 852 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 853 854 self.assertEqual(await call.code(), grpc.StatusCode.OK) 855 856 async def test_too_many_reads(self): 857 async def request_generator(): 858 for _ in range(_NUM_STREAM_RESPONSES): 859 yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE 860 861 call = self._stub.FullDuplexCall(request_generator()) 862 for _ in range(_NUM_STREAM_RESPONSES): 863 response = await call.read() 864 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 865 self.assertIs(await call.read(), aio.EOF) 866 867 self.assertEqual(await call.code(), grpc.StatusCode.OK) 868 # After the RPC finished, the read should also produce EOF 869 self.assertIs(await call.read(), aio.EOF) 870 871 async def test_read_write_after_done_writing(self): 872 call = self._stub.FullDuplexCall() 873 874 # Writes two requests, and pending two requests 875 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 876 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 877 await call.done_writing() 878 879 # Further write should fail 880 with self.assertRaises(asyncio.InvalidStateError): 881 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 882 883 # But read should be unaffected 884 response = await call.read() 885 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 886 response = await call.read() 887 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 888 889 self.assertEqual(await call.code(), grpc.StatusCode.OK) 890 891 async def test_error_in_async_generator(self): 892 # Server will pause between responses 893 request = messages_pb2.StreamingOutputCallRequest() 894 request.response_parameters.append( 895 messages_pb2.ResponseParameters( 896 size=_RESPONSE_PAYLOAD_SIZE, 897 interval_us=_RESPONSE_INTERVAL_US, 898 ) 899 ) 900 901 # We expect the request iterator to receive the exception 902 request_iterator_received_the_exception = asyncio.Event() 903 904 async def request_iterator(): 905 with self.assertRaises(asyncio.CancelledError): 906 for _ in range(_NUM_STREAM_RESPONSES): 907 yield request 908 await asyncio.sleep(_SHORT_TIMEOUT_S) 909 request_iterator_received_the_exception.set() 910 911 call = self._stub.FullDuplexCall(request_iterator()) 912 913 # Cancel the RPC after at least one response 914 async def cancel_later(): 915 await asyncio.sleep(_SHORT_TIMEOUT_S * 2) 916 call.cancel() 917 918 cancel_later_task = self.loop.create_task(cancel_later()) 919 920 with self.assertRaises(asyncio.CancelledError): 921 async for response in call: 922 self.assertEqual( 923 _RESPONSE_PAYLOAD_SIZE, len(response.payload.body) 924 ) 925 926 await request_iterator_received_the_exception.wait() 927 928 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 929 # No failures in the cancel later task! 930 await cancel_later_task 931 932 async def test_normal_iterable_requests(self): 933 requests = [_STREAM_OUTPUT_REQUEST_ONE_RESPONSE] * _NUM_STREAM_RESPONSES 934 935 call = self._stub.FullDuplexCall(iter(requests)) 936 async for response in call: 937 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 938 939 self.assertEqual(await call.code(), grpc.StatusCode.OK) 940 941 async def test_empty_ping_pong(self): 942 call = self._stub.FullDuplexCall() 943 for _ in range(_NUM_STREAM_RESPONSES): 944 await call.write(_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE) 945 response = await call.read() 946 self.assertEqual(b"", response.SerializeToString()) 947 await call.done_writing() 948 self.assertEqual(await call.code(), grpc.StatusCode.OK) 949 950 951if __name__ == "__main__": 952 logging.basicConfig(level=logging.DEBUG) 953 unittest.main(verbosity=2) 954