• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Defines a callback-based RPC ClientImpl to use with pw_rpc.Client.
15
16callback_client.Impl supports invoking RPCs synchronously or asynchronously.
17Asynchronous invocations use a callback.
18
19Synchronous invocations look like a function call:
20
21  status, response = client.channel(1).call.MyServer.MyUnary(some_field=123)
22
23  # Streaming calls return an iterable of responses
24  for reply in client.channel(1).call.MyService.MyServerStreaming(request):
25      pass
26
27Asynchronous invocations pass a callback in addition to the request. The
28callback must be a callable that accepts a status and a payload, either of
29which may be None. The Status is only set when the RPC is completed.
30
31  callback = lambda status, payload: print('Response:', status, payload)
32
33  call = client.channel(1).call.MyServer.MyUnary.invoke(
34      callback, some_field=123)
35
36  call = client.channel(1).call.MyService.MyServerStreaming.invoke(
37      callback, request):
38
39When invoking a method, requests may be provided as a message object or as
40kwargs for the message fields (but not both).
41"""
42
43import enum
44import inspect
45import logging
46import queue
47import textwrap
48import threading
49from typing import Any, Callable, Iterator, NamedTuple, Union, Optional
50
51from pw_protobuf_compiler.python_protos import proto_repr
52from pw_status import Status
53
54from pw_rpc import client, descriptors
55from pw_rpc.client import PendingRpc, PendingRpcs
56from pw_rpc.descriptors import Channel, Method, Service
57
58_LOG = logging.getLogger(__name__)
59
60
61class UseDefault(enum.Enum):
62    """Marker for args that should use a default value, when None is valid."""
63    VALUE = 0
64
65
66OptionalTimeout = Union[UseDefault, float, None]
67
68ResponseCallback = Callable[[PendingRpc, Any], Any]
69CompletionCallback = Callable[[PendingRpc, Status], Any]
70ErrorCallback = Callable[[PendingRpc, Status], Any]
71
72
73class _Callbacks(NamedTuple):
74    response: ResponseCallback
75    completion: CompletionCallback
76    error: ErrorCallback
77
78
79def _default_response(rpc: PendingRpc, response: Any) -> None:
80    _LOG.info('%s response: %s', rpc, response)
81
82
83def _default_completion(rpc: PendingRpc, status: Status) -> None:
84    _LOG.info('%s finished: %s', rpc, status)
85
86
87def _default_error(rpc: PendingRpc, status: Status) -> None:
88    _LOG.error('%s error: %s', rpc, status)
89
90
91class _MethodClient:
92    """A method that can be invoked for a particular channel."""
93    def __init__(self, client_impl: 'Impl', rpcs: PendingRpcs,
94                 channel: Channel, method: Method,
95                 default_timeout_s: Optional[float]):
96        self._impl = client_impl
97        self._rpcs = rpcs
98        self._rpc = PendingRpc(channel, method.service, method)
99        self.default_timeout_s: Optional[float] = default_timeout_s
100
101    @property
102    def channel(self) -> Channel:
103        return self._rpc.channel
104
105    @property
106    def method(self) -> Method:
107        return self._rpc.method
108
109    @property
110    def service(self) -> Service:
111        return self._rpc.service
112
113    def invoke(self,
114               request: Any,
115               response: ResponseCallback = _default_response,
116               completion: CompletionCallback = _default_completion,
117               error: ErrorCallback = _default_error,
118               *,
119               override_pending: bool = True,
120               keep_open: bool = False) -> '_AsyncCall':
121        """Invokes an RPC with callbacks."""
122        self._rpcs.send_request(self._rpc,
123                                request,
124                                _Callbacks(response, completion, error),
125                                override_pending=override_pending,
126                                keep_open=keep_open)
127        return _AsyncCall(self._rpcs, self._rpc)
128
129    def __repr__(self) -> str:
130        return self.help()
131
132    def __call__(self):
133        raise NotImplementedError('Implemented by derived classes')
134
135    def help(self) -> str:
136        """Returns a help message about this RPC."""
137        function_call = self.method.full_name + '('
138
139        docstring = inspect.getdoc(self.__call__)
140        assert docstring is not None
141
142        annotation = inspect.Signature.from_callable(self).return_annotation
143        if isinstance(annotation, type):
144            annotation = annotation.__name__
145
146        arg_sep = f',\n{" " * len(function_call)}'
147        return (
148            f'{function_call}'
149            f'{arg_sep.join(descriptors.field_help(self.method.request_type))})'
150            f'\n\n{textwrap.indent(docstring, "  ")}\n\n'
151            f'  Returns {annotation}.')
152
153
154class RpcTimeout(Exception):
155    def __init__(self, rpc: PendingRpc, timeout: Optional[float]):
156        super().__init__(
157            f'No response received for {rpc.method} after {timeout} s')
158        self.rpc = rpc
159        self.timeout = timeout
160
161
162class RpcError(Exception):
163    def __init__(self, rpc: PendingRpc, status: Status):
164        if status is Status.NOT_FOUND:
165            msg = ': the RPC server does not support this RPC'
166        else:
167            msg = ''
168
169        super().__init__(f'{rpc.method} failed with error {status}{msg}')
170        self.rpc = rpc
171        self.status = status
172
173
174class _AsyncCall:
175    """Represents an ongoing callback-based call."""
176
177    # TODO(hepler): Consider alternatives (futures) and/or expand functionality.
178
179    def __init__(self, rpcs: PendingRpcs, rpc: PendingRpc):
180        self._rpc = rpc
181        self._rpcs = rpcs
182
183    def cancel(self) -> bool:
184        return self._rpcs.send_cancel(self._rpc)
185
186    def __enter__(self) -> '_AsyncCall':
187        return self
188
189    def __exit__(self, exc_type, exc_value, traceback) -> None:
190        self.cancel()
191
192
193class StreamingResponses:
194    """Used to iterate over a queue.SimpleQueue."""
195    def __init__(self, method_client: _MethodClient,
196                 responses: queue.SimpleQueue,
197                 default_timeout_s: OptionalTimeout):
198        self._method_client = method_client
199        self._queue = responses
200        self.status: Optional[Status] = None
201
202        if default_timeout_s is UseDefault.VALUE:
203            self.default_timeout_s = self._method_client.default_timeout_s
204        else:
205            self.default_timeout_s = default_timeout_s
206
207    @property
208    def method(self) -> Method:
209        return self._method_client.method
210
211    def cancel(self) -> None:
212        self._method_client._rpcs.send_cancel(self._method_client._rpc)  # pylint: disable=protected-access
213
214    def responses(self,
215                  *,
216                  block: bool = True,
217                  timeout_s: OptionalTimeout = UseDefault.VALUE) -> Iterator:
218        """Returns an iterator of stream responses.
219
220        Args:
221          timeout_s: timeout in seconds; None blocks indefinitely
222        """
223        if timeout_s is UseDefault.VALUE:
224            timeout_s = self.default_timeout_s
225
226        try:
227            while True:
228                response = self._queue.get(block, timeout_s)
229
230                if isinstance(response, Exception):
231                    raise response
232
233                if isinstance(response, Status):
234                    self.status = response
235                    return
236
237                yield response
238        except queue.Empty:
239            self.cancel()
240            raise RpcTimeout(self._method_client._rpc, timeout_s)  # pylint: disable=protected-access
241        except:
242            self.cancel()
243            raise
244
245    def __iter__(self):
246        return self.responses()
247
248    def __repr__(self) -> str:
249        return f'{type(self).__name__}({self.method})'
250
251
252def _method_client_docstring(method: Method) -> str:
253    return f'''\
254Class that invokes the {method.full_name} {method.type.sentence_name()} RPC.
255
256Calling this directly invokes the RPC synchronously. The RPC can be invoked
257asynchronously using the invoke method.
258'''
259
260
261def _function_docstring(method: Method) -> str:
262    return f'''\
263Invokes the {method.full_name} {method.type.sentence_name()} RPC.
264
265This function accepts either the request protobuf fields as keyword arguments or
266a request protobuf as a positional argument.
267'''
268
269
270def _update_function_signature(method: Method, function: Callable) -> None:
271    """Updates the name, docstring, and parameters to match a method."""
272    function.__name__ = method.full_name
273    function.__doc__ = _function_docstring(method)
274
275    # In order to have good tab completion and help messages, update the
276    # function signature to accept only keyword arguments for the proto message
277    # fields. This doesn't actually change the function signature -- it just
278    # updates how it appears when inspected.
279    sig = inspect.signature(function)
280
281    params = [next(iter(sig.parameters.values()))]  # Get the "self" parameter
282    params += method.request_parameters()
283    params.append(
284        inspect.Parameter('pw_rpc_timeout_s', inspect.Parameter.KEYWORD_ONLY))
285    function.__signature__ = sig.replace(  # type: ignore[attr-defined]
286        parameters=params)
287
288
289class UnaryResponse(NamedTuple):
290    """Result of invoking a unary RPC: status and response."""
291    status: Status
292    response: Any
293
294    def __repr__(self) -> str:
295        return f'({self.status}, {proto_repr(self.response)})'
296
297
298class _UnaryResponseHandler:
299    """Tracks the state of an ongoing synchronous unary RPC call."""
300    def __init__(self, rpc: PendingRpc):
301        self._rpc = rpc
302        self._response: Any = None
303        self._status: Optional[Status] = None
304        self._error: Optional[RpcError] = None
305        self._event = threading.Event()
306
307    def on_response(self, _: PendingRpc, response: Any) -> None:
308        self._response = response
309
310    def on_completion(self, _: PendingRpc, status: Status) -> None:
311        self._status = status
312        self._event.set()
313
314    def on_error(self, _: PendingRpc, status: Status) -> None:
315        self._error = RpcError(self._rpc, status)
316        self._event.set()
317
318    def wait(self, timeout_s: Optional[float]) -> UnaryResponse:
319        if not self._event.wait(timeout_s):
320            raise RpcTimeout(self._rpc, timeout_s)
321
322        if self._error is not None:
323            raise self._error
324
325        assert self._status is not None
326        return UnaryResponse(self._status, self._response)
327
328
329def _unary_method_client(client_impl: 'Impl', rpcs: PendingRpcs,
330                         channel: Channel, method: Method,
331                         default_timeout: Optional[float]) -> _MethodClient:
332    """Creates an object used to call a unary method."""
333    def call(self: _MethodClient,
334             _rpc_request_proto=None,
335             *,
336             pw_rpc_timeout_s=UseDefault.VALUE,
337             **request_fields) -> UnaryResponse:
338
339        handler = _UnaryResponseHandler(self._rpc)  # pylint: disable=protected-access
340        self.invoke(
341            self.method.get_request(_rpc_request_proto, request_fields),
342            handler.on_response, handler.on_completion, handler.on_error)
343
344        if pw_rpc_timeout_s is UseDefault.VALUE:
345            pw_rpc_timeout_s = self.default_timeout_s
346
347        return handler.wait(pw_rpc_timeout_s)
348
349    _update_function_signature(method, call)
350
351    # The MethodClient class is created dynamically so that the __call__ method
352    # can be configured differently for each method.
353    method_client_type = type(
354        f'{method.name}_UnaryMethodClient', (_MethodClient, ),
355        dict(__call__=call, __doc__=_method_client_docstring(method)))
356    return method_client_type(client_impl, rpcs, channel, method,
357                              default_timeout)
358
359
360def _server_streaming_method_client(client_impl: 'Impl', rpcs: PendingRpcs,
361                                    channel: Channel, method: Method,
362                                    default_timeout: Optional[float]):
363    """Creates an object used to call a server streaming method."""
364    def call(self: _MethodClient,
365             _rpc_request_proto=None,
366             *,
367             pw_rpc_timeout_s=UseDefault.VALUE,
368             **request_fields) -> StreamingResponses:
369        responses: queue.SimpleQueue = queue.SimpleQueue()
370        self.invoke(
371            self.method.get_request(_rpc_request_proto, request_fields),
372            lambda _, response: responses.put(response),
373            lambda _, status: responses.put(status),
374            lambda rpc, status: responses.put(RpcError(rpc, status)))
375        return StreamingResponses(self, responses, pw_rpc_timeout_s)
376
377    _update_function_signature(method, call)
378
379    # The MethodClient class is created dynamically so that the __call__ method
380    # can be configured differently for each method type.
381    method_client_type = type(
382        f'{method.name}_ServerStreamingMethodClient', (_MethodClient, ),
383        dict(__call__=call, __doc__=_method_client_docstring(method)))
384    return method_client_type(client_impl, rpcs, channel, method,
385                              default_timeout)
386
387
388class ClientStreamingMethodClient(_MethodClient):
389    def __call__(self):
390        raise NotImplementedError
391
392    def invoke(self,
393               request: Any,
394               response: ResponseCallback = _default_response,
395               completion: CompletionCallback = _default_completion,
396               error: ErrorCallback = _default_error,
397               *,
398               override_pending: bool = True,
399               keep_open: bool = False) -> _AsyncCall:
400        raise NotImplementedError
401
402
403class BidirectionalStreamingMethodClient(_MethodClient):
404    def __call__(self):
405        raise NotImplementedError
406
407    def invoke(self,
408               request: Any,
409               response: ResponseCallback = _default_response,
410               completion: CompletionCallback = _default_completion,
411               error: ErrorCallback = _default_error,
412               *,
413               override_pending: bool = True,
414               keep_open: bool = False) -> _AsyncCall:
415        raise NotImplementedError
416
417
418class Impl(client.ClientImpl):
419    """Callback-based ClientImpl."""
420    def __init__(self,
421                 default_unary_timeout_s: Optional[float] = 1.0,
422                 default_stream_timeout_s: Optional[float] = 1.0):
423        super().__init__()
424        self._default_unary_timeout_s = default_unary_timeout_s
425        self._default_stream_timeout_s = default_stream_timeout_s
426
427    @property
428    def default_unary_timeout_s(self) -> Optional[float]:
429        return self._default_unary_timeout_s
430
431    @property
432    def default_stream_timeout_s(self) -> Optional[float]:
433        return self._default_stream_timeout_s
434
435    def method_client(self, channel: Channel, method: Method) -> _MethodClient:
436        """Returns an object that invokes a method using the given chanel."""
437
438        if method.type is Method.Type.UNARY:
439            return _unary_method_client(self, self.rpcs, channel, method,
440                                        self.default_unary_timeout_s)
441
442        if method.type is Method.Type.SERVER_STREAMING:
443            return _server_streaming_method_client(
444                self, self.rpcs, channel, method,
445                self.default_stream_timeout_s)
446
447        if method.type is Method.Type.CLIENT_STREAMING:
448            return ClientStreamingMethodClient(self, self.rpcs, channel,
449                                               method,
450                                               self.default_unary_timeout_s)
451
452        if method.type is Method.Type.BIDIRECTIONAL_STREAMING:
453            return BidirectionalStreamingMethodClient(
454                self, self.rpcs, channel, method,
455                self.default_stream_timeout_s)
456
457        raise AssertionError(f'Unknown method type {method.type}')
458
459    def handle_response(self,
460                        rpc: PendingRpc,
461                        context,
462                        payload,
463                        *,
464                        args: tuple = (),
465                        kwargs: dict = None) -> None:
466        """Invokes the callback associated with this RPC.
467
468        Any additional positional and keyword args passed through
469        Client.process_packet are forwarded to the callback.
470        """
471        if kwargs is None:
472            kwargs = {}
473
474        try:
475            context.response(rpc, payload, *args, **kwargs)
476        except:  # pylint: disable=bare-except
477            self.rpcs.send_cancel(rpc)
478            _LOG.exception('Response callback %s for %s raised exception',
479                           context.response, rpc)
480
481    def handle_completion(self,
482                          rpc: PendingRpc,
483                          context,
484                          status: Status,
485                          *,
486                          args: tuple = (),
487                          kwargs: dict = None):
488        if kwargs is None:
489            kwargs = {}
490
491        try:
492            context.completion(rpc, status, *args, **kwargs)
493        except:  # pylint: disable=bare-except
494            _LOG.exception('Completion callback %s for %s raised exception',
495                           context.completion, rpc)
496
497    def handle_error(self,
498                     rpc: PendingRpc,
499                     context,
500                     status: Status,
501                     *,
502                     args: tuple = (),
503                     kwargs: dict = None) -> None:
504        if kwargs is None:
505            kwargs = {}
506
507        try:
508            context.error(rpc, status, *args, **kwargs)
509        except:  # pylint: disable=bare-except
510            _LOG.exception('Error callback %s for %s raised exception',
511                           context.error, rpc)
512