• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2021 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests using the callback client for pw_rpc."""
16
17import unittest
18from unittest import mock
19from typing import Any
20
21from pw_protobuf_compiler import python_protos
22from pw_status import Status
23
24from pw_rpc import callback_client, client, descriptors, packets
25from pw_rpc.internal import packet_pb2
26
27TEST_PROTO_1 = """\
28syntax = "proto3";
29
30package pw.test1;
31
32message SomeMessage {
33  uint32 magic_number = 1;
34}
35
36message AnotherMessage {
37  enum Result {
38    FAILED = 0;
39    FAILED_MISERABLY = 1;
40    I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
41  }
42
43  Result result = 1;
44  string payload = 2;
45}
46
47service PublicService {
48  rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
49  rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
50  rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
51  rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
52}
53"""
54
55CLIENT_CHANNEL_ID: int = 489
56
57
58def _message_bytes(msg) -> bytes:
59    return msg if isinstance(msg, bytes) else msg.SerializeToString()
60
61
62class _CallbackClientImplTestBase(unittest.TestCase):
63    """Supports writing tests that require responses from an RPC server."""
64
65    def setUp(self) -> None:
66        self._protos = python_protos.Library.from_strings(TEST_PROTO_1)
67        self._request = self._protos.packages.pw.test1.SomeMessage
68
69        self._client = client.Client.from_modules(
70            callback_client.Impl(),
71            [client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)],
72            self._protos.modules(),
73        )
74        self._service = self._client.channel(
75            CLIENT_CHANNEL_ID
76        ).rpcs.pw.test1.PublicService
77
78        self.requests: list[packet_pb2.RpcPacket] = []
79        self._next_packets: list[tuple[bytes, Status]] = []
80        self.send_responses_after_packets: float = 1
81
82        self.output_exception: Exception | None = None
83
84    def last_request(self) -> packet_pb2.RpcPacket:
85        assert self.requests
86        return self.requests[-1]
87
88    def _enqueue_response(
89        self,
90        channel_id: int = CLIENT_CHANNEL_ID,
91        method: descriptors.Method | None = None,
92        status: Status = Status.OK,
93        payload: bytes = b'',
94        *,
95        ids: tuple[int, int] | None = None,
96        process_status: Status = Status.OK,
97        call_id: int = client.OPEN_CALL_ID,
98    ) -> None:
99        if method:
100            assert ids is None
101            service_id, method_id = method.service.id, method.id
102        else:
103            assert ids is not None and method is None
104            service_id, method_id = ids
105
106        self._next_packets.append(
107            (
108                packet_pb2.RpcPacket(
109                    type=packet_pb2.PacketType.RESPONSE,
110                    channel_id=channel_id,
111                    service_id=service_id,
112                    method_id=method_id,
113                    call_id=call_id,
114                    status=status.value,
115                    payload=_message_bytes(payload),
116                ).SerializeToString(),
117                process_status,
118            )
119        )
120
121    def _enqueue_server_stream(
122        self,
123        channel_id: int,
124        method,
125        response,
126        process_status=Status.OK,
127        call_id: int = client.OPEN_CALL_ID,
128    ) -> None:
129        self._next_packets.append(
130            (
131                packet_pb2.RpcPacket(
132                    type=packet_pb2.PacketType.SERVER_STREAM,
133                    channel_id=channel_id,
134                    service_id=method.service.id,
135                    method_id=method.id,
136                    call_id=call_id,
137                    payload=_message_bytes(response),
138                ).SerializeToString(),
139                process_status,
140            )
141        )
142
143    def _enqueue_error(
144        self,
145        channel_id: int,
146        service,
147        method,
148        status: Status,
149        process_status=Status.OK,
150        call_id: int = client.OPEN_CALL_ID,
151    ) -> None:
152        self._next_packets.append(
153            (
154                packet_pb2.RpcPacket(
155                    type=packet_pb2.PacketType.SERVER_ERROR,
156                    channel_id=channel_id,
157                    service_id=service
158                    if isinstance(service, int)
159                    else service.id,
160                    method_id=method if isinstance(method, int) else method.id,
161                    call_id=call_id,
162                    status=status.value,
163                ).SerializeToString(),
164                process_status,
165            )
166        )
167
168    def _handle_packet(self, data: bytes) -> None:
169        if self.output_exception:
170            raise self.output_exception  # pylint: disable=raising-bad-type
171
172        self.requests.append(packets.decode(data))
173
174        if self.send_responses_after_packets > 1:
175            self.send_responses_after_packets -= 1
176            return
177
178        self._process_enqueued_packets()
179
180    def _process_enqueued_packets(self) -> None:
181        # Set send_responses_after_packets to infinity to prevent potential
182        # infinite recursion when a packet causes another packet to send.
183        send_after_count = self.send_responses_after_packets
184        self.send_responses_after_packets = float('inf')
185
186        for packet, status in self._next_packets:
187            self.assertIs(status, self._client.process_packet(packet))
188
189        self._next_packets.clear()
190        self.send_responses_after_packets = send_after_count
191
192    def _sent_payload(self, message_type: type) -> Any:
193        message = message_type()
194        message.ParseFromString(self.last_request().payload)
195        return message
196
197
198# Disable docstring requirements for test functions.
199# pylint: disable=missing-function-docstring
200
201
202class CallbackClientImplTest(_CallbackClientImplTestBase):
203    """Tests the callback_client.Impl client implementation."""
204
205    def test_callback_exceptions_suppressed(self) -> None:
206        stub = self._service.SomeUnary
207
208        self._enqueue_response(CLIENT_CHANNEL_ID, stub.method)
209        exception_msg = 'YOU BROKE IT O-]-<'
210
211        with self.assertLogs(callback_client.__package__, 'ERROR') as logs:
212            stub.invoke(
213                self._request(), mock.Mock(side_effect=Exception(exception_msg))
214            )
215
216        self.assertIn(exception_msg, ''.join(logs.output))
217
218        # Make sure we can still invoke the RPC.
219        self._enqueue_response(CLIENT_CHANNEL_ID, stub.method, Status.UNKNOWN)
220        status, _ = stub()
221        self.assertIs(status, Status.UNKNOWN)
222
223    def test_ignore_bad_packets_with_pending_rpc(self) -> None:
224        method = self._service.SomeUnary.method
225        service_id = method.service.id
226
227        # Unknown channel
228        self._enqueue_response(999, method, process_status=Status.NOT_FOUND)
229        # Bad service
230        self._enqueue_response(
231            CLIENT_CHANNEL_ID, ids=(999, method.id), process_status=Status.OK
232        )
233        # Bad method
234        self._enqueue_response(
235            CLIENT_CHANNEL_ID, ids=(service_id, 999), process_status=Status.OK
236        )
237        # For RPC not pending (is Status.OK because the packet is processed)
238        self._enqueue_response(
239            CLIENT_CHANNEL_ID,
240            ids=(service_id, self._service.SomeBidiStreaming.method.id),
241            process_status=Status.OK,
242        )
243
244        self._enqueue_response(
245            CLIENT_CHANNEL_ID, method, process_status=Status.OK
246        )
247
248        status, response = self._service.SomeUnary(magic_number=6)
249        self.assertIs(Status.OK, status)
250        self.assertEqual('', response.payload)
251
252    def test_server_error_for_unknown_call_sends_no_errors(self) -> None:
253        method = self._service.SomeUnary.method
254        service_id = method.service.id
255
256        # Unknown channel
257        self._enqueue_error(
258            999,
259            service_id,
260            method,
261            Status.NOT_FOUND,
262            process_status=Status.NOT_FOUND,
263        )
264        # Bad service
265        self._enqueue_error(
266            CLIENT_CHANNEL_ID, 999, method.id, Status.INVALID_ARGUMENT
267        )
268        # Bad method
269        self._enqueue_error(
270            CLIENT_CHANNEL_ID, service_id, 999, Status.INVALID_ARGUMENT
271        )
272        # For RPC not pending
273        self._enqueue_error(
274            CLIENT_CHANNEL_ID,
275            service_id,
276            self._service.SomeBidiStreaming.method.id,
277            Status.NOT_FOUND,
278        )
279
280        self._process_enqueued_packets()
281
282        self.assertEqual(self.requests, [])
283
284    def test_exception_if_payload_fails_to_decode(self) -> None:
285        method = self._service.SomeUnary.method
286
287        self._enqueue_response(
288            CLIENT_CHANNEL_ID,
289            method,
290            Status.OK,
291            b'INVALID DATA!!!',
292            process_status=Status.OK,
293        )
294
295        with self.assertRaises(callback_client.RpcError) as context:
296            self._service.SomeUnary(magic_number=6)
297
298        self.assertIs(context.exception.status, Status.DATA_LOSS)
299
300    def test_rpc_help_contains_method_name(self) -> None:
301        rpc = self._service.SomeUnary
302        self.assertIn(rpc.method.full_name, rpc.help())
303
304    def test_default_timeouts_set_on_impl(self) -> None:
305        impl = callback_client.Impl(None, 1.5)
306
307        self.assertEqual(impl.default_unary_timeout_s, None)
308        self.assertEqual(impl.default_stream_timeout_s, 1.5)
309
310    def test_default_timeouts_set_for_all_rpcs(self) -> None:
311        rpc_client = client.Client.from_modules(
312            callback_client.Impl(99, 100),
313            [client.Channel(CLIENT_CHANNEL_ID, lambda *a, **b: None)],
314            self._protos.modules(),
315        )
316        rpcs = rpc_client.channel(CLIENT_CHANNEL_ID).rpcs
317
318        self.assertEqual(
319            rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99
320        )
321        self.assertEqual(
322            rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s,
323            100,
324        )
325        self.assertEqual(
326            rpcs.pw.test1.PublicService.SomeClientStreaming.default_timeout_s,
327            99,
328        )
329        self.assertEqual(
330            rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s, 100
331        )
332
333    def test_rpc_provides_request_type(self) -> None:
334        self.assertIs(
335            self._service.SomeUnary.request,
336            self._service.SomeUnary.method.request_type,
337        )
338
339    def test_rpc_provides_response_type(self) -> None:
340        self.assertIs(
341            self._service.SomeUnary.request,
342            self._service.SomeUnary.method.request_type,
343        )
344
345
346class UnaryTest(_CallbackClientImplTestBase):
347    """Tests for invoking a unary RPC."""
348
349    def setUp(self) -> None:
350        super().setUp()
351        self.rpc = self._service.SomeUnary
352        self.method = self.rpc.method
353
354    def test_blocking_call(self) -> None:
355        for _ in range(3):
356            self._enqueue_response(
357                CLIENT_CHANNEL_ID,
358                self.method,
359                Status.ABORTED,
360                self.method.response_type(payload='0_o'),
361            )
362
363            status, response = self._service.SomeUnary(
364                self.method.request_type(magic_number=6)
365            )
366
367            self.assertEqual(
368                6, self._sent_payload(self.method.request_type).magic_number
369            )
370
371            self.assertIs(Status.ABORTED, status)
372            self.assertEqual('0_o', response.payload)
373
374    def test_nonblocking_call(self) -> None:
375        for _ in range(3):
376            callback = mock.Mock()
377            call = self.rpc.invoke(
378                self._request(magic_number=5), callback, callback
379            )
380
381            self._enqueue_response(
382                CLIENT_CHANNEL_ID,
383                self.method,
384                Status.ABORTED,
385                self.method.response_type(payload='0_o'),
386                call_id=call.call_id,
387            )
388            self._process_enqueued_packets()
389
390            callback.assert_has_calls(
391                [
392                    mock.call(call, self.method.response_type(payload='0_o')),
393                    mock.call(call, Status.ABORTED),
394                ]
395            )
396
397            self.assertEqual(
398                5, self._sent_payload(self.method.request_type).magic_number
399            )
400
401    def test_concurrent_nonblocking_calls(self) -> None:
402        # Start several calls to the same method
403        callbacks_and_calls: list[
404            tuple[mock.Mock, callback_client.call.Call]
405        ] = []
406        for _ in range(3):
407            callback = mock.Mock()
408            call = self.rpc.invoke(self._request(magic_number=5), callback)
409            callbacks_and_calls.append((callback, call))
410
411        # Respond only to the last call
412        last_callback, last_call = callbacks_and_calls.pop()
413        last_payload = self.method.response_type(payload='last payload')
414        self._enqueue_response(
415            CLIENT_CHANNEL_ID,
416            self.method,
417            payload=last_payload,
418            call_id=last_call.call_id,
419        )
420        self._process_enqueued_packets()
421
422        # Assert that only the last caller received a response
423        last_callback.assert_called_once_with(last_call, last_payload)
424        for remaining_callback, _ in callbacks_and_calls:
425            remaining_callback.assert_not_called()
426
427        # Respond to the other callers and check for receipt
428        other_payload = self.method.response_type(payload='other payload')
429        for callback, call in callbacks_and_calls:
430            self._enqueue_response(
431                CLIENT_CHANNEL_ID,
432                self.method,
433                payload=other_payload,
434                call_id=call.call_id,
435            )
436            self._process_enqueued_packets()
437            callback.assert_called_once_with(call, other_payload)
438
439    def test_open(self) -> None:
440        self.output_exception = IOError('something went wrong sending!')
441
442        for _ in range(3):
443            self._enqueue_response(
444                CLIENT_CHANNEL_ID,
445                self.method,
446                Status.ABORTED,
447                self.method.response_type(payload='0_o'),
448            )
449
450            callback = mock.Mock()
451            call = self.rpc.open(
452                self._request(magic_number=5), callback, callback
453            )
454            self.assertEqual(self.requests, [])
455
456            self._process_enqueued_packets()
457
458            callback.assert_has_calls(
459                [
460                    mock.call(call, self.method.response_type(payload='0_o')),
461                    mock.call(call, Status.ABORTED),
462                ]
463            )
464
465    def test_blocking_server_error(self) -> None:
466        for _ in range(3):
467            self._enqueue_error(
468                CLIENT_CHANNEL_ID,
469                self.method.service,
470                self.method,
471                Status.NOT_FOUND,
472            )
473
474            with self.assertRaises(callback_client.RpcError) as context:
475                self._service.SomeUnary(
476                    self.method.request_type(magic_number=6)
477                )
478
479            self.assertIs(context.exception.status, Status.NOT_FOUND)
480
481    def test_nonblocking_cancel(self) -> None:
482        callback = mock.Mock()
483
484        for _ in range(3):
485            call = self._service.SomeUnary.invoke(
486                self._request(magic_number=55), callback
487            )
488
489            self.assertGreater(len(self.requests), 0)
490            self.requests.clear()
491
492            self.assertTrue(call.cancel())
493            self.assertFalse(call.cancel())  # Already cancelled, returns False
494
495            self.assertEqual(
496                self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR
497            )
498            self.assertEqual(self.last_request().status, Status.CANCELLED.value)
499
500        callback.assert_not_called()
501
502    def test_nonblocking_with_request_args(self) -> None:
503        self.rpc.invoke(request_args=dict(magic_number=1138))
504        self.assertEqual(
505            self._sent_payload(self.rpc.request).magic_number, 1138
506        )
507
508    def test_blocking_timeout_as_argument(self) -> None:
509        with self.assertRaises(callback_client.RpcTimeout):
510            self._service.SomeUnary(pw_rpc_timeout_s=0.0001)
511
512    def test_blocking_timeout_set_default(self) -> None:
513        self._service.SomeUnary.default_timeout_s = 0.0001
514
515        with self.assertRaises(callback_client.RpcTimeout):
516            self._service.SomeUnary()
517
518    def test_nonblocking_duplicate_calls_not_cancelled(self) -> None:
519        first_call = self.rpc.invoke()
520        self.assertFalse(first_call.completed())
521
522        second_call = self.rpc.invoke()
523
524        self.assertIs(first_call.error, None)
525        self.assertIs(second_call.error, None)
526
527    def test_nonblocking_exception_in_callback(self) -> None:
528        exception = ValueError('something went wrong! (intentionally)')
529
530        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)
531
532        call = self.rpc.invoke(on_completed=mock.Mock(side_effect=exception))
533
534        with self.assertRaises(RuntimeError) as context:
535            call.wait()
536
537        self.assertEqual(context.exception.__cause__, exception)
538
539    def test_unary_response(self) -> None:
540        proto = self._protos.packages.pw.test1.SomeMessage(magic_number=123)
541        self.assertEqual(
542            repr(callback_client.UnaryResponse(Status.ABORTED, proto)),
543            '(Status.ABORTED, pw.test1.SomeMessage(magic_number=123))',
544        )
545        self.assertEqual(
546            repr(callback_client.UnaryResponse(Status.OK, None)),
547            '(Status.OK, None)',
548        )
549
550    def test_on_call_hook(self) -> None:
551        hook_function = mock.Mock()
552
553        self._client = client.Client.from_modules(
554            callback_client.Impl(on_call_hook=hook_function),
555            [client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)],
556            self._protos.modules(),
557        )
558
559        self._service = self._client.channel(
560            CLIENT_CHANNEL_ID
561        ).rpcs.pw.test1.PublicService
562
563        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)
564        self._service.SomeUnary(self.method.request_type(magic_number=6))
565
566        hook_function.assert_called_once()
567        self.assertEqual(
568            hook_function.call_args[0][0].method.full_name,
569            self.method.full_name,
570        )
571
572
573class ServerStreamingTest(_CallbackClientImplTestBase):
574    """Tests for server streaming RPCs."""
575
576    def setUp(self) -> None:
577        super().setUp()
578        self.rpc = self._service.SomeServerStreaming
579        self.method = self.rpc.method
580
581    def test_blocking_call(self) -> None:
582        rep1 = self.method.response_type(payload='!!!')
583        rep2 = self.method.response_type(payload='?')
584
585        for _ in range(3):
586            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
587            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2)
588            self._enqueue_response(
589                CLIENT_CHANNEL_ID, self.method, Status.ABORTED
590            )
591
592            self.assertEqual(
593                [rep1, rep2],
594                self._service.SomeServerStreaming(magic_number=4).responses,
595            )
596
597            self.assertEqual(
598                4, self._sent_payload(self.method.request_type).magic_number
599            )
600
601    def test_nonblocking_call(self) -> None:
602        rep1 = self.method.response_type(payload='!!!')
603        rep2 = self.method.response_type(payload='?')
604
605        for _ in range(3):
606            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
607            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2)
608            self._enqueue_response(
609                CLIENT_CHANNEL_ID, self.method, Status.ABORTED
610            )
611
612            callback = mock.Mock()
613            call = self.rpc.invoke(
614                self._request(magic_number=3), callback, callback
615            )
616
617            callback.assert_has_calls(
618                [
619                    mock.call(call, self.method.response_type(payload='!!!')),
620                    mock.call(call, self.method.response_type(payload='?')),
621                    mock.call(call, Status.ABORTED),
622                ]
623            )
624
625            self.assertEqual(
626                3, self._sent_payload(self.method.request_type).magic_number
627            )
628
629    def test_open(self) -> None:
630        self.output_exception = IOError('something went wrong sending!')
631        rep1 = self.method.response_type(payload='!!!')
632        rep2 = self.method.response_type(payload='?')
633
634        for _ in range(3):
635            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
636            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2)
637            self._enqueue_response(
638                CLIENT_CHANNEL_ID, self.method, Status.ABORTED
639            )
640
641            callback = mock.Mock()
642            call = self.rpc.open(
643                self._request(magic_number=3), callback, callback
644            )
645            self.assertEqual(self.requests, [])
646
647            self._process_enqueued_packets()
648
649            callback.assert_has_calls(
650                [
651                    mock.call(call, self.method.response_type(payload='!!!')),
652                    mock.call(call, self.method.response_type(payload='?')),
653                    mock.call(call, Status.ABORTED),
654                ]
655            )
656
657    def test_nonblocking_cancel(self) -> None:
658        resp = self.rpc.method.response_type(payload='!!!')
659        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.rpc.method, resp)
660
661        callback = mock.Mock()
662        call = self.rpc.invoke(self._request(magic_number=3), callback)
663        callback.assert_called_once_with(
664            call, self.rpc.method.response_type(payload='!!!')
665        )
666
667        callback.reset_mock()
668
669        call.cancel()
670
671        self.assertEqual(
672            self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR
673        )
674        self.assertEqual(self.last_request().status, Status.CANCELLED.value)
675
676        # Ensure the RPC can be called after being cancelled.
677        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, resp)
678        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)
679
680        call = self.rpc.invoke(
681            self._request(magic_number=3), callback, callback
682        )
683
684        callback.assert_has_calls(
685            [
686                mock.call(call, self.method.response_type(payload='!!!')),
687                mock.call(call, Status.OK),
688            ]
689        )
690
691    def test_request_completion(self) -> None:
692        resp = self.rpc.method.response_type(payload='!!!')
693        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.rpc.method, resp)
694
695        callback = mock.Mock()
696        call = self.rpc.invoke(self._request(magic_number=3), callback)
697        callback.assert_called_once_with(
698            call, self.rpc.method.response_type(payload='!!!')
699        )
700
701        callback.reset_mock()
702
703        call.request_completion()
704
705        self.assertEqual(
706            self.last_request().type,
707            packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION,
708        )
709
710        # Ensure the RPC can be called after being completed.
711        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, resp)
712        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)
713
714        call = self.rpc.invoke(
715            self._request(magic_number=3), callback, callback
716        )
717
718        callback.assert_has_calls(
719            [
720                mock.call(call, self.method.response_type(payload='!!!')),
721                mock.call(call, Status.OK),
722            ]
723        )
724
725    def test_nonblocking_with_request_args(self) -> None:
726        self.rpc.invoke(request_args=dict(magic_number=1138))
727        self.assertEqual(
728            self._sent_payload(self.rpc.request).magic_number, 1138
729        )
730
731    def test_blocking_timeout(self) -> None:
732        with self.assertRaises(callback_client.RpcTimeout):
733            self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001)
734
735    def test_nonblocking_iteration_timeout(self) -> None:
736        call = self._service.SomeServerStreaming.invoke(timeout_s=0.0001)
737        with self.assertRaises(callback_client.RpcTimeout):
738            for _ in call:
739                pass
740
741    def test_nonblocking_duplicate_calls_not_cancelled(self) -> None:
742        first_call = self.rpc.invoke()
743        self.assertFalse(first_call.completed())
744
745        second_call = self.rpc.invoke()
746
747        self.assertIs(first_call.error, None)
748        self.assertIs(second_call.error, None)
749
750    def test_nonblocking_iterate_over_count(self) -> None:
751        reply = self.method.response_type(payload='!?')
752
753        for _ in range(4):
754            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)
755
756        call = self.rpc.invoke()
757
758        self.assertEqual(list(call.get_responses(count=1)), [reply])
759        self.assertEqual(next(iter(call)), reply)
760        self.assertEqual(list(call.get_responses(count=2)), [reply, reply])
761
762    def test_nonblocking_iterate_after_completed_doesnt_block(self) -> None:
763        reply = self.method.response_type(payload='!?')
764        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)
765        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)
766
767        call = self.rpc.invoke()
768
769        self.assertEqual(list(call.get_responses()), [reply])
770        self.assertEqual(list(call.get_responses()), [])
771        self.assertEqual(list(call), [])
772
773
774class ClientStreamingTest(_CallbackClientImplTestBase):
775    """Tests for client streaming RPCs."""
776
777    def setUp(self) -> None:
778        super().setUp()
779        self.rpc = self._service.SomeClientStreaming
780        self.method = self.rpc.method
781
782    def test_blocking_call(self) -> None:
783        requests = [
784            self.method.request_type(magic_number=123),
785            self.method.request_type(magic_number=456),
786        ]
787
788        # Send after len(requests) and the client stream end packet.
789        self.send_responses_after_packets = 3
790        response = self.method.response_type(payload='yo')
791        self._enqueue_response(
792            CLIENT_CHANNEL_ID, self.method, Status.OK, response
793        )
794
795        results = self.rpc(requests)
796        self.assertIs(results.status, Status.OK)
797        self.assertEqual(results.response, response)
798
799    def test_blocking_server_error(self) -> None:
800        requests = [self.method.request_type(magic_number=123)]
801
802        # Send after len(requests) and the client stream end packet.
803        self._enqueue_error(
804            CLIENT_CHANNEL_ID,
805            self.method.service,
806            self.method,
807            Status.NOT_FOUND,
808        )
809
810        with self.assertRaises(callback_client.RpcError) as context:
811            self.rpc(requests)
812
813        self.assertIs(context.exception.status, Status.NOT_FOUND)
814
815    def test_nonblocking_call(self) -> None:
816        """Tests a successful client streaming RPC ended by the server."""
817        payload_1 = self.method.response_type(payload='-_-')
818
819        for _ in range(3):
820            stream = self._service.SomeClientStreaming.invoke()
821            self.assertFalse(stream.completed())
822
823            stream.send(magic_number=31)
824            self.assertIs(
825                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
826            )
827            self.assertEqual(
828                31, self._sent_payload(self.method.request_type).magic_number
829            )
830            self.assertFalse(stream.completed())
831
832            # Enqueue the server response to be sent after the next message.
833            self._enqueue_response(
834                CLIENT_CHANNEL_ID, self.method, Status.OK, payload_1
835            )
836
837            stream.send(magic_number=32)
838            self.assertIs(
839                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
840            )
841            self.assertEqual(
842                32, self._sent_payload(self.method.request_type).magic_number
843            )
844
845            self.assertTrue(stream.completed())
846            self.assertIs(Status.OK, stream.status)
847            self.assertIsNone(stream.error)
848            self.assertEqual(payload_1, stream.response)
849
850    def test_open(self) -> None:
851        self.output_exception = IOError('something went wrong sending!')
852        payload = self.method.response_type(payload='-_-')
853
854        for _ in range(3):
855            self._enqueue_response(
856                CLIENT_CHANNEL_ID, self.method, Status.OK, payload
857            )
858
859            callback = mock.Mock()
860            call = self.rpc.open(callback, callback, callback)
861            self.assertEqual(self.requests, [])
862
863            self._process_enqueued_packets()
864
865            callback.assert_has_calls(
866                [
867                    mock.call(call, payload),
868                    mock.call(call, Status.OK),
869                ]
870            )
871
872    def test_nonblocking_finish(self) -> None:
873        """Tests a client streaming RPC ended by the client."""
874        payload_1 = self.method.response_type(payload='-_-')
875
876        for _ in range(3):
877            stream = self._service.SomeClientStreaming.invoke()
878            self.assertFalse(stream.completed())
879
880            stream.send(magic_number=37)
881            self.assertIs(
882                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
883            )
884            self.assertEqual(
885                37, self._sent_payload(self.method.request_type).magic_number
886            )
887            self.assertFalse(stream.completed())
888
889            # Enqueue the server response to be sent after the next message.
890            self._enqueue_response(
891                CLIENT_CHANNEL_ID, self.method, Status.OK, payload_1
892            )
893
894            stream.finish_and_wait()
895            self.assertIs(
896                packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION,
897                self.last_request().type,
898            )
899
900            self.assertTrue(stream.completed())
901            self.assertIs(Status.OK, stream.status)
902            self.assertIsNone(stream.error)
903            self.assertEqual(payload_1, stream.response)
904
905    def test_nonblocking_cancel(self) -> None:
906        for _ in range(3):
907            stream = self._service.SomeClientStreaming.invoke()
908            stream.send(magic_number=37)
909
910            self.assertTrue(stream.cancel())
911            self.assertIs(
912                packet_pb2.PacketType.CLIENT_ERROR, self.last_request().type
913            )
914            self.assertIs(Status.CANCELLED.value, self.last_request().status)
915            self.assertFalse(stream.cancel())
916
917            self.assertTrue(stream.completed())
918            self.assertIs(stream.error, Status.CANCELLED)
919
920    def test_nonblocking_server_error(self) -> None:
921        for _ in range(3):
922            stream = self._service.SomeClientStreaming.invoke()
923
924            self._enqueue_error(
925                CLIENT_CHANNEL_ID,
926                self.method.service,
927                self.method,
928                Status.INVALID_ARGUMENT,
929            )
930            stream.send(magic_number=2**32 - 1)
931
932            with self.assertRaises(callback_client.RpcError) as context:
933                stream.finish_and_wait()
934
935            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)
936
937    def test_nonblocking_server_error_after_stream_end(self) -> None:
938        for _ in range(3):
939            stream = self._service.SomeClientStreaming.invoke()
940
941            # Error will be sent in response to the CLIENT_REQUEST_COMPLETION
942            # packet.
943            self._enqueue_error(
944                CLIENT_CHANNEL_ID,
945                self.method.service,
946                self.method,
947                Status.INVALID_ARGUMENT,
948            )
949
950            with self.assertRaises(callback_client.RpcError) as context:
951                stream.finish_and_wait()
952
953            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)
954
955    def test_nonblocking_send_after_cancelled(self) -> None:
956        call = self._service.SomeClientStreaming.invoke()
957        self.assertTrue(call.cancel())
958
959        with self.assertRaises(callback_client.RpcError) as context:
960            call.send(payload='hello')
961
962        self.assertIs(context.exception.status, Status.CANCELLED)
963
964    def test_nonblocking_finish_after_completed(self) -> None:
965        reply = self.method.response_type(payload='!?')
966        self._enqueue_response(
967            CLIENT_CHANNEL_ID, self.method, Status.UNAVAILABLE, reply
968        )
969
970        call = self.rpc.invoke()
971        result = call.finish_and_wait()
972        self.assertEqual(result.response, reply)
973
974        self.assertEqual(result, call.finish_and_wait())
975        self.assertEqual(result, call.finish_and_wait())
976
977    def test_nonblocking_finish_after_error(self) -> None:
978        self._enqueue_error(
979            CLIENT_CHANNEL_ID,
980            self.method.service,
981            self.method,
982            Status.UNAVAILABLE,
983        )
984
985        call = self.rpc.invoke()
986
987        for _ in range(3):
988            with self.assertRaises(callback_client.RpcError) as context:
989                call.finish_and_wait()
990
991            self.assertIs(context.exception.status, Status.UNAVAILABLE)
992            self.assertIs(call.error, Status.UNAVAILABLE)
993            self.assertIsNone(call.response)
994
995    def test_nonblocking_duplicate_calls_not_cancelled(self) -> None:
996        first_call = self.rpc.invoke()
997        self.assertFalse(first_call.completed())
998
999        second_call = self.rpc.invoke()
1000
1001        self.assertIs(first_call.error, None)
1002        self.assertIs(second_call.error, None)
1003
1004
1005class BidirectionalStreamingTest(_CallbackClientImplTestBase):
1006    """Tests for bidirectional streaming RPCs."""
1007
1008    def setUp(self) -> None:
1009        super().setUp()
1010        self.rpc = self._service.SomeBidiStreaming
1011        self.method = self.rpc.method
1012
1013    def test_blocking_call(self) -> None:
1014        requests = [
1015            self.method.request_type(magic_number=123),
1016            self.method.request_type(magic_number=456),
1017        ]
1018
1019        # Send after len(requests) and the client stream end packet.
1020        self.send_responses_after_packets = 3
1021        self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.NOT_FOUND)
1022
1023        results = self.rpc(requests)
1024        self.assertIs(results.status, Status.NOT_FOUND)
1025        self.assertFalse(results.responses)
1026
1027    def test_blocking_server_error(self) -> None:
1028        requests = [self.method.request_type(magic_number=123)]
1029
1030        # Send after len(requests) and the client stream end packet.
1031        self._enqueue_error(
1032            CLIENT_CHANNEL_ID,
1033            self.method.service,
1034            self.method,
1035            Status.NOT_FOUND,
1036        )
1037
1038        with self.assertRaises(callback_client.RpcError) as context:
1039            self.rpc(requests)
1040
1041        self.assertIs(context.exception.status, Status.NOT_FOUND)
1042
1043    def test_nonblocking_call(self) -> None:
1044        """Tests a bidirectional streaming RPC ended by the server."""
1045        rep1 = self.method.response_type(payload='!!!')
1046        rep2 = self.method.response_type(payload='?')
1047
1048        for _ in range(3):
1049            responses: list = []
1050            stream = self._service.SomeBidiStreaming.invoke(
1051                lambda _, res, responses=responses: responses.append(res)
1052            )
1053            self.assertFalse(stream.completed())
1054
1055            stream.send(magic_number=55)
1056            self.assertIs(
1057                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
1058            )
1059            self.assertEqual(
1060                55, self._sent_payload(self.method.request_type).magic_number
1061            )
1062            self.assertFalse(stream.completed())
1063            self.assertEqual([], responses)
1064
1065            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
1066            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2)
1067
1068            stream.send(magic_number=66)
1069            self.assertIs(
1070                packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type
1071            )
1072            self.assertEqual(
1073                66, self._sent_payload(self.method.request_type).magic_number
1074            )
1075            self.assertFalse(stream.completed())
1076            self.assertEqual([rep1, rep2], responses)
1077
1078            self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)
1079
1080            stream.send(magic_number=77)
1081            self.assertTrue(stream.completed())
1082            self.assertEqual([rep1, rep2], responses)
1083
1084            self.assertIs(Status.OK, stream.status)
1085            self.assertIsNone(stream.error)
1086
1087    def test_open(self) -> None:
1088        self.output_exception = IOError('something went wrong sending!')
1089        rep1 = self.method.response_type(payload='!!!')
1090        rep2 = self.method.response_type(payload='?')
1091
1092        for _ in range(3):
1093            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
1094            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep2)
1095            self._enqueue_response(CLIENT_CHANNEL_ID, self.method, Status.OK)
1096
1097            callback = mock.Mock()
1098            call = self.rpc.open(callback, callback, callback)
1099            self.assertEqual(self.requests, [])
1100
1101            self._process_enqueued_packets()
1102
1103            callback.assert_has_calls(
1104                [
1105                    mock.call(call, self.method.response_type(payload='!!!')),
1106                    mock.call(call, self.method.response_type(payload='?')),
1107                    mock.call(call, Status.OK),
1108                ]
1109            )
1110
1111    @mock.patch('pw_rpc.callback_client.call.Call._default_response')
1112    def test_nonblocking(self, callback) -> None:
1113        """Tests a bidirectional streaming RPC ended by the server."""
1114        reply = self.method.response_type(payload='This is the payload!')
1115        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)
1116
1117        self._service.SomeBidiStreaming.invoke()
1118
1119        callback.assert_called_once_with(mock.ANY, reply)
1120
1121    def test_nonblocking_server_error(self) -> None:
1122        rep1 = self.method.response_type(payload='!!!')
1123
1124        for _ in range(3):
1125            responses: list = []
1126            stream = self._service.SomeBidiStreaming.invoke(
1127                lambda _, res, responses=responses: responses.append(res)
1128            )
1129            self.assertFalse(stream.completed())
1130
1131            self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, rep1)
1132
1133            stream.send(magic_number=55)
1134            self.assertFalse(stream.completed())
1135            self.assertEqual([rep1], responses)
1136
1137            self._enqueue_error(
1138                CLIENT_CHANNEL_ID,
1139                self.method.service,
1140                self.method,
1141                Status.OUT_OF_RANGE,
1142            )
1143
1144            stream.send(magic_number=99999)
1145            self.assertTrue(stream.completed())
1146            self.assertEqual([rep1], responses)
1147
1148            self.assertIsNone(stream.status)
1149            self.assertIs(Status.OUT_OF_RANGE, stream.error)
1150
1151            with self.assertRaises(callback_client.RpcError) as context:
1152                stream.finish_and_wait()
1153            self.assertIs(context.exception.status, Status.OUT_OF_RANGE)
1154
1155    def test_nonblocking_server_error_after_stream_end(self) -> None:
1156        for _ in range(3):
1157            stream = self._service.SomeBidiStreaming.invoke()
1158
1159            # Error will be sent in response to the CLIENT_REQUEST_COMPLETION
1160            # packet.
1161            self._enqueue_error(
1162                CLIENT_CHANNEL_ID,
1163                self.method.service,
1164                self.method,
1165                Status.INVALID_ARGUMENT,
1166            )
1167
1168            with self.assertRaises(callback_client.RpcError) as context:
1169                stream.finish_and_wait()
1170
1171            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)
1172
1173    def test_nonblocking_send_after_cancelled(self) -> None:
1174        call = self._service.SomeBidiStreaming.invoke()
1175        self.assertTrue(call.cancel())
1176
1177        with self.assertRaises(callback_client.RpcError) as context:
1178            call.send(payload='hello')
1179
1180        self.assertIs(context.exception.status, Status.CANCELLED)
1181
1182    def test_nonblocking_finish_after_completed(self) -> None:
1183        reply = self.method.response_type(payload='!?')
1184        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)
1185        self._enqueue_response(
1186            CLIENT_CHANNEL_ID, self.method, Status.UNAVAILABLE
1187        )
1188
1189        call = self.rpc.invoke()
1190        result = call.finish_and_wait()
1191        self.assertEqual(result.responses, [reply])
1192
1193        self.assertEqual(result, call.finish_and_wait())
1194        self.assertEqual(result, call.finish_and_wait())
1195
1196    def test_nonblocking_finish_after_error(self) -> None:
1197        reply = self.method.response_type(payload='!?')
1198        self._enqueue_server_stream(CLIENT_CHANNEL_ID, self.method, reply)
1199        self._enqueue_error(
1200            CLIENT_CHANNEL_ID,
1201            self.method.service,
1202            self.method,
1203            Status.UNAVAILABLE,
1204        )
1205
1206        call = self.rpc.invoke()
1207
1208        for _ in range(3):
1209            with self.assertRaises(callback_client.RpcError) as context:
1210                call.finish_and_wait()
1211
1212            self.assertIs(context.exception.status, Status.UNAVAILABLE)
1213            self.assertIs(call.error, Status.UNAVAILABLE)
1214            self.assertEqual(call.responses, [reply])
1215
1216    def test_nonblocking_duplicate_calls_not_cancelled(self) -> None:
1217        first_call = self.rpc.invoke()
1218        self.assertFalse(first_call.completed())
1219
1220        second_call = self.rpc.invoke()
1221
1222        self.assertIs(first_call.error, None)
1223        self.assertIs(second_call.error, None)
1224
1225    def test_stream_response(self) -> None:
1226        proto = self._protos.packages.pw.test1.SomeMessage(magic_number=123)
1227        self.assertEqual(
1228            repr(callback_client.StreamResponse(Status.ABORTED, [proto] * 2)),
1229            '(Status.ABORTED, [pw.test1.SomeMessage(magic_number=123), '
1230            'pw.test1.SomeMessage(magic_number=123)])',
1231        )
1232        self.assertEqual(
1233            repr(callback_client.StreamResponse(Status.OK, [])),
1234            '(Status.OK, [])',
1235        )
1236
1237
1238if __name__ == '__main__':
1239    unittest.main()
1240