1#!/usr/bin/env python3 2# Copyright 2021 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests using the callback client for pw_rpc.""" 16 17import unittest 18from unittest import mock 19from typing import Any, List, Optional, Tuple 20 21from pw_protobuf_compiler import python_protos 22from pw_status import Status 23 24from pw_rpc import callback_client, client, packets 25from pw_rpc.internal import packet_pb2 26 27TEST_PROTO_1 = """\ 28syntax = "proto3"; 29 30package pw.test1; 31 32message SomeMessage { 33 uint32 magic_number = 1; 34} 35 36message AnotherMessage { 37 enum Result { 38 FAILED = 0; 39 FAILED_MISERABLY = 1; 40 I_DONT_WANT_TO_TALK_ABOUT_IT = 2; 41 } 42 43 Result result = 1; 44 string payload = 2; 45} 46 47service PublicService { 48 rpc SomeUnary(SomeMessage) returns (AnotherMessage) {} 49 rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {} 50 rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {} 51 rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {} 52} 53""" 54 55 56def _message_bytes(msg) -> bytes: 57 return msg if isinstance(msg, bytes) else msg.SerializeToString() 58 59 60class _CallbackClientImplTestBase(unittest.TestCase): 61 """Supports writing tests that require responses from an RPC server.""" 62 63 def setUp(self) -> None: 64 self._protos = python_protos.Library.from_strings(TEST_PROTO_1) 65 self._request = self._protos.packages.pw.test1.SomeMessage 66 67 self._client = client.Client.from_modules( 68 callback_client.Impl(), 69 [client.Channel(1, self._handle_packet)], 70 self._protos.modules(), 71 ) 72 self._service = self._client.channel(1).rpcs.pw.test1.PublicService 73 74 self.requests: List[packet_pb2.RpcPacket] = [] 75 self._next_packets: List[Tuple[bytes, Status]] = [] 76 self.send_responses_after_packets: float = 1 77 78 self.output_exception: Optional[Exception] = None 79 80 def last_request(self) -> packet_pb2.RpcPacket: 81 assert self.requests 82 return self.requests[-1] 83 84 def _enqueue_response( 85 self, 86 channel_id: int, 87 method=None, 88 status: Status = Status.OK, 89 payload=b'', 90 *, 91 ids: Optional[Tuple[int, int]] = None, 92 process_status=Status.OK, 93 ) -> None: 94 if method: 95 assert ids is None 96 service_id, method_id = method.service.id, method.id 97 else: 98 assert ids is not None and method is None 99 service_id, method_id = ids 100 101 self._next_packets.append( 102 ( 103 packet_pb2.RpcPacket( 104 type=packet_pb2.PacketType.RESPONSE, 105 channel_id=channel_id, 106 service_id=service_id, 107 method_id=method_id, 108 status=status.value, 109 payload=_message_bytes(payload), 110 ).SerializeToString(), 111 process_status, 112 ) 113 ) 114 115 def _enqueue_server_stream( 116 self, channel_id: int, method, response, process_status=Status.OK 117 ) -> None: 118 self._next_packets.append( 119 ( 120 packet_pb2.RpcPacket( 121 type=packet_pb2.PacketType.SERVER_STREAM, 122 channel_id=channel_id, 123 service_id=method.service.id, 124 method_id=method.id, 125 payload=_message_bytes(response), 126 ).SerializeToString(), 127 process_status, 128 ) 129 ) 130 131 def _enqueue_error( 132 self, 133 channel_id: int, 134 service, 135 method, 136 status: Status, 137 process_status=Status.OK, 138 ) -> None: 139 self._next_packets.append( 140 ( 141 packet_pb2.RpcPacket( 142 type=packet_pb2.PacketType.SERVER_ERROR, 143 channel_id=channel_id, 144 service_id=service 145 if isinstance(service, int) 146 else service.id, 147 method_id=method if isinstance(method, int) else method.id, 148 status=status.value, 149 ).SerializeToString(), 150 process_status, 151 ) 152 ) 153 154 def _handle_packet(self, data: bytes) -> None: 155 if self.output_exception: 156 raise self.output_exception # pylint: disable=raising-bad-type 157 158 self.requests.append(packets.decode(data)) 159 160 if self.send_responses_after_packets > 1: 161 self.send_responses_after_packets -= 1 162 return 163 164 self._process_enqueued_packets() 165 166 def _process_enqueued_packets(self) -> None: 167 # Set send_responses_after_packets to infinity to prevent potential 168 # infinite recursion when a packet causes another packet to send. 169 send_after_count = self.send_responses_after_packets 170 self.send_responses_after_packets = float('inf') 171 172 for packet, status in self._next_packets: 173 self.assertIs(status, self._client.process_packet(packet)) 174 175 self._next_packets.clear() 176 self.send_responses_after_packets = send_after_count 177 178 def _sent_payload(self, message_type: type) -> Any: 179 message = message_type() 180 message.ParseFromString(self.last_request().payload) 181 return message 182 183 184class CallbackClientImplTest(_CallbackClientImplTestBase): 185 """Tests the callback_client.Impl client implementation.""" 186 187 def test_callback_exceptions_suppressed(self) -> None: 188 stub = self._service.SomeUnary 189 190 self._enqueue_response(1, stub.method) 191 exception_msg = 'YOU BROKE IT O-]-<' 192 193 with self.assertLogs(callback_client.__package__, 'ERROR') as logs: 194 stub.invoke( 195 self._request(), mock.Mock(side_effect=Exception(exception_msg)) 196 ) 197 198 self.assertIn(exception_msg, ''.join(logs.output)) 199 200 # Make sure we can still invoke the RPC. 201 self._enqueue_response(1, stub.method, Status.UNKNOWN) 202 status, _ = stub() 203 self.assertIs(status, Status.UNKNOWN) 204 205 def test_ignore_bad_packets_with_pending_rpc(self) -> None: 206 method = self._service.SomeUnary.method 207 service_id = method.service.id 208 209 # Unknown channel 210 self._enqueue_response(999, method, process_status=Status.NOT_FOUND) 211 # Bad service 212 self._enqueue_response( 213 1, ids=(999, method.id), process_status=Status.OK 214 ) 215 # Bad method 216 self._enqueue_response( 217 1, ids=(service_id, 999), process_status=Status.OK 218 ) 219 # For RPC not pending (is Status.OK because the packet is processed) 220 self._enqueue_response( 221 1, 222 ids=(service_id, self._service.SomeBidiStreaming.method.id), 223 process_status=Status.OK, 224 ) 225 226 self._enqueue_response(1, method, process_status=Status.OK) 227 228 status, response = self._service.SomeUnary(magic_number=6) 229 self.assertIs(Status.OK, status) 230 self.assertEqual('', response.payload) 231 232 def test_server_error_for_unknown_call_sends_no_errors(self) -> None: 233 method = self._service.SomeUnary.method 234 service_id = method.service.id 235 236 # Unknown channel 237 self._enqueue_error( 238 999, 239 service_id, 240 method, 241 Status.NOT_FOUND, 242 process_status=Status.NOT_FOUND, 243 ) 244 # Bad service 245 self._enqueue_error(1, 999, method.id, Status.INVALID_ARGUMENT) 246 # Bad method 247 self._enqueue_error(1, service_id, 999, Status.INVALID_ARGUMENT) 248 # For RPC not pending 249 self._enqueue_error( 250 1, 251 service_id, 252 self._service.SomeBidiStreaming.method.id, 253 Status.NOT_FOUND, 254 ) 255 256 self._process_enqueued_packets() 257 258 self.assertEqual(self.requests, []) 259 260 def test_exception_if_payload_fails_to_decode(self) -> None: 261 method = self._service.SomeUnary.method 262 263 self._enqueue_response( 264 1, method, Status.OK, b'INVALID DATA!!!', process_status=Status.OK 265 ) 266 267 with self.assertRaises(callback_client.RpcError) as context: 268 self._service.SomeUnary(magic_number=6) 269 270 self.assertIs(context.exception.status, Status.DATA_LOSS) 271 272 def test_rpc_help_contains_method_name(self) -> None: 273 rpc = self._service.SomeUnary 274 self.assertIn(rpc.method.full_name, rpc.help()) 275 276 def test_default_timeouts_set_on_impl(self) -> None: 277 impl = callback_client.Impl(None, 1.5) 278 279 self.assertEqual(impl.default_unary_timeout_s, None) 280 self.assertEqual(impl.default_stream_timeout_s, 1.5) 281 282 def test_default_timeouts_set_for_all_rpcs(self) -> None: 283 rpc_client = client.Client.from_modules( 284 callback_client.Impl(99, 100), 285 [client.Channel(1, lambda *a, **b: None)], 286 self._protos.modules(), 287 ) 288 rpcs = rpc_client.channel(1).rpcs 289 290 self.assertEqual( 291 rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99 292 ) 293 self.assertEqual( 294 rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s, 295 100, 296 ) 297 self.assertEqual( 298 rpcs.pw.test1.PublicService.SomeClientStreaming.default_timeout_s, 299 99, 300 ) 301 self.assertEqual( 302 rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s, 100 303 ) 304 305 def test_rpc_provides_request_type(self) -> None: 306 self.assertIs( 307 self._service.SomeUnary.request, 308 self._service.SomeUnary.method.request_type, 309 ) 310 311 def test_rpc_provides_response_type(self) -> None: 312 self.assertIs( 313 self._service.SomeUnary.request, 314 self._service.SomeUnary.method.request_type, 315 ) 316 317 318class UnaryTest(_CallbackClientImplTestBase): 319 """Tests for invoking a unary RPC.""" 320 321 def setUp(self) -> None: 322 super().setUp() 323 self.rpc = self._service.SomeUnary 324 self.method = self.rpc.method 325 326 def test_blocking_call(self) -> None: 327 for _ in range(3): 328 self._enqueue_response( 329 1, 330 self.method, 331 Status.ABORTED, 332 self.method.response_type(payload='0_o'), 333 ) 334 335 status, response = self._service.SomeUnary( 336 self.method.request_type(magic_number=6) 337 ) 338 339 self.assertEqual( 340 6, self._sent_payload(self.method.request_type).magic_number 341 ) 342 343 self.assertIs(Status.ABORTED, status) 344 self.assertEqual('0_o', response.payload) 345 346 def test_nonblocking_call(self) -> None: 347 for _ in range(3): 348 self._enqueue_response( 349 1, 350 self.method, 351 Status.ABORTED, 352 self.method.response_type(payload='0_o'), 353 ) 354 355 callback = mock.Mock() 356 call = self.rpc.invoke( 357 self._request(magic_number=5), callback, callback 358 ) 359 360 callback.assert_has_calls( 361 [ 362 mock.call(call, self.method.response_type(payload='0_o')), 363 mock.call(call, Status.ABORTED), 364 ] 365 ) 366 367 self.assertEqual( 368 5, self._sent_payload(self.method.request_type).magic_number 369 ) 370 371 def test_open(self) -> None: 372 self.output_exception = IOError('something went wrong sending!') 373 374 for _ in range(3): 375 self._enqueue_response( 376 1, 377 self.method, 378 Status.ABORTED, 379 self.method.response_type(payload='0_o'), 380 ) 381 382 callback = mock.Mock() 383 call = self.rpc.open( 384 self._request(magic_number=5), callback, callback 385 ) 386 self.assertEqual(self.requests, []) 387 388 self._process_enqueued_packets() 389 390 callback.assert_has_calls( 391 [ 392 mock.call(call, self.method.response_type(payload='0_o')), 393 mock.call(call, Status.ABORTED), 394 ] 395 ) 396 397 def test_blocking_server_error(self) -> None: 398 for _ in range(3): 399 self._enqueue_error( 400 1, self.method.service, self.method, Status.NOT_FOUND 401 ) 402 403 with self.assertRaises(callback_client.RpcError) as context: 404 self._service.SomeUnary( 405 self.method.request_type(magic_number=6) 406 ) 407 408 self.assertIs(context.exception.status, Status.NOT_FOUND) 409 410 def test_nonblocking_cancel(self) -> None: 411 callback = mock.Mock() 412 413 for _ in range(3): 414 call = self._service.SomeUnary.invoke( 415 self._request(magic_number=55), callback 416 ) 417 418 self.assertGreater(len(self.requests), 0) 419 self.requests.clear() 420 421 self.assertTrue(call.cancel()) 422 self.assertFalse(call.cancel()) # Already cancelled, returns False 423 424 self.assertEqual( 425 self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR 426 ) 427 self.assertEqual(self.last_request().status, Status.CANCELLED.value) 428 429 callback.assert_not_called() 430 431 def test_nonblocking_with_request_args(self) -> None: 432 self.rpc.invoke(request_args=dict(magic_number=1138)) 433 self.assertEqual( 434 self._sent_payload(self.rpc.request).magic_number, 1138 435 ) 436 437 def test_blocking_timeout_as_argument(self) -> None: 438 with self.assertRaises(callback_client.RpcTimeout): 439 self._service.SomeUnary(pw_rpc_timeout_s=0.0001) 440 441 def test_blocking_timeout_set_default(self) -> None: 442 self._service.SomeUnary.default_timeout_s = 0.0001 443 444 with self.assertRaises(callback_client.RpcTimeout): 445 self._service.SomeUnary() 446 447 def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None: 448 first_call = self.rpc.invoke() 449 self.assertFalse(first_call.completed()) 450 451 second_call = self.rpc.invoke() 452 453 self.assertIs(first_call.error, Status.CANCELLED) 454 self.assertFalse(second_call.completed()) 455 456 def test_nonblocking_exception_in_callback(self) -> None: 457 exception = ValueError('something went wrong!') 458 459 self._enqueue_response(1, self.method, Status.OK) 460 461 call = self.rpc.invoke(on_completed=mock.Mock(side_effect=exception)) 462 463 with self.assertRaises(RuntimeError) as context: 464 call.wait() 465 466 self.assertEqual(context.exception.__cause__, exception) 467 468 def test_unary_response(self) -> None: 469 proto = self._protos.packages.pw.test1.SomeMessage(magic_number=123) 470 self.assertEqual( 471 repr(callback_client.UnaryResponse(Status.ABORTED, proto)), 472 '(Status.ABORTED, pw.test1.SomeMessage(magic_number=123))', 473 ) 474 self.assertEqual( 475 repr(callback_client.UnaryResponse(Status.OK, None)), 476 '(Status.OK, None)', 477 ) 478 479 480class ServerStreamingTest(_CallbackClientImplTestBase): 481 """Tests for server streaming RPCs.""" 482 483 def setUp(self) -> None: 484 super().setUp() 485 self.rpc = self._service.SomeServerStreaming 486 self.method = self.rpc.method 487 488 def test_blocking_call(self) -> None: 489 rep1 = self.method.response_type(payload='!!!') 490 rep2 = self.method.response_type(payload='?') 491 492 for _ in range(3): 493 self._enqueue_server_stream(1, self.method, rep1) 494 self._enqueue_server_stream(1, self.method, rep2) 495 self._enqueue_response(1, self.method, Status.ABORTED) 496 497 self.assertEqual( 498 [rep1, rep2], 499 self._service.SomeServerStreaming(magic_number=4).responses, 500 ) 501 502 self.assertEqual( 503 4, self._sent_payload(self.method.request_type).magic_number 504 ) 505 506 def test_nonblocking_call(self) -> None: 507 rep1 = self.method.response_type(payload='!!!') 508 rep2 = self.method.response_type(payload='?') 509 510 for _ in range(3): 511 self._enqueue_server_stream(1, self.method, rep1) 512 self._enqueue_server_stream(1, self.method, rep2) 513 self._enqueue_response(1, self.method, Status.ABORTED) 514 515 callback = mock.Mock() 516 call = self.rpc.invoke( 517 self._request(magic_number=3), callback, callback 518 ) 519 520 callback.assert_has_calls( 521 [ 522 mock.call(call, self.method.response_type(payload='!!!')), 523 mock.call(call, self.method.response_type(payload='?')), 524 mock.call(call, Status.ABORTED), 525 ] 526 ) 527 528 self.assertEqual( 529 3, self._sent_payload(self.method.request_type).magic_number 530 ) 531 532 def test_open(self) -> None: 533 self.output_exception = IOError('something went wrong sending!') 534 rep1 = self.method.response_type(payload='!!!') 535 rep2 = self.method.response_type(payload='?') 536 537 for _ in range(3): 538 self._enqueue_server_stream(1, self.method, rep1) 539 self._enqueue_server_stream(1, self.method, rep2) 540 self._enqueue_response(1, self.method, Status.ABORTED) 541 542 callback = mock.Mock() 543 call = self.rpc.open( 544 self._request(magic_number=3), callback, callback 545 ) 546 self.assertEqual(self.requests, []) 547 548 self._process_enqueued_packets() 549 550 callback.assert_has_calls( 551 [ 552 mock.call(call, self.method.response_type(payload='!!!')), 553 mock.call(call, self.method.response_type(payload='?')), 554 mock.call(call, Status.ABORTED), 555 ] 556 ) 557 558 def test_nonblocking_cancel(self) -> None: 559 resp = self.rpc.method.response_type(payload='!!!') 560 self._enqueue_server_stream(1, self.rpc.method, resp) 561 562 callback = mock.Mock() 563 call = self.rpc.invoke(self._request(magic_number=3), callback) 564 callback.assert_called_once_with( 565 call, self.rpc.method.response_type(payload='!!!') 566 ) 567 568 callback.reset_mock() 569 570 call.cancel() 571 572 self.assertEqual( 573 self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR 574 ) 575 self.assertEqual(self.last_request().status, Status.CANCELLED.value) 576 577 # Ensure the RPC can be called after being cancelled. 578 self._enqueue_server_stream(1, self.method, resp) 579 self._enqueue_response(1, self.method, Status.OK) 580 581 call = self.rpc.invoke( 582 self._request(magic_number=3), callback, callback 583 ) 584 585 callback.assert_has_calls( 586 [ 587 mock.call(call, self.method.response_type(payload='!!!')), 588 mock.call(call, Status.OK), 589 ] 590 ) 591 592 def test_nonblocking_with_request_args(self) -> None: 593 self.rpc.invoke(request_args=dict(magic_number=1138)) 594 self.assertEqual( 595 self._sent_payload(self.rpc.request).magic_number, 1138 596 ) 597 598 def test_blocking_timeout(self) -> None: 599 with self.assertRaises(callback_client.RpcTimeout): 600 self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001) 601 602 def test_nonblocking_iteration_timeout(self) -> None: 603 call = self._service.SomeServerStreaming.invoke(timeout_s=0.0001) 604 with self.assertRaises(callback_client.RpcTimeout): 605 for _ in call: 606 pass 607 608 def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None: 609 first_call = self.rpc.invoke() 610 self.assertFalse(first_call.completed()) 611 612 second_call = self.rpc.invoke() 613 614 self.assertIs(first_call.error, Status.CANCELLED) 615 self.assertFalse(second_call.completed()) 616 617 def test_nonblocking_iterate_over_count(self) -> None: 618 reply = self.method.response_type(payload='!?') 619 620 for _ in range(4): 621 self._enqueue_server_stream(1, self.method, reply) 622 623 call = self.rpc.invoke() 624 625 self.assertEqual(list(call.get_responses(count=1)), [reply]) 626 self.assertEqual(next(iter(call)), reply) 627 self.assertEqual(list(call.get_responses(count=2)), [reply, reply]) 628 629 def test_nonblocking_iterate_after_completed_doesnt_block(self) -> None: 630 reply = self.method.response_type(payload='!?') 631 self._enqueue_server_stream(1, self.method, reply) 632 self._enqueue_response(1, self.method, Status.OK) 633 634 call = self.rpc.invoke() 635 636 self.assertEqual(list(call.get_responses()), [reply]) 637 self.assertEqual(list(call.get_responses()), []) 638 self.assertEqual(list(call), []) 639 640 641class ClientStreamingTest(_CallbackClientImplTestBase): 642 """Tests for client streaming RPCs.""" 643 644 def setUp(self) -> None: 645 super().setUp() 646 self.rpc = self._service.SomeClientStreaming 647 self.method = self.rpc.method 648 649 def test_blocking_call(self) -> None: 650 requests = [ 651 self.method.request_type(magic_number=123), 652 self.method.request_type(magic_number=456), 653 ] 654 655 # Send after len(requests) and the client stream end packet. 656 self.send_responses_after_packets = 3 657 response = self.method.response_type(payload='yo') 658 self._enqueue_response(1, self.method, Status.OK, response) 659 660 results = self.rpc(requests) 661 self.assertIs(results.status, Status.OK) 662 self.assertEqual(results.response, response) 663 664 def test_blocking_server_error(self) -> None: 665 requests = [self.method.request_type(magic_number=123)] 666 667 # Send after len(requests) and the client stream end packet. 668 self._enqueue_error( 669 1, self.method.service, self.method, Status.NOT_FOUND 670 ) 671 672 with self.assertRaises(callback_client.RpcError) as context: 673 self.rpc(requests) 674 675 self.assertIs(context.exception.status, Status.NOT_FOUND) 676 677 def test_nonblocking_call(self) -> None: 678 """Tests a successful client streaming RPC ended by the server.""" 679 payload_1 = self.method.response_type(payload='-_-') 680 681 for _ in range(3): 682 stream = self._service.SomeClientStreaming.invoke() 683 self.assertFalse(stream.completed()) 684 685 stream.send(magic_number=31) 686 self.assertIs( 687 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 688 ) 689 self.assertEqual( 690 31, self._sent_payload(self.method.request_type).magic_number 691 ) 692 self.assertFalse(stream.completed()) 693 694 # Enqueue the server response to be sent after the next message. 695 self._enqueue_response(1, self.method, Status.OK, payload_1) 696 697 stream.send(magic_number=32) 698 self.assertIs( 699 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 700 ) 701 self.assertEqual( 702 32, self._sent_payload(self.method.request_type).magic_number 703 ) 704 705 self.assertTrue(stream.completed()) 706 self.assertIs(Status.OK, stream.status) 707 self.assertIsNone(stream.error) 708 self.assertEqual(payload_1, stream.response) 709 710 def test_open(self) -> None: 711 self.output_exception = IOError('something went wrong sending!') 712 payload = self.method.response_type(payload='-_-') 713 714 for _ in range(3): 715 self._enqueue_response(1, self.method, Status.OK, payload) 716 717 callback = mock.Mock() 718 call = self.rpc.open(callback, callback, callback) 719 self.assertEqual(self.requests, []) 720 721 self._process_enqueued_packets() 722 723 callback.assert_has_calls( 724 [ 725 mock.call(call, payload), 726 mock.call(call, Status.OK), 727 ] 728 ) 729 730 def test_nonblocking_finish(self) -> None: 731 """Tests a client streaming RPC ended by the client.""" 732 payload_1 = self.method.response_type(payload='-_-') 733 734 for _ in range(3): 735 stream = self._service.SomeClientStreaming.invoke() 736 self.assertFalse(stream.completed()) 737 738 stream.send(magic_number=37) 739 self.assertIs( 740 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 741 ) 742 self.assertEqual( 743 37, self._sent_payload(self.method.request_type).magic_number 744 ) 745 self.assertFalse(stream.completed()) 746 747 # Enqueue the server response to be sent after the next message. 748 self._enqueue_response(1, self.method, Status.OK, payload_1) 749 750 stream.finish_and_wait() 751 self.assertIs( 752 packet_pb2.PacketType.CLIENT_STREAM_END, 753 self.last_request().type, 754 ) 755 756 self.assertTrue(stream.completed()) 757 self.assertIs(Status.OK, stream.status) 758 self.assertIsNone(stream.error) 759 self.assertEqual(payload_1, stream.response) 760 761 def test_nonblocking_cancel(self) -> None: 762 for _ in range(3): 763 stream = self._service.SomeClientStreaming.invoke() 764 stream.send(magic_number=37) 765 766 self.assertTrue(stream.cancel()) 767 self.assertIs( 768 packet_pb2.PacketType.CLIENT_ERROR, self.last_request().type 769 ) 770 self.assertIs(Status.CANCELLED.value, self.last_request().status) 771 self.assertFalse(stream.cancel()) 772 773 self.assertTrue(stream.completed()) 774 self.assertIs(stream.error, Status.CANCELLED) 775 776 def test_nonblocking_server_error(self) -> None: 777 for _ in range(3): 778 stream = self._service.SomeClientStreaming.invoke() 779 780 self._enqueue_error( 781 1, self.method.service, self.method, Status.INVALID_ARGUMENT 782 ) 783 stream.send(magic_number=2**32 - 1) 784 785 with self.assertRaises(callback_client.RpcError) as context: 786 stream.finish_and_wait() 787 788 self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) 789 790 def test_nonblocking_server_error_after_stream_end(self) -> None: 791 for _ in range(3): 792 stream = self._service.SomeClientStreaming.invoke() 793 794 # Error will be sent in response to the CLIENT_STREAM_END packet. 795 self._enqueue_error( 796 1, self.method.service, self.method, Status.INVALID_ARGUMENT 797 ) 798 799 with self.assertRaises(callback_client.RpcError) as context: 800 stream.finish_and_wait() 801 802 self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) 803 804 def test_nonblocking_send_after_cancelled(self) -> None: 805 call = self._service.SomeClientStreaming.invoke() 806 self.assertTrue(call.cancel()) 807 808 with self.assertRaises(callback_client.RpcError) as context: 809 call.send(payload='hello') 810 811 self.assertIs(context.exception.status, Status.CANCELLED) 812 813 def test_nonblocking_finish_after_completed(self) -> None: 814 reply = self.method.response_type(payload='!?') 815 self._enqueue_response(1, self.method, Status.UNAVAILABLE, reply) 816 817 call = self.rpc.invoke() 818 result = call.finish_and_wait() 819 self.assertEqual(result.response, reply) 820 821 self.assertEqual(result, call.finish_and_wait()) 822 self.assertEqual(result, call.finish_and_wait()) 823 824 def test_nonblocking_finish_after_error(self) -> None: 825 self._enqueue_error( 826 1, self.method.service, self.method, Status.UNAVAILABLE 827 ) 828 829 call = self.rpc.invoke() 830 831 for _ in range(3): 832 with self.assertRaises(callback_client.RpcError) as context: 833 call.finish_and_wait() 834 835 self.assertIs(context.exception.status, Status.UNAVAILABLE) 836 self.assertIs(call.error, Status.UNAVAILABLE) 837 self.assertIsNone(call.response) 838 839 def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None: 840 first_call = self.rpc.invoke() 841 self.assertFalse(first_call.completed()) 842 843 second_call = self.rpc.invoke() 844 845 self.assertIs(first_call.error, Status.CANCELLED) 846 self.assertFalse(second_call.completed()) 847 848 849class BidirectionalStreamingTest(_CallbackClientImplTestBase): 850 """Tests for bidirectional streaming RPCs.""" 851 852 def setUp(self) -> None: 853 super().setUp() 854 self.rpc = self._service.SomeBidiStreaming 855 self.method = self.rpc.method 856 857 def test_blocking_call(self) -> None: 858 requests = [ 859 self.method.request_type(magic_number=123), 860 self.method.request_type(magic_number=456), 861 ] 862 863 # Send after len(requests) and the client stream end packet. 864 self.send_responses_after_packets = 3 865 self._enqueue_response(1, self.method, Status.NOT_FOUND) 866 867 results = self.rpc(requests) 868 self.assertIs(results.status, Status.NOT_FOUND) 869 self.assertFalse(results.responses) 870 871 def test_blocking_server_error(self) -> None: 872 requests = [self.method.request_type(magic_number=123)] 873 874 # Send after len(requests) and the client stream end packet. 875 self._enqueue_error( 876 1, self.method.service, self.method, Status.NOT_FOUND 877 ) 878 879 with self.assertRaises(callback_client.RpcError) as context: 880 self.rpc(requests) 881 882 self.assertIs(context.exception.status, Status.NOT_FOUND) 883 884 def test_nonblocking_call(self) -> None: 885 """Tests a bidirectional streaming RPC ended by the server.""" 886 rep1 = self.method.response_type(payload='!!!') 887 rep2 = self.method.response_type(payload='?') 888 889 for _ in range(3): 890 responses: list = [] 891 stream = self._service.SomeBidiStreaming.invoke( 892 lambda _, res, responses=responses: responses.append(res) 893 ) 894 self.assertFalse(stream.completed()) 895 896 stream.send(magic_number=55) 897 self.assertIs( 898 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 899 ) 900 self.assertEqual( 901 55, self._sent_payload(self.method.request_type).magic_number 902 ) 903 self.assertFalse(stream.completed()) 904 self.assertEqual([], responses) 905 906 self._enqueue_server_stream(1, self.method, rep1) 907 self._enqueue_server_stream(1, self.method, rep2) 908 909 stream.send(magic_number=66) 910 self.assertIs( 911 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 912 ) 913 self.assertEqual( 914 66, self._sent_payload(self.method.request_type).magic_number 915 ) 916 self.assertFalse(stream.completed()) 917 self.assertEqual([rep1, rep2], responses) 918 919 self._enqueue_response(1, self.method, Status.OK) 920 921 stream.send(magic_number=77) 922 self.assertTrue(stream.completed()) 923 self.assertEqual([rep1, rep2], responses) 924 925 self.assertIs(Status.OK, stream.status) 926 self.assertIsNone(stream.error) 927 928 def test_open(self) -> None: 929 self.output_exception = IOError('something went wrong sending!') 930 rep1 = self.method.response_type(payload='!!!') 931 rep2 = self.method.response_type(payload='?') 932 933 for _ in range(3): 934 self._enqueue_server_stream(1, self.method, rep1) 935 self._enqueue_server_stream(1, self.method, rep2) 936 self._enqueue_response(1, self.method, Status.OK) 937 938 callback = mock.Mock() 939 call = self.rpc.open(callback, callback, callback) 940 self.assertEqual(self.requests, []) 941 942 self._process_enqueued_packets() 943 944 callback.assert_has_calls( 945 [ 946 mock.call(call, self.method.response_type(payload='!!!')), 947 mock.call(call, self.method.response_type(payload='?')), 948 mock.call(call, Status.OK), 949 ] 950 ) 951 952 @mock.patch('pw_rpc.callback_client.call.Call._default_response') 953 def test_nonblocking(self, callback) -> None: 954 """Tests a bidirectional streaming RPC ended by the server.""" 955 reply = self.method.response_type(payload='This is the payload!') 956 self._enqueue_server_stream(1, self.method, reply) 957 958 self._service.SomeBidiStreaming.invoke() 959 960 callback.assert_called_once_with(mock.ANY, reply) 961 962 def test_nonblocking_server_error(self) -> None: 963 rep1 = self.method.response_type(payload='!!!') 964 965 for _ in range(3): 966 responses: list = [] 967 stream = self._service.SomeBidiStreaming.invoke( 968 lambda _, res, responses=responses: responses.append(res) 969 ) 970 self.assertFalse(stream.completed()) 971 972 self._enqueue_server_stream(1, self.method, rep1) 973 974 stream.send(magic_number=55) 975 self.assertFalse(stream.completed()) 976 self.assertEqual([rep1], responses) 977 978 self._enqueue_error( 979 1, self.method.service, self.method, Status.OUT_OF_RANGE 980 ) 981 982 stream.send(magic_number=99999) 983 self.assertTrue(stream.completed()) 984 self.assertEqual([rep1], responses) 985 986 self.assertIsNone(stream.status) 987 self.assertIs(Status.OUT_OF_RANGE, stream.error) 988 989 with self.assertRaises(callback_client.RpcError) as context: 990 stream.finish_and_wait() 991 self.assertIs(context.exception.status, Status.OUT_OF_RANGE) 992 993 def test_nonblocking_server_error_after_stream_end(self) -> None: 994 for _ in range(3): 995 stream = self._service.SomeBidiStreaming.invoke() 996 997 # Error will be sent in response to the CLIENT_STREAM_END packet. 998 self._enqueue_error( 999 1, self.method.service, self.method, Status.INVALID_ARGUMENT 1000 ) 1001 1002 with self.assertRaises(callback_client.RpcError) as context: 1003 stream.finish_and_wait() 1004 1005 self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) 1006 1007 def test_nonblocking_send_after_cancelled(self) -> None: 1008 call = self._service.SomeBidiStreaming.invoke() 1009 self.assertTrue(call.cancel()) 1010 1011 with self.assertRaises(callback_client.RpcError) as context: 1012 call.send(payload='hello') 1013 1014 self.assertIs(context.exception.status, Status.CANCELLED) 1015 1016 def test_nonblocking_finish_after_completed(self) -> None: 1017 reply = self.method.response_type(payload='!?') 1018 self._enqueue_server_stream(1, self.method, reply) 1019 self._enqueue_response(1, self.method, Status.UNAVAILABLE) 1020 1021 call = self.rpc.invoke() 1022 result = call.finish_and_wait() 1023 self.assertEqual(result.responses, [reply]) 1024 1025 self.assertEqual(result, call.finish_and_wait()) 1026 self.assertEqual(result, call.finish_and_wait()) 1027 1028 def test_nonblocking_finish_after_error(self) -> None: 1029 reply = self.method.response_type(payload='!?') 1030 self._enqueue_server_stream(1, self.method, reply) 1031 self._enqueue_error( 1032 1, self.method.service, self.method, Status.UNAVAILABLE 1033 ) 1034 1035 call = self.rpc.invoke() 1036 1037 for _ in range(3): 1038 with self.assertRaises(callback_client.RpcError) as context: 1039 call.finish_and_wait() 1040 1041 self.assertIs(context.exception.status, Status.UNAVAILABLE) 1042 self.assertIs(call.error, Status.UNAVAILABLE) 1043 self.assertEqual(call.responses, [reply]) 1044 1045 def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None: 1046 first_call = self.rpc.invoke() 1047 self.assertFalse(first_call.completed()) 1048 1049 second_call = self.rpc.invoke() 1050 1051 self.assertIs(first_call.error, Status.CANCELLED) 1052 self.assertFalse(second_call.completed()) 1053 1054 def test_stream_response(self) -> None: 1055 proto = self._protos.packages.pw.test1.SomeMessage(magic_number=123) 1056 self.assertEqual( 1057 repr(callback_client.StreamResponse(Status.ABORTED, [proto] * 2)), 1058 '(Status.ABORTED, [pw.test1.SomeMessage(magic_number=123), ' 1059 'pw.test1.SomeMessage(magic_number=123)])', 1060 ) 1061 self.assertEqual( 1062 repr(callback_client.StreamResponse(Status.OK, [])), 1063 '(Status.OK, [])', 1064 ) 1065 1066 1067if __name__ == '__main__': 1068 unittest.main() 1069