1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""The callback-based pw_rpc client implementation.""" 15 16from __future__ import annotations 17 18import inspect 19import logging 20import textwrap 21from typing import Any, Callable, Iterable, Type 22 23from dataclasses import dataclass 24from pw_status import Status 25from google.protobuf.message import Message 26 27from pw_rpc import client, descriptors 28from pw_rpc.client import PendingRpc, PendingRpcs 29from pw_rpc.descriptors import Channel, Method, Service 30 31from pw_rpc.callback_client.call import ( 32 UseDefault, 33 OptionalTimeout, 34 CallTypeT, 35 UnaryResponse, 36 StreamResponse, 37 Call, 38 UnaryCall, 39 ServerStreamingCall, 40 ClientStreamingCall, 41 BidirectionalStreamingCall, 42 OnNextCallback, 43 OnCompletedCallback, 44 OnErrorCallback, 45) 46 47_LOG = logging.getLogger(__package__) 48 49 50@dataclass(eq=True, frozen=True) 51class CallInfo: 52 method: Method 53 54 @property 55 def service(self) -> Service: 56 return self.method.service 57 58 59class _MethodClient: 60 """A method that can be invoked for a particular channel.""" 61 62 def __init__( 63 self, 64 client_impl: Impl, 65 rpcs: PendingRpcs, 66 channel: Channel, 67 method: Method, 68 default_timeout_s: float | None, 69 ) -> None: 70 self._impl = client_impl 71 self._rpcs = rpcs 72 self._channel = channel 73 self._method = method 74 self.default_timeout_s: float | None = default_timeout_s 75 76 @property 77 def channel(self) -> Channel: 78 return self._channel 79 80 @property 81 def method(self) -> Method: 82 return self._method 83 84 @property 85 def service(self) -> Service: 86 return self._method.service 87 88 @property 89 def request(self) -> type: 90 """Returns the request proto class.""" 91 return self.method.request_type 92 93 @property 94 def response(self) -> type: 95 """Returns the response proto class.""" 96 return self.method.response_type 97 98 def __repr__(self) -> str: 99 return self.help() 100 101 def help(self) -> str: 102 """Returns a help message about this RPC.""" 103 function_call = self.method.full_name + '(' 104 105 docstring = inspect.getdoc(self.__call__) # type: ignore[operator] # pylint: disable=no-member 106 assert docstring is not None 107 108 annotation = inspect.Signature.from_callable(self).return_annotation # type: ignore[arg-type] # pylint: disable=line-too-long 109 if isinstance(annotation, type): 110 annotation = annotation.__name__ 111 112 arg_sep = f',\n{" " * len(function_call)}' 113 return ( 114 f'{function_call}' 115 f'{arg_sep.join(descriptors.field_help(self.method.request_type))})' 116 f'\n\n{textwrap.indent(docstring, " ")}\n\n' 117 f' Returns {annotation}.' 118 ) 119 120 def _start_call( 121 self, 122 call_type: Type[CallTypeT], 123 request: Message | None, 124 timeout_s: OptionalTimeout, 125 on_next: OnNextCallback | None, 126 on_completed: OnCompletedCallback | None, 127 on_error: OnErrorCallback | None, 128 ignore_errors: bool = False, 129 ) -> CallTypeT: 130 """Creates the Call object and invokes the RPC using it.""" 131 if timeout_s is UseDefault.VALUE: 132 timeout_s = self.default_timeout_s 133 134 if self._impl.on_call_hook: 135 self._impl.on_call_hook(CallInfo(self._method)) 136 137 rpc = PendingRpc( 138 self._channel, 139 self.service, 140 self.method, 141 self._rpcs.allocate_call_id(), 142 ) 143 call = call_type( 144 self._rpcs, rpc, timeout_s, on_next, on_completed, on_error 145 ) 146 call._invoke(request, ignore_errors) # pylint: disable=protected-access 147 return call 148 149 def _client_streaming_call_type( 150 self, base: Type[CallTypeT] 151 ) -> Type[CallTypeT]: 152 """Creates a client or bidirectional stream call type. 153 154 Applies the signature from the request protobuf to the send method. 155 """ 156 157 def send( 158 self, _rpc_request_proto: Message | None = None, **request_fields 159 ) -> None: 160 ClientStreamingCall.send(self, _rpc_request_proto, **request_fields) 161 162 _apply_protobuf_signature(self.method, send) 163 164 return type( 165 f'{self.method.name}_{base.__name__}', (base,), dict(send=send) 166 ) 167 168 169def _function_docstring(method: Method) -> str: 170 return f'''\ 171Invokes the {method.full_name} {method.type.sentence_name()} RPC. 172 173This function accepts either the request protobuf fields as keyword arguments or 174a request protobuf as a positional argument. 175''' 176 177 178def _update_call_method(method: Method, function: Callable) -> None: 179 """Updates the name, docstring, and parameters to match a method.""" 180 function.__name__ = method.full_name 181 function.__doc__ = _function_docstring(method) 182 _apply_protobuf_signature(method, function) 183 184 185def _apply_protobuf_signature(method: Method, function: Callable) -> None: 186 """Update a function signature to accept proto arguments. 187 188 In order to have good tab completion and help messages, update the function 189 signature to accept only keyword arguments for the proto message fields. 190 This doesn't actually change the function signature -- it just updates how 191 it appears when inspected. 192 """ 193 sig = inspect.signature(function) 194 195 params = [next(iter(sig.parameters.values()))] # Get the "self" parameter 196 params += method.request_parameters() 197 params.append( 198 inspect.Parameter('pw_rpc_timeout_s', inspect.Parameter.KEYWORD_ONLY) 199 ) 200 201 function.__signature__ = sig.replace( # type: ignore[attr-defined] 202 parameters=params 203 ) 204 205 206class _UnaryMethodClient(_MethodClient): 207 def invoke( 208 self, 209 request: Message | None = None, 210 on_next: OnNextCallback | None = None, 211 on_completed: OnCompletedCallback | None = None, 212 on_error: OnErrorCallback | None = None, 213 *, 214 request_args: dict[str, Any] | None = None, 215 timeout_s: OptionalTimeout = UseDefault.VALUE, 216 ) -> UnaryCall: 217 """Invokes the unary RPC and returns a call object.""" 218 return self._start_call( 219 UnaryCall, 220 self.method.get_request(request, request_args), 221 timeout_s, 222 on_next, 223 on_completed, 224 on_error, 225 ) 226 227 def open( 228 self, 229 request: Message | None = None, 230 on_next: OnNextCallback | None = None, 231 on_completed: OnCompletedCallback | None = None, 232 on_error: OnErrorCallback | None = None, 233 *, 234 request_args: dict[str, Any] | None = None, 235 ) -> UnaryCall: 236 """Invokes the unary RPC and returns a call object.""" 237 return self._start_call( 238 UnaryCall, 239 self.method.get_request(request, request_args), 240 None, 241 on_next, 242 on_completed, 243 on_error, 244 True, 245 ) 246 247 248class _ServerStreamingMethodClient(_MethodClient): 249 def invoke( 250 self, 251 request: Message | None = None, 252 on_next: OnNextCallback | None = None, 253 on_completed: OnCompletedCallback | None = None, 254 on_error: OnErrorCallback | None = None, 255 *, 256 request_args: dict[str, Any] | None = None, 257 timeout_s: OptionalTimeout = UseDefault.VALUE, 258 ) -> ServerStreamingCall: 259 """Invokes the server streaming RPC and returns a call object.""" 260 return self._start_call( 261 ServerStreamingCall, 262 self.method.get_request(request, request_args), 263 timeout_s, 264 on_next, 265 on_completed, 266 on_error, 267 ) 268 269 def open( 270 self, 271 request: Message | None = None, 272 on_next: OnNextCallback | None = None, 273 on_completed: OnCompletedCallback | None = None, 274 on_error: OnErrorCallback | None = None, 275 *, 276 request_args: dict[str, Any] | None = None, 277 ) -> ServerStreamingCall: 278 """Returns a call object for the RPC, even if the RPC cannot be invoked. 279 280 Can be used to listen for responses from an RPC server that may yet be 281 available. 282 """ 283 return self._start_call( 284 ServerStreamingCall, 285 self.method.get_request(request, request_args), 286 None, 287 on_next, 288 on_completed, 289 on_error, 290 True, 291 ) 292 293 294class _ClientStreamingMethodClient(_MethodClient): 295 def invoke( 296 self, 297 on_next: OnNextCallback | None = None, 298 on_completed: OnCompletedCallback | None = None, 299 on_error: OnErrorCallback | None = None, 300 *, 301 timeout_s: OptionalTimeout = UseDefault.VALUE, 302 ) -> ClientStreamingCall: 303 """Invokes the client streaming RPC and returns a call object""" 304 return self._start_call( 305 self._client_streaming_call_type(ClientStreamingCall), 306 None, 307 timeout_s, 308 on_next, 309 on_completed, 310 on_error, 311 True, 312 ) 313 314 def open( 315 self, 316 on_next: OnNextCallback | None = None, 317 on_completed: OnCompletedCallback | None = None, 318 on_error: OnErrorCallback | None = None, 319 ) -> ClientStreamingCall: 320 """Returns a call object for the RPC, even if the RPC cannot be invoked. 321 322 Can be used to listen for responses from an RPC server that may yet be 323 available. 324 """ 325 return self._start_call( 326 self._client_streaming_call_type(ClientStreamingCall), 327 None, 328 None, 329 on_next, 330 on_completed, 331 on_error, 332 True, 333 ) 334 335 def __call__( 336 self, 337 requests: Iterable[Message] = (), 338 *, 339 timeout_s: OptionalTimeout = UseDefault.VALUE, 340 ) -> UnaryResponse: 341 return self.invoke().finish_and_wait(requests, timeout_s=timeout_s) 342 343 344class _BidirectionalStreamingMethodClient(_MethodClient): 345 def invoke( 346 self, 347 on_next: OnNextCallback | None = None, 348 on_completed: OnCompletedCallback | None = None, 349 on_error: OnErrorCallback | None = None, 350 *, 351 timeout_s: OptionalTimeout = UseDefault.VALUE, 352 ) -> BidirectionalStreamingCall: 353 """Invokes the bidirectional streaming RPC and returns a call object.""" 354 return self._start_call( 355 self._client_streaming_call_type(BidirectionalStreamingCall), 356 None, 357 timeout_s, 358 on_next, 359 on_completed, 360 on_error, 361 ) 362 363 def open( 364 self, 365 on_next: OnNextCallback | None = None, 366 on_completed: OnCompletedCallback | None = None, 367 on_error: OnErrorCallback | None = None, 368 ) -> BidirectionalStreamingCall: 369 """Returns a call object for the RPC, even if the RPC cannot be invoked. 370 371 Can be used to listen for responses from an RPC server that may yet be 372 available. 373 """ 374 return self._start_call( 375 self._client_streaming_call_type(BidirectionalStreamingCall), 376 None, 377 None, 378 on_next, 379 on_completed, 380 on_error, 381 True, 382 ) 383 384 def __call__( 385 self, 386 requests: Iterable[Message] = (), 387 *, 388 timeout_s: OptionalTimeout = UseDefault.VALUE, 389 ) -> StreamResponse: 390 return self.invoke().finish_and_wait(requests, timeout_s=timeout_s) 391 392 393def _method_client_docstring(method: Method) -> str: 394 return f'''\ 395Class that invokes the {method.full_name} {method.type.sentence_name()} RPC. 396 397Calling this directly invokes the RPC synchronously. The RPC can be invoked 398asynchronously using the invoke method. 399''' 400 401 402class Impl(client.ClientImpl): 403 """Callback-based ClientImpl, for use with pw_rpc.Client. 404 405 Args: 406 on_call_hook: A callable object to handle RPC method calls. 407 If hook is set, it will be called before RPC execution. 408 """ 409 410 def __init__( 411 self, 412 default_unary_timeout_s: float | None = None, 413 default_stream_timeout_s: float | None = None, 414 on_call_hook: Callable[[CallInfo], Any] | None = None, 415 cancel_duplicate_calls: bool | None = True, 416 ) -> None: 417 super().__init__() 418 self._default_unary_timeout_s = default_unary_timeout_s 419 self._default_stream_timeout_s = default_stream_timeout_s 420 self.on_call_hook = on_call_hook 421 # Temporary workaround for clients that rely on mulitple in-flight 422 # instances of an RPC on the same channel, which is not supported. 423 # TODO(hepler): Remove this option when clients have updated. 424 self._cancel_duplicate_calls = cancel_duplicate_calls 425 426 @property 427 def default_unary_timeout_s(self) -> float | None: 428 return self._default_unary_timeout_s 429 430 @property 431 def default_stream_timeout_s(self) -> float | None: 432 return self._default_stream_timeout_s 433 434 def method_client(self, channel: Channel, method: Method) -> _MethodClient: 435 """Returns an object that invokes a method using the given chanel.""" 436 437 # Temporarily attach the cancel_duplicate_calls option to the 438 # PendingRpcs object. 439 # TODO(hepler): Remove this workaround. 440 assert self.rpcs 441 self.rpcs.cancel_duplicate_calls = ( # type: ignore[attr-defined] 442 self._cancel_duplicate_calls 443 ) 444 445 if method.type is Method.Type.UNARY: 446 return self._create_unary_method_client( 447 channel, method, self.default_unary_timeout_s 448 ) 449 450 if method.type is Method.Type.SERVER_STREAMING: 451 return self._create_server_streaming_method_client( 452 channel, method, self.default_stream_timeout_s 453 ) 454 455 if method.type is Method.Type.CLIENT_STREAMING: 456 return self._create_method_client( 457 _ClientStreamingMethodClient, 458 channel, 459 method, 460 self.default_unary_timeout_s, 461 ) 462 463 if method.type is Method.Type.BIDIRECTIONAL_STREAMING: 464 return self._create_method_client( 465 _BidirectionalStreamingMethodClient, 466 channel, 467 method, 468 self.default_stream_timeout_s, 469 ) 470 471 raise AssertionError(f'Unknown method type {method.type}') 472 473 def _create_method_client( 474 self, 475 base: type, 476 channel: Channel, 477 method: Method, 478 default_timeout_s: float | None, 479 **fields, 480 ): 481 """Creates a _MethodClient derived class customized for the method.""" 482 method_client_type = type( 483 f'{method.name}{base.__name__}', 484 (base,), 485 dict(__doc__=_method_client_docstring(method), **fields), 486 ) 487 return method_client_type( 488 self, self.rpcs, channel, method, default_timeout_s 489 ) 490 491 def _create_unary_method_client( 492 self, 493 channel: Channel, 494 method: Method, 495 default_timeout_s: float | None, 496 ) -> _UnaryMethodClient: 497 """Creates a _UnaryMethodClient with a customized __call__ method.""" 498 499 # TODO(hepler): Use / to mark the first arg as positional-only 500 # when when Python 3.7 support is no longer required. 501 def call( 502 self: _UnaryMethodClient, 503 _rpc_request_proto: Message | None = None, 504 *, 505 pw_rpc_timeout_s: OptionalTimeout = UseDefault.VALUE, 506 **request_fields, 507 ) -> UnaryResponse: 508 return self.invoke( 509 self.method.get_request(_rpc_request_proto, request_fields) 510 ).wait(pw_rpc_timeout_s) 511 512 _update_call_method(method, call) 513 return self._create_method_client( 514 _UnaryMethodClient, 515 channel, 516 method, 517 default_timeout_s, 518 __call__=call, 519 ) 520 521 def _create_server_streaming_method_client( 522 self, 523 channel: Channel, 524 method: Method, 525 default_timeout_s: float | None, 526 ) -> _ServerStreamingMethodClient: 527 """Creates _ServerStreamingMethodClient with custom __call__ method.""" 528 529 # TODO(hepler): Use / to mark the first arg as positional-only 530 # when when Python 3.7 support is no longer required. 531 def call( 532 self: _ServerStreamingMethodClient, 533 _rpc_request_proto: Message | None = None, 534 *, 535 pw_rpc_timeout_s: OptionalTimeout = UseDefault.VALUE, 536 **request_fields, 537 ) -> StreamResponse: 538 return self.invoke( 539 self.method.get_request(_rpc_request_proto, request_fields) 540 ).wait(pw_rpc_timeout_s) 541 542 _update_call_method(method, call) 543 return self._create_method_client( 544 _ServerStreamingMethodClient, 545 channel, 546 method, 547 default_timeout_s, 548 __call__=call, 549 ) 550 551 def handle_response( 552 self, 553 rpc: PendingRpc, 554 context: Call, 555 payload, 556 *, 557 args: tuple = (), 558 kwargs: dict | None = None, 559 ) -> None: 560 """Invokes the callback associated with this RPC.""" 561 assert not args and not kwargs, 'Forwarding args & kwargs not supported' 562 context._handle_response(payload) # pylint: disable=protected-access 563 564 def handle_completion( 565 self, 566 rpc: PendingRpc, 567 context: Call, 568 status: Status, 569 *, 570 args: tuple = (), 571 kwargs: dict | None = None, 572 ): 573 assert not args and not kwargs, 'Forwarding args & kwargs not supported' 574 context._handle_completion(status) # pylint: disable=protected-access 575 576 def handle_error( 577 self, 578 rpc: PendingRpc, 579 context: Call, 580 status: Status, 581 *, 582 args: tuple = (), 583 kwargs: dict | None = None, 584 ) -> None: 585 assert not args and not kwargs, 'Forwarding args & kwargs not supported' 586 context._handle_error(status) # pylint: disable=protected-access 587