• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Classes for handling ongoing RPC calls."""
15
16from __future__ import annotations
17
18import enum
19import logging
20import math
21import queue
22from typing import (
23    Any,
24    Callable,
25    Iterable,
26    Iterator,
27    NamedTuple,
28    Sequence,
29    TypeVar,
30)
31
32from pw_protobuf_compiler.python_protos import proto_repr
33from pw_status import Status
34from google.protobuf.message import Message
35
36from pw_rpc.callback_client.errors import RpcTimeout, RpcError
37from pw_rpc.client import PendingRpc, PendingRpcs
38from pw_rpc.descriptors import Method
39
40_LOG = logging.getLogger(__package__)
41
42
43class UseDefault(enum.Enum):
44    """Marker for args that should use a default value, when None is valid."""
45
46    VALUE = 0
47
48
49CallTypeT = TypeVar(
50    'CallTypeT',
51    'UnaryCall',
52    'ServerStreamingCall',
53    'ClientStreamingCall',
54    'BidirectionalStreamingCall',
55)
56
57OnNextCallback = Callable[[CallTypeT, Any], Any]
58OnCompletedCallback = Callable[[CallTypeT, Any], Any]
59OnErrorCallback = Callable[[CallTypeT, Any], Any]
60
61OptionalTimeout = UseDefault | float | None
62
63
64class UnaryResponse(NamedTuple):
65    """Result from a unary or client streaming RPC: status and response."""
66
67    status: Status
68    response: Any
69
70    def __repr__(self) -> str:
71        reply = proto_repr(self.response) if self.response else self.response
72        return f'({self.status}, {reply})'
73
74
75class StreamResponse(NamedTuple):
76    """Results from a server or bidirectional streaming RPC."""
77
78    status: Status
79    responses: Sequence[Any]
80
81    def __repr__(self) -> str:
82        return (
83            f'({self.status}, '
84            f'[{", ".join(proto_repr(r) for r in self.responses)}])'
85        )
86
87
88class Call:
89    """Represents an in-progress or completed RPC call."""
90
91    def __init__(
92        self,
93        rpcs: PendingRpcs,
94        rpc: PendingRpc,
95        default_timeout_s: float | None,
96        on_next: OnNextCallback | None,
97        on_completed: OnCompletedCallback | None,
98        on_error: OnErrorCallback | None,
99    ) -> None:
100        self._rpcs = rpcs
101        self._rpc = rpc
102        self.default_timeout_s = default_timeout_s
103
104        self.status: Status | None = None
105        self.error: Status | None = None
106        self._callback_exception: Exception | None = None
107        self._responses: list = []
108        self._response_queue: queue.SimpleQueue = queue.SimpleQueue()
109
110        self.on_next = on_next or Call._default_response
111        self.on_completed = on_completed or Call._default_completion
112        self.on_error = on_error or Call._default_error
113
114    def _invoke(self, request: Message | None, ignore_errors: bool) -> None:
115        """Calls the RPC. This must be called immediately after __init__."""
116        previous = self._rpcs.send_request(
117            self._rpc,
118            request,
119            self,
120            ignore_errors=ignore_errors,
121            override_pending=True,
122        )
123
124        # TODO(hepler): Remove the cancel_duplicate_calls option.
125        if (
126            self._rpcs.cancel_duplicate_calls  # type: ignore[attr-defined]
127            and previous is not None
128            and not previous.completed()
129        ):
130            previous._handle_error(  # pylint: disable=protected-access
131                Status.CANCELLED
132            )
133
134    def _default_response(self, response: Message) -> None:
135        _LOG.debug('%s received response: %s', self._rpc, response)
136
137    def _default_completion(self, status: Status) -> None:
138        _LOG.info('%s completed: %s', self._rpc, status)
139
140    def _default_error(self, error: Status) -> None:
141        _LOG.warning('%s terminated due to an error: %s', self._rpc, error)
142
143    @property
144    def call_id(self) -> int:
145        return self._rpc.call_id
146
147    @property
148    def method(self) -> Method:
149        return self._rpc.method
150
151    def completed(self) -> bool:
152        """True if the RPC call has completed, successfully or from an error."""
153        return self.status is not None or self.error is not None
154
155    def _send_client_stream(
156        self, request_proto: Message | None, request_fields: dict
157    ) -> None:
158        """Sends a client to the server in the client stream.
159
160        Sending a client stream packet on a closed RPC raises an exception.
161        """
162        self._check_errors()
163
164        if self.status is not None:
165            raise RpcError(self._rpc, Status.FAILED_PRECONDITION)
166
167        self._rpcs.send_client_stream(
168            self._rpc, self.method.get_request(request_proto, request_fields)
169        )
170
171    def _finish_client_stream(self, requests: Iterable[Message]) -> None:
172        for request in requests:
173            self._send_client_stream(request, {})
174
175        if not self.completed():
176            self._rpcs.send_client_stream_end(self._rpc)
177
178    def _unary_wait(self, timeout_s: OptionalTimeout) -> UnaryResponse:
179        """Waits until the RPC has completed."""
180        for _ in self._get_responses(timeout_s=timeout_s):
181            pass
182
183        assert self.status is not None and self._responses
184        return UnaryResponse(self.status, self._responses[-1])
185
186    def _stream_wait(self, timeout_s: OptionalTimeout) -> StreamResponse:
187        """Waits until the RPC has completed."""
188        for _ in self._get_responses(timeout_s=timeout_s):
189            pass
190
191        assert self.status is not None
192        return StreamResponse(self.status, self._responses)
193
194    def _get_responses(
195        self, *, count: int | None = None, timeout_s: OptionalTimeout
196    ) -> Iterator:
197        """Returns an iterator of stream responses.
198
199        Args:
200          count: Responses to read before returning; None reads all
201          timeout_s: max time in seconds to wait between responses; 0 doesn't
202              block, None blocks indefinitely
203        """
204        self._check_errors()
205
206        if self.completed() and self._response_queue.empty():
207            return
208
209        if timeout_s is UseDefault.VALUE:
210            timeout_s = self.default_timeout_s
211
212        remaining = math.inf if count is None else count
213
214        try:
215            while remaining:
216                response = self._response_queue.get(True, timeout_s)
217
218                self._check_errors()
219
220                if response is None:
221                    return
222
223                yield response
224                remaining -= 1
225        except queue.Empty:
226            raise RpcTimeout(self._rpc, timeout_s)
227
228    def cancel(self) -> bool:
229        """Cancels the RPC; returns whether the RPC was active."""
230        if self.completed():
231            return False
232
233        self.error = Status.CANCELLED
234        return self._rpcs.send_cancel(self._rpc)
235
236    def _check_errors(self) -> None:
237        if self._callback_exception:
238            raise self._callback_exception
239
240        if self.error:
241            raise RpcError(self._rpc, self.error)
242
243    def _handle_response(self, response: Any) -> None:
244        # TODO(frolv): These lists could grow very large for persistent
245        # streaming RPCs such as logs. The size should be limited.
246        self._responses.append(response)
247        self._response_queue.put(response)
248
249        self._invoke_callback('on_next', response)
250
251    def _handle_completion(self, status: Status) -> None:
252        self.status = status
253        self._response_queue.put(None)
254
255        self._invoke_callback('on_completed', status)
256
257    def _handle_error(self, error: Status) -> None:
258        self.error = error
259        self._response_queue.put(None)
260
261        self._invoke_callback('on_error', error)
262
263    def _invoke_callback(self, callback_name: str, arg: Any) -> None:
264        """Invokes a user-provided callback function for an RPC event."""
265
266        # Catch and log any exceptions from the user-provided callback so that
267        # exceptions don't terminate the thread handling RPC packets.
268        callback: Callable[[Call, Any], None] = getattr(self, callback_name)
269
270        try:
271            callback(self, arg)
272        except Exception as callback_exception:  # pylint: disable=broad-except
273            msg = (
274                f'The {callback_name} callback ({callback}) for '
275                f'{self._rpc} raised an exception'
276            )
277            _LOG.exception(msg)
278
279            self._callback_exception = RuntimeError(msg)
280            self._callback_exception.__cause__ = callback_exception
281
282    def __enter__(self) -> Call:
283        return self
284
285    def __exit__(self, exc_type, exc_value, traceback) -> None:
286        self.cancel()
287
288    def __repr__(self) -> str:
289        return f'{type(self).__name__}({self.method})'
290
291
292class UnaryCall(Call):
293    """Tracks the state of a unary RPC call."""
294
295    @property
296    def response(self) -> Any:
297        return self._responses[-1] if self._responses else None
298
299    def wait(
300        self, timeout_s: OptionalTimeout = UseDefault.VALUE
301    ) -> UnaryResponse:
302        return self._unary_wait(timeout_s)
303
304
305class ServerStreamingCall(Call):
306    """Tracks the state of a server streaming RPC call."""
307
308    @property
309    def responses(self) -> Sequence:
310        return self._responses
311
312    def wait(
313        self, timeout_s: OptionalTimeout = UseDefault.VALUE
314    ) -> StreamResponse:
315        return self._stream_wait(timeout_s)
316
317    def get_responses(
318        self,
319        *,
320        count: int | None = None,
321        timeout_s: OptionalTimeout = UseDefault.VALUE,
322    ) -> Iterator:
323        return self._get_responses(count=count, timeout_s=timeout_s)
324
325    def request_completion(self) -> None:
326        """Sends client completion packet to server."""
327        if not self.completed():
328            self._rpcs.send_client_stream_end(self._rpc)
329
330    def __iter__(self) -> Iterator:
331        return self.get_responses()
332
333
334class ClientStreamingCall(Call):
335    """Tracks the state of a client streaming RPC call."""
336
337    @property
338    def response(self) -> Any:
339        return self._responses[-1] if self._responses else None
340
341    # TODO(hepler): Use / to mark the first arg as positional-only
342    #     when when Python 3.7 support is no longer required.
343    def send(
344        self, _rpc_request_proto: Message | None = None, **request_fields
345    ) -> None:
346        """Sends client stream request to the server."""
347        self._send_client_stream(_rpc_request_proto, request_fields)
348
349    def finish_and_wait(
350        self,
351        requests: Iterable[Message] = (),
352        *,
353        timeout_s: OptionalTimeout = UseDefault.VALUE,
354    ) -> UnaryResponse:
355        """Ends the client stream and waits for the RPC to complete."""
356        self._finish_client_stream(requests)
357        return self._unary_wait(timeout_s)
358
359
360class BidirectionalStreamingCall(Call):
361    """Tracks the state of a bidirectional streaming RPC call."""
362
363    @property
364    def responses(self) -> Sequence:
365        return self._responses
366
367    # TODO(hepler): Use / to mark the first arg as positional-only
368    #     when when Python 3.7 support is no longer required.
369    def send(
370        self, _rpc_request_proto: Message | None = None, **request_fields
371    ) -> None:
372        """Sends a message to the server in the client stream."""
373        self._send_client_stream(_rpc_request_proto, request_fields)
374
375    def finish_and_wait(
376        self,
377        requests: Iterable[Message] = (),
378        *,
379        timeout_s: OptionalTimeout = UseDefault.VALUE,
380    ) -> StreamResponse:
381        """Ends the client stream and waits for the RPC to complete."""
382        self._finish_client_stream(requests)
383        return self._stream_wait(timeout_s)
384
385    def get_responses(
386        self,
387        *,
388        count: int | None = None,
389        timeout_s: OptionalTimeout = UseDefault.VALUE,
390    ) -> Iterator:
391        return self._get_responses(count=count, timeout_s=timeout_s)
392
393    def __iter__(self) -> Iterator:
394        return self.get_responses()
395