• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests creating pw_rpc client."""
16
17import unittest
18from typing import Optional
19
20from pw_protobuf_compiler import python_protos
21from pw_status import Status
22
23from pw_rpc import callback_client, client, packets
24import pw_rpc.ids
25from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket
26
27TEST_PROTO_1 = """\
28syntax = "proto3";
29
30package pw.test1;
31
32message SomeMessage {
33  uint32 magic_number = 1;
34}
35
36message AnotherMessage {
37  enum Result {
38    FAILED = 0;
39    FAILED_MISERABLY = 1;
40    I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
41  }
42
43  Result result = 1;
44  string payload = 2;
45}
46
47service PublicService {
48  rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
49  rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
50  rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
51  rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
52}
53"""
54
55TEST_PROTO_2 = """\
56syntax = "proto2";
57
58package pw.test2;
59
60message Request {
61  optional float magic_number = 1;
62}
63
64message Response {
65}
66
67service Alpha {
68  rpc Unary(Request) returns (Response) {}
69}
70
71service Bravo {
72  rpc BidiStreaming(stream Request) returns (stream Response) {}
73}
74"""
75
76
77def _test_setup(output=None):
78    protos = python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2])
79    return protos, client.Client.from_modules(
80        callback_client.Impl(),
81        [client.Channel(1, output), client.Channel(2, lambda _: None)],
82        protos.modules(),
83    )
84
85
86class ChannelClientTest(unittest.TestCase):
87    """Tests the ChannelClient."""
88
89    def setUp(self) -> None:
90        self._channel_client = _test_setup()[1].channel(1)
91
92    def test_access_service_client_as_attribute_or_index(self) -> None:
93        self.assertIs(
94            self._channel_client.rpcs.pw.test1.PublicService,
95            self._channel_client.rpcs['pw.test1.PublicService'],
96        )
97        self.assertIs(
98            self._channel_client.rpcs.pw.test1.PublicService,
99            self._channel_client.rpcs[
100                pw_rpc.ids.calculate('pw.test1.PublicService')
101            ],
102        )
103
104    def test_access_method_client_as_attribute_or_index(self) -> None:
105        self.assertIs(
106            self._channel_client.rpcs.pw.test2.Alpha.Unary,
107            self._channel_client.rpcs['pw.test2.Alpha']['Unary'],
108        )
109        self.assertIs(
110            self._channel_client.rpcs.pw.test2.Alpha.Unary,
111            self._channel_client.rpcs['pw.test2.Alpha'][
112                pw_rpc.ids.calculate('Unary')
113            ],
114        )
115
116    def test_service_name(self) -> None:
117        self.assertEqual(
118            self._channel_client.rpcs.pw.test2.Alpha.Unary.service.name, 'Alpha'
119        )
120        self.assertEqual(
121            self._channel_client.rpcs.pw.test2.Alpha.Unary.service.full_name,
122            'pw.test2.Alpha',
123        )
124
125    def test_method_name(self) -> None:
126        self.assertEqual(
127            self._channel_client.rpcs.pw.test2.Alpha.Unary.method.name, 'Unary'
128        )
129        self.assertEqual(
130            self._channel_client.rpcs.pw.test2.Alpha.Unary.method.full_name,
131            'pw.test2.Alpha.Unary',
132        )
133
134    def test_iterate_over_all_methods(self) -> None:
135        channel_client = self._channel_client
136        all_methods = {
137            channel_client.rpcs.pw.test1.PublicService.SomeUnary,
138            channel_client.rpcs.pw.test1.PublicService.SomeServerStreaming,
139            channel_client.rpcs.pw.test1.PublicService.SomeClientStreaming,
140            channel_client.rpcs.pw.test1.PublicService.SomeBidiStreaming,
141            channel_client.rpcs.pw.test2.Alpha.Unary,
142            channel_client.rpcs.pw.test2.Bravo.BidiStreaming,
143        }
144        self.assertEqual(set(channel_client.methods()), all_methods)
145
146    def test_check_for_presence_of_services(self) -> None:
147        self.assertIn('pw.test1.PublicService', self._channel_client.rpcs)
148        self.assertIn(
149            pw_rpc.ids.calculate('pw.test1.PublicService'),
150            self._channel_client.rpcs,
151        )
152
153    def test_check_for_presence_of_missing_services(self) -> None:
154        self.assertNotIn('PublicService', self._channel_client.rpcs)
155        self.assertNotIn('NotAService', self._channel_client.rpcs)
156        self.assertNotIn(-1213, self._channel_client.rpcs)
157
158    def test_check_for_presence_of_methods(self) -> None:
159        service = self._channel_client.rpcs.pw.test1.PublicService
160        self.assertIn('SomeUnary', service)
161        self.assertIn(pw_rpc.ids.calculate('SomeUnary'), service)
162
163    def test_check_for_presence_of_missing_methods(self) -> None:
164        service = self._channel_client.rpcs.pw.test1.PublicService
165        self.assertNotIn('Some', service)
166        self.assertNotIn('Unary', service)
167        self.assertNotIn(12345, service)
168
169    def test_method_fully_qualified_name(self) -> None:
170        self.assertIs(
171            self._channel_client.method('pw.test2.Alpha/Unary'),
172            self._channel_client.rpcs.pw.test2.Alpha.Unary,
173        )
174        self.assertIs(
175            self._channel_client.method('pw.test2.Alpha.Unary'),
176            self._channel_client.rpcs.pw.test2.Alpha.Unary,
177        )
178
179
180class ClientTest(unittest.TestCase):
181    """Tests the pw_rpc Client independently of the ClientImpl."""
182
183    def setUp(self) -> None:
184        self._last_packet_sent_bytes: Optional[bytes] = None
185        self._protos, self._client = _test_setup(self._save_packet)
186
187    def _save_packet(self, packet) -> None:
188        self._last_packet_sent_bytes = packet
189
190    def _last_packet_sent(self) -> RpcPacket:
191        packet = RpcPacket()
192        assert self._last_packet_sent_bytes is not None
193        packet.MergeFromString(self._last_packet_sent_bytes)
194        return packet
195
196    def test_channel(self) -> None:
197        self.assertEqual(self._client.channel(1).channel.id, 1)
198        self.assertEqual(self._client.channel(2).channel.id, 2)
199
200    def test_channel_default_is_first_listed(self) -> None:
201        self.assertEqual(self._client.channel().channel.id, 1)
202
203    def test_channel_invalid(self) -> None:
204        with self.assertRaises(KeyError):
205            self._client.channel(404)
206
207    def test_all_methods(self) -> None:
208        services = self._client.services
209
210        all_methods = {
211            services['pw.test1.PublicService'].methods['SomeUnary'],
212            services['pw.test1.PublicService'].methods['SomeServerStreaming'],
213            services['pw.test1.PublicService'].methods['SomeClientStreaming'],
214            services['pw.test1.PublicService'].methods['SomeBidiStreaming'],
215            services['pw.test2.Alpha'].methods['Unary'],
216            services['pw.test2.Bravo'].methods['BidiStreaming'],
217        }
218        self.assertEqual(set(self._client.methods()), all_methods)
219
220    def test_method_present(self) -> None:
221        self.assertIs(
222            self._client.method('pw.test1.PublicService.SomeUnary'),
223            self._client.services['pw.test1.PublicService'].methods[
224                'SomeUnary'
225            ],
226        )
227        self.assertIs(
228            self._client.method('pw.test1.PublicService/SomeUnary'),
229            self._client.services['pw.test1.PublicService'].methods[
230                'SomeUnary'
231            ],
232        )
233
234    def test_method_invalid_format(self) -> None:
235        with self.assertRaises(ValueError):
236            self._client.method('SomeUnary')
237
238    def test_method_not_present(self) -> None:
239        with self.assertRaises(KeyError):
240            self._client.method('pw.test1.PublicService/ThisIsNotGood')
241
242        with self.assertRaises(KeyError):
243            self._client.method('nothing.Good')
244
245    def test_process_packet_invalid_proto_data(self) -> None:
246        self.assertIs(
247            self._client.process_packet(b'NOT a packet!'), Status.DATA_LOSS
248        )
249
250    def test_process_packet_not_for_client(self) -> None:
251        self.assertIs(
252            self._client.process_packet(
253                RpcPacket(type=PacketType.REQUEST).SerializeToString()
254            ),
255            Status.INVALID_ARGUMENT,
256        )
257
258    def test_process_packet_unrecognized_channel(self) -> None:
259        self.assertIs(
260            self._client.process_packet(
261                packets.encode_response(
262                    (123, 456, 789), self._protos.packages.pw.test2.Request()
263                )
264            ),
265            Status.NOT_FOUND,
266        )
267
268    def test_process_packet_unrecognized_service(self) -> None:
269        self.assertIs(
270            self._client.process_packet(
271                packets.encode_response(
272                    (1, 456, 789), self._protos.packages.pw.test2.Request()
273                )
274            ),
275            Status.OK,
276        )
277
278        self.assertEqual(
279            self._last_packet_sent(),
280            RpcPacket(
281                type=PacketType.CLIENT_ERROR,
282                channel_id=1,
283                service_id=456,
284                method_id=789,
285                status=Status.NOT_FOUND.value,
286            ),
287        )
288
289    def test_process_packet_unrecognized_method(self) -> None:
290        service = next(iter(self._client.services))
291
292        self.assertIs(
293            self._client.process_packet(
294                packets.encode_response(
295                    (1, service.id, 789),
296                    self._protos.packages.pw.test2.Request(),
297                )
298            ),
299            Status.OK,
300        )
301
302        self.assertEqual(
303            self._last_packet_sent(),
304            RpcPacket(
305                type=PacketType.CLIENT_ERROR,
306                channel_id=1,
307                service_id=service.id,
308                method_id=789,
309                status=Status.NOT_FOUND.value,
310            ),
311        )
312
313    def test_process_packet_non_pending_method(self) -> None:
314        service = next(iter(self._client.services))
315        method = next(iter(service.methods))
316
317        self.assertIs(
318            self._client.process_packet(
319                packets.encode_response(
320                    (1, service.id, method.id),
321                    self._protos.packages.pw.test2.Request(),
322                )
323            ),
324            Status.OK,
325        )
326
327        self.assertEqual(
328            self._last_packet_sent(),
329            RpcPacket(
330                type=PacketType.CLIENT_ERROR,
331                channel_id=1,
332                service_id=service.id,
333                method_id=method.id,
334                status=Status.FAILED_PRECONDITION.value,
335            ),
336        )
337
338    def test_process_packet_non_pending_calls_response_callback(self) -> None:
339        method = self._client.method('pw.test1.PublicService.SomeUnary')
340        reply = method.response_type(payload='hello')
341
342        def response_callback(
343            rpc: client.PendingRpc, message, status: Optional[Status]
344        ) -> None:
345            self.assertEqual(
346                rpc,
347                client.PendingRpc(
348                    self._client.channel(1).channel, method.service, method
349                ),
350            )
351            self.assertEqual(message, reply)
352            self.assertIs(status, Status.OK)
353
354        self._client.response_callback = response_callback
355
356        self.assertIs(
357            self._client.process_packet(
358                packets.encode_response((1, method.service, method), reply)
359            ),
360            Status.OK,
361        )
362
363
364if __name__ == '__main__':
365    unittest.main()
366