• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests creating pw_rpc client."""
16
17import unittest
18from typing import Any, Callable
19
20from pw_protobuf_compiler import python_protos
21from pw_status import Status
22
23from pw_rpc import callback_client, client, packets
24import pw_rpc.ids
25from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket
26
27RpcIds = packets.RpcIds
28
29TEST_PROTO_1 = """\
30syntax = "proto3";
31
32package pw.test1;
33
34message SomeMessage {
35  uint32 magic_number = 1;
36}
37
38message AnotherMessage {
39  enum Result {
40    FAILED = 0;
41    FAILED_MISERABLY = 1;
42    I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
43  }
44
45  Result result = 1;
46  string payload = 2;
47}
48
49service PublicService {
50  rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
51  rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
52  rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
53  rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
54}
55"""
56
57TEST_PROTO_2 = """\
58syntax = "proto2";
59
60package pw.test2;
61
62message Request {
63  optional float magic_number = 1;
64}
65
66message Response {
67}
68
69service Alpha {
70  rpc Unary(Request) returns (Response) {}
71}
72
73service Bravo {
74  rpc BidiStreaming(stream Request) returns (stream Response) {}
75}
76"""
77
78SOME_CHANNEL_ID: int = 237
79SOME_SERVICE_ID: int = 193
80SOME_METHOD_ID: int = 769
81SOME_CALL_ID: int = 452
82
83CLIENT_FIRST_CHANNEL_ID: int = 557
84CLIENT_SECOND_CHANNEL_ID: int = 474
85
86
87def create_protos() -> Any:
88    return python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2])
89
90
91def create_client(
92    proto_modules: Any,
93    first_channel_output_fn: Callable[[bytes], Any] | None = None,
94) -> client.Client:
95    return client.Client.from_modules(
96        callback_client.Impl(),
97        [
98            client.Channel(CLIENT_FIRST_CHANNEL_ID, first_channel_output_fn),
99            client.Channel(CLIENT_SECOND_CHANNEL_ID, lambda _: None),
100        ],
101        proto_modules,
102    )
103
104
105class ChannelClientTest(unittest.TestCase):
106    """Tests the ChannelClient."""
107
108    def setUp(self) -> None:
109        client_instance = create_client(create_protos().modules())
110        self._channel_client: client.ChannelClient = client_instance.channel(
111            CLIENT_FIRST_CHANNEL_ID
112        )
113
114    def test_access_service_client_as_attribute_or_index(self) -> None:
115        self.assertIs(
116            self._channel_client.rpcs.pw.test1.PublicService,
117            self._channel_client.rpcs['pw.test1.PublicService'],
118        )
119        self.assertIs(
120            self._channel_client.rpcs.pw.test1.PublicService,
121            self._channel_client.rpcs[
122                pw_rpc.ids.calculate('pw.test1.PublicService')
123            ],
124        )
125
126    def test_access_method_client_as_attribute_or_index(self) -> None:
127        self.assertIs(
128            self._channel_client.rpcs.pw.test2.Alpha.Unary,
129            self._channel_client.rpcs['pw.test2.Alpha']['Unary'],
130        )
131        self.assertIs(
132            self._channel_client.rpcs.pw.test2.Alpha.Unary,
133            self._channel_client.rpcs['pw.test2.Alpha'][
134                pw_rpc.ids.calculate('Unary')
135            ],
136        )
137
138    def test_service_name(self) -> None:
139        self.assertEqual(
140            self._channel_client.rpcs.pw.test2.Alpha.Unary.service.name, 'Alpha'
141        )
142        self.assertEqual(
143            self._channel_client.rpcs.pw.test2.Alpha.Unary.service.full_name,
144            'pw.test2.Alpha',
145        )
146
147    def test_method_name(self) -> None:
148        self.assertEqual(
149            self._channel_client.rpcs.pw.test2.Alpha.Unary.method.name, 'Unary'
150        )
151        self.assertEqual(
152            self._channel_client.rpcs.pw.test2.Alpha.Unary.method.full_name,
153            'pw.test2.Alpha.Unary',
154        )
155
156    def test_iterate_over_all_methods(self) -> None:
157        channel_client = self._channel_client
158        all_methods = {
159            channel_client.rpcs.pw.test1.PublicService.SomeUnary,
160            channel_client.rpcs.pw.test1.PublicService.SomeServerStreaming,
161            channel_client.rpcs.pw.test1.PublicService.SomeClientStreaming,
162            channel_client.rpcs.pw.test1.PublicService.SomeBidiStreaming,
163            channel_client.rpcs.pw.test2.Alpha.Unary,
164            channel_client.rpcs.pw.test2.Bravo.BidiStreaming,
165        }
166        self.assertEqual(set(channel_client.methods()), all_methods)
167
168    def test_check_for_presence_of_services(self) -> None:
169        self.assertIn('pw.test1.PublicService', self._channel_client.rpcs)
170        self.assertIn(
171            pw_rpc.ids.calculate('pw.test1.PublicService'),
172            self._channel_client.rpcs,
173        )
174
175    def test_check_for_presence_of_missing_services(self) -> None:
176        self.assertNotIn('PublicService', self._channel_client.rpcs)
177        self.assertNotIn('NotAService', self._channel_client.rpcs)
178        self.assertNotIn(-1213, self._channel_client.rpcs)
179
180    def test_check_for_presence_of_methods(self) -> None:
181        service = self._channel_client.rpcs.pw.test1.PublicService
182        self.assertIn('SomeUnary', service)
183        self.assertIn(pw_rpc.ids.calculate('SomeUnary'), service)
184
185    def test_check_for_presence_of_missing_methods(self) -> None:
186        service = self._channel_client.rpcs.pw.test1.PublicService
187        self.assertNotIn('Some', service)
188        self.assertNotIn('Unary', service)
189        self.assertNotIn(12345, service)
190
191    def test_method_fully_qualified_name(self) -> None:
192        self.assertIs(
193            self._channel_client.method('pw.test2.Alpha/Unary'),
194            self._channel_client.rpcs.pw.test2.Alpha.Unary,
195        )
196        self.assertIs(
197            self._channel_client.method('pw.test2.Alpha.Unary'),
198            self._channel_client.rpcs.pw.test2.Alpha.Unary,
199        )
200
201
202class ClientTest(unittest.TestCase):
203    """Tests the pw_rpc Client independently of the ClientImpl."""
204
205    def setUp(self) -> None:
206        self._last_packet_sent_bytes: bytes | None = None
207        self._protos = create_protos()
208        self._client = create_client(self._protos.modules(), self._save_packet)
209
210    def _save_packet(self, packet) -> None:
211        self._last_packet_sent_bytes = packet
212
213    def _last_packet_sent(self) -> RpcPacket:
214        packet = RpcPacket()
215        assert self._last_packet_sent_bytes is not None
216        packet.MergeFromString(self._last_packet_sent_bytes)
217        return packet
218
219    def test_channel(self) -> None:
220        self.assertEqual(
221            self._client.channel(CLIENT_FIRST_CHANNEL_ID).channel.id,
222            CLIENT_FIRST_CHANNEL_ID,
223        )
224        self.assertEqual(
225            self._client.channel(CLIENT_SECOND_CHANNEL_ID).channel.id,
226            CLIENT_SECOND_CHANNEL_ID,
227        )
228
229    def test_channel_default_is_first_listed(self) -> None:
230        self.assertEqual(
231            self._client.channel().channel.id, CLIENT_FIRST_CHANNEL_ID
232        )
233
234    def test_channel_invalid(self) -> None:
235        with self.assertRaises(KeyError):
236            self._client.channel(404)
237
238    def test_all_methods(self) -> None:
239        services = self._client.services
240
241        all_methods = {
242            services['pw.test1.PublicService'].methods['SomeUnary'],
243            services['pw.test1.PublicService'].methods['SomeServerStreaming'],
244            services['pw.test1.PublicService'].methods['SomeClientStreaming'],
245            services['pw.test1.PublicService'].methods['SomeBidiStreaming'],
246            services['pw.test2.Alpha'].methods['Unary'],
247            services['pw.test2.Bravo'].methods['BidiStreaming'],
248        }
249        self.assertEqual(set(self._client.methods()), all_methods)
250
251    def test_method_present(self) -> None:
252        self.assertIs(
253            self._client.method('pw.test1.PublicService.SomeUnary'),
254            self._client.services['pw.test1.PublicService'].methods[
255                'SomeUnary'
256            ],
257        )
258        self.assertIs(
259            self._client.method('pw.test1.PublicService/SomeUnary'),
260            self._client.services['pw.test1.PublicService'].methods[
261                'SomeUnary'
262            ],
263        )
264
265    def test_method_invalid_format(self) -> None:
266        with self.assertRaises(ValueError):
267            self._client.method('SomeUnary')
268
269    def test_method_not_present(self) -> None:
270        with self.assertRaises(KeyError):
271            self._client.method('pw.test1.PublicService/ThisIsNotGood')
272
273        with self.assertRaises(KeyError):
274            self._client.method('nothing.Good')
275
276    def test_process_packet_invalid_proto_data(self) -> None:
277        self.assertIs(
278            self._client.process_packet(b'NOT a packet!'), Status.DATA_LOSS
279        )
280
281    def test_process_packet_not_for_client(self) -> None:
282        self.assertIs(
283            self._client.process_packet(
284                RpcPacket(type=PacketType.REQUEST).SerializeToString()
285            ),
286            Status.INVALID_ARGUMENT,
287        )
288
289    def test_process_packet_unrecognized_channel(self) -> None:
290        self.assertIs(
291            self._client.process_packet(
292                packets.encode_response(
293                    RpcIds(
294                        SOME_CHANNEL_ID,
295                        SOME_SERVICE_ID,
296                        SOME_METHOD_ID,
297                        SOME_CALL_ID,
298                    ),
299                    self._protos.packages.pw.test2.Request(),
300                )
301            ),
302            Status.NOT_FOUND,
303        )
304
305    def test_process_packet_unrecognized_service(self) -> None:
306        self.assertIs(
307            self._client.process_packet(
308                packets.encode_response(
309                    RpcIds(
310                        CLIENT_FIRST_CHANNEL_ID,
311                        SOME_SERVICE_ID,
312                        SOME_METHOD_ID,
313                        SOME_CALL_ID,
314                    ),
315                    self._protos.packages.pw.test2.Request(),
316                )
317            ),
318            Status.OK,
319        )
320
321        self.assertEqual(
322            self._last_packet_sent(),
323            RpcPacket(
324                type=PacketType.CLIENT_ERROR,
325                channel_id=CLIENT_FIRST_CHANNEL_ID,
326                service_id=SOME_SERVICE_ID,
327                method_id=SOME_METHOD_ID,
328                call_id=SOME_CALL_ID,
329                status=Status.NOT_FOUND.value,
330            ),
331        )
332
333    def test_process_packet_unrecognized_method(self) -> None:
334        service = next(iter(self._client.services))
335
336        self.assertIs(
337            self._client.process_packet(
338                packets.encode_response(
339                    RpcIds(
340                        CLIENT_FIRST_CHANNEL_ID,
341                        service.id,
342                        SOME_METHOD_ID,
343                        SOME_CALL_ID,
344                    ),
345                    self._protos.packages.pw.test2.Request(),
346                )
347            ),
348            Status.OK,
349        )
350
351        self.assertEqual(
352            self._last_packet_sent(),
353            RpcPacket(
354                type=PacketType.CLIENT_ERROR,
355                channel_id=CLIENT_FIRST_CHANNEL_ID,
356                service_id=service.id,
357                method_id=SOME_METHOD_ID,
358                call_id=SOME_CALL_ID,
359                status=Status.NOT_FOUND.value,
360            ),
361        )
362
363    def test_process_packet_non_pending_method(self) -> None:
364        service = next(iter(self._client.services))
365        method = next(iter(service.methods))
366
367        self.assertIs(
368            self._client.process_packet(
369                packets.encode_response(
370                    RpcIds(
371                        CLIENT_FIRST_CHANNEL_ID,
372                        service.id,
373                        method.id,
374                        SOME_CALL_ID,
375                    ),
376                    self._protos.packages.pw.test2.Request(),
377                )
378            ),
379            Status.OK,
380        )
381
382        self.assertEqual(
383            self._last_packet_sent(),
384            RpcPacket(
385                type=PacketType.CLIENT_ERROR,
386                channel_id=CLIENT_FIRST_CHANNEL_ID,
387                service_id=service.id,
388                method_id=method.id,
389                call_id=SOME_CALL_ID,
390                status=Status.FAILED_PRECONDITION.value,
391            ),
392        )
393
394    def test_process_packet_non_pending_calls_response_callback(self) -> None:
395        method = self._client.method('pw.test1.PublicService.SomeUnary')
396        reply = method.response_type(payload='hello')
397
398        def response_callback(
399            rpc: client.PendingRpc,
400            message,
401            status: Status | None,
402        ) -> None:
403            self.assertEqual(
404                rpc,
405                client.PendingRpc(
406                    self._client.channel(CLIENT_FIRST_CHANNEL_ID).channel,
407                    method.service,
408                    method,
409                    call_id=SOME_CALL_ID,
410                ),
411            )
412            self.assertEqual(message, reply)
413            self.assertIs(status, Status.OK)
414
415        self._client.response_callback = response_callback
416
417        self.assertIs(
418            self._client.process_packet(
419                packets.encode_response(
420                    RpcIds(
421                        CLIENT_FIRST_CHANNEL_ID,
422                        method.service.id,
423                        method.id,
424                        SOME_CALL_ID,
425                    ),
426                    reply,
427                )
428            ),
429            Status.OK,
430        )
431
432
433if __name__ == '__main__':
434    unittest.main()
435