1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Classes for handling ongoing RPC calls.""" 15 16from __future__ import annotations 17 18import enum 19import logging 20import math 21import queue 22from typing import ( 23 Any, 24 Callable, 25 Iterable, 26 Iterator, 27 NamedTuple, 28 Sequence, 29 TypeVar, 30) 31 32from pw_protobuf_compiler.python_protos import proto_repr 33from pw_status import Status 34from google.protobuf.message import Message 35 36from pw_rpc.callback_client.errors import RpcTimeout, RpcError 37from pw_rpc.client import PendingRpc, PendingRpcs 38from pw_rpc.descriptors import Method 39 40_LOG = logging.getLogger(__package__) 41 42 43class UseDefault(enum.Enum): 44 """Marker for args that should use a default value, when None is valid.""" 45 46 VALUE = 0 47 48 49CallTypeT = TypeVar( 50 'CallTypeT', 51 'UnaryCall', 52 'ServerStreamingCall', 53 'ClientStreamingCall', 54 'BidirectionalStreamingCall', 55) 56 57OnNextCallback = Callable[[CallTypeT, Any], Any] 58OnCompletedCallback = Callable[[CallTypeT, Any], Any] 59OnErrorCallback = Callable[[CallTypeT, Any], Any] 60 61OptionalTimeout = UseDefault | float | None 62 63 64class UnaryResponse(NamedTuple): 65 """Result from a unary or client streaming RPC: status and response.""" 66 67 status: Status 68 response: Any 69 70 def __repr__(self) -> str: 71 reply = proto_repr(self.response) if self.response else self.response 72 return f'({self.status}, {reply})' 73 74 75class StreamResponse(NamedTuple): 76 """Results from a server or bidirectional streaming RPC.""" 77 78 status: Status 79 responses: Sequence[Any] 80 81 def __repr__(self) -> str: 82 return ( 83 f'({self.status}, ' 84 f'[{", ".join(proto_repr(r) for r in self.responses)}])' 85 ) 86 87 88class Call: 89 """Represents an in-progress or completed RPC call.""" 90 91 def __init__( 92 self, 93 rpcs: PendingRpcs, 94 rpc: PendingRpc, 95 default_timeout_s: float | None, 96 on_next: OnNextCallback | None, 97 on_completed: OnCompletedCallback | None, 98 on_error: OnErrorCallback | None, 99 ) -> None: 100 self._rpcs = rpcs 101 self._rpc = rpc 102 self.default_timeout_s = default_timeout_s 103 104 self.status: Status | None = None 105 self.error: Status | None = None 106 self._callback_exception: Exception | None = None 107 self._responses: list = [] 108 self._response_queue: queue.SimpleQueue = queue.SimpleQueue() 109 110 self.on_next = on_next or Call._default_response 111 self.on_completed = on_completed or Call._default_completion 112 self.on_error = on_error or Call._default_error 113 114 def _invoke(self, request: Message | None, ignore_errors: bool) -> None: 115 """Calls the RPC. This must be called immediately after __init__.""" 116 previous = self._rpcs.send_request( 117 self._rpc, 118 request, 119 self, 120 ignore_errors=ignore_errors, 121 override_pending=True, 122 ) 123 124 # TODO(hepler): Remove the cancel_duplicate_calls option. 125 if ( 126 self._rpcs.cancel_duplicate_calls # type: ignore[attr-defined] 127 and previous is not None 128 and not previous.completed() 129 ): 130 previous._handle_error( # pylint: disable=protected-access 131 Status.CANCELLED 132 ) 133 134 def _default_response(self, response: Message) -> None: 135 _LOG.debug('%s received response: %s', self._rpc, response) 136 137 def _default_completion(self, status: Status) -> None: 138 _LOG.info('%s completed: %s', self._rpc, status) 139 140 def _default_error(self, error: Status) -> None: 141 _LOG.warning('%s terminated due to an error: %s', self._rpc, error) 142 143 @property 144 def call_id(self) -> int: 145 return self._rpc.call_id 146 147 @property 148 def method(self) -> Method: 149 return self._rpc.method 150 151 def completed(self) -> bool: 152 """True if the RPC call has completed, successfully or from an error.""" 153 return self.status is not None or self.error is not None 154 155 def _send_client_stream( 156 self, request_proto: Message | None, request_fields: dict 157 ) -> None: 158 """Sends a client to the server in the client stream. 159 160 Sending a client stream packet on a closed RPC raises an exception. 161 """ 162 self._check_errors() 163 164 if self.status is not None: 165 raise RpcError(self._rpc, Status.FAILED_PRECONDITION) 166 167 self._rpcs.send_client_stream( 168 self._rpc, self.method.get_request(request_proto, request_fields) 169 ) 170 171 def _finish_client_stream(self, requests: Iterable[Message]) -> None: 172 for request in requests: 173 self._send_client_stream(request, {}) 174 175 if not self.completed(): 176 self._rpcs.send_client_stream_end(self._rpc) 177 178 def _unary_wait(self, timeout_s: OptionalTimeout) -> UnaryResponse: 179 """Waits until the RPC has completed.""" 180 for _ in self._get_responses(timeout_s=timeout_s): 181 pass 182 183 assert self.status is not None and self._responses 184 return UnaryResponse(self.status, self._responses[-1]) 185 186 def _stream_wait(self, timeout_s: OptionalTimeout) -> StreamResponse: 187 """Waits until the RPC has completed.""" 188 for _ in self._get_responses(timeout_s=timeout_s): 189 pass 190 191 assert self.status is not None 192 return StreamResponse(self.status, self._responses) 193 194 def _get_responses( 195 self, *, count: int | None = None, timeout_s: OptionalTimeout 196 ) -> Iterator: 197 """Returns an iterator of stream responses. 198 199 Args: 200 count: Responses to read before returning; None reads all 201 timeout_s: max time in seconds to wait between responses; 0 doesn't 202 block, None blocks indefinitely 203 """ 204 self._check_errors() 205 206 if self.completed() and self._response_queue.empty(): 207 return 208 209 if timeout_s is UseDefault.VALUE: 210 timeout_s = self.default_timeout_s 211 212 remaining = math.inf if count is None else count 213 214 try: 215 while remaining: 216 response = self._response_queue.get(True, timeout_s) 217 218 self._check_errors() 219 220 if response is None: 221 return 222 223 yield response 224 remaining -= 1 225 except queue.Empty: 226 raise RpcTimeout(self._rpc, timeout_s) 227 228 def cancel(self) -> bool: 229 """Cancels the RPC; returns whether the RPC was active.""" 230 if self.completed(): 231 return False 232 233 self.error = Status.CANCELLED 234 return self._rpcs.send_cancel(self._rpc) 235 236 def _check_errors(self) -> None: 237 if self._callback_exception: 238 raise self._callback_exception 239 240 if self.error: 241 raise RpcError(self._rpc, self.error) 242 243 def _handle_response(self, response: Any) -> None: 244 # TODO(frolv): These lists could grow very large for persistent 245 # streaming RPCs such as logs. The size should be limited. 246 self._responses.append(response) 247 self._response_queue.put(response) 248 249 self._invoke_callback('on_next', response) 250 251 def _handle_completion(self, status: Status) -> None: 252 self.status = status 253 self._response_queue.put(None) 254 255 self._invoke_callback('on_completed', status) 256 257 def _handle_error(self, error: Status) -> None: 258 self.error = error 259 self._response_queue.put(None) 260 261 self._invoke_callback('on_error', error) 262 263 def _invoke_callback(self, callback_name: str, arg: Any) -> None: 264 """Invokes a user-provided callback function for an RPC event.""" 265 266 # Catch and log any exceptions from the user-provided callback so that 267 # exceptions don't terminate the thread handling RPC packets. 268 callback: Callable[[Call, Any], None] = getattr(self, callback_name) 269 270 try: 271 callback(self, arg) 272 except Exception as callback_exception: # pylint: disable=broad-except 273 msg = ( 274 f'The {callback_name} callback ({callback}) for ' 275 f'{self._rpc} raised an exception' 276 ) 277 _LOG.exception(msg) 278 279 self._callback_exception = RuntimeError(msg) 280 self._callback_exception.__cause__ = callback_exception 281 282 def __enter__(self) -> Call: 283 return self 284 285 def __exit__(self, exc_type, exc_value, traceback) -> None: 286 self.cancel() 287 288 def __repr__(self) -> str: 289 return f'{type(self).__name__}({self.method})' 290 291 292class UnaryCall(Call): 293 """Tracks the state of a unary RPC call.""" 294 295 @property 296 def response(self) -> Any: 297 return self._responses[-1] if self._responses else None 298 299 def wait( 300 self, timeout_s: OptionalTimeout = UseDefault.VALUE 301 ) -> UnaryResponse: 302 return self._unary_wait(timeout_s) 303 304 305class ServerStreamingCall(Call): 306 """Tracks the state of a server streaming RPC call.""" 307 308 @property 309 def responses(self) -> Sequence: 310 return self._responses 311 312 def wait( 313 self, timeout_s: OptionalTimeout = UseDefault.VALUE 314 ) -> StreamResponse: 315 return self._stream_wait(timeout_s) 316 317 def get_responses( 318 self, 319 *, 320 count: int | None = None, 321 timeout_s: OptionalTimeout = UseDefault.VALUE, 322 ) -> Iterator: 323 return self._get_responses(count=count, timeout_s=timeout_s) 324 325 def request_completion(self) -> None: 326 """Sends client completion packet to server.""" 327 if not self.completed(): 328 self._rpcs.send_client_stream_end(self._rpc) 329 330 def __iter__(self) -> Iterator: 331 return self.get_responses() 332 333 334class ClientStreamingCall(Call): 335 """Tracks the state of a client streaming RPC call.""" 336 337 @property 338 def response(self) -> Any: 339 return self._responses[-1] if self._responses else None 340 341 # TODO(hepler): Use / to mark the first arg as positional-only 342 # when when Python 3.7 support is no longer required. 343 def send( 344 self, _rpc_request_proto: Message | None = None, **request_fields 345 ) -> None: 346 """Sends client stream request to the server.""" 347 self._send_client_stream(_rpc_request_proto, request_fields) 348 349 def finish_and_wait( 350 self, 351 requests: Iterable[Message] = (), 352 *, 353 timeout_s: OptionalTimeout = UseDefault.VALUE, 354 ) -> UnaryResponse: 355 """Ends the client stream and waits for the RPC to complete.""" 356 self._finish_client_stream(requests) 357 return self._unary_wait(timeout_s) 358 359 360class BidirectionalStreamingCall(Call): 361 """Tracks the state of a bidirectional streaming RPC call.""" 362 363 @property 364 def responses(self) -> Sequence: 365 return self._responses 366 367 # TODO(hepler): Use / to mark the first arg as positional-only 368 # when when Python 3.7 support is no longer required. 369 def send( 370 self, _rpc_request_proto: Message | None = None, **request_fields 371 ) -> None: 372 """Sends a message to the server in the client stream.""" 373 self._send_client_stream(_rpc_request_proto, request_fields) 374 375 def finish_and_wait( 376 self, 377 requests: Iterable[Message] = (), 378 *, 379 timeout_s: OptionalTimeout = UseDefault.VALUE, 380 ) -> StreamResponse: 381 """Ends the client stream and waits for the RPC to complete.""" 382 self._finish_client_stream(requests) 383 return self._stream_wait(timeout_s) 384 385 def get_responses( 386 self, 387 *, 388 count: int | None = None, 389 timeout_s: OptionalTimeout = UseDefault.VALUE, 390 ) -> Iterator: 391 return self._get_responses(count=count, timeout_s=timeout_s) 392 393 def __iter__(self) -> Iterator: 394 return self.get_responses() 395