1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Service-side implementation of gRPC Python.""" 15 16import collections 17import enum 18import logging 19import threading 20import time 21 22from concurrent import futures 23import six 24 25import grpc 26from grpc import _common 27from grpc import _compression 28from grpc import _interceptor 29from grpc._cython import cygrpc 30 31_LOGGER = logging.getLogger(__name__) 32 33_SHUTDOWN_TAG = 'shutdown' 34_REQUEST_CALL_TAG = 'request_call' 35 36_RECEIVE_CLOSE_ON_SERVER_TOKEN = 'receive_close_on_server' 37_SEND_INITIAL_METADATA_TOKEN = 'send_initial_metadata' 38_RECEIVE_MESSAGE_TOKEN = 'receive_message' 39_SEND_MESSAGE_TOKEN = 'send_message' 40_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = ( 41 'send_initial_metadata * send_message') 42_SEND_STATUS_FROM_SERVER_TOKEN = 'send_status_from_server' 43_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = ( 44 'send_initial_metadata * send_status_from_server') 45 46_OPEN = 'open' 47_CLOSED = 'closed' 48_CANCELLED = 'cancelled' 49 50_EMPTY_FLAGS = 0 51 52_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0 53_INF_TIMEOUT = 1e9 54 55 56def _serialized_request(request_event): 57 return request_event.batch_operations[0].message() 58 59 60def _application_code(code): 61 cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code) 62 return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code 63 64 65def _completion_code(state): 66 if state.code is None: 67 return cygrpc.StatusCode.ok 68 else: 69 return _application_code(state.code) 70 71 72def _abortion_code(state, code): 73 if state.code is None: 74 return code 75 else: 76 return _application_code(state.code) 77 78 79def _details(state): 80 return b'' if state.details is None else state.details 81 82 83class _HandlerCallDetails( 84 collections.namedtuple('_HandlerCallDetails', ( 85 'method', 86 'invocation_metadata', 87 )), grpc.HandlerCallDetails): 88 pass 89 90 91class _RPCState(object): 92 93 def __init__(self): 94 self.condition = threading.Condition() 95 self.due = set() 96 self.request = None 97 self.client = _OPEN 98 self.initial_metadata_allowed = True 99 self.compression_algorithm = None 100 self.disable_next_compression = False 101 self.trailing_metadata = None 102 self.code = None 103 self.details = None 104 self.statused = False 105 self.rpc_errors = [] 106 self.callbacks = [] 107 self.aborted = False 108 109 110def _raise_rpc_error(state): 111 rpc_error = grpc.RpcError() 112 state.rpc_errors.append(rpc_error) 113 raise rpc_error 114 115 116def _possibly_finish_call(state, token): 117 state.due.remove(token) 118 if not _is_rpc_state_active(state) and not state.due: 119 callbacks = state.callbacks 120 state.callbacks = None 121 return state, callbacks 122 else: 123 return None, () 124 125 126def _send_status_from_server(state, token): 127 128 def send_status_from_server(unused_send_status_from_server_event): 129 with state.condition: 130 return _possibly_finish_call(state, token) 131 132 return send_status_from_server 133 134 135def _get_initial_metadata(state, metadata): 136 with state.condition: 137 if state.compression_algorithm: 138 compression_metadata = ( 139 _compression.compression_algorithm_to_metadata( 140 state.compression_algorithm),) 141 if metadata is None: 142 return compression_metadata 143 else: 144 return compression_metadata + tuple(metadata) 145 else: 146 return metadata 147 148 149def _get_initial_metadata_operation(state, metadata): 150 operation = cygrpc.SendInitialMetadataOperation( 151 _get_initial_metadata(state, metadata), _EMPTY_FLAGS) 152 return operation 153 154 155def _abort(state, call, code, details): 156 if state.client is not _CANCELLED: 157 effective_code = _abortion_code(state, code) 158 effective_details = details if state.details is None else state.details 159 if state.initial_metadata_allowed: 160 operations = ( 161 _get_initial_metadata_operation(state, None), 162 cygrpc.SendStatusFromServerOperation(state.trailing_metadata, 163 effective_code, 164 effective_details, 165 _EMPTY_FLAGS), 166 ) 167 token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN 168 else: 169 operations = (cygrpc.SendStatusFromServerOperation( 170 state.trailing_metadata, effective_code, effective_details, 171 _EMPTY_FLAGS),) 172 token = _SEND_STATUS_FROM_SERVER_TOKEN 173 call.start_server_batch(operations, 174 _send_status_from_server(state, token)) 175 state.statused = True 176 state.due.add(token) 177 178 179def _receive_close_on_server(state): 180 181 def receive_close_on_server(receive_close_on_server_event): 182 with state.condition: 183 if receive_close_on_server_event.batch_operations[0].cancelled(): 184 state.client = _CANCELLED 185 elif state.client is _OPEN: 186 state.client = _CLOSED 187 state.condition.notify_all() 188 return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN) 189 190 return receive_close_on_server 191 192 193def _receive_message(state, call, request_deserializer): 194 195 def receive_message(receive_message_event): 196 serialized_request = _serialized_request(receive_message_event) 197 if serialized_request is None: 198 with state.condition: 199 if state.client is _OPEN: 200 state.client = _CLOSED 201 state.condition.notify_all() 202 return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) 203 else: 204 request = _common.deserialize(serialized_request, 205 request_deserializer) 206 with state.condition: 207 if request is None: 208 _abort(state, call, cygrpc.StatusCode.internal, 209 b'Exception deserializing request!') 210 else: 211 state.request = request 212 state.condition.notify_all() 213 return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) 214 215 return receive_message 216 217 218def _send_initial_metadata(state): 219 220 def send_initial_metadata(unused_send_initial_metadata_event): 221 with state.condition: 222 return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN) 223 224 return send_initial_metadata 225 226 227def _send_message(state, token): 228 229 def send_message(unused_send_message_event): 230 with state.condition: 231 state.condition.notify_all() 232 return _possibly_finish_call(state, token) 233 234 return send_message 235 236 237class _Context(grpc.ServicerContext): 238 239 def __init__(self, rpc_event, state, request_deserializer): 240 self._rpc_event = rpc_event 241 self._state = state 242 self._request_deserializer = request_deserializer 243 244 def is_active(self): 245 with self._state.condition: 246 return _is_rpc_state_active(self._state) 247 248 def time_remaining(self): 249 return max(self._rpc_event.call_details.deadline - time.time(), 0) 250 251 def cancel(self): 252 self._rpc_event.call.cancel() 253 254 def add_callback(self, callback): 255 with self._state.condition: 256 if self._state.callbacks is None: 257 return False 258 else: 259 self._state.callbacks.append(callback) 260 return True 261 262 def disable_next_message_compression(self): 263 with self._state.condition: 264 self._state.disable_next_compression = True 265 266 def invocation_metadata(self): 267 return self._rpc_event.invocation_metadata 268 269 def peer(self): 270 return _common.decode(self._rpc_event.call.peer()) 271 272 def peer_identities(self): 273 return cygrpc.peer_identities(self._rpc_event.call) 274 275 def peer_identity_key(self): 276 id_key = cygrpc.peer_identity_key(self._rpc_event.call) 277 return id_key if id_key is None else _common.decode(id_key) 278 279 def auth_context(self): 280 return { 281 _common.decode(key): value for key, value in six.iteritems( 282 cygrpc.auth_context(self._rpc_event.call)) 283 } 284 285 def set_compression(self, compression): 286 with self._state.condition: 287 self._state.compression_algorithm = compression 288 289 def send_initial_metadata(self, initial_metadata): 290 with self._state.condition: 291 if self._state.client is _CANCELLED: 292 _raise_rpc_error(self._state) 293 else: 294 if self._state.initial_metadata_allowed: 295 operation = _get_initial_metadata_operation( 296 self._state, initial_metadata) 297 self._rpc_event.call.start_server_batch( 298 (operation,), _send_initial_metadata(self._state)) 299 self._state.initial_metadata_allowed = False 300 self._state.due.add(_SEND_INITIAL_METADATA_TOKEN) 301 else: 302 raise ValueError('Initial metadata no longer allowed!') 303 304 def set_trailing_metadata(self, trailing_metadata): 305 with self._state.condition: 306 self._state.trailing_metadata = trailing_metadata 307 308 def abort(self, code, details): 309 # treat OK like other invalid arguments: fail the RPC 310 if code == grpc.StatusCode.OK: 311 _LOGGER.error( 312 'abort() called with StatusCode.OK; returning UNKNOWN') 313 code = grpc.StatusCode.UNKNOWN 314 details = '' 315 with self._state.condition: 316 self._state.code = code 317 self._state.details = _common.encode(details) 318 self._state.aborted = True 319 raise Exception() 320 321 def abort_with_status(self, status): 322 self._state.trailing_metadata = status.trailing_metadata 323 self.abort(status.code, status.details) 324 325 def set_code(self, code): 326 with self._state.condition: 327 self._state.code = code 328 329 def set_details(self, details): 330 with self._state.condition: 331 self._state.details = _common.encode(details) 332 333 def _finalize_state(self): 334 pass 335 336 337class _RequestIterator(object): 338 339 def __init__(self, state, call, request_deserializer): 340 self._state = state 341 self._call = call 342 self._request_deserializer = request_deserializer 343 344 def _raise_or_start_receive_message(self): 345 if self._state.client is _CANCELLED: 346 _raise_rpc_error(self._state) 347 elif not _is_rpc_state_active(self._state): 348 raise StopIteration() 349 else: 350 self._call.start_server_batch( 351 (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), 352 _receive_message(self._state, self._call, 353 self._request_deserializer)) 354 self._state.due.add(_RECEIVE_MESSAGE_TOKEN) 355 356 def _look_for_request(self): 357 if self._state.client is _CANCELLED: 358 _raise_rpc_error(self._state) 359 elif (self._state.request is None and 360 _RECEIVE_MESSAGE_TOKEN not in self._state.due): 361 raise StopIteration() 362 else: 363 request = self._state.request 364 self._state.request = None 365 return request 366 367 raise AssertionError() # should never run 368 369 def _next(self): 370 with self._state.condition: 371 self._raise_or_start_receive_message() 372 while True: 373 self._state.condition.wait() 374 request = self._look_for_request() 375 if request is not None: 376 return request 377 378 def __iter__(self): 379 return self 380 381 def __next__(self): 382 return self._next() 383 384 def next(self): 385 return self._next() 386 387 388def _unary_request(rpc_event, state, request_deserializer): 389 390 def unary_request(): 391 with state.condition: 392 if not _is_rpc_state_active(state): 393 return None 394 else: 395 rpc_event.call.start_server_batch( 396 (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), 397 _receive_message(state, rpc_event.call, 398 request_deserializer)) 399 state.due.add(_RECEIVE_MESSAGE_TOKEN) 400 while True: 401 state.condition.wait() 402 if state.request is None: 403 if state.client is _CLOSED: 404 details = '"{}" requires exactly one request message.'.format( 405 rpc_event.call_details.method) 406 _abort(state, rpc_event.call, 407 cygrpc.StatusCode.unimplemented, 408 _common.encode(details)) 409 return None 410 elif state.client is _CANCELLED: 411 return None 412 else: 413 request = state.request 414 state.request = None 415 return request 416 417 return unary_request 418 419 420def _call_behavior(rpc_event, 421 state, 422 behavior, 423 argument, 424 request_deserializer, 425 send_response_callback=None): 426 from grpc import _create_servicer_context 427 with _create_servicer_context(rpc_event, state, 428 request_deserializer) as context: 429 try: 430 response_or_iterator = None 431 if send_response_callback is not None: 432 response_or_iterator = behavior(argument, context, 433 send_response_callback) 434 else: 435 response_or_iterator = behavior(argument, context) 436 return response_or_iterator, True 437 except Exception as exception: # pylint: disable=broad-except 438 with state.condition: 439 if state.aborted: 440 _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, 441 b'RPC Aborted') 442 elif exception not in state.rpc_errors: 443 details = 'Exception calling application: {}'.format( 444 exception) 445 _LOGGER.exception(details) 446 _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, 447 _common.encode(details)) 448 return None, False 449 450 451def _take_response_from_response_iterator(rpc_event, state, response_iterator): 452 try: 453 return next(response_iterator), True 454 except StopIteration: 455 return None, True 456 except Exception as exception: # pylint: disable=broad-except 457 with state.condition: 458 if state.aborted: 459 _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, 460 b'RPC Aborted') 461 elif exception not in state.rpc_errors: 462 details = 'Exception iterating responses: {}'.format(exception) 463 _LOGGER.exception(details) 464 _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, 465 _common.encode(details)) 466 return None, False 467 468 469def _serialize_response(rpc_event, state, response, response_serializer): 470 serialized_response = _common.serialize(response, response_serializer) 471 if serialized_response is None: 472 with state.condition: 473 _abort(state, rpc_event.call, cygrpc.StatusCode.internal, 474 b'Failed to serialize response!') 475 return None 476 else: 477 return serialized_response 478 479 480def _get_send_message_op_flags_from_state(state): 481 if state.disable_next_compression: 482 return cygrpc.WriteFlag.no_compress 483 else: 484 return _EMPTY_FLAGS 485 486 487def _reset_per_message_state(state): 488 with state.condition: 489 state.disable_next_compression = False 490 491 492def _send_response(rpc_event, state, serialized_response): 493 with state.condition: 494 if not _is_rpc_state_active(state): 495 return False 496 else: 497 if state.initial_metadata_allowed: 498 operations = ( 499 _get_initial_metadata_operation(state, None), 500 cygrpc.SendMessageOperation( 501 serialized_response, 502 _get_send_message_op_flags_from_state(state)), 503 ) 504 state.initial_metadata_allowed = False 505 token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN 506 else: 507 operations = (cygrpc.SendMessageOperation( 508 serialized_response, 509 _get_send_message_op_flags_from_state(state)),) 510 token = _SEND_MESSAGE_TOKEN 511 rpc_event.call.start_server_batch(operations, 512 _send_message(state, token)) 513 state.due.add(token) 514 _reset_per_message_state(state) 515 while True: 516 state.condition.wait() 517 if token not in state.due: 518 return _is_rpc_state_active(state) 519 520 521def _status(rpc_event, state, serialized_response): 522 with state.condition: 523 if state.client is not _CANCELLED: 524 code = _completion_code(state) 525 details = _details(state) 526 operations = [ 527 cygrpc.SendStatusFromServerOperation(state.trailing_metadata, 528 code, details, 529 _EMPTY_FLAGS), 530 ] 531 if state.initial_metadata_allowed: 532 operations.append(_get_initial_metadata_operation(state, None)) 533 if serialized_response is not None: 534 operations.append( 535 cygrpc.SendMessageOperation( 536 serialized_response, 537 _get_send_message_op_flags_from_state(state))) 538 rpc_event.call.start_server_batch( 539 operations, 540 _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN)) 541 state.statused = True 542 _reset_per_message_state(state) 543 state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN) 544 545 546def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk, 547 request_deserializer, response_serializer): 548 cygrpc.install_context_from_request_call_event(rpc_event) 549 try: 550 argument = argument_thunk() 551 if argument is not None: 552 response, proceed = _call_behavior(rpc_event, state, behavior, 553 argument, request_deserializer) 554 if proceed: 555 serialized_response = _serialize_response( 556 rpc_event, state, response, response_serializer) 557 if serialized_response is not None: 558 _status(rpc_event, state, serialized_response) 559 finally: 560 cygrpc.uninstall_context() 561 562 563def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk, 564 request_deserializer, response_serializer): 565 cygrpc.install_context_from_request_call_event(rpc_event) 566 567 def send_response(response): 568 if response is None: 569 _status(rpc_event, state, None) 570 else: 571 serialized_response = _serialize_response(rpc_event, state, 572 response, 573 response_serializer) 574 if serialized_response is not None: 575 _send_response(rpc_event, state, serialized_response) 576 577 try: 578 argument = argument_thunk() 579 if argument is not None: 580 if hasattr(behavior, 'experimental_non_blocking' 581 ) and behavior.experimental_non_blocking: 582 _call_behavior(rpc_event, 583 state, 584 behavior, 585 argument, 586 request_deserializer, 587 send_response_callback=send_response) 588 else: 589 response_iterator, proceed = _call_behavior( 590 rpc_event, state, behavior, argument, request_deserializer) 591 if proceed: 592 _send_message_callback_to_blocking_iterator_adapter( 593 rpc_event, state, send_response, response_iterator) 594 finally: 595 cygrpc.uninstall_context() 596 597 598def _is_rpc_state_active(state): 599 return state.client is not _CANCELLED and not state.statused 600 601 602def _send_message_callback_to_blocking_iterator_adapter(rpc_event, state, 603 send_response_callback, 604 response_iterator): 605 while True: 606 response, proceed = _take_response_from_response_iterator( 607 rpc_event, state, response_iterator) 608 if proceed: 609 send_response_callback(response) 610 if not _is_rpc_state_active(state): 611 break 612 else: 613 break 614 615 616def _select_thread_pool_for_behavior(behavior, default_thread_pool): 617 if hasattr(behavior, 'experimental_thread_pool') and isinstance( 618 behavior.experimental_thread_pool, futures.ThreadPoolExecutor): 619 return behavior.experimental_thread_pool 620 else: 621 return default_thread_pool 622 623 624def _handle_unary_unary(rpc_event, state, method_handler, default_thread_pool): 625 unary_request = _unary_request(rpc_event, state, 626 method_handler.request_deserializer) 627 thread_pool = _select_thread_pool_for_behavior(method_handler.unary_unary, 628 default_thread_pool) 629 return thread_pool.submit(_unary_response_in_pool, rpc_event, state, 630 method_handler.unary_unary, unary_request, 631 method_handler.request_deserializer, 632 method_handler.response_serializer) 633 634 635def _handle_unary_stream(rpc_event, state, method_handler, default_thread_pool): 636 unary_request = _unary_request(rpc_event, state, 637 method_handler.request_deserializer) 638 thread_pool = _select_thread_pool_for_behavior(method_handler.unary_stream, 639 default_thread_pool) 640 return thread_pool.submit(_stream_response_in_pool, rpc_event, state, 641 method_handler.unary_stream, unary_request, 642 method_handler.request_deserializer, 643 method_handler.response_serializer) 644 645 646def _handle_stream_unary(rpc_event, state, method_handler, default_thread_pool): 647 request_iterator = _RequestIterator(state, rpc_event.call, 648 method_handler.request_deserializer) 649 thread_pool = _select_thread_pool_for_behavior(method_handler.stream_unary, 650 default_thread_pool) 651 return thread_pool.submit(_unary_response_in_pool, rpc_event, state, 652 method_handler.stream_unary, 653 lambda: request_iterator, 654 method_handler.request_deserializer, 655 method_handler.response_serializer) 656 657 658def _handle_stream_stream(rpc_event, state, method_handler, 659 default_thread_pool): 660 request_iterator = _RequestIterator(state, rpc_event.call, 661 method_handler.request_deserializer) 662 thread_pool = _select_thread_pool_for_behavior(method_handler.stream_stream, 663 default_thread_pool) 664 return thread_pool.submit(_stream_response_in_pool, rpc_event, state, 665 method_handler.stream_stream, 666 lambda: request_iterator, 667 method_handler.request_deserializer, 668 method_handler.response_serializer) 669 670 671def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline): 672 673 def query_handlers(handler_call_details): 674 for generic_handler in generic_handlers: 675 method_handler = generic_handler.service(handler_call_details) 676 if method_handler is not None: 677 return method_handler 678 return None 679 680 handler_call_details = _HandlerCallDetails( 681 _common.decode(rpc_event.call_details.method), 682 rpc_event.invocation_metadata) 683 684 if interceptor_pipeline is not None: 685 return interceptor_pipeline.execute(query_handlers, 686 handler_call_details) 687 else: 688 return query_handlers(handler_call_details) 689 690 691def _reject_rpc(rpc_event, status, details): 692 rpc_state = _RPCState() 693 operations = ( 694 _get_initial_metadata_operation(rpc_state, None), 695 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), 696 cygrpc.SendStatusFromServerOperation(None, status, details, 697 _EMPTY_FLAGS), 698 ) 699 rpc_event.call.start_server_batch(operations, lambda ignored_event: ( 700 rpc_state, 701 (), 702 )) 703 return rpc_state 704 705 706def _handle_with_method_handler(rpc_event, method_handler, thread_pool): 707 state = _RPCState() 708 with state.condition: 709 rpc_event.call.start_server_batch( 710 (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),), 711 _receive_close_on_server(state)) 712 state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN) 713 if method_handler.request_streaming: 714 if method_handler.response_streaming: 715 return state, _handle_stream_stream(rpc_event, state, 716 method_handler, thread_pool) 717 else: 718 return state, _handle_stream_unary(rpc_event, state, 719 method_handler, thread_pool) 720 else: 721 if method_handler.response_streaming: 722 return state, _handle_unary_stream(rpc_event, state, 723 method_handler, thread_pool) 724 else: 725 return state, _handle_unary_unary(rpc_event, state, 726 method_handler, thread_pool) 727 728 729def _handle_call(rpc_event, generic_handlers, interceptor_pipeline, thread_pool, 730 concurrency_exceeded): 731 if not rpc_event.success: 732 return None, None 733 if rpc_event.call_details.method is not None: 734 try: 735 method_handler = _find_method_handler(rpc_event, generic_handlers, 736 interceptor_pipeline) 737 except Exception as exception: # pylint: disable=broad-except 738 details = 'Exception servicing handler: {}'.format(exception) 739 _LOGGER.exception(details) 740 return _reject_rpc(rpc_event, cygrpc.StatusCode.unknown, 741 b'Error in service handler!'), None 742 if method_handler is None: 743 return _reject_rpc(rpc_event, cygrpc.StatusCode.unimplemented, 744 b'Method not found!'), None 745 elif concurrency_exceeded: 746 return _reject_rpc(rpc_event, cygrpc.StatusCode.resource_exhausted, 747 b'Concurrent RPC limit exceeded!'), None 748 else: 749 return _handle_with_method_handler(rpc_event, method_handler, 750 thread_pool) 751 else: 752 return None, None 753 754 755@enum.unique 756class _ServerStage(enum.Enum): 757 STOPPED = 'stopped' 758 STARTED = 'started' 759 GRACE = 'grace' 760 761 762class _ServerState(object): 763 764 # pylint: disable=too-many-arguments 765 def __init__(self, completion_queue, server, generic_handlers, 766 interceptor_pipeline, thread_pool, maximum_concurrent_rpcs): 767 self.lock = threading.RLock() 768 self.completion_queue = completion_queue 769 self.server = server 770 self.generic_handlers = list(generic_handlers) 771 self.interceptor_pipeline = interceptor_pipeline 772 self.thread_pool = thread_pool 773 self.stage = _ServerStage.STOPPED 774 self.termination_event = threading.Event() 775 self.shutdown_events = [self.termination_event] 776 self.maximum_concurrent_rpcs = maximum_concurrent_rpcs 777 self.active_rpc_count = 0 778 779 # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields. 780 self.rpc_states = set() 781 self.due = set() 782 783 # A "volatile" flag to interrupt the daemon serving thread 784 self.server_deallocated = False 785 786 787def _add_generic_handlers(state, generic_handlers): 788 with state.lock: 789 state.generic_handlers.extend(generic_handlers) 790 791 792def _add_insecure_port(state, address): 793 with state.lock: 794 return state.server.add_http2_port(address) 795 796 797def _add_secure_port(state, address, server_credentials): 798 with state.lock: 799 return state.server.add_http2_port(address, 800 server_credentials._credentials) 801 802 803def _request_call(state): 804 state.server.request_call(state.completion_queue, state.completion_queue, 805 _REQUEST_CALL_TAG) 806 state.due.add(_REQUEST_CALL_TAG) 807 808 809# TODO(https://github.com/grpc/grpc/issues/6597): delete this function. 810def _stop_serving(state): 811 if not state.rpc_states and not state.due: 812 state.server.destroy() 813 for shutdown_event in state.shutdown_events: 814 shutdown_event.set() 815 state.stage = _ServerStage.STOPPED 816 return True 817 else: 818 return False 819 820 821def _on_call_completed(state): 822 with state.lock: 823 state.active_rpc_count -= 1 824 825 826def _process_event_and_continue(state, event): 827 should_continue = True 828 if event.tag is _SHUTDOWN_TAG: 829 with state.lock: 830 state.due.remove(_SHUTDOWN_TAG) 831 if _stop_serving(state): 832 should_continue = False 833 elif event.tag is _REQUEST_CALL_TAG: 834 with state.lock: 835 state.due.remove(_REQUEST_CALL_TAG) 836 concurrency_exceeded = ( 837 state.maximum_concurrent_rpcs is not None and 838 state.active_rpc_count >= state.maximum_concurrent_rpcs) 839 rpc_state, rpc_future = _handle_call(event, state.generic_handlers, 840 state.interceptor_pipeline, 841 state.thread_pool, 842 concurrency_exceeded) 843 if rpc_state is not None: 844 state.rpc_states.add(rpc_state) 845 if rpc_future is not None: 846 state.active_rpc_count += 1 847 rpc_future.add_done_callback( 848 lambda unused_future: _on_call_completed(state)) 849 if state.stage is _ServerStage.STARTED: 850 _request_call(state) 851 elif _stop_serving(state): 852 should_continue = False 853 else: 854 rpc_state, callbacks = event.tag(event) 855 for callback in callbacks: 856 try: 857 callback() 858 except Exception: # pylint: disable=broad-except 859 _LOGGER.exception('Exception calling callback!') 860 if rpc_state is not None: 861 with state.lock: 862 state.rpc_states.remove(rpc_state) 863 if _stop_serving(state): 864 should_continue = False 865 return should_continue 866 867 868def _serve(state): 869 while True: 870 timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S 871 event = state.completion_queue.poll(timeout) 872 if state.server_deallocated: 873 _begin_shutdown_once(state) 874 if event.completion_type != cygrpc.CompletionType.queue_timeout: 875 if not _process_event_and_continue(state, event): 876 return 877 # We want to force the deletion of the previous event 878 # ~before~ we poll again; if the event has a reference 879 # to a shutdown Call object, this can induce spinlock. 880 event = None 881 882 883def _begin_shutdown_once(state): 884 with state.lock: 885 if state.stage is _ServerStage.STARTED: 886 state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) 887 state.stage = _ServerStage.GRACE 888 state.due.add(_SHUTDOWN_TAG) 889 890 891def _stop(state, grace): 892 with state.lock: 893 if state.stage is _ServerStage.STOPPED: 894 shutdown_event = threading.Event() 895 shutdown_event.set() 896 return shutdown_event 897 else: 898 _begin_shutdown_once(state) 899 shutdown_event = threading.Event() 900 state.shutdown_events.append(shutdown_event) 901 if grace is None: 902 state.server.cancel_all_calls() 903 else: 904 905 def cancel_all_calls_after_grace(): 906 shutdown_event.wait(timeout=grace) 907 with state.lock: 908 state.server.cancel_all_calls() 909 910 thread = threading.Thread(target=cancel_all_calls_after_grace) 911 thread.start() 912 return shutdown_event 913 shutdown_event.wait() 914 return shutdown_event 915 916 917def _start(state): 918 with state.lock: 919 if state.stage is not _ServerStage.STOPPED: 920 raise ValueError('Cannot start already-started server!') 921 state.server.start() 922 state.stage = _ServerStage.STARTED 923 _request_call(state) 924 925 thread = threading.Thread(target=_serve, args=(state,)) 926 thread.daemon = True 927 thread.start() 928 929 930def _validate_generic_rpc_handlers(generic_rpc_handlers): 931 for generic_rpc_handler in generic_rpc_handlers: 932 service_attribute = getattr(generic_rpc_handler, 'service', None) 933 if service_attribute is None: 934 raise AttributeError( 935 '"{}" must conform to grpc.GenericRpcHandler type but does ' 936 'not have "service" method!'.format(generic_rpc_handler)) 937 938 939def _augment_options(base_options, compression): 940 compression_option = _compression.create_channel_option(compression) 941 return tuple(base_options) + compression_option 942 943 944class _Server(grpc.Server): 945 946 # pylint: disable=too-many-arguments 947 def __init__(self, thread_pool, generic_handlers, interceptors, options, 948 maximum_concurrent_rpcs, compression, xds): 949 completion_queue = cygrpc.CompletionQueue() 950 server = cygrpc.Server(_augment_options(options, compression), xds) 951 server.register_completion_queue(completion_queue) 952 self._state = _ServerState(completion_queue, server, generic_handlers, 953 _interceptor.service_pipeline(interceptors), 954 thread_pool, maximum_concurrent_rpcs) 955 956 def add_generic_rpc_handlers(self, generic_rpc_handlers): 957 _validate_generic_rpc_handlers(generic_rpc_handlers) 958 _add_generic_handlers(self._state, generic_rpc_handlers) 959 960 def add_insecure_port(self, address): 961 return _common.validate_port_binding_result( 962 address, _add_insecure_port(self._state, _common.encode(address))) 963 964 def add_secure_port(self, address, server_credentials): 965 return _common.validate_port_binding_result( 966 address, 967 _add_secure_port(self._state, _common.encode(address), 968 server_credentials)) 969 970 def start(self): 971 _start(self._state) 972 973 def wait_for_termination(self, timeout=None): 974 # NOTE(https://bugs.python.org/issue35935) 975 # Remove this workaround once threading.Event.wait() is working with 976 # CTRL+C across platforms. 977 return _common.wait(self._state.termination_event.wait, 978 self._state.termination_event.is_set, 979 timeout=timeout) 980 981 def stop(self, grace): 982 return _stop(self._state, grace) 983 984 def __del__(self): 985 if hasattr(self, '_state'): 986 # We can not grab a lock in __del__(), so set a flag to signal the 987 # serving daemon thread (if it exists) to initiate shutdown. 988 self._state.server_deallocated = True 989 990 991def create_server(thread_pool, generic_rpc_handlers, interceptors, options, 992 maximum_concurrent_rpcs, compression, xds): 993 _validate_generic_rpc_handlers(generic_rpc_handlers) 994 return _Server(thread_pool, generic_rpc_handlers, interceptors, options, 995 maximum_concurrent_rpcs, compression, xds) 996