1# Copyright 2019 The gRPC Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests behavior of the Call classes.""" 15 16import asyncio 17import logging 18import unittest 19import datetime 20 21import grpc 22from grpc.experimental import aio 23 24from src.proto.grpc.testing import messages_pb2, test_pb2_grpc 25from tests_aio.unit._test_base import AioTestBase 26from tests_aio.unit._test_server import start_test_server 27from tests_aio.unit._constants import UNREACHABLE_TARGET 28 29_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds() 30 31_NUM_STREAM_RESPONSES = 5 32_RESPONSE_PAYLOAD_SIZE = 42 33_REQUEST_PAYLOAD_SIZE = 7 34_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' 35_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) 36_INFINITE_INTERVAL_US = 2**31 - 1 37 38 39class _MulticallableTestMixin(): 40 41 async def setUp(self): 42 address, self._server = await start_test_server() 43 self._channel = aio.insecure_channel(address) 44 self._stub = test_pb2_grpc.TestServiceStub(self._channel) 45 46 async def tearDown(self): 47 await self._channel.close() 48 await self._server.stop(None) 49 50 51class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): 52 53 async def test_call_to_string(self): 54 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 55 56 self.assertTrue(str(call) is not None) 57 self.assertTrue(repr(call) is not None) 58 59 await call 60 61 self.assertTrue(str(call) is not None) 62 self.assertTrue(repr(call) is not None) 63 64 async def test_call_ok(self): 65 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 66 67 self.assertFalse(call.done()) 68 69 response = await call 70 71 self.assertTrue(call.done()) 72 self.assertIsInstance(response, messages_pb2.SimpleResponse) 73 self.assertEqual(await call.code(), grpc.StatusCode.OK) 74 75 # Response is cached at call object level, reentrance 76 # returns again the same response 77 response_retry = await call 78 self.assertIs(response, response_retry) 79 80 async def test_call_rpc_error(self): 81 async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: 82 stub = test_pb2_grpc.TestServiceStub(channel) 83 84 call = stub.UnaryCall(messages_pb2.SimpleRequest()) 85 86 with self.assertRaises(aio.AioRpcError) as exception_context: 87 await call 88 89 self.assertEqual(grpc.StatusCode.UNAVAILABLE, 90 exception_context.exception.code()) 91 92 self.assertTrue(call.done()) 93 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) 94 95 async def test_call_code_awaitable(self): 96 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 97 self.assertEqual(await call.code(), grpc.StatusCode.OK) 98 99 async def test_call_details_awaitable(self): 100 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 101 self.assertEqual('', await call.details()) 102 103 async def test_call_initial_metadata_awaitable(self): 104 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 105 self.assertEqual(aio.Metadata(), await call.initial_metadata()) 106 107 async def test_call_trailing_metadata_awaitable(self): 108 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 109 self.assertEqual(aio.Metadata(), await call.trailing_metadata()) 110 111 async def test_call_initial_metadata_cancelable(self): 112 coro_started = asyncio.Event() 113 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 114 115 async def coro(): 116 coro_started.set() 117 await call.initial_metadata() 118 119 task = self.loop.create_task(coro()) 120 await coro_started.wait() 121 task.cancel() 122 123 # Test that initial metadata can still be asked thought 124 # a cancellation happened with the previous task 125 self.assertEqual(aio.Metadata(), await call.initial_metadata()) 126 127 async def test_call_initial_metadata_multiple_waiters(self): 128 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 129 130 async def coro(): 131 return await call.initial_metadata() 132 133 task1 = self.loop.create_task(coro()) 134 task2 = self.loop.create_task(coro()) 135 136 await call 137 expected = [aio.Metadata() for _ in range(2)] 138 self.assertEqual(expected, await asyncio.gather(*[task1, task2])) 139 140 async def test_call_code_cancelable(self): 141 coro_started = asyncio.Event() 142 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 143 144 async def coro(): 145 coro_started.set() 146 await call.code() 147 148 task = self.loop.create_task(coro()) 149 await coro_started.wait() 150 task.cancel() 151 152 # Test that code can still be asked thought 153 # a cancellation happened with the previous task 154 self.assertEqual(grpc.StatusCode.OK, await call.code()) 155 156 async def test_call_code_multiple_waiters(self): 157 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 158 159 async def coro(): 160 return await call.code() 161 162 task1 = self.loop.create_task(coro()) 163 task2 = self.loop.create_task(coro()) 164 165 await call 166 167 self.assertEqual([grpc.StatusCode.OK, grpc.StatusCode.OK], await 168 asyncio.gather(task1, task2)) 169 170 async def test_cancel_unary_unary(self): 171 call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) 172 173 self.assertFalse(call.cancelled()) 174 175 self.assertTrue(call.cancel()) 176 self.assertFalse(call.cancel()) 177 178 with self.assertRaises(asyncio.CancelledError): 179 await call 180 181 # The info in the RpcError should match the info in Call object. 182 self.assertTrue(call.cancelled()) 183 self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) 184 self.assertEqual(await call.details(), 185 'Locally cancelled by application!') 186 187 async def test_cancel_unary_unary_in_task(self): 188 coro_started = asyncio.Event() 189 call = self._stub.EmptyCall(messages_pb2.SimpleRequest()) 190 191 async def another_coro(): 192 coro_started.set() 193 await call 194 195 task = self.loop.create_task(another_coro()) 196 await coro_started.wait() 197 198 self.assertFalse(task.done()) 199 task.cancel() 200 201 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 202 203 with self.assertRaises(asyncio.CancelledError): 204 await task 205 206 async def test_passing_credentials_fails_over_insecure_channel(self): 207 call_credentials = grpc.composite_call_credentials( 208 grpc.access_token_call_credentials("abc"), 209 grpc.access_token_call_credentials("def"), 210 ) 211 with self.assertRaisesRegex( 212 aio.UsageError, 213 "Call credentials are only valid on secure channels"): 214 self._stub.UnaryCall(messages_pb2.SimpleRequest(), 215 credentials=call_credentials) 216 217 218class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase): 219 220 async def test_call_rpc_error(self): 221 channel = aio.insecure_channel(UNREACHABLE_TARGET) 222 request = messages_pb2.StreamingOutputCallRequest() 223 stub = test_pb2_grpc.TestServiceStub(channel) 224 call = stub.StreamingOutputCall(request) 225 226 with self.assertRaises(aio.AioRpcError) as exception_context: 227 async for response in call: 228 pass 229 230 self.assertEqual(grpc.StatusCode.UNAVAILABLE, 231 exception_context.exception.code()) 232 233 self.assertTrue(call.done()) 234 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) 235 await channel.close() 236 237 async def test_cancel_unary_stream(self): 238 # Prepares the request 239 request = messages_pb2.StreamingOutputCallRequest() 240 for _ in range(_NUM_STREAM_RESPONSES): 241 request.response_parameters.append( 242 messages_pb2.ResponseParameters( 243 size=_RESPONSE_PAYLOAD_SIZE, 244 interval_us=_RESPONSE_INTERVAL_US, 245 )) 246 247 # Invokes the actual RPC 248 call = self._stub.StreamingOutputCall(request) 249 self.assertFalse(call.cancelled()) 250 251 response = await call.read() 252 self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse) 253 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 254 255 self.assertTrue(call.cancel()) 256 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 257 self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await 258 call.details()) 259 self.assertFalse(call.cancel()) 260 261 with self.assertRaises(asyncio.CancelledError): 262 await call.read() 263 self.assertTrue(call.cancelled()) 264 265 async def test_multiple_cancel_unary_stream(self): 266 # Prepares the request 267 request = messages_pb2.StreamingOutputCallRequest() 268 for _ in range(_NUM_STREAM_RESPONSES): 269 request.response_parameters.append( 270 messages_pb2.ResponseParameters( 271 size=_RESPONSE_PAYLOAD_SIZE, 272 interval_us=_RESPONSE_INTERVAL_US, 273 )) 274 275 # Invokes the actual RPC 276 call = self._stub.StreamingOutputCall(request) 277 self.assertFalse(call.cancelled()) 278 279 response = await call.read() 280 self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse) 281 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 282 283 self.assertTrue(call.cancel()) 284 self.assertFalse(call.cancel()) 285 self.assertFalse(call.cancel()) 286 self.assertFalse(call.cancel()) 287 288 with self.assertRaises(asyncio.CancelledError): 289 await call.read() 290 291 async def test_early_cancel_unary_stream(self): 292 """Test cancellation before receiving messages.""" 293 # Prepares the request 294 request = messages_pb2.StreamingOutputCallRequest() 295 for _ in range(_NUM_STREAM_RESPONSES): 296 request.response_parameters.append( 297 messages_pb2.ResponseParameters( 298 size=_RESPONSE_PAYLOAD_SIZE, 299 interval_us=_RESPONSE_INTERVAL_US, 300 )) 301 302 # Invokes the actual RPC 303 call = self._stub.StreamingOutputCall(request) 304 305 self.assertFalse(call.cancelled()) 306 self.assertTrue(call.cancel()) 307 self.assertFalse(call.cancel()) 308 309 with self.assertRaises(asyncio.CancelledError): 310 await call.read() 311 312 self.assertTrue(call.cancelled()) 313 314 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 315 self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await 316 call.details()) 317 318 async def test_late_cancel_unary_stream(self): 319 """Test cancellation after received all messages.""" 320 # Prepares the request 321 request = messages_pb2.StreamingOutputCallRequest() 322 for _ in range(_NUM_STREAM_RESPONSES): 323 request.response_parameters.append( 324 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) 325 326 # Invokes the actual RPC 327 call = self._stub.StreamingOutputCall(request) 328 329 for _ in range(_NUM_STREAM_RESPONSES): 330 response = await call.read() 331 self.assertIs(type(response), 332 messages_pb2.StreamingOutputCallResponse) 333 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 334 335 # After all messages received, it is possible that the final state 336 # is received or on its way. It's basically a data race, so our 337 # expectation here is do not crash :) 338 call.cancel() 339 self.assertIn(await call.code(), 340 [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED]) 341 342 async def test_too_many_reads_unary_stream(self): 343 """Test calling read after received all messages fails.""" 344 # Prepares the request 345 request = messages_pb2.StreamingOutputCallRequest() 346 for _ in range(_NUM_STREAM_RESPONSES): 347 request.response_parameters.append( 348 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) 349 350 # Invokes the actual RPC 351 call = self._stub.StreamingOutputCall(request) 352 353 for _ in range(_NUM_STREAM_RESPONSES): 354 response = await call.read() 355 self.assertIs(type(response), 356 messages_pb2.StreamingOutputCallResponse) 357 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 358 self.assertIs(await call.read(), aio.EOF) 359 360 # After the RPC is finished, further reads will lead to exception. 361 self.assertEqual(await call.code(), grpc.StatusCode.OK) 362 self.assertIs(await call.read(), aio.EOF) 363 364 async def test_unary_stream_async_generator(self): 365 """Sunny day test case for unary_stream.""" 366 # Prepares the request 367 request = messages_pb2.StreamingOutputCallRequest() 368 for _ in range(_NUM_STREAM_RESPONSES): 369 request.response_parameters.append( 370 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) 371 372 # Invokes the actual RPC 373 call = self._stub.StreamingOutputCall(request) 374 self.assertFalse(call.cancelled()) 375 376 async for response in call: 377 self.assertIs(type(response), 378 messages_pb2.StreamingOutputCallResponse) 379 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 380 381 self.assertEqual(await call.code(), grpc.StatusCode.OK) 382 383 async def test_cancel_unary_stream_in_task_using_read(self): 384 coro_started = asyncio.Event() 385 386 # Configs the server method to block forever 387 request = messages_pb2.StreamingOutputCallRequest() 388 request.response_parameters.append( 389 messages_pb2.ResponseParameters( 390 size=_RESPONSE_PAYLOAD_SIZE, 391 interval_us=_INFINITE_INTERVAL_US, 392 )) 393 394 # Invokes the actual RPC 395 call = self._stub.StreamingOutputCall(request) 396 397 async def another_coro(): 398 coro_started.set() 399 await call.read() 400 401 task = self.loop.create_task(another_coro()) 402 await coro_started.wait() 403 404 self.assertFalse(task.done()) 405 task.cancel() 406 407 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 408 409 with self.assertRaises(asyncio.CancelledError): 410 await task 411 412 async def test_cancel_unary_stream_in_task_using_async_for(self): 413 coro_started = asyncio.Event() 414 415 # Configs the server method to block forever 416 request = messages_pb2.StreamingOutputCallRequest() 417 request.response_parameters.append( 418 messages_pb2.ResponseParameters( 419 size=_RESPONSE_PAYLOAD_SIZE, 420 interval_us=_INFINITE_INTERVAL_US, 421 )) 422 423 # Invokes the actual RPC 424 call = self._stub.StreamingOutputCall(request) 425 426 async def another_coro(): 427 coro_started.set() 428 async for _ in call: 429 pass 430 431 task = self.loop.create_task(another_coro()) 432 await coro_started.wait() 433 434 self.assertFalse(task.done()) 435 task.cancel() 436 437 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 438 439 with self.assertRaises(asyncio.CancelledError): 440 await task 441 442 async def test_time_remaining(self): 443 request = messages_pb2.StreamingOutputCallRequest() 444 # First message comes back immediately 445 request.response_parameters.append( 446 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) 447 # Second message comes back after a unit of wait time 448 request.response_parameters.append( 449 messages_pb2.ResponseParameters( 450 size=_RESPONSE_PAYLOAD_SIZE, 451 interval_us=_RESPONSE_INTERVAL_US, 452 )) 453 454 call = self._stub.StreamingOutputCall(request, 455 timeout=_SHORT_TIMEOUT_S * 2) 456 457 response = await call.read() 458 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 459 460 # Should be around the same as the timeout 461 remained_time = call.time_remaining() 462 self.assertGreater(remained_time, _SHORT_TIMEOUT_S * 3 / 2) 463 self.assertLess(remained_time, _SHORT_TIMEOUT_S * 5 / 2) 464 465 response = await call.read() 466 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 467 468 # Should be around the timeout minus a unit of wait time 469 remained_time = call.time_remaining() 470 self.assertGreater(remained_time, _SHORT_TIMEOUT_S / 2) 471 self.assertLess(remained_time, _SHORT_TIMEOUT_S * 3 / 2) 472 473 self.assertEqual(grpc.StatusCode.OK, await call.code()) 474 475 async def test_empty_responses(self): 476 # Prepares the request 477 request = messages_pb2.StreamingOutputCallRequest() 478 for _ in range(_NUM_STREAM_RESPONSES): 479 request.response_parameters.append( 480 messages_pb2.ResponseParameters()) 481 482 # Invokes the actual RPC 483 call = self._stub.StreamingOutputCall(request) 484 485 for _ in range(_NUM_STREAM_RESPONSES): 486 response = await call.read() 487 self.assertIs(type(response), 488 messages_pb2.StreamingOutputCallResponse) 489 self.assertEqual(b'', response.SerializeToString()) 490 491 self.assertEqual(grpc.StatusCode.OK, await call.code()) 492 493 494class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase): 495 496 async def test_cancel_stream_unary(self): 497 call = self._stub.StreamingInputCall() 498 499 # Prepares the request 500 payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) 501 request = messages_pb2.StreamingInputCallRequest(payload=payload) 502 503 # Sends out requests 504 for _ in range(_NUM_STREAM_RESPONSES): 505 await call.write(request) 506 507 # Cancels the RPC 508 self.assertFalse(call.done()) 509 self.assertFalse(call.cancelled()) 510 self.assertTrue(call.cancel()) 511 self.assertTrue(call.cancelled()) 512 513 await call.done_writing() 514 515 with self.assertRaises(asyncio.CancelledError): 516 await call 517 518 async def test_early_cancel_stream_unary(self): 519 call = self._stub.StreamingInputCall() 520 521 # Cancels the RPC 522 self.assertFalse(call.done()) 523 self.assertFalse(call.cancelled()) 524 self.assertTrue(call.cancel()) 525 self.assertTrue(call.cancelled()) 526 527 with self.assertRaises(asyncio.InvalidStateError): 528 await call.write(messages_pb2.StreamingInputCallRequest()) 529 530 # Should be no-op 531 await call.done_writing() 532 533 with self.assertRaises(asyncio.CancelledError): 534 await call 535 536 async def test_write_after_done_writing(self): 537 call = self._stub.StreamingInputCall() 538 539 # Prepares the request 540 payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) 541 request = messages_pb2.StreamingInputCallRequest(payload=payload) 542 543 # Sends out requests 544 for _ in range(_NUM_STREAM_RESPONSES): 545 await call.write(request) 546 547 # Should be no-op 548 await call.done_writing() 549 550 with self.assertRaises(asyncio.InvalidStateError): 551 await call.write(messages_pb2.StreamingInputCallRequest()) 552 553 response = await call 554 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) 555 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 556 response.aggregated_payload_size) 557 558 self.assertEqual(await call.code(), grpc.StatusCode.OK) 559 560 async def test_error_in_async_generator(self): 561 # Server will pause between responses 562 request = messages_pb2.StreamingOutputCallRequest() 563 request.response_parameters.append( 564 messages_pb2.ResponseParameters( 565 size=_RESPONSE_PAYLOAD_SIZE, 566 interval_us=_RESPONSE_INTERVAL_US, 567 )) 568 569 # We expect the request iterator to receive the exception 570 request_iterator_received_the_exception = asyncio.Event() 571 572 async def request_iterator(): 573 with self.assertRaises(asyncio.CancelledError): 574 for _ in range(_NUM_STREAM_RESPONSES): 575 yield request 576 await asyncio.sleep(_SHORT_TIMEOUT_S) 577 request_iterator_received_the_exception.set() 578 579 call = self._stub.StreamingInputCall(request_iterator()) 580 581 # Cancel the RPC after at least one response 582 async def cancel_later(): 583 await asyncio.sleep(_SHORT_TIMEOUT_S * 2) 584 call.cancel() 585 586 cancel_later_task = self.loop.create_task(cancel_later()) 587 588 with self.assertRaises(asyncio.CancelledError): 589 await call 590 591 await request_iterator_received_the_exception.wait() 592 593 # No failures in the cancel later task! 594 await cancel_later_task 595 596 async def test_normal_iterable_requests(self): 597 # Prepares the request 598 payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) 599 request = messages_pb2.StreamingInputCallRequest(payload=payload) 600 requests = [request] * _NUM_STREAM_RESPONSES 601 602 # Sends out requests 603 call = self._stub.StreamingInputCall(requests) 604 605 # RPC should succeed 606 response = await call 607 self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) 608 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 609 response.aggregated_payload_size) 610 611 self.assertEqual(await call.code(), grpc.StatusCode.OK) 612 613 async def test_call_rpc_error(self): 614 async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: 615 stub = test_pb2_grpc.TestServiceStub(channel) 616 617 # The error should be raised automatically without any traffic. 618 call = stub.StreamingInputCall() 619 with self.assertRaises(aio.AioRpcError) as exception_context: 620 await call 621 622 self.assertEqual(grpc.StatusCode.UNAVAILABLE, 623 exception_context.exception.code()) 624 625 self.assertTrue(call.done()) 626 self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) 627 628 async def test_timeout(self): 629 call = self._stub.StreamingInputCall(timeout=_SHORT_TIMEOUT_S) 630 631 # The error should be raised automatically without any traffic. 632 with self.assertRaises(aio.AioRpcError) as exception_context: 633 await call 634 635 rpc_error = exception_context.exception 636 self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code()) 637 self.assertTrue(call.done()) 638 self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await call.code()) 639 640 641# Prepares the request that stream in a ping-pong manner. 642_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest() 643_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append( 644 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) 645_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE = messages_pb2.StreamingOutputCallRequest( 646) 647_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE.response_parameters.append( 648 messages_pb2.ResponseParameters()) 649 650 651class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase): 652 653 async def test_cancel(self): 654 # Invokes the actual RPC 655 call = self._stub.FullDuplexCall() 656 657 for _ in range(_NUM_STREAM_RESPONSES): 658 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 659 response = await call.read() 660 self.assertIsInstance(response, 661 messages_pb2.StreamingOutputCallResponse) 662 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 663 664 # Cancels the RPC 665 self.assertFalse(call.done()) 666 self.assertFalse(call.cancelled()) 667 self.assertTrue(call.cancel()) 668 self.assertTrue(call.cancelled()) 669 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 670 671 async def test_cancel_with_pending_read(self): 672 call = self._stub.FullDuplexCall() 673 674 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 675 676 # Cancels the RPC 677 self.assertFalse(call.done()) 678 self.assertFalse(call.cancelled()) 679 self.assertTrue(call.cancel()) 680 self.assertTrue(call.cancelled()) 681 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 682 683 async def test_cancel_with_ongoing_read(self): 684 call = self._stub.FullDuplexCall() 685 coro_started = asyncio.Event() 686 687 async def read_coro(): 688 coro_started.set() 689 await call.read() 690 691 read_task = self.loop.create_task(read_coro()) 692 await coro_started.wait() 693 self.assertFalse(read_task.done()) 694 695 # Cancels the RPC 696 self.assertFalse(call.done()) 697 self.assertFalse(call.cancelled()) 698 self.assertTrue(call.cancel()) 699 self.assertTrue(call.cancelled()) 700 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 701 702 async def test_early_cancel(self): 703 call = self._stub.FullDuplexCall() 704 705 # Cancels the RPC 706 self.assertFalse(call.done()) 707 self.assertFalse(call.cancelled()) 708 self.assertTrue(call.cancel()) 709 self.assertTrue(call.cancelled()) 710 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 711 712 async def test_cancel_after_done_writing(self): 713 call = self._stub.FullDuplexCall() 714 await call.done_writing() 715 716 # Cancels the RPC 717 self.assertFalse(call.done()) 718 self.assertFalse(call.cancelled()) 719 self.assertTrue(call.cancel()) 720 self.assertTrue(call.cancelled()) 721 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 722 723 async def test_late_cancel(self): 724 call = self._stub.FullDuplexCall() 725 await call.done_writing() 726 self.assertEqual(grpc.StatusCode.OK, await call.code()) 727 728 # Cancels the RPC 729 self.assertTrue(call.done()) 730 self.assertFalse(call.cancelled()) 731 self.assertFalse(call.cancel()) 732 self.assertFalse(call.cancelled()) 733 734 # Status is still OK 735 self.assertEqual(grpc.StatusCode.OK, await call.code()) 736 737 async def test_async_generator(self): 738 739 async def request_generator(): 740 yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE 741 yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE 742 743 call = self._stub.FullDuplexCall(request_generator()) 744 async for response in call: 745 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 746 747 self.assertEqual(await call.code(), grpc.StatusCode.OK) 748 749 async def test_too_many_reads(self): 750 751 async def request_generator(): 752 for _ in range(_NUM_STREAM_RESPONSES): 753 yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE 754 755 call = self._stub.FullDuplexCall(request_generator()) 756 for _ in range(_NUM_STREAM_RESPONSES): 757 response = await call.read() 758 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 759 self.assertIs(await call.read(), aio.EOF) 760 761 self.assertEqual(await call.code(), grpc.StatusCode.OK) 762 # After the RPC finished, the read should also produce EOF 763 self.assertIs(await call.read(), aio.EOF) 764 765 async def test_read_write_after_done_writing(self): 766 call = self._stub.FullDuplexCall() 767 768 # Writes two requests, and pending two requests 769 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 770 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 771 await call.done_writing() 772 773 # Further write should fail 774 with self.assertRaises(asyncio.InvalidStateError): 775 await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) 776 777 # But read should be unaffected 778 response = await call.read() 779 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 780 response = await call.read() 781 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 782 783 self.assertEqual(await call.code(), grpc.StatusCode.OK) 784 785 async def test_error_in_async_generator(self): 786 # Server will pause between responses 787 request = messages_pb2.StreamingOutputCallRequest() 788 request.response_parameters.append( 789 messages_pb2.ResponseParameters( 790 size=_RESPONSE_PAYLOAD_SIZE, 791 interval_us=_RESPONSE_INTERVAL_US, 792 )) 793 794 # We expect the request iterator to receive the exception 795 request_iterator_received_the_exception = asyncio.Event() 796 797 async def request_iterator(): 798 with self.assertRaises(asyncio.CancelledError): 799 for _ in range(_NUM_STREAM_RESPONSES): 800 yield request 801 await asyncio.sleep(_SHORT_TIMEOUT_S) 802 request_iterator_received_the_exception.set() 803 804 call = self._stub.FullDuplexCall(request_iterator()) 805 806 # Cancel the RPC after at least one response 807 async def cancel_later(): 808 await asyncio.sleep(_SHORT_TIMEOUT_S * 2) 809 call.cancel() 810 811 cancel_later_task = self.loop.create_task(cancel_later()) 812 813 with self.assertRaises(asyncio.CancelledError): 814 async for response in call: 815 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, 816 len(response.payload.body)) 817 818 await request_iterator_received_the_exception.wait() 819 820 self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) 821 # No failures in the cancel later task! 822 await cancel_later_task 823 824 async def test_normal_iterable_requests(self): 825 requests = [_STREAM_OUTPUT_REQUEST_ONE_RESPONSE] * _NUM_STREAM_RESPONSES 826 827 call = self._stub.FullDuplexCall(iter(requests)) 828 async for response in call: 829 self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) 830 831 self.assertEqual(await call.code(), grpc.StatusCode.OK) 832 833 async def test_empty_ping_pong(self): 834 call = self._stub.FullDuplexCall() 835 for _ in range(_NUM_STREAM_RESPONSES): 836 await call.write(_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE) 837 response = await call.read() 838 self.assertEqual(b'', response.SerializeToString()) 839 await call.done_writing() 840 self.assertEqual(await call.code(), grpc.StatusCode.OK) 841 842 843if __name__ == '__main__': 844 logging.basicConfig(level=logging.DEBUG) 845 unittest.main(verbosity=2) 846