1# Copyright 2021 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""The callback-based pw_rpc client implementation.""" 15 16import inspect 17import logging 18import textwrap 19from typing import Any, Callable, Dict, Iterable, Optional, Type 20 21from pw_status import Status 22from google.protobuf.message import Message 23 24from pw_rpc import client, descriptors 25from pw_rpc.client import PendingRpc, PendingRpcs 26from pw_rpc.descriptors import Channel, Method, Service 27 28from pw_rpc.callback_client.call import ( 29 UseDefault, 30 OptionalTimeout, 31 CallTypeT, 32 UnaryResponse, 33 StreamResponse, 34 Call, 35 UnaryCall, 36 ServerStreamingCall, 37 ClientStreamingCall, 38 BidirectionalStreamingCall, 39 OnNextCallback, 40 OnCompletedCallback, 41 OnErrorCallback, 42) 43 44_LOG = logging.getLogger(__package__) 45 46 47class _MethodClient: 48 """A method that can be invoked for a particular channel.""" 49 50 def __init__( 51 self, 52 client_impl: 'Impl', 53 rpcs: PendingRpcs, 54 channel: Channel, 55 method: Method, 56 default_timeout_s: Optional[float], 57 ) -> None: 58 self._impl = client_impl 59 self._rpcs = rpcs 60 self._rpc = PendingRpc(channel, method.service, method) 61 self.default_timeout_s: Optional[float] = default_timeout_s 62 63 @property 64 def channel(self) -> Channel: 65 return self._rpc.channel 66 67 @property 68 def method(self) -> Method: 69 return self._rpc.method 70 71 @property 72 def service(self) -> Service: 73 return self._rpc.service 74 75 @property 76 def request(self) -> type: 77 """Returns the request proto class.""" 78 return self.method.request_type 79 80 @property 81 def response(self) -> type: 82 """Returns the response proto class.""" 83 return self.method.response_type 84 85 def __repr__(self) -> str: 86 return self.help() 87 88 def help(self) -> str: 89 """Returns a help message about this RPC.""" 90 function_call = self.method.full_name + '(' 91 92 docstring = inspect.getdoc(self.__call__) # type: ignore[operator] # pylint: disable=no-member 93 assert docstring is not None 94 95 annotation = inspect.Signature.from_callable(self).return_annotation # type: ignore[arg-type] # pylint: disable=line-too-long 96 if isinstance(annotation, type): 97 annotation = annotation.__name__ 98 99 arg_sep = f',\n{" " * len(function_call)}' 100 return ( 101 f'{function_call}' 102 f'{arg_sep.join(descriptors.field_help(self.method.request_type))})' 103 f'\n\n{textwrap.indent(docstring, " ")}\n\n' 104 f' Returns {annotation}.' 105 ) 106 107 def _start_call( 108 self, 109 call_type: Type[CallTypeT], 110 request: Optional[Message], 111 timeout_s: OptionalTimeout, 112 on_next: Optional[OnNextCallback], 113 on_completed: Optional[OnCompletedCallback], 114 on_error: Optional[OnErrorCallback], 115 ignore_errors: bool = False, 116 ) -> CallTypeT: 117 """Creates the Call object and invokes the RPC using it.""" 118 if timeout_s is UseDefault.VALUE: 119 timeout_s = self.default_timeout_s 120 121 call = call_type( 122 self._rpcs, self._rpc, timeout_s, on_next, on_completed, on_error 123 ) 124 call._invoke(request, ignore_errors) # pylint: disable=protected-access 125 return call 126 127 def _client_streaming_call_type( 128 self, base: Type[CallTypeT] 129 ) -> Type[CallTypeT]: 130 """Creates a client or bidirectional stream call type. 131 132 Applies the signature from the request protobuf to the send method. 133 """ 134 135 def send( 136 self, _rpc_request_proto: Optional[Message] = None, **request_fields 137 ) -> None: 138 ClientStreamingCall.send(self, _rpc_request_proto, **request_fields) 139 140 _apply_protobuf_signature(self.method, send) 141 142 return type( 143 f'{self.method.name}_{base.__name__}', (base,), dict(send=send) 144 ) 145 146 147def _function_docstring(method: Method) -> str: 148 return f'''\ 149Invokes the {method.full_name} {method.type.sentence_name()} RPC. 150 151This function accepts either the request protobuf fields as keyword arguments or 152a request protobuf as a positional argument. 153''' 154 155 156def _update_call_method(method: Method, function: Callable) -> None: 157 """Updates the name, docstring, and parameters to match a method.""" 158 function.__name__ = method.full_name 159 function.__doc__ = _function_docstring(method) 160 _apply_protobuf_signature(method, function) 161 162 163def _apply_protobuf_signature(method: Method, function: Callable) -> None: 164 """Update a function signature to accept proto arguments. 165 166 In order to have good tab completion and help messages, update the function 167 signature to accept only keyword arguments for the proto message fields. 168 This doesn't actually change the function signature -- it just updates how 169 it appears when inspected. 170 """ 171 sig = inspect.signature(function) 172 173 params = [next(iter(sig.parameters.values()))] # Get the "self" parameter 174 params += method.request_parameters() 175 params.append( 176 inspect.Parameter('pw_rpc_timeout_s', inspect.Parameter.KEYWORD_ONLY) 177 ) 178 179 function.__signature__ = sig.replace( # type: ignore[attr-defined] 180 parameters=params 181 ) 182 183 184class _UnaryMethodClient(_MethodClient): 185 def invoke( 186 self, 187 request: Optional[Message] = None, 188 on_next: Optional[OnNextCallback] = None, 189 on_completed: Optional[OnCompletedCallback] = None, 190 on_error: Optional[OnErrorCallback] = None, 191 *, 192 request_args: Optional[Dict[str, Any]] = None, 193 timeout_s: OptionalTimeout = UseDefault.VALUE, 194 ) -> UnaryCall: 195 """Invokes the unary RPC and returns a call object.""" 196 return self._start_call( 197 UnaryCall, 198 self.method.get_request(request, request_args), 199 timeout_s, 200 on_next, 201 on_completed, 202 on_error, 203 ) 204 205 def open( 206 self, 207 request: Optional[Message] = None, 208 on_next: Optional[OnNextCallback] = None, 209 on_completed: Optional[OnCompletedCallback] = None, 210 on_error: Optional[OnErrorCallback] = None, 211 *, 212 request_args: Optional[Dict[str, Any]] = None, 213 ) -> UnaryCall: 214 """Invokes the unary RPC and returns a call object.""" 215 return self._start_call( 216 UnaryCall, 217 self.method.get_request(request, request_args), 218 None, 219 on_next, 220 on_completed, 221 on_error, 222 True, 223 ) 224 225 226class _ServerStreamingMethodClient(_MethodClient): 227 def invoke( 228 self, 229 request: Optional[Message] = None, 230 on_next: Optional[OnNextCallback] = None, 231 on_completed: Optional[OnCompletedCallback] = None, 232 on_error: Optional[OnErrorCallback] = None, 233 *, 234 request_args: Optional[Dict[str, Any]] = None, 235 timeout_s: OptionalTimeout = UseDefault.VALUE, 236 ) -> ServerStreamingCall: 237 """Invokes the server streaming RPC and returns a call object.""" 238 return self._start_call( 239 ServerStreamingCall, 240 self.method.get_request(request, request_args), 241 timeout_s, 242 on_next, 243 on_completed, 244 on_error, 245 ) 246 247 def open( 248 self, 249 request: Optional[Message] = None, 250 on_next: Optional[OnNextCallback] = None, 251 on_completed: Optional[OnCompletedCallback] = None, 252 on_error: Optional[OnErrorCallback] = None, 253 *, 254 request_args: Optional[Dict[str, Any]] = None, 255 ) -> ServerStreamingCall: 256 """Returns a call object for the RPC, even if the RPC cannot be invoked. 257 258 Can be used to listen for responses from an RPC server that may yet be 259 available. 260 """ 261 return self._start_call( 262 ServerStreamingCall, 263 self.method.get_request(request, request_args), 264 None, 265 on_next, 266 on_completed, 267 on_error, 268 True, 269 ) 270 271 272class _ClientStreamingMethodClient(_MethodClient): 273 def invoke( 274 self, 275 on_next: Optional[OnNextCallback] = None, 276 on_completed: Optional[OnCompletedCallback] = None, 277 on_error: Optional[OnErrorCallback] = None, 278 *, 279 timeout_s: OptionalTimeout = UseDefault.VALUE, 280 ) -> ClientStreamingCall: 281 """Invokes the client streaming RPC and returns a call object""" 282 return self._start_call( 283 self._client_streaming_call_type(ClientStreamingCall), 284 None, 285 timeout_s, 286 on_next, 287 on_completed, 288 on_error, 289 True, 290 ) 291 292 def open( 293 self, 294 on_next: Optional[OnNextCallback] = None, 295 on_completed: Optional[OnCompletedCallback] = None, 296 on_error: Optional[OnErrorCallback] = None, 297 ) -> ClientStreamingCall: 298 """Returns a call object for the RPC, even if the RPC cannot be invoked. 299 300 Can be used to listen for responses from an RPC server that may yet be 301 available. 302 """ 303 return self._start_call( 304 self._client_streaming_call_type(ClientStreamingCall), 305 None, 306 None, 307 on_next, 308 on_completed, 309 on_error, 310 True, 311 ) 312 313 def __call__( 314 self, 315 requests: Iterable[Message] = (), 316 *, 317 timeout_s: OptionalTimeout = UseDefault.VALUE, 318 ) -> UnaryResponse: 319 return self.invoke().finish_and_wait(requests, timeout_s=timeout_s) 320 321 322class _BidirectionalStreamingMethodClient(_MethodClient): 323 def invoke( 324 self, 325 on_next: Optional[OnNextCallback] = None, 326 on_completed: Optional[OnCompletedCallback] = None, 327 on_error: Optional[OnErrorCallback] = None, 328 *, 329 timeout_s: OptionalTimeout = UseDefault.VALUE, 330 ) -> BidirectionalStreamingCall: 331 """Invokes the bidirectional streaming RPC and returns a call object.""" 332 return self._start_call( 333 self._client_streaming_call_type(BidirectionalStreamingCall), 334 None, 335 timeout_s, 336 on_next, 337 on_completed, 338 on_error, 339 ) 340 341 def open( 342 self, 343 on_next: Optional[OnNextCallback] = None, 344 on_completed: Optional[OnCompletedCallback] = None, 345 on_error: Optional[OnErrorCallback] = None, 346 ) -> BidirectionalStreamingCall: 347 """Returns a call object for the RPC, even if the RPC cannot be invoked. 348 349 Can be used to listen for responses from an RPC server that may yet be 350 available. 351 """ 352 return self._start_call( 353 self._client_streaming_call_type(BidirectionalStreamingCall), 354 None, 355 None, 356 on_next, 357 on_completed, 358 on_error, 359 True, 360 ) 361 362 def __call__( 363 self, 364 requests: Iterable[Message] = (), 365 *, 366 timeout_s: OptionalTimeout = UseDefault.VALUE, 367 ) -> StreamResponse: 368 return self.invoke().finish_and_wait(requests, timeout_s=timeout_s) 369 370 371def _method_client_docstring(method: Method) -> str: 372 return f'''\ 373Class that invokes the {method.full_name} {method.type.sentence_name()} RPC. 374 375Calling this directly invokes the RPC synchronously. The RPC can be invoked 376asynchronously using the invoke method. 377''' 378 379 380class Impl(client.ClientImpl): 381 """Callback-based ClientImpl, for use with pw_rpc.Client.""" 382 383 def __init__( 384 self, 385 default_unary_timeout_s: Optional[float] = None, 386 default_stream_timeout_s: Optional[float] = None, 387 cancel_duplicate_calls: Optional[bool] = True, 388 ) -> None: 389 super().__init__() 390 self._default_unary_timeout_s = default_unary_timeout_s 391 self._default_stream_timeout_s = default_stream_timeout_s 392 393 # Temporary workaround for clients that rely on mulitple in-flight 394 # instances of an RPC on the same channel, which is not supported. 395 # TODO(hepler): Remove this option when clients have updated. 396 self._cancel_duplicate_calls = cancel_duplicate_calls 397 398 @property 399 def default_unary_timeout_s(self) -> Optional[float]: 400 return self._default_unary_timeout_s 401 402 @property 403 def default_stream_timeout_s(self) -> Optional[float]: 404 return self._default_stream_timeout_s 405 406 def method_client(self, channel: Channel, method: Method) -> _MethodClient: 407 """Returns an object that invokes a method using the given chanel.""" 408 409 # Temporarily attach the cancel_duplicate_calls option to the 410 # PendingRpcs object. 411 # TODO(hepler): Remove this workaround. 412 assert self.rpcs 413 self.rpcs.cancel_duplicate_calls = ( # type: ignore[attr-defined] 414 self._cancel_duplicate_calls 415 ) 416 417 if method.type is Method.Type.UNARY: 418 return self._create_unary_method_client( 419 channel, method, self.default_unary_timeout_s 420 ) 421 422 if method.type is Method.Type.SERVER_STREAMING: 423 return self._create_server_streaming_method_client( 424 channel, method, self.default_stream_timeout_s 425 ) 426 427 if method.type is Method.Type.CLIENT_STREAMING: 428 return self._create_method_client( 429 _ClientStreamingMethodClient, 430 channel, 431 method, 432 self.default_unary_timeout_s, 433 ) 434 435 if method.type is Method.Type.BIDIRECTIONAL_STREAMING: 436 return self._create_method_client( 437 _BidirectionalStreamingMethodClient, 438 channel, 439 method, 440 self.default_stream_timeout_s, 441 ) 442 443 raise AssertionError(f'Unknown method type {method.type}') 444 445 def _create_method_client( 446 self, 447 base: type, 448 channel: Channel, 449 method: Method, 450 default_timeout_s: Optional[float], 451 **fields, 452 ): 453 """Creates a _MethodClient derived class customized for the method.""" 454 method_client_type = type( 455 f'{method.name}{base.__name__}', 456 (base,), 457 dict(__doc__=_method_client_docstring(method), **fields), 458 ) 459 return method_client_type( 460 self, self.rpcs, channel, method, default_timeout_s 461 ) 462 463 def _create_unary_method_client( 464 self, 465 channel: Channel, 466 method: Method, 467 default_timeout_s: Optional[float], 468 ) -> _UnaryMethodClient: 469 """Creates a _UnaryMethodClient with a customized __call__ method.""" 470 471 # TODO(hepler): Use / to mark the first arg as positional-only 472 # when when Python 3.7 support is no longer required. 473 def call( 474 self: _UnaryMethodClient, 475 _rpc_request_proto: Optional[Message] = None, 476 *, 477 pw_rpc_timeout_s: OptionalTimeout = UseDefault.VALUE, 478 **request_fields, 479 ) -> UnaryResponse: 480 return self.invoke( 481 self.method.get_request(_rpc_request_proto, request_fields) 482 ).wait(pw_rpc_timeout_s) 483 484 _update_call_method(method, call) 485 return self._create_method_client( 486 _UnaryMethodClient, 487 channel, 488 method, 489 default_timeout_s, 490 __call__=call, 491 ) 492 493 def _create_server_streaming_method_client( 494 self, 495 channel: Channel, 496 method: Method, 497 default_timeout_s: Optional[float], 498 ) -> _ServerStreamingMethodClient: 499 """Creates _ServerStreamingMethodClient with custom __call__ method.""" 500 501 # TODO(hepler): Use / to mark the first arg as positional-only 502 # when when Python 3.7 support is no longer required. 503 def call( 504 self: _ServerStreamingMethodClient, 505 _rpc_request_proto: Optional[Message] = None, 506 *, 507 pw_rpc_timeout_s: OptionalTimeout = UseDefault.VALUE, 508 **request_fields, 509 ) -> StreamResponse: 510 return self.invoke( 511 self.method.get_request(_rpc_request_proto, request_fields) 512 ).wait(pw_rpc_timeout_s) 513 514 _update_call_method(method, call) 515 return self._create_method_client( 516 _ServerStreamingMethodClient, 517 channel, 518 method, 519 default_timeout_s, 520 __call__=call, 521 ) 522 523 def handle_response( 524 self, 525 rpc: PendingRpc, 526 context: Call, 527 payload, 528 *, 529 args: tuple = (), 530 kwargs: Optional[dict] = None, 531 ) -> None: 532 """Invokes the callback associated with this RPC.""" 533 assert not args and not kwargs, 'Forwarding args & kwargs not supported' 534 context._handle_response(payload) # pylint: disable=protected-access 535 536 def handle_completion( 537 self, 538 rpc: PendingRpc, 539 context: Call, 540 status: Status, 541 *, 542 args: tuple = (), 543 kwargs: Optional[dict] = None, 544 ): 545 assert not args and not kwargs, 'Forwarding args & kwargs not supported' 546 context._handle_completion(status) # pylint: disable=protected-access 547 548 def handle_error( 549 self, 550 rpc: PendingRpc, 551 context: Call, 552 status: Status, 553 *, 554 args: tuple = (), 555 kwargs: Optional[dict] = None, 556 ) -> None: 557 assert not args and not kwargs, 'Forwarding args & kwargs not supported' 558 context._handle_error(status) # pylint: disable=protected-access 559