• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""The callback-based pw_rpc client implementation."""
15
16from __future__ import annotations
17
18import inspect
19import logging
20import textwrap
21from typing import Any, Callable, Iterable, Type
22
23from dataclasses import dataclass
24from pw_status import Status
25from google.protobuf.message import Message
26
27from pw_rpc import client, descriptors
28from pw_rpc.client import PendingRpc, PendingRpcs
29from pw_rpc.descriptors import Channel, Method, Service
30
31from pw_rpc.callback_client.call import (
32    UseDefault,
33    OptionalTimeout,
34    CallTypeT,
35    UnaryResponse,
36    StreamResponse,
37    Call,
38    UnaryCall,
39    ServerStreamingCall,
40    ClientStreamingCall,
41    BidirectionalStreamingCall,
42    OnNextCallback,
43    OnCompletedCallback,
44    OnErrorCallback,
45)
46
47_LOG = logging.getLogger(__package__)
48
49
50@dataclass(eq=True, frozen=True)
51class CallInfo:
52    method: Method
53
54    @property
55    def service(self) -> Service:
56        return self.method.service
57
58
59class _MethodClient:
60    """A method that can be invoked for a particular channel."""
61
62    def __init__(
63        self,
64        client_impl: Impl,
65        rpcs: PendingRpcs,
66        channel: Channel,
67        method: Method,
68        default_timeout_s: float | None,
69    ) -> None:
70        self._impl = client_impl
71        self._rpcs = rpcs
72        self._channel = channel
73        self._method = method
74        self.default_timeout_s: float | None = default_timeout_s
75
76    @property
77    def channel(self) -> Channel:
78        return self._channel
79
80    @property
81    def method(self) -> Method:
82        return self._method
83
84    @property
85    def service(self) -> Service:
86        return self._method.service
87
88    @property
89    def request(self) -> type:
90        """Returns the request proto class."""
91        return self.method.request_type
92
93    @property
94    def response(self) -> type:
95        """Returns the response proto class."""
96        return self.method.response_type
97
98    def __repr__(self) -> str:
99        return self.help()
100
101    def help(self) -> str:
102        """Returns a help message about this RPC."""
103        function_call = self.method.full_name + '('
104
105        docstring = inspect.getdoc(self.__call__)  # type: ignore[operator] # pylint: disable=no-member
106        assert docstring is not None
107
108        annotation = inspect.Signature.from_callable(self).return_annotation  # type: ignore[arg-type] # pylint: disable=line-too-long
109        if isinstance(annotation, type):
110            annotation = annotation.__name__
111
112        arg_sep = f',\n{" " * len(function_call)}'
113        return (
114            f'{function_call}'
115            f'{arg_sep.join(descriptors.field_help(self.method.request_type))})'
116            f'\n\n{textwrap.indent(docstring, "  ")}\n\n'
117            f'  Returns {annotation}.'
118        )
119
120    def _start_call(
121        self,
122        call_type: Type[CallTypeT],
123        request: Message | None,
124        timeout_s: OptionalTimeout,
125        on_next: OnNextCallback | None,
126        on_completed: OnCompletedCallback | None,
127        on_error: OnErrorCallback | None,
128        ignore_errors: bool = False,
129    ) -> CallTypeT:
130        """Creates the Call object and invokes the RPC using it."""
131        if timeout_s is UseDefault.VALUE:
132            timeout_s = self.default_timeout_s
133
134        if self._impl.on_call_hook:
135            self._impl.on_call_hook(CallInfo(self._method))
136
137        rpc = PendingRpc(
138            self._channel,
139            self.service,
140            self.method,
141            self._rpcs.allocate_call_id(),
142        )
143        call = call_type(
144            self._rpcs, rpc, timeout_s, on_next, on_completed, on_error
145        )
146        call._invoke(request, ignore_errors)  # pylint: disable=protected-access
147        return call
148
149    def _client_streaming_call_type(
150        self, base: Type[CallTypeT]
151    ) -> Type[CallTypeT]:
152        """Creates a client or bidirectional stream call type.
153
154        Applies the signature from the request protobuf to the send method.
155        """
156
157        def send(
158            self, _rpc_request_proto: Message | None = None, **request_fields
159        ) -> None:
160            ClientStreamingCall.send(self, _rpc_request_proto, **request_fields)
161
162        _apply_protobuf_signature(self.method, send)
163
164        return type(
165            f'{self.method.name}_{base.__name__}', (base,), dict(send=send)
166        )
167
168
169def _function_docstring(method: Method) -> str:
170    return f'''\
171Invokes the {method.full_name} {method.type.sentence_name()} RPC.
172
173This function accepts either the request protobuf fields as keyword arguments or
174a request protobuf as a positional argument.
175'''
176
177
178def _update_call_method(method: Method, function: Callable) -> None:
179    """Updates the name, docstring, and parameters to match a method."""
180    function.__name__ = method.full_name
181    function.__doc__ = _function_docstring(method)
182    _apply_protobuf_signature(method, function)
183
184
185def _apply_protobuf_signature(method: Method, function: Callable) -> None:
186    """Update a function signature to accept proto arguments.
187
188    In order to have good tab completion and help messages, update the function
189    signature to accept only keyword arguments for the proto message fields.
190    This doesn't actually change the function signature -- it just updates how
191    it appears when inspected.
192    """
193    sig = inspect.signature(function)
194
195    params = [next(iter(sig.parameters.values()))]  # Get the "self" parameter
196    params += method.request_parameters()
197    params.append(
198        inspect.Parameter('pw_rpc_timeout_s', inspect.Parameter.KEYWORD_ONLY)
199    )
200
201    function.__signature__ = sig.replace(  # type: ignore[attr-defined]
202        parameters=params
203    )
204
205
206class _UnaryMethodClient(_MethodClient):
207    def invoke(
208        self,
209        request: Message | None = None,
210        on_next: OnNextCallback | None = None,
211        on_completed: OnCompletedCallback | None = None,
212        on_error: OnErrorCallback | None = None,
213        *,
214        request_args: dict[str, Any] | None = None,
215        timeout_s: OptionalTimeout = UseDefault.VALUE,
216    ) -> UnaryCall:
217        """Invokes the unary RPC and returns a call object."""
218        return self._start_call(
219            UnaryCall,
220            self.method.get_request(request, request_args),
221            timeout_s,
222            on_next,
223            on_completed,
224            on_error,
225        )
226
227    def open(
228        self,
229        request: Message | None = None,
230        on_next: OnNextCallback | None = None,
231        on_completed: OnCompletedCallback | None = None,
232        on_error: OnErrorCallback | None = None,
233        *,
234        request_args: dict[str, Any] | None = None,
235    ) -> UnaryCall:
236        """Invokes the unary RPC and returns a call object."""
237        return self._start_call(
238            UnaryCall,
239            self.method.get_request(request, request_args),
240            None,
241            on_next,
242            on_completed,
243            on_error,
244            True,
245        )
246
247
248class _ServerStreamingMethodClient(_MethodClient):
249    def invoke(
250        self,
251        request: Message | None = None,
252        on_next: OnNextCallback | None = None,
253        on_completed: OnCompletedCallback | None = None,
254        on_error: OnErrorCallback | None = None,
255        *,
256        request_args: dict[str, Any] | None = None,
257        timeout_s: OptionalTimeout = UseDefault.VALUE,
258    ) -> ServerStreamingCall:
259        """Invokes the server streaming RPC and returns a call object."""
260        return self._start_call(
261            ServerStreamingCall,
262            self.method.get_request(request, request_args),
263            timeout_s,
264            on_next,
265            on_completed,
266            on_error,
267        )
268
269    def open(
270        self,
271        request: Message | None = None,
272        on_next: OnNextCallback | None = None,
273        on_completed: OnCompletedCallback | None = None,
274        on_error: OnErrorCallback | None = None,
275        *,
276        request_args: dict[str, Any] | None = None,
277    ) -> ServerStreamingCall:
278        """Returns a call object for the RPC, even if the RPC cannot be invoked.
279
280        Can be used to listen for responses from an RPC server that may yet be
281        available.
282        """
283        return self._start_call(
284            ServerStreamingCall,
285            self.method.get_request(request, request_args),
286            None,
287            on_next,
288            on_completed,
289            on_error,
290            True,
291        )
292
293
294class _ClientStreamingMethodClient(_MethodClient):
295    def invoke(
296        self,
297        on_next: OnNextCallback | None = None,
298        on_completed: OnCompletedCallback | None = None,
299        on_error: OnErrorCallback | None = None,
300        *,
301        timeout_s: OptionalTimeout = UseDefault.VALUE,
302    ) -> ClientStreamingCall:
303        """Invokes the client streaming RPC and returns a call object"""
304        return self._start_call(
305            self._client_streaming_call_type(ClientStreamingCall),
306            None,
307            timeout_s,
308            on_next,
309            on_completed,
310            on_error,
311            True,
312        )
313
314    def open(
315        self,
316        on_next: OnNextCallback | None = None,
317        on_completed: OnCompletedCallback | None = None,
318        on_error: OnErrorCallback | None = None,
319    ) -> ClientStreamingCall:
320        """Returns a call object for the RPC, even if the RPC cannot be invoked.
321
322        Can be used to listen for responses from an RPC server that may yet be
323        available.
324        """
325        return self._start_call(
326            self._client_streaming_call_type(ClientStreamingCall),
327            None,
328            None,
329            on_next,
330            on_completed,
331            on_error,
332            True,
333        )
334
335    def __call__(
336        self,
337        requests: Iterable[Message] = (),
338        *,
339        timeout_s: OptionalTimeout = UseDefault.VALUE,
340    ) -> UnaryResponse:
341        return self.invoke().finish_and_wait(requests, timeout_s=timeout_s)
342
343
344class _BidirectionalStreamingMethodClient(_MethodClient):
345    def invoke(
346        self,
347        on_next: OnNextCallback | None = None,
348        on_completed: OnCompletedCallback | None = None,
349        on_error: OnErrorCallback | None = None,
350        *,
351        timeout_s: OptionalTimeout = UseDefault.VALUE,
352    ) -> BidirectionalStreamingCall:
353        """Invokes the bidirectional streaming RPC and returns a call object."""
354        return self._start_call(
355            self._client_streaming_call_type(BidirectionalStreamingCall),
356            None,
357            timeout_s,
358            on_next,
359            on_completed,
360            on_error,
361        )
362
363    def open(
364        self,
365        on_next: OnNextCallback | None = None,
366        on_completed: OnCompletedCallback | None = None,
367        on_error: OnErrorCallback | None = None,
368    ) -> BidirectionalStreamingCall:
369        """Returns a call object for the RPC, even if the RPC cannot be invoked.
370
371        Can be used to listen for responses from an RPC server that may yet be
372        available.
373        """
374        return self._start_call(
375            self._client_streaming_call_type(BidirectionalStreamingCall),
376            None,
377            None,
378            on_next,
379            on_completed,
380            on_error,
381            True,
382        )
383
384    def __call__(
385        self,
386        requests: Iterable[Message] = (),
387        *,
388        timeout_s: OptionalTimeout = UseDefault.VALUE,
389    ) -> StreamResponse:
390        return self.invoke().finish_and_wait(requests, timeout_s=timeout_s)
391
392
393def _method_client_docstring(method: Method) -> str:
394    return f'''\
395Class that invokes the {method.full_name} {method.type.sentence_name()} RPC.
396
397Calling this directly invokes the RPC synchronously. The RPC can be invoked
398asynchronously using the invoke method.
399'''
400
401
402class Impl(client.ClientImpl):
403    """Callback-based ClientImpl, for use with pw_rpc.Client.
404
405    Args:
406        on_call_hook: A callable object to handle RPC method calls.
407            If hook is set, it will be called before RPC execution.
408    """
409
410    def __init__(
411        self,
412        default_unary_timeout_s: float | None = None,
413        default_stream_timeout_s: float | None = None,
414        on_call_hook: Callable[[CallInfo], Any] | None = None,
415        cancel_duplicate_calls: bool | None = True,
416    ) -> None:
417        super().__init__()
418        self._default_unary_timeout_s = default_unary_timeout_s
419        self._default_stream_timeout_s = default_stream_timeout_s
420        self.on_call_hook = on_call_hook
421        # Temporary workaround for clients that rely on mulitple in-flight
422        # instances of an RPC on the same channel, which is not supported.
423        # TODO(hepler): Remove this option when clients have updated.
424        self._cancel_duplicate_calls = cancel_duplicate_calls
425
426    @property
427    def default_unary_timeout_s(self) -> float | None:
428        return self._default_unary_timeout_s
429
430    @property
431    def default_stream_timeout_s(self) -> float | None:
432        return self._default_stream_timeout_s
433
434    def method_client(self, channel: Channel, method: Method) -> _MethodClient:
435        """Returns an object that invokes a method using the given chanel."""
436
437        # Temporarily attach the cancel_duplicate_calls option to the
438        # PendingRpcs object.
439        # TODO(hepler): Remove this workaround.
440        assert self.rpcs
441        self.rpcs.cancel_duplicate_calls = (  # type: ignore[attr-defined]
442            self._cancel_duplicate_calls
443        )
444
445        if method.type is Method.Type.UNARY:
446            return self._create_unary_method_client(
447                channel, method, self.default_unary_timeout_s
448            )
449
450        if method.type is Method.Type.SERVER_STREAMING:
451            return self._create_server_streaming_method_client(
452                channel, method, self.default_stream_timeout_s
453            )
454
455        if method.type is Method.Type.CLIENT_STREAMING:
456            return self._create_method_client(
457                _ClientStreamingMethodClient,
458                channel,
459                method,
460                self.default_unary_timeout_s,
461            )
462
463        if method.type is Method.Type.BIDIRECTIONAL_STREAMING:
464            return self._create_method_client(
465                _BidirectionalStreamingMethodClient,
466                channel,
467                method,
468                self.default_stream_timeout_s,
469            )
470
471        raise AssertionError(f'Unknown method type {method.type}')
472
473    def _create_method_client(
474        self,
475        base: type,
476        channel: Channel,
477        method: Method,
478        default_timeout_s: float | None,
479        **fields,
480    ):
481        """Creates a _MethodClient derived class customized for the method."""
482        method_client_type = type(
483            f'{method.name}{base.__name__}',
484            (base,),
485            dict(__doc__=_method_client_docstring(method), **fields),
486        )
487        return method_client_type(
488            self, self.rpcs, channel, method, default_timeout_s
489        )
490
491    def _create_unary_method_client(
492        self,
493        channel: Channel,
494        method: Method,
495        default_timeout_s: float | None,
496    ) -> _UnaryMethodClient:
497        """Creates a _UnaryMethodClient with a customized __call__ method."""
498
499        # TODO(hepler): Use / to mark the first arg as positional-only
500        #     when when Python 3.7 support is no longer required.
501        def call(
502            self: _UnaryMethodClient,
503            _rpc_request_proto: Message | None = None,
504            *,
505            pw_rpc_timeout_s: OptionalTimeout = UseDefault.VALUE,
506            **request_fields,
507        ) -> UnaryResponse:
508            return self.invoke(
509                self.method.get_request(_rpc_request_proto, request_fields)
510            ).wait(pw_rpc_timeout_s)
511
512        _update_call_method(method, call)
513        return self._create_method_client(
514            _UnaryMethodClient,
515            channel,
516            method,
517            default_timeout_s,
518            __call__=call,
519        )
520
521    def _create_server_streaming_method_client(
522        self,
523        channel: Channel,
524        method: Method,
525        default_timeout_s: float | None,
526    ) -> _ServerStreamingMethodClient:
527        """Creates _ServerStreamingMethodClient with custom __call__ method."""
528
529        # TODO(hepler): Use / to mark the first arg as positional-only
530        #     when when Python 3.7 support is no longer required.
531        def call(
532            self: _ServerStreamingMethodClient,
533            _rpc_request_proto: Message | None = None,
534            *,
535            pw_rpc_timeout_s: OptionalTimeout = UseDefault.VALUE,
536            **request_fields,
537        ) -> StreamResponse:
538            return self.invoke(
539                self.method.get_request(_rpc_request_proto, request_fields)
540            ).wait(pw_rpc_timeout_s)
541
542        _update_call_method(method, call)
543        return self._create_method_client(
544            _ServerStreamingMethodClient,
545            channel,
546            method,
547            default_timeout_s,
548            __call__=call,
549        )
550
551    def handle_response(
552        self,
553        rpc: PendingRpc,
554        context: Call,
555        payload,
556        *,
557        args: tuple = (),
558        kwargs: dict | None = None,
559    ) -> None:
560        """Invokes the callback associated with this RPC."""
561        assert not args and not kwargs, 'Forwarding args & kwargs not supported'
562        context._handle_response(payload)  # pylint: disable=protected-access
563
564    def handle_completion(
565        self,
566        rpc: PendingRpc,
567        context: Call,
568        status: Status,
569        *,
570        args: tuple = (),
571        kwargs: dict | None = None,
572    ):
573        assert not args and not kwargs, 'Forwarding args & kwargs not supported'
574        context._handle_completion(status)  # pylint: disable=protected-access
575
576    def handle_error(
577        self,
578        rpc: PendingRpc,
579        context: Call,
580        status: Status,
581        *,
582        args: tuple = (),
583        kwargs: dict | None = None,
584    ) -> None:
585        assert not args and not kwargs, 'Forwarding args & kwargs not supported'
586        context._handle_error(status)  # pylint: disable=protected-access
587