1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Provides a pw_rpc client for Python.""" 15 16import abc 17from dataclasses import dataclass 18import logging 19from typing import (Any, Collection, Dict, Iterable, Iterator, NamedTuple, 20 Optional) 21 22from google.protobuf.message import DecodeError 23from pw_status import Status 24 25from pw_rpc import descriptors, packets 26from pw_rpc.descriptors import Channel, Service, Method 27from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket 28 29_LOG = logging.getLogger(__name__) 30 31 32class Error(Exception): 33 """Error from incorrectly using the RPC client classes.""" 34 35 36class PendingRpc(NamedTuple): 37 """Uniquely identifies an RPC call.""" 38 channel: Channel 39 service: Service 40 method: Method 41 42 def __str__(self) -> str: 43 return f'PendingRpc(channel={self.channel.id}, method={self.method})' 44 45 46class _PendingRpcMetadata: 47 def __init__(self, context: Any, keep_open: bool): 48 self.context = context 49 self.keep_open = keep_open 50 51 52class PendingRpcs: 53 """Tracks pending RPCs and encodes outgoing RPC packets.""" 54 def __init__(self): 55 self._pending: Dict[PendingRpc, _PendingRpcMetadata] = {} 56 57 def request(self, 58 rpc: PendingRpc, 59 request, 60 context, 61 override_pending: bool = True, 62 keep_open: bool = False) -> bytes: 63 """Starts the provided RPC and returns the encoded packet to send.""" 64 # Ensure that every context is a unique object by wrapping it in a list. 65 self.open(rpc, context, override_pending, keep_open) 66 _LOG.debug('Starting %s', rpc) 67 return packets.encode_request(rpc, request) 68 69 def send_request(self, 70 rpc: PendingRpc, 71 request, 72 context, 73 override_pending: bool = False, 74 keep_open: bool = False) -> None: 75 """Calls request and sends the resulting packet to the channel.""" 76 # TODO(hepler): Remove `type: ignore` on this and similar lines when 77 # https://github.com/python/mypy/issues/5485 is fixed 78 rpc.channel.output( # type: ignore 79 self.request(rpc, request, context, override_pending, keep_open)) 80 81 def open(self, 82 rpc: PendingRpc, 83 context, 84 override_pending: bool = False, 85 keep_open: bool = False) -> None: 86 """Creates a context for an RPC, but does not invoke it. 87 88 open() can be used to receive streaming responses to an RPC that was not 89 invoked by this client. For example, a server may stream logs with a 90 server streaming RPC prior to any clients invoking it. 91 """ 92 metadata = _PendingRpcMetadata(context, keep_open) 93 94 if override_pending: 95 self._pending[rpc] = metadata 96 elif self._pending.setdefault(rpc, metadata) is not metadata: 97 # If the context was not added, the RPC was already pending. 98 raise Error(f'Sent request for {rpc}, but it is already pending! ' 99 'Cancel the RPC before invoking it again') 100 101 def cancel(self, rpc: PendingRpc) -> Optional[bytes]: 102 """Cancels the RPC. Returns the CANCEL packet to send. 103 104 Returns: 105 True if the RPC was cancelled; False if it was not pending 106 107 Raises: 108 KeyError if the RPC is not pending 109 """ 110 _LOG.debug('Cancelling %s', rpc) 111 del self._pending[rpc] 112 113 if rpc.method.type is Method.Type.UNARY: 114 return None 115 116 return packets.encode_cancel(rpc) 117 118 def send_cancel(self, rpc: PendingRpc) -> bool: 119 """Calls cancel and sends the cancel packet, if any, to the channel.""" 120 try: 121 packet = self.cancel(rpc) 122 except KeyError: 123 return False 124 125 if packet: 126 rpc.channel.output(packet) # type: ignore 127 128 return True 129 130 def get_pending(self, rpc: PendingRpc, status: Optional[Status]): 131 """Gets the pending RPC's context. If status is set, clears the RPC.""" 132 if status is None: 133 return self._pending[rpc].context 134 135 if self._pending[rpc].keep_open: 136 _LOG.debug('%s finished with status %s; keeping open', rpc, status) 137 return self._pending[rpc].context 138 139 _LOG.debug('%s finished with status %s', rpc, status) 140 return self._pending.pop(rpc).context 141 142 143class ClientImpl(abc.ABC): 144 """The internal interface of the RPC client. 145 146 This interface defines the semantics for invoking an RPC on a particular 147 client. 148 """ 149 def __init__(self): 150 self.client: 'Client' = None 151 self.rpcs: PendingRpcs = None 152 153 @abc.abstractmethod 154 def method_client(self, channel: Channel, method: Method) -> Any: 155 """Returns an object that invokes a method using the given channel.""" 156 157 @abc.abstractmethod 158 def handle_response(self, 159 rpc: PendingRpc, 160 context: Any, 161 payload: Any, 162 *, 163 args: tuple = (), 164 kwargs: dict = None) -> Any: 165 """Handles a response from the RPC server. 166 167 Args: 168 rpc: Information about the pending RPC 169 context: Arbitrary context object associated with the pending RPC 170 payload: A protobuf message 171 args, kwargs: Arbitrary arguments passed to the ClientImpl 172 """ 173 174 @abc.abstractmethod 175 def handle_completion(self, 176 rpc: PendingRpc, 177 context: Any, 178 status: Status, 179 *, 180 args: tuple = (), 181 kwargs: dict = None) -> Any: 182 """Handles the successful completion of an RPC. 183 184 Args: 185 rpc: Information about the pending RPC 186 context: Arbitrary context object associated with the pending RPC 187 status: Status returned from the RPC 188 args, kwargs: Arbitrary arguments passed to the ClientImpl 189 """ 190 191 @abc.abstractmethod 192 def handle_error(self, 193 rpc: PendingRpc, 194 context, 195 status: Status, 196 *, 197 args: tuple = (), 198 kwargs: dict = None): 199 """Handles the abnormal termination of an RPC. 200 201 args: 202 rpc: Information about the pending RPC 203 context: Arbitrary context object associated with the pending RPC 204 status: which error occurred 205 args, kwargs: Arbitrary arguments passed to the ClientImpl 206 """ 207 208 209class ServiceClient(descriptors.ServiceAccessor): 210 """Navigates the methods in a service provided by a ChannelClient.""" 211 def __init__(self, client_impl: ClientImpl, channel: Channel, 212 service: Service): 213 super().__init__( 214 { 215 method: client_impl.method_client(channel, method) 216 for method in service.methods 217 }, 218 as_attrs='members') 219 220 self._channel = channel 221 self._service = service 222 223 def __repr__(self) -> str: 224 return (f'Service({self._service.full_name!r}, ' 225 f'methods={[m.name for m in self._service.methods]}, ' 226 f'channel={self._channel.id})') 227 228 def __str__(self) -> str: 229 return str(self._service) 230 231 232class Services(descriptors.ServiceAccessor[ServiceClient]): 233 """Navigates the services provided by a ChannelClient.""" 234 def __init__(self, client_impl, channel: Channel, 235 services: Collection[Service]): 236 super().__init__( 237 {s: ServiceClient(client_impl, channel, s) 238 for s in services}, 239 as_attrs='packages') 240 241 self._channel = channel 242 self._services = services 243 244 def __repr__(self) -> str: 245 return (f'Services(channel={self._channel.id}, ' 246 f'services={[s.full_name for s in self._services]})') 247 248 249def _decode_status(rpc: PendingRpc, packet) -> Optional[Status]: 250 # Server streaming RPC packets never have a status; all other packets do. 251 if packet.type == PacketType.RESPONSE and rpc.method.server_streaming: 252 return None 253 254 try: 255 return Status(packet.status) 256 except ValueError: 257 _LOG.warning('Illegal status code %d for %s', packet.status, rpc) 258 259 return None 260 261 262def _decode_payload(rpc: PendingRpc, packet): 263 if packet.type == PacketType.RESPONSE: 264 try: 265 return packets.decode_payload(packet, rpc.method.response_type) 266 except DecodeError as err: 267 _LOG.warning('Failed to decode %s response for %s: %s', 268 rpc.method.response_type.DESCRIPTOR.full_name, 269 rpc.method.full_name, err) 270 return None 271 272 273@dataclass(frozen=True, eq=False) 274class ChannelClient: 275 """RPC services and methods bound to a particular channel. 276 277 RPCs are invoked through service method clients. These may be accessed via 278 the `rpcs` member. Service methods use a fully qualified name: package, 279 service, method. Service methods may be selected as attributes or by 280 indexing the rpcs member by service and method name or ID. 281 282 # Access the service method client as an attribute 283 rpc = client.channel(1).rpcs.the.package.FooService.SomeMethod 284 285 # Access the service method client by string name 286 rpc = client.channel(1).rpcs[foo_service_id]['SomeMethod'] 287 288 RPCs may also be accessed from their canonical name. 289 290 # Access the service method client from its full name: 291 rpc = client.channel(1).method('the.package.FooService/SomeMethod') 292 293 # Using a . instead of a / is also supported: 294 rpc = client.channel(1).method('the.package.FooService.SomeMethod') 295 296 The ClientImpl class determines the type of the service method client. A 297 synchronous RPC client might return a callable object, so an RPC could be 298 invoked directly (e.g. rpc(field1=123, field2=b'456')). 299 """ 300 client: 'Client' 301 channel: Channel 302 rpcs: Services 303 304 def method(self, method_name: str): 305 """Returns a method client matching the given name. 306 307 Args: 308 method_name: name as package.Service/Method or package.Service.Method. 309 310 Raises: 311 ValueError: the method name is not properly formatted 312 KeyError: the method is not present 313 """ 314 return descriptors.get_method(self.rpcs, method_name) 315 316 def services(self) -> Iterator: 317 return iter(self.rpcs) 318 319 def methods(self) -> Iterator: 320 """Iterates over all method clients in this ChannelClient.""" 321 for service_client in self.rpcs: 322 yield from service_client 323 324 def __repr__(self) -> str: 325 return (f'ChannelClient(channel={self.channel.id}, ' 326 f'services={[str(s) for s in self.services()]})') 327 328 329class Client: 330 """Sends requests and handles responses for a set of channels. 331 332 RPC invocations occur through a ChannelClient. 333 """ 334 @classmethod 335 def from_modules(cls, impl: ClientImpl, channels: Iterable[Channel], 336 modules: Iterable): 337 return cls( 338 impl, channels, 339 (Service.from_descriptor(service) for module in modules 340 for service in module.DESCRIPTOR.services_by_name.values())) 341 342 def __init__(self, impl: ClientImpl, channels: Iterable[Channel], 343 services: Iterable[Service]): 344 self._impl = impl 345 self._impl.client = self 346 self._impl.rpcs = PendingRpcs() 347 348 self.services = descriptors.Services(services) 349 350 self._channels_by_id = { 351 channel.id: 352 ChannelClient(self, channel, 353 Services(self._impl, channel, self.services)) 354 for channel in channels 355 } 356 357 def channel(self, channel_id: int = None) -> ChannelClient: 358 """Returns a ChannelClient, which is used to call RPCs on a channel. 359 360 If no channel is provided, the first channel is used. 361 """ 362 if channel_id is None: 363 return next(iter(self._channels_by_id.values())) 364 365 return self._channels_by_id[channel_id] 366 367 def channels(self) -> Iterable[ChannelClient]: 368 """Accesses the ChannelClients in this client.""" 369 return self._channels_by_id.values() 370 371 def method(self, method_name: str) -> Method: 372 """Returns a Method matching the given name. 373 374 Args: 375 method_name: name as package.Service/Method or package.Service.Method. 376 377 Raises: 378 ValueError: the method name is not properly formatted 379 KeyError: the method is not present 380 """ 381 return descriptors.get_method(self.services, method_name) 382 383 def methods(self) -> Iterator[Method]: 384 """Iterates over all Methods supported by this client.""" 385 for service in self.services: 386 yield from service.methods 387 388 def process_packet(self, pw_rpc_raw_packet_data: bytes, *impl_args, 389 **impl_kwargs) -> Status: 390 """Processes an incoming packet. 391 392 Args: 393 pw_rpc_raw_packet_data: raw binary data for exactly one RPC packet 394 impl_args: optional positional arguments passed to the ClientImpl 395 impl_kwargs: optional keyword arguments passed to the ClientImpl 396 397 Returns: 398 OK - the packet was processed by this client 399 DATA_LOSS - the packet could not be decoded 400 INVALID_ARGUMENT - the packet is for a server, not a client 401 NOT_FOUND - the packet's channel ID is not known to this client 402 """ 403 try: 404 packet = packets.decode(pw_rpc_raw_packet_data) 405 except DecodeError as err: 406 _LOG.warning('Failed to decode packet: %s', err) 407 _LOG.debug('Raw packet: %r', pw_rpc_raw_packet_data) 408 return Status.DATA_LOSS 409 410 if packets.for_server(packet): 411 return Status.INVALID_ARGUMENT 412 413 try: 414 channel_client = self._channels_by_id[packet.channel_id] 415 except KeyError: 416 _LOG.warning('Unrecognized channel ID %d', packet.channel_id) 417 return Status.NOT_FOUND 418 419 try: 420 rpc = self._look_up_service_and_method(packet, channel_client) 421 except ValueError as err: 422 channel_client.channel.output( # type: ignore 423 packets.encode_client_error(packet, Status.NOT_FOUND)) 424 _LOG.warning('%s', err) 425 return Status.OK 426 427 status = _decode_status(rpc, packet) 428 429 if packet.type not in (PacketType.RESPONSE, 430 PacketType.SERVER_STREAM_END, 431 PacketType.SERVER_ERROR): 432 _LOG.error('%s: unexpected PacketType %s', rpc, packet.type) 433 _LOG.debug('Packet:\n%s', packet) 434 return Status.OK 435 436 payload = _decode_payload(rpc, packet) 437 438 try: 439 context = self._impl.rpcs.get_pending(rpc, status) 440 except KeyError: 441 channel_client.channel.output( # type: ignore 442 packets.encode_client_error(packet, 443 Status.FAILED_PRECONDITION)) 444 _LOG.debug('Discarding response for %s, which is not pending', rpc) 445 return Status.OK 446 447 if packet.type == PacketType.SERVER_ERROR: 448 assert status is not None and not status.ok() 449 _LOG.warning('%s: invocation failed with %s', rpc, status) 450 self._impl.handle_error(rpc, 451 context, 452 status, 453 args=impl_args, 454 kwargs=impl_kwargs) 455 return Status.OK 456 457 if payload is not None: 458 self._impl.handle_response(rpc, 459 context, 460 payload, 461 args=impl_args, 462 kwargs=impl_kwargs) 463 if status is not None: 464 self._impl.handle_completion(rpc, 465 context, 466 status, 467 args=impl_args, 468 kwargs=impl_kwargs) 469 470 return Status.OK 471 472 def _look_up_service_and_method( 473 self, packet: RpcPacket, 474 channel_client: ChannelClient) -> PendingRpc: 475 try: 476 service = self.services[packet.service_id] 477 except KeyError: 478 raise ValueError(f'Unrecognized service ID {packet.service_id}') 479 480 try: 481 method = service.methods[packet.method_id] 482 except KeyError: 483 raise ValueError( 484 f'No method ID {packet.method_id} in service {service.name}') 485 486 return PendingRpc(channel_client.channel, service, method) 487 488 def __repr__(self) -> str: 489 return (f'pw_rpc.Client(channels={list(self._channels_by_id)}, ' 490 f'services={[s.full_name for s in self.services]})') 491