1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Classes for handling ongoing RPC calls.""" 15 16import enum 17import logging 18import math 19import queue 20from typing import (Any, Callable, Iterable, Iterator, NamedTuple, Union, 21 Optional, Sequence, TypeVar) 22 23from pw_protobuf_compiler.python_protos import proto_repr 24from pw_status import Status 25from google.protobuf.message import Message 26 27from pw_rpc.callback_client.errors import RpcTimeout, RpcError 28from pw_rpc.client import PendingRpc, PendingRpcs 29from pw_rpc.descriptors import Method 30 31_LOG = logging.getLogger(__package__) 32 33 34class UseDefault(enum.Enum): 35 """Marker for args that should use a default value, when None is valid.""" 36 VALUE = 0 37 38 39CallType = TypeVar('CallType', 'UnaryCall', 'ServerStreamingCall', 40 'ClientStreamingCall', 'BidirectionalStreamingCall') 41 42OnNextCallback = Callable[[CallType, Any], Any] 43OnCompletedCallback = Callable[[CallType, Any], Any] 44OnErrorCallback = Callable[[CallType, Any], Any] 45 46OptionalTimeout = Union[UseDefault, float, None] 47 48 49class UnaryResponse(NamedTuple): 50 """Result from a unary or client streaming RPC: status and response.""" 51 status: Status 52 response: Any 53 54 def __repr__(self) -> str: 55 return f'({self.status}, {proto_repr(self.response)})' 56 57 58class StreamResponse(NamedTuple): 59 """Results from a server or bidirectional streaming RPC.""" 60 status: Status 61 responses: Sequence[Any] 62 63 def __repr__(self) -> str: 64 return (f'({self.status}, ' 65 f'[{", ".join(proto_repr(r) for r in self.responses)}])') 66 67 68class Call: 69 """Represents an in-progress or completed RPC call.""" 70 def __init__(self, rpcs: PendingRpcs, rpc: PendingRpc, 71 default_timeout_s: Optional[float], 72 on_next: Optional[OnNextCallback], 73 on_completed: Optional[OnCompletedCallback], 74 on_error: Optional[OnErrorCallback]) -> None: 75 self._rpcs = rpcs 76 self._rpc = rpc 77 self.default_timeout_s = default_timeout_s 78 79 self.status: Optional[Status] = None 80 self.error: Optional[Status] = None 81 self._callback_exception: Optional[Exception] = None 82 self._responses: list = [] 83 self._response_queue: queue.SimpleQueue = queue.SimpleQueue() 84 85 self.on_next = on_next or Call._default_response 86 self.on_completed = on_completed or Call._default_completion 87 self.on_error = on_error or Call._default_error 88 89 def _invoke(self, request: Optional[Message], ignore_errors: bool) -> None: 90 """Calls the RPC. This must be called immediately after __init__.""" 91 previous = self._rpcs.send_request(self._rpc, 92 request, 93 self, 94 ignore_errors=ignore_errors, 95 override_pending=True) 96 97 # TODO(hepler): Remove the cancel_duplicate_calls option. 98 if (self._rpcs.cancel_duplicate_calls and # type: ignore[attr-defined] 99 previous is not None and not previous.completed()): 100 previous._handle_error(Status.CANCELLED) # pylint: disable=protected-access 101 102 def _default_response(self, response: Message) -> None: 103 _LOG.debug('%s received response: %s', self._rpc, response) 104 105 def _default_completion(self, status: Status) -> None: 106 _LOG.info('%s completed: %s', self._rpc, status) 107 108 def _default_error(self, error: Status) -> None: 109 _LOG.warning('%s terminated due to an error: %s', self._rpc, error) 110 111 @property 112 def method(self) -> Method: 113 return self._rpc.method 114 115 def completed(self) -> bool: 116 """True if the RPC call has completed, successfully or from an error.""" 117 return self.status is not None or self.error is not None 118 119 def _send_client_stream(self, request_proto: Optional[Message], 120 request_fields: dict) -> None: 121 """Sends a client to the server in the client stream. 122 123 Sending a client stream packet on a closed RPC raises an exception. 124 """ 125 self._check_errors() 126 127 if self.status is not None: 128 raise RpcError(self._rpc, Status.FAILED_PRECONDITION) 129 130 self._rpcs.send_client_stream( 131 self._rpc, self.method.get_request(request_proto, request_fields)) 132 133 def _finish_client_stream(self, requests: Iterable[Message]) -> None: 134 for request in requests: 135 self._send_client_stream(request, {}) 136 137 if not self.completed(): 138 self._rpcs.send_client_stream_end(self._rpc) 139 140 def _unary_wait(self, timeout_s: OptionalTimeout) -> UnaryResponse: 141 """Waits until the RPC has completed.""" 142 for _ in self._get_responses(timeout_s=timeout_s): 143 pass 144 145 assert self.status is not None and self._responses 146 return UnaryResponse(self.status, self._responses[-1]) 147 148 def _stream_wait(self, timeout_s: OptionalTimeout) -> StreamResponse: 149 """Waits until the RPC has completed.""" 150 for _ in self._get_responses(timeout_s=timeout_s): 151 pass 152 153 assert self.status is not None 154 return StreamResponse(self.status, self._responses) 155 156 def _get_responses(self, 157 *, 158 count: int = None, 159 timeout_s: OptionalTimeout) -> Iterator: 160 """Returns an iterator of stream responses. 161 162 Args: 163 count: Responses to read before returning; None reads all 164 timeout_s: max time in seconds to wait between responses; 0 doesn't 165 block, None blocks indefinitely 166 """ 167 self._check_errors() 168 169 if self.completed() and self._response_queue.empty(): 170 return 171 172 if timeout_s is UseDefault.VALUE: 173 timeout_s = self.default_timeout_s 174 175 remaining = math.inf if count is None else count 176 177 try: 178 while remaining: 179 response = self._response_queue.get(True, timeout_s) 180 181 self._check_errors() 182 183 if response is None: 184 return 185 186 yield response 187 remaining -= 1 188 except queue.Empty: 189 raise RpcTimeout(self._rpc, timeout_s) 190 191 def cancel(self) -> bool: 192 """Cancels the RPC; returns whether the RPC was active.""" 193 if self.completed(): 194 return False 195 196 self.error = Status.CANCELLED 197 return self._rpcs.send_cancel(self._rpc) 198 199 def _check_errors(self) -> None: 200 if self._callback_exception: 201 raise self._callback_exception 202 203 if self.error: 204 raise RpcError(self._rpc, self.error) 205 206 def _handle_response(self, response: Any) -> None: 207 # TODO(frolv): These lists could grow very large for persistent 208 # streaming RPCs such as logs. The size should be limited. 209 self._responses.append(response) 210 self._response_queue.put(response) 211 212 self._invoke_callback('on_next', response) 213 214 def _handle_completion(self, status: Status) -> None: 215 self.status = status 216 self._response_queue.put(None) 217 218 self._invoke_callback('on_completed', status) 219 220 def _handle_error(self, error: Status) -> None: 221 self.error = error 222 self._response_queue.put(None) 223 224 self._invoke_callback('on_error', error) 225 226 def _invoke_callback(self, callback_name: str, arg: Any) -> None: 227 """Invokes a user-provided callback function for an RPC event.""" 228 229 # Catch and log any exceptions from the user-provided callback so that 230 # exceptions don't terminate the thread handling RPC packets. 231 callback: Callable[[Call, Any], None] = getattr(self, callback_name) 232 233 try: 234 callback(self, arg) 235 except Exception as callback_exception: # pylint: disable=broad-except 236 msg = (f'The {callback_name} callback ({callback}) for ' 237 f'{self._rpc} raised an exception') 238 _LOG.exception(msg) 239 240 self._callback_exception = RuntimeError(msg) 241 self._callback_exception.__cause__ = callback_exception 242 243 def __enter__(self) -> 'Call': 244 return self 245 246 def __exit__(self, exc_type, exc_value, traceback) -> None: 247 self.cancel() 248 249 def __repr__(self) -> str: 250 return f'{type(self).__name__}({self.method})' 251 252 253class UnaryCall(Call): 254 """Tracks the state of a unary RPC call.""" 255 @property 256 def response(self) -> Any: 257 return self._responses[-1] if self._responses else None 258 259 def wait(self, 260 timeout_s: OptionalTimeout = UseDefault.VALUE) -> UnaryResponse: 261 return self._unary_wait(timeout_s) 262 263 264class ServerStreamingCall(Call): 265 """Tracks the state of a server streaming RPC call.""" 266 @property 267 def responses(self) -> Sequence: 268 return self._responses 269 270 def wait(self, 271 timeout_s: OptionalTimeout = UseDefault.VALUE) -> StreamResponse: 272 return self._stream_wait(timeout_s) 273 274 def get_responses( 275 self, 276 *, 277 count: int = None, 278 timeout_s: OptionalTimeout = UseDefault.VALUE) -> Iterator: 279 return self._get_responses(count=count, timeout_s=timeout_s) 280 281 def __iter__(self) -> Iterator: 282 return self.get_responses() 283 284 285class ClientStreamingCall(Call): 286 """Tracks the state of a client streaming RPC call.""" 287 @property 288 def response(self) -> Any: 289 return self._responses[-1] if self._responses else None 290 291 # TODO(hepler): Use / to mark the first arg as positional-only 292 # when when Python 3.7 support is no longer required. 293 def send(self, 294 _rpc_request_proto: Message = None, 295 **request_fields) -> None: 296 """Sends client stream request to the server.""" 297 self._send_client_stream(_rpc_request_proto, request_fields) 298 299 def finish_and_wait( 300 self, 301 requests: Iterable[Message] = (), 302 *, 303 timeout_s: OptionalTimeout = UseDefault.VALUE) -> UnaryResponse: 304 """Ends the client stream and waits for the RPC to complete.""" 305 self._finish_client_stream(requests) 306 return self._unary_wait(timeout_s) 307 308 309class BidirectionalStreamingCall(Call): 310 """Tracks the state of a bidirectional streaming RPC call.""" 311 @property 312 def responses(self) -> Sequence: 313 return self._responses 314 315 # TODO(hepler): Use / to mark the first arg as positional-only 316 # when when Python 3.7 support is no longer required. 317 def send(self, 318 _rpc_request_proto: Message = None, 319 **request_fields) -> None: 320 """Sends a message to the server in the client stream.""" 321 self._send_client_stream(_rpc_request_proto, request_fields) 322 323 def finish_and_wait( 324 self, 325 requests: Iterable[Message] = (), 326 *, 327 timeout_s: OptionalTimeout = UseDefault.VALUE) -> StreamResponse: 328 """Ends the client stream and waits for the RPC to complete.""" 329 self._finish_client_stream(requests) 330 return self._stream_wait(timeout_s) 331 332 def get_responses( 333 self, 334 *, 335 count: int = None, 336 timeout_s: OptionalTimeout = UseDefault.VALUE) -> Iterator: 337 return self._get_responses(count=count, timeout_s=timeout_s) 338 339 def __iter__(self) -> Iterator: 340 return self.get_responses() 341