1#!/usr/bin/env python3 2# Copyright 2021 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests using the callback client for pw_rpc.""" 16 17import unittest 18from unittest import mock 19from typing import List, Tuple 20 21from pw_protobuf_compiler import python_protos 22from pw_status import Status 23 24from pw_rpc import callback_client, client, packets 25from pw_rpc.internal import packet_pb2 26 27TEST_PROTO_1 = """\ 28syntax = "proto3"; 29 30package pw.test1; 31 32message SomeMessage { 33 uint32 magic_number = 1; 34} 35 36message AnotherMessage { 37 enum Result { 38 FAILED = 0; 39 FAILED_MISERABLY = 1; 40 I_DONT_WANT_TO_TALK_ABOUT_IT = 2; 41 } 42 43 Result result = 1; 44 string payload = 2; 45} 46 47service PublicService { 48 rpc SomeUnary(SomeMessage) returns (AnotherMessage) {} 49 rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {} 50 rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {} 51 rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {} 52} 53""" 54 55 56def _rpc(method_stub): 57 return client.PendingRpc(method_stub.channel, method_stub.method.service, 58 method_stub.method) 59 60 61class CallbackClientImplTest(unittest.TestCase): 62 """Tests the callback_client as used within a pw_rpc Client.""" 63 def setUp(self): 64 self._protos = python_protos.Library.from_strings(TEST_PROTO_1) 65 self._request = self._protos.packages.pw.test1.SomeMessage 66 67 self._client = client.Client.from_modules( 68 callback_client.Impl(), [client.Channel(1, self._handle_request)], 69 self._protos.modules()) 70 self._service = self._client.channel(1).rpcs.pw.test1.PublicService 71 72 self._last_request: packet_pb2.RpcPacket = None 73 self._next_packets: List[Tuple[bytes, Status]] = [] 74 self._send_responses_on_request = True 75 76 def _enqueue_response(self, 77 channel_id: int, 78 method=None, 79 status: Status = Status.OK, 80 response=b'', 81 *, 82 ids: Tuple[int, int] = None, 83 process_status=Status.OK): 84 if method: 85 assert ids is None 86 service_id, method_id = method.service.id, method.id 87 else: 88 assert ids is not None and method is None 89 service_id, method_id = ids 90 91 if isinstance(response, bytes): 92 payload = response 93 else: 94 payload = response.SerializeToString() 95 96 self._next_packets.append( 97 (packet_pb2.RpcPacket(type=packet_pb2.PacketType.RESPONSE, 98 channel_id=channel_id, 99 service_id=service_id, 100 method_id=method_id, 101 status=status.value, 102 payload=payload).SerializeToString(), 103 process_status)) 104 105 def _enqueue_stream_end(self, 106 channel_id: int, 107 method, 108 status: Status = Status.OK, 109 process_status=Status.OK): 110 self._next_packets.append( 111 (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_STREAM_END, 112 channel_id=channel_id, 113 service_id=method.service.id, 114 method_id=method.id, 115 status=status.value).SerializeToString(), 116 process_status)) 117 118 def _enqueue_error(self, 119 channel_id: int, 120 method, 121 status: Status, 122 process_status=Status.OK): 123 self._next_packets.append( 124 (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_ERROR, 125 channel_id=channel_id, 126 service_id=method.service.id, 127 method_id=method.id, 128 status=status.value).SerializeToString(), 129 process_status)) 130 131 def _handle_request(self, data: bytes): 132 # Disable this method to prevent infinite recursion if processing the 133 # packet happens to send another packet. 134 if not self._send_responses_on_request: 135 return 136 137 self._send_responses_on_request = False 138 139 self._last_request = packets.decode(data) 140 141 for packet, status in self._next_packets: 142 self.assertIs(status, self._client.process_packet(packet)) 143 144 self._next_packets.clear() 145 self._send_responses_on_request = True 146 147 def _sent_payload(self, message_type): 148 self.assertIsNotNone(self._last_request) 149 message = message_type() 150 message.ParseFromString(self._last_request.payload) 151 return message 152 153 def test_invoke_unary_rpc(self): 154 method = self._service.SomeUnary.method 155 156 for _ in range(3): 157 self._enqueue_response(1, method, Status.ABORTED, 158 method.response_type(payload='0_o')) 159 160 status, response = self._service.SomeUnary( 161 method.request_type(magic_number=6)) 162 163 self.assertEqual( 164 6, 165 self._sent_payload(method.request_type).magic_number) 166 167 self.assertIs(Status.ABORTED, status) 168 self.assertEqual('0_o', response.payload) 169 170 def test_invoke_unary_rpc_keep_open(self) -> None: 171 method = self._service.SomeUnary.method 172 173 payload_1 = method.response_type(payload='-_-') 174 payload_2 = method.response_type(payload='0_o') 175 176 self._enqueue_response(1, method, Status.ABORTED, payload_1) 177 178 replies: list = [] 179 enqueue_replies = lambda _, reply: replies.append(reply) 180 181 self._service.SomeUnary.invoke(method.request_type(magic_number=6), 182 enqueue_replies, 183 enqueue_replies, 184 keep_open=True) 185 186 self.assertEqual([payload_1, Status.ABORTED], replies) 187 188 # Send another packet and make sure it is processed even though the RPC 189 # terminated. 190 self._client.process_packet( 191 packet_pb2.RpcPacket( 192 type=packet_pb2.PacketType.RESPONSE, 193 channel_id=1, 194 service_id=method.service.id, 195 method_id=method.id, 196 status=Status.OK.value, 197 payload=payload_2.SerializeToString()).SerializeToString()) 198 199 self.assertEqual([payload_1, Status.ABORTED, payload_2, Status.OK], 200 replies) 201 202 def test_invoke_unary_rpc_with_callback(self): 203 method = self._service.SomeUnary.method 204 205 for _ in range(3): 206 self._enqueue_response(1, method, Status.ABORTED, 207 method.response_type(payload='0_o')) 208 209 callback = mock.Mock() 210 self._service.SomeUnary.invoke(self._request(magic_number=5), 211 callback, callback) 212 213 callback.assert_has_calls([ 214 mock.call(_rpc(self._service.SomeUnary), 215 method.response_type(payload='0_o')), 216 mock.call(_rpc(self._service.SomeUnary), Status.ABORTED) 217 ]) 218 219 self.assertEqual( 220 5, 221 self._sent_payload(method.request_type).magic_number) 222 223 def test_unary_rpc_server_error(self): 224 method = self._service.SomeUnary.method 225 226 for _ in range(3): 227 self._enqueue_error(1, method, Status.NOT_FOUND) 228 229 with self.assertRaises(callback_client.RpcError) as context: 230 self._service.SomeUnary(method.request_type(magic_number=6)) 231 232 self.assertIs(context.exception.status, Status.NOT_FOUND) 233 234 def test_invoke_unary_rpc_callback_exceptions_suppressed(self): 235 stub = self._service.SomeUnary 236 237 self._enqueue_response(1, stub.method) 238 exception_msg = 'YOU BROKE IT O-]-<' 239 240 with self.assertLogs(callback_client.__name__, 'ERROR') as logs: 241 stub.invoke(self._request(), 242 mock.Mock(side_effect=Exception(exception_msg))) 243 244 self.assertIn(exception_msg, ''.join(logs.output)) 245 246 # Make sure we can still invoke the RPC. 247 self._enqueue_response(1, stub.method, Status.UNKNOWN) 248 status, _ = stub() 249 self.assertIs(status, Status.UNKNOWN) 250 251 def test_invoke_unary_rpc_with_callback_cancel(self): 252 callback = mock.Mock() 253 254 for _ in range(3): 255 call = self._service.SomeUnary.invoke( 256 self._request(magic_number=55), callback) 257 258 self.assertIsNotNone(self._last_request) 259 self._last_request = None 260 261 # Try to invoke the RPC again before cancelling, without overriding 262 # pending RPCs. 263 with self.assertRaises(client.Error): 264 self._service.SomeUnary.invoke(self._request(magic_number=56), 265 callback, 266 override_pending=False) 267 268 self.assertTrue(call.cancel()) 269 self.assertFalse(call.cancel()) # Already cancelled, returns False 270 271 # Unary RPCs do not send a cancel request to the server. 272 self.assertIsNone(self._last_request) 273 274 callback.assert_not_called() 275 276 def test_reinvoke_unary_rpc(self): 277 for _ in range(3): 278 self._last_request = None 279 self._service.SomeUnary.invoke(self._request(magic_number=55), 280 override_pending=True) 281 self.assertEqual(self._last_request.type, 282 packet_pb2.PacketType.REQUEST) 283 284 def test_invoke_server_streaming(self): 285 method = self._service.SomeServerStreaming.method 286 287 rep1 = method.response_type(payload='!!!') 288 rep2 = method.response_type(payload='?') 289 290 for _ in range(3): 291 self._enqueue_response(1, method, response=rep1) 292 self._enqueue_response(1, method, response=rep2) 293 self._enqueue_stream_end(1, method, Status.ABORTED) 294 295 self.assertEqual( 296 [rep1, rep2], 297 list(self._service.SomeServerStreaming(magic_number=4))) 298 299 self.assertEqual( 300 4, 301 self._sent_payload(method.request_type).magic_number) 302 303 def test_invoke_server_streaming_with_callbacks(self): 304 method = self._service.SomeServerStreaming.method 305 306 rep1 = method.response_type(payload='!!!') 307 rep2 = method.response_type(payload='?') 308 309 for _ in range(3): 310 self._enqueue_response(1, method, response=rep1) 311 self._enqueue_response(1, method, response=rep2) 312 self._enqueue_stream_end(1, method, Status.ABORTED) 313 314 callback = mock.Mock() 315 self._service.SomeServerStreaming.invoke( 316 self._request(magic_number=3), callback, callback) 317 318 rpc = _rpc(self._service.SomeServerStreaming) 319 callback.assert_has_calls([ 320 mock.call(rpc, method.response_type(payload='!!!')), 321 mock.call(rpc, method.response_type(payload='?')), 322 mock.call(rpc, Status.ABORTED), 323 ]) 324 325 self.assertEqual( 326 3, 327 self._sent_payload(method.request_type).magic_number) 328 329 def test_invoke_server_streaming_with_callback_cancel(self): 330 stub = self._service.SomeServerStreaming 331 332 resp = stub.method.response_type(payload='!!!') 333 self._enqueue_response(1, stub.method, response=resp) 334 335 callback = mock.Mock() 336 call = stub.invoke(self._request(magic_number=3), callback) 337 callback.assert_called_once_with( 338 _rpc(stub), stub.method.response_type(payload='!!!')) 339 340 callback.reset_mock() 341 342 call.cancel() 343 344 self.assertEqual(self._last_request.type, 345 packet_pb2.PacketType.CANCEL_SERVER_STREAM) 346 347 # Ensure the RPC can be called after being cancelled. 348 self._enqueue_response(1, stub.method, response=resp) 349 self._enqueue_stream_end(1, stub.method, Status.OK) 350 351 call = stub.invoke(self._request(magic_number=3), callback, callback) 352 353 callback.assert_has_calls([ 354 mock.call(_rpc(stub), stub.method.response_type(payload='!!!')), 355 mock.call(_rpc(stub), Status.OK), 356 ]) 357 358 def test_ignore_bad_packets_with_pending_rpc(self): 359 method = self._service.SomeUnary.method 360 service_id = method.service.id 361 362 # Unknown channel 363 self._enqueue_response(999, method, process_status=Status.NOT_FOUND) 364 # Bad service 365 self._enqueue_response(1, 366 ids=(999, method.id), 367 process_status=Status.OK) 368 # Bad method 369 self._enqueue_response(1, 370 ids=(service_id, 999), 371 process_status=Status.OK) 372 # For RPC not pending (is Status.OK because the packet is processed) 373 self._enqueue_response(1, 374 ids=(service_id, 375 self._service.SomeBidiStreaming.method.id), 376 process_status=Status.OK) 377 378 self._enqueue_response(1, method, process_status=Status.OK) 379 380 status, response = self._service.SomeUnary(magic_number=6) 381 self.assertIs(Status.OK, status) 382 self.assertEqual('', response.payload) 383 384 def test_pass_none_if_payload_fails_to_decode(self): 385 method = self._service.SomeUnary.method 386 387 self._enqueue_response(1, 388 method, 389 Status.OK, 390 b'INVALID DATA!!!', 391 process_status=Status.OK) 392 393 status, response = self._service.SomeUnary(magic_number=6) 394 self.assertIs(status, Status.OK) 395 self.assertIsNone(response) 396 397 def test_rpc_help_contains_method_name(self): 398 rpc = self._service.SomeUnary 399 self.assertIn(rpc.method.full_name, rpc.help()) 400 401 def test_default_timeouts_set_on_impl(self): 402 impl = callback_client.Impl(None, 1.5) 403 404 self.assertEqual(impl.default_unary_timeout_s, None) 405 self.assertEqual(impl.default_stream_timeout_s, 1.5) 406 407 def test_default_timeouts_set_for_all_rpcs(self): 408 rpc_client = client.Client.from_modules(callback_client.Impl( 409 99, 100), [client.Channel(1, lambda *a, **b: None)], 410 self._protos.modules()) 411 rpcs = rpc_client.channel(1).rpcs 412 413 self.assertEqual( 414 rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99) 415 self.assertEqual( 416 rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s, 417 100) 418 419 def test_timeout_unary(self): 420 with self.assertRaises(callback_client.RpcTimeout): 421 self._service.SomeUnary(pw_rpc_timeout_s=0.0001) 422 423 def test_timeout_unary_set_default(self): 424 self._service.SomeUnary.default_timeout_s = 0.0001 425 426 with self.assertRaises(callback_client.RpcTimeout): 427 self._service.SomeUnary() 428 429 def test_timeout_server_streaming_iteration(self): 430 responses = self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001) 431 with self.assertRaises(callback_client.RpcTimeout): 432 for _ in responses: 433 pass 434 435 def test_timeout_server_streaming_responses(self): 436 responses = self._service.SomeServerStreaming() 437 with self.assertRaises(callback_client.RpcTimeout): 438 for _ in responses.responses(timeout_s=0.0001): 439 pass 440 441 442if __name__ == '__main__': 443 unittest.main() 444