• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Classes for handling ongoing RPC calls."""
15
16import enum
17import logging
18import math
19import queue
20from typing import (Any, Callable, Iterable, Iterator, NamedTuple, Union,
21                    Optional, Sequence, TypeVar)
22
23from pw_protobuf_compiler.python_protos import proto_repr
24from pw_status import Status
25from google.protobuf.message import Message
26
27from pw_rpc.callback_client.errors import RpcTimeout, RpcError
28from pw_rpc.client import PendingRpc, PendingRpcs
29from pw_rpc.descriptors import Method
30
31_LOG = logging.getLogger(__package__)
32
33
34class UseDefault(enum.Enum):
35    """Marker for args that should use a default value, when None is valid."""
36    VALUE = 0
37
38
39CallType = TypeVar('CallType', 'UnaryCall', 'ServerStreamingCall',
40                   'ClientStreamingCall', 'BidirectionalStreamingCall')
41
42OnNextCallback = Callable[[CallType, Any], Any]
43OnCompletedCallback = Callable[[CallType, Any], Any]
44OnErrorCallback = Callable[[CallType, Any], Any]
45
46OptionalTimeout = Union[UseDefault, float, None]
47
48
49class UnaryResponse(NamedTuple):
50    """Result from a unary or client streaming RPC: status and response."""
51    status: Status
52    response: Any
53
54    def __repr__(self) -> str:
55        return f'({self.status}, {proto_repr(self.response)})'
56
57
58class StreamResponse(NamedTuple):
59    """Results from a server or bidirectional streaming RPC."""
60    status: Status
61    responses: Sequence[Any]
62
63    def __repr__(self) -> str:
64        return (f'({self.status}, '
65                f'[{", ".join(proto_repr(r) for r in self.responses)}])')
66
67
68class Call:
69    """Represents an in-progress or completed RPC call."""
70    def __init__(self, rpcs: PendingRpcs, rpc: PendingRpc,
71                 default_timeout_s: Optional[float],
72                 on_next: Optional[OnNextCallback],
73                 on_completed: Optional[OnCompletedCallback],
74                 on_error: Optional[OnErrorCallback]) -> None:
75        self._rpcs = rpcs
76        self._rpc = rpc
77        self.default_timeout_s = default_timeout_s
78
79        self.status: Optional[Status] = None
80        self.error: Optional[Status] = None
81        self._callback_exception: Optional[Exception] = None
82        self._responses: list = []
83        self._response_queue: queue.SimpleQueue = queue.SimpleQueue()
84
85        self.on_next = on_next or Call._default_response
86        self.on_completed = on_completed or Call._default_completion
87        self.on_error = on_error or Call._default_error
88
89    def _invoke(self, request: Optional[Message], ignore_errors: bool) -> None:
90        """Calls the RPC. This must be called immediately after __init__."""
91        previous = self._rpcs.send_request(self._rpc,
92                                           request,
93                                           self,
94                                           ignore_errors=ignore_errors,
95                                           override_pending=True)
96
97        # TODO(hepler): Remove the cancel_duplicate_calls option.
98        if (self._rpcs.cancel_duplicate_calls and  # type: ignore[attr-defined]
99                previous is not None and not previous.completed()):
100            previous._handle_error(Status.CANCELLED)  # pylint: disable=protected-access
101
102    def _default_response(self, response: Message) -> None:
103        _LOG.debug('%s received response: %s', self._rpc, response)
104
105    def _default_completion(self, status: Status) -> None:
106        _LOG.info('%s completed: %s', self._rpc, status)
107
108    def _default_error(self, error: Status) -> None:
109        _LOG.warning('%s terminated due to an error: %s', self._rpc, error)
110
111    @property
112    def method(self) -> Method:
113        return self._rpc.method
114
115    def completed(self) -> bool:
116        """True if the RPC call has completed, successfully or from an error."""
117        return self.status is not None or self.error is not None
118
119    def _send_client_stream(self, request_proto: Optional[Message],
120                            request_fields: dict) -> None:
121        """Sends a client to the server in the client stream.
122
123        Sending a client stream packet on a closed RPC raises an exception.
124        """
125        self._check_errors()
126
127        if self.status is not None:
128            raise RpcError(self._rpc, Status.FAILED_PRECONDITION)
129
130        self._rpcs.send_client_stream(
131            self._rpc, self.method.get_request(request_proto, request_fields))
132
133    def _finish_client_stream(self, requests: Iterable[Message]) -> None:
134        for request in requests:
135            self._send_client_stream(request, {})
136
137        if not self.completed():
138            self._rpcs.send_client_stream_end(self._rpc)
139
140    def _unary_wait(self, timeout_s: OptionalTimeout) -> UnaryResponse:
141        """Waits until the RPC has completed."""
142        for _ in self._get_responses(timeout_s=timeout_s):
143            pass
144
145        assert self.status is not None and self._responses
146        return UnaryResponse(self.status, self._responses[-1])
147
148    def _stream_wait(self, timeout_s: OptionalTimeout) -> StreamResponse:
149        """Waits until the RPC has completed."""
150        for _ in self._get_responses(timeout_s=timeout_s):
151            pass
152
153        assert self.status is not None
154        return StreamResponse(self.status, self._responses)
155
156    def _get_responses(self,
157                       *,
158                       count: int = None,
159                       timeout_s: OptionalTimeout) -> Iterator:
160        """Returns an iterator of stream responses.
161
162        Args:
163          count: Responses to read before returning; None reads all
164          timeout_s: max time in seconds to wait between responses; 0 doesn't
165              block, None blocks indefinitely
166        """
167        self._check_errors()
168
169        if self.completed() and self._response_queue.empty():
170            return
171
172        if timeout_s is UseDefault.VALUE:
173            timeout_s = self.default_timeout_s
174
175        remaining = math.inf if count is None else count
176
177        try:
178            while remaining:
179                response = self._response_queue.get(True, timeout_s)
180
181                self._check_errors()
182
183                if response is None:
184                    return
185
186                yield response
187                remaining -= 1
188        except queue.Empty:
189            raise RpcTimeout(self._rpc, timeout_s)
190
191    def cancel(self) -> bool:
192        """Cancels the RPC; returns whether the RPC was active."""
193        if self.completed():
194            return False
195
196        self.error = Status.CANCELLED
197        return self._rpcs.send_cancel(self._rpc)
198
199    def _check_errors(self) -> None:
200        if self._callback_exception:
201            raise self._callback_exception
202
203        if self.error:
204            raise RpcError(self._rpc, self.error)
205
206    def _handle_response(self, response: Any) -> None:
207        # TODO(frolv): These lists could grow very large for persistent
208        # streaming RPCs such as logs. The size should be limited.
209        self._responses.append(response)
210        self._response_queue.put(response)
211
212        self._invoke_callback('on_next', response)
213
214    def _handle_completion(self, status: Status) -> None:
215        self.status = status
216        self._response_queue.put(None)
217
218        self._invoke_callback('on_completed', status)
219
220    def _handle_error(self, error: Status) -> None:
221        self.error = error
222        self._response_queue.put(None)
223
224        self._invoke_callback('on_error', error)
225
226    def _invoke_callback(self, callback_name: str, arg: Any) -> None:
227        """Invokes a user-provided callback function for an RPC event."""
228
229        # Catch and log any exceptions from the user-provided callback so that
230        # exceptions don't terminate the thread handling RPC packets.
231        callback: Callable[[Call, Any], None] = getattr(self, callback_name)
232
233        try:
234            callback(self, arg)
235        except Exception as callback_exception:  # pylint: disable=broad-except
236            msg = (f'The {callback_name} callback ({callback}) for '
237                   f'{self._rpc} raised an exception')
238            _LOG.exception(msg)
239
240            self._callback_exception = RuntimeError(msg)
241            self._callback_exception.__cause__ = callback_exception
242
243    def __enter__(self) -> 'Call':
244        return self
245
246    def __exit__(self, exc_type, exc_value, traceback) -> None:
247        self.cancel()
248
249    def __repr__(self) -> str:
250        return f'{type(self).__name__}({self.method})'
251
252
253class UnaryCall(Call):
254    """Tracks the state of a unary RPC call."""
255    @property
256    def response(self) -> Any:
257        return self._responses[-1] if self._responses else None
258
259    def wait(self,
260             timeout_s: OptionalTimeout = UseDefault.VALUE) -> UnaryResponse:
261        return self._unary_wait(timeout_s)
262
263
264class ServerStreamingCall(Call):
265    """Tracks the state of a server streaming RPC call."""
266    @property
267    def responses(self) -> Sequence:
268        return self._responses
269
270    def wait(self,
271             timeout_s: OptionalTimeout = UseDefault.VALUE) -> StreamResponse:
272        return self._stream_wait(timeout_s)
273
274    def get_responses(
275            self,
276            *,
277            count: int = None,
278            timeout_s: OptionalTimeout = UseDefault.VALUE) -> Iterator:
279        return self._get_responses(count=count, timeout_s=timeout_s)
280
281    def __iter__(self) -> Iterator:
282        return self.get_responses()
283
284
285class ClientStreamingCall(Call):
286    """Tracks the state of a client streaming RPC call."""
287    @property
288    def response(self) -> Any:
289        return self._responses[-1] if self._responses else None
290
291    # TODO(hepler): Use / to mark the first arg as positional-only
292    #     when when Python 3.7 support is no longer required.
293    def send(self,
294             _rpc_request_proto: Message = None,
295             **request_fields) -> None:
296        """Sends client stream request to the server."""
297        self._send_client_stream(_rpc_request_proto, request_fields)
298
299    def finish_and_wait(
300            self,
301            requests: Iterable[Message] = (),
302            *,
303            timeout_s: OptionalTimeout = UseDefault.VALUE) -> UnaryResponse:
304        """Ends the client stream and waits for the RPC to complete."""
305        self._finish_client_stream(requests)
306        return self._unary_wait(timeout_s)
307
308
309class BidirectionalStreamingCall(Call):
310    """Tracks the state of a bidirectional streaming RPC call."""
311    @property
312    def responses(self) -> Sequence:
313        return self._responses
314
315    # TODO(hepler): Use / to mark the first arg as positional-only
316    #     when when Python 3.7 support is no longer required.
317    def send(self,
318             _rpc_request_proto: Message = None,
319             **request_fields) -> None:
320        """Sends a message to the server in the client stream."""
321        self._send_client_stream(_rpc_request_proto, request_fields)
322
323    def finish_and_wait(
324            self,
325            requests: Iterable[Message] = (),
326            *,
327            timeout_s: OptionalTimeout = UseDefault.VALUE) -> StreamResponse:
328        """Ends the client stream and waits for the RPC to complete."""
329        self._finish_client_stream(requests)
330        return self._stream_wait(timeout_s)
331
332    def get_responses(
333            self,
334            *,
335            count: int = None,
336            timeout_s: OptionalTimeout = UseDefault.VALUE) -> Iterator:
337        return self._get_responses(count=count, timeout_s=timeout_s)
338
339    def __iter__(self) -> Iterator:
340        return self.get_responses()
341