1#!/usr/bin/env python3 2# Copyright 2021 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests using the callback client for pw_rpc.""" 16 17import unittest 18from unittest import mock 19from typing import Any, List, Optional, Tuple 20 21from pw_protobuf_compiler import python_protos 22from pw_status import Status 23 24from pw_rpc import callback_client, client, packets 25from pw_rpc.internal import packet_pb2 26 27TEST_PROTO_1 = """\ 28syntax = "proto3"; 29 30package pw.test1; 31 32message SomeMessage { 33 uint32 magic_number = 1; 34} 35 36message AnotherMessage { 37 enum Result { 38 FAILED = 0; 39 FAILED_MISERABLY = 1; 40 I_DONT_WANT_TO_TALK_ABOUT_IT = 2; 41 } 42 43 Result result = 1; 44 string payload = 2; 45} 46 47service PublicService { 48 rpc SomeUnary(SomeMessage) returns (AnotherMessage) {} 49 rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {} 50 rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {} 51 rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {} 52} 53""" 54 55 56def _message_bytes(msg) -> bytes: 57 return msg if isinstance(msg, bytes) else msg.SerializeToString() 58 59 60class _CallbackClientImplTestBase(unittest.TestCase): 61 """Supports writing tests that require responses from an RPC server.""" 62 def setUp(self) -> None: 63 self._protos = python_protos.Library.from_strings(TEST_PROTO_1) 64 self._request = self._protos.packages.pw.test1.SomeMessage 65 66 self._client = client.Client.from_modules( 67 callback_client.Impl(), [client.Channel(1, self._handle_packet)], 68 self._protos.modules()) 69 self._service = self._client.channel(1).rpcs.pw.test1.PublicService 70 71 self.requests: List[packet_pb2.RpcPacket] = [] 72 self._next_packets: List[Tuple[bytes, Status]] = [] 73 self.send_responses_after_packets: float = 1 74 75 self.output_exception: Optional[Exception] = None 76 77 def last_request(self) -> packet_pb2.RpcPacket: 78 assert self.requests 79 return self.requests[-1] 80 81 def _enqueue_response(self, 82 channel_id: int, 83 method=None, 84 status: Status = Status.OK, 85 payload=b'', 86 *, 87 ids: Tuple[int, int] = None, 88 process_status=Status.OK) -> None: 89 if method: 90 assert ids is None 91 service_id, method_id = method.service.id, method.id 92 else: 93 assert ids is not None and method is None 94 service_id, method_id = ids 95 96 self._next_packets.append((packet_pb2.RpcPacket( 97 type=packet_pb2.PacketType.RESPONSE, 98 channel_id=channel_id, 99 service_id=service_id, 100 method_id=method_id, 101 status=status.value, 102 payload=_message_bytes(payload)).SerializeToString(), 103 process_status)) 104 105 def _enqueue_server_stream(self, 106 channel_id: int, 107 method, 108 response, 109 process_status=Status.OK) -> None: 110 self._next_packets.append((packet_pb2.RpcPacket( 111 type=packet_pb2.PacketType.SERVER_STREAM, 112 channel_id=channel_id, 113 service_id=method.service.id, 114 method_id=method.id, 115 payload=_message_bytes(response)).SerializeToString(), 116 process_status)) 117 118 def _enqueue_error(self, 119 channel_id: int, 120 service, 121 method, 122 status: Status, 123 process_status=Status.OK) -> None: 124 self._next_packets.append((packet_pb2.RpcPacket( 125 type=packet_pb2.PacketType.SERVER_ERROR, 126 channel_id=channel_id, 127 service_id=service if isinstance(service, int) else service.id, 128 method_id=method if isinstance(method, int) else method.id, 129 status=status.value).SerializeToString(), process_status)) 130 131 def _handle_packet(self, data: bytes) -> None: 132 if self.output_exception: 133 raise self.output_exception # pylint: disable=raising-bad-type 134 135 self.requests.append(packets.decode(data)) 136 137 if self.send_responses_after_packets > 1: 138 self.send_responses_after_packets -= 1 139 return 140 141 self._process_enqueued_packets() 142 143 def _process_enqueued_packets(self) -> None: 144 # Set send_responses_after_packets to infinity to prevent potential 145 # infinite recursion when a packet causes another packet to send. 146 send_after_count = self.send_responses_after_packets 147 self.send_responses_after_packets = float('inf') 148 149 for packet, status in self._next_packets: 150 self.assertIs(status, self._client.process_packet(packet)) 151 152 self._next_packets.clear() 153 self.send_responses_after_packets = send_after_count 154 155 def _sent_payload(self, message_type: type) -> Any: 156 message = message_type() 157 message.ParseFromString(self.last_request().payload) 158 return message 159 160 161class CallbackClientImplTest(_CallbackClientImplTestBase): 162 """Tests the callback_client.Impl client implementation.""" 163 def test_callback_exceptions_suppressed(self) -> None: 164 stub = self._service.SomeUnary 165 166 self._enqueue_response(1, stub.method) 167 exception_msg = 'YOU BROKE IT O-]-<' 168 169 with self.assertLogs(callback_client.__package__, 'ERROR') as logs: 170 stub.invoke(self._request(), 171 mock.Mock(side_effect=Exception(exception_msg))) 172 173 self.assertIn(exception_msg, ''.join(logs.output)) 174 175 # Make sure we can still invoke the RPC. 176 self._enqueue_response(1, stub.method, Status.UNKNOWN) 177 status, _ = stub() 178 self.assertIs(status, Status.UNKNOWN) 179 180 def test_ignore_bad_packets_with_pending_rpc(self) -> None: 181 method = self._service.SomeUnary.method 182 service_id = method.service.id 183 184 # Unknown channel 185 self._enqueue_response(999, method, process_status=Status.NOT_FOUND) 186 # Bad service 187 self._enqueue_response(1, 188 ids=(999, method.id), 189 process_status=Status.OK) 190 # Bad method 191 self._enqueue_response(1, 192 ids=(service_id, 999), 193 process_status=Status.OK) 194 # For RPC not pending (is Status.OK because the packet is processed) 195 self._enqueue_response(1, 196 ids=(service_id, 197 self._service.SomeBidiStreaming.method.id), 198 process_status=Status.OK) 199 200 self._enqueue_response(1, method, process_status=Status.OK) 201 202 status, response = self._service.SomeUnary(magic_number=6) 203 self.assertIs(Status.OK, status) 204 self.assertEqual('', response.payload) 205 206 def test_server_error_for_unknown_call_sends_no_errors(self) -> None: 207 method = self._service.SomeUnary.method 208 service_id = method.service.id 209 210 # Unknown channel 211 self._enqueue_error(999, 212 service_id, 213 method, 214 Status.NOT_FOUND, 215 process_status=Status.NOT_FOUND) 216 # Bad service 217 self._enqueue_error(1, 999, method.id, Status.INVALID_ARGUMENT) 218 # Bad method 219 self._enqueue_error(1, service_id, 999, Status.INVALID_ARGUMENT) 220 # For RPC not pending 221 self._enqueue_error(1, service_id, 222 self._service.SomeBidiStreaming.method.id, 223 Status.NOT_FOUND) 224 225 self._process_enqueued_packets() 226 227 self.assertEqual(self.requests, []) 228 229 def test_exception_if_payload_fails_to_decode(self) -> None: 230 method = self._service.SomeUnary.method 231 232 self._enqueue_response(1, 233 method, 234 Status.OK, 235 b'INVALID DATA!!!', 236 process_status=Status.OK) 237 238 with self.assertRaises(callback_client.RpcError) as context: 239 self._service.SomeUnary(magic_number=6) 240 241 self.assertIs(context.exception.status, Status.DATA_LOSS) 242 243 def test_rpc_help_contains_method_name(self) -> None: 244 rpc = self._service.SomeUnary 245 self.assertIn(rpc.method.full_name, rpc.help()) 246 247 def test_default_timeouts_set_on_impl(self) -> None: 248 impl = callback_client.Impl(None, 1.5) 249 250 self.assertEqual(impl.default_unary_timeout_s, None) 251 self.assertEqual(impl.default_stream_timeout_s, 1.5) 252 253 def test_default_timeouts_set_for_all_rpcs(self) -> None: 254 rpc_client = client.Client.from_modules(callback_client.Impl( 255 99, 100), [client.Channel(1, lambda *a, **b: None)], 256 self._protos.modules()) 257 rpcs = rpc_client.channel(1).rpcs 258 259 self.assertEqual( 260 rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99) 261 self.assertEqual( 262 rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s, 263 100) 264 self.assertEqual( 265 rpcs.pw.test1.PublicService.SomeClientStreaming.default_timeout_s, 266 99) 267 self.assertEqual( 268 rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s, 269 100) 270 271 def test_rpc_provides_request_type(self) -> None: 272 self.assertIs(self._service.SomeUnary.request, 273 self._service.SomeUnary.method.request_type) 274 275 def test_rpc_provides_response_type(self) -> None: 276 self.assertIs(self._service.SomeUnary.request, 277 self._service.SomeUnary.method.request_type) 278 279 280class UnaryTest(_CallbackClientImplTestBase): 281 """Tests for invoking a unary RPC.""" 282 def setUp(self) -> None: 283 super().setUp() 284 self.rpc = self._service.SomeUnary 285 self.method = self.rpc.method 286 287 def test_blocking_call(self) -> None: 288 for _ in range(3): 289 self._enqueue_response(1, self.method, Status.ABORTED, 290 self.method.response_type(payload='0_o')) 291 292 status, response = self._service.SomeUnary( 293 self.method.request_type(magic_number=6)) 294 295 self.assertEqual( 296 6, 297 self._sent_payload(self.method.request_type).magic_number) 298 299 self.assertIs(Status.ABORTED, status) 300 self.assertEqual('0_o', response.payload) 301 302 def test_nonblocking_call(self) -> None: 303 for _ in range(3): 304 self._enqueue_response(1, self.method, Status.ABORTED, 305 self.method.response_type(payload='0_o')) 306 307 callback = mock.Mock() 308 call = self.rpc.invoke(self._request(magic_number=5), callback, 309 callback) 310 311 callback.assert_has_calls([ 312 mock.call(call, self.method.response_type(payload='0_o')), 313 mock.call(call, Status.ABORTED) 314 ]) 315 316 self.assertEqual( 317 5, 318 self._sent_payload(self.method.request_type).magic_number) 319 320 def test_open(self) -> None: 321 self.output_exception = IOError('something went wrong sending!') 322 323 for _ in range(3): 324 self._enqueue_response(1, self.method, Status.ABORTED, 325 self.method.response_type(payload='0_o')) 326 327 callback = mock.Mock() 328 call = self.rpc.open(self._request(magic_number=5), callback, 329 callback) 330 self.assertEqual(self.requests, []) 331 332 self._process_enqueued_packets() 333 334 callback.assert_has_calls([ 335 mock.call(call, self.method.response_type(payload='0_o')), 336 mock.call(call, Status.ABORTED) 337 ]) 338 339 def test_blocking_server_error(self) -> None: 340 for _ in range(3): 341 self._enqueue_error(1, self.method.service, self.method, 342 Status.NOT_FOUND) 343 344 with self.assertRaises(callback_client.RpcError) as context: 345 self._service.SomeUnary( 346 self.method.request_type(magic_number=6)) 347 348 self.assertIs(context.exception.status, Status.NOT_FOUND) 349 350 def test_nonblocking_cancel(self) -> None: 351 callback = mock.Mock() 352 353 for _ in range(3): 354 call = self._service.SomeUnary.invoke( 355 self._request(magic_number=55), callback) 356 357 self.assertGreater(len(self.requests), 0) 358 self.requests.clear() 359 360 self.assertTrue(call.cancel()) 361 self.assertFalse(call.cancel()) # Already cancelled, returns False 362 363 # Unary RPCs do not send a cancel request to the server. 364 self.assertFalse(self.requests) 365 366 callback.assert_not_called() 367 368 def test_nonblocking_with_request_args(self) -> None: 369 self.rpc.invoke(request_args=dict(magic_number=1138)) 370 self.assertEqual( 371 self._sent_payload(self.rpc.request).magic_number, 1138) 372 373 def test_blocking_timeout_as_argument(self) -> None: 374 with self.assertRaises(callback_client.RpcTimeout): 375 self._service.SomeUnary(pw_rpc_timeout_s=0.0001) 376 377 def test_blocking_timeout_set_default(self) -> None: 378 self._service.SomeUnary.default_timeout_s = 0.0001 379 380 with self.assertRaises(callback_client.RpcTimeout): 381 self._service.SomeUnary() 382 383 def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None: 384 first_call = self.rpc.invoke() 385 self.assertFalse(first_call.completed()) 386 387 second_call = self.rpc.invoke() 388 389 self.assertIs(first_call.error, Status.CANCELLED) 390 self.assertFalse(second_call.completed()) 391 392 def test_nonblocking_exception_in_callback(self) -> None: 393 exception = ValueError('something went wrong!') 394 395 self._enqueue_response(1, self.method, Status.OK) 396 397 call = self.rpc.invoke(on_completed=mock.Mock(side_effect=exception)) 398 399 with self.assertRaises(RuntimeError) as context: 400 call.wait() 401 402 self.assertEqual(context.exception.__cause__, exception) 403 404 405class ServerStreamingTest(_CallbackClientImplTestBase): 406 """Tests for server streaming RPCs.""" 407 def setUp(self) -> None: 408 super().setUp() 409 self.rpc = self._service.SomeServerStreaming 410 self.method = self.rpc.method 411 412 def test_blocking_call(self) -> None: 413 rep1 = self.method.response_type(payload='!!!') 414 rep2 = self.method.response_type(payload='?') 415 416 for _ in range(3): 417 self._enqueue_server_stream(1, self.method, rep1) 418 self._enqueue_server_stream(1, self.method, rep2) 419 self._enqueue_response(1, self.method, Status.ABORTED) 420 421 self.assertEqual( 422 [rep1, rep2], 423 self._service.SomeServerStreaming(magic_number=4).responses) 424 425 self.assertEqual( 426 4, 427 self._sent_payload(self.method.request_type).magic_number) 428 429 def test_deprecated_packet_format(self) -> None: 430 rep1 = self.method.response_type(payload='!!!') 431 rep2 = self.method.response_type(payload='?') 432 433 for _ in range(3): 434 # The original packet format used RESPONSE packets for the server 435 # stream and a SERVER_STREAM_END packet as the last packet. These 436 # are converted to SERVER_STREAM packets followed by a RESPONSE. 437 self._enqueue_response(1, self.method, payload=rep1) 438 self._enqueue_response(1, self.method, payload=rep2) 439 440 self._next_packets.append((packet_pb2.RpcPacket( 441 type=packet_pb2.PacketType.DEPRECATED_SERVER_STREAM_END, 442 channel_id=1, 443 service_id=self.method.service.id, 444 method_id=self.method.id, 445 status=Status.INVALID_ARGUMENT.value).SerializeToString(), 446 Status.OK)) 447 448 status, replies = self._service.SomeServerStreaming(magic_number=4) 449 self.assertEqual([rep1, rep2], replies) 450 self.assertIs(status, Status.INVALID_ARGUMENT) 451 452 self.assertEqual( 453 4, 454 self._sent_payload(self.method.request_type).magic_number) 455 456 def test_nonblocking_call(self) -> None: 457 rep1 = self.method.response_type(payload='!!!') 458 rep2 = self.method.response_type(payload='?') 459 460 for _ in range(3): 461 self._enqueue_server_stream(1, self.method, rep1) 462 self._enqueue_server_stream(1, self.method, rep2) 463 self._enqueue_response(1, self.method, Status.ABORTED) 464 465 callback = mock.Mock() 466 call = self.rpc.invoke(self._request(magic_number=3), callback, 467 callback) 468 469 callback.assert_has_calls([ 470 mock.call(call, self.method.response_type(payload='!!!')), 471 mock.call(call, self.method.response_type(payload='?')), 472 mock.call(call, Status.ABORTED), 473 ]) 474 475 self.assertEqual( 476 3, 477 self._sent_payload(self.method.request_type).magic_number) 478 479 def test_open(self) -> None: 480 self.output_exception = IOError('something went wrong sending!') 481 rep1 = self.method.response_type(payload='!!!') 482 rep2 = self.method.response_type(payload='?') 483 484 for _ in range(3): 485 self._enqueue_server_stream(1, self.method, rep1) 486 self._enqueue_server_stream(1, self.method, rep2) 487 self._enqueue_response(1, self.method, Status.ABORTED) 488 489 callback = mock.Mock() 490 call = self.rpc.open(self._request(magic_number=3), callback, 491 callback) 492 self.assertEqual(self.requests, []) 493 494 self._process_enqueued_packets() 495 496 callback.assert_has_calls([ 497 mock.call(call, self.method.response_type(payload='!!!')), 498 mock.call(call, self.method.response_type(payload='?')), 499 mock.call(call, Status.ABORTED), 500 ]) 501 502 def test_nonblocking_cancel(self) -> None: 503 resp = self.rpc.method.response_type(payload='!!!') 504 self._enqueue_server_stream(1, self.rpc.method, resp) 505 506 callback = mock.Mock() 507 call = self.rpc.invoke(self._request(magic_number=3), callback) 508 callback.assert_called_once_with( 509 call, self.rpc.method.response_type(payload='!!!')) 510 511 callback.reset_mock() 512 513 call.cancel() 514 515 self.assertEqual(self.last_request().type, 516 packet_pb2.PacketType.CLIENT_ERROR) 517 self.assertEqual(self.last_request().status, Status.CANCELLED.value) 518 519 # Ensure the RPC can be called after being cancelled. 520 self._enqueue_server_stream(1, self.method, resp) 521 self._enqueue_response(1, self.method, Status.OK) 522 523 call = self.rpc.invoke(self._request(magic_number=3), callback, 524 callback) 525 526 callback.assert_has_calls([ 527 mock.call(call, self.method.response_type(payload='!!!')), 528 mock.call(call, Status.OK), 529 ]) 530 531 def test_nonblocking_with_request_args(self) -> None: 532 self.rpc.invoke(request_args=dict(magic_number=1138)) 533 self.assertEqual( 534 self._sent_payload(self.rpc.request).magic_number, 1138) 535 536 def test_blocking_timeout(self) -> None: 537 with self.assertRaises(callback_client.RpcTimeout): 538 self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001) 539 540 def test_nonblocking_iteration_timeout(self) -> None: 541 call = self._service.SomeServerStreaming.invoke(timeout_s=0.0001) 542 with self.assertRaises(callback_client.RpcTimeout): 543 for _ in call: 544 pass 545 546 def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None: 547 first_call = self.rpc.invoke() 548 self.assertFalse(first_call.completed()) 549 550 second_call = self.rpc.invoke() 551 552 self.assertIs(first_call.error, Status.CANCELLED) 553 self.assertFalse(second_call.completed()) 554 555 def test_nonblocking_iterate_over_count(self) -> None: 556 reply = self.method.response_type(payload='!?') 557 558 for _ in range(4): 559 self._enqueue_server_stream(1, self.method, reply) 560 561 call = self.rpc.invoke() 562 563 self.assertEqual(list(call.get_responses(count=1)), [reply]) 564 self.assertEqual(next(iter(call)), reply) 565 self.assertEqual(list(call.get_responses(count=2)), [reply, reply]) 566 567 def test_nonblocking_iterate_after_completed_doesnt_block(self) -> None: 568 reply = self.method.response_type(payload='!?') 569 self._enqueue_server_stream(1, self.method, reply) 570 self._enqueue_response(1, self.method, Status.OK) 571 572 call = self.rpc.invoke() 573 574 self.assertEqual(list(call.get_responses()), [reply]) 575 self.assertEqual(list(call.get_responses()), []) 576 self.assertEqual(list(call), []) 577 578 579class ClientStreamingTest(_CallbackClientImplTestBase): 580 """Tests for client streaming RPCs.""" 581 def setUp(self) -> None: 582 super().setUp() 583 self.rpc = self._service.SomeClientStreaming 584 self.method = self.rpc.method 585 586 def test_blocking_call(self) -> None: 587 requests = [ 588 self.method.request_type(magic_number=123), 589 self.method.request_type(magic_number=456), 590 ] 591 592 # Send after len(requests) and the client stream end packet. 593 self.send_responses_after_packets = 3 594 response = self.method.response_type(payload='yo') 595 self._enqueue_response(1, self.method, Status.OK, response) 596 597 results = self.rpc(requests) 598 self.assertIs(results.status, Status.OK) 599 self.assertEqual(results.response, response) 600 601 def test_blocking_server_error(self) -> None: 602 requests = [self.method.request_type(magic_number=123)] 603 604 # Send after len(requests) and the client stream end packet. 605 self._enqueue_error(1, self.method.service, self.method, 606 Status.NOT_FOUND) 607 608 with self.assertRaises(callback_client.RpcError) as context: 609 self.rpc(requests) 610 611 self.assertIs(context.exception.status, Status.NOT_FOUND) 612 613 def test_nonblocking_call(self) -> None: 614 """Tests a successful client streaming RPC ended by the server.""" 615 payload_1 = self.method.response_type(payload='-_-') 616 617 for _ in range(3): 618 stream = self._service.SomeClientStreaming.invoke() 619 self.assertFalse(stream.completed()) 620 621 stream.send(magic_number=31) 622 self.assertIs(packet_pb2.PacketType.CLIENT_STREAM, 623 self.last_request().type) 624 self.assertEqual( 625 31, 626 self._sent_payload(self.method.request_type).magic_number) 627 self.assertFalse(stream.completed()) 628 629 # Enqueue the server response to be sent after the next message. 630 self._enqueue_response(1, self.method, Status.OK, payload_1) 631 632 stream.send(magic_number=32) 633 self.assertIs(packet_pb2.PacketType.CLIENT_STREAM, 634 self.last_request().type) 635 self.assertEqual( 636 32, 637 self._sent_payload(self.method.request_type).magic_number) 638 639 self.assertTrue(stream.completed()) 640 self.assertIs(Status.OK, stream.status) 641 self.assertIsNone(stream.error) 642 self.assertEqual(payload_1, stream.response) 643 644 def test_open(self) -> None: 645 self.output_exception = IOError('something went wrong sending!') 646 payload = self.method.response_type(payload='-_-') 647 648 for _ in range(3): 649 self._enqueue_response(1, self.method, Status.OK, payload) 650 651 callback = mock.Mock() 652 call = self.rpc.open(callback, callback, callback) 653 self.assertEqual(self.requests, []) 654 655 self._process_enqueued_packets() 656 657 callback.assert_has_calls([ 658 mock.call(call, payload), 659 mock.call(call, Status.OK), 660 ]) 661 662 def test_nonblocking_finish(self) -> None: 663 """Tests a client streaming RPC ended by the client.""" 664 payload_1 = self.method.response_type(payload='-_-') 665 666 for _ in range(3): 667 stream = self._service.SomeClientStreaming.invoke() 668 self.assertFalse(stream.completed()) 669 670 stream.send(magic_number=37) 671 self.assertIs(packet_pb2.PacketType.CLIENT_STREAM, 672 self.last_request().type) 673 self.assertEqual( 674 37, 675 self._sent_payload(self.method.request_type).magic_number) 676 self.assertFalse(stream.completed()) 677 678 # Enqueue the server response to be sent after the next message. 679 self._enqueue_response(1, self.method, Status.OK, payload_1) 680 681 stream.finish_and_wait() 682 self.assertIs(packet_pb2.PacketType.CLIENT_STREAM_END, 683 self.last_request().type) 684 685 self.assertTrue(stream.completed()) 686 self.assertIs(Status.OK, stream.status) 687 self.assertIsNone(stream.error) 688 self.assertEqual(payload_1, stream.response) 689 690 def test_nonblocking_cancel(self) -> None: 691 for _ in range(3): 692 stream = self._service.SomeClientStreaming.invoke() 693 stream.send(magic_number=37) 694 695 self.assertTrue(stream.cancel()) 696 self.assertIs(packet_pb2.PacketType.CLIENT_ERROR, 697 self.last_request().type) 698 self.assertIs(Status.CANCELLED.value, self.last_request().status) 699 self.assertFalse(stream.cancel()) 700 701 self.assertTrue(stream.completed()) 702 self.assertIs(stream.error, Status.CANCELLED) 703 704 def test_nonblocking_server_error(self) -> None: 705 for _ in range(3): 706 stream = self._service.SomeClientStreaming.invoke() 707 708 self._enqueue_error(1, self.method.service, self.method, 709 Status.INVALID_ARGUMENT) 710 stream.send(magic_number=2**32 - 1) 711 712 with self.assertRaises(callback_client.RpcError) as context: 713 stream.finish_and_wait() 714 715 self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) 716 717 def test_nonblocking_server_error_after_stream_end(self) -> None: 718 for _ in range(3): 719 stream = self._service.SomeClientStreaming.invoke() 720 721 # Error will be sent in response to the CLIENT_STREAM_END packet. 722 self._enqueue_error(1, self.method.service, self.method, 723 Status.INVALID_ARGUMENT) 724 725 with self.assertRaises(callback_client.RpcError) as context: 726 stream.finish_and_wait() 727 728 self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) 729 730 def test_nonblocking_send_after_cancelled(self) -> None: 731 call = self._service.SomeClientStreaming.invoke() 732 self.assertTrue(call.cancel()) 733 734 with self.assertRaises(callback_client.RpcError) as context: 735 call.send(payload='hello') 736 737 self.assertIs(context.exception.status, Status.CANCELLED) 738 739 def test_nonblocking_finish_after_completed(self) -> None: 740 reply = self.method.response_type(payload='!?') 741 self._enqueue_response(1, self.method, Status.UNAVAILABLE, reply) 742 743 call = self.rpc.invoke() 744 result = call.finish_and_wait() 745 self.assertEqual(result.response, reply) 746 747 self.assertEqual(result, call.finish_and_wait()) 748 self.assertEqual(result, call.finish_and_wait()) 749 750 def test_nonblocking_finish_after_error(self) -> None: 751 self._enqueue_error(1, self.method.service, self.method, 752 Status.UNAVAILABLE) 753 754 call = self.rpc.invoke() 755 756 for _ in range(3): 757 with self.assertRaises(callback_client.RpcError) as context: 758 call.finish_and_wait() 759 760 self.assertIs(context.exception.status, Status.UNAVAILABLE) 761 self.assertIs(call.error, Status.UNAVAILABLE) 762 self.assertIsNone(call.response) 763 764 def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None: 765 first_call = self.rpc.invoke() 766 self.assertFalse(first_call.completed()) 767 768 second_call = self.rpc.invoke() 769 770 self.assertIs(first_call.error, Status.CANCELLED) 771 self.assertFalse(second_call.completed()) 772 773 774class BidirectionalStreamingTest(_CallbackClientImplTestBase): 775 """Tests for bidirectional streaming RPCs.""" 776 def setUp(self) -> None: 777 super().setUp() 778 self.rpc = self._service.SomeBidiStreaming 779 self.method = self.rpc.method 780 781 def test_blocking_call(self) -> None: 782 requests = [ 783 self.method.request_type(magic_number=123), 784 self.method.request_type(magic_number=456), 785 ] 786 787 # Send after len(requests) and the client stream end packet. 788 self.send_responses_after_packets = 3 789 self._enqueue_response(1, self.method, Status.NOT_FOUND) 790 791 results = self.rpc(requests) 792 self.assertIs(results.status, Status.NOT_FOUND) 793 self.assertFalse(results.responses) 794 795 def test_blocking_server_error(self) -> None: 796 requests = [self.method.request_type(magic_number=123)] 797 798 # Send after len(requests) and the client stream end packet. 799 self._enqueue_error(1, self.method.service, self.method, 800 Status.NOT_FOUND) 801 802 with self.assertRaises(callback_client.RpcError) as context: 803 self.rpc(requests) 804 805 self.assertIs(context.exception.status, Status.NOT_FOUND) 806 807 def test_nonblocking_call(self) -> None: 808 """Tests a bidirectional streaming RPC ended by the server.""" 809 rep1 = self.method.response_type(payload='!!!') 810 rep2 = self.method.response_type(payload='?') 811 812 for _ in range(3): 813 responses: list = [] 814 stream = self._service.SomeBidiStreaming.invoke( 815 lambda _, res, responses=responses: responses.append(res)) 816 self.assertFalse(stream.completed()) 817 818 stream.send(magic_number=55) 819 self.assertIs(packet_pb2.PacketType.CLIENT_STREAM, 820 self.last_request().type) 821 self.assertEqual( 822 55, 823 self._sent_payload(self.method.request_type).magic_number) 824 self.assertFalse(stream.completed()) 825 self.assertEqual([], responses) 826 827 self._enqueue_server_stream(1, self.method, rep1) 828 self._enqueue_server_stream(1, self.method, rep2) 829 830 stream.send(magic_number=66) 831 self.assertIs(packet_pb2.PacketType.CLIENT_STREAM, 832 self.last_request().type) 833 self.assertEqual( 834 66, 835 self._sent_payload(self.method.request_type).magic_number) 836 self.assertFalse(stream.completed()) 837 self.assertEqual([rep1, rep2], responses) 838 839 self._enqueue_response(1, self.method, Status.OK) 840 841 stream.send(magic_number=77) 842 self.assertTrue(stream.completed()) 843 self.assertEqual([rep1, rep2], responses) 844 845 self.assertIs(Status.OK, stream.status) 846 self.assertIsNone(stream.error) 847 848 def test_open(self) -> None: 849 self.output_exception = IOError('something went wrong sending!') 850 rep1 = self.method.response_type(payload='!!!') 851 rep2 = self.method.response_type(payload='?') 852 853 for _ in range(3): 854 self._enqueue_server_stream(1, self.method, rep1) 855 self._enqueue_server_stream(1, self.method, rep2) 856 self._enqueue_response(1, self.method, Status.OK) 857 858 callback = mock.Mock() 859 call = self.rpc.open(callback, callback, callback) 860 self.assertEqual(self.requests, []) 861 862 self._process_enqueued_packets() 863 864 callback.assert_has_calls([ 865 mock.call(call, self.method.response_type(payload='!!!')), 866 mock.call(call, self.method.response_type(payload='?')), 867 mock.call(call, Status.OK), 868 ]) 869 870 @mock.patch('pw_rpc.callback_client.call.Call._default_response') 871 def test_nonblocking(self, callback) -> None: 872 """Tests a bidirectional streaming RPC ended by the server.""" 873 reply = self.method.response_type(payload='This is the payload!') 874 self._enqueue_server_stream(1, self.method, reply) 875 876 self._service.SomeBidiStreaming.invoke() 877 878 callback.assert_called_once_with(mock.ANY, reply) 879 880 def test_nonblocking_server_error(self) -> None: 881 rep1 = self.method.response_type(payload='!!!') 882 883 for _ in range(3): 884 responses: list = [] 885 stream = self._service.SomeBidiStreaming.invoke( 886 lambda _, res, responses=responses: responses.append(res)) 887 self.assertFalse(stream.completed()) 888 889 self._enqueue_server_stream(1, self.method, rep1) 890 891 stream.send(magic_number=55) 892 self.assertFalse(stream.completed()) 893 self.assertEqual([rep1], responses) 894 895 self._enqueue_error(1, self.method.service, self.method, 896 Status.OUT_OF_RANGE) 897 898 stream.send(magic_number=99999) 899 self.assertTrue(stream.completed()) 900 self.assertEqual([rep1], responses) 901 902 self.assertIsNone(stream.status) 903 self.assertIs(Status.OUT_OF_RANGE, stream.error) 904 905 with self.assertRaises(callback_client.RpcError) as context: 906 stream.finish_and_wait() 907 self.assertIs(context.exception.status, Status.OUT_OF_RANGE) 908 909 def test_nonblocking_server_error_after_stream_end(self) -> None: 910 for _ in range(3): 911 stream = self._service.SomeBidiStreaming.invoke() 912 913 # Error will be sent in response to the CLIENT_STREAM_END packet. 914 self._enqueue_error(1, self.method.service, self.method, 915 Status.INVALID_ARGUMENT) 916 917 with self.assertRaises(callback_client.RpcError) as context: 918 stream.finish_and_wait() 919 920 self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) 921 922 def test_nonblocking_send_after_cancelled(self) -> None: 923 call = self._service.SomeBidiStreaming.invoke() 924 self.assertTrue(call.cancel()) 925 926 with self.assertRaises(callback_client.RpcError) as context: 927 call.send(payload='hello') 928 929 self.assertIs(context.exception.status, Status.CANCELLED) 930 931 def test_nonblocking_finish_after_completed(self) -> None: 932 reply = self.method.response_type(payload='!?') 933 self._enqueue_server_stream(1, self.method, reply) 934 self._enqueue_response(1, self.method, Status.UNAVAILABLE) 935 936 call = self.rpc.invoke() 937 result = call.finish_and_wait() 938 self.assertEqual(result.responses, [reply]) 939 940 self.assertEqual(result, call.finish_and_wait()) 941 self.assertEqual(result, call.finish_and_wait()) 942 943 def test_nonblocking_finish_after_error(self) -> None: 944 reply = self.method.response_type(payload='!?') 945 self._enqueue_server_stream(1, self.method, reply) 946 self._enqueue_error(1, self.method.service, self.method, 947 Status.UNAVAILABLE) 948 949 call = self.rpc.invoke() 950 951 for _ in range(3): 952 with self.assertRaises(callback_client.RpcError) as context: 953 call.finish_and_wait() 954 955 self.assertIs(context.exception.status, Status.UNAVAILABLE) 956 self.assertIs(call.error, Status.UNAVAILABLE) 957 self.assertEqual(call.responses, [reply]) 958 959 def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None: 960 first_call = self.rpc.invoke() 961 self.assertFalse(first_call.completed()) 962 963 second_call = self.rpc.invoke() 964 965 self.assertIs(first_call.error, Status.CANCELLED) 966 self.assertFalse(second_call.completed()) 967 968 969if __name__ == '__main__': 970 unittest.main() 971