1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Provides a pw_rpc client for Python.""" 15 16from __future__ import annotations 17 18import abc 19from dataclasses import dataclass 20import logging 21from typing import ( 22 Any, 23 Callable, 24 Collection, 25 Iterable, 26 Iterator, 27) 28 29from google.protobuf.message import DecodeError, Message 30from pw_status import Status 31 32from pw_rpc import descriptors, packets 33from pw_rpc.descriptors import Channel, Service, Method 34from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket 35 36_LOG = logging.getLogger(__package__) 37 38# Calls with ID of `kOpenCallId` were unrequested, and are updated to have the 39# call ID of the first matching request. 40LEGACY_OPEN_CALL_ID: int = 0 41OPEN_CALL_ID: int = (2**32) - 1 42 43_MAX_CALL_ID: int = 1 << 14 44 45 46class Error(Exception): 47 """Error from incorrectly using the RPC client classes.""" 48 49 50class PendingRpc(packets.RpcIds): 51 """Uniquely identifies an RPC call. 52 53 Attributes: 54 channel: Channel 55 service: Service 56 method: Method 57 channel_id: int 58 service_id: int 59 method_id: int 60 call_id: int 61 """ 62 63 def __init__( 64 self, 65 channel: Channel, 66 service: Service, 67 method: Method, 68 call_id: int, 69 ) -> None: 70 super().__init__(channel.id, service.id, method.id, call_id) 71 self.channel = channel 72 self.service = service 73 self.method = method 74 75 76class _PendingRpcMetadata: 77 def __init__(self, context: object): 78 self.context = context 79 80 81class PendingRpcs: 82 """Tracks pending RPCs and encodes outgoing RPC packets.""" 83 84 def __init__(self) -> None: 85 self._pending: dict[PendingRpc, _PendingRpcMetadata] = {} 86 # We skip call_id = 0 in order to avoid LEGACY_OPEN_CALL_ID. 87 self._next_call_id: int = 1 88 89 def allocate_call_id(self) -> int: 90 call_id = self._next_call_id 91 self._next_call_id = (self._next_call_id + 1) % _MAX_CALL_ID 92 # We skip call_id = 0 in order to avoid LEGACY_OPEN_CALL_ID. 93 if self._next_call_id == 0: 94 self._next_call_id = 1 95 return call_id 96 97 def request( 98 self, 99 rpc: PendingRpc, 100 request: Message | None, 101 context: object, 102 override_pending: bool = True, 103 ) -> bytes: 104 """Starts the provided RPC and returns the encoded packet to send.""" 105 # Ensure that every context is a unique object by wrapping it in a list. 106 self.open(rpc, context, override_pending) 107 return packets.encode_request(rpc, request) 108 109 def send_request( 110 self, 111 rpc: PendingRpc, 112 request: Message | None, 113 context: object, 114 *, 115 ignore_errors: bool = False, 116 override_pending: bool = False, 117 ) -> Any: 118 """Starts the provided RPC and sends the request packet to the channel. 119 120 Returns: 121 the previous context object or None 122 """ 123 previous = self.open(rpc, context, override_pending) 124 packet = packets.encode_request(rpc, request) 125 126 # TODO(hepler): Remove `type: ignore[misc]` below when 127 # https://github.com/python/mypy/issues/10711 is fixed. 128 if ignore_errors: 129 try: 130 rpc.channel.output(packet) # type: ignore[misc] 131 except Exception as err: # pylint: disable=broad-except 132 _LOG.debug('Ignoring exception when starting RPC: %s', err) 133 else: 134 rpc.channel.output(packet) # type: ignore[misc] 135 136 return previous 137 138 def open( 139 self, rpc: PendingRpc, context: object, override_pending: bool = False 140 ) -> Any: 141 """Creates a context for an RPC, but does not invoke it. 142 143 open() can be used to receive streaming responses to an RPC that was not 144 invoked by this client. For example, a server may stream logs with a 145 server streaming RPC prior to any clients invoking it. 146 147 Returns: 148 the previous context object or None 149 """ 150 _LOG.debug('Starting %s', rpc) 151 metadata = _PendingRpcMetadata(context) 152 153 if override_pending: 154 previous = self._pending.get(rpc) 155 self._pending[rpc] = metadata 156 return None if previous is None else previous.context 157 158 if self._pending.setdefault(rpc, metadata) is not metadata: 159 # If the context was not added, the RPC was already pending. 160 raise Error( 161 f'Sent request for {rpc}, but it is already pending! ' 162 'Cancel the RPC before invoking it again' 163 ) 164 165 return None 166 167 def send_client_stream(self, rpc: PendingRpc, message: Message) -> None: 168 if rpc not in self._pending: 169 raise Error(f'Attempt to send client stream for inactive RPC {rpc}') 170 171 rpc.channel.output( # type: ignore 172 packets.encode_client_stream(rpc, message) 173 ) 174 175 def send_client_stream_end(self, rpc: PendingRpc) -> None: 176 if rpc not in self._pending: 177 raise Error( 178 f'Attempt to send client stream end for inactive RPC {rpc}' 179 ) 180 181 rpc.channel.output( # type: ignore 182 packets.encode_client_stream_end(rpc) 183 ) 184 185 def cancel(self, rpc: PendingRpc) -> bytes: 186 """Cancels the RPC. 187 188 Returns: 189 The CLIENT_ERROR packet to send. 190 191 Raises: 192 KeyError if the RPC is not pending 193 """ 194 _LOG.debug('Cancelling %s', rpc) 195 del self._pending[rpc] 196 197 return packets.encode_cancel(rpc) 198 199 def send_cancel(self, rpc: PendingRpc) -> bool: 200 """Calls cancel and sends the cancel packet, if any, to the channel.""" 201 try: 202 packet = self.cancel(rpc) 203 except KeyError: 204 return False 205 206 if packet: 207 rpc.channel.output(packet) # type: ignore 208 209 return True 210 211 def get_pending(self, rpc: PendingRpc, status: Status | None): 212 """Gets the pending RPC's context. If status is set, clears the RPC.""" 213 if rpc.call_id == OPEN_CALL_ID or rpc.call_id == LEGACY_OPEN_CALL_ID: 214 # Calls with ID `OPEN_CALL_ID` were unrequested, and are updated to 215 # have the call ID of the first matching request. 216 for pending in self._pending: 217 if ( 218 pending.channel == rpc.channel 219 and pending.service == rpc.service 220 and pending.method == rpc.method 221 ): 222 rpc = pending 223 224 if status is None: 225 return self._pending[rpc].context 226 227 _LOG.debug('%s finished with status %s', rpc, status) 228 return self._pending.pop(rpc).context 229 230 231class ClientImpl(abc.ABC): 232 """The internal interface of the RPC client. 233 234 This interface defines the semantics for invoking an RPC on a particular 235 client. 236 """ 237 238 def __init__(self) -> None: 239 self.client: Client | None = None 240 self.rpcs: PendingRpcs | None = None 241 242 @abc.abstractmethod 243 def method_client(self, channel: Channel, method: Method) -> Any: 244 """Returns an object that invokes a method using the given channel.""" 245 246 @abc.abstractmethod 247 def handle_response( 248 self, 249 rpc: PendingRpc, 250 context: Any, 251 payload: Any, 252 *, 253 args: tuple = (), 254 kwargs: dict | None = None, 255 ) -> Any: 256 """Handles a response from the RPC server. 257 258 Args: 259 rpc: Information about the pending RPC 260 context: Arbitrary context object associated with the pending RPC 261 payload: A protobuf message 262 args, kwargs: Arbitrary arguments passed to the ClientImpl 263 """ 264 265 @abc.abstractmethod 266 def handle_completion( 267 self, 268 rpc: PendingRpc, 269 context: Any, 270 status: Status, 271 *, 272 args: tuple = (), 273 kwargs: dict | None = None, 274 ) -> Any: 275 """Handles the successful completion of an RPC. 276 277 Args: 278 rpc: Information about the pending RPC 279 context: Arbitrary context object associated with the pending RPC 280 status: Status returned from the RPC 281 args, kwargs: Arbitrary arguments passed to the ClientImpl 282 """ 283 284 @abc.abstractmethod 285 def handle_error( 286 self, 287 rpc: PendingRpc, 288 context, 289 status: Status, 290 *, 291 args: tuple = (), 292 kwargs: dict | None = None, 293 ): 294 """Handles the abnormal termination of an RPC. 295 296 args: 297 rpc: Information about the pending RPC 298 context: Arbitrary context object associated with the pending RPC 299 status: which error occurred 300 args, kwargs: Arbitrary arguments passed to the ClientImpl 301 """ 302 303 304class ServiceClient(descriptors.ServiceAccessor): 305 """Navigates the methods in a service provided by a ChannelClient.""" 306 307 def __init__( 308 self, client_impl: ClientImpl, channel: Channel, service: Service 309 ): 310 super().__init__( 311 { 312 method: client_impl.method_client(channel, method) 313 for method in service.methods 314 }, 315 as_attrs='members', 316 ) 317 318 self._channel = channel 319 self._service = service 320 321 def __repr__(self) -> str: 322 return ( 323 f'Service({self._service.full_name!r}, ' 324 f'methods={[m.name for m in self._service.methods]}, ' 325 f'channel={self._channel.id})' 326 ) 327 328 def __str__(self) -> str: 329 return str(self._service) 330 331 332class Services(descriptors.ServiceAccessor[ServiceClient]): 333 """Navigates the services provided by a ChannelClient.""" 334 335 def __init__( 336 self, client_impl, channel: Channel, services: Collection[Service] 337 ): 338 super().__init__( 339 {s: ServiceClient(client_impl, channel, s) for s in services}, 340 as_attrs='packages', 341 ) 342 343 self._channel = channel 344 self._services = services 345 346 def __repr__(self) -> str: 347 return ( 348 f'Services(channel={self._channel.id}, ' 349 f'services={[s.full_name for s in self._services]})' 350 ) 351 352 353def _decode_status(rpc: PendingRpc, packet) -> Status | None: 354 if packet.type == PacketType.SERVER_STREAM: 355 return None 356 357 try: 358 return Status(packet.status) 359 except ValueError: 360 _LOG.warning('Illegal status code %d for %s', packet.status, rpc) 361 return Status.UNKNOWN 362 363 364def _decode_payload(rpc: PendingRpc, packet) -> Message | None: 365 if packet.type == PacketType.SERVER_ERROR: 366 return None 367 368 # Server streaming RPCs do not send a payload with their RESPONSE packet. 369 if packet.type == PacketType.RESPONSE and rpc.method.server_streaming: 370 return None 371 372 return packets.decode_payload(packet, rpc.method.response_type) 373 374 375@dataclass(frozen=True, eq=False) 376class ChannelClient: 377 """RPC services and methods bound to a particular channel. 378 379 RPCs are invoked through service method clients. These may be accessed via 380 the `rpcs` member. Service methods use a fully qualified name: package, 381 service, method. Service methods may be selected as attributes or by 382 indexing the rpcs member by service and method name or ID. 383 384 # Access the service method client as an attribute 385 rpc = client.channel(1).rpcs.the.package.FooService.SomeMethod 386 387 # Access the service method client by string name 388 rpc = client.channel(1).rpcs[foo_service_id]['SomeMethod'] 389 390 RPCs may also be accessed from their canonical name. 391 392 # Access the service method client from its full name: 393 rpc = client.channel(1).method('the.package.FooService/SomeMethod') 394 395 # Using a . instead of a / is also supported: 396 rpc = client.channel(1).method('the.package.FooService.SomeMethod') 397 398 The ClientImpl class determines the type of the service method client. A 399 synchronous RPC client might return a callable object, so an RPC could be 400 invoked directly (e.g. rpc(field1=123, field2=b'456')). 401 """ 402 403 client: Client 404 channel: Channel 405 rpcs: Services 406 407 def method(self, method_name: str): 408 """Returns a method client matching the given name. 409 410 Args: 411 method_name: name as package.Service/Method or package.Service.Method. 412 413 Raises: 414 ValueError: the method name is not properly formatted 415 KeyError: the method is not present 416 """ 417 return descriptors.get_method(self.rpcs, method_name) 418 419 def services(self) -> Iterator: 420 return iter(self.rpcs) 421 422 def methods(self) -> Iterator: 423 """Iterates over all method clients in this ChannelClient.""" 424 for service_client in self.rpcs: 425 yield from service_client 426 427 def __repr__(self) -> str: 428 return ( 429 f'ChannelClient(channel={self.channel.id}, ' 430 f'services={[str(s) for s in self.services()]})' 431 ) 432 433 434def _update_for_backwards_compatibility( 435 rpc: PendingRpc, packet: RpcPacket 436) -> None: 437 """Adapts server streaming RPC packets to the updated protocol if needed.""" 438 # The protocol changes only affect server streaming RPCs. 439 if rpc.method.type is not Method.Type.SERVER_STREAMING: 440 return 441 442 # Prior to the introduction of SERVER_STREAM packets, RESPONSE packets with 443 # a payload were used instead. If a non-zero payload is present, assume this 444 # RESPONSE is equivalent to a SERVER_STREAM packet. 445 # 446 # Note that the payload field is not 'optional', so an empty payload is 447 # equivalent to a payload that happens to encode to zero bytes. This would 448 # only affect server streaming RPCs on the old protocol that intentionally 449 # send empty payloads, which will not be an issue in practice. 450 if packet.type == PacketType.RESPONSE and packet.payload: 451 packet.type = PacketType.SERVER_STREAM 452 453 454class Client: 455 """Sends requests and handles responses for a set of channels. 456 457 RPC invocations occur through a ChannelClient. 458 459 Users may set an optional response_callback that is called before processing 460 every response or server stream RPC packet. 461 """ 462 463 @classmethod 464 def from_modules( 465 cls, impl: ClientImpl, channels: Iterable[Channel], modules: Iterable 466 ): 467 return cls( 468 impl, 469 channels, 470 ( 471 Service.from_descriptor(service) 472 for module in modules 473 for service in module.DESCRIPTOR.services_by_name.values() 474 ), 475 ) 476 477 def __init__( 478 self, 479 impl: ClientImpl, 480 channels: Iterable[Channel], 481 services: Iterable[Service], 482 ): 483 self._impl = impl 484 self._impl.client = self 485 self._impl.rpcs = PendingRpcs() 486 487 self.services = descriptors.Services(services) 488 489 self._channels_by_id = { 490 channel.id: ChannelClient( 491 self, channel, Services(self._impl, channel, self.services) 492 ) 493 for channel in channels 494 } 495 496 # Optional function called before processing every non-error RPC packet. 497 self.response_callback: ( 498 Callable[[PendingRpc, Any, Status | None], Any] | None 499 ) = None 500 501 def channel(self, channel_id: int | None = None) -> ChannelClient: 502 """Returns a ChannelClient, which is used to call RPCs on a channel. 503 504 If no channel is provided, the first channel is used. 505 """ 506 if channel_id is None: 507 return next(iter(self._channels_by_id.values())) 508 509 return self._channels_by_id[channel_id] 510 511 def channels(self) -> Iterable[ChannelClient]: 512 """Accesses the ChannelClients in this client.""" 513 return self._channels_by_id.values() 514 515 def method(self, method_name: str) -> Method: 516 """Returns a Method matching the given name. 517 518 Args: 519 method_name: name as package.Service/Method or package.Service.Method. 520 521 Raises: 522 ValueError: the method name is not properly formatted 523 KeyError: the method is not present 524 """ 525 return descriptors.get_method(self.services, method_name) 526 527 def methods(self) -> Iterator[Method]: 528 """Iterates over all Methods supported by this client.""" 529 for service in self.services: 530 yield from service.methods 531 532 def process_packet( 533 self, pw_rpc_raw_packet_data: bytes, *impl_args, **impl_kwargs 534 ) -> Status: 535 """Processes an incoming packet. 536 537 Args: 538 pw_rpc_raw_packet_data: raw binary data for exactly one RPC packet 539 impl_args: optional positional arguments passed to the ClientImpl 540 impl_kwargs: optional keyword arguments passed to the ClientImpl 541 542 Returns: 543 OK - the packet was processed by this client 544 DATA_LOSS - the packet could not be decoded 545 INVALID_ARGUMENT - the packet is for a server, not a client 546 NOT_FOUND - the packet's channel ID is not known to this client 547 """ 548 try: 549 packet = packets.decode(pw_rpc_raw_packet_data) 550 except DecodeError as err: 551 _LOG.warning('Failed to decode packet: %s', err) 552 _LOG.debug('Raw packet: %r', pw_rpc_raw_packet_data) 553 return Status.DATA_LOSS 554 555 if packets.for_server(packet): 556 return Status.INVALID_ARGUMENT 557 558 try: 559 channel_client = self._channels_by_id[packet.channel_id] 560 except KeyError: 561 _LOG.warning('Unrecognized channel ID %d', packet.channel_id) 562 return Status.NOT_FOUND 563 564 try: 565 rpc = self._look_up_service_and_method(packet, channel_client) 566 except ValueError as err: 567 _send_client_error(channel_client, packet, Status.NOT_FOUND) 568 _LOG.warning('%s', err) 569 return Status.OK 570 571 _update_for_backwards_compatibility(rpc, packet) 572 573 if packet.type not in ( 574 PacketType.RESPONSE, 575 PacketType.SERVER_STREAM, 576 PacketType.SERVER_ERROR, 577 ): 578 _LOG.error('%s: unexpected PacketType %s', rpc, packet.type) 579 _LOG.debug('Packet:\n%s', packet) 580 return Status.OK 581 582 status = _decode_status(rpc, packet) 583 584 try: 585 payload = _decode_payload(rpc, packet) 586 except DecodeError as err: 587 _send_client_error(channel_client, packet, Status.DATA_LOSS) 588 _LOG.warning( 589 'Failed to decode %s response for %s: %s', 590 rpc.method.response_type.DESCRIPTOR.full_name, 591 rpc.method.full_name, 592 err, 593 ) 594 _LOG.debug('Raw payload: %s', packet.payload) 595 596 # Make this an error packet so the error handler is called. 597 packet.type = PacketType.SERVER_ERROR 598 status = Status.DATA_LOSS 599 600 # If set, call the response callback with non-error packets. 601 if self.response_callback and packet.type != PacketType.SERVER_ERROR: 602 self.response_callback( # pylint: disable=not-callable 603 rpc, payload, status 604 ) 605 606 try: 607 assert self._impl.rpcs 608 context = self._impl.rpcs.get_pending(rpc, status) 609 except KeyError: 610 _send_client_error( 611 channel_client, packet, Status.FAILED_PRECONDITION 612 ) 613 _LOG.debug('Discarding response for %s, which is not pending', rpc) 614 return Status.OK 615 616 if packet.type == PacketType.SERVER_ERROR: 617 assert status is not None and not status.ok() 618 _LOG.warning('%s: invocation failed with %s', rpc, status) 619 self._impl.handle_error( 620 rpc, context, status, args=impl_args, kwargs=impl_kwargs 621 ) 622 return Status.OK 623 624 if payload is not None: 625 self._impl.handle_response( 626 rpc, context, payload, args=impl_args, kwargs=impl_kwargs 627 ) 628 if status is not None: 629 self._impl.handle_completion( 630 rpc, context, status, args=impl_args, kwargs=impl_kwargs 631 ) 632 633 return Status.OK 634 635 def _look_up_service_and_method( 636 self, packet: RpcPacket, channel_client: ChannelClient 637 ) -> PendingRpc: 638 # Protobuf is sometimes silly so the 32 bit python bindings return 639 # signed values from `fixed32` fields. Let's convert back to unsigned. 640 # b/239712573 641 service_id = packet.service_id & 0xFFFFFFFF 642 try: 643 service = self.services[service_id] 644 except KeyError: 645 raise ValueError(f'Unrecognized service ID {service_id}') 646 647 # See above, also for b/239712573 648 method_id = packet.method_id & 0xFFFFFFFF 649 try: 650 method = service.methods[method_id] 651 except KeyError: 652 raise ValueError( 653 f'No method ID {method_id} in service {service.name}' 654 ) 655 656 return PendingRpc( 657 channel_client.channel, service, method, packet.call_id 658 ) 659 660 def __repr__(self) -> str: 661 return ( 662 f'pw_rpc.Client(channels={list(self._channels_by_id)}, ' 663 f'services={[s.full_name for s in self.services]})' 664 ) 665 666 667def _send_client_error( 668 client: ChannelClient, packet: RpcPacket, error: Status 669) -> None: 670 # Never send responses to SERVER_ERRORs. 671 if packet.type != PacketType.SERVER_ERROR: 672 client.channel.output( # type: ignore 673 packets.encode_client_error(packet, error) 674 ) 675