• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2021 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests using the callback client for pw_rpc."""
16
17import unittest
18from unittest import mock
19from typing import List, Tuple
20
21from pw_protobuf_compiler import python_protos
22from pw_status import Status
23
24from pw_rpc import callback_client, client, packets
25from pw_rpc.internal import packet_pb2
26
27TEST_PROTO_1 = """\
28syntax = "proto3";
29
30package pw.test1;
31
32message SomeMessage {
33  uint32 magic_number = 1;
34}
35
36message AnotherMessage {
37  enum Result {
38    FAILED = 0;
39    FAILED_MISERABLY = 1;
40    I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
41  }
42
43  Result result = 1;
44  string payload = 2;
45}
46
47service PublicService {
48  rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
49  rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
50  rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
51  rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
52}
53"""
54
55
56def _rpc(method_stub):
57    return client.PendingRpc(method_stub.channel, method_stub.method.service,
58                             method_stub.method)
59
60
61class CallbackClientImplTest(unittest.TestCase):
62    """Tests the callback_client as used within a pw_rpc Client."""
63    def setUp(self):
64        self._protos = python_protos.Library.from_strings(TEST_PROTO_1)
65        self._request = self._protos.packages.pw.test1.SomeMessage
66
67        self._client = client.Client.from_modules(
68            callback_client.Impl(), [client.Channel(1, self._handle_request)],
69            self._protos.modules())
70        self._service = self._client.channel(1).rpcs.pw.test1.PublicService
71
72        self._last_request: packet_pb2.RpcPacket = None
73        self._next_packets: List[Tuple[bytes, Status]] = []
74        self._send_responses_on_request = True
75
76    def _enqueue_response(self,
77                          channel_id: int,
78                          method=None,
79                          status: Status = Status.OK,
80                          response=b'',
81                          *,
82                          ids: Tuple[int, int] = None,
83                          process_status=Status.OK):
84        if method:
85            assert ids is None
86            service_id, method_id = method.service.id, method.id
87        else:
88            assert ids is not None and method is None
89            service_id, method_id = ids
90
91        if isinstance(response, bytes):
92            payload = response
93        else:
94            payload = response.SerializeToString()
95
96        self._next_packets.append(
97            (packet_pb2.RpcPacket(type=packet_pb2.PacketType.RESPONSE,
98                                  channel_id=channel_id,
99                                  service_id=service_id,
100                                  method_id=method_id,
101                                  status=status.value,
102                                  payload=payload).SerializeToString(),
103             process_status))
104
105    def _enqueue_stream_end(self,
106                            channel_id: int,
107                            method,
108                            status: Status = Status.OK,
109                            process_status=Status.OK):
110        self._next_packets.append(
111            (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_STREAM_END,
112                                  channel_id=channel_id,
113                                  service_id=method.service.id,
114                                  method_id=method.id,
115                                  status=status.value).SerializeToString(),
116             process_status))
117
118    def _enqueue_error(self,
119                       channel_id: int,
120                       method,
121                       status: Status,
122                       process_status=Status.OK):
123        self._next_packets.append(
124            (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_ERROR,
125                                  channel_id=channel_id,
126                                  service_id=method.service.id,
127                                  method_id=method.id,
128                                  status=status.value).SerializeToString(),
129             process_status))
130
131    def _handle_request(self, data: bytes):
132        # Disable this method to prevent infinite recursion if processing the
133        # packet happens to send another packet.
134        if not self._send_responses_on_request:
135            return
136
137        self._send_responses_on_request = False
138
139        self._last_request = packets.decode(data)
140
141        for packet, status in self._next_packets:
142            self.assertIs(status, self._client.process_packet(packet))
143
144        self._next_packets.clear()
145        self._send_responses_on_request = True
146
147    def _sent_payload(self, message_type):
148        self.assertIsNotNone(self._last_request)
149        message = message_type()
150        message.ParseFromString(self._last_request.payload)
151        return message
152
153    def test_invoke_unary_rpc(self):
154        method = self._service.SomeUnary.method
155
156        for _ in range(3):
157            self._enqueue_response(1, method, Status.ABORTED,
158                                   method.response_type(payload='0_o'))
159
160            status, response = self._service.SomeUnary(
161                method.request_type(magic_number=6))
162
163            self.assertEqual(
164                6,
165                self._sent_payload(method.request_type).magic_number)
166
167            self.assertIs(Status.ABORTED, status)
168            self.assertEqual('0_o', response.payload)
169
170    def test_invoke_unary_rpc_keep_open(self) -> None:
171        method = self._service.SomeUnary.method
172
173        payload_1 = method.response_type(payload='-_-')
174        payload_2 = method.response_type(payload='0_o')
175
176        self._enqueue_response(1, method, Status.ABORTED, payload_1)
177
178        replies: list = []
179        enqueue_replies = lambda _, reply: replies.append(reply)
180
181        self._service.SomeUnary.invoke(method.request_type(magic_number=6),
182                                       enqueue_replies,
183                                       enqueue_replies,
184                                       keep_open=True)
185
186        self.assertEqual([payload_1, Status.ABORTED], replies)
187
188        # Send another packet and make sure it is processed even though the RPC
189        # terminated.
190        self._client.process_packet(
191            packet_pb2.RpcPacket(
192                type=packet_pb2.PacketType.RESPONSE,
193                channel_id=1,
194                service_id=method.service.id,
195                method_id=method.id,
196                status=Status.OK.value,
197                payload=payload_2.SerializeToString()).SerializeToString())
198
199        self.assertEqual([payload_1, Status.ABORTED, payload_2, Status.OK],
200                         replies)
201
202    def test_invoke_unary_rpc_with_callback(self):
203        method = self._service.SomeUnary.method
204
205        for _ in range(3):
206            self._enqueue_response(1, method, Status.ABORTED,
207                                   method.response_type(payload='0_o'))
208
209            callback = mock.Mock()
210            self._service.SomeUnary.invoke(self._request(magic_number=5),
211                                           callback, callback)
212
213            callback.assert_has_calls([
214                mock.call(_rpc(self._service.SomeUnary),
215                          method.response_type(payload='0_o')),
216                mock.call(_rpc(self._service.SomeUnary), Status.ABORTED)
217            ])
218
219            self.assertEqual(
220                5,
221                self._sent_payload(method.request_type).magic_number)
222
223    def test_unary_rpc_server_error(self):
224        method = self._service.SomeUnary.method
225
226        for _ in range(3):
227            self._enqueue_error(1, method, Status.NOT_FOUND)
228
229            with self.assertRaises(callback_client.RpcError) as context:
230                self._service.SomeUnary(method.request_type(magic_number=6))
231
232            self.assertIs(context.exception.status, Status.NOT_FOUND)
233
234    def test_invoke_unary_rpc_callback_exceptions_suppressed(self):
235        stub = self._service.SomeUnary
236
237        self._enqueue_response(1, stub.method)
238        exception_msg = 'YOU BROKE IT O-]-<'
239
240        with self.assertLogs(callback_client.__name__, 'ERROR') as logs:
241            stub.invoke(self._request(),
242                        mock.Mock(side_effect=Exception(exception_msg)))
243
244        self.assertIn(exception_msg, ''.join(logs.output))
245
246        # Make sure we can still invoke the RPC.
247        self._enqueue_response(1, stub.method, Status.UNKNOWN)
248        status, _ = stub()
249        self.assertIs(status, Status.UNKNOWN)
250
251    def test_invoke_unary_rpc_with_callback_cancel(self):
252        callback = mock.Mock()
253
254        for _ in range(3):
255            call = self._service.SomeUnary.invoke(
256                self._request(magic_number=55), callback)
257
258            self.assertIsNotNone(self._last_request)
259            self._last_request = None
260
261            # Try to invoke the RPC again before cancelling, without overriding
262            # pending RPCs.
263            with self.assertRaises(client.Error):
264                self._service.SomeUnary.invoke(self._request(magic_number=56),
265                                               callback,
266                                               override_pending=False)
267
268            self.assertTrue(call.cancel())
269            self.assertFalse(call.cancel())  # Already cancelled, returns False
270
271            # Unary RPCs do not send a cancel request to the server.
272            self.assertIsNone(self._last_request)
273
274        callback.assert_not_called()
275
276    def test_reinvoke_unary_rpc(self):
277        for _ in range(3):
278            self._last_request = None
279            self._service.SomeUnary.invoke(self._request(magic_number=55),
280                                           override_pending=True)
281            self.assertEqual(self._last_request.type,
282                             packet_pb2.PacketType.REQUEST)
283
284    def test_invoke_server_streaming(self):
285        method = self._service.SomeServerStreaming.method
286
287        rep1 = method.response_type(payload='!!!')
288        rep2 = method.response_type(payload='?')
289
290        for _ in range(3):
291            self._enqueue_response(1, method, response=rep1)
292            self._enqueue_response(1, method, response=rep2)
293            self._enqueue_stream_end(1, method, Status.ABORTED)
294
295            self.assertEqual(
296                [rep1, rep2],
297                list(self._service.SomeServerStreaming(magic_number=4)))
298
299            self.assertEqual(
300                4,
301                self._sent_payload(method.request_type).magic_number)
302
303    def test_invoke_server_streaming_with_callbacks(self):
304        method = self._service.SomeServerStreaming.method
305
306        rep1 = method.response_type(payload='!!!')
307        rep2 = method.response_type(payload='?')
308
309        for _ in range(3):
310            self._enqueue_response(1, method, response=rep1)
311            self._enqueue_response(1, method, response=rep2)
312            self._enqueue_stream_end(1, method, Status.ABORTED)
313
314            callback = mock.Mock()
315            self._service.SomeServerStreaming.invoke(
316                self._request(magic_number=3), callback, callback)
317
318            rpc = _rpc(self._service.SomeServerStreaming)
319            callback.assert_has_calls([
320                mock.call(rpc, method.response_type(payload='!!!')),
321                mock.call(rpc, method.response_type(payload='?')),
322                mock.call(rpc, Status.ABORTED),
323            ])
324
325            self.assertEqual(
326                3,
327                self._sent_payload(method.request_type).magic_number)
328
329    def test_invoke_server_streaming_with_callback_cancel(self):
330        stub = self._service.SomeServerStreaming
331
332        resp = stub.method.response_type(payload='!!!')
333        self._enqueue_response(1, stub.method, response=resp)
334
335        callback = mock.Mock()
336        call = stub.invoke(self._request(magic_number=3), callback)
337        callback.assert_called_once_with(
338            _rpc(stub), stub.method.response_type(payload='!!!'))
339
340        callback.reset_mock()
341
342        call.cancel()
343
344        self.assertEqual(self._last_request.type,
345                         packet_pb2.PacketType.CANCEL_SERVER_STREAM)
346
347        # Ensure the RPC can be called after being cancelled.
348        self._enqueue_response(1, stub.method, response=resp)
349        self._enqueue_stream_end(1, stub.method, Status.OK)
350
351        call = stub.invoke(self._request(magic_number=3), callback, callback)
352
353        callback.assert_has_calls([
354            mock.call(_rpc(stub), stub.method.response_type(payload='!!!')),
355            mock.call(_rpc(stub), Status.OK),
356        ])
357
358    def test_ignore_bad_packets_with_pending_rpc(self):
359        method = self._service.SomeUnary.method
360        service_id = method.service.id
361
362        # Unknown channel
363        self._enqueue_response(999, method, process_status=Status.NOT_FOUND)
364        # Bad service
365        self._enqueue_response(1,
366                               ids=(999, method.id),
367                               process_status=Status.OK)
368        # Bad method
369        self._enqueue_response(1,
370                               ids=(service_id, 999),
371                               process_status=Status.OK)
372        # For RPC not pending (is Status.OK because the packet is processed)
373        self._enqueue_response(1,
374                               ids=(service_id,
375                                    self._service.SomeBidiStreaming.method.id),
376                               process_status=Status.OK)
377
378        self._enqueue_response(1, method, process_status=Status.OK)
379
380        status, response = self._service.SomeUnary(magic_number=6)
381        self.assertIs(Status.OK, status)
382        self.assertEqual('', response.payload)
383
384    def test_pass_none_if_payload_fails_to_decode(self):
385        method = self._service.SomeUnary.method
386
387        self._enqueue_response(1,
388                               method,
389                               Status.OK,
390                               b'INVALID DATA!!!',
391                               process_status=Status.OK)
392
393        status, response = self._service.SomeUnary(magic_number=6)
394        self.assertIs(status, Status.OK)
395        self.assertIsNone(response)
396
397    def test_rpc_help_contains_method_name(self):
398        rpc = self._service.SomeUnary
399        self.assertIn(rpc.method.full_name, rpc.help())
400
401    def test_default_timeouts_set_on_impl(self):
402        impl = callback_client.Impl(None, 1.5)
403
404        self.assertEqual(impl.default_unary_timeout_s, None)
405        self.assertEqual(impl.default_stream_timeout_s, 1.5)
406
407    def test_default_timeouts_set_for_all_rpcs(self):
408        rpc_client = client.Client.from_modules(callback_client.Impl(
409            99, 100), [client.Channel(1, lambda *a, **b: None)],
410                                                self._protos.modules())
411        rpcs = rpc_client.channel(1).rpcs
412
413        self.assertEqual(
414            rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99)
415        self.assertEqual(
416            rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s,
417            100)
418
419    def test_timeout_unary(self):
420        with self.assertRaises(callback_client.RpcTimeout):
421            self._service.SomeUnary(pw_rpc_timeout_s=0.0001)
422
423    def test_timeout_unary_set_default(self):
424        self._service.SomeUnary.default_timeout_s = 0.0001
425
426        with self.assertRaises(callback_client.RpcTimeout):
427            self._service.SomeUnary()
428
429    def test_timeout_server_streaming_iteration(self):
430        responses = self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001)
431        with self.assertRaises(callback_client.RpcTimeout):
432            for _ in responses:
433                pass
434
435    def test_timeout_server_streaming_responses(self):
436        responses = self._service.SomeServerStreaming()
437        with self.assertRaises(callback_client.RpcTimeout):
438            for _ in responses.responses(timeout_s=0.0001):
439                pass
440
441
442if __name__ == '__main__':
443    unittest.main()
444