1#!/usr/bin/env python3 2# Copyright 2020 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests creating pw_rpc client.""" 16 17import unittest 18from typing import Any, Callable 19 20from pw_protobuf_compiler import python_protos 21from pw_status import Status 22 23from pw_rpc import callback_client, client, packets 24import pw_rpc.ids 25from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket 26 27RpcIds = packets.RpcIds 28 29TEST_PROTO_1 = """\ 30syntax = "proto3"; 31 32package pw.test1; 33 34message SomeMessage { 35 uint32 magic_number = 1; 36} 37 38message AnotherMessage { 39 enum Result { 40 FAILED = 0; 41 FAILED_MISERABLY = 1; 42 I_DONT_WANT_TO_TALK_ABOUT_IT = 2; 43 } 44 45 Result result = 1; 46 string payload = 2; 47} 48 49service PublicService { 50 rpc SomeUnary(SomeMessage) returns (AnotherMessage) {} 51 rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {} 52 rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {} 53 rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {} 54} 55""" 56 57TEST_PROTO_2 = """\ 58syntax = "proto2"; 59 60package pw.test2; 61 62message Request { 63 optional float magic_number = 1; 64} 65 66message Response { 67} 68 69service Alpha { 70 rpc Unary(Request) returns (Response) {} 71} 72 73service Bravo { 74 rpc BidiStreaming(stream Request) returns (stream Response) {} 75} 76""" 77 78SOME_CHANNEL_ID: int = 237 79SOME_SERVICE_ID: int = 193 80SOME_METHOD_ID: int = 769 81SOME_CALL_ID: int = 452 82 83CLIENT_FIRST_CHANNEL_ID: int = 557 84CLIENT_SECOND_CHANNEL_ID: int = 474 85 86 87def create_protos() -> Any: 88 return python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2]) 89 90 91def create_client( 92 proto_modules: Any, 93 first_channel_output_fn: Callable[[bytes], Any] | None = None, 94) -> client.Client: 95 return client.Client.from_modules( 96 callback_client.Impl(), 97 [ 98 client.Channel(CLIENT_FIRST_CHANNEL_ID, first_channel_output_fn), 99 client.Channel(CLIENT_SECOND_CHANNEL_ID, lambda _: None), 100 ], 101 proto_modules, 102 ) 103 104 105class ChannelClientTest(unittest.TestCase): 106 """Tests the ChannelClient.""" 107 108 def setUp(self) -> None: 109 client_instance = create_client(create_protos().modules()) 110 self._channel_client: client.ChannelClient = client_instance.channel( 111 CLIENT_FIRST_CHANNEL_ID 112 ) 113 114 def test_access_service_client_as_attribute_or_index(self) -> None: 115 self.assertIs( 116 self._channel_client.rpcs.pw.test1.PublicService, 117 self._channel_client.rpcs['pw.test1.PublicService'], 118 ) 119 self.assertIs( 120 self._channel_client.rpcs.pw.test1.PublicService, 121 self._channel_client.rpcs[ 122 pw_rpc.ids.calculate('pw.test1.PublicService') 123 ], 124 ) 125 126 def test_access_method_client_as_attribute_or_index(self) -> None: 127 self.assertIs( 128 self._channel_client.rpcs.pw.test2.Alpha.Unary, 129 self._channel_client.rpcs['pw.test2.Alpha']['Unary'], 130 ) 131 self.assertIs( 132 self._channel_client.rpcs.pw.test2.Alpha.Unary, 133 self._channel_client.rpcs['pw.test2.Alpha'][ 134 pw_rpc.ids.calculate('Unary') 135 ], 136 ) 137 138 def test_service_name(self) -> None: 139 self.assertEqual( 140 self._channel_client.rpcs.pw.test2.Alpha.Unary.service.name, 'Alpha' 141 ) 142 self.assertEqual( 143 self._channel_client.rpcs.pw.test2.Alpha.Unary.service.full_name, 144 'pw.test2.Alpha', 145 ) 146 147 def test_method_name(self) -> None: 148 self.assertEqual( 149 self._channel_client.rpcs.pw.test2.Alpha.Unary.method.name, 'Unary' 150 ) 151 self.assertEqual( 152 self._channel_client.rpcs.pw.test2.Alpha.Unary.method.full_name, 153 'pw.test2.Alpha.Unary', 154 ) 155 156 def test_iterate_over_all_methods(self) -> None: 157 channel_client = self._channel_client 158 all_methods = { 159 channel_client.rpcs.pw.test1.PublicService.SomeUnary, 160 channel_client.rpcs.pw.test1.PublicService.SomeServerStreaming, 161 channel_client.rpcs.pw.test1.PublicService.SomeClientStreaming, 162 channel_client.rpcs.pw.test1.PublicService.SomeBidiStreaming, 163 channel_client.rpcs.pw.test2.Alpha.Unary, 164 channel_client.rpcs.pw.test2.Bravo.BidiStreaming, 165 } 166 self.assertEqual(set(channel_client.methods()), all_methods) 167 168 def test_check_for_presence_of_services(self) -> None: 169 self.assertIn('pw.test1.PublicService', self._channel_client.rpcs) 170 self.assertIn( 171 pw_rpc.ids.calculate('pw.test1.PublicService'), 172 self._channel_client.rpcs, 173 ) 174 175 def test_check_for_presence_of_missing_services(self) -> None: 176 self.assertNotIn('PublicService', self._channel_client.rpcs) 177 self.assertNotIn('NotAService', self._channel_client.rpcs) 178 self.assertNotIn(-1213, self._channel_client.rpcs) 179 180 def test_check_for_presence_of_methods(self) -> None: 181 service = self._channel_client.rpcs.pw.test1.PublicService 182 self.assertIn('SomeUnary', service) 183 self.assertIn(pw_rpc.ids.calculate('SomeUnary'), service) 184 185 def test_check_for_presence_of_missing_methods(self) -> None: 186 service = self._channel_client.rpcs.pw.test1.PublicService 187 self.assertNotIn('Some', service) 188 self.assertNotIn('Unary', service) 189 self.assertNotIn(12345, service) 190 191 def test_method_fully_qualified_name(self) -> None: 192 self.assertIs( 193 self._channel_client.method('pw.test2.Alpha/Unary'), 194 self._channel_client.rpcs.pw.test2.Alpha.Unary, 195 ) 196 self.assertIs( 197 self._channel_client.method('pw.test2.Alpha.Unary'), 198 self._channel_client.rpcs.pw.test2.Alpha.Unary, 199 ) 200 201 202class ClientTest(unittest.TestCase): 203 """Tests the pw_rpc Client independently of the ClientImpl.""" 204 205 def setUp(self) -> None: 206 self._last_packet_sent_bytes: bytes | None = None 207 self._protos = create_protos() 208 self._client = create_client(self._protos.modules(), self._save_packet) 209 210 def _save_packet(self, packet) -> None: 211 self._last_packet_sent_bytes = packet 212 213 def _last_packet_sent(self) -> RpcPacket: 214 packet = RpcPacket() 215 assert self._last_packet_sent_bytes is not None 216 packet.MergeFromString(self._last_packet_sent_bytes) 217 return packet 218 219 def test_channel(self) -> None: 220 self.assertEqual( 221 self._client.channel(CLIENT_FIRST_CHANNEL_ID).channel.id, 222 CLIENT_FIRST_CHANNEL_ID, 223 ) 224 self.assertEqual( 225 self._client.channel(CLIENT_SECOND_CHANNEL_ID).channel.id, 226 CLIENT_SECOND_CHANNEL_ID, 227 ) 228 229 def test_channel_default_is_first_listed(self) -> None: 230 self.assertEqual( 231 self._client.channel().channel.id, CLIENT_FIRST_CHANNEL_ID 232 ) 233 234 def test_channel_invalid(self) -> None: 235 with self.assertRaises(KeyError): 236 self._client.channel(404) 237 238 def test_all_methods(self) -> None: 239 services = self._client.services 240 241 all_methods = { 242 services['pw.test1.PublicService'].methods['SomeUnary'], 243 services['pw.test1.PublicService'].methods['SomeServerStreaming'], 244 services['pw.test1.PublicService'].methods['SomeClientStreaming'], 245 services['pw.test1.PublicService'].methods['SomeBidiStreaming'], 246 services['pw.test2.Alpha'].methods['Unary'], 247 services['pw.test2.Bravo'].methods['BidiStreaming'], 248 } 249 self.assertEqual(set(self._client.methods()), all_methods) 250 251 def test_method_present(self) -> None: 252 self.assertIs( 253 self._client.method('pw.test1.PublicService.SomeUnary'), 254 self._client.services['pw.test1.PublicService'].methods[ 255 'SomeUnary' 256 ], 257 ) 258 self.assertIs( 259 self._client.method('pw.test1.PublicService/SomeUnary'), 260 self._client.services['pw.test1.PublicService'].methods[ 261 'SomeUnary' 262 ], 263 ) 264 265 def test_method_invalid_format(self) -> None: 266 with self.assertRaises(ValueError): 267 self._client.method('SomeUnary') 268 269 def test_method_not_present(self) -> None: 270 with self.assertRaises(KeyError): 271 self._client.method('pw.test1.PublicService/ThisIsNotGood') 272 273 with self.assertRaises(KeyError): 274 self._client.method('nothing.Good') 275 276 def test_process_packet_invalid_proto_data(self) -> None: 277 self.assertIs( 278 self._client.process_packet(b'NOT a packet!'), Status.DATA_LOSS 279 ) 280 281 def test_process_packet_not_for_client(self) -> None: 282 self.assertIs( 283 self._client.process_packet( 284 RpcPacket(type=PacketType.REQUEST).SerializeToString() 285 ), 286 Status.INVALID_ARGUMENT, 287 ) 288 289 def test_process_packet_unrecognized_channel(self) -> None: 290 self.assertIs( 291 self._client.process_packet( 292 packets.encode_response( 293 RpcIds( 294 SOME_CHANNEL_ID, 295 SOME_SERVICE_ID, 296 SOME_METHOD_ID, 297 SOME_CALL_ID, 298 ), 299 self._protos.packages.pw.test2.Request(), 300 ) 301 ), 302 Status.NOT_FOUND, 303 ) 304 305 def test_process_packet_unrecognized_service(self) -> None: 306 self.assertIs( 307 self._client.process_packet( 308 packets.encode_response( 309 RpcIds( 310 CLIENT_FIRST_CHANNEL_ID, 311 SOME_SERVICE_ID, 312 SOME_METHOD_ID, 313 SOME_CALL_ID, 314 ), 315 self._protos.packages.pw.test2.Request(), 316 ) 317 ), 318 Status.OK, 319 ) 320 321 self.assertEqual( 322 self._last_packet_sent(), 323 RpcPacket( 324 type=PacketType.CLIENT_ERROR, 325 channel_id=CLIENT_FIRST_CHANNEL_ID, 326 service_id=SOME_SERVICE_ID, 327 method_id=SOME_METHOD_ID, 328 call_id=SOME_CALL_ID, 329 status=Status.NOT_FOUND.value, 330 ), 331 ) 332 333 def test_process_packet_unrecognized_method(self) -> None: 334 service = next(iter(self._client.services)) 335 336 self.assertIs( 337 self._client.process_packet( 338 packets.encode_response( 339 RpcIds( 340 CLIENT_FIRST_CHANNEL_ID, 341 service.id, 342 SOME_METHOD_ID, 343 SOME_CALL_ID, 344 ), 345 self._protos.packages.pw.test2.Request(), 346 ) 347 ), 348 Status.OK, 349 ) 350 351 self.assertEqual( 352 self._last_packet_sent(), 353 RpcPacket( 354 type=PacketType.CLIENT_ERROR, 355 channel_id=CLIENT_FIRST_CHANNEL_ID, 356 service_id=service.id, 357 method_id=SOME_METHOD_ID, 358 call_id=SOME_CALL_ID, 359 status=Status.NOT_FOUND.value, 360 ), 361 ) 362 363 def test_process_packet_non_pending_method(self) -> None: 364 service = next(iter(self._client.services)) 365 method = next(iter(service.methods)) 366 367 self.assertIs( 368 self._client.process_packet( 369 packets.encode_response( 370 RpcIds( 371 CLIENT_FIRST_CHANNEL_ID, 372 service.id, 373 method.id, 374 SOME_CALL_ID, 375 ), 376 self._protos.packages.pw.test2.Request(), 377 ) 378 ), 379 Status.OK, 380 ) 381 382 self.assertEqual( 383 self._last_packet_sent(), 384 RpcPacket( 385 type=PacketType.CLIENT_ERROR, 386 channel_id=CLIENT_FIRST_CHANNEL_ID, 387 service_id=service.id, 388 method_id=method.id, 389 call_id=SOME_CALL_ID, 390 status=Status.FAILED_PRECONDITION.value, 391 ), 392 ) 393 394 def test_process_packet_non_pending_calls_response_callback(self) -> None: 395 method = self._client.method('pw.test1.PublicService.SomeUnary') 396 reply = method.response_type(payload='hello') 397 398 def response_callback( 399 rpc: client.PendingRpc, 400 message, 401 status: Status | None, 402 ) -> None: 403 self.assertEqual( 404 rpc, 405 client.PendingRpc( 406 self._client.channel(CLIENT_FIRST_CHANNEL_ID).channel, 407 method.service, 408 method, 409 call_id=SOME_CALL_ID, 410 ), 411 ) 412 self.assertEqual(message, reply) 413 self.assertIs(status, Status.OK) 414 415 self._client.response_callback = response_callback 416 417 self.assertIs( 418 self._client.process_packet( 419 packets.encode_response( 420 RpcIds( 421 CLIENT_FIRST_CHANNEL_ID, 422 method.service.id, 423 method.id, 424 SOME_CALL_ID, 425 ), 426 reply, 427 ) 428 ), 429 Status.OK, 430 ) 431 432 433if __name__ == '__main__': 434 unittest.main() 435