• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Provides a pw_rpc client for Python."""
15
16from __future__ import annotations
17
18import abc
19from dataclasses import dataclass
20import logging
21from typing import (
22    Any,
23    Callable,
24    Collection,
25    Iterable,
26    Iterator,
27)
28
29from google.protobuf.message import DecodeError, Message
30from pw_status import Status
31
32from pw_rpc import descriptors, packets
33from pw_rpc.descriptors import Channel, Service, Method
34from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket
35
36_LOG = logging.getLogger(__package__)
37
38# Calls with ID of `kOpenCallId` were unrequested, and are updated to have the
39# call ID of the first matching request.
40LEGACY_OPEN_CALL_ID: int = 0
41OPEN_CALL_ID: int = (2**32) - 1
42
43_MAX_CALL_ID: int = 1 << 14
44
45
46class Error(Exception):
47    """Error from incorrectly using the RPC client classes."""
48
49
50class PendingRpc(packets.RpcIds):
51    """Uniquely identifies an RPC call.
52
53    Attributes:
54      channel: Channel
55      service: Service
56      method: Method
57      channel_id: int
58      service_id: int
59      method_id: int
60      call_id: int
61    """
62
63    def __init__(
64        self,
65        channel: Channel,
66        service: Service,
67        method: Method,
68        call_id: int,
69    ) -> None:
70        super().__init__(channel.id, service.id, method.id, call_id)
71        self.channel = channel
72        self.service = service
73        self.method = method
74
75
76class _PendingRpcMetadata:
77    def __init__(self, context: object):
78        self.context = context
79
80
81class PendingRpcs:
82    """Tracks pending RPCs and encodes outgoing RPC packets."""
83
84    def __init__(self) -> None:
85        self._pending: dict[PendingRpc, _PendingRpcMetadata] = {}
86        # We skip call_id = 0 in order to avoid LEGACY_OPEN_CALL_ID.
87        self._next_call_id: int = 1
88
89    def allocate_call_id(self) -> int:
90        call_id = self._next_call_id
91        self._next_call_id = (self._next_call_id + 1) % _MAX_CALL_ID
92        # We skip call_id = 0 in order to avoid LEGACY_OPEN_CALL_ID.
93        if self._next_call_id == 0:
94            self._next_call_id = 1
95        return call_id
96
97    def request(
98        self,
99        rpc: PendingRpc,
100        request: Message | None,
101        context: object,
102        override_pending: bool = True,
103    ) -> bytes:
104        """Starts the provided RPC and returns the encoded packet to send."""
105        # Ensure that every context is a unique object by wrapping it in a list.
106        self.open(rpc, context, override_pending)
107        return packets.encode_request(rpc, request)
108
109    def send_request(
110        self,
111        rpc: PendingRpc,
112        request: Message | None,
113        context: object,
114        *,
115        ignore_errors: bool = False,
116        override_pending: bool = False,
117    ) -> Any:
118        """Starts the provided RPC and sends the request packet to the channel.
119
120        Returns:
121          the previous context object or None
122        """
123        previous = self.open(rpc, context, override_pending)
124        packet = packets.encode_request(rpc, request)
125
126        # TODO(hepler): Remove `type: ignore[misc]` below when
127        #     https://github.com/python/mypy/issues/10711 is fixed.
128        if ignore_errors:
129            try:
130                rpc.channel.output(packet)  # type: ignore[misc]
131            except Exception as err:  # pylint: disable=broad-except
132                _LOG.debug('Ignoring exception when starting RPC: %s', err)
133        else:
134            rpc.channel.output(packet)  # type: ignore[misc]
135
136        return previous
137
138    def open(
139        self, rpc: PendingRpc, context: object, override_pending: bool = False
140    ) -> Any:
141        """Creates a context for an RPC, but does not invoke it.
142
143        open() can be used to receive streaming responses to an RPC that was not
144        invoked by this client. For example, a server may stream logs with a
145        server streaming RPC prior to any clients invoking it.
146
147        Returns:
148          the previous context object or None
149        """
150        _LOG.debug('Starting %s', rpc)
151        metadata = _PendingRpcMetadata(context)
152
153        if override_pending:
154            previous = self._pending.get(rpc)
155            self._pending[rpc] = metadata
156            return None if previous is None else previous.context
157
158        if self._pending.setdefault(rpc, metadata) is not metadata:
159            # If the context was not added, the RPC was already pending.
160            raise Error(
161                f'Sent request for {rpc}, but it is already pending! '
162                'Cancel the RPC before invoking it again'
163            )
164
165        return None
166
167    def send_client_stream(self, rpc: PendingRpc, message: Message) -> None:
168        if rpc not in self._pending:
169            raise Error(f'Attempt to send client stream for inactive RPC {rpc}')
170
171        rpc.channel.output(  # type: ignore
172            packets.encode_client_stream(rpc, message)
173        )
174
175    def send_client_stream_end(self, rpc: PendingRpc) -> None:
176        if rpc not in self._pending:
177            raise Error(
178                f'Attempt to send client stream end for inactive RPC {rpc}'
179            )
180
181        rpc.channel.output(  # type: ignore
182            packets.encode_client_stream_end(rpc)
183        )
184
185    def cancel(self, rpc: PendingRpc) -> bytes:
186        """Cancels the RPC.
187
188        Returns:
189          The CLIENT_ERROR packet to send.
190
191        Raises:
192          KeyError if the RPC is not pending
193        """
194        _LOG.debug('Cancelling %s', rpc)
195        del self._pending[rpc]
196
197        return packets.encode_cancel(rpc)
198
199    def send_cancel(self, rpc: PendingRpc) -> bool:
200        """Calls cancel and sends the cancel packet, if any, to the channel."""
201        try:
202            packet = self.cancel(rpc)
203        except KeyError:
204            return False
205
206        if packet:
207            rpc.channel.output(packet)  # type: ignore
208
209        return True
210
211    def get_pending(self, rpc: PendingRpc, status: Status | None):
212        """Gets the pending RPC's context. If status is set, clears the RPC."""
213        if rpc.call_id == OPEN_CALL_ID or rpc.call_id == LEGACY_OPEN_CALL_ID:
214            # Calls with ID `OPEN_CALL_ID` were unrequested, and are updated to
215            # have the call ID of the first matching request.
216            for pending in self._pending:
217                if (
218                    pending.channel == rpc.channel
219                    and pending.service == rpc.service
220                    and pending.method == rpc.method
221                ):
222                    rpc = pending
223
224        if status is None:
225            return self._pending[rpc].context
226
227        _LOG.debug('%s finished with status %s', rpc, status)
228        return self._pending.pop(rpc).context
229
230
231class ClientImpl(abc.ABC):
232    """The internal interface of the RPC client.
233
234    This interface defines the semantics for invoking an RPC on a particular
235    client.
236    """
237
238    def __init__(self) -> None:
239        self.client: Client | None = None
240        self.rpcs: PendingRpcs | None = None
241
242    @abc.abstractmethod
243    def method_client(self, channel: Channel, method: Method) -> Any:
244        """Returns an object that invokes a method using the given channel."""
245
246    @abc.abstractmethod
247    def handle_response(
248        self,
249        rpc: PendingRpc,
250        context: Any,
251        payload: Any,
252        *,
253        args: tuple = (),
254        kwargs: dict | None = None,
255    ) -> Any:
256        """Handles a response from the RPC server.
257
258        Args:
259          rpc: Information about the pending RPC
260          context: Arbitrary context object associated with the pending RPC
261          payload: A protobuf message
262          args, kwargs: Arbitrary arguments passed to the ClientImpl
263        """
264
265    @abc.abstractmethod
266    def handle_completion(
267        self,
268        rpc: PendingRpc,
269        context: Any,
270        status: Status,
271        *,
272        args: tuple = (),
273        kwargs: dict | None = None,
274    ) -> Any:
275        """Handles the successful completion of an RPC.
276
277        Args:
278          rpc: Information about the pending RPC
279          context: Arbitrary context object associated with the pending RPC
280          status: Status returned from the RPC
281          args, kwargs: Arbitrary arguments passed to the ClientImpl
282        """
283
284    @abc.abstractmethod
285    def handle_error(
286        self,
287        rpc: PendingRpc,
288        context,
289        status: Status,
290        *,
291        args: tuple = (),
292        kwargs: dict | None = None,
293    ):
294        """Handles the abnormal termination of an RPC.
295
296        args:
297          rpc: Information about the pending RPC
298          context: Arbitrary context object associated with the pending RPC
299          status: which error occurred
300          args, kwargs: Arbitrary arguments passed to the ClientImpl
301        """
302
303
304class ServiceClient(descriptors.ServiceAccessor):
305    """Navigates the methods in a service provided by a ChannelClient."""
306
307    def __init__(
308        self, client_impl: ClientImpl, channel: Channel, service: Service
309    ):
310        super().__init__(
311            {
312                method: client_impl.method_client(channel, method)
313                for method in service.methods
314            },
315            as_attrs='members',
316        )
317
318        self._channel = channel
319        self._service = service
320
321    def __repr__(self) -> str:
322        return (
323            f'Service({self._service.full_name!r}, '
324            f'methods={[m.name for m in self._service.methods]}, '
325            f'channel={self._channel.id})'
326        )
327
328    def __str__(self) -> str:
329        return str(self._service)
330
331
332class Services(descriptors.ServiceAccessor[ServiceClient]):
333    """Navigates the services provided by a ChannelClient."""
334
335    def __init__(
336        self, client_impl, channel: Channel, services: Collection[Service]
337    ):
338        super().__init__(
339            {s: ServiceClient(client_impl, channel, s) for s in services},
340            as_attrs='packages',
341        )
342
343        self._channel = channel
344        self._services = services
345
346    def __repr__(self) -> str:
347        return (
348            f'Services(channel={self._channel.id}, '
349            f'services={[s.full_name for s in self._services]})'
350        )
351
352
353def _decode_status(rpc: PendingRpc, packet) -> Status | None:
354    if packet.type == PacketType.SERVER_STREAM:
355        return None
356
357    try:
358        return Status(packet.status)
359    except ValueError:
360        _LOG.warning('Illegal status code %d for %s', packet.status, rpc)
361        return Status.UNKNOWN
362
363
364def _decode_payload(rpc: PendingRpc, packet) -> Message | None:
365    if packet.type == PacketType.SERVER_ERROR:
366        return None
367
368    # Server streaming RPCs do not send a payload with their RESPONSE packet.
369    if packet.type == PacketType.RESPONSE and rpc.method.server_streaming:
370        return None
371
372    return packets.decode_payload(packet, rpc.method.response_type)
373
374
375@dataclass(frozen=True, eq=False)
376class ChannelClient:
377    """RPC services and methods bound to a particular channel.
378
379    RPCs are invoked through service method clients. These may be accessed via
380    the `rpcs` member. Service methods use a fully qualified name: package,
381    service, method. Service methods may be selected as attributes or by
382    indexing the rpcs member by service and method name or ID.
383
384      # Access the service method client as an attribute
385      rpc = client.channel(1).rpcs.the.package.FooService.SomeMethod
386
387      # Access the service method client by string name
388      rpc = client.channel(1).rpcs[foo_service_id]['SomeMethod']
389
390    RPCs may also be accessed from their canonical name.
391
392      # Access the service method client from its full name:
393      rpc = client.channel(1).method('the.package.FooService/SomeMethod')
394
395      # Using a . instead of a / is also supported:
396      rpc = client.channel(1).method('the.package.FooService.SomeMethod')
397
398    The ClientImpl class determines the type of the service method client. A
399    synchronous RPC client might return a callable object, so an RPC could be
400    invoked directly (e.g. rpc(field1=123, field2=b'456')).
401    """
402
403    client: Client
404    channel: Channel
405    rpcs: Services
406
407    def method(self, method_name: str):
408        """Returns a method client matching the given name.
409
410        Args:
411          method_name: name as package.Service/Method or package.Service.Method.
412
413        Raises:
414          ValueError: the method name is not properly formatted
415          KeyError: the method is not present
416        """
417        return descriptors.get_method(self.rpcs, method_name)
418
419    def services(self) -> Iterator:
420        return iter(self.rpcs)
421
422    def methods(self) -> Iterator:
423        """Iterates over all method clients in this ChannelClient."""
424        for service_client in self.rpcs:
425            yield from service_client
426
427    def __repr__(self) -> str:
428        return (
429            f'ChannelClient(channel={self.channel.id}, '
430            f'services={[str(s) for s in self.services()]})'
431        )
432
433
434def _update_for_backwards_compatibility(
435    rpc: PendingRpc, packet: RpcPacket
436) -> None:
437    """Adapts server streaming RPC packets to the updated protocol if needed."""
438    # The protocol changes only affect server streaming RPCs.
439    if rpc.method.type is not Method.Type.SERVER_STREAMING:
440        return
441
442    # Prior to the introduction of SERVER_STREAM packets, RESPONSE packets with
443    # a payload were used instead. If a non-zero payload is present, assume this
444    # RESPONSE is equivalent to a SERVER_STREAM packet.
445    #
446    # Note that the payload field is not 'optional', so an empty payload is
447    # equivalent to a payload that happens to encode to zero bytes. This would
448    # only affect server streaming RPCs on the old protocol that intentionally
449    # send empty payloads, which will not be an issue in practice.
450    if packet.type == PacketType.RESPONSE and packet.payload:
451        packet.type = PacketType.SERVER_STREAM
452
453
454class Client:
455    """Sends requests and handles responses for a set of channels.
456
457    RPC invocations occur through a ChannelClient.
458
459    Users may set an optional response_callback that is called before processing
460    every response or server stream RPC packet.
461    """
462
463    @classmethod
464    def from_modules(
465        cls, impl: ClientImpl, channels: Iterable[Channel], modules: Iterable
466    ):
467        return cls(
468            impl,
469            channels,
470            (
471                Service.from_descriptor(service)
472                for module in modules
473                for service in module.DESCRIPTOR.services_by_name.values()
474            ),
475        )
476
477    def __init__(
478        self,
479        impl: ClientImpl,
480        channels: Iterable[Channel],
481        services: Iterable[Service],
482    ):
483        self._impl = impl
484        self._impl.client = self
485        self._impl.rpcs = PendingRpcs()
486
487        self.services = descriptors.Services(services)
488
489        self._channels_by_id = {
490            channel.id: ChannelClient(
491                self, channel, Services(self._impl, channel, self.services)
492            )
493            for channel in channels
494        }
495
496        # Optional function called before processing every non-error RPC packet.
497        self.response_callback: (
498            Callable[[PendingRpc, Any, Status | None], Any] | None
499        ) = None
500
501    def channel(self, channel_id: int | None = None) -> ChannelClient:
502        """Returns a ChannelClient, which is used to call RPCs on a channel.
503
504        If no channel is provided, the first channel is used.
505        """
506        if channel_id is None:
507            return next(iter(self._channels_by_id.values()))
508
509        return self._channels_by_id[channel_id]
510
511    def channels(self) -> Iterable[ChannelClient]:
512        """Accesses the ChannelClients in this client."""
513        return self._channels_by_id.values()
514
515    def method(self, method_name: str) -> Method:
516        """Returns a Method matching the given name.
517
518        Args:
519          method_name: name as package.Service/Method or package.Service.Method.
520
521        Raises:
522          ValueError: the method name is not properly formatted
523          KeyError: the method is not present
524        """
525        return descriptors.get_method(self.services, method_name)
526
527    def methods(self) -> Iterator[Method]:
528        """Iterates over all Methods supported by this client."""
529        for service in self.services:
530            yield from service.methods
531
532    def process_packet(
533        self, pw_rpc_raw_packet_data: bytes, *impl_args, **impl_kwargs
534    ) -> Status:
535        """Processes an incoming packet.
536
537        Args:
538          pw_rpc_raw_packet_data: raw binary data for exactly one RPC packet
539          impl_args: optional positional arguments passed to the ClientImpl
540          impl_kwargs: optional keyword arguments passed to the ClientImpl
541
542        Returns:
543          OK - the packet was processed by this client
544          DATA_LOSS - the packet could not be decoded
545          INVALID_ARGUMENT - the packet is for a server, not a client
546          NOT_FOUND - the packet's channel ID is not known to this client
547        """
548        try:
549            packet = packets.decode(pw_rpc_raw_packet_data)
550        except DecodeError as err:
551            _LOG.warning('Failed to decode packet: %s', err)
552            _LOG.debug('Raw packet: %r', pw_rpc_raw_packet_data)
553            return Status.DATA_LOSS
554
555        if packets.for_server(packet):
556            return Status.INVALID_ARGUMENT
557
558        try:
559            channel_client = self._channels_by_id[packet.channel_id]
560        except KeyError:
561            _LOG.warning('Unrecognized channel ID %d', packet.channel_id)
562            return Status.NOT_FOUND
563
564        try:
565            rpc = self._look_up_service_and_method(packet, channel_client)
566        except ValueError as err:
567            _send_client_error(channel_client, packet, Status.NOT_FOUND)
568            _LOG.warning('%s', err)
569            return Status.OK
570
571        _update_for_backwards_compatibility(rpc, packet)
572
573        if packet.type not in (
574            PacketType.RESPONSE,
575            PacketType.SERVER_STREAM,
576            PacketType.SERVER_ERROR,
577        ):
578            _LOG.error('%s: unexpected PacketType %s', rpc, packet.type)
579            _LOG.debug('Packet:\n%s', packet)
580            return Status.OK
581
582        status = _decode_status(rpc, packet)
583
584        try:
585            payload = _decode_payload(rpc, packet)
586        except DecodeError as err:
587            _send_client_error(channel_client, packet, Status.DATA_LOSS)
588            _LOG.warning(
589                'Failed to decode %s response for %s: %s',
590                rpc.method.response_type.DESCRIPTOR.full_name,
591                rpc.method.full_name,
592                err,
593            )
594            _LOG.debug('Raw payload: %s', packet.payload)
595
596            # Make this an error packet so the error handler is called.
597            packet.type = PacketType.SERVER_ERROR
598            status = Status.DATA_LOSS
599
600        # If set, call the response callback with non-error packets.
601        if self.response_callback and packet.type != PacketType.SERVER_ERROR:
602            self.response_callback(  # pylint: disable=not-callable
603                rpc, payload, status
604            )
605
606        try:
607            assert self._impl.rpcs
608            context = self._impl.rpcs.get_pending(rpc, status)
609        except KeyError:
610            _send_client_error(
611                channel_client, packet, Status.FAILED_PRECONDITION
612            )
613            _LOG.debug('Discarding response for %s, which is not pending', rpc)
614            return Status.OK
615
616        if packet.type == PacketType.SERVER_ERROR:
617            assert status is not None and not status.ok()
618            _LOG.warning('%s: invocation failed with %s', rpc, status)
619            self._impl.handle_error(
620                rpc, context, status, args=impl_args, kwargs=impl_kwargs
621            )
622            return Status.OK
623
624        if payload is not None:
625            self._impl.handle_response(
626                rpc, context, payload, args=impl_args, kwargs=impl_kwargs
627            )
628        if status is not None:
629            self._impl.handle_completion(
630                rpc, context, status, args=impl_args, kwargs=impl_kwargs
631            )
632
633        return Status.OK
634
635    def _look_up_service_and_method(
636        self, packet: RpcPacket, channel_client: ChannelClient
637    ) -> PendingRpc:
638        # Protobuf is sometimes silly so the 32 bit python bindings return
639        # signed values from `fixed32` fields. Let's convert back to unsigned.
640        # b/239712573
641        service_id = packet.service_id & 0xFFFFFFFF
642        try:
643            service = self.services[service_id]
644        except KeyError:
645            raise ValueError(f'Unrecognized service ID {service_id}')
646
647        # See above, also for b/239712573
648        method_id = packet.method_id & 0xFFFFFFFF
649        try:
650            method = service.methods[method_id]
651        except KeyError:
652            raise ValueError(
653                f'No method ID {method_id} in service {service.name}'
654            )
655
656        return PendingRpc(
657            channel_client.channel, service, method, packet.call_id
658        )
659
660    def __repr__(self) -> str:
661        return (
662            f'pw_rpc.Client(channels={list(self._channels_by_id)}, '
663            f'services={[s.full_name for s in self.services]})'
664        )
665
666
667def _send_client_error(
668    client: ChannelClient, packet: RpcPacket, error: Status
669) -> None:
670    # Never send responses to SERVER_ERRORs.
671    if packet.type != PacketType.SERVER_ERROR:
672        client.channel.output(  # type: ignore
673            packets.encode_client_error(packet, error)
674        )
675