• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""The callback-based pw_rpc client implementation."""
15
16import inspect
17import logging
18import textwrap
19from typing import Any, Callable, Dict, Iterable, Optional, Type
20
21from pw_status import Status
22from google.protobuf.message import Message
23
24from pw_rpc import client, descriptors
25from pw_rpc.client import PendingRpc, PendingRpcs
26from pw_rpc.descriptors import Channel, Method, Service
27
28from pw_rpc.callback_client.call import (
29    UseDefault,
30    OptionalTimeout,
31    CallTypeT,
32    UnaryResponse,
33    StreamResponse,
34    Call,
35    UnaryCall,
36    ServerStreamingCall,
37    ClientStreamingCall,
38    BidirectionalStreamingCall,
39    OnNextCallback,
40    OnCompletedCallback,
41    OnErrorCallback,
42)
43
44_LOG = logging.getLogger(__package__)
45
46
47class _MethodClient:
48    """A method that can be invoked for a particular channel."""
49
50    def __init__(
51        self,
52        client_impl: 'Impl',
53        rpcs: PendingRpcs,
54        channel: Channel,
55        method: Method,
56        default_timeout_s: Optional[float],
57    ) -> None:
58        self._impl = client_impl
59        self._rpcs = rpcs
60        self._rpc = PendingRpc(channel, method.service, method)
61        self.default_timeout_s: Optional[float] = default_timeout_s
62
63    @property
64    def channel(self) -> Channel:
65        return self._rpc.channel
66
67    @property
68    def method(self) -> Method:
69        return self._rpc.method
70
71    @property
72    def service(self) -> Service:
73        return self._rpc.service
74
75    @property
76    def request(self) -> type:
77        """Returns the request proto class."""
78        return self.method.request_type
79
80    @property
81    def response(self) -> type:
82        """Returns the response proto class."""
83        return self.method.response_type
84
85    def __repr__(self) -> str:
86        return self.help()
87
88    def help(self) -> str:
89        """Returns a help message about this RPC."""
90        function_call = self.method.full_name + '('
91
92        docstring = inspect.getdoc(self.__call__)  # type: ignore[operator] # pylint: disable=no-member
93        assert docstring is not None
94
95        annotation = inspect.Signature.from_callable(self).return_annotation  # type: ignore[arg-type] # pylint: disable=line-too-long
96        if isinstance(annotation, type):
97            annotation = annotation.__name__
98
99        arg_sep = f',\n{" " * len(function_call)}'
100        return (
101            f'{function_call}'
102            f'{arg_sep.join(descriptors.field_help(self.method.request_type))})'
103            f'\n\n{textwrap.indent(docstring, "  ")}\n\n'
104            f'  Returns {annotation}.'
105        )
106
107    def _start_call(
108        self,
109        call_type: Type[CallTypeT],
110        request: Optional[Message],
111        timeout_s: OptionalTimeout,
112        on_next: Optional[OnNextCallback],
113        on_completed: Optional[OnCompletedCallback],
114        on_error: Optional[OnErrorCallback],
115        ignore_errors: bool = False,
116    ) -> CallTypeT:
117        """Creates the Call object and invokes the RPC using it."""
118        if timeout_s is UseDefault.VALUE:
119            timeout_s = self.default_timeout_s
120
121        call = call_type(
122            self._rpcs, self._rpc, timeout_s, on_next, on_completed, on_error
123        )
124        call._invoke(request, ignore_errors)  # pylint: disable=protected-access
125        return call
126
127    def _client_streaming_call_type(
128        self, base: Type[CallTypeT]
129    ) -> Type[CallTypeT]:
130        """Creates a client or bidirectional stream call type.
131
132        Applies the signature from the request protobuf to the send method.
133        """
134
135        def send(
136            self, _rpc_request_proto: Optional[Message] = None, **request_fields
137        ) -> None:
138            ClientStreamingCall.send(self, _rpc_request_proto, **request_fields)
139
140        _apply_protobuf_signature(self.method, send)
141
142        return type(
143            f'{self.method.name}_{base.__name__}', (base,), dict(send=send)
144        )
145
146
147def _function_docstring(method: Method) -> str:
148    return f'''\
149Invokes the {method.full_name} {method.type.sentence_name()} RPC.
150
151This function accepts either the request protobuf fields as keyword arguments or
152a request protobuf as a positional argument.
153'''
154
155
156def _update_call_method(method: Method, function: Callable) -> None:
157    """Updates the name, docstring, and parameters to match a method."""
158    function.__name__ = method.full_name
159    function.__doc__ = _function_docstring(method)
160    _apply_protobuf_signature(method, function)
161
162
163def _apply_protobuf_signature(method: Method, function: Callable) -> None:
164    """Update a function signature to accept proto arguments.
165
166    In order to have good tab completion and help messages, update the function
167    signature to accept only keyword arguments for the proto message fields.
168    This doesn't actually change the function signature -- it just updates how
169    it appears when inspected.
170    """
171    sig = inspect.signature(function)
172
173    params = [next(iter(sig.parameters.values()))]  # Get the "self" parameter
174    params += method.request_parameters()
175    params.append(
176        inspect.Parameter('pw_rpc_timeout_s', inspect.Parameter.KEYWORD_ONLY)
177    )
178
179    function.__signature__ = sig.replace(  # type: ignore[attr-defined]
180        parameters=params
181    )
182
183
184class _UnaryMethodClient(_MethodClient):
185    def invoke(
186        self,
187        request: Optional[Message] = None,
188        on_next: Optional[OnNextCallback] = None,
189        on_completed: Optional[OnCompletedCallback] = None,
190        on_error: Optional[OnErrorCallback] = None,
191        *,
192        request_args: Optional[Dict[str, Any]] = None,
193        timeout_s: OptionalTimeout = UseDefault.VALUE,
194    ) -> UnaryCall:
195        """Invokes the unary RPC and returns a call object."""
196        return self._start_call(
197            UnaryCall,
198            self.method.get_request(request, request_args),
199            timeout_s,
200            on_next,
201            on_completed,
202            on_error,
203        )
204
205    def open(
206        self,
207        request: Optional[Message] = None,
208        on_next: Optional[OnNextCallback] = None,
209        on_completed: Optional[OnCompletedCallback] = None,
210        on_error: Optional[OnErrorCallback] = None,
211        *,
212        request_args: Optional[Dict[str, Any]] = None,
213    ) -> UnaryCall:
214        """Invokes the unary RPC and returns a call object."""
215        return self._start_call(
216            UnaryCall,
217            self.method.get_request(request, request_args),
218            None,
219            on_next,
220            on_completed,
221            on_error,
222            True,
223        )
224
225
226class _ServerStreamingMethodClient(_MethodClient):
227    def invoke(
228        self,
229        request: Optional[Message] = None,
230        on_next: Optional[OnNextCallback] = None,
231        on_completed: Optional[OnCompletedCallback] = None,
232        on_error: Optional[OnErrorCallback] = None,
233        *,
234        request_args: Optional[Dict[str, Any]] = None,
235        timeout_s: OptionalTimeout = UseDefault.VALUE,
236    ) -> ServerStreamingCall:
237        """Invokes the server streaming RPC and returns a call object."""
238        return self._start_call(
239            ServerStreamingCall,
240            self.method.get_request(request, request_args),
241            timeout_s,
242            on_next,
243            on_completed,
244            on_error,
245        )
246
247    def open(
248        self,
249        request: Optional[Message] = None,
250        on_next: Optional[OnNextCallback] = None,
251        on_completed: Optional[OnCompletedCallback] = None,
252        on_error: Optional[OnErrorCallback] = None,
253        *,
254        request_args: Optional[Dict[str, Any]] = None,
255    ) -> ServerStreamingCall:
256        """Returns a call object for the RPC, even if the RPC cannot be invoked.
257
258        Can be used to listen for responses from an RPC server that may yet be
259        available.
260        """
261        return self._start_call(
262            ServerStreamingCall,
263            self.method.get_request(request, request_args),
264            None,
265            on_next,
266            on_completed,
267            on_error,
268            True,
269        )
270
271
272class _ClientStreamingMethodClient(_MethodClient):
273    def invoke(
274        self,
275        on_next: Optional[OnNextCallback] = None,
276        on_completed: Optional[OnCompletedCallback] = None,
277        on_error: Optional[OnErrorCallback] = None,
278        *,
279        timeout_s: OptionalTimeout = UseDefault.VALUE,
280    ) -> ClientStreamingCall:
281        """Invokes the client streaming RPC and returns a call object"""
282        return self._start_call(
283            self._client_streaming_call_type(ClientStreamingCall),
284            None,
285            timeout_s,
286            on_next,
287            on_completed,
288            on_error,
289            True,
290        )
291
292    def open(
293        self,
294        on_next: Optional[OnNextCallback] = None,
295        on_completed: Optional[OnCompletedCallback] = None,
296        on_error: Optional[OnErrorCallback] = None,
297    ) -> ClientStreamingCall:
298        """Returns a call object for the RPC, even if the RPC cannot be invoked.
299
300        Can be used to listen for responses from an RPC server that may yet be
301        available.
302        """
303        return self._start_call(
304            self._client_streaming_call_type(ClientStreamingCall),
305            None,
306            None,
307            on_next,
308            on_completed,
309            on_error,
310            True,
311        )
312
313    def __call__(
314        self,
315        requests: Iterable[Message] = (),
316        *,
317        timeout_s: OptionalTimeout = UseDefault.VALUE,
318    ) -> UnaryResponse:
319        return self.invoke().finish_and_wait(requests, timeout_s=timeout_s)
320
321
322class _BidirectionalStreamingMethodClient(_MethodClient):
323    def invoke(
324        self,
325        on_next: Optional[OnNextCallback] = None,
326        on_completed: Optional[OnCompletedCallback] = None,
327        on_error: Optional[OnErrorCallback] = None,
328        *,
329        timeout_s: OptionalTimeout = UseDefault.VALUE,
330    ) -> BidirectionalStreamingCall:
331        """Invokes the bidirectional streaming RPC and returns a call object."""
332        return self._start_call(
333            self._client_streaming_call_type(BidirectionalStreamingCall),
334            None,
335            timeout_s,
336            on_next,
337            on_completed,
338            on_error,
339        )
340
341    def open(
342        self,
343        on_next: Optional[OnNextCallback] = None,
344        on_completed: Optional[OnCompletedCallback] = None,
345        on_error: Optional[OnErrorCallback] = None,
346    ) -> BidirectionalStreamingCall:
347        """Returns a call object for the RPC, even if the RPC cannot be invoked.
348
349        Can be used to listen for responses from an RPC server that may yet be
350        available.
351        """
352        return self._start_call(
353            self._client_streaming_call_type(BidirectionalStreamingCall),
354            None,
355            None,
356            on_next,
357            on_completed,
358            on_error,
359            True,
360        )
361
362    def __call__(
363        self,
364        requests: Iterable[Message] = (),
365        *,
366        timeout_s: OptionalTimeout = UseDefault.VALUE,
367    ) -> StreamResponse:
368        return self.invoke().finish_and_wait(requests, timeout_s=timeout_s)
369
370
371def _method_client_docstring(method: Method) -> str:
372    return f'''\
373Class that invokes the {method.full_name} {method.type.sentence_name()} RPC.
374
375Calling this directly invokes the RPC synchronously. The RPC can be invoked
376asynchronously using the invoke method.
377'''
378
379
380class Impl(client.ClientImpl):
381    """Callback-based ClientImpl, for use with pw_rpc.Client."""
382
383    def __init__(
384        self,
385        default_unary_timeout_s: Optional[float] = None,
386        default_stream_timeout_s: Optional[float] = None,
387        cancel_duplicate_calls: Optional[bool] = True,
388    ) -> None:
389        super().__init__()
390        self._default_unary_timeout_s = default_unary_timeout_s
391        self._default_stream_timeout_s = default_stream_timeout_s
392
393        # Temporary workaround for clients that rely on mulitple in-flight
394        # instances of an RPC on the same channel, which is not supported.
395        # TODO(hepler): Remove this option when clients have updated.
396        self._cancel_duplicate_calls = cancel_duplicate_calls
397
398    @property
399    def default_unary_timeout_s(self) -> Optional[float]:
400        return self._default_unary_timeout_s
401
402    @property
403    def default_stream_timeout_s(self) -> Optional[float]:
404        return self._default_stream_timeout_s
405
406    def method_client(self, channel: Channel, method: Method) -> _MethodClient:
407        """Returns an object that invokes a method using the given chanel."""
408
409        # Temporarily attach the cancel_duplicate_calls option to the
410        # PendingRpcs object.
411        # TODO(hepler): Remove this workaround.
412        assert self.rpcs
413        self.rpcs.cancel_duplicate_calls = (  # type: ignore[attr-defined]
414            self._cancel_duplicate_calls
415        )
416
417        if method.type is Method.Type.UNARY:
418            return self._create_unary_method_client(
419                channel, method, self.default_unary_timeout_s
420            )
421
422        if method.type is Method.Type.SERVER_STREAMING:
423            return self._create_server_streaming_method_client(
424                channel, method, self.default_stream_timeout_s
425            )
426
427        if method.type is Method.Type.CLIENT_STREAMING:
428            return self._create_method_client(
429                _ClientStreamingMethodClient,
430                channel,
431                method,
432                self.default_unary_timeout_s,
433            )
434
435        if method.type is Method.Type.BIDIRECTIONAL_STREAMING:
436            return self._create_method_client(
437                _BidirectionalStreamingMethodClient,
438                channel,
439                method,
440                self.default_stream_timeout_s,
441            )
442
443        raise AssertionError(f'Unknown method type {method.type}')
444
445    def _create_method_client(
446        self,
447        base: type,
448        channel: Channel,
449        method: Method,
450        default_timeout_s: Optional[float],
451        **fields,
452    ):
453        """Creates a _MethodClient derived class customized for the method."""
454        method_client_type = type(
455            f'{method.name}{base.__name__}',
456            (base,),
457            dict(__doc__=_method_client_docstring(method), **fields),
458        )
459        return method_client_type(
460            self, self.rpcs, channel, method, default_timeout_s
461        )
462
463    def _create_unary_method_client(
464        self,
465        channel: Channel,
466        method: Method,
467        default_timeout_s: Optional[float],
468    ) -> _UnaryMethodClient:
469        """Creates a _UnaryMethodClient with a customized __call__ method."""
470
471        # TODO(hepler): Use / to mark the first arg as positional-only
472        #     when when Python 3.7 support is no longer required.
473        def call(
474            self: _UnaryMethodClient,
475            _rpc_request_proto: Optional[Message] = None,
476            *,
477            pw_rpc_timeout_s: OptionalTimeout = UseDefault.VALUE,
478            **request_fields,
479        ) -> UnaryResponse:
480            return self.invoke(
481                self.method.get_request(_rpc_request_proto, request_fields)
482            ).wait(pw_rpc_timeout_s)
483
484        _update_call_method(method, call)
485        return self._create_method_client(
486            _UnaryMethodClient,
487            channel,
488            method,
489            default_timeout_s,
490            __call__=call,
491        )
492
493    def _create_server_streaming_method_client(
494        self,
495        channel: Channel,
496        method: Method,
497        default_timeout_s: Optional[float],
498    ) -> _ServerStreamingMethodClient:
499        """Creates _ServerStreamingMethodClient with custom __call__ method."""
500
501        # TODO(hepler): Use / to mark the first arg as positional-only
502        #     when when Python 3.7 support is no longer required.
503        def call(
504            self: _ServerStreamingMethodClient,
505            _rpc_request_proto: Optional[Message] = None,
506            *,
507            pw_rpc_timeout_s: OptionalTimeout = UseDefault.VALUE,
508            **request_fields,
509        ) -> StreamResponse:
510            return self.invoke(
511                self.method.get_request(_rpc_request_proto, request_fields)
512            ).wait(pw_rpc_timeout_s)
513
514        _update_call_method(method, call)
515        return self._create_method_client(
516            _ServerStreamingMethodClient,
517            channel,
518            method,
519            default_timeout_s,
520            __call__=call,
521        )
522
523    def handle_response(
524        self,
525        rpc: PendingRpc,
526        context: Call,
527        payload,
528        *,
529        args: tuple = (),
530        kwargs: Optional[dict] = None,
531    ) -> None:
532        """Invokes the callback associated with this RPC."""
533        assert not args and not kwargs, 'Forwarding args & kwargs not supported'
534        context._handle_response(payload)  # pylint: disable=protected-access
535
536    def handle_completion(
537        self,
538        rpc: PendingRpc,
539        context: Call,
540        status: Status,
541        *,
542        args: tuple = (),
543        kwargs: Optional[dict] = None,
544    ):
545        assert not args and not kwargs, 'Forwarding args & kwargs not supported'
546        context._handle_completion(status)  # pylint: disable=protected-access
547
548    def handle_error(
549        self,
550        rpc: PendingRpc,
551        context: Call,
552        status: Status,
553        *,
554        args: tuple = (),
555        kwargs: Optional[dict] = None,
556    ) -> None:
557        assert not args and not kwargs, 'Forwarding args & kwargs not supported'
558        context._handle_error(status)  # pylint: disable=protected-access
559