1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Defines a callback-based RPC ClientImpl to use with pw_rpc.Client. 15 16callback_client.Impl supports invoking RPCs synchronously or asynchronously. 17Asynchronous invocations use a callback. 18 19Synchronous invocations look like a function call: 20 21 status, response = client.channel(1).call.MyServer.MyUnary(some_field=123) 22 23 # Streaming calls return an iterable of responses 24 for reply in client.channel(1).call.MyService.MyServerStreaming(request): 25 pass 26 27Asynchronous invocations pass a callback in addition to the request. The 28callback must be a callable that accepts a status and a payload, either of 29which may be None. The Status is only set when the RPC is completed. 30 31 callback = lambda status, payload: print('Response:', status, payload) 32 33 call = client.channel(1).call.MyServer.MyUnary.invoke( 34 callback, some_field=123) 35 36 call = client.channel(1).call.MyService.MyServerStreaming.invoke( 37 callback, request): 38 39When invoking a method, requests may be provided as a message object or as 40kwargs for the message fields (but not both). 41""" 42 43import enum 44import inspect 45import logging 46import queue 47import textwrap 48import threading 49from typing import Any, Callable, Iterator, NamedTuple, Union, Optional 50 51from pw_protobuf_compiler.python_protos import proto_repr 52from pw_status import Status 53 54from pw_rpc import client, descriptors 55from pw_rpc.client import PendingRpc, PendingRpcs 56from pw_rpc.descriptors import Channel, Method, Service 57 58_LOG = logging.getLogger(__name__) 59 60 61class UseDefault(enum.Enum): 62 """Marker for args that should use a default value, when None is valid.""" 63 VALUE = 0 64 65 66OptionalTimeout = Union[UseDefault, float, None] 67 68ResponseCallback = Callable[[PendingRpc, Any], Any] 69CompletionCallback = Callable[[PendingRpc, Status], Any] 70ErrorCallback = Callable[[PendingRpc, Status], Any] 71 72 73class _Callbacks(NamedTuple): 74 response: ResponseCallback 75 completion: CompletionCallback 76 error: ErrorCallback 77 78 79def _default_response(rpc: PendingRpc, response: Any) -> None: 80 _LOG.info('%s response: %s', rpc, response) 81 82 83def _default_completion(rpc: PendingRpc, status: Status) -> None: 84 _LOG.info('%s finished: %s', rpc, status) 85 86 87def _default_error(rpc: PendingRpc, status: Status) -> None: 88 _LOG.error('%s error: %s', rpc, status) 89 90 91class _MethodClient: 92 """A method that can be invoked for a particular channel.""" 93 def __init__(self, client_impl: 'Impl', rpcs: PendingRpcs, 94 channel: Channel, method: Method, 95 default_timeout_s: Optional[float]): 96 self._impl = client_impl 97 self._rpcs = rpcs 98 self._rpc = PendingRpc(channel, method.service, method) 99 self.default_timeout_s: Optional[float] = default_timeout_s 100 101 @property 102 def channel(self) -> Channel: 103 return self._rpc.channel 104 105 @property 106 def method(self) -> Method: 107 return self._rpc.method 108 109 @property 110 def service(self) -> Service: 111 return self._rpc.service 112 113 def invoke(self, 114 request: Any, 115 response: ResponseCallback = _default_response, 116 completion: CompletionCallback = _default_completion, 117 error: ErrorCallback = _default_error, 118 *, 119 override_pending: bool = True, 120 keep_open: bool = False) -> '_AsyncCall': 121 """Invokes an RPC with callbacks.""" 122 self._rpcs.send_request(self._rpc, 123 request, 124 _Callbacks(response, completion, error), 125 override_pending=override_pending, 126 keep_open=keep_open) 127 return _AsyncCall(self._rpcs, self._rpc) 128 129 def __repr__(self) -> str: 130 return self.help() 131 132 def __call__(self): 133 raise NotImplementedError('Implemented by derived classes') 134 135 def help(self) -> str: 136 """Returns a help message about this RPC.""" 137 function_call = self.method.full_name + '(' 138 139 docstring = inspect.getdoc(self.__call__) 140 assert docstring is not None 141 142 annotation = inspect.Signature.from_callable(self).return_annotation 143 if isinstance(annotation, type): 144 annotation = annotation.__name__ 145 146 arg_sep = f',\n{" " * len(function_call)}' 147 return ( 148 f'{function_call}' 149 f'{arg_sep.join(descriptors.field_help(self.method.request_type))})' 150 f'\n\n{textwrap.indent(docstring, " ")}\n\n' 151 f' Returns {annotation}.') 152 153 154class RpcTimeout(Exception): 155 def __init__(self, rpc: PendingRpc, timeout: Optional[float]): 156 super().__init__( 157 f'No response received for {rpc.method} after {timeout} s') 158 self.rpc = rpc 159 self.timeout = timeout 160 161 162class RpcError(Exception): 163 def __init__(self, rpc: PendingRpc, status: Status): 164 if status is Status.NOT_FOUND: 165 msg = ': the RPC server does not support this RPC' 166 else: 167 msg = '' 168 169 super().__init__(f'{rpc.method} failed with error {status}{msg}') 170 self.rpc = rpc 171 self.status = status 172 173 174class _AsyncCall: 175 """Represents an ongoing callback-based call.""" 176 177 # TODO(hepler): Consider alternatives (futures) and/or expand functionality. 178 179 def __init__(self, rpcs: PendingRpcs, rpc: PendingRpc): 180 self._rpc = rpc 181 self._rpcs = rpcs 182 183 def cancel(self) -> bool: 184 return self._rpcs.send_cancel(self._rpc) 185 186 def __enter__(self) -> '_AsyncCall': 187 return self 188 189 def __exit__(self, exc_type, exc_value, traceback) -> None: 190 self.cancel() 191 192 193class StreamingResponses: 194 """Used to iterate over a queue.SimpleQueue.""" 195 def __init__(self, method_client: _MethodClient, 196 responses: queue.SimpleQueue, 197 default_timeout_s: OptionalTimeout): 198 self._method_client = method_client 199 self._queue = responses 200 self.status: Optional[Status] = None 201 202 if default_timeout_s is UseDefault.VALUE: 203 self.default_timeout_s = self._method_client.default_timeout_s 204 else: 205 self.default_timeout_s = default_timeout_s 206 207 @property 208 def method(self) -> Method: 209 return self._method_client.method 210 211 def cancel(self) -> None: 212 self._method_client._rpcs.send_cancel(self._method_client._rpc) # pylint: disable=protected-access 213 214 def responses(self, 215 *, 216 block: bool = True, 217 timeout_s: OptionalTimeout = UseDefault.VALUE) -> Iterator: 218 """Returns an iterator of stream responses. 219 220 Args: 221 timeout_s: timeout in seconds; None blocks indefinitely 222 """ 223 if timeout_s is UseDefault.VALUE: 224 timeout_s = self.default_timeout_s 225 226 try: 227 while True: 228 response = self._queue.get(block, timeout_s) 229 230 if isinstance(response, Exception): 231 raise response 232 233 if isinstance(response, Status): 234 self.status = response 235 return 236 237 yield response 238 except queue.Empty: 239 self.cancel() 240 raise RpcTimeout(self._method_client._rpc, timeout_s) # pylint: disable=protected-access 241 except: 242 self.cancel() 243 raise 244 245 def __iter__(self): 246 return self.responses() 247 248 def __repr__(self) -> str: 249 return f'{type(self).__name__}({self.method})' 250 251 252def _method_client_docstring(method: Method) -> str: 253 return f'''\ 254Class that invokes the {method.full_name} {method.type.sentence_name()} RPC. 255 256Calling this directly invokes the RPC synchronously. The RPC can be invoked 257asynchronously using the invoke method. 258''' 259 260 261def _function_docstring(method: Method) -> str: 262 return f'''\ 263Invokes the {method.full_name} {method.type.sentence_name()} RPC. 264 265This function accepts either the request protobuf fields as keyword arguments or 266a request protobuf as a positional argument. 267''' 268 269 270def _update_function_signature(method: Method, function: Callable) -> None: 271 """Updates the name, docstring, and parameters to match a method.""" 272 function.__name__ = method.full_name 273 function.__doc__ = _function_docstring(method) 274 275 # In order to have good tab completion and help messages, update the 276 # function signature to accept only keyword arguments for the proto message 277 # fields. This doesn't actually change the function signature -- it just 278 # updates how it appears when inspected. 279 sig = inspect.signature(function) 280 281 params = [next(iter(sig.parameters.values()))] # Get the "self" parameter 282 params += method.request_parameters() 283 params.append( 284 inspect.Parameter('pw_rpc_timeout_s', inspect.Parameter.KEYWORD_ONLY)) 285 function.__signature__ = sig.replace( # type: ignore[attr-defined] 286 parameters=params) 287 288 289class UnaryResponse(NamedTuple): 290 """Result of invoking a unary RPC: status and response.""" 291 status: Status 292 response: Any 293 294 def __repr__(self) -> str: 295 return f'({self.status}, {proto_repr(self.response)})' 296 297 298class _UnaryResponseHandler: 299 """Tracks the state of an ongoing synchronous unary RPC call.""" 300 def __init__(self, rpc: PendingRpc): 301 self._rpc = rpc 302 self._response: Any = None 303 self._status: Optional[Status] = None 304 self._error: Optional[RpcError] = None 305 self._event = threading.Event() 306 307 def on_response(self, _: PendingRpc, response: Any) -> None: 308 self._response = response 309 310 def on_completion(self, _: PendingRpc, status: Status) -> None: 311 self._status = status 312 self._event.set() 313 314 def on_error(self, _: PendingRpc, status: Status) -> None: 315 self._error = RpcError(self._rpc, status) 316 self._event.set() 317 318 def wait(self, timeout_s: Optional[float]) -> UnaryResponse: 319 if not self._event.wait(timeout_s): 320 raise RpcTimeout(self._rpc, timeout_s) 321 322 if self._error is not None: 323 raise self._error 324 325 assert self._status is not None 326 return UnaryResponse(self._status, self._response) 327 328 329def _unary_method_client(client_impl: 'Impl', rpcs: PendingRpcs, 330 channel: Channel, method: Method, 331 default_timeout: Optional[float]) -> _MethodClient: 332 """Creates an object used to call a unary method.""" 333 def call(self: _MethodClient, 334 _rpc_request_proto=None, 335 *, 336 pw_rpc_timeout_s=UseDefault.VALUE, 337 **request_fields) -> UnaryResponse: 338 339 handler = _UnaryResponseHandler(self._rpc) # pylint: disable=protected-access 340 self.invoke( 341 self.method.get_request(_rpc_request_proto, request_fields), 342 handler.on_response, handler.on_completion, handler.on_error) 343 344 if pw_rpc_timeout_s is UseDefault.VALUE: 345 pw_rpc_timeout_s = self.default_timeout_s 346 347 return handler.wait(pw_rpc_timeout_s) 348 349 _update_function_signature(method, call) 350 351 # The MethodClient class is created dynamically so that the __call__ method 352 # can be configured differently for each method. 353 method_client_type = type( 354 f'{method.name}_UnaryMethodClient', (_MethodClient, ), 355 dict(__call__=call, __doc__=_method_client_docstring(method))) 356 return method_client_type(client_impl, rpcs, channel, method, 357 default_timeout) 358 359 360def _server_streaming_method_client(client_impl: 'Impl', rpcs: PendingRpcs, 361 channel: Channel, method: Method, 362 default_timeout: Optional[float]): 363 """Creates an object used to call a server streaming method.""" 364 def call(self: _MethodClient, 365 _rpc_request_proto=None, 366 *, 367 pw_rpc_timeout_s=UseDefault.VALUE, 368 **request_fields) -> StreamingResponses: 369 responses: queue.SimpleQueue = queue.SimpleQueue() 370 self.invoke( 371 self.method.get_request(_rpc_request_proto, request_fields), 372 lambda _, response: responses.put(response), 373 lambda _, status: responses.put(status), 374 lambda rpc, status: responses.put(RpcError(rpc, status))) 375 return StreamingResponses(self, responses, pw_rpc_timeout_s) 376 377 _update_function_signature(method, call) 378 379 # The MethodClient class is created dynamically so that the __call__ method 380 # can be configured differently for each method type. 381 method_client_type = type( 382 f'{method.name}_ServerStreamingMethodClient', (_MethodClient, ), 383 dict(__call__=call, __doc__=_method_client_docstring(method))) 384 return method_client_type(client_impl, rpcs, channel, method, 385 default_timeout) 386 387 388class ClientStreamingMethodClient(_MethodClient): 389 def __call__(self): 390 raise NotImplementedError 391 392 def invoke(self, 393 request: Any, 394 response: ResponseCallback = _default_response, 395 completion: CompletionCallback = _default_completion, 396 error: ErrorCallback = _default_error, 397 *, 398 override_pending: bool = True, 399 keep_open: bool = False) -> _AsyncCall: 400 raise NotImplementedError 401 402 403class BidirectionalStreamingMethodClient(_MethodClient): 404 def __call__(self): 405 raise NotImplementedError 406 407 def invoke(self, 408 request: Any, 409 response: ResponseCallback = _default_response, 410 completion: CompletionCallback = _default_completion, 411 error: ErrorCallback = _default_error, 412 *, 413 override_pending: bool = True, 414 keep_open: bool = False) -> _AsyncCall: 415 raise NotImplementedError 416 417 418class Impl(client.ClientImpl): 419 """Callback-based ClientImpl.""" 420 def __init__(self, 421 default_unary_timeout_s: Optional[float] = 1.0, 422 default_stream_timeout_s: Optional[float] = 1.0): 423 super().__init__() 424 self._default_unary_timeout_s = default_unary_timeout_s 425 self._default_stream_timeout_s = default_stream_timeout_s 426 427 @property 428 def default_unary_timeout_s(self) -> Optional[float]: 429 return self._default_unary_timeout_s 430 431 @property 432 def default_stream_timeout_s(self) -> Optional[float]: 433 return self._default_stream_timeout_s 434 435 def method_client(self, channel: Channel, method: Method) -> _MethodClient: 436 """Returns an object that invokes a method using the given chanel.""" 437 438 if method.type is Method.Type.UNARY: 439 return _unary_method_client(self, self.rpcs, channel, method, 440 self.default_unary_timeout_s) 441 442 if method.type is Method.Type.SERVER_STREAMING: 443 return _server_streaming_method_client( 444 self, self.rpcs, channel, method, 445 self.default_stream_timeout_s) 446 447 if method.type is Method.Type.CLIENT_STREAMING: 448 return ClientStreamingMethodClient(self, self.rpcs, channel, 449 method, 450 self.default_unary_timeout_s) 451 452 if method.type is Method.Type.BIDIRECTIONAL_STREAMING: 453 return BidirectionalStreamingMethodClient( 454 self, self.rpcs, channel, method, 455 self.default_stream_timeout_s) 456 457 raise AssertionError(f'Unknown method type {method.type}') 458 459 def handle_response(self, 460 rpc: PendingRpc, 461 context, 462 payload, 463 *, 464 args: tuple = (), 465 kwargs: dict = None) -> None: 466 """Invokes the callback associated with this RPC. 467 468 Any additional positional and keyword args passed through 469 Client.process_packet are forwarded to the callback. 470 """ 471 if kwargs is None: 472 kwargs = {} 473 474 try: 475 context.response(rpc, payload, *args, **kwargs) 476 except: # pylint: disable=bare-except 477 self.rpcs.send_cancel(rpc) 478 _LOG.exception('Response callback %s for %s raised exception', 479 context.response, rpc) 480 481 def handle_completion(self, 482 rpc: PendingRpc, 483 context, 484 status: Status, 485 *, 486 args: tuple = (), 487 kwargs: dict = None): 488 if kwargs is None: 489 kwargs = {} 490 491 try: 492 context.completion(rpc, status, *args, **kwargs) 493 except: # pylint: disable=bare-except 494 _LOG.exception('Completion callback %s for %s raised exception', 495 context.completion, rpc) 496 497 def handle_error(self, 498 rpc: PendingRpc, 499 context, 500 status: Status, 501 *, 502 args: tuple = (), 503 kwargs: dict = None) -> None: 504 if kwargs is None: 505 kwargs = {} 506 507 try: 508 context.error(rpc, status, *args, **kwargs) 509 except: # pylint: disable=bare-except 510 _LOG.exception('Error callback %s for %s raised exception', 511 context.error, rpc) 512