1#!/usr/bin/env python3 2# Copyright 2020 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests creating pw_rpc client.""" 16 17import unittest 18from typing import Optional 19 20from pw_protobuf_compiler import python_protos 21from pw_status import Status 22 23from pw_rpc import callback_client, client, packets 24import pw_rpc.ids 25from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket 26 27TEST_PROTO_1 = """\ 28syntax = "proto3"; 29 30package pw.test1; 31 32message SomeMessage { 33 uint32 magic_number = 1; 34} 35 36message AnotherMessage { 37 enum Result { 38 FAILED = 0; 39 FAILED_MISERABLY = 1; 40 I_DONT_WANT_TO_TALK_ABOUT_IT = 2; 41 } 42 43 Result result = 1; 44 string payload = 2; 45} 46 47service PublicService { 48 rpc SomeUnary(SomeMessage) returns (AnotherMessage) {} 49 rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {} 50 rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {} 51 rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {} 52} 53""" 54 55TEST_PROTO_2 = """\ 56syntax = "proto2"; 57 58package pw.test2; 59 60message Request { 61 optional float magic_number = 1; 62} 63 64message Response { 65} 66 67service Alpha { 68 rpc Unary(Request) returns (Response) {} 69} 70 71service Bravo { 72 rpc BidiStreaming(stream Request) returns (stream Response) {} 73} 74""" 75 76 77def _test_setup(output=None): 78 protos = python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2]) 79 return protos, client.Client.from_modules( 80 callback_client.Impl(), 81 [client.Channel(1, output), 82 client.Channel(2, lambda _: None)], protos.modules()) 83 84 85class ChannelClientTest(unittest.TestCase): 86 """Tests the ChannelClient.""" 87 def setUp(self) -> None: 88 self._channel_client = _test_setup()[1].channel(1) 89 90 def test_access_service_client_as_attribute_or_index(self) -> None: 91 self.assertIs(self._channel_client.rpcs.pw.test1.PublicService, 92 self._channel_client.rpcs['pw.test1.PublicService']) 93 self.assertIs( 94 self._channel_client.rpcs.pw.test1.PublicService, 95 self._channel_client.rpcs[pw_rpc.ids.calculate( 96 'pw.test1.PublicService')]) 97 98 def test_access_method_client_as_attribute_or_index(self) -> None: 99 self.assertIs(self._channel_client.rpcs.pw.test2.Alpha.Unary, 100 self._channel_client.rpcs['pw.test2.Alpha']['Unary']) 101 self.assertIs( 102 self._channel_client.rpcs.pw.test2.Alpha.Unary, 103 self._channel_client.rpcs['pw.test2.Alpha'][pw_rpc.ids.calculate( 104 'Unary')]) 105 106 def test_service_name(self) -> None: 107 self.assertEqual( 108 self._channel_client.rpcs.pw.test2.Alpha.Unary.service.name, 109 'Alpha') 110 self.assertEqual( 111 self._channel_client.rpcs.pw.test2.Alpha.Unary.service.full_name, 112 'pw.test2.Alpha') 113 114 def test_method_name(self) -> None: 115 self.assertEqual( 116 self._channel_client.rpcs.pw.test2.Alpha.Unary.method.name, 117 'Unary') 118 self.assertEqual( 119 self._channel_client.rpcs.pw.test2.Alpha.Unary.method.full_name, 120 'pw.test2.Alpha.Unary') 121 122 def test_iterate_over_all_methods(self) -> None: 123 channel_client = self._channel_client 124 all_methods = { 125 channel_client.rpcs.pw.test1.PublicService.SomeUnary, 126 channel_client.rpcs.pw.test1.PublicService.SomeServerStreaming, 127 channel_client.rpcs.pw.test1.PublicService.SomeClientStreaming, 128 channel_client.rpcs.pw.test1.PublicService.SomeBidiStreaming, 129 channel_client.rpcs.pw.test2.Alpha.Unary, 130 channel_client.rpcs.pw.test2.Bravo.BidiStreaming, 131 } 132 self.assertEqual(set(channel_client.methods()), all_methods) 133 134 def test_check_for_presence_of_services(self) -> None: 135 self.assertIn('pw.test1.PublicService', self._channel_client.rpcs) 136 self.assertIn(pw_rpc.ids.calculate('pw.test1.PublicService'), 137 self._channel_client.rpcs) 138 139 def test_check_for_presence_of_missing_services(self) -> None: 140 self.assertNotIn('PublicService', self._channel_client.rpcs) 141 self.assertNotIn('NotAService', self._channel_client.rpcs) 142 self.assertNotIn(-1213, self._channel_client.rpcs) 143 144 def test_check_for_presence_of_methods(self) -> None: 145 service = self._channel_client.rpcs.pw.test1.PublicService 146 self.assertIn('SomeUnary', service) 147 self.assertIn(pw_rpc.ids.calculate('SomeUnary'), service) 148 149 def test_check_for_presence_of_missing_methods(self) -> None: 150 service = self._channel_client.rpcs.pw.test1.PublicService 151 self.assertNotIn('Some', service) 152 self.assertNotIn('Unary', service) 153 self.assertNotIn(12345, service) 154 155 def test_method_fully_qualified_name(self) -> None: 156 self.assertIs(self._channel_client.method('pw.test2.Alpha/Unary'), 157 self._channel_client.rpcs.pw.test2.Alpha.Unary) 158 self.assertIs(self._channel_client.method('pw.test2.Alpha.Unary'), 159 self._channel_client.rpcs.pw.test2.Alpha.Unary) 160 161 162class ClientTest(unittest.TestCase): 163 """Tests the pw_rpc Client independently of the ClientImpl.""" 164 def setUp(self) -> None: 165 self._last_packet_sent_bytes: Optional[bytes] = None 166 self._protos, self._client = _test_setup(self._save_packet) 167 168 def _save_packet(self, packet) -> None: 169 self._last_packet_sent_bytes = packet 170 171 def _last_packet_sent(self) -> RpcPacket: 172 packet = RpcPacket() 173 assert self._last_packet_sent_bytes is not None 174 packet.MergeFromString(self._last_packet_sent_bytes) 175 return packet 176 177 def test_channel(self) -> None: 178 self.assertEqual(self._client.channel(1).channel.id, 1) 179 self.assertEqual(self._client.channel(2).channel.id, 2) 180 181 def test_channel_default_is_first_listed(self) -> None: 182 self.assertEqual(self._client.channel().channel.id, 1) 183 184 def test_channel_invalid(self) -> None: 185 with self.assertRaises(KeyError): 186 self._client.channel(404) 187 188 def test_all_methods(self) -> None: 189 services = self._client.services 190 191 all_methods = { 192 services['pw.test1.PublicService'].methods['SomeUnary'], 193 services['pw.test1.PublicService'].methods['SomeServerStreaming'], 194 services['pw.test1.PublicService'].methods['SomeClientStreaming'], 195 services['pw.test1.PublicService'].methods['SomeBidiStreaming'], 196 services['pw.test2.Alpha'].methods['Unary'], 197 services['pw.test2.Bravo'].methods['BidiStreaming'], 198 } 199 self.assertEqual(set(self._client.methods()), all_methods) 200 201 def test_method_present(self) -> None: 202 self.assertIs( 203 self._client.method('pw.test1.PublicService.SomeUnary'), self. 204 _client.services['pw.test1.PublicService'].methods['SomeUnary']) 205 self.assertIs( 206 self._client.method('pw.test1.PublicService/SomeUnary'), self. 207 _client.services['pw.test1.PublicService'].methods['SomeUnary']) 208 209 def test_method_invalid_format(self) -> None: 210 with self.assertRaises(ValueError): 211 self._client.method('SomeUnary') 212 213 def test_method_not_present(self) -> None: 214 with self.assertRaises(KeyError): 215 self._client.method('pw.test1.PublicService/ThisIsNotGood') 216 217 with self.assertRaises(KeyError): 218 self._client.method('nothing.Good') 219 220 def test_process_packet_invalid_proto_data(self) -> None: 221 self.assertIs(self._client.process_packet(b'NOT a packet!'), 222 Status.DATA_LOSS) 223 224 def test_process_packet_not_for_client(self) -> None: 225 self.assertIs( 226 self._client.process_packet( 227 RpcPacket(type=PacketType.REQUEST).SerializeToString()), 228 Status.INVALID_ARGUMENT) 229 230 def test_process_packet_unrecognized_channel(self) -> None: 231 self.assertIs( 232 self._client.process_packet( 233 packets.encode_response( 234 (123, 456, 789), 235 self._protos.packages.pw.test2.Request())), 236 Status.NOT_FOUND) 237 238 def test_process_packet_unrecognized_service(self) -> None: 239 self.assertIs( 240 self._client.process_packet( 241 packets.encode_response( 242 (1, 456, 789), self._protos.packages.pw.test2.Request())), 243 Status.OK) 244 245 self.assertEqual( 246 self._last_packet_sent(), 247 RpcPacket(type=PacketType.CLIENT_ERROR, 248 channel_id=1, 249 service_id=456, 250 method_id=789, 251 status=Status.NOT_FOUND.value)) 252 253 def test_process_packet_unrecognized_method(self) -> None: 254 service = next(iter(self._client.services)) 255 256 self.assertIs( 257 self._client.process_packet( 258 packets.encode_response( 259 (1, service.id, 789), 260 self._protos.packages.pw.test2.Request())), Status.OK) 261 262 self.assertEqual( 263 self._last_packet_sent(), 264 RpcPacket(type=PacketType.CLIENT_ERROR, 265 channel_id=1, 266 service_id=service.id, 267 method_id=789, 268 status=Status.NOT_FOUND.value)) 269 270 def test_process_packet_non_pending_method(self) -> None: 271 service = next(iter(self._client.services)) 272 method = next(iter(service.methods)) 273 274 self.assertIs( 275 self._client.process_packet( 276 packets.encode_response( 277 (1, service.id, method.id), 278 self._protos.packages.pw.test2.Request())), Status.OK) 279 280 self.assertEqual( 281 self._last_packet_sent(), 282 RpcPacket(type=PacketType.CLIENT_ERROR, 283 channel_id=1, 284 service_id=service.id, 285 method_id=method.id, 286 status=Status.FAILED_PRECONDITION.value)) 287 288 289if __name__ == '__main__': 290 unittest.main() 291