#!/usr/bin/env python3 # Copyright 2021 The Pigweed Authors # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. """Tests using the callback client for pw_rpc.""" import unittest from unittest import mock from typing import Any from pw_protobuf_compiler import python_protos from pw_status import Status from pw_rpc import callback_client, client, descriptors, packets from pw_rpc.internal import packet_pb2 TEST_PROTO_1 = """\ syntax = "proto3"; package pw.test1; message SomeMessage { uint32 magic_number = 1; } message AnotherMessage { enum Result { FAILED = 0; FAILED_MISERABLY = 1; I_DONT_WANT_TO_TALK_ABOUT_IT = 2; } Result result = 1; string payload = 2; } service PublicService { rpc SomeUnary(SomeMessage) returns (AnotherMessage) {} rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {} rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {} rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {} } """ PROTOS = python_protos.Library.from_strings(TEST_PROTO_1) CLIENT_CHANNEL_ID: int = 489 def _message_bytes(msg) -> bytes: return msg if isinstance(msg, bytes) else msg.SerializeToString() class _CallbackClientImplTestBase(unittest.TestCase): """Supports writing tests that require responses from an RPC server.""" def setUp(self) -> None: self._request = PROTOS.packages.pw.test1.SomeMessage self._client = client.Client.from_modules( callback_client.Impl(), [client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)], PROTOS.modules(), ) self._service = self._client.channel( CLIENT_CHANNEL_ID ).rpcs.pw.test1.PublicService self.requests: list[packet_pb2.RpcPacket] = [] self._next_packets: list[tuple[bytes, Status]] = [] self.send_responses_after_packets: float = 1 self.output_exception: Exception | None = None def last_request(self) -> packet_pb2.RpcPacket: assert self.requests return self.requests[-1] def _enqueue_response( self, channel_id: int = CLIENT_CHANNEL_ID, method: descriptors.Method | None = None, status: Status = Status.OK, payload: bytes = b'', *, ids: tuple[int, int] | None = None, process_status: Status = Status.OK, call_id: int = client.OPEN_CALL_ID, ) -> None: if method: assert ids is None service_id, method_id = method.service.id, method.id else: assert ids is not None and method is None service_id, method_id = ids self._next_packets.append( ( packet_pb2.RpcPacket( type=packet_pb2.PacketType.RESPONSE, channel_id=channel_id, service_id=service_id, method_id=method_id, call_id=call_id, status=status.value, payload=_message_bytes(payload), ).SerializeToString(), process_status, ) ) def _enqueue_server_stream( self, channel_id: int, method, response, process_status=Status.OK, call_id: int = client.OPEN_CALL_ID, ) -> None: self._next_packets.append( ( packet_pb2.RpcPacket( type=packet_pb2.PacketType.SERVER_STREAM, channel_id=channel_id, service_id=method.service.id, method_id=method.id, call_id=call_id, payload=_message_bytes(response), ).SerializeToString(), process_status, ) ) def _enqueue_error( self, channel_id: int, service, method, status: Status, process_status=Status.OK, call_id: int = client.OPEN_CALL_ID, ) -> None: self._next_packets.append( ( packet_pb2.RpcPacket( type=packet_pb2.PacketType.SERVER_ERROR, channel_id=channel_id, service_id=service if isinstance(service, int) else service.id, method_id=method if isinstance(method, int) else method.id, call_id=call_id, status=status.value, ).SerializeToString(), process_status, ) ) def _handle_packet(self, data: bytes) -> None: if self.output_exception: raise self.output_exception # pylint: disable=raising-bad-type self.requests.append(packets.decode(data)) if self.send_responses_after_packets > 1: self.send_responses_after_packets -= 1 return self._process_enqueued_packets() def _process_enqueued_packets(self) -> None: # Set send_responses_after_packets to infinity to prevent potential # infinite recursion when a packet causes another packet to send. send_after_count = self.send_responses_after_packets self.send_responses_after_packets = float('inf') for packet, status in self._next_packets: self.assertIs(status, self._client.process_packet(packet)) self._next_packets.clear() self.send_responses_after_packets = send_after_count def _sent_payload(self, message_type: type) -> Any: message = message_type() message.ParseFromString(self.last_request().payload) return message # Disable docstring requirements for test functions. # pylint: disable=missing-function-docstring class CallbackClientImplTest(_CallbackClientImplTestBase): """Tests the callback_client.Impl client implementation.""" def test_callback_exceptions_suppressed(self) -> None: stub = self._service.SomeUnary self._enqueue_response(CLIENT_CHANNEL_ID, stub.method) exception_msg = 'YOU BROKE IT O-]-<' with self.assertLogs(callback_client.__package__, 'ERROR') as logs: stub.invoke( self._request(), mock.Mock(side_effect=Exception(exception_msg)) ) self.assertIn(exception_msg, ''.join(logs.output)) # Make sure we can still invoke the RPC. self._enqueue_response(CLIENT_CHANNEL_ID, stub.method, Status.UNKNOWN) status, _ = stub() self.assertIs(status, Status.UNKNOWN) def test_ignore_bad_packets_with_pending_rpc(self) -> None: method = self._service.SomeUnary.method service_id = method.service.id # Unknown channel self._enqueue_response(999, method, process_status=Status.NOT_FOUND) # Bad service self._enqueue_response( CLIENT_CHANNEL_ID, ids=(999, method.id), process_status=Status.OK ) # Bad method self._enqueue_response( CLIENT_CHANNEL_ID, ids=(service_id, 999), process_status=Status.OK ) # For RPC not pending (is Status.OK because the packet is processed) self._enqueue_response( CLIENT_CHANNEL_ID, ids=(service_id, self._service.SomeBidiStreaming.method.id), process_status=Status.OK, ) self._enqueue_response( CLIENT_CHANNEL_ID, method, process_status=Status.OK ) status, response = self._service.SomeUnary(magic_number=6) self.assertIs(Status.OK, status) self.assertEqual('', response.payload) def test_server_error_for_unknown_call_sends_no_errors(self) -> None: method = self._service.SomeUnary.method service_id = method.service.id # Unknown channel self._enqueue_error( 999, service_id, method, Status.NOT_FOUND, process_status=Status.NOT_FOUND, ) # Bad service self._enqueue_error( CLIENT_CHANNEL_ID, 999, method.id, Status.INVALID_ARGUMENT ) # Bad method self._enqueue_error( CLIENT_CHANNEL_ID, service_id, 999, Status.INVALID_ARGUMENT ) # For RPC not pending self._enqueue_error( CLIENT_CHANNEL_ID, service_id, self._service.SomeBidiStreaming.method.id, Status.NOT_FOUND, ) self._process_enqueued_packets() self.assertEqual(self.requests, []) def test_exception_if_payload_fails_to_decode(self) -> None: method = self._service.SomeUnary.method self._enqueue_response( CLIENT_CHANNEL_ID, method, Status.OK, b'INVALID DATA!!!', process_status=Status.OK, ) with self.assertRaises(callback_client.RpcError) as context: self._service.SomeUnary(magic_number=6) self.assertIs(context.exception.status, Status.DATA_LOSS) def test_rpc_help_contains_method_name(self) -> None: rpc = self._service.SomeUnary self.assertIn(rpc.method.full_name, rpc.help()) def test_default_timeouts_set_on_impl(self) -> None: impl = callback_client.Impl(None, 1.5) self.assertEqual(impl.default_unary_timeout_s, None) self.assertEqual(impl.default_stream_timeout_s, 1.5) def test_default_timeouts_set_for_all_rpcs(self) -> None: rpc_client = client.Client.from_modules( callback_client.Impl(99, 100), [client.Channel(CLIENT_CHANNEL_ID, lambda *a, **b: None)], PROTOS.modules(), ) rpcs = rpc_client.channel(CLIENT_CHANNEL_ID).rpcs self.assertEqual( rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99 ) self.assertEqual( rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s, 100, ) self.assertEqual( rpcs.pw.test1.PublicService.SomeClientStreaming.default_timeout_s, 99, ) self.assertEqual( rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s, 100 ) def test_rpc_provides_request_type(self) -> None: self.assertIs( self._service.SomeUnary.request, self._service.SomeUnary.method.request_type, ) def test_rpc_provides_response_type(self) -> None: self.assertIs( self._service.SomeUnary.request, self._service.SomeUnary.method.request_type, ) class UnaryTest(_CallbackClientImplTestBase): """Tests for invoking a unary RPC.""" def setUp(self) -> None: super().setUp() self.rpc = self._service.SomeUnary self.method = self.rpc.method def test_blocking_call(self) -> None: for _ in range(3): self._enqueue_response( CLIENT_CHANNEL_ID, self.method, Status.ABORTED, self.method.response_type(payload='0_o'), ) status, response = self._service.SomeUnary( self.method.request_type(magic_number=6) ) self.assertEqual( 6, self._sent_payload(self.method.request_type).magic_number ) self.assertIs(Status.ABORTED, status) self.assertEqual('0_o', response.payload) def test_nonblocking_call(self) -> None: for _ in range(3): callback = mock.Mock() call = self.rpc.invoke( self._request(magic_number=5), callback, callback ) self._enqueue_response( CLIENT_CHANNEL_ID, self.method, Status.ABORTED, self.method.response_type(payload='0_o'), call_id=call.call_id, ) self._process_enqueued_packets() callback.assert_has_calls( [ mock.call(call, self.method.response_type(payload='0_o')), mock.call(call, Status.ABORTED), ] ) self.assertEqual( 5, self._sent_payload(self.method.request_type).magic_number ) def test_concurrent_nonblocking_calls(self) -> None: # Start several calls to the same method callbacks_and_calls: list[ tuple[mock.Mock, callback_client.call.Call] ] = [] for _ in range(3): callback = mock.Mock() call = self.rpc.invoke(self._request(magic_number=5), callback) callbacks_and_calls.append((callback, call)) # Respond only to the last call last_callback, last_call = callbacks_and_calls.pop() last_payload = self.method.response_type(payload='last payload') self._enqueue_response( CLIENT_CHANNEL_ID, self.method, payload=last_payload, call_id=last_call.call_id, ) self._process_enqueued_packets() # Assert that only the last caller received a response last_callback.assert_called_once_with(last_call, last_payload) for remaining_callback, _ in callbacks_and_calls: remaining_callback.assert_not_called() # Respond to the other callers and check for receipt other_payload = self.method.response_type(payload='other payload') for callback, call in callbacks_and_calls: self._enqueue_response( CLIENT_CHANNEL_ID, self.method, payload=other_payload, call_id=call.call_id, ) self._process_enqueued_packets() callback.assert_called_once_with(call, other_payload) def test_open(self) -> None: self.output_exception = IOError('this test should not send packets!') for packet_id in (client.OPEN_CALL_ID, 123): for _ in range(3): self._enqueue_response( CLIENT_CHANNEL_ID, self.method, Status.ABORTED, self.method.response_type(payload='0_o'), call_id=packet_id, ) callback = mock.Mock() call = self.rpc.open(callback, callback, callback) self.assertEqual(self.requests, []) self._process_enqueued_packets() callback.assert_has_calls( [ mock.call( call, self.method.response_type(payload='0_o') ), mock.call(call, Status.ABORTED), ] ) self.assertEqual(call.call_id, packet_id, "Adopts inbound ID") def test_blocking_server_error(self) -> None: for _ in range(3): self._enqueue_error( CLIENT_CHANNEL_ID, self.method.service, self.method, Status.NOT_FOUND, ) with self.assertRaises(callback_client.RpcError) as context: self._service.SomeUnary( self.method.request_type(magic_number=6) ) self.assertIs(context.exception.status, Status.NOT_FOUND) def test_nonblocking_cancel(self) -> None: callback = mock.Mock() for _ in range(3): call = self._service.SomeUnary.invoke( self._request(magic_number=55), callback ) self.assertGreater(len(self.requests), 0) self.requests.clear() self.assertTrue(call.cancel()) self.assertFalse(call.cancel()) # Already cancelled, returns False self.assertEqual( self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR ) self.assertEqual(self.last_request().status, Status.CANCELLED.value) callback.assert_not_called() def test_nonblocking_with_request_args(self) -> None: self.rpc.invoke(request_args=dict(magic_number=1138)) self.assertEqual( self._sent_payload(self.rpc.request).magic_number, 1138 ) def test_blocking_timeout_as_argument(self) -> None: with self.assertRaises(callback_client.RpcTimeout): self._service.SomeUnary(pw_rpc_timeout_s=0.0001) def test_blocking_timeout_set_default(self) -> None: self._service.SomeUnary.default_timeout_s = 0.0001 with self.assertRaises(callback_client.RpcTimeout): self._service.SomeUnary() def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: first_call = self.rpc.invoke() self.assertFalse(first_call.completed()) second_call = self.rpc.invoke() self.assertIs(first_call.error, None) self.assertIs(second_call.error, None) def test_nonblocking_exception_in_callback(self) -> None: exception = ValueError('something went wrong! (intentionally)') self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) call = self.rpc.invoke(on_completed=mock.Mock(side_effect=exception)) with self.assertRaises(RuntimeError) as context: call.wait() self.assertEqual(context.exception.__cause__, exception) def test_unary_response(self) -> None: proto = PROTOS.packages.pw.test1.SomeMessage(magic_number=123) self.assertEqual( repr(callback_client.UnaryResponse(Status.ABORTED, proto)), '(Status.ABORTED, pw.test1.SomeMessage(magic_number=123))', ) self.assertEqual( repr(callback_client.UnaryResponse(Status.OK, None)), '(Status.OK, None)', ) def test_on_call_hook(self) -> None: hook_function = mock.Mock() self._client = client.Client.from_modules( callback_client.Impl(on_call_hook=hook_function), [client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)], PROTOS.modules(), ) self._service = self._client.channel( CLIENT_CHANNEL_ID ).rpcs.pw.test1.PublicService self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) self._service.SomeUnary(self.method.request_type(magic_number=6)) hook_function.assert_called_once() self.assertEqual( hook_function.call_args[0][0].method.full_name, self.method.full_name, ) class ServerStreamingTest(_CallbackClientImplTestBase): """Tests for server streaming RPCs.""" def setUp(self) -> None: super().setUp() self.rpc = self._service.SomeServerStreaming self.method = self.rpc.method def test_blocking_call(self) -> None: rep1 = self.method.response_type(payload='!!!') rep2 = self.method.response_type(payload='?') for _ in range(3): self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) self._enqueue_response( CLIENT_CHANNEL_ID, self.method, Status.ABORTED ) self.assertEqual( [rep1, rep2], self._service.SomeServerStreaming(magic_number=4).responses, ) self.assertEqual( 4, self._sent_payload(self.method.request_type).magic_number ) def test_nonblocking_call(self) -> None: rep1 = self.method.response_type(payload='!!!') rep2 = self.method.response_type(payload='?') for _ in range(3): self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) self._enqueue_response( CLIENT_CHANNEL_ID, self.method, Status.ABORTED ) callback = mock.Mock() call = self.rpc.invoke( self._request(magic_number=3), callback, callback ) callback.assert_has_calls( [ mock.call(call, self.method.response_type(payload='!!!')), mock.call(call, self.method.response_type(payload='?')), mock.call(call, Status.ABORTED), ] ) self.assertEqual( 3, self._sent_payload(self.method.request_type).magic_number ) def test_open(self) -> None: self.output_exception = IOError('this test should not send packets!') rep1 = self.method.response_type(payload='!!!') rep2 = self.method.response_type(payload='?') for packet_id in (client.OPEN_CALL_ID, 123): for _ in range(3): self._enqueue_server_stream( CLIENT_CHANNEL_ID, self.method, rep1, call_id=packet_id ) self._enqueue_server_stream( CLIENT_CHANNEL_ID, self.method, rep2, call_id=packet_id ) self._enqueue_response( CLIENT_CHANNEL_ID, self.method, Status.ABORTED, call_id=packet_id, ) callback = mock.Mock() call = self.rpc.open(callback, callback, callback) self.assertEqual(self.requests, []) self._process_enqueued_packets() callback.assert_has_calls( [ mock.call( call, self.method.response_type(payload='!!!') ), mock.call(call, self.method.response_type(payload='?')), mock.call(call, Status.ABORTED), ] ) self.assertEqual(call.call_id, packet_id, "Adopts inbound ID") def test_nonblocking_cancel(self) -> None: resp = self.rpc.method.response_type(payload='!!!') self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.rpc.method, resp) callback = mock.Mock() call = self.rpc.invoke(self._request(magic_number=3), callback) callback.assert_called_once_with( call, self.rpc.method.response_type(payload='!!!') ) callback.reset_mock() call.cancel() self.assertEqual( self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR ) self.assertEqual(self.last_request().status, Status.CANCELLED.value) # Ensure the RPC can be called after being cancelled. self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, resp) self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) call = self.rpc.invoke( self._request(magic_number=3), callback, callback ) callback.assert_has_calls( [ mock.call(call, self.method.response_type(payload='!!!')), mock.call(call, Status.OK), ] ) def test_request_completion(self) -> None: resp = self.rpc.method.response_type(payload='!!!') self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.rpc.method, resp) callback = mock.Mock() call = self.rpc.invoke(self._request(magic_number=3), callback) callback.assert_called_once_with( call, self.rpc.method.response_type(payload='!!!') ) callback.reset_mock() call.request_completion() self.assertEqual( self.last_request().type, packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION, ) def test_nonblocking_with_request_args(self) -> None: self.rpc.invoke(request_args=dict(magic_number=1138)) self.assertEqual( self._sent_payload(self.rpc.request).magic_number, 1138 ) def test_blocking_timeout(self) -> None: with self.assertRaises(callback_client.RpcTimeout): self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001) def test_nonblocking_iteration_timeout(self) -> None: call = self._service.SomeServerStreaming.invoke(timeout_s=0.0001) with self.assertRaises(callback_client.RpcTimeout): for _ in call: pass def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: first_call = self.rpc.invoke() self.assertFalse(first_call.completed()) second_call = self.rpc.invoke() self.assertIs(first_call.error, None) self.assertIs(second_call.error, None) def test_nonblocking_iterate_over_count(self) -> None: reply = self.method.response_type(payload='!?') for _ in range(4): self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) call = self.rpc.invoke() self.assertEqual(list(call.get_responses(count=1)), [reply]) self.assertEqual(next(iter(call)), reply) self.assertEqual(list(call.get_responses(count=2)), [reply, reply]) def test_nonblocking_iterate_after_completed_doesnt_block(self) -> None: reply = self.method.response_type(payload='!?') self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) call = self.rpc.invoke() self.assertEqual(list(call.get_responses()), [reply]) self.assertEqual(list(call.get_responses()), []) self.assertEqual(list(call), []) class ClientStreamingTest(_CallbackClientImplTestBase): """Tests for client streaming RPCs.""" def setUp(self) -> None: super().setUp() self.rpc = self._service.SomeClientStreaming self.method = self.rpc.method def test_blocking_call(self) -> None: requests = [ self.method.request_type(magic_number=123), self.method.request_type(magic_number=456), ] # Send after len(requests) and the client stream end packet. self.send_responses_after_packets = 3 response = self.method.response_type(payload='yo') self._enqueue_response( CLIENT_CHANNEL_ID, self.method, Status.OK, response ) results = self.rpc(requests) self.assertIs(results.status, Status.OK) self.assertEqual(results.response, response) def test_blocking_server_error(self) -> None: requests = [self.method.request_type(magic_number=123)] # Send after len(requests) and the client stream end packet. self._enqueue_error( CLIENT_CHANNEL_ID, self.method.service, self.method, Status.NOT_FOUND, ) with self.assertRaises(callback_client.RpcError) as context: self.rpc(requests) self.assertIs(context.exception.status, Status.NOT_FOUND) def test_nonblocking_call(self) -> None: """Tests a successful client streaming RPC ended by the server.""" payload_1 = self.method.response_type(payload='-_-') for _ in range(3): stream = self._service.SomeClientStreaming.invoke() self.assertFalse(stream.completed()) stream.send(magic_number=31) self.assertIs( packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type ) self.assertEqual( 31, self._sent_payload(self.method.request_type).magic_number ) self.assertFalse(stream.completed()) # Enqueue the server response to be sent after the next message. self._enqueue_response( CLIENT_CHANNEL_ID, self.method, Status.OK, payload_1 ) stream.send(magic_number=32) self.assertIs( packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type ) self.assertEqual( 32, self._sent_payload(self.method.request_type).magic_number ) self.assertTrue(stream.completed()) self.assertIs(Status.OK, stream.status) self.assertIsNone(stream.error) self.assertEqual(payload_1, stream.response) def test_open(self) -> None: self.output_exception = IOError('this test should not send packets!') payload = self.method.response_type(payload='-_-') for packet_id in (client.OPEN_CALL_ID, 123): for _ in range(3): self._enqueue_response( CLIENT_CHANNEL_ID, self.method, Status.OK, payload, call_id=packet_id, ) callback = mock.Mock() call = self.rpc.open(callback, callback, callback) self.assertEqual(self.requests, []) self._process_enqueued_packets() callback.assert_has_calls( [ mock.call(call, payload), mock.call(call, Status.OK), ] ) self.assertEqual(call.call_id, packet_id, "Adopts inbound ID") def test_nonblocking_finish(self) -> None: """Tests a client streaming RPC ended by the client.""" payload_1 = self.method.response_type(payload='-_-') for _ in range(3): stream = self._service.SomeClientStreaming.invoke() self.assertFalse(stream.completed()) stream.send(magic_number=37) self.assertIs( packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type ) self.assertEqual( 37, self._sent_payload(self.method.request_type).magic_number ) self.assertFalse(stream.completed()) # Enqueue the server response to be sent after the next message. self._enqueue_response( CLIENT_CHANNEL_ID, self.method, Status.OK, payload_1 ) stream.finish_and_wait() self.assertIs( packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION, self.last_request().type, ) self.assertTrue(stream.completed()) self.assertIs(Status.OK, stream.status) self.assertIsNone(stream.error) self.assertEqual(payload_1, stream.response) def test_nonblocking_cancel(self) -> None: for _ in range(3): stream = self._service.SomeClientStreaming.invoke() stream.send(magic_number=37) self.assertTrue(stream.cancel()) self.assertIs( packet_pb2.PacketType.CLIENT_ERROR, self.last_request().type ) self.assertIs(Status.CANCELLED.value, self.last_request().status) self.assertFalse(stream.cancel()) self.assertTrue(stream.completed()) self.assertIs(stream.error, Status.CANCELLED) def test_nonblocking_server_error(self) -> None: for _ in range(3): stream = self._service.SomeClientStreaming.invoke() self._enqueue_error( CLIENT_CHANNEL_ID, self.method.service, self.method, Status.INVALID_ARGUMENT, ) stream.send(magic_number=2**32 - 1) with self.assertRaises(callback_client.RpcError) as context: stream.finish_and_wait() self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) def test_nonblocking_server_error_after_stream_end(self) -> None: for _ in range(3): stream = self._service.SomeClientStreaming.invoke() # Error will be sent in response to the CLIENT_REQUEST_COMPLETION # packet. self._enqueue_error( CLIENT_CHANNEL_ID, self.method.service, self.method, Status.INVALID_ARGUMENT, ) with self.assertRaises(callback_client.RpcError) as context: stream.finish_and_wait() self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) def test_nonblocking_send_after_cancelled(self) -> None: call = self._service.SomeClientStreaming.invoke() self.assertTrue(call.cancel()) with self.assertRaises(callback_client.RpcError) as context: call.send(payload='hello') self.assertIs(context.exception.status, Status.CANCELLED) def test_nonblocking_finish_after_completed(self) -> None: reply = self.method.response_type(payload='!?') self._enqueue_response( CLIENT_CHANNEL_ID, self.method, Status.UNAVAILABLE, reply ) call = self.rpc.invoke() result = call.finish_and_wait() self.assertEqual(result.response, reply) self.assertEqual(result, call.finish_and_wait()) self.assertEqual(result, call.finish_and_wait()) def test_nonblocking_finish_after_error(self) -> None: self._enqueue_error( CLIENT_CHANNEL_ID, self.method.service, self.method, Status.UNAVAILABLE, ) call = self.rpc.invoke() for _ in range(3): with self.assertRaises(callback_client.RpcError) as context: call.finish_and_wait() self.assertIs(context.exception.status, Status.UNAVAILABLE) self.assertIs(call.error, Status.UNAVAILABLE) self.assertIsNone(call.response) def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: first_call = self.rpc.invoke() self.assertFalse(first_call.completed()) second_call = self.rpc.invoke() self.assertIs(first_call.error, None) self.assertIs(second_call.error, None) class BidirectionalStreamingTest(_CallbackClientImplTestBase): """Tests for bidirectional streaming RPCs.""" def setUp(self) -> None: super().setUp() self.rpc = self._service.SomeBidiStreaming self.method = self.rpc.method def test_blocking_call(self) -> None: requests = [ self.method.request_type(magic_number=123), self.method.request_type(magic_number=456), ] # Send after len(requests) and the client stream end packet. self.send_responses_after_packets = 3 self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.NOT_FOUND) results = self.rpc(requests) self.assertIs(results.status, Status.NOT_FOUND) self.assertFalse(results.responses) def test_blocking_server_error(self) -> None: requests = [self.method.request_type(magic_number=123)] # Send after len(requests) and the client stream end packet. self._enqueue_error( CLIENT_CHANNEL_ID, self.method.service, self.method, Status.NOT_FOUND, ) with self.assertRaises(callback_client.RpcError) as context: self.rpc(requests) self.assertIs(context.exception.status, Status.NOT_FOUND) def test_nonblocking_call(self) -> None: """Tests a bidirectional streaming RPC ended by the server.""" rep1 = self.method.response_type(payload='!!!') rep2 = self.method.response_type(payload='?') for _ in range(3): responses: list = [] stream = self._service.SomeBidiStreaming.invoke( lambda _, res, responses=responses: responses.append(res) ) self.assertFalse(stream.completed()) stream.send(magic_number=55) self.assertIs( packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type ) self.assertEqual( 55, self._sent_payload(self.method.request_type).magic_number ) self.assertFalse(stream.completed()) self.assertEqual([], responses) self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) stream.send(magic_number=66) self.assertIs( packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type ) self.assertEqual( 66, self._sent_payload(self.method.request_type).magic_number ) self.assertFalse(stream.completed()) self.assertEqual([rep1, rep2], responses) self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) stream.send(magic_number=77) self.assertTrue(stream.completed()) self.assertEqual([rep1, rep2], responses) self.assertIs(Status.OK, stream.status) self.assertIsNone(stream.error) def test_open(self) -> None: self.output_exception = IOError('this test should not send packets!') rep1 = self.method.response_type(payload='!!!') rep2 = self.method.response_type(payload='?') for packet_id in (client.OPEN_CALL_ID, 123): for _ in range(3): self._enqueue_server_stream( CLIENT_CHANNEL_ID, self.method, rep1, call_id=packet_id ) self._enqueue_server_stream( CLIENT_CHANNEL_ID, self.method, rep2, call_id=packet_id ) self._enqueue_response( CLIENT_CHANNEL_ID, self.method, Status.OK, call_id=packet_id ) callback = mock.Mock() call = self.rpc.open(callback, callback, callback) self.assertEqual(self.requests, []) self._process_enqueued_packets() callback.assert_has_calls( [ mock.call( call, self.method.response_type(payload='!!!') ), mock.call(call, self.method.response_type(payload='?')), mock.call(call, Status.OK), ] ) self.assertEqual(call.call_id, packet_id, "Adopts inbound ID") @mock.patch('pw_rpc.callback_client.call.Call._default_response') def test_nonblocking(self, callback) -> None: """Tests a bidirectional streaming RPC ended by the server.""" reply = self.method.response_type(payload='This is the payload!') self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) self._service.SomeBidiStreaming.invoke() callback.assert_called_once_with(mock.ANY, reply) def test_nonblocking_server_error(self) -> None: rep1 = self.method.response_type(payload='!!!') for _ in range(3): responses: list = [] stream = self._service.SomeBidiStreaming.invoke( lambda _, res, responses=responses: responses.append(res) ) self.assertFalse(stream.completed()) self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) stream.send(magic_number=55) self.assertFalse(stream.completed()) self.assertEqual([rep1], responses) self._enqueue_error( CLIENT_CHANNEL_ID, self.method.service, self.method, Status.OUT_OF_RANGE, ) stream.send(magic_number=99999) self.assertTrue(stream.completed()) self.assertEqual([rep1], responses) self.assertIsNone(stream.status) self.assertIs(Status.OUT_OF_RANGE, stream.error) with self.assertRaises(callback_client.RpcError) as context: stream.finish_and_wait() self.assertIs(context.exception.status, Status.OUT_OF_RANGE) def test_nonblocking_server_error_after_stream_end(self) -> None: for _ in range(3): stream = self._service.SomeBidiStreaming.invoke() # Error will be sent in response to the CLIENT_REQUEST_COMPLETION # packet. self._enqueue_error( CLIENT_CHANNEL_ID, self.method.service, self.method, Status.INVALID_ARGUMENT, ) with self.assertRaises(callback_client.RpcError) as context: stream.finish_and_wait() self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) def test_nonblocking_send_after_cancelled(self) -> None: call = self._service.SomeBidiStreaming.invoke() self.assertTrue(call.cancel()) with self.assertRaises(callback_client.RpcError) as context: call.send(payload='hello') self.assertIs(context.exception.status, Status.CANCELLED) def test_nonblocking_finish_after_completed(self) -> None: reply = self.method.response_type(payload='!?') self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) self._enqueue_response( CLIENT_CHANNEL_ID, self.method, Status.UNAVAILABLE ) call = self.rpc.invoke() result = call.finish_and_wait() self.assertEqual(result.responses, [reply]) self.assertEqual(result, call.finish_and_wait()) self.assertEqual(result, call.finish_and_wait()) def test_nonblocking_finish_after_error(self) -> None: reply = self.method.response_type(payload='!?') self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply) self._enqueue_error( CLIENT_CHANNEL_ID, self.method.service, self.method, Status.UNAVAILABLE, ) call = self.rpc.invoke() for _ in range(3): with self.assertRaises(callback_client.RpcError) as context: call.finish_and_wait() self.assertIs(context.exception.status, Status.UNAVAILABLE) self.assertIs(call.error, Status.UNAVAILABLE) self.assertEqual(list(call.responses), [reply]) def test_nonblocking_duplicate_calls_not_cancelled(self) -> None: first_call = self.rpc.invoke() self.assertFalse(first_call.completed()) second_call = self.rpc.invoke() self.assertIs(first_call.error, None) self.assertIs(second_call.error, None) def test_max_responses(self) -> None: rep1 = self.method.response_type(payload='a') rep2 = self.method.response_type(payload='b') rep3 = self.method.response_type(payload='c') rep4 = self.method.response_type(payload='d') rep5 = self.method.response_type(payload='e') self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1) self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2) self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep3) self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep4) self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep5) self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK) responses: list = [] call = self.rpc.invoke( on_next=lambda _, res, responses=responses: responses.append(res), max_responses=4, ) result = call.finish_and_wait() # All 5 responses are received, but only the most recent 4 are stored # in the call. self.assertEqual(responses, [rep1, rep2, rep3, rep4, rep5]) self.assertEqual(result.responses, [rep2, rep3, rep4, rep5]) self.assertEqual(result.responses, list(call.responses)) def test_stream_response(self) -> None: proto = PROTOS.packages.pw.test1.SomeMessage(magic_number=123) self.assertEqual( repr(callback_client.StreamResponse(Status.ABORTED, [proto] * 2)), '(Status.ABORTED, [pw.test1.SomeMessage(magic_number=123), ' 'pw.test1.SomeMessage(magic_number=123)])', ) self.assertEqual( repr(callback_client.StreamResponse(Status.OK, [])), '(Status.OK, [])', ) if __name__ == '__main__': unittest.main()