• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2021 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests using the callback client for pw_rpc."""
16
17import unittest
18from unittest import mock
19from typing import Any, List, Optional, Tuple
20
21from pw_protobuf_compiler import python_protos
22from pw_status import Status
23
24from pw_rpc import callback_client, client, packets
25from pw_rpc.internal import packet_pb2
26
27TEST_PROTO_1 = """\
28syntax = "proto3";
29
30package pw.test1;
31
32message SomeMessage {
33  uint32 magic_number = 1;
34}
35
36message AnotherMessage {
37  enum Result {
38    FAILED = 0;
39    FAILED_MISERABLY = 1;
40    I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
41  }
42
43  Result result = 1;
44  string payload = 2;
45}
46
47service PublicService {
48  rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
49  rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
50  rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
51  rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
52}
53"""
54
55
56def _message_bytes(msg) -> bytes:
57    return msg if isinstance(msg, bytes) else msg.SerializeToString()
58
59
60class _CallbackClientImplTestBase(unittest.TestCase):
61    """Supports writing tests that require responses from an RPC server."""
62    def setUp(self) -> None:
63        self._protos = python_protos.Library.from_strings(TEST_PROTO_1)
64        self._request = self._protos.packages.pw.test1.SomeMessage
65
66        self._client = client.Client.from_modules(
67            callback_client.Impl(), [client.Channel(1, self._handle_packet)],
68            self._protos.modules())
69        self._service = self._client.channel(1).rpcs.pw.test1.PublicService
70
71        self.requests: List[packet_pb2.RpcPacket] = []
72        self._next_packets: List[Tuple[bytes, Status]] = []
73        self.send_responses_after_packets: float = 1
74
75        self.output_exception: Optional[Exception] = None
76
77    def last_request(self) -> packet_pb2.RpcPacket:
78        assert self.requests
79        return self.requests[-1]
80
81    def _enqueue_response(self,
82                          channel_id: int,
83                          method=None,
84                          status: Status = Status.OK,
85                          payload=b'',
86                          *,
87                          ids: Tuple[int, int] = None,
88                          process_status=Status.OK) -> None:
89        if method:
90            assert ids is None
91            service_id, method_id = method.service.id, method.id
92        else:
93            assert ids is not None and method is None
94            service_id, method_id = ids
95
96        self._next_packets.append((packet_pb2.RpcPacket(
97            type=packet_pb2.PacketType.RESPONSE,
98            channel_id=channel_id,
99            service_id=service_id,
100            method_id=method_id,
101            status=status.value,
102            payload=_message_bytes(payload)).SerializeToString(),
103                                   process_status))
104
105    def _enqueue_server_stream(self,
106                               channel_id: int,
107                               method,
108                               response,
109                               process_status=Status.OK) -> None:
110        self._next_packets.append((packet_pb2.RpcPacket(
111            type=packet_pb2.PacketType.SERVER_STREAM,
112            channel_id=channel_id,
113            service_id=method.service.id,
114            method_id=method.id,
115            payload=_message_bytes(response)).SerializeToString(),
116                                   process_status))
117
118    def _enqueue_error(self,
119                       channel_id: int,
120                       service,
121                       method,
122                       status: Status,
123                       process_status=Status.OK) -> None:
124        self._next_packets.append((packet_pb2.RpcPacket(
125            type=packet_pb2.PacketType.SERVER_ERROR,
126            channel_id=channel_id,
127            service_id=service if isinstance(service, int) else service.id,
128            method_id=method if isinstance(method, int) else method.id,
129            status=status.value).SerializeToString(), process_status))
130
131    def _handle_packet(self, data: bytes) -> None:
132        if self.output_exception:
133            raise self.output_exception  # pylint: disable=raising-bad-type
134
135        self.requests.append(packets.decode(data))
136
137        if self.send_responses_after_packets > 1:
138            self.send_responses_after_packets -= 1
139            return
140
141        self._process_enqueued_packets()
142
143    def _process_enqueued_packets(self) -> None:
144        # Set send_responses_after_packets to infinity to prevent potential
145        # infinite recursion when a packet causes another packet to send.
146        send_after_count = self.send_responses_after_packets
147        self.send_responses_after_packets = float('inf')
148
149        for packet, status in self._next_packets:
150            self.assertIs(status, self._client.process_packet(packet))
151
152        self._next_packets.clear()
153        self.send_responses_after_packets = send_after_count
154
155    def _sent_payload(self, message_type: type) -> Any:
156        message = message_type()
157        message.ParseFromString(self.last_request().payload)
158        return message
159
160
161class CallbackClientImplTest(_CallbackClientImplTestBase):
162    """Tests the callback_client.Impl client implementation."""
163    def test_callback_exceptions_suppressed(self) -> None:
164        stub = self._service.SomeUnary
165
166        self._enqueue_response(1, stub.method)
167        exception_msg = 'YOU BROKE IT O-]-<'
168
169        with self.assertLogs(callback_client.__package__, 'ERROR') as logs:
170            stub.invoke(self._request(),
171                        mock.Mock(side_effect=Exception(exception_msg)))
172
173        self.assertIn(exception_msg, ''.join(logs.output))
174
175        # Make sure we can still invoke the RPC.
176        self._enqueue_response(1, stub.method, Status.UNKNOWN)
177        status, _ = stub()
178        self.assertIs(status, Status.UNKNOWN)
179
180    def test_ignore_bad_packets_with_pending_rpc(self) -> None:
181        method = self._service.SomeUnary.method
182        service_id = method.service.id
183
184        # Unknown channel
185        self._enqueue_response(999, method, process_status=Status.NOT_FOUND)
186        # Bad service
187        self._enqueue_response(1,
188                               ids=(999, method.id),
189                               process_status=Status.OK)
190        # Bad method
191        self._enqueue_response(1,
192                               ids=(service_id, 999),
193                               process_status=Status.OK)
194        # For RPC not pending (is Status.OK because the packet is processed)
195        self._enqueue_response(1,
196                               ids=(service_id,
197                                    self._service.SomeBidiStreaming.method.id),
198                               process_status=Status.OK)
199
200        self._enqueue_response(1, method, process_status=Status.OK)
201
202        status, response = self._service.SomeUnary(magic_number=6)
203        self.assertIs(Status.OK, status)
204        self.assertEqual('', response.payload)
205
206    def test_server_error_for_unknown_call_sends_no_errors(self) -> None:
207        method = self._service.SomeUnary.method
208        service_id = method.service.id
209
210        # Unknown channel
211        self._enqueue_error(999,
212                            service_id,
213                            method,
214                            Status.NOT_FOUND,
215                            process_status=Status.NOT_FOUND)
216        # Bad service
217        self._enqueue_error(1, 999, method.id, Status.INVALID_ARGUMENT)
218        # Bad method
219        self._enqueue_error(1, service_id, 999, Status.INVALID_ARGUMENT)
220        # For RPC not pending
221        self._enqueue_error(1, service_id,
222                            self._service.SomeBidiStreaming.method.id,
223                            Status.NOT_FOUND)
224
225        self._process_enqueued_packets()
226
227        self.assertEqual(self.requests, [])
228
229    def test_exception_if_payload_fails_to_decode(self) -> None:
230        method = self._service.SomeUnary.method
231
232        self._enqueue_response(1,
233                               method,
234                               Status.OK,
235                               b'INVALID DATA!!!',
236                               process_status=Status.OK)
237
238        with self.assertRaises(callback_client.RpcError) as context:
239            self._service.SomeUnary(magic_number=6)
240
241        self.assertIs(context.exception.status, Status.DATA_LOSS)
242
243    def test_rpc_help_contains_method_name(self) -> None:
244        rpc = self._service.SomeUnary
245        self.assertIn(rpc.method.full_name, rpc.help())
246
247    def test_default_timeouts_set_on_impl(self) -> None:
248        impl = callback_client.Impl(None, 1.5)
249
250        self.assertEqual(impl.default_unary_timeout_s, None)
251        self.assertEqual(impl.default_stream_timeout_s, 1.5)
252
253    def test_default_timeouts_set_for_all_rpcs(self) -> None:
254        rpc_client = client.Client.from_modules(callback_client.Impl(
255            99, 100), [client.Channel(1, lambda *a, **b: None)],
256                                                self._protos.modules())
257        rpcs = rpc_client.channel(1).rpcs
258
259        self.assertEqual(
260            rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99)
261        self.assertEqual(
262            rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s,
263            100)
264        self.assertEqual(
265            rpcs.pw.test1.PublicService.SomeClientStreaming.default_timeout_s,
266            99)
267        self.assertEqual(
268            rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s,
269            100)
270
271    def test_rpc_provides_request_type(self) -> None:
272        self.assertIs(self._service.SomeUnary.request,
273                      self._service.SomeUnary.method.request_type)
274
275    def test_rpc_provides_response_type(self) -> None:
276        self.assertIs(self._service.SomeUnary.request,
277                      self._service.SomeUnary.method.request_type)
278
279
280class UnaryTest(_CallbackClientImplTestBase):
281    """Tests for invoking a unary RPC."""
282    def setUp(self) -> None:
283        super().setUp()
284        self.rpc = self._service.SomeUnary
285        self.method = self.rpc.method
286
287    def test_blocking_call(self) -> None:
288        for _ in range(3):
289            self._enqueue_response(1, self.method, Status.ABORTED,
290                                   self.method.response_type(payload='0_o'))
291
292            status, response = self._service.SomeUnary(
293                self.method.request_type(magic_number=6))
294
295            self.assertEqual(
296                6,
297                self._sent_payload(self.method.request_type).magic_number)
298
299            self.assertIs(Status.ABORTED, status)
300            self.assertEqual('0_o', response.payload)
301
302    def test_nonblocking_call(self) -> None:
303        for _ in range(3):
304            self._enqueue_response(1, self.method, Status.ABORTED,
305                                   self.method.response_type(payload='0_o'))
306
307            callback = mock.Mock()
308            call = self.rpc.invoke(self._request(magic_number=5), callback,
309                                   callback)
310
311            callback.assert_has_calls([
312                mock.call(call, self.method.response_type(payload='0_o')),
313                mock.call(call, Status.ABORTED)
314            ])
315
316            self.assertEqual(
317                5,
318                self._sent_payload(self.method.request_type).magic_number)
319
320    def test_open(self) -> None:
321        self.output_exception = IOError('something went wrong sending!')
322
323        for _ in range(3):
324            self._enqueue_response(1, self.method, Status.ABORTED,
325                                   self.method.response_type(payload='0_o'))
326
327            callback = mock.Mock()
328            call = self.rpc.open(self._request(magic_number=5), callback,
329                                 callback)
330            self.assertEqual(self.requests, [])
331
332            self._process_enqueued_packets()
333
334            callback.assert_has_calls([
335                mock.call(call, self.method.response_type(payload='0_o')),
336                mock.call(call, Status.ABORTED)
337            ])
338
339    def test_blocking_server_error(self) -> None:
340        for _ in range(3):
341            self._enqueue_error(1, self.method.service, self.method,
342                                Status.NOT_FOUND)
343
344            with self.assertRaises(callback_client.RpcError) as context:
345                self._service.SomeUnary(
346                    self.method.request_type(magic_number=6))
347
348            self.assertIs(context.exception.status, Status.NOT_FOUND)
349
350    def test_nonblocking_cancel(self) -> None:
351        callback = mock.Mock()
352
353        for _ in range(3):
354            call = self._service.SomeUnary.invoke(
355                self._request(magic_number=55), callback)
356
357            self.assertGreater(len(self.requests), 0)
358            self.requests.clear()
359
360            self.assertTrue(call.cancel())
361            self.assertFalse(call.cancel())  # Already cancelled, returns False
362
363            # Unary RPCs do not send a cancel request to the server.
364            self.assertFalse(self.requests)
365
366        callback.assert_not_called()
367
368    def test_nonblocking_with_request_args(self) -> None:
369        self.rpc.invoke(request_args=dict(magic_number=1138))
370        self.assertEqual(
371            self._sent_payload(self.rpc.request).magic_number, 1138)
372
373    def test_blocking_timeout_as_argument(self) -> None:
374        with self.assertRaises(callback_client.RpcTimeout):
375            self._service.SomeUnary(pw_rpc_timeout_s=0.0001)
376
377    def test_blocking_timeout_set_default(self) -> None:
378        self._service.SomeUnary.default_timeout_s = 0.0001
379
380        with self.assertRaises(callback_client.RpcTimeout):
381            self._service.SomeUnary()
382
383    def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None:
384        first_call = self.rpc.invoke()
385        self.assertFalse(first_call.completed())
386
387        second_call = self.rpc.invoke()
388
389        self.assertIs(first_call.error, Status.CANCELLED)
390        self.assertFalse(second_call.completed())
391
392    def test_nonblocking_exception_in_callback(self) -> None:
393        exception = ValueError('something went wrong!')
394
395        self._enqueue_response(1, self.method, Status.OK)
396
397        call = self.rpc.invoke(on_completed=mock.Mock(side_effect=exception))
398
399        with self.assertRaises(RuntimeError) as context:
400            call.wait()
401
402        self.assertEqual(context.exception.__cause__, exception)
403
404
405class ServerStreamingTest(_CallbackClientImplTestBase):
406    """Tests for server streaming RPCs."""
407    def setUp(self) -> None:
408        super().setUp()
409        self.rpc = self._service.SomeServerStreaming
410        self.method = self.rpc.method
411
412    def test_blocking_call(self) -> None:
413        rep1 = self.method.response_type(payload='!!!')
414        rep2 = self.method.response_type(payload='?')
415
416        for _ in range(3):
417            self._enqueue_server_stream(1, self.method, rep1)
418            self._enqueue_server_stream(1, self.method, rep2)
419            self._enqueue_response(1, self.method, Status.ABORTED)
420
421            self.assertEqual(
422                [rep1, rep2],
423                self._service.SomeServerStreaming(magic_number=4).responses)
424
425            self.assertEqual(
426                4,
427                self._sent_payload(self.method.request_type).magic_number)
428
429    def test_deprecated_packet_format(self) -> None:
430        rep1 = self.method.response_type(payload='!!!')
431        rep2 = self.method.response_type(payload='?')
432
433        for _ in range(3):
434            # The original packet format used RESPONSE packets for the server
435            # stream and a SERVER_STREAM_END packet as the last packet. These
436            # are converted to SERVER_STREAM packets followed by a RESPONSE.
437            self._enqueue_response(1, self.method, payload=rep1)
438            self._enqueue_response(1, self.method, payload=rep2)
439
440            self._next_packets.append((packet_pb2.RpcPacket(
441                type=packet_pb2.PacketType.DEPRECATED_SERVER_STREAM_END,
442                channel_id=1,
443                service_id=self.method.service.id,
444                method_id=self.method.id,
445                status=Status.INVALID_ARGUMENT.value).SerializeToString(),
446                                       Status.OK))
447
448            status, replies = self._service.SomeServerStreaming(magic_number=4)
449            self.assertEqual([rep1, rep2], replies)
450            self.assertIs(status, Status.INVALID_ARGUMENT)
451
452            self.assertEqual(
453                4,
454                self._sent_payload(self.method.request_type).magic_number)
455
456    def test_nonblocking_call(self) -> None:
457        rep1 = self.method.response_type(payload='!!!')
458        rep2 = self.method.response_type(payload='?')
459
460        for _ in range(3):
461            self._enqueue_server_stream(1, self.method, rep1)
462            self._enqueue_server_stream(1, self.method, rep2)
463            self._enqueue_response(1, self.method, Status.ABORTED)
464
465            callback = mock.Mock()
466            call = self.rpc.invoke(self._request(magic_number=3), callback,
467                                   callback)
468
469            callback.assert_has_calls([
470                mock.call(call, self.method.response_type(payload='!!!')),
471                mock.call(call, self.method.response_type(payload='?')),
472                mock.call(call, Status.ABORTED),
473            ])
474
475            self.assertEqual(
476                3,
477                self._sent_payload(self.method.request_type).magic_number)
478
479    def test_open(self) -> None:
480        self.output_exception = IOError('something went wrong sending!')
481        rep1 = self.method.response_type(payload='!!!')
482        rep2 = self.method.response_type(payload='?')
483
484        for _ in range(3):
485            self._enqueue_server_stream(1, self.method, rep1)
486            self._enqueue_server_stream(1, self.method, rep2)
487            self._enqueue_response(1, self.method, Status.ABORTED)
488
489            callback = mock.Mock()
490            call = self.rpc.open(self._request(magic_number=3), callback,
491                                 callback)
492            self.assertEqual(self.requests, [])
493
494            self._process_enqueued_packets()
495
496            callback.assert_has_calls([
497                mock.call(call, self.method.response_type(payload='!!!')),
498                mock.call(call, self.method.response_type(payload='?')),
499                mock.call(call, Status.ABORTED),
500            ])
501
502    def test_nonblocking_cancel(self) -> None:
503        resp = self.rpc.method.response_type(payload='!!!')
504        self._enqueue_server_stream(1, self.rpc.method, resp)
505
506        callback = mock.Mock()
507        call = self.rpc.invoke(self._request(magic_number=3), callback)
508        callback.assert_called_once_with(
509            call, self.rpc.method.response_type(payload='!!!'))
510
511        callback.reset_mock()
512
513        call.cancel()
514
515        self.assertEqual(self.last_request().type,
516                         packet_pb2.PacketType.CLIENT_ERROR)
517        self.assertEqual(self.last_request().status, Status.CANCELLED.value)
518
519        # Ensure the RPC can be called after being cancelled.
520        self._enqueue_server_stream(1, self.method, resp)
521        self._enqueue_response(1, self.method, Status.OK)
522
523        call = self.rpc.invoke(self._request(magic_number=3), callback,
524                               callback)
525
526        callback.assert_has_calls([
527            mock.call(call, self.method.response_type(payload='!!!')),
528            mock.call(call, Status.OK),
529        ])
530
531    def test_nonblocking_with_request_args(self) -> None:
532        self.rpc.invoke(request_args=dict(magic_number=1138))
533        self.assertEqual(
534            self._sent_payload(self.rpc.request).magic_number, 1138)
535
536    def test_blocking_timeout(self) -> None:
537        with self.assertRaises(callback_client.RpcTimeout):
538            self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001)
539
540    def test_nonblocking_iteration_timeout(self) -> None:
541        call = self._service.SomeServerStreaming.invoke(timeout_s=0.0001)
542        with self.assertRaises(callback_client.RpcTimeout):
543            for _ in call:
544                pass
545
546    def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None:
547        first_call = self.rpc.invoke()
548        self.assertFalse(first_call.completed())
549
550        second_call = self.rpc.invoke()
551
552        self.assertIs(first_call.error, Status.CANCELLED)
553        self.assertFalse(second_call.completed())
554
555    def test_nonblocking_iterate_over_count(self) -> None:
556        reply = self.method.response_type(payload='!?')
557
558        for _ in range(4):
559            self._enqueue_server_stream(1, self.method, reply)
560
561        call = self.rpc.invoke()
562
563        self.assertEqual(list(call.get_responses(count=1)), [reply])
564        self.assertEqual(next(iter(call)), reply)
565        self.assertEqual(list(call.get_responses(count=2)), [reply, reply])
566
567    def test_nonblocking_iterate_after_completed_doesnt_block(self) -> None:
568        reply = self.method.response_type(payload='!?')
569        self._enqueue_server_stream(1, self.method, reply)
570        self._enqueue_response(1, self.method, Status.OK)
571
572        call = self.rpc.invoke()
573
574        self.assertEqual(list(call.get_responses()), [reply])
575        self.assertEqual(list(call.get_responses()), [])
576        self.assertEqual(list(call), [])
577
578
579class ClientStreamingTest(_CallbackClientImplTestBase):
580    """Tests for client streaming RPCs."""
581    def setUp(self) -> None:
582        super().setUp()
583        self.rpc = self._service.SomeClientStreaming
584        self.method = self.rpc.method
585
586    def test_blocking_call(self) -> None:
587        requests = [
588            self.method.request_type(magic_number=123),
589            self.method.request_type(magic_number=456),
590        ]
591
592        # Send after len(requests) and the client stream end packet.
593        self.send_responses_after_packets = 3
594        response = self.method.response_type(payload='yo')
595        self._enqueue_response(1, self.method, Status.OK, response)
596
597        results = self.rpc(requests)
598        self.assertIs(results.status, Status.OK)
599        self.assertEqual(results.response, response)
600
601    def test_blocking_server_error(self) -> None:
602        requests = [self.method.request_type(magic_number=123)]
603
604        # Send after len(requests) and the client stream end packet.
605        self._enqueue_error(1, self.method.service, self.method,
606                            Status.NOT_FOUND)
607
608        with self.assertRaises(callback_client.RpcError) as context:
609            self.rpc(requests)
610
611        self.assertIs(context.exception.status, Status.NOT_FOUND)
612
613    def test_nonblocking_call(self) -> None:
614        """Tests a successful client streaming RPC ended by the server."""
615        payload_1 = self.method.response_type(payload='-_-')
616
617        for _ in range(3):
618            stream = self._service.SomeClientStreaming.invoke()
619            self.assertFalse(stream.completed())
620
621            stream.send(magic_number=31)
622            self.assertIs(packet_pb2.PacketType.CLIENT_STREAM,
623                          self.last_request().type)
624            self.assertEqual(
625                31,
626                self._sent_payload(self.method.request_type).magic_number)
627            self.assertFalse(stream.completed())
628
629            # Enqueue the server response to be sent after the next message.
630            self._enqueue_response(1, self.method, Status.OK, payload_1)
631
632            stream.send(magic_number=32)
633            self.assertIs(packet_pb2.PacketType.CLIENT_STREAM,
634                          self.last_request().type)
635            self.assertEqual(
636                32,
637                self._sent_payload(self.method.request_type).magic_number)
638
639            self.assertTrue(stream.completed())
640            self.assertIs(Status.OK, stream.status)
641            self.assertIsNone(stream.error)
642            self.assertEqual(payload_1, stream.response)
643
644    def test_open(self) -> None:
645        self.output_exception = IOError('something went wrong sending!')
646        payload = self.method.response_type(payload='-_-')
647
648        for _ in range(3):
649            self._enqueue_response(1, self.method, Status.OK, payload)
650
651            callback = mock.Mock()
652            call = self.rpc.open(callback, callback, callback)
653            self.assertEqual(self.requests, [])
654
655            self._process_enqueued_packets()
656
657            callback.assert_has_calls([
658                mock.call(call, payload),
659                mock.call(call, Status.OK),
660            ])
661
662    def test_nonblocking_finish(self) -> None:
663        """Tests a client streaming RPC ended by the client."""
664        payload_1 = self.method.response_type(payload='-_-')
665
666        for _ in range(3):
667            stream = self._service.SomeClientStreaming.invoke()
668            self.assertFalse(stream.completed())
669
670            stream.send(magic_number=37)
671            self.assertIs(packet_pb2.PacketType.CLIENT_STREAM,
672                          self.last_request().type)
673            self.assertEqual(
674                37,
675                self._sent_payload(self.method.request_type).magic_number)
676            self.assertFalse(stream.completed())
677
678            # Enqueue the server response to be sent after the next message.
679            self._enqueue_response(1, self.method, Status.OK, payload_1)
680
681            stream.finish_and_wait()
682            self.assertIs(packet_pb2.PacketType.CLIENT_STREAM_END,
683                          self.last_request().type)
684
685            self.assertTrue(stream.completed())
686            self.assertIs(Status.OK, stream.status)
687            self.assertIsNone(stream.error)
688            self.assertEqual(payload_1, stream.response)
689
690    def test_nonblocking_cancel(self) -> None:
691        for _ in range(3):
692            stream = self._service.SomeClientStreaming.invoke()
693            stream.send(magic_number=37)
694
695            self.assertTrue(stream.cancel())
696            self.assertIs(packet_pb2.PacketType.CLIENT_ERROR,
697                          self.last_request().type)
698            self.assertIs(Status.CANCELLED.value, self.last_request().status)
699            self.assertFalse(stream.cancel())
700
701            self.assertTrue(stream.completed())
702            self.assertIs(stream.error, Status.CANCELLED)
703
704    def test_nonblocking_server_error(self) -> None:
705        for _ in range(3):
706            stream = self._service.SomeClientStreaming.invoke()
707
708            self._enqueue_error(1, self.method.service, self.method,
709                                Status.INVALID_ARGUMENT)
710            stream.send(magic_number=2**32 - 1)
711
712            with self.assertRaises(callback_client.RpcError) as context:
713                stream.finish_and_wait()
714
715            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)
716
717    def test_nonblocking_server_error_after_stream_end(self) -> None:
718        for _ in range(3):
719            stream = self._service.SomeClientStreaming.invoke()
720
721            # Error will be sent in response to the CLIENT_STREAM_END packet.
722            self._enqueue_error(1, self.method.service, self.method,
723                                Status.INVALID_ARGUMENT)
724
725            with self.assertRaises(callback_client.RpcError) as context:
726                stream.finish_and_wait()
727
728            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)
729
730    def test_nonblocking_send_after_cancelled(self) -> None:
731        call = self._service.SomeClientStreaming.invoke()
732        self.assertTrue(call.cancel())
733
734        with self.assertRaises(callback_client.RpcError) as context:
735            call.send(payload='hello')
736
737        self.assertIs(context.exception.status, Status.CANCELLED)
738
739    def test_nonblocking_finish_after_completed(self) -> None:
740        reply = self.method.response_type(payload='!?')
741        self._enqueue_response(1, self.method, Status.UNAVAILABLE, reply)
742
743        call = self.rpc.invoke()
744        result = call.finish_and_wait()
745        self.assertEqual(result.response, reply)
746
747        self.assertEqual(result, call.finish_and_wait())
748        self.assertEqual(result, call.finish_and_wait())
749
750    def test_nonblocking_finish_after_error(self) -> None:
751        self._enqueue_error(1, self.method.service, self.method,
752                            Status.UNAVAILABLE)
753
754        call = self.rpc.invoke()
755
756        for _ in range(3):
757            with self.assertRaises(callback_client.RpcError) as context:
758                call.finish_and_wait()
759
760            self.assertIs(context.exception.status, Status.UNAVAILABLE)
761            self.assertIs(call.error, Status.UNAVAILABLE)
762            self.assertIsNone(call.response)
763
764    def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None:
765        first_call = self.rpc.invoke()
766        self.assertFalse(first_call.completed())
767
768        second_call = self.rpc.invoke()
769
770        self.assertIs(first_call.error, Status.CANCELLED)
771        self.assertFalse(second_call.completed())
772
773
774class BidirectionalStreamingTest(_CallbackClientImplTestBase):
775    """Tests for bidirectional streaming RPCs."""
776    def setUp(self) -> None:
777        super().setUp()
778        self.rpc = self._service.SomeBidiStreaming
779        self.method = self.rpc.method
780
781    def test_blocking_call(self) -> None:
782        requests = [
783            self.method.request_type(magic_number=123),
784            self.method.request_type(magic_number=456),
785        ]
786
787        # Send after len(requests) and the client stream end packet.
788        self.send_responses_after_packets = 3
789        self._enqueue_response(1, self.method, Status.NOT_FOUND)
790
791        results = self.rpc(requests)
792        self.assertIs(results.status, Status.NOT_FOUND)
793        self.assertFalse(results.responses)
794
795    def test_blocking_server_error(self) -> None:
796        requests = [self.method.request_type(magic_number=123)]
797
798        # Send after len(requests) and the client stream end packet.
799        self._enqueue_error(1, self.method.service, self.method,
800                            Status.NOT_FOUND)
801
802        with self.assertRaises(callback_client.RpcError) as context:
803            self.rpc(requests)
804
805        self.assertIs(context.exception.status, Status.NOT_FOUND)
806
807    def test_nonblocking_call(self) -> None:
808        """Tests a bidirectional streaming RPC ended by the server."""
809        rep1 = self.method.response_type(payload='!!!')
810        rep2 = self.method.response_type(payload='?')
811
812        for _ in range(3):
813            responses: list = []
814            stream = self._service.SomeBidiStreaming.invoke(
815                lambda _, res, responses=responses: responses.append(res))
816            self.assertFalse(stream.completed())
817
818            stream.send(magic_number=55)
819            self.assertIs(packet_pb2.PacketType.CLIENT_STREAM,
820                          self.last_request().type)
821            self.assertEqual(
822                55,
823                self._sent_payload(self.method.request_type).magic_number)
824            self.assertFalse(stream.completed())
825            self.assertEqual([], responses)
826
827            self._enqueue_server_stream(1, self.method, rep1)
828            self._enqueue_server_stream(1, self.method, rep2)
829
830            stream.send(magic_number=66)
831            self.assertIs(packet_pb2.PacketType.CLIENT_STREAM,
832                          self.last_request().type)
833            self.assertEqual(
834                66,
835                self._sent_payload(self.method.request_type).magic_number)
836            self.assertFalse(stream.completed())
837            self.assertEqual([rep1, rep2], responses)
838
839            self._enqueue_response(1, self.method, Status.OK)
840
841            stream.send(magic_number=77)
842            self.assertTrue(stream.completed())
843            self.assertEqual([rep1, rep2], responses)
844
845            self.assertIs(Status.OK, stream.status)
846            self.assertIsNone(stream.error)
847
848    def test_open(self) -> None:
849        self.output_exception = IOError('something went wrong sending!')
850        rep1 = self.method.response_type(payload='!!!')
851        rep2 = self.method.response_type(payload='?')
852
853        for _ in range(3):
854            self._enqueue_server_stream(1, self.method, rep1)
855            self._enqueue_server_stream(1, self.method, rep2)
856            self._enqueue_response(1, self.method, Status.OK)
857
858            callback = mock.Mock()
859            call = self.rpc.open(callback, callback, callback)
860            self.assertEqual(self.requests, [])
861
862            self._process_enqueued_packets()
863
864            callback.assert_has_calls([
865                mock.call(call, self.method.response_type(payload='!!!')),
866                mock.call(call, self.method.response_type(payload='?')),
867                mock.call(call, Status.OK),
868            ])
869
870    @mock.patch('pw_rpc.callback_client.call.Call._default_response')
871    def test_nonblocking(self, callback) -> None:
872        """Tests a bidirectional streaming RPC ended by the server."""
873        reply = self.method.response_type(payload='This is the payload!')
874        self._enqueue_server_stream(1, self.method, reply)
875
876        self._service.SomeBidiStreaming.invoke()
877
878        callback.assert_called_once_with(mock.ANY, reply)
879
880    def test_nonblocking_server_error(self) -> None:
881        rep1 = self.method.response_type(payload='!!!')
882
883        for _ in range(3):
884            responses: list = []
885            stream = self._service.SomeBidiStreaming.invoke(
886                lambda _, res, responses=responses: responses.append(res))
887            self.assertFalse(stream.completed())
888
889            self._enqueue_server_stream(1, self.method, rep1)
890
891            stream.send(magic_number=55)
892            self.assertFalse(stream.completed())
893            self.assertEqual([rep1], responses)
894
895            self._enqueue_error(1, self.method.service, self.method,
896                                Status.OUT_OF_RANGE)
897
898            stream.send(magic_number=99999)
899            self.assertTrue(stream.completed())
900            self.assertEqual([rep1], responses)
901
902            self.assertIsNone(stream.status)
903            self.assertIs(Status.OUT_OF_RANGE, stream.error)
904
905            with self.assertRaises(callback_client.RpcError) as context:
906                stream.finish_and_wait()
907            self.assertIs(context.exception.status, Status.OUT_OF_RANGE)
908
909    def test_nonblocking_server_error_after_stream_end(self) -> None:
910        for _ in range(3):
911            stream = self._service.SomeBidiStreaming.invoke()
912
913            # Error will be sent in response to the CLIENT_STREAM_END packet.
914            self._enqueue_error(1, self.method.service, self.method,
915                                Status.INVALID_ARGUMENT)
916
917            with self.assertRaises(callback_client.RpcError) as context:
918                stream.finish_and_wait()
919
920            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)
921
922    def test_nonblocking_send_after_cancelled(self) -> None:
923        call = self._service.SomeBidiStreaming.invoke()
924        self.assertTrue(call.cancel())
925
926        with self.assertRaises(callback_client.RpcError) as context:
927            call.send(payload='hello')
928
929        self.assertIs(context.exception.status, Status.CANCELLED)
930
931    def test_nonblocking_finish_after_completed(self) -> None:
932        reply = self.method.response_type(payload='!?')
933        self._enqueue_server_stream(1, self.method, reply)
934        self._enqueue_response(1, self.method, Status.UNAVAILABLE)
935
936        call = self.rpc.invoke()
937        result = call.finish_and_wait()
938        self.assertEqual(result.responses, [reply])
939
940        self.assertEqual(result, call.finish_and_wait())
941        self.assertEqual(result, call.finish_and_wait())
942
943    def test_nonblocking_finish_after_error(self) -> None:
944        reply = self.method.response_type(payload='!?')
945        self._enqueue_server_stream(1, self.method, reply)
946        self._enqueue_error(1, self.method.service, self.method,
947                            Status.UNAVAILABLE)
948
949        call = self.rpc.invoke()
950
951        for _ in range(3):
952            with self.assertRaises(callback_client.RpcError) as context:
953                call.finish_and_wait()
954
955            self.assertIs(context.exception.status, Status.UNAVAILABLE)
956            self.assertIs(call.error, Status.UNAVAILABLE)
957            self.assertEqual(call.responses, [reply])
958
959    def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None:
960        first_call = self.rpc.invoke()
961        self.assertFalse(first_call.completed())
962
963        second_call = self.rpc.invoke()
964
965        self.assertIs(first_call.error, Status.CANCELLED)
966        self.assertFalse(second_call.completed())
967
968
969if __name__ == '__main__':
970    unittest.main()
971