• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Provides a pw_rpc client for Python."""
15
16import abc
17from dataclasses import dataclass
18import logging
19from typing import (
20    Any,
21    Callable,
22    Collection,
23    Dict,
24    Iterable,
25    Iterator,
26    NamedTuple,
27    Optional,
28)
29
30from google.protobuf.message import DecodeError, Message
31from pw_status import Status
32
33from pw_rpc import descriptors, packets
34from pw_rpc.descriptors import Channel, Service, Method
35from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket
36
37_LOG = logging.getLogger(__package__)
38
39
40class Error(Exception):
41    """Error from incorrectly using the RPC client classes."""
42
43
44class PendingRpc(NamedTuple):
45    """Uniquely identifies an RPC call."""
46
47    channel: Channel
48    service: Service
49    method: Method
50
51    def __str__(self) -> str:
52        return f'PendingRpc(channel={self.channel.id}, method={self.method})'
53
54
55class _PendingRpcMetadata:
56    def __init__(self, context: object):
57        self.context = context
58
59
60class PendingRpcs:
61    """Tracks pending RPCs and encodes outgoing RPC packets."""
62
63    def __init__(self):
64        self._pending: Dict[PendingRpc, _PendingRpcMetadata] = {}
65
66    def request(
67        self,
68        rpc: PendingRpc,
69        request: Optional[Message],
70        context: object,
71        override_pending: bool = True,
72    ) -> bytes:
73        """Starts the provided RPC and returns the encoded packet to send."""
74        # Ensure that every context is a unique object by wrapping it in a list.
75        self.open(rpc, context, override_pending)
76        return packets.encode_request(rpc, request)
77
78    def send_request(
79        self,
80        rpc: PendingRpc,
81        request: Optional[Message],
82        context: object,
83        *,
84        ignore_errors: bool = False,
85        override_pending: bool = False,
86    ) -> Any:
87        """Starts the provided RPC and sends the request packet to the channel.
88
89        Returns:
90          the previous context object or None
91        """
92        previous = self.open(rpc, context, override_pending)
93        packet = packets.encode_request(rpc, request)
94
95        # TODO(hepler): Remove `type: ignore[misc]` below when
96        #     https://github.com/python/mypy/issues/10711 is fixed.
97        if ignore_errors:
98            try:
99                rpc.channel.output(packet)  # type: ignore[misc]
100            except Exception as err:  # pylint: disable=broad-except
101                _LOG.debug('Ignoring exception when starting RPC: %s', err)
102        else:
103            rpc.channel.output(packet)  # type: ignore[misc]
104
105        return previous
106
107    def open(
108        self, rpc: PendingRpc, context: object, override_pending: bool = False
109    ) -> Any:
110        """Creates a context for an RPC, but does not invoke it.
111
112        open() can be used to receive streaming responses to an RPC that was not
113        invoked by this client. For example, a server may stream logs with a
114        server streaming RPC prior to any clients invoking it.
115
116        Returns:
117          the previous context object or None
118        """
119        _LOG.debug('Starting %s', rpc)
120        metadata = _PendingRpcMetadata(context)
121
122        if override_pending:
123            previous = self._pending.get(rpc)
124            self._pending[rpc] = metadata
125            return None if previous is None else previous.context
126
127        if self._pending.setdefault(rpc, metadata) is not metadata:
128            # If the context was not added, the RPC was already pending.
129            raise Error(
130                f'Sent request for {rpc}, but it is already pending! '
131                'Cancel the RPC before invoking it again'
132            )
133
134        return None
135
136    def send_client_stream(self, rpc: PendingRpc, message: Message) -> None:
137        if rpc not in self._pending:
138            raise Error(f'Attempt to send client stream for inactive RPC {rpc}')
139
140        rpc.channel.output(  # type: ignore
141            packets.encode_client_stream(rpc, message)
142        )
143
144    def send_client_stream_end(self, rpc: PendingRpc) -> None:
145        if rpc not in self._pending:
146            raise Error(
147                f'Attempt to send client stream end for inactive RPC {rpc}'
148            )
149
150        rpc.channel.output(  # type: ignore
151            packets.encode_client_stream_end(rpc)
152        )
153
154    def cancel(self, rpc: PendingRpc) -> bytes:
155        """Cancels the RPC.
156
157        Returns:
158          The CLIENT_ERROR packet to send.
159
160        Raises:
161          KeyError if the RPC is not pending
162        """
163        _LOG.debug('Cancelling %s', rpc)
164        del self._pending[rpc]
165
166        return packets.encode_cancel(rpc)
167
168    def send_cancel(self, rpc: PendingRpc) -> bool:
169        """Calls cancel and sends the cancel packet, if any, to the channel."""
170        try:
171            packet = self.cancel(rpc)
172        except KeyError:
173            return False
174
175        if packet:
176            rpc.channel.output(packet)  # type: ignore
177
178        return True
179
180    def get_pending(self, rpc: PendingRpc, status: Optional[Status]):
181        """Gets the pending RPC's context. If status is set, clears the RPC."""
182        if status is None:
183            return self._pending[rpc].context
184
185        _LOG.debug('%s finished with status %s', rpc, status)
186        return self._pending.pop(rpc).context
187
188
189class ClientImpl(abc.ABC):
190    """The internal interface of the RPC client.
191
192    This interface defines the semantics for invoking an RPC on a particular
193    client.
194    """
195
196    def __init__(self):
197        self.client: Optional['Client'] = None
198        self.rpcs: Optional[PendingRpcs] = None
199
200    @abc.abstractmethod
201    def method_client(self, channel: Channel, method: Method) -> Any:
202        """Returns an object that invokes a method using the given channel."""
203
204    @abc.abstractmethod
205    def handle_response(
206        self,
207        rpc: PendingRpc,
208        context: Any,
209        payload: Any,
210        *,
211        args: tuple = (),
212        kwargs: Optional[dict] = None,
213    ) -> Any:
214        """Handles a response from the RPC server.
215
216        Args:
217          rpc: Information about the pending RPC
218          context: Arbitrary context object associated with the pending RPC
219          payload: A protobuf message
220          args, kwargs: Arbitrary arguments passed to the ClientImpl
221        """
222
223    @abc.abstractmethod
224    def handle_completion(
225        self,
226        rpc: PendingRpc,
227        context: Any,
228        status: Status,
229        *,
230        args: tuple = (),
231        kwargs: Optional[dict] = None,
232    ) -> Any:
233        """Handles the successful completion of an RPC.
234
235        Args:
236          rpc: Information about the pending RPC
237          context: Arbitrary context object associated with the pending RPC
238          status: Status returned from the RPC
239          args, kwargs: Arbitrary arguments passed to the ClientImpl
240        """
241
242    @abc.abstractmethod
243    def handle_error(
244        self,
245        rpc: PendingRpc,
246        context,
247        status: Status,
248        *,
249        args: tuple = (),
250        kwargs: Optional[dict] = None,
251    ):
252        """Handles the abnormal termination of an RPC.
253
254        args:
255          rpc: Information about the pending RPC
256          context: Arbitrary context object associated with the pending RPC
257          status: which error occurred
258          args, kwargs: Arbitrary arguments passed to the ClientImpl
259        """
260
261
262class ServiceClient(descriptors.ServiceAccessor):
263    """Navigates the methods in a service provided by a ChannelClient."""
264
265    def __init__(
266        self, client_impl: ClientImpl, channel: Channel, service: Service
267    ):
268        super().__init__(
269            {
270                method: client_impl.method_client(channel, method)
271                for method in service.methods
272            },
273            as_attrs='members',
274        )
275
276        self._channel = channel
277        self._service = service
278
279    def __repr__(self) -> str:
280        return (
281            f'Service({self._service.full_name!r}, '
282            f'methods={[m.name for m in self._service.methods]}, '
283            f'channel={self._channel.id})'
284        )
285
286    def __str__(self) -> str:
287        return str(self._service)
288
289
290class Services(descriptors.ServiceAccessor[ServiceClient]):
291    """Navigates the services provided by a ChannelClient."""
292
293    def __init__(
294        self, client_impl, channel: Channel, services: Collection[Service]
295    ):
296        super().__init__(
297            {s: ServiceClient(client_impl, channel, s) for s in services},
298            as_attrs='packages',
299        )
300
301        self._channel = channel
302        self._services = services
303
304    def __repr__(self) -> str:
305        return (
306            f'Services(channel={self._channel.id}, '
307            f'services={[s.full_name for s in self._services]})'
308        )
309
310
311def _decode_status(rpc: PendingRpc, packet) -> Optional[Status]:
312    if packet.type == PacketType.SERVER_STREAM:
313        return None
314
315    try:
316        return Status(packet.status)
317    except ValueError:
318        _LOG.warning('Illegal status code %d for %s', packet.status, rpc)
319        return Status.UNKNOWN
320
321
322def _decode_payload(rpc: PendingRpc, packet) -> Optional[Message]:
323    if packet.type == PacketType.SERVER_ERROR:
324        return None
325
326    # Server streaming RPCs do not send a payload with their RESPONSE packet.
327    if packet.type == PacketType.RESPONSE and rpc.method.server_streaming:
328        return None
329
330    return packets.decode_payload(packet, rpc.method.response_type)
331
332
333@dataclass(frozen=True, eq=False)
334class ChannelClient:
335    """RPC services and methods bound to a particular channel.
336
337    RPCs are invoked through service method clients. These may be accessed via
338    the `rpcs` member. Service methods use a fully qualified name: package,
339    service, method. Service methods may be selected as attributes or by
340    indexing the rpcs member by service and method name or ID.
341
342      # Access the service method client as an attribute
343      rpc = client.channel(1).rpcs.the.package.FooService.SomeMethod
344
345      # Access the service method client by string name
346      rpc = client.channel(1).rpcs[foo_service_id]['SomeMethod']
347
348    RPCs may also be accessed from their canonical name.
349
350      # Access the service method client from its full name:
351      rpc = client.channel(1).method('the.package.FooService/SomeMethod')
352
353      # Using a . instead of a / is also supported:
354      rpc = client.channel(1).method('the.package.FooService.SomeMethod')
355
356    The ClientImpl class determines the type of the service method client. A
357    synchronous RPC client might return a callable object, so an RPC could be
358    invoked directly (e.g. rpc(field1=123, field2=b'456')).
359    """
360
361    client: 'Client'
362    channel: Channel
363    rpcs: Services
364
365    def method(self, method_name: str):
366        """Returns a method client matching the given name.
367
368        Args:
369          method_name: name as package.Service/Method or package.Service.Method.
370
371        Raises:
372          ValueError: the method name is not properly formatted
373          KeyError: the method is not present
374        """
375        return descriptors.get_method(self.rpcs, method_name)
376
377    def services(self) -> Iterator:
378        return iter(self.rpcs)
379
380    def methods(self) -> Iterator:
381        """Iterates over all method clients in this ChannelClient."""
382        for service_client in self.rpcs:
383            yield from service_client
384
385    def __repr__(self) -> str:
386        return (
387            f'ChannelClient(channel={self.channel.id}, '
388            f'services={[str(s) for s in self.services()]})'
389        )
390
391
392def _update_for_backwards_compatibility(
393    rpc: PendingRpc, packet: RpcPacket
394) -> None:
395    """Adapts server streaming RPC packets to the updated protocol if needed."""
396    # The protocol changes only affect server streaming RPCs.
397    if rpc.method.type is not Method.Type.SERVER_STREAMING:
398        return
399
400    # Prior to the introduction of SERVER_STREAM packets, RESPONSE packets with
401    # a payload were used instead. If a non-zero payload is present, assume this
402    # RESPONSE is equivalent to a SERVER_STREAM packet.
403    #
404    # Note that the payload field is not 'optional', so an empty payload is
405    # equivalent to a payload that happens to encode to zero bytes. This would
406    # only affect server streaming RPCs on the old protocol that intentionally
407    # send empty payloads, which will not be an issue in practice.
408    if packet.type == PacketType.RESPONSE and packet.payload:
409        packet.type = PacketType.SERVER_STREAM
410
411
412class Client:
413    """Sends requests and handles responses for a set of channels.
414
415    RPC invocations occur through a ChannelClient.
416
417    Users may set an optional response_callback that is called before processing
418    every response or server stream RPC packet.
419    """
420
421    @classmethod
422    def from_modules(
423        cls, impl: ClientImpl, channels: Iterable[Channel], modules: Iterable
424    ):
425        return cls(
426            impl,
427            channels,
428            (
429                Service.from_descriptor(service)
430                for module in modules
431                for service in module.DESCRIPTOR.services_by_name.values()
432            ),
433        )
434
435    def __init__(
436        self,
437        impl: ClientImpl,
438        channels: Iterable[Channel],
439        services: Iterable[Service],
440    ):
441        self._impl = impl
442        self._impl.client = self
443        self._impl.rpcs = PendingRpcs()
444
445        self.services = descriptors.Services(services)
446
447        self._channels_by_id = {
448            channel.id: ChannelClient(
449                self, channel, Services(self._impl, channel, self.services)
450            )
451            for channel in channels
452        }
453
454        # Optional function called before processing every non-error RPC packet.
455        self.response_callback: Optional[
456            Callable[[PendingRpc, Any, Optional[Status]], Any]
457        ] = None
458
459    def channel(self, channel_id: Optional[int] = None) -> ChannelClient:
460        """Returns a ChannelClient, which is used to call RPCs on a channel.
461
462        If no channel is provided, the first channel is used.
463        """
464        if channel_id is None:
465            return next(iter(self._channels_by_id.values()))
466
467        return self._channels_by_id[channel_id]
468
469    def channels(self) -> Iterable[ChannelClient]:
470        """Accesses the ChannelClients in this client."""
471        return self._channels_by_id.values()
472
473    def method(self, method_name: str) -> Method:
474        """Returns a Method matching the given name.
475
476        Args:
477          method_name: name as package.Service/Method or package.Service.Method.
478
479        Raises:
480          ValueError: the method name is not properly formatted
481          KeyError: the method is not present
482        """
483        return descriptors.get_method(self.services, method_name)
484
485    def methods(self) -> Iterator[Method]:
486        """Iterates over all Methods supported by this client."""
487        for service in self.services:
488            yield from service.methods
489
490    def process_packet(
491        self, pw_rpc_raw_packet_data: bytes, *impl_args, **impl_kwargs
492    ) -> Status:
493        """Processes an incoming packet.
494
495        Args:
496          pw_rpc_raw_packet_data: raw binary data for exactly one RPC packet
497          impl_args: optional positional arguments passed to the ClientImpl
498          impl_kwargs: optional keyword arguments passed to the ClientImpl
499
500        Returns:
501          OK - the packet was processed by this client
502          DATA_LOSS - the packet could not be decoded
503          INVALID_ARGUMENT - the packet is for a server, not a client
504          NOT_FOUND - the packet's channel ID is not known to this client
505        """
506        try:
507            packet = packets.decode(pw_rpc_raw_packet_data)
508        except DecodeError as err:
509            _LOG.warning('Failed to decode packet: %s', err)
510            _LOG.debug('Raw packet: %r', pw_rpc_raw_packet_data)
511            return Status.DATA_LOSS
512
513        if packets.for_server(packet):
514            return Status.INVALID_ARGUMENT
515
516        try:
517            channel_client = self._channels_by_id[packet.channel_id]
518        except KeyError:
519            _LOG.warning('Unrecognized channel ID %d', packet.channel_id)
520            return Status.NOT_FOUND
521
522        try:
523            rpc = self._look_up_service_and_method(packet, channel_client)
524        except ValueError as err:
525            _send_client_error(channel_client, packet, Status.NOT_FOUND)
526            _LOG.warning('%s', err)
527            return Status.OK
528
529        _update_for_backwards_compatibility(rpc, packet)
530
531        if packet.type not in (
532            PacketType.RESPONSE,
533            PacketType.SERVER_STREAM,
534            PacketType.SERVER_ERROR,
535        ):
536            _LOG.error('%s: unexpected PacketType %s', rpc, packet.type)
537            _LOG.debug('Packet:\n%s', packet)
538            return Status.OK
539
540        status = _decode_status(rpc, packet)
541
542        try:
543            payload = _decode_payload(rpc, packet)
544        except DecodeError as err:
545            _send_client_error(channel_client, packet, Status.DATA_LOSS)
546            _LOG.warning(
547                'Failed to decode %s response for %s: %s',
548                rpc.method.response_type.DESCRIPTOR.full_name,
549                rpc.method.full_name,
550                err,
551            )
552            _LOG.debug('Raw payload: %s', packet.payload)
553
554            # Make this an error packet so the error handler is called.
555            packet.type = PacketType.SERVER_ERROR
556            status = Status.DATA_LOSS
557
558        # If set, call the response callback with non-error packets.
559        if self.response_callback and packet.type != PacketType.SERVER_ERROR:
560            self.response_callback(  # pylint: disable=not-callable
561                rpc, payload, status
562            )
563
564        try:
565            assert self._impl.rpcs
566            context = self._impl.rpcs.get_pending(rpc, status)
567        except KeyError:
568            _send_client_error(
569                channel_client, packet, Status.FAILED_PRECONDITION
570            )
571            _LOG.debug('Discarding response for %s, which is not pending', rpc)
572            return Status.OK
573
574        if packet.type == PacketType.SERVER_ERROR:
575            assert status is not None and not status.ok()
576            _LOG.warning('%s: invocation failed with %s', rpc, status)
577            self._impl.handle_error(
578                rpc, context, status, args=impl_args, kwargs=impl_kwargs
579            )
580            return Status.OK
581
582        if payload is not None:
583            self._impl.handle_response(
584                rpc, context, payload, args=impl_args, kwargs=impl_kwargs
585            )
586        if status is not None:
587            self._impl.handle_completion(
588                rpc, context, status, args=impl_args, kwargs=impl_kwargs
589            )
590
591        return Status.OK
592
593    def _look_up_service_and_method(
594        self, packet: RpcPacket, channel_client: ChannelClient
595    ) -> PendingRpc:
596        # Protobuf is sometimes silly so the 32 bit python bindings return
597        # signed values from `fixed32` fields. Let's convert back to unsigned.
598        # b/239712573
599        service_id = packet.service_id & 0xFFFFFFFF
600        try:
601            service = self.services[service_id]
602        except KeyError:
603            raise ValueError(f'Unrecognized service ID {service_id}')
604
605        # See above, also for b/239712573
606        method_id = packet.method_id & 0xFFFFFFFF
607        try:
608            method = service.methods[method_id]
609        except KeyError:
610            raise ValueError(
611                f'No method ID {method_id} in service {service.name}'
612            )
613
614        return PendingRpc(channel_client.channel, service, method)
615
616    def __repr__(self) -> str:
617        return (
618            f'pw_rpc.Client(channels={list(self._channels_by_id)}, '
619            f'services={[s.full_name for s in self.services]})'
620        )
621
622
623def _send_client_error(
624    client: ChannelClient, packet: RpcPacket, error: Status
625) -> None:
626    # Never send responses to SERVER_ERRORs.
627    if packet.type != PacketType.SERVER_ERROR:
628        client.channel.output(  # type: ignore
629            packets.encode_client_error(packet, error)
630        )
631