1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Service-side implementation of gRPC Python.""" 15 16from __future__ import annotations 17 18import abc 19import collections 20from concurrent import futures 21import contextvars 22import enum 23import logging 24import threading 25import time 26import traceback 27from typing import ( 28 Any, 29 Callable, 30 Dict, 31 Iterable, 32 Iterator, 33 List, 34 Mapping, 35 Optional, 36 Sequence, 37 Set, 38 Tuple, 39 Union, 40) 41 42import grpc # pytype: disable=pyi-error 43from grpc import _common # pytype: disable=pyi-error 44from grpc import _compression # pytype: disable=pyi-error 45from grpc import _interceptor # pytype: disable=pyi-error 46from grpc import _observability # pytype: disable=pyi-error 47from grpc._cython import cygrpc 48from grpc._typing import ArityAgnosticMethodHandler 49from grpc._typing import ChannelArgumentType 50from grpc._typing import DeserializingFunction 51from grpc._typing import MetadataType 52from grpc._typing import NullaryCallbackType 53from grpc._typing import ResponseType 54from grpc._typing import SerializingFunction 55from grpc._typing import ServerCallbackTag 56from grpc._typing import ServerTagCallbackType 57 58_LOGGER = logging.getLogger(__name__) 59 60_SHUTDOWN_TAG = "shutdown" 61_REQUEST_CALL_TAG = "request_call" 62 63_RECEIVE_CLOSE_ON_SERVER_TOKEN = "receive_close_on_server" 64_SEND_INITIAL_METADATA_TOKEN = "send_initial_metadata" 65_RECEIVE_MESSAGE_TOKEN = "receive_message" 66_SEND_MESSAGE_TOKEN = "send_message" 67_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = ( 68 "send_initial_metadata * send_message" 69) 70_SEND_STATUS_FROM_SERVER_TOKEN = "send_status_from_server" 71_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = ( 72 "send_initial_metadata * send_status_from_server" 73) 74 75_OPEN = "open" 76_CLOSED = "closed" 77_CANCELLED = "cancelled" 78 79_EMPTY_FLAGS = 0 80 81_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0 82_INF_TIMEOUT = 1e9 83 84 85def _serialized_request(request_event: cygrpc.BaseEvent) -> bytes: 86 return request_event.batch_operations[0].message() 87 88 89def _application_code(code: grpc.StatusCode) -> cygrpc.StatusCode: 90 cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code) 91 return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code 92 93 94def _completion_code(state: _RPCState) -> cygrpc.StatusCode: 95 if state.code is None: 96 return cygrpc.StatusCode.ok 97 else: 98 return _application_code(state.code) 99 100 101def _abortion_code( 102 state: _RPCState, code: cygrpc.StatusCode 103) -> cygrpc.StatusCode: 104 if state.code is None: 105 return code 106 else: 107 return _application_code(state.code) 108 109 110def _details(state: _RPCState) -> bytes: 111 return b"" if state.details is None else state.details 112 113 114class _HandlerCallDetails( 115 collections.namedtuple( 116 "_HandlerCallDetails", 117 ( 118 "method", 119 "invocation_metadata", 120 ), 121 ), 122 grpc.HandlerCallDetails, 123): 124 pass 125 126 127class _Method(abc.ABC): 128 @abc.abstractmethod 129 def name(self) -> Optional[str]: 130 raise NotImplementedError() 131 132 @abc.abstractmethod 133 def handler( 134 self, handler_call_details: _HandlerCallDetails 135 ) -> Optional[grpc.RpcMethodHandler]: 136 raise NotImplementedError() 137 138 139class _RegisteredMethod(_Method): 140 def __init__( 141 self, 142 name: str, 143 registered_handler: Optional[grpc.RpcMethodHandler], 144 ): 145 self._name = name 146 self._registered_handler = registered_handler 147 148 def name(self) -> Optional[str]: 149 return self._name 150 151 def handler( 152 self, handler_call_details: _HandlerCallDetails 153 ) -> Optional[grpc.RpcMethodHandler]: 154 return self._registered_handler 155 156 157class _GenericMethod(_Method): 158 def __init__( 159 self, 160 generic_handlers: List[grpc.GenericRpcHandler], 161 ): 162 self._generic_handlers = generic_handlers 163 164 def name(self) -> Optional[str]: 165 return None 166 167 def handler( 168 self, handler_call_details: _HandlerCallDetails 169 ) -> Optional[grpc.RpcMethodHandler]: 170 # If the same method have both generic and registered handler, 171 # registered handler will take precedence. 172 for generic_handler in self._generic_handlers: 173 method_handler = generic_handler.service(handler_call_details) 174 if method_handler is not None: 175 return method_handler 176 return None 177 178 179class _RPCState(object): 180 context: contextvars.Context 181 condition: threading.Condition 182 due = Set[str] 183 request: Any 184 client: str 185 initial_metadata_allowed: bool 186 compression_algorithm: Optional[grpc.Compression] 187 disable_next_compression: bool 188 trailing_metadata: Optional[MetadataType] 189 code: Optional[grpc.StatusCode] 190 details: Optional[bytes] 191 statused: bool 192 rpc_errors: List[Exception] 193 callbacks: Optional[List[NullaryCallbackType]] 194 aborted: bool 195 196 def __init__(self): 197 self.context = contextvars.Context() 198 self.condition = threading.Condition() 199 self.due = set() 200 self.request = None 201 self.client = _OPEN 202 self.initial_metadata_allowed = True 203 self.compression_algorithm = None 204 self.disable_next_compression = False 205 self.trailing_metadata = None 206 self.code = None 207 self.details = None 208 self.statused = False 209 self.rpc_errors = [] 210 self.callbacks = [] 211 self.aborted = False 212 213 214def _raise_rpc_error(state: _RPCState) -> None: 215 rpc_error = grpc.RpcError() 216 state.rpc_errors.append(rpc_error) 217 raise rpc_error 218 219 220def _possibly_finish_call( 221 state: _RPCState, token: str 222) -> ServerTagCallbackType: 223 state.due.remove(token) 224 if not _is_rpc_state_active(state) and not state.due: 225 callbacks = state.callbacks 226 state.callbacks = None 227 return state, callbacks 228 else: 229 return None, () 230 231 232def _send_status_from_server(state: _RPCState, token: str) -> ServerCallbackTag: 233 def send_status_from_server(unused_send_status_from_server_event): 234 with state.condition: 235 return _possibly_finish_call(state, token) 236 237 return send_status_from_server 238 239 240def _get_initial_metadata( 241 state: _RPCState, metadata: Optional[MetadataType] 242) -> Optional[MetadataType]: 243 with state.condition: 244 if state.compression_algorithm: 245 compression_metadata = ( 246 _compression.compression_algorithm_to_metadata( 247 state.compression_algorithm 248 ), 249 ) 250 if metadata is None: 251 return compression_metadata 252 else: 253 return compression_metadata + tuple(metadata) 254 else: 255 return metadata 256 257 258def _get_initial_metadata_operation( 259 state: _RPCState, metadata: Optional[MetadataType] 260) -> cygrpc.Operation: 261 operation = cygrpc.SendInitialMetadataOperation( 262 _get_initial_metadata(state, metadata), _EMPTY_FLAGS 263 ) 264 return operation 265 266 267def _abort( 268 state: _RPCState, call: cygrpc.Call, code: cygrpc.StatusCode, details: bytes 269) -> None: 270 if state.client is not _CANCELLED: 271 effective_code = _abortion_code(state, code) 272 effective_details = details if state.details is None else state.details 273 if state.initial_metadata_allowed: 274 operations = ( 275 _get_initial_metadata_operation(state, None), 276 cygrpc.SendStatusFromServerOperation( 277 state.trailing_metadata, 278 effective_code, 279 effective_details, 280 _EMPTY_FLAGS, 281 ), 282 ) 283 token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN 284 else: 285 operations = ( 286 cygrpc.SendStatusFromServerOperation( 287 state.trailing_metadata, 288 effective_code, 289 effective_details, 290 _EMPTY_FLAGS, 291 ), 292 ) 293 token = _SEND_STATUS_FROM_SERVER_TOKEN 294 call.start_server_batch( 295 operations, _send_status_from_server(state, token) 296 ) 297 state.statused = True 298 state.due.add(token) 299 300 301def _receive_close_on_server(state: _RPCState) -> ServerCallbackTag: 302 def receive_close_on_server(receive_close_on_server_event): 303 with state.condition: 304 if receive_close_on_server_event.batch_operations[0].cancelled(): 305 state.client = _CANCELLED 306 elif state.client is _OPEN: 307 state.client = _CLOSED 308 state.condition.notify_all() 309 return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN) 310 311 return receive_close_on_server 312 313 314def _receive_message( 315 state: _RPCState, 316 call: cygrpc.Call, 317 request_deserializer: Optional[DeserializingFunction], 318) -> ServerCallbackTag: 319 def receive_message(receive_message_event): 320 serialized_request = _serialized_request(receive_message_event) 321 if serialized_request is None: 322 with state.condition: 323 if state.client is _OPEN: 324 state.client = _CLOSED 325 state.condition.notify_all() 326 return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) 327 else: 328 request = _common.deserialize( 329 serialized_request, request_deserializer 330 ) 331 with state.condition: 332 if request is None: 333 _abort( 334 state, 335 call, 336 cygrpc.StatusCode.internal, 337 b"Exception deserializing request!", 338 ) 339 else: 340 state.request = request 341 state.condition.notify_all() 342 return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) 343 344 return receive_message 345 346 347def _send_initial_metadata(state: _RPCState) -> ServerCallbackTag: 348 def send_initial_metadata(unused_send_initial_metadata_event): 349 with state.condition: 350 return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN) 351 352 return send_initial_metadata 353 354 355def _send_message(state: _RPCState, token: str) -> ServerCallbackTag: 356 def send_message(unused_send_message_event): 357 with state.condition: 358 state.condition.notify_all() 359 return _possibly_finish_call(state, token) 360 361 return send_message 362 363 364class _Context(grpc.ServicerContext): 365 _rpc_event: cygrpc.BaseEvent 366 _state: _RPCState 367 request_deserializer: Optional[DeserializingFunction] 368 369 def __init__( 370 self, 371 rpc_event: cygrpc.BaseEvent, 372 state: _RPCState, 373 request_deserializer: Optional[DeserializingFunction], 374 ): 375 self._rpc_event = rpc_event 376 self._state = state 377 self._request_deserializer = request_deserializer 378 379 def is_active(self) -> bool: 380 with self._state.condition: 381 return _is_rpc_state_active(self._state) 382 383 def time_remaining(self) -> float: 384 return max(self._rpc_event.call_details.deadline - time.time(), 0) 385 386 def cancel(self) -> None: 387 self._rpc_event.call.cancel() 388 389 def add_callback(self, callback: NullaryCallbackType) -> bool: 390 with self._state.condition: 391 if self._state.callbacks is None: 392 return False 393 else: 394 self._state.callbacks.append(callback) 395 return True 396 397 def disable_next_message_compression(self) -> None: 398 with self._state.condition: 399 self._state.disable_next_compression = True 400 401 def invocation_metadata(self) -> Optional[MetadataType]: 402 return self._rpc_event.invocation_metadata 403 404 def peer(self) -> str: 405 return _common.decode(self._rpc_event.call.peer()) 406 407 def peer_identities(self) -> Optional[Sequence[bytes]]: 408 return cygrpc.peer_identities(self._rpc_event.call) 409 410 def peer_identity_key(self) -> Optional[str]: 411 id_key = cygrpc.peer_identity_key(self._rpc_event.call) 412 return id_key if id_key is None else _common.decode(id_key) 413 414 def auth_context(self) -> Mapping[str, Sequence[bytes]]: 415 auth_context = cygrpc.auth_context(self._rpc_event.call) 416 auth_context_dict = {} if auth_context is None else auth_context 417 return { 418 _common.decode(key): value 419 for key, value in auth_context_dict.items() 420 } 421 422 def set_compression(self, compression: grpc.Compression) -> None: 423 with self._state.condition: 424 self._state.compression_algorithm = compression 425 426 def send_initial_metadata(self, initial_metadata: MetadataType) -> None: 427 with self._state.condition: 428 if self._state.client is _CANCELLED: 429 _raise_rpc_error(self._state) 430 else: 431 if self._state.initial_metadata_allowed: 432 operation = _get_initial_metadata_operation( 433 self._state, initial_metadata 434 ) 435 self._rpc_event.call.start_server_batch( 436 (operation,), _send_initial_metadata(self._state) 437 ) 438 self._state.initial_metadata_allowed = False 439 self._state.due.add(_SEND_INITIAL_METADATA_TOKEN) 440 else: 441 raise ValueError("Initial metadata no longer allowed!") 442 443 def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None: 444 with self._state.condition: 445 self._state.trailing_metadata = trailing_metadata 446 447 def trailing_metadata(self) -> Optional[MetadataType]: 448 return self._state.trailing_metadata 449 450 def abort(self, code: grpc.StatusCode, details: str) -> None: 451 # treat OK like other invalid arguments: fail the RPC 452 if code == grpc.StatusCode.OK: 453 _LOGGER.error( 454 "abort() called with StatusCode.OK; returning UNKNOWN" 455 ) 456 code = grpc.StatusCode.UNKNOWN 457 details = "" 458 with self._state.condition: 459 self._state.code = code 460 self._state.details = _common.encode(details) 461 self._state.aborted = True 462 raise Exception() 463 464 def abort_with_status(self, status: grpc.Status) -> None: 465 self._state.trailing_metadata = status.trailing_metadata 466 self.abort(status.code, status.details) 467 468 def set_code(self, code: grpc.StatusCode) -> None: 469 with self._state.condition: 470 self._state.code = code 471 472 def code(self) -> grpc.StatusCode: 473 return self._state.code 474 475 def set_details(self, details: str) -> None: 476 with self._state.condition: 477 self._state.details = _common.encode(details) 478 479 def details(self) -> bytes: 480 return self._state.details 481 482 def _finalize_state(self) -> None: 483 pass 484 485 486class _RequestIterator(object): 487 _state: _RPCState 488 _call: cygrpc.Call 489 _request_deserializer: Optional[DeserializingFunction] 490 491 def __init__( 492 self, 493 state: _RPCState, 494 call: cygrpc.Call, 495 request_deserializer: Optional[DeserializingFunction], 496 ): 497 self._state = state 498 self._call = call 499 self._request_deserializer = request_deserializer 500 501 def _raise_or_start_receive_message(self) -> None: 502 if self._state.client is _CANCELLED: 503 _raise_rpc_error(self._state) 504 elif not _is_rpc_state_active(self._state): 505 raise StopIteration() 506 else: 507 self._call.start_server_batch( 508 (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), 509 _receive_message( 510 self._state, self._call, self._request_deserializer 511 ), 512 ) 513 self._state.due.add(_RECEIVE_MESSAGE_TOKEN) 514 515 def _look_for_request(self) -> Any: 516 if self._state.client is _CANCELLED: 517 _raise_rpc_error(self._state) 518 elif ( 519 self._state.request is None 520 and _RECEIVE_MESSAGE_TOKEN not in self._state.due 521 ): 522 raise StopIteration() 523 else: 524 request = self._state.request 525 self._state.request = None 526 return request 527 528 raise AssertionError() # should never run 529 530 def _next(self) -> Any: 531 with self._state.condition: 532 self._raise_or_start_receive_message() 533 while True: 534 self._state.condition.wait() 535 request = self._look_for_request() 536 if request is not None: 537 return request 538 539 def __iter__(self) -> _RequestIterator: 540 return self 541 542 def __next__(self) -> Any: 543 return self._next() 544 545 def next(self) -> Any: 546 return self._next() 547 548 549def _unary_request( 550 rpc_event: cygrpc.BaseEvent, 551 state: _RPCState, 552 request_deserializer: Optional[DeserializingFunction], 553) -> Callable[[], Any]: 554 def unary_request(): 555 with state.condition: 556 if not _is_rpc_state_active(state): 557 return None 558 else: 559 rpc_event.call.start_server_batch( 560 (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), 561 _receive_message( 562 state, rpc_event.call, request_deserializer 563 ), 564 ) 565 state.due.add(_RECEIVE_MESSAGE_TOKEN) 566 while True: 567 state.condition.wait() 568 if state.request is None: 569 if state.client is _CLOSED: 570 details = '"{}" requires exactly one request message.'.format( 571 rpc_event.call_details.method 572 ) 573 _abort( 574 state, 575 rpc_event.call, 576 cygrpc.StatusCode.unimplemented, 577 _common.encode(details), 578 ) 579 return None 580 elif state.client is _CANCELLED: 581 return None 582 else: 583 request = state.request 584 state.request = None 585 return request 586 587 return unary_request 588 589 590def _call_behavior( 591 rpc_event: cygrpc.BaseEvent, 592 state: _RPCState, 593 behavior: ArityAgnosticMethodHandler, 594 argument: Any, 595 request_deserializer: Optional[DeserializingFunction], 596 send_response_callback: Optional[Callable[[ResponseType], None]] = None, 597) -> Tuple[Union[ResponseType, Iterator[ResponseType]], bool]: 598 from grpc import _create_servicer_context # pytype: disable=pyi-error 599 600 with _create_servicer_context( 601 rpc_event, state, request_deserializer 602 ) as context: 603 try: 604 response_or_iterator = None 605 if send_response_callback is not None: 606 response_or_iterator = behavior( 607 argument, context, send_response_callback 608 ) 609 else: 610 response_or_iterator = behavior(argument, context) 611 return response_or_iterator, True 612 except Exception as exception: # pylint: disable=broad-except 613 with state.condition: 614 if state.aborted: 615 _abort( 616 state, 617 rpc_event.call, 618 cygrpc.StatusCode.unknown, 619 b"RPC Aborted", 620 ) 621 elif exception not in state.rpc_errors: 622 try: 623 details = "Exception calling application: {}".format( 624 exception 625 ) 626 except Exception: # pylint: disable=broad-except 627 details = ( 628 "Calling application raised unprintable Exception!" 629 ) 630 _LOGGER.exception( 631 traceback.format_exception( 632 type(exception), 633 exception, 634 exception.__traceback__, 635 ) 636 ) 637 traceback.print_exc() 638 _LOGGER.exception(details) 639 _abort( 640 state, 641 rpc_event.call, 642 cygrpc.StatusCode.unknown, 643 _common.encode(details), 644 ) 645 return None, False 646 647 648def _take_response_from_response_iterator( 649 rpc_event: cygrpc.BaseEvent, 650 state: _RPCState, 651 response_iterator: Iterator[ResponseType], 652) -> Tuple[ResponseType, bool]: 653 try: 654 return next(response_iterator), True 655 except StopIteration: 656 return None, True 657 except Exception as exception: # pylint: disable=broad-except 658 with state.condition: 659 if state.aborted: 660 _abort( 661 state, 662 rpc_event.call, 663 cygrpc.StatusCode.unknown, 664 b"RPC Aborted", 665 ) 666 elif exception not in state.rpc_errors: 667 details = "Exception iterating responses: {}".format(exception) 668 _LOGGER.exception(details) 669 _abort( 670 state, 671 rpc_event.call, 672 cygrpc.StatusCode.unknown, 673 _common.encode(details), 674 ) 675 return None, False 676 677 678def _serialize_response( 679 rpc_event: cygrpc.BaseEvent, 680 state: _RPCState, 681 response: Any, 682 response_serializer: Optional[SerializingFunction], 683) -> Optional[bytes]: 684 serialized_response = _common.serialize(response, response_serializer) 685 if serialized_response is None: 686 with state.condition: 687 _abort( 688 state, 689 rpc_event.call, 690 cygrpc.StatusCode.internal, 691 b"Failed to serialize response!", 692 ) 693 return None 694 else: 695 return serialized_response 696 697 698def _get_send_message_op_flags_from_state( 699 state: _RPCState, 700) -> Union[int, cygrpc.WriteFlag]: 701 if state.disable_next_compression: 702 return cygrpc.WriteFlag.no_compress 703 else: 704 return _EMPTY_FLAGS 705 706 707def _reset_per_message_state(state: _RPCState) -> None: 708 with state.condition: 709 state.disable_next_compression = False 710 711 712def _send_response( 713 rpc_event: cygrpc.BaseEvent, state: _RPCState, serialized_response: bytes 714) -> bool: 715 with state.condition: 716 if not _is_rpc_state_active(state): 717 return False 718 else: 719 if state.initial_metadata_allowed: 720 operations = ( 721 _get_initial_metadata_operation(state, None), 722 cygrpc.SendMessageOperation( 723 serialized_response, 724 _get_send_message_op_flags_from_state(state), 725 ), 726 ) 727 state.initial_metadata_allowed = False 728 token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN 729 else: 730 operations = ( 731 cygrpc.SendMessageOperation( 732 serialized_response, 733 _get_send_message_op_flags_from_state(state), 734 ), 735 ) 736 token = _SEND_MESSAGE_TOKEN 737 rpc_event.call.start_server_batch( 738 operations, _send_message(state, token) 739 ) 740 state.due.add(token) 741 _reset_per_message_state(state) 742 while True: 743 state.condition.wait() 744 if token not in state.due: 745 return _is_rpc_state_active(state) 746 747 748def _status( 749 rpc_event: cygrpc.BaseEvent, 750 state: _RPCState, 751 serialized_response: Optional[bytes], 752) -> None: 753 with state.condition: 754 if state.client is not _CANCELLED: 755 code = _completion_code(state) 756 details = _details(state) 757 operations = [ 758 cygrpc.SendStatusFromServerOperation( 759 state.trailing_metadata, code, details, _EMPTY_FLAGS 760 ), 761 ] 762 if state.initial_metadata_allowed: 763 operations.append(_get_initial_metadata_operation(state, None)) 764 if serialized_response is not None: 765 operations.append( 766 cygrpc.SendMessageOperation( 767 serialized_response, 768 _get_send_message_op_flags_from_state(state), 769 ) 770 ) 771 rpc_event.call.start_server_batch( 772 operations, 773 _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN), 774 ) 775 state.statused = True 776 _reset_per_message_state(state) 777 state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN) 778 779 780def _unary_response_in_pool( 781 rpc_event: cygrpc.BaseEvent, 782 state: _RPCState, 783 behavior: ArityAgnosticMethodHandler, 784 argument_thunk: Callable[[], Any], 785 request_deserializer: Optional[SerializingFunction], 786 response_serializer: Optional[SerializingFunction], 787) -> None: 788 cygrpc.install_context_from_request_call_event(rpc_event) 789 790 try: 791 argument = argument_thunk() 792 if argument is not None: 793 response, proceed = _call_behavior( 794 rpc_event, state, behavior, argument, request_deserializer 795 ) 796 if proceed: 797 serialized_response = _serialize_response( 798 rpc_event, state, response, response_serializer 799 ) 800 if serialized_response is not None: 801 _status(rpc_event, state, serialized_response) 802 except Exception: # pylint: disable=broad-except 803 traceback.print_exc() 804 finally: 805 cygrpc.uninstall_context() 806 807 808def _stream_response_in_pool( 809 rpc_event: cygrpc.BaseEvent, 810 state: _RPCState, 811 behavior: ArityAgnosticMethodHandler, 812 argument_thunk: Callable[[], Any], 813 request_deserializer: Optional[DeserializingFunction], 814 response_serializer: Optional[SerializingFunction], 815) -> None: 816 cygrpc.install_context_from_request_call_event(rpc_event) 817 818 def send_response(response: Any) -> None: 819 if response is None: 820 _status(rpc_event, state, None) 821 else: 822 serialized_response = _serialize_response( 823 rpc_event, state, response, response_serializer 824 ) 825 if serialized_response is not None: 826 _send_response(rpc_event, state, serialized_response) 827 828 try: 829 argument = argument_thunk() 830 if argument is not None: 831 if ( 832 hasattr(behavior, "experimental_non_blocking") 833 and behavior.experimental_non_blocking 834 ): 835 _call_behavior( 836 rpc_event, 837 state, 838 behavior, 839 argument, 840 request_deserializer, 841 send_response_callback=send_response, 842 ) 843 else: 844 response_iterator, proceed = _call_behavior( 845 rpc_event, state, behavior, argument, request_deserializer 846 ) 847 if proceed: 848 _send_message_callback_to_blocking_iterator_adapter( 849 rpc_event, state, send_response, response_iterator 850 ) 851 except Exception: # pylint: disable=broad-except 852 traceback.print_exc() 853 finally: 854 cygrpc.uninstall_context() 855 856 857def _is_rpc_state_active(state: _RPCState) -> bool: 858 return state.client is not _CANCELLED and not state.statused 859 860 861def _send_message_callback_to_blocking_iterator_adapter( 862 rpc_event: cygrpc.BaseEvent, 863 state: _RPCState, 864 send_response_callback: Callable[[ResponseType], None], 865 response_iterator: Iterator[ResponseType], 866) -> None: 867 while True: 868 response, proceed = _take_response_from_response_iterator( 869 rpc_event, state, response_iterator 870 ) 871 if proceed: 872 send_response_callback(response) 873 if not _is_rpc_state_active(state): 874 break 875 else: 876 break 877 878 879def _select_thread_pool_for_behavior( 880 behavior: ArityAgnosticMethodHandler, 881 default_thread_pool: futures.ThreadPoolExecutor, 882) -> futures.ThreadPoolExecutor: 883 if hasattr(behavior, "experimental_thread_pool") and isinstance( 884 behavior.experimental_thread_pool, futures.ThreadPoolExecutor 885 ): 886 return behavior.experimental_thread_pool 887 else: 888 return default_thread_pool 889 890 891def _handle_unary_unary( 892 rpc_event: cygrpc.BaseEvent, 893 state: _RPCState, 894 method_handler: grpc.RpcMethodHandler, 895 default_thread_pool: futures.ThreadPoolExecutor, 896) -> futures.Future: 897 unary_request = _unary_request( 898 rpc_event, state, method_handler.request_deserializer 899 ) 900 thread_pool = _select_thread_pool_for_behavior( 901 method_handler.unary_unary, default_thread_pool 902 ) 903 return thread_pool.submit( 904 state.context.run, 905 _unary_response_in_pool, 906 rpc_event, 907 state, 908 method_handler.unary_unary, 909 unary_request, 910 method_handler.request_deserializer, 911 method_handler.response_serializer, 912 ) 913 914 915def _handle_unary_stream( 916 rpc_event: cygrpc.BaseEvent, 917 state: _RPCState, 918 method_handler: grpc.RpcMethodHandler, 919 default_thread_pool: futures.ThreadPoolExecutor, 920) -> futures.Future: 921 unary_request = _unary_request( 922 rpc_event, state, method_handler.request_deserializer 923 ) 924 thread_pool = _select_thread_pool_for_behavior( 925 method_handler.unary_stream, default_thread_pool 926 ) 927 return thread_pool.submit( 928 state.context.run, 929 _stream_response_in_pool, 930 rpc_event, 931 state, 932 method_handler.unary_stream, 933 unary_request, 934 method_handler.request_deserializer, 935 method_handler.response_serializer, 936 ) 937 938 939def _handle_stream_unary( 940 rpc_event: cygrpc.BaseEvent, 941 state: _RPCState, 942 method_handler: grpc.RpcMethodHandler, 943 default_thread_pool: futures.ThreadPoolExecutor, 944) -> futures.Future: 945 request_iterator = _RequestIterator( 946 state, rpc_event.call, method_handler.request_deserializer 947 ) 948 thread_pool = _select_thread_pool_for_behavior( 949 method_handler.stream_unary, default_thread_pool 950 ) 951 return thread_pool.submit( 952 state.context.run, 953 _unary_response_in_pool, 954 rpc_event, 955 state, 956 method_handler.stream_unary, 957 lambda: request_iterator, 958 method_handler.request_deserializer, 959 method_handler.response_serializer, 960 ) 961 962 963def _handle_stream_stream( 964 rpc_event: cygrpc.BaseEvent, 965 state: _RPCState, 966 method_handler: grpc.RpcMethodHandler, 967 default_thread_pool: futures.ThreadPoolExecutor, 968) -> futures.Future: 969 request_iterator = _RequestIterator( 970 state, rpc_event.call, method_handler.request_deserializer 971 ) 972 thread_pool = _select_thread_pool_for_behavior( 973 method_handler.stream_stream, default_thread_pool 974 ) 975 return thread_pool.submit( 976 state.context.run, 977 _stream_response_in_pool, 978 rpc_event, 979 state, 980 method_handler.stream_stream, 981 lambda: request_iterator, 982 method_handler.request_deserializer, 983 method_handler.response_serializer, 984 ) 985 986 987def _find_method_handler( 988 rpc_event: cygrpc.BaseEvent, 989 state: _RPCState, 990 method_with_handler: _Method, 991 interceptor_pipeline: Optional[_interceptor._ServicePipeline], 992) -> Optional[grpc.RpcMethodHandler]: 993 def query_handlers( 994 handler_call_details: _HandlerCallDetails, 995 ) -> Optional[grpc.RpcMethodHandler]: 996 return method_with_handler.handler(handler_call_details) 997 998 method_name = method_with_handler.name() 999 if not method_name: 1000 method_name = _common.decode(rpc_event.call_details.method) 1001 1002 handler_call_details = _HandlerCallDetails( 1003 method_name, 1004 rpc_event.invocation_metadata, 1005 ) 1006 1007 if interceptor_pipeline is not None: 1008 return state.context.run( 1009 interceptor_pipeline.execute, query_handlers, handler_call_details 1010 ) 1011 else: 1012 return state.context.run(query_handlers, handler_call_details) 1013 1014 1015def _reject_rpc( 1016 rpc_event: cygrpc.BaseEvent, 1017 rpc_state: _RPCState, 1018 status: cygrpc.StatusCode, 1019 details: bytes, 1020): 1021 operations = ( 1022 _get_initial_metadata_operation(rpc_state, None), 1023 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), 1024 cygrpc.SendStatusFromServerOperation( 1025 None, status, details, _EMPTY_FLAGS 1026 ), 1027 ) 1028 rpc_event.call.start_server_batch( 1029 operations, 1030 lambda ignored_event: ( 1031 rpc_state, 1032 (), 1033 ), 1034 ) 1035 1036 1037def _handle_with_method_handler( 1038 rpc_event: cygrpc.BaseEvent, 1039 state: _RPCState, 1040 method_handler: grpc.RpcMethodHandler, 1041 thread_pool: futures.ThreadPoolExecutor, 1042) -> futures.Future: 1043 with state.condition: 1044 rpc_event.call.start_server_batch( 1045 (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),), 1046 _receive_close_on_server(state), 1047 ) 1048 state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN) 1049 if method_handler.request_streaming: 1050 if method_handler.response_streaming: 1051 return _handle_stream_stream( 1052 rpc_event, state, method_handler, thread_pool 1053 ) 1054 else: 1055 return _handle_stream_unary( 1056 rpc_event, state, method_handler, thread_pool 1057 ) 1058 else: 1059 if method_handler.response_streaming: 1060 return _handle_unary_stream( 1061 rpc_event, state, method_handler, thread_pool 1062 ) 1063 else: 1064 return _handle_unary_unary( 1065 rpc_event, state, method_handler, thread_pool 1066 ) 1067 1068 1069def _handle_call( 1070 rpc_event: cygrpc.BaseEvent, 1071 method_with_handler: _Method, 1072 interceptor_pipeline: Optional[_interceptor._ServicePipeline], 1073 thread_pool: futures.ThreadPoolExecutor, 1074 concurrency_exceeded: bool, 1075) -> Tuple[Optional[_RPCState], Optional[futures.Future]]: 1076 """Handles RPC based on provided handlers. 1077 1078 When receiving a call event from Core, registered method will have its 1079 name as tag, we pass the tag as registered_method_name to this method, 1080 then we can find the handler in registered_method_handlers based on 1081 the method name. 1082 1083 For call event with unregistered method, the method name will be included 1084 in rpc_event.call_details.method and we need to query the generics handlers 1085 to find the actual handler. 1086 """ 1087 if not rpc_event.success: 1088 return None, None 1089 if rpc_event.call_details.method or method_with_handler.name(): 1090 rpc_state = _RPCState() 1091 try: 1092 method_handler = _find_method_handler( 1093 rpc_event, 1094 rpc_state, 1095 method_with_handler, 1096 interceptor_pipeline, 1097 ) 1098 except Exception as exception: # pylint: disable=broad-except 1099 details = "Exception servicing handler: {}".format(exception) 1100 _LOGGER.exception(details) 1101 _reject_rpc( 1102 rpc_event, 1103 rpc_state, 1104 cygrpc.StatusCode.unknown, 1105 b"Error in service handler!", 1106 ) 1107 return rpc_state, None 1108 if method_handler is None: 1109 _reject_rpc( 1110 rpc_event, 1111 rpc_state, 1112 cygrpc.StatusCode.unimplemented, 1113 b"Method not found!", 1114 ) 1115 return rpc_state, None 1116 elif concurrency_exceeded: 1117 _reject_rpc( 1118 rpc_event, 1119 rpc_state, 1120 cygrpc.StatusCode.resource_exhausted, 1121 b"Concurrent RPC limit exceeded!", 1122 ) 1123 return rpc_state, None 1124 else: 1125 return ( 1126 rpc_state, 1127 _handle_with_method_handler( 1128 rpc_event, rpc_state, method_handler, thread_pool 1129 ), 1130 ) 1131 else: 1132 return None, None 1133 1134 1135@enum.unique 1136class _ServerStage(enum.Enum): 1137 STOPPED = "stopped" 1138 STARTED = "started" 1139 GRACE = "grace" 1140 1141 1142class _ServerState(object): 1143 lock: threading.RLock 1144 completion_queue: cygrpc.CompletionQueue 1145 server: cygrpc.Server 1146 generic_handlers: List[grpc.GenericRpcHandler] 1147 registered_method_handlers: Dict[str, grpc.RpcMethodHandler] 1148 interceptor_pipeline: Optional[_interceptor._ServicePipeline] 1149 thread_pool: futures.ThreadPoolExecutor 1150 stage: _ServerStage 1151 termination_event: threading.Event 1152 shutdown_events: List[threading.Event] 1153 maximum_concurrent_rpcs: Optional[int] 1154 active_rpc_count: int 1155 rpc_states: Set[_RPCState] 1156 due: Set[str] 1157 server_deallocated: bool 1158 1159 # pylint: disable=too-many-arguments 1160 def __init__( 1161 self, 1162 completion_queue: cygrpc.CompletionQueue, 1163 server: cygrpc.Server, 1164 generic_handlers: Sequence[grpc.GenericRpcHandler], 1165 interceptor_pipeline: Optional[_interceptor._ServicePipeline], 1166 thread_pool: futures.ThreadPoolExecutor, 1167 maximum_concurrent_rpcs: Optional[int], 1168 ): 1169 self.lock = threading.RLock() 1170 self.completion_queue = completion_queue 1171 self.server = server 1172 self.generic_handlers = list(generic_handlers) 1173 self.interceptor_pipeline = interceptor_pipeline 1174 self.thread_pool = thread_pool 1175 self.stage = _ServerStage.STOPPED 1176 self.termination_event = threading.Event() 1177 self.shutdown_events = [self.termination_event] 1178 self.maximum_concurrent_rpcs = maximum_concurrent_rpcs 1179 self.active_rpc_count = 0 1180 self.registered_method_handlers = {} 1181 1182 # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields. 1183 self.rpc_states = set() 1184 self.due = set() 1185 1186 # A "volatile" flag to interrupt the daemon serving thread 1187 self.server_deallocated = False 1188 1189 1190def _add_generic_handlers( 1191 state: _ServerState, generic_handlers: Iterable[grpc.GenericRpcHandler] 1192) -> None: 1193 with state.lock: 1194 state.generic_handlers.extend(generic_handlers) 1195 1196 1197def _add_registered_method_handlers( 1198 state: _ServerState, method_handlers: Dict[str, grpc.RpcMethodHandler] 1199) -> None: 1200 with state.lock: 1201 state.registered_method_handlers.update(method_handlers) 1202 1203 1204def _add_insecure_port(state: _ServerState, address: bytes) -> int: 1205 with state.lock: 1206 return state.server.add_http2_port(address) 1207 1208 1209def _add_secure_port( 1210 state: _ServerState, 1211 address: bytes, 1212 server_credentials: grpc.ServerCredentials, 1213) -> int: 1214 with state.lock: 1215 return state.server.add_http2_port( 1216 address, server_credentials._credentials 1217 ) 1218 1219 1220def _request_call(state: _ServerState) -> None: 1221 state.server.request_call( 1222 state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG 1223 ) 1224 state.due.add(_REQUEST_CALL_TAG) 1225 1226 1227def _request_registered_call(state: _ServerState, method: str) -> None: 1228 registered_call_tag = method 1229 state.server.request_registered_call( 1230 state.completion_queue, 1231 state.completion_queue, 1232 method, 1233 registered_call_tag, 1234 ) 1235 state.due.add(registered_call_tag) 1236 1237 1238# TODO(https://github.com/grpc/grpc/issues/6597): delete this function. 1239def _stop_serving(state: _ServerState) -> bool: 1240 if not state.rpc_states and not state.due: 1241 state.server.destroy() 1242 for shutdown_event in state.shutdown_events: 1243 shutdown_event.set() 1244 state.stage = _ServerStage.STOPPED 1245 return True 1246 else: 1247 return False 1248 1249 1250def _on_call_completed(state: _ServerState) -> None: 1251 with state.lock: 1252 state.active_rpc_count -= 1 1253 1254 1255# pylint: disable=too-many-branches 1256def _process_event_and_continue( 1257 state: _ServerState, event: cygrpc.BaseEvent 1258) -> bool: 1259 should_continue = True 1260 if event.tag is _SHUTDOWN_TAG: 1261 with state.lock: 1262 state.due.remove(_SHUTDOWN_TAG) 1263 if _stop_serving(state): 1264 should_continue = False 1265 elif ( 1266 event.tag is _REQUEST_CALL_TAG 1267 or event.tag in state.registered_method_handlers.keys() 1268 ): 1269 registered_method_name = None 1270 if event.tag in state.registered_method_handlers.keys(): 1271 registered_method_name = event.tag 1272 method_with_handler = _RegisteredMethod( 1273 registered_method_name, 1274 state.registered_method_handlers.get( 1275 registered_method_name, None 1276 ), 1277 ) 1278 else: 1279 method_with_handler = _GenericMethod( 1280 state.generic_handlers, 1281 ) 1282 with state.lock: 1283 state.due.remove(event.tag) 1284 concurrency_exceeded = ( 1285 state.maximum_concurrent_rpcs is not None 1286 and state.active_rpc_count >= state.maximum_concurrent_rpcs 1287 ) 1288 rpc_state, rpc_future = _handle_call( 1289 event, 1290 method_with_handler, 1291 state.interceptor_pipeline, 1292 state.thread_pool, 1293 concurrency_exceeded, 1294 ) 1295 if rpc_state is not None: 1296 state.rpc_states.add(rpc_state) 1297 if rpc_future is not None: 1298 state.active_rpc_count += 1 1299 rpc_future.add_done_callback( 1300 lambda unused_future: _on_call_completed(state) 1301 ) 1302 if state.stage is _ServerStage.STARTED: 1303 if ( 1304 registered_method_name 1305 in state.registered_method_handlers.keys() 1306 ): 1307 _request_registered_call(state, registered_method_name) 1308 else: 1309 _request_call(state) 1310 elif _stop_serving(state): 1311 should_continue = False 1312 else: 1313 rpc_state, callbacks = event.tag(event) 1314 for callback in callbacks: 1315 try: 1316 callback() 1317 except Exception: # pylint: disable=broad-except 1318 _LOGGER.exception("Exception calling callback!") 1319 if rpc_state is not None: 1320 with state.lock: 1321 state.rpc_states.remove(rpc_state) 1322 if _stop_serving(state): 1323 should_continue = False 1324 return should_continue 1325 1326 1327def _serve(state: _ServerState) -> None: 1328 while True: 1329 timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S 1330 event = state.completion_queue.poll(timeout) 1331 if state.server_deallocated: 1332 _begin_shutdown_once(state) 1333 if event.completion_type != cygrpc.CompletionType.queue_timeout: 1334 if not _process_event_and_continue(state, event): 1335 return 1336 # We want to force the deletion of the previous event 1337 # ~before~ we poll again; if the event has a reference 1338 # to a shutdown Call object, this can induce spinlock. 1339 event = None 1340 1341 1342def _begin_shutdown_once(state: _ServerState) -> None: 1343 with state.lock: 1344 if state.stage is _ServerStage.STARTED: 1345 state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) 1346 state.stage = _ServerStage.GRACE 1347 state.due.add(_SHUTDOWN_TAG) 1348 1349 1350def _stop(state: _ServerState, grace: Optional[float]) -> threading.Event: 1351 with state.lock: 1352 if state.stage is _ServerStage.STOPPED: 1353 shutdown_event = threading.Event() 1354 shutdown_event.set() 1355 return shutdown_event 1356 else: 1357 _begin_shutdown_once(state) 1358 shutdown_event = threading.Event() 1359 state.shutdown_events.append(shutdown_event) 1360 if grace is None: 1361 state.server.cancel_all_calls() 1362 else: 1363 1364 def cancel_all_calls_after_grace(): 1365 shutdown_event.wait(timeout=grace) 1366 with state.lock: 1367 state.server.cancel_all_calls() 1368 1369 thread = threading.Thread(target=cancel_all_calls_after_grace) 1370 thread.start() 1371 return shutdown_event 1372 shutdown_event.wait() 1373 return shutdown_event 1374 1375 1376def _start(state: _ServerState) -> None: 1377 with state.lock: 1378 if state.stage is not _ServerStage.STOPPED: 1379 raise ValueError("Cannot start already-started server!") 1380 state.server.start() 1381 state.stage = _ServerStage.STARTED 1382 # Request a call for each registered method so we can handle any of them. 1383 for method in state.registered_method_handlers.keys(): 1384 _request_registered_call(state, method) 1385 # Also request a call for non-registered method. 1386 _request_call(state) 1387 thread = threading.Thread(target=_serve, args=(state,)) 1388 thread.daemon = True 1389 thread.start() 1390 1391 1392def _validate_generic_rpc_handlers( 1393 generic_rpc_handlers: Iterable[grpc.GenericRpcHandler], 1394) -> None: 1395 for generic_rpc_handler in generic_rpc_handlers: 1396 service_attribute = getattr(generic_rpc_handler, "service", None) 1397 if service_attribute is None: 1398 raise AttributeError( 1399 '"{}" must conform to grpc.GenericRpcHandler type but does ' 1400 'not have "service" method!'.format(generic_rpc_handler) 1401 ) 1402 1403 1404def _augment_options( 1405 base_options: Sequence[ChannelArgumentType], 1406 compression: Optional[grpc.Compression], 1407 xds: bool, 1408) -> Sequence[ChannelArgumentType]: 1409 compression_option = _compression.create_channel_option(compression) 1410 maybe_server_call_tracer_factory_option = ( 1411 _observability.create_server_call_tracer_factory_option(xds) 1412 ) 1413 return ( 1414 tuple(base_options) 1415 + compression_option 1416 + maybe_server_call_tracer_factory_option 1417 ) 1418 1419 1420class _Server(grpc.Server): 1421 _state: _ServerState 1422 1423 # pylint: disable=too-many-arguments 1424 def __init__( 1425 self, 1426 thread_pool: futures.ThreadPoolExecutor, 1427 generic_handlers: Sequence[grpc.GenericRpcHandler], 1428 interceptors: Sequence[grpc.ServerInterceptor], 1429 options: Sequence[ChannelArgumentType], 1430 maximum_concurrent_rpcs: Optional[int], 1431 compression: Optional[grpc.Compression], 1432 xds: bool, 1433 ): 1434 completion_queue = cygrpc.CompletionQueue() 1435 server = cygrpc.Server(_augment_options(options, compression, xds), xds) 1436 server.register_completion_queue(completion_queue) 1437 self._state = _ServerState( 1438 completion_queue, 1439 server, 1440 generic_handlers, 1441 _interceptor.service_pipeline(interceptors), 1442 thread_pool, 1443 maximum_concurrent_rpcs, 1444 ) 1445 self._cy_server = server 1446 1447 def add_generic_rpc_handlers( 1448 self, generic_rpc_handlers: Iterable[grpc.GenericRpcHandler] 1449 ) -> None: 1450 _validate_generic_rpc_handlers(generic_rpc_handlers) 1451 _add_generic_handlers(self._state, generic_rpc_handlers) 1452 1453 def add_registered_method_handlers( 1454 self, 1455 service_name: str, 1456 method_handlers: Dict[str, grpc.RpcMethodHandler], 1457 ) -> None: 1458 # Can't register method once server started. 1459 with self._state.lock: 1460 if self._state.stage is _ServerStage.STARTED: 1461 return 1462 1463 # TODO(xuanwn): We should validate method_handlers first. 1464 method_to_handlers = { 1465 _common.fully_qualified_method(service_name, method): method_handler 1466 for method, method_handler in method_handlers.items() 1467 } 1468 for fully_qualified_method in method_to_handlers.keys(): 1469 self._cy_server.register_method(fully_qualified_method) 1470 _add_registered_method_handlers(self._state, method_to_handlers) 1471 1472 def add_insecure_port(self, address: str) -> int: 1473 return _common.validate_port_binding_result( 1474 address, _add_insecure_port(self._state, _common.encode(address)) 1475 ) 1476 1477 def add_secure_port( 1478 self, address: str, server_credentials: grpc.ServerCredentials 1479 ) -> int: 1480 return _common.validate_port_binding_result( 1481 address, 1482 _add_secure_port( 1483 self._state, _common.encode(address), server_credentials 1484 ), 1485 ) 1486 1487 def start(self) -> None: 1488 _start(self._state) 1489 1490 def wait_for_termination(self, timeout: Optional[float] = None) -> bool: 1491 # NOTE(https://bugs.python.org/issue35935) 1492 # Remove this workaround once threading.Event.wait() is working with 1493 # CTRL+C across platforms. 1494 return _common.wait( 1495 self._state.termination_event.wait, 1496 self._state.termination_event.is_set, 1497 timeout=timeout, 1498 ) 1499 1500 def stop(self, grace: Optional[float]) -> threading.Event: 1501 return _stop(self._state, grace) 1502 1503 def __del__(self): 1504 if hasattr(self, "_state"): 1505 # We can not grab a lock in __del__(), so set a flag to signal the 1506 # serving daemon thread (if it exists) to initiate shutdown. 1507 self._state.server_deallocated = True 1508 1509 1510def create_server( 1511 thread_pool: futures.ThreadPoolExecutor, 1512 generic_rpc_handlers: Sequence[grpc.GenericRpcHandler], 1513 interceptors: Sequence[grpc.ServerInterceptor], 1514 options: Sequence[ChannelArgumentType], 1515 maximum_concurrent_rpcs: Optional[int], 1516 compression: Optional[grpc.Compression], 1517 xds: bool, 1518) -> _Server: 1519 _validate_generic_rpc_handlers(generic_rpc_handlers) 1520 return _Server( 1521 thread_pool, 1522 generic_rpc_handlers, 1523 interceptors, 1524 options, 1525 maximum_concurrent_rpcs, 1526 compression, 1527 xds, 1528 ) 1529