1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Provides a pw_rpc client for Python.""" 15 16import abc 17from dataclasses import dataclass 18import logging 19from typing import ( 20 Any, 21 Callable, 22 Collection, 23 Dict, 24 Iterable, 25 Iterator, 26 NamedTuple, 27 Optional, 28) 29 30from google.protobuf.message import DecodeError, Message 31from pw_status import Status 32 33from pw_rpc import descriptors, packets 34from pw_rpc.descriptors import Channel, Service, Method 35from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket 36 37_LOG = logging.getLogger(__package__) 38 39 40class Error(Exception): 41 """Error from incorrectly using the RPC client classes.""" 42 43 44class PendingRpc(NamedTuple): 45 """Uniquely identifies an RPC call.""" 46 47 channel: Channel 48 service: Service 49 method: Method 50 51 def __str__(self) -> str: 52 return f'PendingRpc(channel={self.channel.id}, method={self.method})' 53 54 55class _PendingRpcMetadata: 56 def __init__(self, context: object): 57 self.context = context 58 59 60class PendingRpcs: 61 """Tracks pending RPCs and encodes outgoing RPC packets.""" 62 63 def __init__(self): 64 self._pending: Dict[PendingRpc, _PendingRpcMetadata] = {} 65 66 def request( 67 self, 68 rpc: PendingRpc, 69 request: Optional[Message], 70 context: object, 71 override_pending: bool = True, 72 ) -> bytes: 73 """Starts the provided RPC and returns the encoded packet to send.""" 74 # Ensure that every context is a unique object by wrapping it in a list. 75 self.open(rpc, context, override_pending) 76 return packets.encode_request(rpc, request) 77 78 def send_request( 79 self, 80 rpc: PendingRpc, 81 request: Optional[Message], 82 context: object, 83 *, 84 ignore_errors: bool = False, 85 override_pending: bool = False, 86 ) -> Any: 87 """Starts the provided RPC and sends the request packet to the channel. 88 89 Returns: 90 the previous context object or None 91 """ 92 previous = self.open(rpc, context, override_pending) 93 packet = packets.encode_request(rpc, request) 94 95 # TODO(hepler): Remove `type: ignore[misc]` below when 96 # https://github.com/python/mypy/issues/10711 is fixed. 97 if ignore_errors: 98 try: 99 rpc.channel.output(packet) # type: ignore[misc] 100 except Exception as err: # pylint: disable=broad-except 101 _LOG.debug('Ignoring exception when starting RPC: %s', err) 102 else: 103 rpc.channel.output(packet) # type: ignore[misc] 104 105 return previous 106 107 def open( 108 self, rpc: PendingRpc, context: object, override_pending: bool = False 109 ) -> Any: 110 """Creates a context for an RPC, but does not invoke it. 111 112 open() can be used to receive streaming responses to an RPC that was not 113 invoked by this client. For example, a server may stream logs with a 114 server streaming RPC prior to any clients invoking it. 115 116 Returns: 117 the previous context object or None 118 """ 119 _LOG.debug('Starting %s', rpc) 120 metadata = _PendingRpcMetadata(context) 121 122 if override_pending: 123 previous = self._pending.get(rpc) 124 self._pending[rpc] = metadata 125 return None if previous is None else previous.context 126 127 if self._pending.setdefault(rpc, metadata) is not metadata: 128 # If the context was not added, the RPC was already pending. 129 raise Error( 130 f'Sent request for {rpc}, but it is already pending! ' 131 'Cancel the RPC before invoking it again' 132 ) 133 134 return None 135 136 def send_client_stream(self, rpc: PendingRpc, message: Message) -> None: 137 if rpc not in self._pending: 138 raise Error(f'Attempt to send client stream for inactive RPC {rpc}') 139 140 rpc.channel.output( # type: ignore 141 packets.encode_client_stream(rpc, message) 142 ) 143 144 def send_client_stream_end(self, rpc: PendingRpc) -> None: 145 if rpc not in self._pending: 146 raise Error( 147 f'Attempt to send client stream end for inactive RPC {rpc}' 148 ) 149 150 rpc.channel.output( # type: ignore 151 packets.encode_client_stream_end(rpc) 152 ) 153 154 def cancel(self, rpc: PendingRpc) -> bytes: 155 """Cancels the RPC. 156 157 Returns: 158 The CLIENT_ERROR packet to send. 159 160 Raises: 161 KeyError if the RPC is not pending 162 """ 163 _LOG.debug('Cancelling %s', rpc) 164 del self._pending[rpc] 165 166 return packets.encode_cancel(rpc) 167 168 def send_cancel(self, rpc: PendingRpc) -> bool: 169 """Calls cancel and sends the cancel packet, if any, to the channel.""" 170 try: 171 packet = self.cancel(rpc) 172 except KeyError: 173 return False 174 175 if packet: 176 rpc.channel.output(packet) # type: ignore 177 178 return True 179 180 def get_pending(self, rpc: PendingRpc, status: Optional[Status]): 181 """Gets the pending RPC's context. If status is set, clears the RPC.""" 182 if status is None: 183 return self._pending[rpc].context 184 185 _LOG.debug('%s finished with status %s', rpc, status) 186 return self._pending.pop(rpc).context 187 188 189class ClientImpl(abc.ABC): 190 """The internal interface of the RPC client. 191 192 This interface defines the semantics for invoking an RPC on a particular 193 client. 194 """ 195 196 def __init__(self): 197 self.client: Optional['Client'] = None 198 self.rpcs: Optional[PendingRpcs] = None 199 200 @abc.abstractmethod 201 def method_client(self, channel: Channel, method: Method) -> Any: 202 """Returns an object that invokes a method using the given channel.""" 203 204 @abc.abstractmethod 205 def handle_response( 206 self, 207 rpc: PendingRpc, 208 context: Any, 209 payload: Any, 210 *, 211 args: tuple = (), 212 kwargs: Optional[dict] = None, 213 ) -> Any: 214 """Handles a response from the RPC server. 215 216 Args: 217 rpc: Information about the pending RPC 218 context: Arbitrary context object associated with the pending RPC 219 payload: A protobuf message 220 args, kwargs: Arbitrary arguments passed to the ClientImpl 221 """ 222 223 @abc.abstractmethod 224 def handle_completion( 225 self, 226 rpc: PendingRpc, 227 context: Any, 228 status: Status, 229 *, 230 args: tuple = (), 231 kwargs: Optional[dict] = None, 232 ) -> Any: 233 """Handles the successful completion of an RPC. 234 235 Args: 236 rpc: Information about the pending RPC 237 context: Arbitrary context object associated with the pending RPC 238 status: Status returned from the RPC 239 args, kwargs: Arbitrary arguments passed to the ClientImpl 240 """ 241 242 @abc.abstractmethod 243 def handle_error( 244 self, 245 rpc: PendingRpc, 246 context, 247 status: Status, 248 *, 249 args: tuple = (), 250 kwargs: Optional[dict] = None, 251 ): 252 """Handles the abnormal termination of an RPC. 253 254 args: 255 rpc: Information about the pending RPC 256 context: Arbitrary context object associated with the pending RPC 257 status: which error occurred 258 args, kwargs: Arbitrary arguments passed to the ClientImpl 259 """ 260 261 262class ServiceClient(descriptors.ServiceAccessor): 263 """Navigates the methods in a service provided by a ChannelClient.""" 264 265 def __init__( 266 self, client_impl: ClientImpl, channel: Channel, service: Service 267 ): 268 super().__init__( 269 { 270 method: client_impl.method_client(channel, method) 271 for method in service.methods 272 }, 273 as_attrs='members', 274 ) 275 276 self._channel = channel 277 self._service = service 278 279 def __repr__(self) -> str: 280 return ( 281 f'Service({self._service.full_name!r}, ' 282 f'methods={[m.name for m in self._service.methods]}, ' 283 f'channel={self._channel.id})' 284 ) 285 286 def __str__(self) -> str: 287 return str(self._service) 288 289 290class Services(descriptors.ServiceAccessor[ServiceClient]): 291 """Navigates the services provided by a ChannelClient.""" 292 293 def __init__( 294 self, client_impl, channel: Channel, services: Collection[Service] 295 ): 296 super().__init__( 297 {s: ServiceClient(client_impl, channel, s) for s in services}, 298 as_attrs='packages', 299 ) 300 301 self._channel = channel 302 self._services = services 303 304 def __repr__(self) -> str: 305 return ( 306 f'Services(channel={self._channel.id}, ' 307 f'services={[s.full_name for s in self._services]})' 308 ) 309 310 311def _decode_status(rpc: PendingRpc, packet) -> Optional[Status]: 312 if packet.type == PacketType.SERVER_STREAM: 313 return None 314 315 try: 316 return Status(packet.status) 317 except ValueError: 318 _LOG.warning('Illegal status code %d for %s', packet.status, rpc) 319 return Status.UNKNOWN 320 321 322def _decode_payload(rpc: PendingRpc, packet) -> Optional[Message]: 323 if packet.type == PacketType.SERVER_ERROR: 324 return None 325 326 # Server streaming RPCs do not send a payload with their RESPONSE packet. 327 if packet.type == PacketType.RESPONSE and rpc.method.server_streaming: 328 return None 329 330 return packets.decode_payload(packet, rpc.method.response_type) 331 332 333@dataclass(frozen=True, eq=False) 334class ChannelClient: 335 """RPC services and methods bound to a particular channel. 336 337 RPCs are invoked through service method clients. These may be accessed via 338 the `rpcs` member. Service methods use a fully qualified name: package, 339 service, method. Service methods may be selected as attributes or by 340 indexing the rpcs member by service and method name or ID. 341 342 # Access the service method client as an attribute 343 rpc = client.channel(1).rpcs.the.package.FooService.SomeMethod 344 345 # Access the service method client by string name 346 rpc = client.channel(1).rpcs[foo_service_id]['SomeMethod'] 347 348 RPCs may also be accessed from their canonical name. 349 350 # Access the service method client from its full name: 351 rpc = client.channel(1).method('the.package.FooService/SomeMethod') 352 353 # Using a . instead of a / is also supported: 354 rpc = client.channel(1).method('the.package.FooService.SomeMethod') 355 356 The ClientImpl class determines the type of the service method client. A 357 synchronous RPC client might return a callable object, so an RPC could be 358 invoked directly (e.g. rpc(field1=123, field2=b'456')). 359 """ 360 361 client: 'Client' 362 channel: Channel 363 rpcs: Services 364 365 def method(self, method_name: str): 366 """Returns a method client matching the given name. 367 368 Args: 369 method_name: name as package.Service/Method or package.Service.Method. 370 371 Raises: 372 ValueError: the method name is not properly formatted 373 KeyError: the method is not present 374 """ 375 return descriptors.get_method(self.rpcs, method_name) 376 377 def services(self) -> Iterator: 378 return iter(self.rpcs) 379 380 def methods(self) -> Iterator: 381 """Iterates over all method clients in this ChannelClient.""" 382 for service_client in self.rpcs: 383 yield from service_client 384 385 def __repr__(self) -> str: 386 return ( 387 f'ChannelClient(channel={self.channel.id}, ' 388 f'services={[str(s) for s in self.services()]})' 389 ) 390 391 392def _update_for_backwards_compatibility( 393 rpc: PendingRpc, packet: RpcPacket 394) -> None: 395 """Adapts server streaming RPC packets to the updated protocol if needed.""" 396 # The protocol changes only affect server streaming RPCs. 397 if rpc.method.type is not Method.Type.SERVER_STREAMING: 398 return 399 400 # Prior to the introduction of SERVER_STREAM packets, RESPONSE packets with 401 # a payload were used instead. If a non-zero payload is present, assume this 402 # RESPONSE is equivalent to a SERVER_STREAM packet. 403 # 404 # Note that the payload field is not 'optional', so an empty payload is 405 # equivalent to a payload that happens to encode to zero bytes. This would 406 # only affect server streaming RPCs on the old protocol that intentionally 407 # send empty payloads, which will not be an issue in practice. 408 if packet.type == PacketType.RESPONSE and packet.payload: 409 packet.type = PacketType.SERVER_STREAM 410 411 412class Client: 413 """Sends requests and handles responses for a set of channels. 414 415 RPC invocations occur through a ChannelClient. 416 417 Users may set an optional response_callback that is called before processing 418 every response or server stream RPC packet. 419 """ 420 421 @classmethod 422 def from_modules( 423 cls, impl: ClientImpl, channels: Iterable[Channel], modules: Iterable 424 ): 425 return cls( 426 impl, 427 channels, 428 ( 429 Service.from_descriptor(service) 430 for module in modules 431 for service in module.DESCRIPTOR.services_by_name.values() 432 ), 433 ) 434 435 def __init__( 436 self, 437 impl: ClientImpl, 438 channels: Iterable[Channel], 439 services: Iterable[Service], 440 ): 441 self._impl = impl 442 self._impl.client = self 443 self._impl.rpcs = PendingRpcs() 444 445 self.services = descriptors.Services(services) 446 447 self._channels_by_id = { 448 channel.id: ChannelClient( 449 self, channel, Services(self._impl, channel, self.services) 450 ) 451 for channel in channels 452 } 453 454 # Optional function called before processing every non-error RPC packet. 455 self.response_callback: Optional[ 456 Callable[[PendingRpc, Any, Optional[Status]], Any] 457 ] = None 458 459 def channel(self, channel_id: Optional[int] = None) -> ChannelClient: 460 """Returns a ChannelClient, which is used to call RPCs on a channel. 461 462 If no channel is provided, the first channel is used. 463 """ 464 if channel_id is None: 465 return next(iter(self._channels_by_id.values())) 466 467 return self._channels_by_id[channel_id] 468 469 def channels(self) -> Iterable[ChannelClient]: 470 """Accesses the ChannelClients in this client.""" 471 return self._channels_by_id.values() 472 473 def method(self, method_name: str) -> Method: 474 """Returns a Method matching the given name. 475 476 Args: 477 method_name: name as package.Service/Method or package.Service.Method. 478 479 Raises: 480 ValueError: the method name is not properly formatted 481 KeyError: the method is not present 482 """ 483 return descriptors.get_method(self.services, method_name) 484 485 def methods(self) -> Iterator[Method]: 486 """Iterates over all Methods supported by this client.""" 487 for service in self.services: 488 yield from service.methods 489 490 def process_packet( 491 self, pw_rpc_raw_packet_data: bytes, *impl_args, **impl_kwargs 492 ) -> Status: 493 """Processes an incoming packet. 494 495 Args: 496 pw_rpc_raw_packet_data: raw binary data for exactly one RPC packet 497 impl_args: optional positional arguments passed to the ClientImpl 498 impl_kwargs: optional keyword arguments passed to the ClientImpl 499 500 Returns: 501 OK - the packet was processed by this client 502 DATA_LOSS - the packet could not be decoded 503 INVALID_ARGUMENT - the packet is for a server, not a client 504 NOT_FOUND - the packet's channel ID is not known to this client 505 """ 506 try: 507 packet = packets.decode(pw_rpc_raw_packet_data) 508 except DecodeError as err: 509 _LOG.warning('Failed to decode packet: %s', err) 510 _LOG.debug('Raw packet: %r', pw_rpc_raw_packet_data) 511 return Status.DATA_LOSS 512 513 if packets.for_server(packet): 514 return Status.INVALID_ARGUMENT 515 516 try: 517 channel_client = self._channels_by_id[packet.channel_id] 518 except KeyError: 519 _LOG.warning('Unrecognized channel ID %d', packet.channel_id) 520 return Status.NOT_FOUND 521 522 try: 523 rpc = self._look_up_service_and_method(packet, channel_client) 524 except ValueError as err: 525 _send_client_error(channel_client, packet, Status.NOT_FOUND) 526 _LOG.warning('%s', err) 527 return Status.OK 528 529 _update_for_backwards_compatibility(rpc, packet) 530 531 if packet.type not in ( 532 PacketType.RESPONSE, 533 PacketType.SERVER_STREAM, 534 PacketType.SERVER_ERROR, 535 ): 536 _LOG.error('%s: unexpected PacketType %s', rpc, packet.type) 537 _LOG.debug('Packet:\n%s', packet) 538 return Status.OK 539 540 status = _decode_status(rpc, packet) 541 542 try: 543 payload = _decode_payload(rpc, packet) 544 except DecodeError as err: 545 _send_client_error(channel_client, packet, Status.DATA_LOSS) 546 _LOG.warning( 547 'Failed to decode %s response for %s: %s', 548 rpc.method.response_type.DESCRIPTOR.full_name, 549 rpc.method.full_name, 550 err, 551 ) 552 _LOG.debug('Raw payload: %s', packet.payload) 553 554 # Make this an error packet so the error handler is called. 555 packet.type = PacketType.SERVER_ERROR 556 status = Status.DATA_LOSS 557 558 # If set, call the response callback with non-error packets. 559 if self.response_callback and packet.type != PacketType.SERVER_ERROR: 560 self.response_callback( # pylint: disable=not-callable 561 rpc, payload, status 562 ) 563 564 try: 565 assert self._impl.rpcs 566 context = self._impl.rpcs.get_pending(rpc, status) 567 except KeyError: 568 _send_client_error( 569 channel_client, packet, Status.FAILED_PRECONDITION 570 ) 571 _LOG.debug('Discarding response for %s, which is not pending', rpc) 572 return Status.OK 573 574 if packet.type == PacketType.SERVER_ERROR: 575 assert status is not None and not status.ok() 576 _LOG.warning('%s: invocation failed with %s', rpc, status) 577 self._impl.handle_error( 578 rpc, context, status, args=impl_args, kwargs=impl_kwargs 579 ) 580 return Status.OK 581 582 if payload is not None: 583 self._impl.handle_response( 584 rpc, context, payload, args=impl_args, kwargs=impl_kwargs 585 ) 586 if status is not None: 587 self._impl.handle_completion( 588 rpc, context, status, args=impl_args, kwargs=impl_kwargs 589 ) 590 591 return Status.OK 592 593 def _look_up_service_and_method( 594 self, packet: RpcPacket, channel_client: ChannelClient 595 ) -> PendingRpc: 596 # Protobuf is sometimes silly so the 32 bit python bindings return 597 # signed values from `fixed32` fields. Let's convert back to unsigned. 598 # b/239712573 599 service_id = packet.service_id & 0xFFFFFFFF 600 try: 601 service = self.services[service_id] 602 except KeyError: 603 raise ValueError(f'Unrecognized service ID {service_id}') 604 605 # See above, also for b/239712573 606 method_id = packet.method_id & 0xFFFFFFFF 607 try: 608 method = service.methods[method_id] 609 except KeyError: 610 raise ValueError( 611 f'No method ID {method_id} in service {service.name}' 612 ) 613 614 return PendingRpc(channel_client.channel, service, method) 615 616 def __repr__(self) -> str: 617 return ( 618 f'pw_rpc.Client(channels={list(self._channels_by_id)}, ' 619 f'services={[s.full_name for s in self.services]})' 620 ) 621 622 623def _send_client_error( 624 client: ChannelClient, packet: RpcPacket, error: Status 625) -> None: 626 # Never send responses to SERVER_ERRORs. 627 if packet.type != PacketType.SERVER_ERROR: 628 client.channel.output( # type: ignore 629 packets.encode_client_error(packet, error) 630 ) 631