1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Classes for handling ongoing RPC calls.""" 15 16import enum 17import logging 18import math 19import queue 20from typing import ( 21 Any, 22 Callable, 23 Iterable, 24 Iterator, 25 NamedTuple, 26 Union, 27 Optional, 28 Sequence, 29 TypeVar, 30) 31 32from pw_protobuf_compiler.python_protos import proto_repr 33from pw_status import Status 34from google.protobuf.message import Message 35 36from pw_rpc.callback_client.errors import RpcTimeout, RpcError 37from pw_rpc.client import PendingRpc, PendingRpcs 38from pw_rpc.descriptors import Method 39 40_LOG = logging.getLogger(__package__) 41 42 43class UseDefault(enum.Enum): 44 """Marker for args that should use a default value, when None is valid.""" 45 46 VALUE = 0 47 48 49CallTypeT = TypeVar( 50 'CallTypeT', 51 'UnaryCall', 52 'ServerStreamingCall', 53 'ClientStreamingCall', 54 'BidirectionalStreamingCall', 55) 56 57OnNextCallback = Callable[[CallTypeT, Any], Any] 58OnCompletedCallback = Callable[[CallTypeT, Any], Any] 59OnErrorCallback = Callable[[CallTypeT, Any], Any] 60 61OptionalTimeout = Union[UseDefault, float, None] 62 63 64class UnaryResponse(NamedTuple): 65 """Result from a unary or client streaming RPC: status and response.""" 66 67 status: Status 68 response: Any 69 70 def __repr__(self) -> str: 71 reply = proto_repr(self.response) if self.response else self.response 72 return f'({self.status}, {reply})' 73 74 75class StreamResponse(NamedTuple): 76 """Results from a server or bidirectional streaming RPC.""" 77 78 status: Status 79 responses: Sequence[Any] 80 81 def __repr__(self) -> str: 82 return ( 83 f'({self.status}, ' 84 f'[{", ".join(proto_repr(r) for r in self.responses)}])' 85 ) 86 87 88class Call: 89 """Represents an in-progress or completed RPC call.""" 90 91 def __init__( 92 self, 93 rpcs: PendingRpcs, 94 rpc: PendingRpc, 95 default_timeout_s: Optional[float], 96 on_next: Optional[OnNextCallback], 97 on_completed: Optional[OnCompletedCallback], 98 on_error: Optional[OnErrorCallback], 99 ) -> None: 100 self._rpcs = rpcs 101 self._rpc = rpc 102 self.default_timeout_s = default_timeout_s 103 104 self.status: Optional[Status] = None 105 self.error: Optional[Status] = None 106 self._callback_exception: Optional[Exception] = None 107 self._responses: list = [] 108 self._response_queue: queue.SimpleQueue = queue.SimpleQueue() 109 110 self.on_next = on_next or Call._default_response 111 self.on_completed = on_completed or Call._default_completion 112 self.on_error = on_error or Call._default_error 113 114 def _invoke(self, request: Optional[Message], ignore_errors: bool) -> None: 115 """Calls the RPC. This must be called immediately after __init__.""" 116 previous = self._rpcs.send_request( 117 self._rpc, 118 request, 119 self, 120 ignore_errors=ignore_errors, 121 override_pending=True, 122 ) 123 124 # TODO(hepler): Remove the cancel_duplicate_calls option. 125 if ( 126 self._rpcs.cancel_duplicate_calls # type: ignore[attr-defined] 127 and previous is not None 128 and not previous.completed() 129 ): 130 previous._handle_error( # pylint: disable=protected-access 131 Status.CANCELLED 132 ) 133 134 def _default_response(self, response: Message) -> None: 135 _LOG.debug('%s received response: %s', self._rpc, response) 136 137 def _default_completion(self, status: Status) -> None: 138 _LOG.info('%s completed: %s', self._rpc, status) 139 140 def _default_error(self, error: Status) -> None: 141 _LOG.warning('%s terminated due to an error: %s', self._rpc, error) 142 143 @property 144 def method(self) -> Method: 145 return self._rpc.method 146 147 def completed(self) -> bool: 148 """True if the RPC call has completed, successfully or from an error.""" 149 return self.status is not None or self.error is not None 150 151 def _send_client_stream( 152 self, request_proto: Optional[Message], request_fields: dict 153 ) -> None: 154 """Sends a client to the server in the client stream. 155 156 Sending a client stream packet on a closed RPC raises an exception. 157 """ 158 self._check_errors() 159 160 if self.status is not None: 161 raise RpcError(self._rpc, Status.FAILED_PRECONDITION) 162 163 self._rpcs.send_client_stream( 164 self._rpc, self.method.get_request(request_proto, request_fields) 165 ) 166 167 def _finish_client_stream(self, requests: Iterable[Message]) -> None: 168 for request in requests: 169 self._send_client_stream(request, {}) 170 171 if not self.completed(): 172 self._rpcs.send_client_stream_end(self._rpc) 173 174 def _unary_wait(self, timeout_s: OptionalTimeout) -> UnaryResponse: 175 """Waits until the RPC has completed.""" 176 for _ in self._get_responses(timeout_s=timeout_s): 177 pass 178 179 assert self.status is not None and self._responses 180 return UnaryResponse(self.status, self._responses[-1]) 181 182 def _stream_wait(self, timeout_s: OptionalTimeout) -> StreamResponse: 183 """Waits until the RPC has completed.""" 184 for _ in self._get_responses(timeout_s=timeout_s): 185 pass 186 187 assert self.status is not None 188 return StreamResponse(self.status, self._responses) 189 190 def _get_responses( 191 self, *, count: Optional[int] = None, timeout_s: OptionalTimeout 192 ) -> Iterator: 193 """Returns an iterator of stream responses. 194 195 Args: 196 count: Responses to read before returning; None reads all 197 timeout_s: max time in seconds to wait between responses; 0 doesn't 198 block, None blocks indefinitely 199 """ 200 self._check_errors() 201 202 if self.completed() and self._response_queue.empty(): 203 return 204 205 if timeout_s is UseDefault.VALUE: 206 timeout_s = self.default_timeout_s 207 208 remaining = math.inf if count is None else count 209 210 try: 211 while remaining: 212 response = self._response_queue.get(True, timeout_s) 213 214 self._check_errors() 215 216 if response is None: 217 return 218 219 yield response 220 remaining -= 1 221 except queue.Empty: 222 raise RpcTimeout(self._rpc, timeout_s) 223 224 def cancel(self) -> bool: 225 """Cancels the RPC; returns whether the RPC was active.""" 226 if self.completed(): 227 return False 228 229 self.error = Status.CANCELLED 230 return self._rpcs.send_cancel(self._rpc) 231 232 def _check_errors(self) -> None: 233 if self._callback_exception: 234 raise self._callback_exception 235 236 if self.error: 237 raise RpcError(self._rpc, self.error) 238 239 def _handle_response(self, response: Any) -> None: 240 # TODO(frolv): These lists could grow very large for persistent 241 # streaming RPCs such as logs. The size should be limited. 242 self._responses.append(response) 243 self._response_queue.put(response) 244 245 self._invoke_callback('on_next', response) 246 247 def _handle_completion(self, status: Status) -> None: 248 self.status = status 249 self._response_queue.put(None) 250 251 self._invoke_callback('on_completed', status) 252 253 def _handle_error(self, error: Status) -> None: 254 self.error = error 255 self._response_queue.put(None) 256 257 self._invoke_callback('on_error', error) 258 259 def _invoke_callback(self, callback_name: str, arg: Any) -> None: 260 """Invokes a user-provided callback function for an RPC event.""" 261 262 # Catch and log any exceptions from the user-provided callback so that 263 # exceptions don't terminate the thread handling RPC packets. 264 callback: Callable[[Call, Any], None] = getattr(self, callback_name) 265 266 try: 267 callback(self, arg) 268 except Exception as callback_exception: # pylint: disable=broad-except 269 msg = ( 270 f'The {callback_name} callback ({callback}) for ' 271 f'{self._rpc} raised an exception' 272 ) 273 _LOG.exception(msg) 274 275 self._callback_exception = RuntimeError(msg) 276 self._callback_exception.__cause__ = callback_exception 277 278 def __enter__(self) -> 'Call': 279 return self 280 281 def __exit__(self, exc_type, exc_value, traceback) -> None: 282 self.cancel() 283 284 def __repr__(self) -> str: 285 return f'{type(self).__name__}({self.method})' 286 287 288class UnaryCall(Call): 289 """Tracks the state of a unary RPC call.""" 290 291 @property 292 def response(self) -> Any: 293 return self._responses[-1] if self._responses else None 294 295 def wait( 296 self, timeout_s: OptionalTimeout = UseDefault.VALUE 297 ) -> UnaryResponse: 298 return self._unary_wait(timeout_s) 299 300 301class ServerStreamingCall(Call): 302 """Tracks the state of a server streaming RPC call.""" 303 304 @property 305 def responses(self) -> Sequence: 306 return self._responses 307 308 def wait( 309 self, timeout_s: OptionalTimeout = UseDefault.VALUE 310 ) -> StreamResponse: 311 return self._stream_wait(timeout_s) 312 313 def get_responses( 314 self, 315 *, 316 count: Optional[int] = None, 317 timeout_s: OptionalTimeout = UseDefault.VALUE, 318 ) -> Iterator: 319 return self._get_responses(count=count, timeout_s=timeout_s) 320 321 def __iter__(self) -> Iterator: 322 return self.get_responses() 323 324 325class ClientStreamingCall(Call): 326 """Tracks the state of a client streaming RPC call.""" 327 328 @property 329 def response(self) -> Any: 330 return self._responses[-1] if self._responses else None 331 332 # TODO(hepler): Use / to mark the first arg as positional-only 333 # when when Python 3.7 support is no longer required. 334 def send( 335 self, _rpc_request_proto: Optional[Message] = None, **request_fields 336 ) -> None: 337 """Sends client stream request to the server.""" 338 self._send_client_stream(_rpc_request_proto, request_fields) 339 340 def finish_and_wait( 341 self, 342 requests: Iterable[Message] = (), 343 *, 344 timeout_s: OptionalTimeout = UseDefault.VALUE, 345 ) -> UnaryResponse: 346 """Ends the client stream and waits for the RPC to complete.""" 347 self._finish_client_stream(requests) 348 return self._unary_wait(timeout_s) 349 350 351class BidirectionalStreamingCall(Call): 352 """Tracks the state of a bidirectional streaming RPC call.""" 353 354 @property 355 def responses(self) -> Sequence: 356 return self._responses 357 358 # TODO(hepler): Use / to mark the first arg as positional-only 359 # when when Python 3.7 support is no longer required. 360 def send( 361 self, _rpc_request_proto: Optional[Message] = None, **request_fields 362 ) -> None: 363 """Sends a message to the server in the client stream.""" 364 self._send_client_stream(_rpc_request_proto, request_fields) 365 366 def finish_and_wait( 367 self, 368 requests: Iterable[Message] = (), 369 *, 370 timeout_s: OptionalTimeout = UseDefault.VALUE, 371 ) -> StreamResponse: 372 """Ends the client stream and waits for the RPC to complete.""" 373 self._finish_client_stream(requests) 374 return self._stream_wait(timeout_s) 375 376 def get_responses( 377 self, 378 *, 379 count: Optional[int] = None, 380 timeout_s: OptionalTimeout = UseDefault.VALUE, 381 ) -> Iterator: 382 return self._get_responses(count=count, timeout_s=timeout_s) 383 384 def __iter__(self) -> Iterator: 385 return self.get_responses() 386