• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2021 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests using the callback client for pw_rpc."""
16
17import unittest
18from unittest import mock
19from typing import Any, List, Optional, Tuple
20
21from pw_protobuf_compiler import python_protos
22from pw_status import Status
23
24from pw_rpc import callback_client, client, packets
25from pw_rpc.internal import packet_pb2
26
27TEST_PROTO_1 = """\
28syntax = "proto3";
29
30package pw.test1;
31
32message SomeMessage {
33  uint32 magic_number = 1;
34}
35
36message AnotherMessage {
37  enum Result {
38    FAILED = 0;
39    FAILED_MISERABLY = 1;
40    I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
41  }
42
43  Result result = 1;
44  string payload = 2;
45}
46
47service PublicService {
48  rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
49  rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
50  rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
51  rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
52}
53"""
54
55
56def _message_bytes(msg) -> bytes:
57    return msg if isinstance(msg, bytes) else msg.SerializeToString()
58
59
60class _CallbackClientImplTestBase(unittest.TestCase):
61    """Supports writing tests that require responses from an RPC server."""
62
63    def setUp(self) -> None:
64        self._protos = python_protos.Library.from_strings(TEST_PROTO_1)
65        self._request = self._protos.packages.pw.test1.SomeMessage
66
67        self._client = client.Client.from_modules(
68            callback_client.Impl(),
69            [client.Channel(1, self._handle_packet)],
70            self._protos.modules(),
71        )
72        self._service = self._client.channel(1).rpcs.pw.test1.PublicService
73
74        self.requests: List[packet_pb2.RpcPacket] = []
75        self._next_packets: List[Tuple[bytes, Status]] = []
76        self.send_responses_after_packets: float = 1
77
78        self.output_exception: Optional[Exception] = None
79
80    def last_request(self) -> packet_pb2.RpcPacket:
81        assert self.requests
82        return self.requests[-1]
83
84    def _enqueue_response(
85        self,
86        channel_id: int,
87        method=None,
88        status: Status = Status.OK,
89        payload=b'',
90        *,
91        ids: Optional[Tuple[int, int]] = None,
92        process_status=Status.OK,
93    ) -> None:
94        if method:
95            assert ids is None
96            service_id, method_id = method.service.id, method.id
97        else:
98            assert ids is not None and method is None
99            service_id, method_id = ids
100
101        self._next_packets.append(
102            (
103                packet_pb2.RpcPacket(
104                    type=packet_pb2.PacketType.RESPONSE,
105                    channel_id=channel_id,
106                    service_id=service_id,
107                    method_id=method_id,
108                    status=status.value,
109                    payload=_message_bytes(payload),
110                ).SerializeToString(),
111                process_status,
112            )
113        )
114
115    def _enqueue_server_stream(
116        self, channel_id: int, method, response, process_status=Status.OK
117    ) -> None:
118        self._next_packets.append(
119            (
120                packet_pb2.RpcPacket(
121                    type=packet_pb2.PacketType.SERVER_STREAM,
122                    channel_id=channel_id,
123                    service_id=method.service.id,
124                    method_id=method.id,
125                    payload=_message_bytes(response),
126                ).SerializeToString(),
127                process_status,
128            )
129        )
130
131    def _enqueue_error(
132        self,
133        channel_id: int,
134        service,
135        method,
136        status: Status,
137        process_status=Status.OK,
138    ) -> None:
139        self._next_packets.append(
140            (
141                packet_pb2.RpcPacket(
142                    type=packet_pb2.PacketType.SERVER_ERROR,
143                    channel_id=channel_id,
144                    service_id=service
145                    if isinstance(service, int)
146                    else service.id,
147                    method_id=method if isinstance(method, int) else method.id,
148                    status=status.value,
149                ).SerializeToString(),
150                process_status,
151            )
152        )
153
154    def _handle_packet(self, data: bytes) -> None:
155        if self.output_exception:
156            raise self.output_exception  # pylint: disable=raising-bad-type
157
158        self.requests.append(packets.decode(data))
159
160        if self.send_responses_after_packets > 1:
161            self.send_responses_after_packets -= 1
162            return
163
164        self._process_enqueued_packets()
165
166    def _process_enqueued_packets(self) -> None:
167        # Set send_responses_after_packets to infinity to prevent potential
168        # infinite recursion when a packet causes another packet to send.
169        send_after_count = self.send_responses_after_packets
170        self.send_responses_after_packets = float('inf')
171
172        for packet, status in self._next_packets:
173            self.assertIs(status, self._client.process_packet(packet))
174
175        self._next_packets.clear()
176        self.send_responses_after_packets = send_after_count
177
178    def _sent_payload(self, message_type: type) -> Any:
179        message = message_type()
180        message.ParseFromString(self.last_request().payload)
181        return message
182
183
184class CallbackClientImplTest(_CallbackClientImplTestBase):
185    """Tests the callback_client.Impl client implementation."""
186
187    def test_callback_exceptions_suppressed(self) -> None:
188        stub = self._service.SomeUnary
189
190        self._enqueue_response(1, stub.method)
191        exception_msg = 'YOU BROKE IT O-]-<'
192
193        with self.assertLogs(callback_client.__package__, 'ERROR') as logs:
194            stub.invoke(
195                self._request(), mock.Mock(side_effect=Exception(exception_msg))
196            )
197
198        self.assertIn(exception_msg, ''.join(logs.output))
199
200        # Make sure we can still invoke the RPC.
201        self._enqueue_response(1, stub.method, Status.UNKNOWN)
202        status, _ = stub()
203        self.assertIs(status, Status.UNKNOWN)
204
205    def test_ignore_bad_packets_with_pending_rpc(self) -> None:
206        method = self._service.SomeUnary.method
207        service_id = method.service.id
208
209        # Unknown channel
210        self._enqueue_response(999, method, process_status=Status.NOT_FOUND)
211        # Bad service
212        self._enqueue_response(
213            1, ids=(999, method.id), process_status=Status.OK
214        )
215        # Bad method
216        self._enqueue_response(
217            1, ids=(service_id, 999), process_status=Status.OK
218        )
219        # For RPC not pending (is Status.OK because the packet is processed)
220        self._enqueue_response(
221            1,
222            ids=(service_id, self._service.SomeBidiStreaming.method.id),
223            process_status=Status.OK,
224        )
225
226        self._enqueue_response(1, method, process_status=Status.OK)
227
228        status, response = self._service.SomeUnary(magic_number=6)
229        self.assertIs(Status.OK, status)
230        self.assertEqual('', response.payload)
231
232    def test_server_error_for_unknown_call_sends_no_errors(self) -> None:
233        method = self._service.SomeUnary.method
234        service_id = method.service.id
235
236        # Unknown channel
237        self._enqueue_error(
238            999,
239            service_id,
240            method,
241            Status.NOT_FOUND,
242            process_status=Status.NOT_FOUND,
243        )
244        # Bad service
245        self._enqueue_error(1, 999, method.id, Status.INVALID_ARGUMENT)
246        # Bad method
247        self._enqueue_error(1, service_id, 999, Status.INVALID_ARGUMENT)
248        # For RPC not pending
249        self._enqueue_error(
250            1,
251            service_id,
252            self._service.SomeBidiStreaming.method.id,
253            Status.NOT_FOUND,
254        )
255
256        self._process_enqueued_packets()
257
258        self.assertEqual(self.requests, [])
259
260    def test_exception_if_payload_fails_to_decode(self) -> None:
261        method = self._service.SomeUnary.method
262
263        self._enqueue_response(
264            1, method, Status.OK, b'INVALID DATA!!!', process_status=Status.OK
265        )
266
267        with self.assertRaises(callback_client.RpcError) as context:
268            self._service.SomeUnary(magic_number=6)
269
270        self.assertIs(context.exception.status, Status.DATA_LOSS)
271
272    def test_rpc_help_contains_method_name(self) -> None:
273        rpc = self._service.SomeUnary
274        self.assertIn(rpc.method.full_name, rpc.help())
275
276    def test_default_timeouts_set_on_impl(self) -> None:
277        impl = callback_client.Impl(None, 1.5)
278
279        self.assertEqual(impl.default_unary_timeout_s, None)
280        self.assertEqual(impl.default_stream_timeout_s, 1.5)
281
282    def test_default_timeouts_set_for_all_rpcs(self) -> None:
283        rpc_client = client.Client.from_modules(
284            callback_client.Impl(99, 100),
285            [client.Channel(1, lambda *a, **b: None)],
286            self._protos.modules(),
287        )
288        rpcs = rpc_client.channel(1).rpcs
289
290        self.assertEqual(
291            rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99
292        )
293        self.assertEqual(
294            rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s,
295            100,
296        )
297        self.assertEqual(
298            rpcs.pw.test1.PublicService.SomeClientStreaming.default_timeout_s,
299            99,
300        )
301        self.assertEqual(
302            rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s, 100
303        )
304
305    def test_rpc_provides_request_type(self) -> None:
306        self.assertIs(
307            self._service.SomeUnary.request,
308            self._service.SomeUnary.method.request_type,
309        )
310
311    def test_rpc_provides_response_type(self) -> None:
312        self.assertIs(
313            self._service.SomeUnary.request,
314            self._service.SomeUnary.method.request_type,
315        )
316
317
318class UnaryTest(_CallbackClientImplTestBase):
319    """Tests for invoking a unary RPC."""
320
321    def setUp(self) -> None:
322        super().setUp()
323        self.rpc = self._service.SomeUnary
324        self.method = self.rpc.method
325
326    def test_blocking_call(self) -> None:
327        for _ in range(3):
328            self._enqueue_response(
329                1,
330                self.method,
331                Status.ABORTED,
332                self.method.response_type(payload='0_o'),
333            )
334
335            status, response = self._service.SomeUnary(
336                self.method.request_type(magic_number=6)
337            )
338
339            self.assertEqual(
340                6, self._sent_payload(self.method.request_type).magic_number
341            )
342
343            self.assertIs(Status.ABORTED, status)
344            self.assertEqual('0_o', response.payload)
345
346    def test_nonblocking_call(self) -> None:
347        for _ in range(3):
348            self._enqueue_response(
349                1,
350                self.method,
351                Status.ABORTED,
352                self.method.response_type(payload='0_o'),
353            )
354
355            callback = mock.Mock()
356            call = self.rpc.invoke(
357                self._request(magic_number=5), callback, callback
358            )
359
360            callback.assert_has_calls(
361                [
362                    mock.call(call, self.method.response_type(payload='0_o')),
363                    mock.call(call, Status.ABORTED),
364                ]
365            )
366
367            self.assertEqual(
368                5, self._sent_payload(self.method.request_type).magic_number
369            )
370
371    def test_open(self) -> None:
372        self.output_exception = IOError('something went wrong sending!')
373
374        for _ in range(3):
375            self._enqueue_response(
376                1,
377                self.method,
378                Status.ABORTED,
379                self.method.response_type(payload='0_o'),
380            )
381
382            callback = mock.Mock()
383            call = self.rpc.open(
384                self._request(magic_number=5), callback, callback
385            )
386            self.assertEqual(self.requests, [])
387
388            self._process_enqueued_packets()
389
390            callback.assert_has_calls(
391                [
392                    mock.call(call, self.method.response_type(payload='0_o')),
393                    mock.call(call, Status.ABORTED),
394                ]
395            )
396
397    def test_blocking_server_error(self) -> None:
398        for _ in range(3):
399            self._enqueue_error(
400                1, self.method.service, self.method, Status.NOT_FOUND
401            )
402
403            with self.assertRaises(callback_client.RpcError) as context:
404                self._service.SomeUnary(
405                    self.method.request_type(magic_number=6)
406                )
407
408            self.assertIs(context.exception.status, Status.NOT_FOUND)
409
410    def test_nonblocking_cancel(self) -> None:
411        callback = mock.Mock()
412
413        for _ in range(3):
414            call = self._service.SomeUnary.invoke(
415                self._request(magic_number=55), callback
416            )
417
418            self.assertGreater(len(self.requests), 0)
419            self.requests.clear()
420
421            self.assertTrue(call.cancel())
422            self.assertFalse(call.cancel())  # Already cancelled, returns False
423
424            self.assertEqual(
425                self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR
426            )
427            self.assertEqual(self.last_request().status, Status.CANCELLED.value)
428
429        callback.assert_not_called()
430
431    def test_nonblocking_with_request_args(self) -> None:
432        self.rpc.invoke(request_args=dict(magic_number=1138))
433        self.assertEqual(
434            self._sent_payload(self.rpc.request).magic_number, 1138
435        )
436
437    def test_blocking_timeout_as_argument(self) -> None:
438        with self.assertRaises(callback_client.RpcTimeout):
439            self._service.SomeUnary(pw_rpc_timeout_s=0.0001)
440
441    def test_blocking_timeout_set_default(self) -> None:
442        self._service.SomeUnary.default_timeout_s = 0.0001
443
444        with self.assertRaises(callback_client.RpcTimeout):
445            self._service.SomeUnary()
446
447    def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None:
448        first_call = self.rpc.invoke()
449        self.assertFalse(first_call.completed())
450
451        second_call = self.rpc.invoke()
452
453        self.assertIs(first_call.error, Status.CANCELLED)
454        self.assertFalse(second_call.completed())
455
456    def test_nonblocking_exception_in_callback(self) -> None:
457        exception = ValueError('something went wrong!')
458
459        self._enqueue_response(1, self.method, Status.OK)
460
461        call = self.rpc.invoke(on_completed=mock.Mock(side_effect=exception))
462
463        with self.assertRaises(RuntimeError) as context:
464            call.wait()
465
466        self.assertEqual(context.exception.__cause__, exception)
467
468    def test_unary_response(self) -> None:
469        proto = self._protos.packages.pw.test1.SomeMessage(magic_number=123)
470        self.assertEqual(
471            repr(callback_client.UnaryResponse(Status.ABORTED, proto)),
472            '(Status.ABORTED, pw.test1.SomeMessage(magic_number=123))',
473        )
474        self.assertEqual(
475            repr(callback_client.UnaryResponse(Status.OK, None)),
476            '(Status.OK, None)',
477        )
478
479
480class ServerStreamingTest(_CallbackClientImplTestBase):
481    """Tests for server streaming RPCs."""
482
483    def setUp(self) -> None:
484        super().setUp()
485        self.rpc = self._service.SomeServerStreaming
486        self.method = self.rpc.method
487
488    def test_blocking_call(self) -> None:
489        rep1 = self.method.response_type(payload='!!!')
490        rep2 = self.method.response_type(payload='?')
491
492        for _ in range(3):
493            self._enqueue_server_stream(1, self.method, rep1)
494            self._enqueue_server_stream(1, self.method, rep2)
495            self._enqueue_response(1, self.method, Status.ABORTED)
496
497            self.assertEqual(
498                [rep1, rep2],
499                self._service.SomeServerStreaming(magic_number=4).responses,
500            )
501
502            self.assertEqual(
503                4, self._sent_payload(self.method.request_type).magic_number
504            )
505
506    def test_nonblocking_call(self) -> None:
507        rep1 = self.method.response_type(payload='!!!')
508        rep2 = self.method.response_type(payload='?')
509
510        for _ in range(3):
511            self._enqueue_server_stream(1, self.method, rep1)
512            self._enqueue_server_stream(1, self.method, rep2)
513            self._enqueue_response(1, self.method, Status.ABORTED)
514
515            callback = mock.Mock()
516            call = self.rpc.invoke(
517                self._request(magic_number=3), callback, callback
518            )
519
520            callback.assert_has_calls(
521                [
522                    mock.call(call, self.method.response_type(payload='!!!')),
523                    mock.call(call, self.method.response_type(payload='?')),
524                    mock.call(call, Status.ABORTED),
525                ]
526            )
527
528            self.assertEqual(
529                3, self._sent_payload(self.method.request_type).magic_number
530            )
531
532    def test_open(self) -> None:
533        self.output_exception = IOError('something went wrong sending!')
534        rep1 = self.method.response_type(payload='!!!')
535        rep2 = self.method.response_type(payload='?')
536
537        for _ in range(3):
538            self._enqueue_server_stream(1, self.method, rep1)
539            self._enqueue_server_stream(1, self.method, rep2)
540            self._enqueue_response(1, self.method, Status.ABORTED)
541
542            callback = mock.Mock()
543            call = self.rpc.open(
544                self._request(magic_number=3), callback, callback
545            )
546            self.assertEqual(self.requests, [])
547
548            self._process_enqueued_packets()
549
550            callback.assert_has_calls(
551                [
552                    mock.call(call, self.method.response_type(payload='!!!')),
553                    mock.call(call, self.method.response_type(payload='?')),
554                    mock.call(call, Status.ABORTED),
555                ]
556            )
557
558    def test_nonblocking_cancel(self) -> None:
559        resp = self.rpc.method.response_type(payload='!!!')
560        self._enqueue_server_stream(1, self.rpc.method, resp)
561
562        callback = mock.Mock()
563        call = self.rpc.invoke(self._request(magic_number=3), callback)
564        callback.assert_called_once_with(
565            call, self.rpc.method.response_type(payload='!!!')
566        )
567
568        callback.reset_mock()
569
570        call.cancel()
571
572        self.assertEqual(
573            self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR
574        )
575        self.assertEqual(self.last_request().status, Status.CANCELLED.value)
576
577        # Ensure the RPC can be called after being cancelled.
578        self._enqueue_server_stream(1, self.method, resp)
579        self._enqueue_response(1, self.method, Status.OK)
580
581        call = self.rpc.invoke(
582            self._request(magic_number=3), callback, callback
583        )
584
585        callback.assert_has_calls(
586            [
587                mock.call(call, self.method.response_type(payload='!!!')),
588                mock.call(call, Status.OK),
589            ]
590        )
591
592    def test_nonblocking_with_request_args(self) -> None:
593        self.rpc.invoke(request_args=dict(magic_number=1138))
594        self.assertEqual(
595            self._sent_payload(self.rpc.request).magic_number, 1138
596        )
597
598    def test_blocking_timeout(self) -> None:
599        with self.assertRaises(callback_client.RpcTimeout):
600            self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001)
601
602    def test_nonblocking_iteration_timeout(self) -> None:
603        call = self._service.SomeServerStreaming.invoke(timeout_s=0.0001)
604        with self.assertRaises(callback_client.RpcTimeout):
605            for _ in call:
606                pass
607
608    def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None:
609        first_call = self.rpc.invoke()
610        self.assertFalse(first_call.completed())
611
612        second_call = self.rpc.invoke()
613
614        self.assertIs(first_call.error, Status.CANCELLED)
615        self.assertFalse(second_call.completed())
616
617    def test_nonblocking_iterate_over_count(self) -> None:
618        reply = self.method.response_type(payload='!?')
619
620        for _ in range(4):
621            self._enqueue_server_stream(1, self.method, reply)
622
623        call = self.rpc.invoke()
624
625        self.assertEqual(list(call.get_responses(count=1)), [reply])
626        self.assertEqual(next(iter(call)), reply)
627        self.assertEqual(list(call.get_responses(count=2)), [reply, reply])
628
629    def test_nonblocking_iterate_after_completed_doesnt_block(self) -> None:
630        reply = self.method.response_type(payload='!?')
631        self._enqueue_server_stream(1, self.method, reply)
632        self._enqueue_response(1, self.method, Status.OK)
633
634        call = self.rpc.invoke()
635
636        self.assertEqual(list(call.get_responses()), [reply])
637        self.assertEqual(list(call.get_responses()), [])
638        self.assertEqual(list(call), [])
639
640
641class ClientStreamingTest(_CallbackClientImplTestBase):
642    """Tests for client streaming RPCs."""
643
644    def setUp(self) -> None:
645        super().setUp()
646        self.rpc = self._service.SomeClientStreaming
647        self.method = self.rpc.method
648
649    def test_blocking_call(self) -> None:
650        requests = [
651            self.method.request_type(magic_number=123),
652            self.method.request_type(magic_number=456),
653        ]
654
655        # Send after len(requests) and the client stream end packet.
656        self.send_responses_after_packets = 3
657        response = self.method.response_type(payload='yo')
658        self._enqueue_response(1, self.method, Status.OK, response)
659
660        results = self.rpc(requests)
661        self.assertIs(results.status, Status.OK)
662        self.assertEqual(results.response, response)
663
664    def test_blocking_server_error(self) -> None:
665        requests = [self.method.request_type(magic_number=123)]
666
667        # Send after len(requests) and the client stream end packet.
668        self._enqueue_error(
669            1, self.method.service, self.method, Status.NOT_FOUND
670        )
671
672        with self.assertRaises(callback_client.RpcError) as context:
673            self.rpc(requests)
674
675        self.assertIs(context.exception.status, Status.NOT_FOUND)
676
677    def test_nonblocking_call(self) -> None:
678        """Tests a successful client streaming RPC ended by the server."""
679        payload_1 = self.method.response_type(payload='-_-')
680
681        for _ in range(3):
682            stream = self._service.SomeClientStreaming.invoke()
683            self.assertFalse(stream.completed())
684
685            stream.send(magic_number=31)
686            self.assertIs(
687                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
688            )
689            self.assertEqual(
690                31, self._sent_payload(self.method.request_type).magic_number
691            )
692            self.assertFalse(stream.completed())
693
694            # Enqueue the server response to be sent after the next message.
695            self._enqueue_response(1, self.method, Status.OK, payload_1)
696
697            stream.send(magic_number=32)
698            self.assertIs(
699                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
700            )
701            self.assertEqual(
702                32, self._sent_payload(self.method.request_type).magic_number
703            )
704
705            self.assertTrue(stream.completed())
706            self.assertIs(Status.OK, stream.status)
707            self.assertIsNone(stream.error)
708            self.assertEqual(payload_1, stream.response)
709
710    def test_open(self) -> None:
711        self.output_exception = IOError('something went wrong sending!')
712        payload = self.method.response_type(payload='-_-')
713
714        for _ in range(3):
715            self._enqueue_response(1, self.method, Status.OK, payload)
716
717            callback = mock.Mock()
718            call = self.rpc.open(callback, callback, callback)
719            self.assertEqual(self.requests, [])
720
721            self._process_enqueued_packets()
722
723            callback.assert_has_calls(
724                [
725                    mock.call(call, payload),
726                    mock.call(call, Status.OK),
727                ]
728            )
729
730    def test_nonblocking_finish(self) -> None:
731        """Tests a client streaming RPC ended by the client."""
732        payload_1 = self.method.response_type(payload='-_-')
733
734        for _ in range(3):
735            stream = self._service.SomeClientStreaming.invoke()
736            self.assertFalse(stream.completed())
737
738            stream.send(magic_number=37)
739            self.assertIs(
740                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
741            )
742            self.assertEqual(
743                37, self._sent_payload(self.method.request_type).magic_number
744            )
745            self.assertFalse(stream.completed())
746
747            # Enqueue the server response to be sent after the next message.
748            self._enqueue_response(1, self.method, Status.OK, payload_1)
749
750            stream.finish_and_wait()
751            self.assertIs(
752                packet_pb2.PacketType.CLIENT_STREAM_END,
753                self.last_request().type,
754            )
755
756            self.assertTrue(stream.completed())
757            self.assertIs(Status.OK, stream.status)
758            self.assertIsNone(stream.error)
759            self.assertEqual(payload_1, stream.response)
760
761    def test_nonblocking_cancel(self) -> None:
762        for _ in range(3):
763            stream = self._service.SomeClientStreaming.invoke()
764            stream.send(magic_number=37)
765
766            self.assertTrue(stream.cancel())
767            self.assertIs(
768                packet_pb2.PacketType.CLIENT_ERROR, self.last_request().type
769            )
770            self.assertIs(Status.CANCELLED.value, self.last_request().status)
771            self.assertFalse(stream.cancel())
772
773            self.assertTrue(stream.completed())
774            self.assertIs(stream.error, Status.CANCELLED)
775
776    def test_nonblocking_server_error(self) -> None:
777        for _ in range(3):
778            stream = self._service.SomeClientStreaming.invoke()
779
780            self._enqueue_error(
781                1, self.method.service, self.method, Status.INVALID_ARGUMENT
782            )
783            stream.send(magic_number=2**32 - 1)
784
785            with self.assertRaises(callback_client.RpcError) as context:
786                stream.finish_and_wait()
787
788            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)
789
790    def test_nonblocking_server_error_after_stream_end(self) -> None:
791        for _ in range(3):
792            stream = self._service.SomeClientStreaming.invoke()
793
794            # Error will be sent in response to the CLIENT_STREAM_END packet.
795            self._enqueue_error(
796                1, self.method.service, self.method, Status.INVALID_ARGUMENT
797            )
798
799            with self.assertRaises(callback_client.RpcError) as context:
800                stream.finish_and_wait()
801
802            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)
803
804    def test_nonblocking_send_after_cancelled(self) -> None:
805        call = self._service.SomeClientStreaming.invoke()
806        self.assertTrue(call.cancel())
807
808        with self.assertRaises(callback_client.RpcError) as context:
809            call.send(payload='hello')
810
811        self.assertIs(context.exception.status, Status.CANCELLED)
812
813    def test_nonblocking_finish_after_completed(self) -> None:
814        reply = self.method.response_type(payload='!?')
815        self._enqueue_response(1, self.method, Status.UNAVAILABLE, reply)
816
817        call = self.rpc.invoke()
818        result = call.finish_and_wait()
819        self.assertEqual(result.response, reply)
820
821        self.assertEqual(result, call.finish_and_wait())
822        self.assertEqual(result, call.finish_and_wait())
823
824    def test_nonblocking_finish_after_error(self) -> None:
825        self._enqueue_error(
826            1, self.method.service, self.method, Status.UNAVAILABLE
827        )
828
829        call = self.rpc.invoke()
830
831        for _ in range(3):
832            with self.assertRaises(callback_client.RpcError) as context:
833                call.finish_and_wait()
834
835            self.assertIs(context.exception.status, Status.UNAVAILABLE)
836            self.assertIs(call.error, Status.UNAVAILABLE)
837            self.assertIsNone(call.response)
838
839    def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None:
840        first_call = self.rpc.invoke()
841        self.assertFalse(first_call.completed())
842
843        second_call = self.rpc.invoke()
844
845        self.assertIs(first_call.error, Status.CANCELLED)
846        self.assertFalse(second_call.completed())
847
848
849class BidirectionalStreamingTest(_CallbackClientImplTestBase):
850    """Tests for bidirectional streaming RPCs."""
851
852    def setUp(self) -> None:
853        super().setUp()
854        self.rpc = self._service.SomeBidiStreaming
855        self.method = self.rpc.method
856
857    def test_blocking_call(self) -> None:
858        requests = [
859            self.method.request_type(magic_number=123),
860            self.method.request_type(magic_number=456),
861        ]
862
863        # Send after len(requests) and the client stream end packet.
864        self.send_responses_after_packets = 3
865        self._enqueue_response(1, self.method, Status.NOT_FOUND)
866
867        results = self.rpc(requests)
868        self.assertIs(results.status, Status.NOT_FOUND)
869        self.assertFalse(results.responses)
870
871    def test_blocking_server_error(self) -> None:
872        requests = [self.method.request_type(magic_number=123)]
873
874        # Send after len(requests) and the client stream end packet.
875        self._enqueue_error(
876            1, self.method.service, self.method, Status.NOT_FOUND
877        )
878
879        with self.assertRaises(callback_client.RpcError) as context:
880            self.rpc(requests)
881
882        self.assertIs(context.exception.status, Status.NOT_FOUND)
883
884    def test_nonblocking_call(self) -> None:
885        """Tests a bidirectional streaming RPC ended by the server."""
886        rep1 = self.method.response_type(payload='!!!')
887        rep2 = self.method.response_type(payload='?')
888
889        for _ in range(3):
890            responses: list = []
891            stream = self._service.SomeBidiStreaming.invoke(
892                lambda _, res, responses=responses: responses.append(res)
893            )
894            self.assertFalse(stream.completed())
895
896            stream.send(magic_number=55)
897            self.assertIs(
898                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
899            )
900            self.assertEqual(
901                55, self._sent_payload(self.method.request_type).magic_number
902            )
903            self.assertFalse(stream.completed())
904            self.assertEqual([], responses)
905
906            self._enqueue_server_stream(1, self.method, rep1)
907            self._enqueue_server_stream(1, self.method, rep2)
908
909            stream.send(magic_number=66)
910            self.assertIs(
911                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
912            )
913            self.assertEqual(
914                66, self._sent_payload(self.method.request_type).magic_number
915            )
916            self.assertFalse(stream.completed())
917            self.assertEqual([rep1, rep2], responses)
918
919            self._enqueue_response(1, self.method, Status.OK)
920
921            stream.send(magic_number=77)
922            self.assertTrue(stream.completed())
923            self.assertEqual([rep1, rep2], responses)
924
925            self.assertIs(Status.OK, stream.status)
926            self.assertIsNone(stream.error)
927
928    def test_open(self) -> None:
929        self.output_exception = IOError('something went wrong sending!')
930        rep1 = self.method.response_type(payload='!!!')
931        rep2 = self.method.response_type(payload='?')
932
933        for _ in range(3):
934            self._enqueue_server_stream(1, self.method, rep1)
935            self._enqueue_server_stream(1, self.method, rep2)
936            self._enqueue_response(1, self.method, Status.OK)
937
938            callback = mock.Mock()
939            call = self.rpc.open(callback, callback, callback)
940            self.assertEqual(self.requests, [])
941
942            self._process_enqueued_packets()
943
944            callback.assert_has_calls(
945                [
946                    mock.call(call, self.method.response_type(payload='!!!')),
947                    mock.call(call, self.method.response_type(payload='?')),
948                    mock.call(call, Status.OK),
949                ]
950            )
951
952    @mock.patch('pw_rpc.callback_client.call.Call._default_response')
953    def test_nonblocking(self, callback) -> None:
954        """Tests a bidirectional streaming RPC ended by the server."""
955        reply = self.method.response_type(payload='This is the payload!')
956        self._enqueue_server_stream(1, self.method, reply)
957
958        self._service.SomeBidiStreaming.invoke()
959
960        callback.assert_called_once_with(mock.ANY, reply)
961
962    def test_nonblocking_server_error(self) -> None:
963        rep1 = self.method.response_type(payload='!!!')
964
965        for _ in range(3):
966            responses: list = []
967            stream = self._service.SomeBidiStreaming.invoke(
968                lambda _, res, responses=responses: responses.append(res)
969            )
970            self.assertFalse(stream.completed())
971
972            self._enqueue_server_stream(1, self.method, rep1)
973
974            stream.send(magic_number=55)
975            self.assertFalse(stream.completed())
976            self.assertEqual([rep1], responses)
977
978            self._enqueue_error(
979                1, self.method.service, self.method, Status.OUT_OF_RANGE
980            )
981
982            stream.send(magic_number=99999)
983            self.assertTrue(stream.completed())
984            self.assertEqual([rep1], responses)
985
986            self.assertIsNone(stream.status)
987            self.assertIs(Status.OUT_OF_RANGE, stream.error)
988
989            with self.assertRaises(callback_client.RpcError) as context:
990                stream.finish_and_wait()
991            self.assertIs(context.exception.status, Status.OUT_OF_RANGE)
992
993    def test_nonblocking_server_error_after_stream_end(self) -> None:
994        for _ in range(3):
995            stream = self._service.SomeBidiStreaming.invoke()
996
997            # Error will be sent in response to the CLIENT_STREAM_END packet.
998            self._enqueue_error(
999                1, self.method.service, self.method, Status.INVALID_ARGUMENT
1000            )
1001
1002            with self.assertRaises(callback_client.RpcError) as context:
1003                stream.finish_and_wait()
1004
1005            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)
1006
1007    def test_nonblocking_send_after_cancelled(self) -> None:
1008        call = self._service.SomeBidiStreaming.invoke()
1009        self.assertTrue(call.cancel())
1010
1011        with self.assertRaises(callback_client.RpcError) as context:
1012            call.send(payload='hello')
1013
1014        self.assertIs(context.exception.status, Status.CANCELLED)
1015
1016    def test_nonblocking_finish_after_completed(self) -> None:
1017        reply = self.method.response_type(payload='!?')
1018        self._enqueue_server_stream(1, self.method, reply)
1019        self._enqueue_response(1, self.method, Status.UNAVAILABLE)
1020
1021        call = self.rpc.invoke()
1022        result = call.finish_and_wait()
1023        self.assertEqual(result.responses, [reply])
1024
1025        self.assertEqual(result, call.finish_and_wait())
1026        self.assertEqual(result, call.finish_and_wait())
1027
1028    def test_nonblocking_finish_after_error(self) -> None:
1029        reply = self.method.response_type(payload='!?')
1030        self._enqueue_server_stream(1, self.method, reply)
1031        self._enqueue_error(
1032            1, self.method.service, self.method, Status.UNAVAILABLE
1033        )
1034
1035        call = self.rpc.invoke()
1036
1037        for _ in range(3):
1038            with self.assertRaises(callback_client.RpcError) as context:
1039                call.finish_and_wait()
1040
1041            self.assertIs(context.exception.status, Status.UNAVAILABLE)
1042            self.assertIs(call.error, Status.UNAVAILABLE)
1043            self.assertEqual(call.responses, [reply])
1044
1045    def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None:
1046        first_call = self.rpc.invoke()
1047        self.assertFalse(first_call.completed())
1048
1049        second_call = self.rpc.invoke()
1050
1051        self.assertIs(first_call.error, Status.CANCELLED)
1052        self.assertFalse(second_call.completed())
1053
1054    def test_stream_response(self) -> None:
1055        proto = self._protos.packages.pw.test1.SomeMessage(magic_number=123)
1056        self.assertEqual(
1057            repr(callback_client.StreamResponse(Status.ABORTED, [proto] * 2)),
1058            '(Status.ABORTED, [pw.test1.SomeMessage(magic_number=123), '
1059            'pw.test1.SomeMessage(magic_number=123)])',
1060        )
1061        self.assertEqual(
1062            repr(callback_client.StreamResponse(Status.OK, [])),
1063            '(Status.OK, [])',
1064        )
1065
1066
1067if __name__ == '__main__':
1068    unittest.main()
1069