1#!/usr/bin/env python3 2# Copyright 2020 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests creating pw_rpc client.""" 16 17import unittest 18from typing import Optional 19 20from pw_protobuf_compiler import python_protos 21from pw_status import Status 22 23from pw_rpc import callback_client, client, packets 24import pw_rpc.ids 25from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket 26 27TEST_PROTO_1 = """\ 28syntax = "proto3"; 29 30package pw.test1; 31 32message SomeMessage { 33 uint32 magic_number = 1; 34} 35 36message AnotherMessage { 37 enum Result { 38 FAILED = 0; 39 FAILED_MISERABLY = 1; 40 I_DONT_WANT_TO_TALK_ABOUT_IT = 2; 41 } 42 43 Result result = 1; 44 string payload = 2; 45} 46 47service PublicService { 48 rpc SomeUnary(SomeMessage) returns (AnotherMessage) {} 49 rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {} 50 rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {} 51 rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {} 52} 53""" 54 55TEST_PROTO_2 = """\ 56syntax = "proto2"; 57 58package pw.test2; 59 60message Request { 61 optional float magic_number = 1; 62} 63 64message Response { 65} 66 67service Alpha { 68 rpc Unary(Request) returns (Response) {} 69} 70 71service Bravo { 72 rpc BidiStreaming(stream Request) returns (stream Response) {} 73} 74""" 75 76 77def _test_setup(output=None): 78 protos = python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2]) 79 return protos, client.Client.from_modules( 80 callback_client.Impl(), 81 [client.Channel(1, output), client.Channel(2, lambda _: None)], 82 protos.modules(), 83 ) 84 85 86class ChannelClientTest(unittest.TestCase): 87 """Tests the ChannelClient.""" 88 89 def setUp(self) -> None: 90 self._channel_client = _test_setup()[1].channel(1) 91 92 def test_access_service_client_as_attribute_or_index(self) -> None: 93 self.assertIs( 94 self._channel_client.rpcs.pw.test1.PublicService, 95 self._channel_client.rpcs['pw.test1.PublicService'], 96 ) 97 self.assertIs( 98 self._channel_client.rpcs.pw.test1.PublicService, 99 self._channel_client.rpcs[ 100 pw_rpc.ids.calculate('pw.test1.PublicService') 101 ], 102 ) 103 104 def test_access_method_client_as_attribute_or_index(self) -> None: 105 self.assertIs( 106 self._channel_client.rpcs.pw.test2.Alpha.Unary, 107 self._channel_client.rpcs['pw.test2.Alpha']['Unary'], 108 ) 109 self.assertIs( 110 self._channel_client.rpcs.pw.test2.Alpha.Unary, 111 self._channel_client.rpcs['pw.test2.Alpha'][ 112 pw_rpc.ids.calculate('Unary') 113 ], 114 ) 115 116 def test_service_name(self) -> None: 117 self.assertEqual( 118 self._channel_client.rpcs.pw.test2.Alpha.Unary.service.name, 'Alpha' 119 ) 120 self.assertEqual( 121 self._channel_client.rpcs.pw.test2.Alpha.Unary.service.full_name, 122 'pw.test2.Alpha', 123 ) 124 125 def test_method_name(self) -> None: 126 self.assertEqual( 127 self._channel_client.rpcs.pw.test2.Alpha.Unary.method.name, 'Unary' 128 ) 129 self.assertEqual( 130 self._channel_client.rpcs.pw.test2.Alpha.Unary.method.full_name, 131 'pw.test2.Alpha.Unary', 132 ) 133 134 def test_iterate_over_all_methods(self) -> None: 135 channel_client = self._channel_client 136 all_methods = { 137 channel_client.rpcs.pw.test1.PublicService.SomeUnary, 138 channel_client.rpcs.pw.test1.PublicService.SomeServerStreaming, 139 channel_client.rpcs.pw.test1.PublicService.SomeClientStreaming, 140 channel_client.rpcs.pw.test1.PublicService.SomeBidiStreaming, 141 channel_client.rpcs.pw.test2.Alpha.Unary, 142 channel_client.rpcs.pw.test2.Bravo.BidiStreaming, 143 } 144 self.assertEqual(set(channel_client.methods()), all_methods) 145 146 def test_check_for_presence_of_services(self) -> None: 147 self.assertIn('pw.test1.PublicService', self._channel_client.rpcs) 148 self.assertIn( 149 pw_rpc.ids.calculate('pw.test1.PublicService'), 150 self._channel_client.rpcs, 151 ) 152 153 def test_check_for_presence_of_missing_services(self) -> None: 154 self.assertNotIn('PublicService', self._channel_client.rpcs) 155 self.assertNotIn('NotAService', self._channel_client.rpcs) 156 self.assertNotIn(-1213, self._channel_client.rpcs) 157 158 def test_check_for_presence_of_methods(self) -> None: 159 service = self._channel_client.rpcs.pw.test1.PublicService 160 self.assertIn('SomeUnary', service) 161 self.assertIn(pw_rpc.ids.calculate('SomeUnary'), service) 162 163 def test_check_for_presence_of_missing_methods(self) -> None: 164 service = self._channel_client.rpcs.pw.test1.PublicService 165 self.assertNotIn('Some', service) 166 self.assertNotIn('Unary', service) 167 self.assertNotIn(12345, service) 168 169 def test_method_fully_qualified_name(self) -> None: 170 self.assertIs( 171 self._channel_client.method('pw.test2.Alpha/Unary'), 172 self._channel_client.rpcs.pw.test2.Alpha.Unary, 173 ) 174 self.assertIs( 175 self._channel_client.method('pw.test2.Alpha.Unary'), 176 self._channel_client.rpcs.pw.test2.Alpha.Unary, 177 ) 178 179 180class ClientTest(unittest.TestCase): 181 """Tests the pw_rpc Client independently of the ClientImpl.""" 182 183 def setUp(self) -> None: 184 self._last_packet_sent_bytes: Optional[bytes] = None 185 self._protos, self._client = _test_setup(self._save_packet) 186 187 def _save_packet(self, packet) -> None: 188 self._last_packet_sent_bytes = packet 189 190 def _last_packet_sent(self) -> RpcPacket: 191 packet = RpcPacket() 192 assert self._last_packet_sent_bytes is not None 193 packet.MergeFromString(self._last_packet_sent_bytes) 194 return packet 195 196 def test_channel(self) -> None: 197 self.assertEqual(self._client.channel(1).channel.id, 1) 198 self.assertEqual(self._client.channel(2).channel.id, 2) 199 200 def test_channel_default_is_first_listed(self) -> None: 201 self.assertEqual(self._client.channel().channel.id, 1) 202 203 def test_channel_invalid(self) -> None: 204 with self.assertRaises(KeyError): 205 self._client.channel(404) 206 207 def test_all_methods(self) -> None: 208 services = self._client.services 209 210 all_methods = { 211 services['pw.test1.PublicService'].methods['SomeUnary'], 212 services['pw.test1.PublicService'].methods['SomeServerStreaming'], 213 services['pw.test1.PublicService'].methods['SomeClientStreaming'], 214 services['pw.test1.PublicService'].methods['SomeBidiStreaming'], 215 services['pw.test2.Alpha'].methods['Unary'], 216 services['pw.test2.Bravo'].methods['BidiStreaming'], 217 } 218 self.assertEqual(set(self._client.methods()), all_methods) 219 220 def test_method_present(self) -> None: 221 self.assertIs( 222 self._client.method('pw.test1.PublicService.SomeUnary'), 223 self._client.services['pw.test1.PublicService'].methods[ 224 'SomeUnary' 225 ], 226 ) 227 self.assertIs( 228 self._client.method('pw.test1.PublicService/SomeUnary'), 229 self._client.services['pw.test1.PublicService'].methods[ 230 'SomeUnary' 231 ], 232 ) 233 234 def test_method_invalid_format(self) -> None: 235 with self.assertRaises(ValueError): 236 self._client.method('SomeUnary') 237 238 def test_method_not_present(self) -> None: 239 with self.assertRaises(KeyError): 240 self._client.method('pw.test1.PublicService/ThisIsNotGood') 241 242 with self.assertRaises(KeyError): 243 self._client.method('nothing.Good') 244 245 def test_process_packet_invalid_proto_data(self) -> None: 246 self.assertIs( 247 self._client.process_packet(b'NOT a packet!'), Status.DATA_LOSS 248 ) 249 250 def test_process_packet_not_for_client(self) -> None: 251 self.assertIs( 252 self._client.process_packet( 253 RpcPacket(type=PacketType.REQUEST).SerializeToString() 254 ), 255 Status.INVALID_ARGUMENT, 256 ) 257 258 def test_process_packet_unrecognized_channel(self) -> None: 259 self.assertIs( 260 self._client.process_packet( 261 packets.encode_response( 262 (123, 456, 789), self._protos.packages.pw.test2.Request() 263 ) 264 ), 265 Status.NOT_FOUND, 266 ) 267 268 def test_process_packet_unrecognized_service(self) -> None: 269 self.assertIs( 270 self._client.process_packet( 271 packets.encode_response( 272 (1, 456, 789), self._protos.packages.pw.test2.Request() 273 ) 274 ), 275 Status.OK, 276 ) 277 278 self.assertEqual( 279 self._last_packet_sent(), 280 RpcPacket( 281 type=PacketType.CLIENT_ERROR, 282 channel_id=1, 283 service_id=456, 284 method_id=789, 285 status=Status.NOT_FOUND.value, 286 ), 287 ) 288 289 def test_process_packet_unrecognized_method(self) -> None: 290 service = next(iter(self._client.services)) 291 292 self.assertIs( 293 self._client.process_packet( 294 packets.encode_response( 295 (1, service.id, 789), 296 self._protos.packages.pw.test2.Request(), 297 ) 298 ), 299 Status.OK, 300 ) 301 302 self.assertEqual( 303 self._last_packet_sent(), 304 RpcPacket( 305 type=PacketType.CLIENT_ERROR, 306 channel_id=1, 307 service_id=service.id, 308 method_id=789, 309 status=Status.NOT_FOUND.value, 310 ), 311 ) 312 313 def test_process_packet_non_pending_method(self) -> None: 314 service = next(iter(self._client.services)) 315 method = next(iter(service.methods)) 316 317 self.assertIs( 318 self._client.process_packet( 319 packets.encode_response( 320 (1, service.id, method.id), 321 self._protos.packages.pw.test2.Request(), 322 ) 323 ), 324 Status.OK, 325 ) 326 327 self.assertEqual( 328 self._last_packet_sent(), 329 RpcPacket( 330 type=PacketType.CLIENT_ERROR, 331 channel_id=1, 332 service_id=service.id, 333 method_id=method.id, 334 status=Status.FAILED_PRECONDITION.value, 335 ), 336 ) 337 338 def test_process_packet_non_pending_calls_response_callback(self) -> None: 339 method = self._client.method('pw.test1.PublicService.SomeUnary') 340 reply = method.response_type(payload='hello') 341 342 def response_callback( 343 rpc: client.PendingRpc, message, status: Optional[Status] 344 ) -> None: 345 self.assertEqual( 346 rpc, 347 client.PendingRpc( 348 self._client.channel(1).channel, method.service, method 349 ), 350 ) 351 self.assertEqual(message, reply) 352 self.assertIs(status, Status.OK) 353 354 self._client.response_callback = response_callback 355 356 self.assertIs( 357 self._client.process_packet( 358 packets.encode_response((1, method.service, method), reply) 359 ), 360 Status.OK, 361 ) 362 363 364if __name__ == '__main__': 365 unittest.main() 366