• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Service-side implementation of gRPC Python."""
15
16import collections
17import enum
18import logging
19import threading
20import time
21
22from concurrent import futures
23import six
24
25import grpc
26from grpc import _common
27from grpc import _compression
28from grpc import _interceptor
29from grpc._cython import cygrpc
30
31_LOGGER = logging.getLogger(__name__)
32
33_SHUTDOWN_TAG = 'shutdown'
34_REQUEST_CALL_TAG = 'request_call'
35
36_RECEIVE_CLOSE_ON_SERVER_TOKEN = 'receive_close_on_server'
37_SEND_INITIAL_METADATA_TOKEN = 'send_initial_metadata'
38_RECEIVE_MESSAGE_TOKEN = 'receive_message'
39_SEND_MESSAGE_TOKEN = 'send_message'
40_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = (
41    'send_initial_metadata * send_message')
42_SEND_STATUS_FROM_SERVER_TOKEN = 'send_status_from_server'
43_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = (
44    'send_initial_metadata * send_status_from_server')
45
46_OPEN = 'open'
47_CLOSED = 'closed'
48_CANCELLED = 'cancelled'
49
50_EMPTY_FLAGS = 0
51
52_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0
53_INF_TIMEOUT = 1e9
54
55
56def _serialized_request(request_event):
57    return request_event.batch_operations[0].message()
58
59
60def _application_code(code):
61    cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code)
62    return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code
63
64
65def _completion_code(state):
66    if state.code is None:
67        return cygrpc.StatusCode.ok
68    else:
69        return _application_code(state.code)
70
71
72def _abortion_code(state, code):
73    if state.code is None:
74        return code
75    else:
76        return _application_code(state.code)
77
78
79def _details(state):
80    return b'' if state.details is None else state.details
81
82
83class _HandlerCallDetails(
84        collections.namedtuple('_HandlerCallDetails', (
85            'method',
86            'invocation_metadata',
87        )), grpc.HandlerCallDetails):
88    pass
89
90
91class _RPCState(object):
92
93    def __init__(self):
94        self.condition = threading.Condition()
95        self.due = set()
96        self.request = None
97        self.client = _OPEN
98        self.initial_metadata_allowed = True
99        self.compression_algorithm = None
100        self.disable_next_compression = False
101        self.trailing_metadata = None
102        self.code = None
103        self.details = None
104        self.statused = False
105        self.rpc_errors = []
106        self.callbacks = []
107        self.aborted = False
108
109
110def _raise_rpc_error(state):
111    rpc_error = grpc.RpcError()
112    state.rpc_errors.append(rpc_error)
113    raise rpc_error
114
115
116def _possibly_finish_call(state, token):
117    state.due.remove(token)
118    if not _is_rpc_state_active(state) and not state.due:
119        callbacks = state.callbacks
120        state.callbacks = None
121        return state, callbacks
122    else:
123        return None, ()
124
125
126def _send_status_from_server(state, token):
127
128    def send_status_from_server(unused_send_status_from_server_event):
129        with state.condition:
130            return _possibly_finish_call(state, token)
131
132    return send_status_from_server
133
134
135def _get_initial_metadata(state, metadata):
136    with state.condition:
137        if state.compression_algorithm:
138            compression_metadata = (
139                _compression.compression_algorithm_to_metadata(
140                    state.compression_algorithm),)
141            if metadata is None:
142                return compression_metadata
143            else:
144                return compression_metadata + tuple(metadata)
145        else:
146            return metadata
147
148
149def _get_initial_metadata_operation(state, metadata):
150    operation = cygrpc.SendInitialMetadataOperation(
151        _get_initial_metadata(state, metadata), _EMPTY_FLAGS)
152    return operation
153
154
155def _abort(state, call, code, details):
156    if state.client is not _CANCELLED:
157        effective_code = _abortion_code(state, code)
158        effective_details = details if state.details is None else state.details
159        if state.initial_metadata_allowed:
160            operations = (
161                _get_initial_metadata_operation(state, None),
162                cygrpc.SendStatusFromServerOperation(state.trailing_metadata,
163                                                     effective_code,
164                                                     effective_details,
165                                                     _EMPTY_FLAGS),
166            )
167            token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
168        else:
169            operations = (cygrpc.SendStatusFromServerOperation(
170                state.trailing_metadata, effective_code, effective_details,
171                _EMPTY_FLAGS),)
172            token = _SEND_STATUS_FROM_SERVER_TOKEN
173        call.start_server_batch(operations,
174                                _send_status_from_server(state, token))
175        state.statused = True
176        state.due.add(token)
177
178
179def _receive_close_on_server(state):
180
181    def receive_close_on_server(receive_close_on_server_event):
182        with state.condition:
183            if receive_close_on_server_event.batch_operations[0].cancelled():
184                state.client = _CANCELLED
185            elif state.client is _OPEN:
186                state.client = _CLOSED
187            state.condition.notify_all()
188            return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN)
189
190    return receive_close_on_server
191
192
193def _receive_message(state, call, request_deserializer):
194
195    def receive_message(receive_message_event):
196        serialized_request = _serialized_request(receive_message_event)
197        if serialized_request is None:
198            with state.condition:
199                if state.client is _OPEN:
200                    state.client = _CLOSED
201                state.condition.notify_all()
202                return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
203        else:
204            request = _common.deserialize(serialized_request,
205                                          request_deserializer)
206            with state.condition:
207                if request is None:
208                    _abort(state, call, cygrpc.StatusCode.internal,
209                           b'Exception deserializing request!')
210                else:
211                    state.request = request
212                state.condition.notify_all()
213                return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
214
215    return receive_message
216
217
218def _send_initial_metadata(state):
219
220    def send_initial_metadata(unused_send_initial_metadata_event):
221        with state.condition:
222            return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN)
223
224    return send_initial_metadata
225
226
227def _send_message(state, token):
228
229    def send_message(unused_send_message_event):
230        with state.condition:
231            state.condition.notify_all()
232            return _possibly_finish_call(state, token)
233
234    return send_message
235
236
237class _Context(grpc.ServicerContext):
238
239    def __init__(self, rpc_event, state, request_deserializer):
240        self._rpc_event = rpc_event
241        self._state = state
242        self._request_deserializer = request_deserializer
243
244    def is_active(self):
245        with self._state.condition:
246            return _is_rpc_state_active(self._state)
247
248    def time_remaining(self):
249        return max(self._rpc_event.call_details.deadline - time.time(), 0)
250
251    def cancel(self):
252        self._rpc_event.call.cancel()
253
254    def add_callback(self, callback):
255        with self._state.condition:
256            if self._state.callbacks is None:
257                return False
258            else:
259                self._state.callbacks.append(callback)
260                return True
261
262    def disable_next_message_compression(self):
263        with self._state.condition:
264            self._state.disable_next_compression = True
265
266    def invocation_metadata(self):
267        return self._rpc_event.invocation_metadata
268
269    def peer(self):
270        return _common.decode(self._rpc_event.call.peer())
271
272    def peer_identities(self):
273        return cygrpc.peer_identities(self._rpc_event.call)
274
275    def peer_identity_key(self):
276        id_key = cygrpc.peer_identity_key(self._rpc_event.call)
277        return id_key if id_key is None else _common.decode(id_key)
278
279    def auth_context(self):
280        return {
281            _common.decode(key): value for key, value in six.iteritems(
282                cygrpc.auth_context(self._rpc_event.call))
283        }
284
285    def set_compression(self, compression):
286        with self._state.condition:
287            self._state.compression_algorithm = compression
288
289    def send_initial_metadata(self, initial_metadata):
290        with self._state.condition:
291            if self._state.client is _CANCELLED:
292                _raise_rpc_error(self._state)
293            else:
294                if self._state.initial_metadata_allowed:
295                    operation = _get_initial_metadata_operation(
296                        self._state, initial_metadata)
297                    self._rpc_event.call.start_server_batch(
298                        (operation,), _send_initial_metadata(self._state))
299                    self._state.initial_metadata_allowed = False
300                    self._state.due.add(_SEND_INITIAL_METADATA_TOKEN)
301                else:
302                    raise ValueError('Initial metadata no longer allowed!')
303
304    def set_trailing_metadata(self, trailing_metadata):
305        with self._state.condition:
306            self._state.trailing_metadata = trailing_metadata
307
308    def abort(self, code, details):
309        # treat OK like other invalid arguments: fail the RPC
310        if code == grpc.StatusCode.OK:
311            _LOGGER.error(
312                'abort() called with StatusCode.OK; returning UNKNOWN')
313            code = grpc.StatusCode.UNKNOWN
314            details = ''
315        with self._state.condition:
316            self._state.code = code
317            self._state.details = _common.encode(details)
318            self._state.aborted = True
319            raise Exception()
320
321    def abort_with_status(self, status):
322        self._state.trailing_metadata = status.trailing_metadata
323        self.abort(status.code, status.details)
324
325    def set_code(self, code):
326        with self._state.condition:
327            self._state.code = code
328
329    def set_details(self, details):
330        with self._state.condition:
331            self._state.details = _common.encode(details)
332
333    def _finalize_state(self):
334        pass
335
336
337class _RequestIterator(object):
338
339    def __init__(self, state, call, request_deserializer):
340        self._state = state
341        self._call = call
342        self._request_deserializer = request_deserializer
343
344    def _raise_or_start_receive_message(self):
345        if self._state.client is _CANCELLED:
346            _raise_rpc_error(self._state)
347        elif not _is_rpc_state_active(self._state):
348            raise StopIteration()
349        else:
350            self._call.start_server_batch(
351                (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
352                _receive_message(self._state, self._call,
353                                 self._request_deserializer))
354            self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
355
356    def _look_for_request(self):
357        if self._state.client is _CANCELLED:
358            _raise_rpc_error(self._state)
359        elif (self._state.request is None and
360              _RECEIVE_MESSAGE_TOKEN not in self._state.due):
361            raise StopIteration()
362        else:
363            request = self._state.request
364            self._state.request = None
365            return request
366
367        raise AssertionError()  # should never run
368
369    def _next(self):
370        with self._state.condition:
371            self._raise_or_start_receive_message()
372            while True:
373                self._state.condition.wait()
374                request = self._look_for_request()
375                if request is not None:
376                    return request
377
378    def __iter__(self):
379        return self
380
381    def __next__(self):
382        return self._next()
383
384    def next(self):
385        return self._next()
386
387
388def _unary_request(rpc_event, state, request_deserializer):
389
390    def unary_request():
391        with state.condition:
392            if not _is_rpc_state_active(state):
393                return None
394            else:
395                rpc_event.call.start_server_batch(
396                    (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
397                    _receive_message(state, rpc_event.call,
398                                     request_deserializer))
399                state.due.add(_RECEIVE_MESSAGE_TOKEN)
400                while True:
401                    state.condition.wait()
402                    if state.request is None:
403                        if state.client is _CLOSED:
404                            details = '"{}" requires exactly one request message.'.format(
405                                rpc_event.call_details.method)
406                            _abort(state, rpc_event.call,
407                                   cygrpc.StatusCode.unimplemented,
408                                   _common.encode(details))
409                            return None
410                        elif state.client is _CANCELLED:
411                            return None
412                    else:
413                        request = state.request
414                        state.request = None
415                        return request
416
417    return unary_request
418
419
420def _call_behavior(rpc_event,
421                   state,
422                   behavior,
423                   argument,
424                   request_deserializer,
425                   send_response_callback=None):
426    from grpc import _create_servicer_context
427    with _create_servicer_context(rpc_event, state,
428                                  request_deserializer) as context:
429        try:
430            response_or_iterator = None
431            if send_response_callback is not None:
432                response_or_iterator = behavior(argument, context,
433                                                send_response_callback)
434            else:
435                response_or_iterator = behavior(argument, context)
436            return response_or_iterator, True
437        except Exception as exception:  # pylint: disable=broad-except
438            with state.condition:
439                if state.aborted:
440                    _abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
441                           b'RPC Aborted')
442                elif exception not in state.rpc_errors:
443                    details = 'Exception calling application: {}'.format(
444                        exception)
445                    _LOGGER.exception(details)
446                    _abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
447                           _common.encode(details))
448            return None, False
449
450
451def _take_response_from_response_iterator(rpc_event, state, response_iterator):
452    try:
453        return next(response_iterator), True
454    except StopIteration:
455        return None, True
456    except Exception as exception:  # pylint: disable=broad-except
457        with state.condition:
458            if state.aborted:
459                _abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
460                       b'RPC Aborted')
461            elif exception not in state.rpc_errors:
462                details = 'Exception iterating responses: {}'.format(exception)
463                _LOGGER.exception(details)
464                _abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
465                       _common.encode(details))
466        return None, False
467
468
469def _serialize_response(rpc_event, state, response, response_serializer):
470    serialized_response = _common.serialize(response, response_serializer)
471    if serialized_response is None:
472        with state.condition:
473            _abort(state, rpc_event.call, cygrpc.StatusCode.internal,
474                   b'Failed to serialize response!')
475        return None
476    else:
477        return serialized_response
478
479
480def _get_send_message_op_flags_from_state(state):
481    if state.disable_next_compression:
482        return cygrpc.WriteFlag.no_compress
483    else:
484        return _EMPTY_FLAGS
485
486
487def _reset_per_message_state(state):
488    with state.condition:
489        state.disable_next_compression = False
490
491
492def _send_response(rpc_event, state, serialized_response):
493    with state.condition:
494        if not _is_rpc_state_active(state):
495            return False
496        else:
497            if state.initial_metadata_allowed:
498                operations = (
499                    _get_initial_metadata_operation(state, None),
500                    cygrpc.SendMessageOperation(
501                        serialized_response,
502                        _get_send_message_op_flags_from_state(state)),
503                )
504                state.initial_metadata_allowed = False
505                token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
506            else:
507                operations = (cygrpc.SendMessageOperation(
508                    serialized_response,
509                    _get_send_message_op_flags_from_state(state)),)
510                token = _SEND_MESSAGE_TOKEN
511            rpc_event.call.start_server_batch(operations,
512                                              _send_message(state, token))
513            state.due.add(token)
514            _reset_per_message_state(state)
515            while True:
516                state.condition.wait()
517                if token not in state.due:
518                    return _is_rpc_state_active(state)
519
520
521def _status(rpc_event, state, serialized_response):
522    with state.condition:
523        if state.client is not _CANCELLED:
524            code = _completion_code(state)
525            details = _details(state)
526            operations = [
527                cygrpc.SendStatusFromServerOperation(state.trailing_metadata,
528                                                     code, details,
529                                                     _EMPTY_FLAGS),
530            ]
531            if state.initial_metadata_allowed:
532                operations.append(_get_initial_metadata_operation(state, None))
533            if serialized_response is not None:
534                operations.append(
535                    cygrpc.SendMessageOperation(
536                        serialized_response,
537                        _get_send_message_op_flags_from_state(state)))
538            rpc_event.call.start_server_batch(
539                operations,
540                _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN))
541            state.statused = True
542            _reset_per_message_state(state)
543            state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
544
545
546def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk,
547                            request_deserializer, response_serializer):
548    cygrpc.install_context_from_request_call_event(rpc_event)
549    try:
550        argument = argument_thunk()
551        if argument is not None:
552            response, proceed = _call_behavior(rpc_event, state, behavior,
553                                               argument, request_deserializer)
554            if proceed:
555                serialized_response = _serialize_response(
556                    rpc_event, state, response, response_serializer)
557                if serialized_response is not None:
558                    _status(rpc_event, state, serialized_response)
559    finally:
560        cygrpc.uninstall_context()
561
562
563def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk,
564                             request_deserializer, response_serializer):
565    cygrpc.install_context_from_request_call_event(rpc_event)
566
567    def send_response(response):
568        if response is None:
569            _status(rpc_event, state, None)
570        else:
571            serialized_response = _serialize_response(rpc_event, state,
572                                                      response,
573                                                      response_serializer)
574            if serialized_response is not None:
575                _send_response(rpc_event, state, serialized_response)
576
577    try:
578        argument = argument_thunk()
579        if argument is not None:
580            if hasattr(behavior, 'experimental_non_blocking'
581                      ) and behavior.experimental_non_blocking:
582                _call_behavior(rpc_event,
583                               state,
584                               behavior,
585                               argument,
586                               request_deserializer,
587                               send_response_callback=send_response)
588            else:
589                response_iterator, proceed = _call_behavior(
590                    rpc_event, state, behavior, argument, request_deserializer)
591                if proceed:
592                    _send_message_callback_to_blocking_iterator_adapter(
593                        rpc_event, state, send_response, response_iterator)
594    finally:
595        cygrpc.uninstall_context()
596
597
598def _is_rpc_state_active(state):
599    return state.client is not _CANCELLED and not state.statused
600
601
602def _send_message_callback_to_blocking_iterator_adapter(rpc_event, state,
603                                                        send_response_callback,
604                                                        response_iterator):
605    while True:
606        response, proceed = _take_response_from_response_iterator(
607            rpc_event, state, response_iterator)
608        if proceed:
609            send_response_callback(response)
610            if not _is_rpc_state_active(state):
611                break
612        else:
613            break
614
615
616def _select_thread_pool_for_behavior(behavior, default_thread_pool):
617    if hasattr(behavior, 'experimental_thread_pool') and isinstance(
618            behavior.experimental_thread_pool, futures.ThreadPoolExecutor):
619        return behavior.experimental_thread_pool
620    else:
621        return default_thread_pool
622
623
624def _handle_unary_unary(rpc_event, state, method_handler, default_thread_pool):
625    unary_request = _unary_request(rpc_event, state,
626                                   method_handler.request_deserializer)
627    thread_pool = _select_thread_pool_for_behavior(method_handler.unary_unary,
628                                                   default_thread_pool)
629    return thread_pool.submit(_unary_response_in_pool, rpc_event, state,
630                              method_handler.unary_unary, unary_request,
631                              method_handler.request_deserializer,
632                              method_handler.response_serializer)
633
634
635def _handle_unary_stream(rpc_event, state, method_handler, default_thread_pool):
636    unary_request = _unary_request(rpc_event, state,
637                                   method_handler.request_deserializer)
638    thread_pool = _select_thread_pool_for_behavior(method_handler.unary_stream,
639                                                   default_thread_pool)
640    return thread_pool.submit(_stream_response_in_pool, rpc_event, state,
641                              method_handler.unary_stream, unary_request,
642                              method_handler.request_deserializer,
643                              method_handler.response_serializer)
644
645
646def _handle_stream_unary(rpc_event, state, method_handler, default_thread_pool):
647    request_iterator = _RequestIterator(state, rpc_event.call,
648                                        method_handler.request_deserializer)
649    thread_pool = _select_thread_pool_for_behavior(method_handler.stream_unary,
650                                                   default_thread_pool)
651    return thread_pool.submit(_unary_response_in_pool, rpc_event, state,
652                              method_handler.stream_unary,
653                              lambda: request_iterator,
654                              method_handler.request_deserializer,
655                              method_handler.response_serializer)
656
657
658def _handle_stream_stream(rpc_event, state, method_handler,
659                          default_thread_pool):
660    request_iterator = _RequestIterator(state, rpc_event.call,
661                                        method_handler.request_deserializer)
662    thread_pool = _select_thread_pool_for_behavior(method_handler.stream_stream,
663                                                   default_thread_pool)
664    return thread_pool.submit(_stream_response_in_pool, rpc_event, state,
665                              method_handler.stream_stream,
666                              lambda: request_iterator,
667                              method_handler.request_deserializer,
668                              method_handler.response_serializer)
669
670
671def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline):
672
673    def query_handlers(handler_call_details):
674        for generic_handler in generic_handlers:
675            method_handler = generic_handler.service(handler_call_details)
676            if method_handler is not None:
677                return method_handler
678        return None
679
680    handler_call_details = _HandlerCallDetails(
681        _common.decode(rpc_event.call_details.method),
682        rpc_event.invocation_metadata)
683
684    if interceptor_pipeline is not None:
685        return interceptor_pipeline.execute(query_handlers,
686                                            handler_call_details)
687    else:
688        return query_handlers(handler_call_details)
689
690
691def _reject_rpc(rpc_event, status, details):
692    rpc_state = _RPCState()
693    operations = (
694        _get_initial_metadata_operation(rpc_state, None),
695        cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
696        cygrpc.SendStatusFromServerOperation(None, status, details,
697                                             _EMPTY_FLAGS),
698    )
699    rpc_event.call.start_server_batch(operations, lambda ignored_event: (
700        rpc_state,
701        (),
702    ))
703    return rpc_state
704
705
706def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
707    state = _RPCState()
708    with state.condition:
709        rpc_event.call.start_server_batch(
710            (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),),
711            _receive_close_on_server(state))
712        state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
713        if method_handler.request_streaming:
714            if method_handler.response_streaming:
715                return state, _handle_stream_stream(rpc_event, state,
716                                                    method_handler, thread_pool)
717            else:
718                return state, _handle_stream_unary(rpc_event, state,
719                                                   method_handler, thread_pool)
720        else:
721            if method_handler.response_streaming:
722                return state, _handle_unary_stream(rpc_event, state,
723                                                   method_handler, thread_pool)
724            else:
725                return state, _handle_unary_unary(rpc_event, state,
726                                                  method_handler, thread_pool)
727
728
729def _handle_call(rpc_event, generic_handlers, interceptor_pipeline, thread_pool,
730                 concurrency_exceeded):
731    if not rpc_event.success:
732        return None, None
733    if rpc_event.call_details.method is not None:
734        try:
735            method_handler = _find_method_handler(rpc_event, generic_handlers,
736                                                  interceptor_pipeline)
737        except Exception as exception:  # pylint: disable=broad-except
738            details = 'Exception servicing handler: {}'.format(exception)
739            _LOGGER.exception(details)
740            return _reject_rpc(rpc_event, cygrpc.StatusCode.unknown,
741                               b'Error in service handler!'), None
742        if method_handler is None:
743            return _reject_rpc(rpc_event, cygrpc.StatusCode.unimplemented,
744                               b'Method not found!'), None
745        elif concurrency_exceeded:
746            return _reject_rpc(rpc_event, cygrpc.StatusCode.resource_exhausted,
747                               b'Concurrent RPC limit exceeded!'), None
748        else:
749            return _handle_with_method_handler(rpc_event, method_handler,
750                                               thread_pool)
751    else:
752        return None, None
753
754
755@enum.unique
756class _ServerStage(enum.Enum):
757    STOPPED = 'stopped'
758    STARTED = 'started'
759    GRACE = 'grace'
760
761
762class _ServerState(object):
763
764    # pylint: disable=too-many-arguments
765    def __init__(self, completion_queue, server, generic_handlers,
766                 interceptor_pipeline, thread_pool, maximum_concurrent_rpcs):
767        self.lock = threading.RLock()
768        self.completion_queue = completion_queue
769        self.server = server
770        self.generic_handlers = list(generic_handlers)
771        self.interceptor_pipeline = interceptor_pipeline
772        self.thread_pool = thread_pool
773        self.stage = _ServerStage.STOPPED
774        self.termination_event = threading.Event()
775        self.shutdown_events = [self.termination_event]
776        self.maximum_concurrent_rpcs = maximum_concurrent_rpcs
777        self.active_rpc_count = 0
778
779        # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
780        self.rpc_states = set()
781        self.due = set()
782
783        # A "volatile" flag to interrupt the daemon serving thread
784        self.server_deallocated = False
785
786
787def _add_generic_handlers(state, generic_handlers):
788    with state.lock:
789        state.generic_handlers.extend(generic_handlers)
790
791
792def _add_insecure_port(state, address):
793    with state.lock:
794        return state.server.add_http2_port(address)
795
796
797def _add_secure_port(state, address, server_credentials):
798    with state.lock:
799        return state.server.add_http2_port(address,
800                                           server_credentials._credentials)
801
802
803def _request_call(state):
804    state.server.request_call(state.completion_queue, state.completion_queue,
805                              _REQUEST_CALL_TAG)
806    state.due.add(_REQUEST_CALL_TAG)
807
808
809# TODO(https://github.com/grpc/grpc/issues/6597): delete this function.
810def _stop_serving(state):
811    if not state.rpc_states and not state.due:
812        state.server.destroy()
813        for shutdown_event in state.shutdown_events:
814            shutdown_event.set()
815        state.stage = _ServerStage.STOPPED
816        return True
817    else:
818        return False
819
820
821def _on_call_completed(state):
822    with state.lock:
823        state.active_rpc_count -= 1
824
825
826def _process_event_and_continue(state, event):
827    should_continue = True
828    if event.tag is _SHUTDOWN_TAG:
829        with state.lock:
830            state.due.remove(_SHUTDOWN_TAG)
831            if _stop_serving(state):
832                should_continue = False
833    elif event.tag is _REQUEST_CALL_TAG:
834        with state.lock:
835            state.due.remove(_REQUEST_CALL_TAG)
836            concurrency_exceeded = (
837                state.maximum_concurrent_rpcs is not None and
838                state.active_rpc_count >= state.maximum_concurrent_rpcs)
839            rpc_state, rpc_future = _handle_call(event, state.generic_handlers,
840                                                 state.interceptor_pipeline,
841                                                 state.thread_pool,
842                                                 concurrency_exceeded)
843            if rpc_state is not None:
844                state.rpc_states.add(rpc_state)
845            if rpc_future is not None:
846                state.active_rpc_count += 1
847                rpc_future.add_done_callback(
848                    lambda unused_future: _on_call_completed(state))
849            if state.stage is _ServerStage.STARTED:
850                _request_call(state)
851            elif _stop_serving(state):
852                should_continue = False
853    else:
854        rpc_state, callbacks = event.tag(event)
855        for callback in callbacks:
856            try:
857                callback()
858            except Exception:  # pylint: disable=broad-except
859                _LOGGER.exception('Exception calling callback!')
860        if rpc_state is not None:
861            with state.lock:
862                state.rpc_states.remove(rpc_state)
863                if _stop_serving(state):
864                    should_continue = False
865    return should_continue
866
867
868def _serve(state):
869    while True:
870        timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S
871        event = state.completion_queue.poll(timeout)
872        if state.server_deallocated:
873            _begin_shutdown_once(state)
874        if event.completion_type != cygrpc.CompletionType.queue_timeout:
875            if not _process_event_and_continue(state, event):
876                return
877        # We want to force the deletion of the previous event
878        # ~before~ we poll again; if the event has a reference
879        # to a shutdown Call object, this can induce spinlock.
880        event = None
881
882
883def _begin_shutdown_once(state):
884    with state.lock:
885        if state.stage is _ServerStage.STARTED:
886            state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG)
887            state.stage = _ServerStage.GRACE
888            state.due.add(_SHUTDOWN_TAG)
889
890
891def _stop(state, grace):
892    with state.lock:
893        if state.stage is _ServerStage.STOPPED:
894            shutdown_event = threading.Event()
895            shutdown_event.set()
896            return shutdown_event
897        else:
898            _begin_shutdown_once(state)
899            shutdown_event = threading.Event()
900            state.shutdown_events.append(shutdown_event)
901            if grace is None:
902                state.server.cancel_all_calls()
903            else:
904
905                def cancel_all_calls_after_grace():
906                    shutdown_event.wait(timeout=grace)
907                    with state.lock:
908                        state.server.cancel_all_calls()
909
910                thread = threading.Thread(target=cancel_all_calls_after_grace)
911                thread.start()
912                return shutdown_event
913    shutdown_event.wait()
914    return shutdown_event
915
916
917def _start(state):
918    with state.lock:
919        if state.stage is not _ServerStage.STOPPED:
920            raise ValueError('Cannot start already-started server!')
921        state.server.start()
922        state.stage = _ServerStage.STARTED
923        _request_call(state)
924
925        thread = threading.Thread(target=_serve, args=(state,))
926        thread.daemon = True
927        thread.start()
928
929
930def _validate_generic_rpc_handlers(generic_rpc_handlers):
931    for generic_rpc_handler in generic_rpc_handlers:
932        service_attribute = getattr(generic_rpc_handler, 'service', None)
933        if service_attribute is None:
934            raise AttributeError(
935                '"{}" must conform to grpc.GenericRpcHandler type but does '
936                'not have "service" method!'.format(generic_rpc_handler))
937
938
939def _augment_options(base_options, compression):
940    compression_option = _compression.create_channel_option(compression)
941    return tuple(base_options) + compression_option
942
943
944class _Server(grpc.Server):
945
946    # pylint: disable=too-many-arguments
947    def __init__(self, thread_pool, generic_handlers, interceptors, options,
948                 maximum_concurrent_rpcs, compression, xds):
949        completion_queue = cygrpc.CompletionQueue()
950        server = cygrpc.Server(_augment_options(options, compression), xds)
951        server.register_completion_queue(completion_queue)
952        self._state = _ServerState(completion_queue, server, generic_handlers,
953                                   _interceptor.service_pipeline(interceptors),
954                                   thread_pool, maximum_concurrent_rpcs)
955
956    def add_generic_rpc_handlers(self, generic_rpc_handlers):
957        _validate_generic_rpc_handlers(generic_rpc_handlers)
958        _add_generic_handlers(self._state, generic_rpc_handlers)
959
960    def add_insecure_port(self, address):
961        return _common.validate_port_binding_result(
962            address, _add_insecure_port(self._state, _common.encode(address)))
963
964    def add_secure_port(self, address, server_credentials):
965        return _common.validate_port_binding_result(
966            address,
967            _add_secure_port(self._state, _common.encode(address),
968                             server_credentials))
969
970    def start(self):
971        _start(self._state)
972
973    def wait_for_termination(self, timeout=None):
974        # NOTE(https://bugs.python.org/issue35935)
975        # Remove this workaround once threading.Event.wait() is working with
976        # CTRL+C across platforms.
977        return _common.wait(self._state.termination_event.wait,
978                            self._state.termination_event.is_set,
979                            timeout=timeout)
980
981    def stop(self, grace):
982        return _stop(self._state, grace)
983
984    def __del__(self):
985        if hasattr(self, '_state'):
986            # We can not grab a lock in __del__(), so set a flag to signal the
987            # serving daemon thread (if it exists) to initiate shutdown.
988            self._state.server_deallocated = True
989
990
991def create_server(thread_pool, generic_rpc_handlers, interceptors, options,
992                  maximum_concurrent_rpcs, compression, xds):
993    _validate_generic_rpc_handlers(generic_rpc_handlers)
994    return _Server(thread_pool, generic_rpc_handlers, interceptors, options,
995                   maximum_concurrent_rpcs, compression, xds)
996