• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests creating pw_rpc client."""
16
17import unittest
18from typing import Optional
19
20from pw_protobuf_compiler import python_protos
21from pw_status import Status
22
23from pw_rpc import callback_client, client, packets
24import pw_rpc.ids
25from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket
26
27TEST_PROTO_1 = """\
28syntax = "proto3";
29
30package pw.test1;
31
32message SomeMessage {
33  uint32 magic_number = 1;
34}
35
36message AnotherMessage {
37  enum Result {
38    FAILED = 0;
39    FAILED_MISERABLY = 1;
40    I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
41  }
42
43  Result result = 1;
44  string payload = 2;
45}
46
47service PublicService {
48  rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
49  rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
50  rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
51  rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
52}
53"""
54
55TEST_PROTO_2 = """\
56syntax = "proto2";
57
58package pw.test2;
59
60message Request {
61  optional float magic_number = 1;
62}
63
64message Response {
65}
66
67service Alpha {
68  rpc Unary(Request) returns (Response) {}
69}
70
71service Bravo {
72  rpc BidiStreaming(stream Request) returns (stream Response) {}
73}
74"""
75
76
77def _test_setup(output=None):
78    protos = python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2])
79    return protos, client.Client.from_modules(
80        callback_client.Impl(),
81        [client.Channel(1, output),
82         client.Channel(2, lambda _: None)], protos.modules())
83
84
85class ChannelClientTest(unittest.TestCase):
86    """Tests the ChannelClient."""
87    def setUp(self) -> None:
88        self._channel_client = _test_setup()[1].channel(1)
89
90    def test_access_service_client_as_attribute_or_index(self) -> None:
91        self.assertIs(self._channel_client.rpcs.pw.test1.PublicService,
92                      self._channel_client.rpcs['pw.test1.PublicService'])
93        self.assertIs(
94            self._channel_client.rpcs.pw.test1.PublicService,
95            self._channel_client.rpcs[pw_rpc.ids.calculate(
96                'pw.test1.PublicService')])
97
98    def test_access_method_client_as_attribute_or_index(self) -> None:
99        self.assertIs(self._channel_client.rpcs.pw.test2.Alpha.Unary,
100                      self._channel_client.rpcs['pw.test2.Alpha']['Unary'])
101        self.assertIs(
102            self._channel_client.rpcs.pw.test2.Alpha.Unary,
103            self._channel_client.rpcs['pw.test2.Alpha'][pw_rpc.ids.calculate(
104                'Unary')])
105
106    def test_service_name(self) -> None:
107        self.assertEqual(
108            self._channel_client.rpcs.pw.test2.Alpha.Unary.service.name,
109            'Alpha')
110        self.assertEqual(
111            self._channel_client.rpcs.pw.test2.Alpha.Unary.service.full_name,
112            'pw.test2.Alpha')
113
114    def test_method_name(self) -> None:
115        self.assertEqual(
116            self._channel_client.rpcs.pw.test2.Alpha.Unary.method.name,
117            'Unary')
118        self.assertEqual(
119            self._channel_client.rpcs.pw.test2.Alpha.Unary.method.full_name,
120            'pw.test2.Alpha.Unary')
121
122    def test_iterate_over_all_methods(self) -> None:
123        channel_client = self._channel_client
124        all_methods = {
125            channel_client.rpcs.pw.test1.PublicService.SomeUnary,
126            channel_client.rpcs.pw.test1.PublicService.SomeServerStreaming,
127            channel_client.rpcs.pw.test1.PublicService.SomeClientStreaming,
128            channel_client.rpcs.pw.test1.PublicService.SomeBidiStreaming,
129            channel_client.rpcs.pw.test2.Alpha.Unary,
130            channel_client.rpcs.pw.test2.Bravo.BidiStreaming,
131        }
132        self.assertEqual(set(channel_client.methods()), all_methods)
133
134    def test_check_for_presence_of_services(self) -> None:
135        self.assertIn('pw.test1.PublicService', self._channel_client.rpcs)
136        self.assertIn(pw_rpc.ids.calculate('pw.test1.PublicService'),
137                      self._channel_client.rpcs)
138
139    def test_check_for_presence_of_missing_services(self) -> None:
140        self.assertNotIn('PublicService', self._channel_client.rpcs)
141        self.assertNotIn('NotAService', self._channel_client.rpcs)
142        self.assertNotIn(-1213, self._channel_client.rpcs)
143
144    def test_check_for_presence_of_methods(self) -> None:
145        service = self._channel_client.rpcs.pw.test1.PublicService
146        self.assertIn('SomeUnary', service)
147        self.assertIn(pw_rpc.ids.calculate('SomeUnary'), service)
148
149    def test_check_for_presence_of_missing_methods(self) -> None:
150        service = self._channel_client.rpcs.pw.test1.PublicService
151        self.assertNotIn('Some', service)
152        self.assertNotIn('Unary', service)
153        self.assertNotIn(12345, service)
154
155    def test_method_fully_qualified_name(self) -> None:
156        self.assertIs(self._channel_client.method('pw.test2.Alpha/Unary'),
157                      self._channel_client.rpcs.pw.test2.Alpha.Unary)
158        self.assertIs(self._channel_client.method('pw.test2.Alpha.Unary'),
159                      self._channel_client.rpcs.pw.test2.Alpha.Unary)
160
161
162class ClientTest(unittest.TestCase):
163    """Tests the pw_rpc Client independently of the ClientImpl."""
164    def setUp(self) -> None:
165        self._last_packet_sent_bytes: Optional[bytes] = None
166        self._protos, self._client = _test_setup(self._save_packet)
167
168    def _save_packet(self, packet) -> None:
169        self._last_packet_sent_bytes = packet
170
171    def _last_packet_sent(self) -> RpcPacket:
172        packet = RpcPacket()
173        assert self._last_packet_sent_bytes is not None
174        packet.MergeFromString(self._last_packet_sent_bytes)
175        return packet
176
177    def test_channel(self) -> None:
178        self.assertEqual(self._client.channel(1).channel.id, 1)
179        self.assertEqual(self._client.channel(2).channel.id, 2)
180
181    def test_channel_default_is_first_listed(self) -> None:
182        self.assertEqual(self._client.channel().channel.id, 1)
183
184    def test_channel_invalid(self) -> None:
185        with self.assertRaises(KeyError):
186            self._client.channel(404)
187
188    def test_all_methods(self) -> None:
189        services = self._client.services
190
191        all_methods = {
192            services['pw.test1.PublicService'].methods['SomeUnary'],
193            services['pw.test1.PublicService'].methods['SomeServerStreaming'],
194            services['pw.test1.PublicService'].methods['SomeClientStreaming'],
195            services['pw.test1.PublicService'].methods['SomeBidiStreaming'],
196            services['pw.test2.Alpha'].methods['Unary'],
197            services['pw.test2.Bravo'].methods['BidiStreaming'],
198        }
199        self.assertEqual(set(self._client.methods()), all_methods)
200
201    def test_method_present(self) -> None:
202        self.assertIs(
203            self._client.method('pw.test1.PublicService.SomeUnary'), self.
204            _client.services['pw.test1.PublicService'].methods['SomeUnary'])
205        self.assertIs(
206            self._client.method('pw.test1.PublicService/SomeUnary'), self.
207            _client.services['pw.test1.PublicService'].methods['SomeUnary'])
208
209    def test_method_invalid_format(self) -> None:
210        with self.assertRaises(ValueError):
211            self._client.method('SomeUnary')
212
213    def test_method_not_present(self) -> None:
214        with self.assertRaises(KeyError):
215            self._client.method('pw.test1.PublicService/ThisIsNotGood')
216
217        with self.assertRaises(KeyError):
218            self._client.method('nothing.Good')
219
220    def test_process_packet_invalid_proto_data(self) -> None:
221        self.assertIs(self._client.process_packet(b'NOT a packet!'),
222                      Status.DATA_LOSS)
223
224    def test_process_packet_not_for_client(self) -> None:
225        self.assertIs(
226            self._client.process_packet(
227                RpcPacket(type=PacketType.REQUEST).SerializeToString()),
228            Status.INVALID_ARGUMENT)
229
230    def test_process_packet_unrecognized_channel(self) -> None:
231        self.assertIs(
232            self._client.process_packet(
233                packets.encode_response(
234                    (123, 456, 789),
235                    self._protos.packages.pw.test2.Request())),
236            Status.NOT_FOUND)
237
238    def test_process_packet_unrecognized_service(self) -> None:
239        self.assertIs(
240            self._client.process_packet(
241                packets.encode_response(
242                    (1, 456, 789), self._protos.packages.pw.test2.Request())),
243            Status.OK)
244
245        self.assertEqual(
246            self._last_packet_sent(),
247            RpcPacket(type=PacketType.CLIENT_ERROR,
248                      channel_id=1,
249                      service_id=456,
250                      method_id=789,
251                      status=Status.NOT_FOUND.value))
252
253    def test_process_packet_unrecognized_method(self) -> None:
254        service = next(iter(self._client.services))
255
256        self.assertIs(
257            self._client.process_packet(
258                packets.encode_response(
259                    (1, service.id, 789),
260                    self._protos.packages.pw.test2.Request())), Status.OK)
261
262        self.assertEqual(
263            self._last_packet_sent(),
264            RpcPacket(type=PacketType.CLIENT_ERROR,
265                      channel_id=1,
266                      service_id=service.id,
267                      method_id=789,
268                      status=Status.NOT_FOUND.value))
269
270    def test_process_packet_non_pending_method(self) -> None:
271        service = next(iter(self._client.services))
272        method = next(iter(service.methods))
273
274        self.assertIs(
275            self._client.process_packet(
276                packets.encode_response(
277                    (1, service.id, method.id),
278                    self._protos.packages.pw.test2.Request())), Status.OK)
279
280        self.assertEqual(
281            self._last_packet_sent(),
282            RpcPacket(type=PacketType.CLIENT_ERROR,
283                      channel_id=1,
284                      service_id=service.id,
285                      method_id=method.id,
286                      status=Status.FAILED_PRECONDITION.value))
287
288
289if __name__ == '__main__':
290    unittest.main()
291