1#!/usr/bin/env python3 2# Copyright 2021 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests using the callback client for pw_rpc.""" 16 17import unittest 18from unittest import mock 19from typing import Any 20 21from pw_protobuf_compiler import python_protos 22from pw_status import Status 23 24from pw_rpc import callback_client, client, descriptors, packets 25from pw_rpc.internal import packet_pb2 26 27TEST_PROTO_1 = """\ 28syntax = "proto3"; 29 30package pw.test1; 31 32message SomeMessage { 33 uint32 magic_number = 1; 34} 35 36message AnotherMessage { 37 enum Result { 38 FAILED = 0; 39 FAILED_MISERABLY = 1; 40 I_DONT_WANT_TO_TALK_ABOUT_IT = 2; 41 } 42 43 Result result = 1; 44 string payload = 2; 45} 46 47service PublicService { 48 rpc SomeUnary(SomeMessage) returns (AnotherMessage) {} 49 rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {} 50 rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {} 51 rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {} 52} 53""" 54 55CLIENT_CHANNEL_ID: int = 489 56 57 58def _message_bytes(msg) -> bytes: 59 return msg if isinstance(msg, bytes) else msg.SerializeToString() 60 61 62class _CallbackClientImplTestBase(unittest.TestCase): 63 """Supports writing tests that require responses from an RPC server.""" 64 65 def setUp(self) -> None: 66 self._protos = python_protos.Library.from_strings(TEST_PROTO_1) 67 self._request = self._protos.packages.pw.test1.SomeMessage 68 69 self._client = client.Client.from_modules( 70 callback_client.Impl(), 71 [client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)], 72 self._protos.modules(), 73 ) 74 self._service = self._client.channel( 75 CLIENT_CHANNEL_ID 76 ).rpcs.pw.test1.PublicService 77 78 self.requests: list[packet_pb2.RpcPacket] = [] 79 self._next_packets: list[tuple[bytes, Status]] = [] 80 self.send_responses_after_packets: float = 1 81 82 self.output_exception: Exception | None = None 83 84 def last_request(self) -> packet_pb2.RpcPacket: 85 assert self.requests 86 return self.requests[-1] 87 88 def _enqueue_response( 89 self, 90 channel_id: int = CLIENT_CHANNEL_ID, 91 method: descriptors.Method | None = None, 92 status: Status = Status.OK, 93 payload: bytes = b'', 94 *, 95 ids: tuple[int, int] | None = None, 96 process_status: Status = Status.OK, 97 call_id: int = client.OPEN_CALL_ID, 98 ) -> None: 99 if method: 100 assert ids is None 101 service_id, method_id = method.service.id, method.id 102 else: 103 assert ids is not None and method is None 104 service_id, method_id = ids 105 106 self._next_packets.append( 107 ( 108 packet_pb2.RpcPacket( 109 type=packet_pb2.PacketType.RESPONSE, 110 channel_id=channel_id, 111 service_id=service_id, 112 method_id=method_id, 113 call_id=call_id, 114 status=status.value, 115 payload=_message_bytes(payload), 116 ).SerializeToString(), 117 process_status, 118 ) 119 ) 120 121 def _enqueue_server_stream( 122 self, 123 channel_id: int, 124 method, 125 response, 126 process_status=Status.OK, 127 call_id: int = client.OPEN_CALL_ID, 128 ) -> None: 129 self._next_packets.append( 130 ( 131 packet_pb2.RpcPacket( 132 type=packet_pb2.PacketType.SERVER_STREAM, 133 channel_id=channel_id, 134 service_id=method.service.id, 135 method_id=method.id, 136 call_id=call_id, 137 payload=_message_bytes(response), 138 ).SerializeToString(), 139 process_status, 140 ) 141 ) 142 143 def _enqueue_error( 144 self, 145 channel_id: int, 146 service, 147 method, 148 status: Status, 149 process_status=Status.OK, 150 call_id: int = client.OPEN_CALL_ID, 151 ) -> None: 152 self._next_packets.append( 153 ( 154 packet_pb2.RpcPacket( 155 type=packet_pb2.PacketType.SERVER_ERROR, 156 channel_id=channel_id, 157 service_id=service 158 if isinstance(service, int) 159 else service.id, 160 method_id=method if isinstance(method, int) else method.id, 161 call_id=call_id, 162 status=status.value, 163 ).SerializeToString(), 164 process_status, 165 ) 166 ) 167 168 def _handle_packet(self, data: bytes) -> None: 169 if self.output_exception: 170 raise self.output_exception # pylint: disable=raising-bad-type 171 172 self.requests.append(packets.decode(data)) 173 174 if self.send_responses_after_packets > 1: 175 self.send_responses_after_packets -= 1 176 return 177 178 self._process_enqueued_packets() 179 180 def _process_enqueued_packets(self) -> None: 181 # Set send_responses_after_packets to infinity to prevent potential 182 # infinite recursion when a packet causes another packet to send. 183 send_after_count = self.send_responses_after_packets 184 self.send_responses_after_packets = float('inf') 185 186 for packet, status in self._next_packets: 187 self.assertIs(status, self._client.process_packet(packet)) 188 189 self._next_packets.clear() 190 self.send_responses_after_packets = send_after_count 191 192 def _sent_payload(self, message_type: type) -> Any: 193 message = message_type() 194 message.ParseFromString(self.last_request().payload) 195 return message 196 197 198# Disable docstring requirements for test functions. 199# pylint: disable=missing-function-docstring 200 201 202class CallbackClientImplTest(_CallbackClientImplTestBase): 203 """Tests the callback_client.Impl client implementation.""" 204 205 def test_callback_exceptions_suppressed(self) -> None: 206 stub = self._service.SomeUnary 207 208 self._enqueue_response(CLIENT_CHANNEL_ID, stub.method) 209 exception_msg = 'YOU BROKE IT O-]-<' 210 211 with self.assertLogs(callback_client.__package__, 'ERROR') as logs: 212 stub.invoke( 213 self._request(), mock.Mock(side_effect=Exception(exception_msg)) 214 ) 215 216 self.assertIn(exception_msg, ''.join(logs.output)) 217 218 # Make sure we can still invoke the RPC. 219 self._enqueue_response(CLIENT_CHANNEL_ID, stub.method, Status.UNKNOWN) 220 status, _ = stub() 221 self.assertIs(status, Status.UNKNOWN) 222 223 def test_ignore_bad_packets_with_pending_rpc(self) -> None: 224 method = self._service.SomeUnary.method 225 service_id = method.service.id 226 227 # Unknown channel 228 self._enqueue_response(999, method, process_status=Status.NOT_FOUND) 229 # Bad service 230 self._enqueue_response( 231 CLIENT_CHANNEL_ID, ids=(999, method.id), process_status=Status.OK 232 ) 233 # Bad method 234 self._enqueue_response( 235 CLIENT_CHANNEL_ID, ids=(service_id, 999), process_status=Status.OK 236 ) 237 # For RPC not pending (is Status.OK because the packet is processed) 238 self._enqueue_response( 239 CLIENT_CHANNEL_ID, 240 ids=(service_id, self._service.SomeBidiStreaming.method.id), 241 process_status=Status.OK, 242 ) 243 244 self._enqueue_response( 245 CLIENT_CHANNEL_ID, method, process_status=Status.OK 246 ) 247 248 status, response = self._service.SomeUnary(magic_number=6) 249 self.assertIs(Status.OK, status) 250 self.assertEqual('', response.payload) 251 252 def test_server_error_for_unknown_call_sends_no_errors(self) -> None: 253 method = self._service.SomeUnary.method 254 service_id = method.service.id 255 256 # Unknown channel 257 self._enqueue_error( 258 999, 259 service_id, 260 method, 261 Status.NOT_FOUND, 262 process_status=Status.NOT_FOUND, 263 ) 264 # Bad service 265 self._enqueue_error( 266 CLIENT_CHANNEL_ID, 999, method.id, Status.INVALID_ARGUMENT 267 ) 268 # Bad method 269 self._enqueue_error( 270 CLIENT_CHANNEL_ID, service_id, 999, Status.INVALID_ARGUMENT 271 ) 272 # For RPC not pending 273 self._enqueue_error( 274 CLIENT_CHANNEL_ID, 275 service_id, 276 self._service.SomeBidiStreaming.method.id, 277 Status.NOT_FOUND, 278 ) 279 280 self._process_enqueued_packets() 281 282 self.assertEqual(self.requests, []) 283 284 def test_exception_if_payload_fails_to_decode(self) -> None: 285 method = self._service.SomeUnary.method 286 287 self._enqueue_response( 288 CLIENT_CHANNEL_ID, 289 method, 290 Status.OK, 291 b'INVALID DATA!!!', 292 process_status=Status.OK, 293 ) 294 295 with self.assertRaises(callback_client.RpcError) as context: 296 self._service.SomeUnary(magic_number=6) 297 298 self.assertIs(context.exception.status, Status.DATA_LOSS) 299 300 def test_rpc_help_contains_method_name(self) -> None: 301 rpc = self._service.SomeUnary 302 self.assertIn(rpc.method.full_name, rpc.help()) 303 304 def test_default_timeouts_set_on_impl(self) -> None: 305 impl = callback_client.Impl(None, 1.5) 306 307 self.assertEqual(impl.default_unary_timeout_s, None) 308 self.assertEqual(impl.default_stream_timeout_s, 1.5) 309 310 def test_default_timeouts_set_for_all_rpcs(self) -> None: 311 rpc_client = client.Client.from_modules( 312 callback_client.Impl(99, 100), 313 [client.Channel(CLIENT_CHANNEL_ID, lambda *a, **b: None)], 314 self._protos.modules(), 315 ) 316 rpcs = rpc_client.channel(CLIENT_CHANNEL_ID).rpcs 317 318 self.assertEqual( 319 rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99 320 ) 321 self.assertEqual( 322 rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s, 323 100, 324 ) 325 self.assertEqual( 326 rpcs.pw.test1.PublicService.SomeClientStreaming.default_timeout_s, 327 99, 328 ) 329 self.assertEqual( 330 rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s, 100 331 ) 332 333 def test_rpc_provides_request_type(self) -> None: 334 self.assertIs( 335 self._service.SomeUnary.request, 336 self._service.SomeUnary.method.request_type, 337 ) 338 339 def test_rpc_provides_response_type(self) -> None: 340 self.assertIs( 341 self._service.SomeUnary.request, 342 self._service.SomeUnary.method.request_type, 343 ) 344 345 346class UnaryTest(_CallbackClientImplTestBase): 347 """Tests for invoking a unary RPC.""" 348 349 def setUp(self) -> None: 350 super().setUp() 351 self.rpc = self._service.SomeUnary 352 self.method = self.rpc.method 353 354 def test_blocking_call(self) -> None: 355 for _ in range(3): 356 self._enqueue_response( 357 CLIENT_CHANNEL_ID, 358 self.method, 359 Status.ABORTED, 360 self.method.response_type(payload='0_o'), 361 ) 362 363 status, response = self._service.SomeUnary( 364 self.method.request_type(magic_number=6) 365 ) 366 367 self.assertEqual( 368 6, self._sent_payload(self.method.request_type).magic_number 369 ) 370 371 self.assertIs(Status.ABORTED, status) 372 self.assertEqual('0_o', response.payload) 373 374 def test_nonblocking_call(self) -> None: 375 for _ in range(3): 376 callback = mock.Mock() 377 call = self.rpc.invoke( 378 self._request(magic_number=5), callback, callback 379 ) 380 381 self._enqueue_response( 382 CLIENT_CHANNEL_ID, 383 self.method, 384 Status.ABORTED, 385 self.method.response_type(payload='0_o'), 386 call_id=call.call_id, 387 ) 388 self._process_enqueued_packets() 389 390 callback.assert_has_calls( 391 [ 392 mock.call(call, self.method.response_type(payload='0_o')), 393 mock.call(call, Status.ABORTED), 394 ] 395 ) 396 397 self.assertEqual( 398 5, self._sent_payload(self.method.request_type).magic_number 399 ) 400 401 def test_concurrent_nonblocking_calls(self) -> None: 402 # Start several calls to the same method 403 callbacks_and_calls: list[ 404 tuple[mock.Mock, callback_client.call.Call] 405 ] = [] 406 for _ in range(3): 407 callback = mock.Mock() 408 call = self.rpc.invoke(self._request(magic_number=5), callback) 409 callbacks_and_calls.append((callback, call)) 410 411 # Respond only to the last call 412 last_callback, last_call = callbacks_and_calls.pop() 413 last_payload = self.method.response_type(payload='last payload') 414 self._enqueue_response( 415 CLIENT_CHANNEL_ID, 416 self.method, 417 payload=last_payload, 418 call_id=last_call.call_id, 419 ) 420 self._process_enqueued_packets() 421 422 # Assert that only the last caller received a response 423 last_callback.assert_called_once_with(last_call, last_payload) 424 for remaining_callback, _ in callbacks_and_calls: 425 remaining_callback.assert_not_called() 426 427 # Respond to the other callers and check for receipt 428 other_payload = self.method.response_type(payload='other payload') 429 for callback, call in callbacks_and_calls: 430 self._enqueue_response( 431 CLIENT_CHANNEL_ID, 432 self.method, 433 payload=other_payload, 434 call_id=call.call_id, 435 ) 436 self._process_enqueued_packets() 437 callback.assert_called_once_with(call, other_payload) 438 439 def test_open(self) -> None: 440 self.output_exception = IOError('something went wrong sending!') 441 442 for _ in range(3): 443 self._enqueue_response( 444 CLIENT_CHANNEL_ID, 445 self.method, 446 Status.ABORTED, 447 self.method.response_type(payload='0_o'), 448 ) 449 450 callback = mock.Mock() 451 call = self.rpc.open( 452 self._request(magic_number=5), callback, callback 453 ) 454 self.assertEqual(self.requests, []) 455 456 self._process_enqueued_packets() 457 458 callback.assert_has_calls( 459 [ 460 mock.call(call, self.method.response_type(payload='0_o')), 461 mock.call(call, Status.ABORTED), 462 ] 463 ) 464 465 def test_blocking_server_error(self) -> None: 466 for _ in range(3): 467 self._enqueue_error( 468 CLIENT_CHANNEL_ID, 469 self.method.service, 470 self.method, 471 Status.NOT_FOUND, 472 ) 473 474 with self.assertRaises(callback_client.RpcError) as context: 475 self._service.SomeUnary( 476 self.method.request_type(magic_number=6) 477 ) 478 479 self.assertIs(context.exception.status, Status.NOT_FOUND) 480 481 def test_nonblocking_cancel(self) -> None: 482 callback = mock.Mock() 483 484 for _ in range(3): 485 call = self._service.SomeUnary.invoke( 486 self._request(magic_number=55), callback 487 ) 488 489 self.assertGreater(len(self.requests), 0) 490 self.requests.clear() 491 492 self.assertTrue(call.cancel()) 493 self.assertFalse(call.cancel()) # Already cancelled, returns False 494 495 self.assertEqual( 496 self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR 497 ) 498 self.assertEqual(self.last_request().status, Status.CANCELLED.value) 499 500 callback.assert_not_called() 501 502 def test_nonblocking_with_request_args(self) -> None: 503 self.rpc.invoke(request_args=dict(magic_number=1138)) 504 self.assertEqual( 505 self._sent_payload(self.rpc.request).magic_number, 1138 506 ) 507 508 def test_blocking_timeout_as_argument(self) -> None: 509 with self.assertRaises(callback_client.RpcTimeout): 510 self._service.SomeUnary(pw_rpc_timeout_s=0.0001) 511 512 def test_blocking_timeout_set_default(self) -> None: 513 self._service.SomeUnary.default_timeout_s = 0.0001 514 515 with self.assertRaises(callback_client.RpcTimeout): 516 self._service.SomeUnary() 517 518 def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: 519 first_call = self.rpc.invoke() 520 self.assertFalse(first_call.completed()) 521 522 second_call = self.rpc.invoke() 523 524 self.assertIs(first_call.error, None) 525 self.assertIs(second_call.error, None) 526 527 def test_nonblocking_exception_in_callback(self) -> None: 528 exception = ValueError('something went wrong! (intentionally)') 529 530 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) 531 532 call = self.rpc.invoke(on_completed=mock.Mock(side_effect=exception)) 533 534 with self.assertRaises(RuntimeError) as context: 535 call.wait() 536 537 self.assertEqual(context.exception.__cause__, exception) 538 539 def test_unary_response(self) -> None: 540 proto = self._protos.packages.pw.test1.SomeMessage(magic_number=123) 541 self.assertEqual( 542 repr(callback_client.UnaryResponse(Status.ABORTED, proto)), 543 '(Status.ABORTED, pw.test1.SomeMessage(magic_number=123))', 544 ) 545 self.assertEqual( 546 repr(callback_client.UnaryResponse(Status.OK, None)), 547 '(Status.OK, None)', 548 ) 549 550 def test_on_call_hook(self) -> None: 551 hook_function = mock.Mock() 552 553 self._client = client.Client.from_modules( 554 callback_client.Impl(on_call_hook=hook_function), 555 [client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)], 556 self._protos.modules(), 557 ) 558 559 self._service = self._client.channel( 560 CLIENT_CHANNEL_ID 561 ).rpcs.pw.test1.PublicService 562 563 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) 564 self._service.SomeUnary(self.method.request_type(magic_number=6)) 565 566 hook_function.assert_called_once() 567 self.assertEqual( 568 hook_function.call_args[0][0].method.full_name, 569 self.method.full_name, 570 ) 571 572 573class ServerStreamingTest(_CallbackClientImplTestBase): 574 """Tests for server streaming RPCs.""" 575 576 def setUp(self) -> None: 577 super().setUp() 578 self.rpc = self._service.SomeServerStreaming 579 self.method = self.rpc.method 580 581 def test_blocking_call(self) -> None: 582 rep1 = self.method.response_type(payload='!!!') 583 rep2 = self.method.response_type(payload='?') 584 585 for _ in range(3): 586 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) 587 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) 588 self._enqueue_response( 589 CLIENT_CHANNEL_ID, self.method, Status.ABORTED 590 ) 591 592 self.assertEqual( 593 [rep1, rep2], 594 self._service.SomeServerStreaming(magic_number=4).responses, 595 ) 596 597 self.assertEqual( 598 4, self._sent_payload(self.method.request_type).magic_number 599 ) 600 601 def test_nonblocking_call(self) -> None: 602 rep1 = self.method.response_type(payload='!!!') 603 rep2 = self.method.response_type(payload='?') 604 605 for _ in range(3): 606 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) 607 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) 608 self._enqueue_response( 609 CLIENT_CHANNEL_ID, self.method, Status.ABORTED 610 ) 611 612 callback = mock.Mock() 613 call = self.rpc.invoke( 614 self._request(magic_number=3), callback, callback 615 ) 616 617 callback.assert_has_calls( 618 [ 619 mock.call(call, self.method.response_type(payload='!!!')), 620 mock.call(call, self.method.response_type(payload='?')), 621 mock.call(call, Status.ABORTED), 622 ] 623 ) 624 625 self.assertEqual( 626 3, self._sent_payload(self.method.request_type).magic_number 627 ) 628 629 def test_open(self) -> None: 630 self.output_exception = IOError('something went wrong sending!') 631 rep1 = self.method.response_type(payload='!!!') 632 rep2 = self.method.response_type(payload='?') 633 634 for _ in range(3): 635 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) 636 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) 637 self._enqueue_response( 638 CLIENT_CHANNEL_ID, self.method, Status.ABORTED 639 ) 640 641 callback = mock.Mock() 642 call = self.rpc.open( 643 self._request(magic_number=3), callback, callback 644 ) 645 self.assertEqual(self.requests, []) 646 647 self._process_enqueued_packets() 648 649 callback.assert_has_calls( 650 [ 651 mock.call(call, self.method.response_type(payload='!!!')), 652 mock.call(call, self.method.response_type(payload='?')), 653 mock.call(call, Status.ABORTED), 654 ] 655 ) 656 657 def test_nonblocking_cancel(self) -> None: 658 resp = self.rpc.method.response_type(payload='!!!') 659 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.rpc.method, resp) 660 661 callback = mock.Mock() 662 call = self.rpc.invoke(self._request(magic_number=3), callback) 663 callback.assert_called_once_with( 664 call, self.rpc.method.response_type(payload='!!!') 665 ) 666 667 callback.reset_mock() 668 669 call.cancel() 670 671 self.assertEqual( 672 self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR 673 ) 674 self.assertEqual(self.last_request().status, Status.CANCELLED.value) 675 676 # Ensure the RPC can be called after being cancelled. 677 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, resp) 678 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) 679 680 call = self.rpc.invoke( 681 self._request(magic_number=3), callback, callback 682 ) 683 684 callback.assert_has_calls( 685 [ 686 mock.call(call, self.method.response_type(payload='!!!')), 687 mock.call(call, Status.OK), 688 ] 689 ) 690 691 def test_request_completion(self) -> None: 692 resp = self.rpc.method.response_type(payload='!!!') 693 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.rpc.method, resp) 694 695 callback = mock.Mock() 696 call = self.rpc.invoke(self._request(magic_number=3), callback) 697 callback.assert_called_once_with( 698 call, self.rpc.method.response_type(payload='!!!') 699 ) 700 701 callback.reset_mock() 702 703 call.request_completion() 704 705 self.assertEqual( 706 self.last_request().type, 707 packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION, 708 ) 709 710 # Ensure the RPC can be called after being completed. 711 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, resp) 712 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) 713 714 call = self.rpc.invoke( 715 self._request(magic_number=3), callback, callback 716 ) 717 718 callback.assert_has_calls( 719 [ 720 mock.call(call, self.method.response_type(payload='!!!')), 721 mock.call(call, Status.OK), 722 ] 723 ) 724 725 def test_nonblocking_with_request_args(self) -> None: 726 self.rpc.invoke(request_args=dict(magic_number=1138)) 727 self.assertEqual( 728 self._sent_payload(self.rpc.request).magic_number, 1138 729 ) 730 731 def test_blocking_timeout(self) -> None: 732 with self.assertRaises(callback_client.RpcTimeout): 733 self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001) 734 735 def test_nonblocking_iteration_timeout(self) -> None: 736 call = self._service.SomeServerStreaming.invoke(timeout_s=0.0001) 737 with self.assertRaises(callback_client.RpcTimeout): 738 for _ in call: 739 pass 740 741 def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: 742 first_call = self.rpc.invoke() 743 self.assertFalse(first_call.completed()) 744 745 second_call = self.rpc.invoke() 746 747 self.assertIs(first_call.error, None) 748 self.assertIs(second_call.error, None) 749 750 def test_nonblocking_iterate_over_count(self) -> None: 751 reply = self.method.response_type(payload='!?') 752 753 for _ in range(4): 754 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) 755 756 call = self.rpc.invoke() 757 758 self.assertEqual(list(call.get_responses(count=1)), [reply]) 759 self.assertEqual(next(iter(call)), reply) 760 self.assertEqual(list(call.get_responses(count=2)), [reply, reply]) 761 762 def test_nonblocking_iterate_after_completed_doesnt_block(self) -> None: 763 reply = self.method.response_type(payload='!?') 764 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) 765 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) 766 767 call = self.rpc.invoke() 768 769 self.assertEqual(list(call.get_responses()), [reply]) 770 self.assertEqual(list(call.get_responses()), []) 771 self.assertEqual(list(call), []) 772 773 774class ClientStreamingTest(_CallbackClientImplTestBase): 775 """Tests for client streaming RPCs.""" 776 777 def setUp(self) -> None: 778 super().setUp() 779 self.rpc = self._service.SomeClientStreaming 780 self.method = self.rpc.method 781 782 def test_blocking_call(self) -> None: 783 requests = [ 784 self.method.request_type(magic_number=123), 785 self.method.request_type(magic_number=456), 786 ] 787 788 # Send after len(requests) and the client stream end packet. 789 self.send_responses_after_packets = 3 790 response = self.method.response_type(payload='yo') 791 self._enqueue_response( 792 CLIENT_CHANNEL_ID, self.method, Status.OK, response 793 ) 794 795 results = self.rpc(requests) 796 self.assertIs(results.status, Status.OK) 797 self.assertEqual(results.response, response) 798 799 def test_blocking_server_error(self) -> None: 800 requests = [self.method.request_type(magic_number=123)] 801 802 # Send after len(requests) and the client stream end packet. 803 self._enqueue_error( 804 CLIENT_CHANNEL_ID, 805 self.method.service, 806 self.method, 807 Status.NOT_FOUND, 808 ) 809 810 with self.assertRaises(callback_client.RpcError) as context: 811 self.rpc(requests) 812 813 self.assertIs(context.exception.status, Status.NOT_FOUND) 814 815 def test_nonblocking_call(self) -> None: 816 """Tests a successful client streaming RPC ended by the server.""" 817 payload_1 = self.method.response_type(payload='-_-') 818 819 for _ in range(3): 820 stream = self._service.SomeClientStreaming.invoke() 821 self.assertFalse(stream.completed()) 822 823 stream.send(magic_number=31) 824 self.assertIs( 825 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 826 ) 827 self.assertEqual( 828 31, self._sent_payload(self.method.request_type).magic_number 829 ) 830 self.assertFalse(stream.completed()) 831 832 # Enqueue the server response to be sent after the next message. 833 self._enqueue_response( 834 CLIENT_CHANNEL_ID, self.method, Status.OK, payload_1 835 ) 836 837 stream.send(magic_number=32) 838 self.assertIs( 839 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 840 ) 841 self.assertEqual( 842 32, self._sent_payload(self.method.request_type).magic_number 843 ) 844 845 self.assertTrue(stream.completed()) 846 self.assertIs(Status.OK, stream.status) 847 self.assertIsNone(stream.error) 848 self.assertEqual(payload_1, stream.response) 849 850 def test_open(self) -> None: 851 self.output_exception = IOError('something went wrong sending!') 852 payload = self.method.response_type(payload='-_-') 853 854 for _ in range(3): 855 self._enqueue_response( 856 CLIENT_CHANNEL_ID, self.method, Status.OK, payload 857 ) 858 859 callback = mock.Mock() 860 call = self.rpc.open(callback, callback, callback) 861 self.assertEqual(self.requests, []) 862 863 self._process_enqueued_packets() 864 865 callback.assert_has_calls( 866 [ 867 mock.call(call, payload), 868 mock.call(call, Status.OK), 869 ] 870 ) 871 872 def test_nonblocking_finish(self) -> None: 873 """Tests a client streaming RPC ended by the client.""" 874 payload_1 = self.method.response_type(payload='-_-') 875 876 for _ in range(3): 877 stream = self._service.SomeClientStreaming.invoke() 878 self.assertFalse(stream.completed()) 879 880 stream.send(magic_number=37) 881 self.assertIs( 882 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 883 ) 884 self.assertEqual( 885 37, self._sent_payload(self.method.request_type).magic_number 886 ) 887 self.assertFalse(stream.completed()) 888 889 # Enqueue the server response to be sent after the next message. 890 self._enqueue_response( 891 CLIENT_CHANNEL_ID, self.method, Status.OK, payload_1 892 ) 893 894 stream.finish_and_wait() 895 self.assertIs( 896 packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION, 897 self.last_request().type, 898 ) 899 900 self.assertTrue(stream.completed()) 901 self.assertIs(Status.OK, stream.status) 902 self.assertIsNone(stream.error) 903 self.assertEqual(payload_1, stream.response) 904 905 def test_nonblocking_cancel(self) -> None: 906 for _ in range(3): 907 stream = self._service.SomeClientStreaming.invoke() 908 stream.send(magic_number=37) 909 910 self.assertTrue(stream.cancel()) 911 self.assertIs( 912 packet_pb2.PacketType.CLIENT_ERROR, self.last_request().type 913 ) 914 self.assertIs(Status.CANCELLED.value, self.last_request().status) 915 self.assertFalse(stream.cancel()) 916 917 self.assertTrue(stream.completed()) 918 self.assertIs(stream.error, Status.CANCELLED) 919 920 def test_nonblocking_server_error(self) -> None: 921 for _ in range(3): 922 stream = self._service.SomeClientStreaming.invoke() 923 924 self._enqueue_error( 925 CLIENT_CHANNEL_ID, 926 self.method.service, 927 self.method, 928 Status.INVALID_ARGUMENT, 929 ) 930 stream.send(magic_number=2**32 - 1) 931 932 with self.assertRaises(callback_client.RpcError) as context: 933 stream.finish_and_wait() 934 935 self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) 936 937 def test_nonblocking_server_error_after_stream_end(self) -> None: 938 for _ in range(3): 939 stream = self._service.SomeClientStreaming.invoke() 940 941 # Error will be sent in response to the CLIENT_REQUEST_COMPLETION 942 # packet. 943 self._enqueue_error( 944 CLIENT_CHANNEL_ID, 945 self.method.service, 946 self.method, 947 Status.INVALID_ARGUMENT, 948 ) 949 950 with self.assertRaises(callback_client.RpcError) as context: 951 stream.finish_and_wait() 952 953 self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) 954 955 def test_nonblocking_send_after_cancelled(self) -> None: 956 call = self._service.SomeClientStreaming.invoke() 957 self.assertTrue(call.cancel()) 958 959 with self.assertRaises(callback_client.RpcError) as context: 960 call.send(payload='hello') 961 962 self.assertIs(context.exception.status, Status.CANCELLED) 963 964 def test_nonblocking_finish_after_completed(self) -> None: 965 reply = self.method.response_type(payload='!?') 966 self._enqueue_response( 967 CLIENT_CHANNEL_ID, self.method, Status.UNAVAILABLE, reply 968 ) 969 970 call = self.rpc.invoke() 971 result = call.finish_and_wait() 972 self.assertEqual(result.response, reply) 973 974 self.assertEqual(result, call.finish_and_wait()) 975 self.assertEqual(result, call.finish_and_wait()) 976 977 def test_nonblocking_finish_after_error(self) -> None: 978 self._enqueue_error( 979 CLIENT_CHANNEL_ID, 980 self.method.service, 981 self.method, 982 Status.UNAVAILABLE, 983 ) 984 985 call = self.rpc.invoke() 986 987 for _ in range(3): 988 with self.assertRaises(callback_client.RpcError) as context: 989 call.finish_and_wait() 990 991 self.assertIs(context.exception.status, Status.UNAVAILABLE) 992 self.assertIs(call.error, Status.UNAVAILABLE) 993 self.assertIsNone(call.response) 994 995 def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: 996 first_call = self.rpc.invoke() 997 self.assertFalse(first_call.completed()) 998 999 second_call = self.rpc.invoke() 1000 1001 self.assertIs(first_call.error, None) 1002 self.assertIs(second_call.error, None) 1003 1004 1005class BidirectionalStreamingTest(_CallbackClientImplTestBase): 1006 """Tests for bidirectional streaming RPCs.""" 1007 1008 def setUp(self) -> None: 1009 super().setUp() 1010 self.rpc = self._service.SomeBidiStreaming 1011 self.method = self.rpc.method 1012 1013 def test_blocking_call(self) -> None: 1014 requests = [ 1015 self.method.request_type(magic_number=123), 1016 self.method.request_type(magic_number=456), 1017 ] 1018 1019 # Send after len(requests) and the client stream end packet. 1020 self.send_responses_after_packets = 3 1021 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.NOT_FOUND) 1022 1023 results = self.rpc(requests) 1024 self.assertIs(results.status, Status.NOT_FOUND) 1025 self.assertFalse(results.responses) 1026 1027 def test_blocking_server_error(self) -> None: 1028 requests = [self.method.request_type(magic_number=123)] 1029 1030 # Send after len(requests) and the client stream end packet. 1031 self._enqueue_error( 1032 CLIENT_CHANNEL_ID, 1033 self.method.service, 1034 self.method, 1035 Status.NOT_FOUND, 1036 ) 1037 1038 with self.assertRaises(callback_client.RpcError) as context: 1039 self.rpc(requests) 1040 1041 self.assertIs(context.exception.status, Status.NOT_FOUND) 1042 1043 def test_nonblocking_call(self) -> None: 1044 """Tests a bidirectional streaming RPC ended by the server.""" 1045 rep1 = self.method.response_type(payload='!!!') 1046 rep2 = self.method.response_type(payload='?') 1047 1048 for _ in range(3): 1049 responses: list = [] 1050 stream = self._service.SomeBidiStreaming.invoke( 1051 lambda _, res, responses=responses: responses.append(res) 1052 ) 1053 self.assertFalse(stream.completed()) 1054 1055 stream.send(magic_number=55) 1056 self.assertIs( 1057 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 1058 ) 1059 self.assertEqual( 1060 55, self._sent_payload(self.method.request_type).magic_number 1061 ) 1062 self.assertFalse(stream.completed()) 1063 self.assertEqual([], responses) 1064 1065 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) 1066 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) 1067 1068 stream.send(magic_number=66) 1069 self.assertIs( 1070 packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type 1071 ) 1072 self.assertEqual( 1073 66, self._sent_payload(self.method.request_type).magic_number 1074 ) 1075 self.assertFalse(stream.completed()) 1076 self.assertEqual([rep1, rep2], responses) 1077 1078 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) 1079 1080 stream.send(magic_number=77) 1081 self.assertTrue(stream.completed()) 1082 self.assertEqual([rep1, rep2], responses) 1083 1084 self.assertIs(Status.OK, stream.status) 1085 self.assertIsNone(stream.error) 1086 1087 def test_open(self) -> None: 1088 self.output_exception = IOError('something went wrong sending!') 1089 rep1 = self.method.response_type(payload='!!!') 1090 rep2 = self.method.response_type(payload='?') 1091 1092 for _ in range(3): 1093 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) 1094 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) 1095 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) 1096 1097 callback = mock.Mock() 1098 call = self.rpc.open(callback, callback, callback) 1099 self.assertEqual(self.requests, []) 1100 1101 self._process_enqueued_packets() 1102 1103 callback.assert_has_calls( 1104 [ 1105 mock.call(call, self.method.response_type(payload='!!!')), 1106 mock.call(call, self.method.response_type(payload='?')), 1107 mock.call(call, Status.OK), 1108 ] 1109 ) 1110 1111 @mock.patch('pw_rpc.callback_client.call.Call._default_response') 1112 def test_nonblocking(self, callback) -> None: 1113 """Tests a bidirectional streaming RPC ended by the server.""" 1114 reply = self.method.response_type(payload='This is the payload!') 1115 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) 1116 1117 self._service.SomeBidiStreaming.invoke() 1118 1119 callback.assert_called_once_with(mock.ANY, reply) 1120 1121 def test_nonblocking_server_error(self) -> None: 1122 rep1 = self.method.response_type(payload='!!!') 1123 1124 for _ in range(3): 1125 responses: list = [] 1126 stream = self._service.SomeBidiStreaming.invoke( 1127 lambda _, res, responses=responses: responses.append(res) 1128 ) 1129 self.assertFalse(stream.completed()) 1130 1131 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) 1132 1133 stream.send(magic_number=55) 1134 self.assertFalse(stream.completed()) 1135 self.assertEqual([rep1], responses) 1136 1137 self._enqueue_error( 1138 CLIENT_CHANNEL_ID, 1139 self.method.service, 1140 self.method, 1141 Status.OUT_OF_RANGE, 1142 ) 1143 1144 stream.send(magic_number=99999) 1145 self.assertTrue(stream.completed()) 1146 self.assertEqual([rep1], responses) 1147 1148 self.assertIsNone(stream.status) 1149 self.assertIs(Status.OUT_OF_RANGE, stream.error) 1150 1151 with self.assertRaises(callback_client.RpcError) as context: 1152 stream.finish_and_wait() 1153 self.assertIs(context.exception.status, Status.OUT_OF_RANGE) 1154 1155 def test_nonblocking_server_error_after_stream_end(self) -> None: 1156 for _ in range(3): 1157 stream = self._service.SomeBidiStreaming.invoke() 1158 1159 # Error will be sent in response to the CLIENT_REQUEST_COMPLETION 1160 # packet. 1161 self._enqueue_error( 1162 CLIENT_CHANNEL_ID, 1163 self.method.service, 1164 self.method, 1165 Status.INVALID_ARGUMENT, 1166 ) 1167 1168 with self.assertRaises(callback_client.RpcError) as context: 1169 stream.finish_and_wait() 1170 1171 self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) 1172 1173 def test_nonblocking_send_after_cancelled(self) -> None: 1174 call = self._service.SomeBidiStreaming.invoke() 1175 self.assertTrue(call.cancel()) 1176 1177 with self.assertRaises(callback_client.RpcError) as context: 1178 call.send(payload='hello') 1179 1180 self.assertIs(context.exception.status, Status.CANCELLED) 1181 1182 def test_nonblocking_finish_after_completed(self) -> None: 1183 reply = self.method.response_type(payload='!?') 1184 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) 1185 self._enqueue_response( 1186 CLIENT_CHANNEL_ID, self.method, Status.UNAVAILABLE 1187 ) 1188 1189 call = self.rpc.invoke() 1190 result = call.finish_and_wait() 1191 self.assertEqual(result.responses, [reply]) 1192 1193 self.assertEqual(result, call.finish_and_wait()) 1194 self.assertEqual(result, call.finish_and_wait()) 1195 1196 def test_nonblocking_finish_after_error(self) -> None: 1197 reply = self.method.response_type(payload='!?') 1198 self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) 1199 self._enqueue_error( 1200 CLIENT_CHANNEL_ID, 1201 self.method.service, 1202 self.method, 1203 Status.UNAVAILABLE, 1204 ) 1205 1206 call = self.rpc.invoke() 1207 1208 for _ in range(3): 1209 with self.assertRaises(callback_client.RpcError) as context: 1210 call.finish_and_wait() 1211 1212 self.assertIs(context.exception.status, Status.UNAVAILABLE) 1213 self.assertIs(call.error, Status.UNAVAILABLE) 1214 self.assertEqual(call.responses, [reply]) 1215 1216 def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: 1217 first_call = self.rpc.invoke() 1218 self.assertFalse(first_call.completed()) 1219 1220 second_call = self.rpc.invoke() 1221 1222 self.assertIs(first_call.error, None) 1223 self.assertIs(second_call.error, None) 1224 1225 def test_stream_response(self) -> None: 1226 proto = self._protos.packages.pw.test1.SomeMessage(magic_number=123) 1227 self.assertEqual( 1228 repr(callback_client.StreamResponse(Status.ABORTED, [proto] * 2)), 1229 '(Status.ABORTED, [pw.test1.SomeMessage(magic_number=123), ' 1230 'pw.test1.SomeMessage(magic_number=123)])', 1231 ) 1232 self.assertEqual( 1233 repr(callback_client.StreamResponse(Status.OK, [])), 1234 '(Status.OK, [])', 1235 ) 1236 1237 1238if __name__ == '__main__': 1239 unittest.main() 1240