• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Classes for handling ongoing RPC calls."""
15
16import enum
17import logging
18import math
19import queue
20from typing import (
21    Any,
22    Callable,
23    Iterable,
24    Iterator,
25    NamedTuple,
26    Union,
27    Optional,
28    Sequence,
29    TypeVar,
30)
31
32from pw_protobuf_compiler.python_protos import proto_repr
33from pw_status import Status
34from google.protobuf.message import Message
35
36from pw_rpc.callback_client.errors import RpcTimeout, RpcError
37from pw_rpc.client import PendingRpc, PendingRpcs
38from pw_rpc.descriptors import Method
39
40_LOG = logging.getLogger(__package__)
41
42
43class UseDefault(enum.Enum):
44    """Marker for args that should use a default value, when None is valid."""
45
46    VALUE = 0
47
48
49CallTypeT = TypeVar(
50    'CallTypeT',
51    'UnaryCall',
52    'ServerStreamingCall',
53    'ClientStreamingCall',
54    'BidirectionalStreamingCall',
55)
56
57OnNextCallback = Callable[[CallTypeT, Any], Any]
58OnCompletedCallback = Callable[[CallTypeT, Any], Any]
59OnErrorCallback = Callable[[CallTypeT, Any], Any]
60
61OptionalTimeout = Union[UseDefault, float, None]
62
63
64class UnaryResponse(NamedTuple):
65    """Result from a unary or client streaming RPC: status and response."""
66
67    status: Status
68    response: Any
69
70    def __repr__(self) -> str:
71        reply = proto_repr(self.response) if self.response else self.response
72        return f'({self.status}, {reply})'
73
74
75class StreamResponse(NamedTuple):
76    """Results from a server or bidirectional streaming RPC."""
77
78    status: Status
79    responses: Sequence[Any]
80
81    def __repr__(self) -> str:
82        return (
83            f'({self.status}, '
84            f'[{", ".join(proto_repr(r) for r in self.responses)}])'
85        )
86
87
88class Call:
89    """Represents an in-progress or completed RPC call."""
90
91    def __init__(
92        self,
93        rpcs: PendingRpcs,
94        rpc: PendingRpc,
95        default_timeout_s: Optional[float],
96        on_next: Optional[OnNextCallback],
97        on_completed: Optional[OnCompletedCallback],
98        on_error: Optional[OnErrorCallback],
99    ) -> None:
100        self._rpcs = rpcs
101        self._rpc = rpc
102        self.default_timeout_s = default_timeout_s
103
104        self.status: Optional[Status] = None
105        self.error: Optional[Status] = None
106        self._callback_exception: Optional[Exception] = None
107        self._responses: list = []
108        self._response_queue: queue.SimpleQueue = queue.SimpleQueue()
109
110        self.on_next = on_next or Call._default_response
111        self.on_completed = on_completed or Call._default_completion
112        self.on_error = on_error or Call._default_error
113
114    def _invoke(self, request: Optional[Message], ignore_errors: bool) -> None:
115        """Calls the RPC. This must be called immediately after __init__."""
116        previous = self._rpcs.send_request(
117            self._rpc,
118            request,
119            self,
120            ignore_errors=ignore_errors,
121            override_pending=True,
122        )
123
124        # TODO(hepler): Remove the cancel_duplicate_calls option.
125        if (
126            self._rpcs.cancel_duplicate_calls  # type: ignore[attr-defined]
127            and previous is not None
128            and not previous.completed()
129        ):
130            previous._handle_error(  # pylint: disable=protected-access
131                Status.CANCELLED
132            )
133
134    def _default_response(self, response: Message) -> None:
135        _LOG.debug('%s received response: %s', self._rpc, response)
136
137    def _default_completion(self, status: Status) -> None:
138        _LOG.info('%s completed: %s', self._rpc, status)
139
140    def _default_error(self, error: Status) -> None:
141        _LOG.warning('%s terminated due to an error: %s', self._rpc, error)
142
143    @property
144    def method(self) -> Method:
145        return self._rpc.method
146
147    def completed(self) -> bool:
148        """True if the RPC call has completed, successfully or from an error."""
149        return self.status is not None or self.error is not None
150
151    def _send_client_stream(
152        self, request_proto: Optional[Message], request_fields: dict
153    ) -> None:
154        """Sends a client to the server in the client stream.
155
156        Sending a client stream packet on a closed RPC raises an exception.
157        """
158        self._check_errors()
159
160        if self.status is not None:
161            raise RpcError(self._rpc, Status.FAILED_PRECONDITION)
162
163        self._rpcs.send_client_stream(
164            self._rpc, self.method.get_request(request_proto, request_fields)
165        )
166
167    def _finish_client_stream(self, requests: Iterable[Message]) -> None:
168        for request in requests:
169            self._send_client_stream(request, {})
170
171        if not self.completed():
172            self._rpcs.send_client_stream_end(self._rpc)
173
174    def _unary_wait(self, timeout_s: OptionalTimeout) -> UnaryResponse:
175        """Waits until the RPC has completed."""
176        for _ in self._get_responses(timeout_s=timeout_s):
177            pass
178
179        assert self.status is not None and self._responses
180        return UnaryResponse(self.status, self._responses[-1])
181
182    def _stream_wait(self, timeout_s: OptionalTimeout) -> StreamResponse:
183        """Waits until the RPC has completed."""
184        for _ in self._get_responses(timeout_s=timeout_s):
185            pass
186
187        assert self.status is not None
188        return StreamResponse(self.status, self._responses)
189
190    def _get_responses(
191        self, *, count: Optional[int] = None, timeout_s: OptionalTimeout
192    ) -> Iterator:
193        """Returns an iterator of stream responses.
194
195        Args:
196          count: Responses to read before returning; None reads all
197          timeout_s: max time in seconds to wait between responses; 0 doesn't
198              block, None blocks indefinitely
199        """
200        self._check_errors()
201
202        if self.completed() and self._response_queue.empty():
203            return
204
205        if timeout_s is UseDefault.VALUE:
206            timeout_s = self.default_timeout_s
207
208        remaining = math.inf if count is None else count
209
210        try:
211            while remaining:
212                response = self._response_queue.get(True, timeout_s)
213
214                self._check_errors()
215
216                if response is None:
217                    return
218
219                yield response
220                remaining -= 1
221        except queue.Empty:
222            raise RpcTimeout(self._rpc, timeout_s)
223
224    def cancel(self) -> bool:
225        """Cancels the RPC; returns whether the RPC was active."""
226        if self.completed():
227            return False
228
229        self.error = Status.CANCELLED
230        return self._rpcs.send_cancel(self._rpc)
231
232    def _check_errors(self) -> None:
233        if self._callback_exception:
234            raise self._callback_exception
235
236        if self.error:
237            raise RpcError(self._rpc, self.error)
238
239    def _handle_response(self, response: Any) -> None:
240        # TODO(frolv): These lists could grow very large for persistent
241        # streaming RPCs such as logs. The size should be limited.
242        self._responses.append(response)
243        self._response_queue.put(response)
244
245        self._invoke_callback('on_next', response)
246
247    def _handle_completion(self, status: Status) -> None:
248        self.status = status
249        self._response_queue.put(None)
250
251        self._invoke_callback('on_completed', status)
252
253    def _handle_error(self, error: Status) -> None:
254        self.error = error
255        self._response_queue.put(None)
256
257        self._invoke_callback('on_error', error)
258
259    def _invoke_callback(self, callback_name: str, arg: Any) -> None:
260        """Invokes a user-provided callback function for an RPC event."""
261
262        # Catch and log any exceptions from the user-provided callback so that
263        # exceptions don't terminate the thread handling RPC packets.
264        callback: Callable[[Call, Any], None] = getattr(self, callback_name)
265
266        try:
267            callback(self, arg)
268        except Exception as callback_exception:  # pylint: disable=broad-except
269            msg = (
270                f'The {callback_name} callback ({callback}) for '
271                f'{self._rpc} raised an exception'
272            )
273            _LOG.exception(msg)
274
275            self._callback_exception = RuntimeError(msg)
276            self._callback_exception.__cause__ = callback_exception
277
278    def __enter__(self) -> 'Call':
279        return self
280
281    def __exit__(self, exc_type, exc_value, traceback) -> None:
282        self.cancel()
283
284    def __repr__(self) -> str:
285        return f'{type(self).__name__}({self.method})'
286
287
288class UnaryCall(Call):
289    """Tracks the state of a unary RPC call."""
290
291    @property
292    def response(self) -> Any:
293        return self._responses[-1] if self._responses else None
294
295    def wait(
296        self, timeout_s: OptionalTimeout = UseDefault.VALUE
297    ) -> UnaryResponse:
298        return self._unary_wait(timeout_s)
299
300
301class ServerStreamingCall(Call):
302    """Tracks the state of a server streaming RPC call."""
303
304    @property
305    def responses(self) -> Sequence:
306        return self._responses
307
308    def wait(
309        self, timeout_s: OptionalTimeout = UseDefault.VALUE
310    ) -> StreamResponse:
311        return self._stream_wait(timeout_s)
312
313    def get_responses(
314        self,
315        *,
316        count: Optional[int] = None,
317        timeout_s: OptionalTimeout = UseDefault.VALUE,
318    ) -> Iterator:
319        return self._get_responses(count=count, timeout_s=timeout_s)
320
321    def __iter__(self) -> Iterator:
322        return self.get_responses()
323
324
325class ClientStreamingCall(Call):
326    """Tracks the state of a client streaming RPC call."""
327
328    @property
329    def response(self) -> Any:
330        return self._responses[-1] if self._responses else None
331
332    # TODO(hepler): Use / to mark the first arg as positional-only
333    #     when when Python 3.7 support is no longer required.
334    def send(
335        self, _rpc_request_proto: Optional[Message] = None, **request_fields
336    ) -> None:
337        """Sends client stream request to the server."""
338        self._send_client_stream(_rpc_request_proto, request_fields)
339
340    def finish_and_wait(
341        self,
342        requests: Iterable[Message] = (),
343        *,
344        timeout_s: OptionalTimeout = UseDefault.VALUE,
345    ) -> UnaryResponse:
346        """Ends the client stream and waits for the RPC to complete."""
347        self._finish_client_stream(requests)
348        return self._unary_wait(timeout_s)
349
350
351class BidirectionalStreamingCall(Call):
352    """Tracks the state of a bidirectional streaming RPC call."""
353
354    @property
355    def responses(self) -> Sequence:
356        return self._responses
357
358    # TODO(hepler): Use / to mark the first arg as positional-only
359    #     when when Python 3.7 support is no longer required.
360    def send(
361        self, _rpc_request_proto: Optional[Message] = None, **request_fields
362    ) -> None:
363        """Sends a message to the server in the client stream."""
364        self._send_client_stream(_rpc_request_proto, request_fields)
365
366    def finish_and_wait(
367        self,
368        requests: Iterable[Message] = (),
369        *,
370        timeout_s: OptionalTimeout = UseDefault.VALUE,
371    ) -> StreamResponse:
372        """Ends the client stream and waits for the RPC to complete."""
373        self._finish_client_stream(requests)
374        return self._stream_wait(timeout_s)
375
376    def get_responses(
377        self,
378        *,
379        count: Optional[int] = None,
380        timeout_s: OptionalTimeout = UseDefault.VALUE,
381    ) -> Iterator:
382        return self._get_responses(count=count, timeout_s=timeout_s)
383
384    def __iter__(self) -> Iterator:
385        return self.get_responses()
386