• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Service-side implementation of gRPC Python."""
15
16import collections
17import enum
18import logging
19import threading
20import time
21
22import six
23
24import grpc
25from grpc import _common
26from grpc import _interceptor
27from grpc._cython import cygrpc
28from grpc.framework.foundation import callable_util
29
30logging.basicConfig()
31_LOGGER = logging.getLogger(__name__)
32
33_SHUTDOWN_TAG = 'shutdown'
34_REQUEST_CALL_TAG = 'request_call'
35
36_RECEIVE_CLOSE_ON_SERVER_TOKEN = 'receive_close_on_server'
37_SEND_INITIAL_METADATA_TOKEN = 'send_initial_metadata'
38_RECEIVE_MESSAGE_TOKEN = 'receive_message'
39_SEND_MESSAGE_TOKEN = 'send_message'
40_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = (
41    'send_initial_metadata * send_message')
42_SEND_STATUS_FROM_SERVER_TOKEN = 'send_status_from_server'
43_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = (
44    'send_initial_metadata * send_status_from_server')
45
46_OPEN = 'open'
47_CLOSED = 'closed'
48_CANCELLED = 'cancelled'
49
50_EMPTY_FLAGS = 0
51
52_UNEXPECTED_EXIT_SERVER_GRACE = 1.0
53
54
55def _serialized_request(request_event):
56    return request_event.batch_operations[0].message()
57
58
59def _application_code(code):
60    cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code)
61    return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code
62
63
64def _completion_code(state):
65    if state.code is None:
66        return cygrpc.StatusCode.ok
67    else:
68        return _application_code(state.code)
69
70
71def _abortion_code(state, code):
72    if state.code is None:
73        return code
74    else:
75        return _application_code(state.code)
76
77
78def _details(state):
79    return b'' if state.details is None else state.details
80
81
82class _HandlerCallDetails(
83        collections.namedtuple('_HandlerCallDetails', (
84            'method',
85            'invocation_metadata',
86        )), grpc.HandlerCallDetails):
87    pass
88
89
90class _RPCState(object):
91
92    def __init__(self):
93        self.condition = threading.Condition()
94        self.due = set()
95        self.request = None
96        self.client = _OPEN
97        self.initial_metadata_allowed = True
98        self.disable_next_compression = False
99        self.trailing_metadata = None
100        self.code = None
101        self.details = None
102        self.statused = False
103        self.rpc_errors = []
104        self.callbacks = []
105        self.abortion = None
106
107
108def _raise_rpc_error(state):
109    rpc_error = grpc.RpcError()
110    state.rpc_errors.append(rpc_error)
111    raise rpc_error
112
113
114def _possibly_finish_call(state, token):
115    state.due.remove(token)
116    if (state.client is _CANCELLED or state.statused) and not state.due:
117        callbacks = state.callbacks
118        state.callbacks = None
119        return state, callbacks
120    else:
121        return None, ()
122
123
124def _send_status_from_server(state, token):
125
126    def send_status_from_server(unused_send_status_from_server_event):
127        with state.condition:
128            return _possibly_finish_call(state, token)
129
130    return send_status_from_server
131
132
133def _abort(state, call, code, details):
134    if state.client is not _CANCELLED:
135        effective_code = _abortion_code(state, code)
136        effective_details = details if state.details is None else state.details
137        if state.initial_metadata_allowed:
138            operations = (
139                cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS),
140                cygrpc.SendStatusFromServerOperation(
141                    state.trailing_metadata, effective_code, effective_details,
142                    _EMPTY_FLAGS),
143            )
144            token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
145        else:
146            operations = (cygrpc.SendStatusFromServerOperation(
147                state.trailing_metadata, effective_code, effective_details,
148                _EMPTY_FLAGS),)
149            token = _SEND_STATUS_FROM_SERVER_TOKEN
150        call.start_server_batch(operations,
151                                _send_status_from_server(state, token))
152        state.statused = True
153        state.due.add(token)
154
155
156def _receive_close_on_server(state):
157
158    def receive_close_on_server(receive_close_on_server_event):
159        with state.condition:
160            if receive_close_on_server_event.batch_operations[0].cancelled():
161                state.client = _CANCELLED
162            elif state.client is _OPEN:
163                state.client = _CLOSED
164            state.condition.notify_all()
165            return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN)
166
167    return receive_close_on_server
168
169
170def _receive_message(state, call, request_deserializer):
171
172    def receive_message(receive_message_event):
173        serialized_request = _serialized_request(receive_message_event)
174        if serialized_request is None:
175            with state.condition:
176                if state.client is _OPEN:
177                    state.client = _CLOSED
178                state.condition.notify_all()
179                return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
180        else:
181            request = _common.deserialize(serialized_request,
182                                          request_deserializer)
183            with state.condition:
184                if request is None:
185                    _abort(state, call, cygrpc.StatusCode.internal,
186                           b'Exception deserializing request!')
187                else:
188                    state.request = request
189                state.condition.notify_all()
190                return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
191
192    return receive_message
193
194
195def _send_initial_metadata(state):
196
197    def send_initial_metadata(unused_send_initial_metadata_event):
198        with state.condition:
199            return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN)
200
201    return send_initial_metadata
202
203
204def _send_message(state, token):
205
206    def send_message(unused_send_message_event):
207        with state.condition:
208            state.condition.notify_all()
209            return _possibly_finish_call(state, token)
210
211    return send_message
212
213
214class _Context(grpc.ServicerContext):
215
216    def __init__(self, rpc_event, state, request_deserializer):
217        self._rpc_event = rpc_event
218        self._state = state
219        self._request_deserializer = request_deserializer
220
221    def is_active(self):
222        with self._state.condition:
223            return self._state.client is not _CANCELLED and not self._state.statused
224
225    def time_remaining(self):
226        return max(self._rpc_event.call_details.deadline - time.time(), 0)
227
228    def cancel(self):
229        self._rpc_event.call.cancel()
230
231    def add_callback(self, callback):
232        with self._state.condition:
233            if self._state.callbacks is None:
234                return False
235            else:
236                self._state.callbacks.append(callback)
237                return True
238
239    def disable_next_message_compression(self):
240        with self._state.condition:
241            self._state.disable_next_compression = True
242
243    def invocation_metadata(self):
244        return self._rpc_event.invocation_metadata
245
246    def peer(self):
247        return _common.decode(self._rpc_event.call.peer())
248
249    def peer_identities(self):
250        return cygrpc.peer_identities(self._rpc_event.call)
251
252    def peer_identity_key(self):
253        id_key = cygrpc.peer_identity_key(self._rpc_event.call)
254        return id_key if id_key is None else _common.decode(id_key)
255
256    def auth_context(self):
257        return {
258            _common.decode(key): value
259            for key, value in six.iteritems(
260                cygrpc.auth_context(self._rpc_event.call))
261        }
262
263    def send_initial_metadata(self, initial_metadata):
264        with self._state.condition:
265            if self._state.client is _CANCELLED:
266                _raise_rpc_error(self._state)
267            else:
268                if self._state.initial_metadata_allowed:
269                    operation = cygrpc.SendInitialMetadataOperation(
270                        initial_metadata, _EMPTY_FLAGS)
271                    self._rpc_event.call.start_server_batch(
272                        (operation,), _send_initial_metadata(self._state))
273                    self._state.initial_metadata_allowed = False
274                    self._state.due.add(_SEND_INITIAL_METADATA_TOKEN)
275                else:
276                    raise ValueError('Initial metadata no longer allowed!')
277
278    def set_trailing_metadata(self, trailing_metadata):
279        with self._state.condition:
280            self._state.trailing_metadata = trailing_metadata
281
282    def abort(self, code, details):
283        # treat OK like other invalid arguments: fail the RPC
284        if code == grpc.StatusCode.OK:
285            _LOGGER.error(
286                'abort() called with StatusCode.OK; returning UNKNOWN')
287            code = grpc.StatusCode.UNKNOWN
288            details = ''
289        with self._state.condition:
290            self._state.code = code
291            self._state.details = _common.encode(details)
292            self._state.abortion = Exception()
293            raise self._state.abortion
294
295    def set_code(self, code):
296        with self._state.condition:
297            self._state.code = code
298
299    def set_details(self, details):
300        with self._state.condition:
301            self._state.details = _common.encode(details)
302
303
304class _RequestIterator(object):
305
306    def __init__(self, state, call, request_deserializer):
307        self._state = state
308        self._call = call
309        self._request_deserializer = request_deserializer
310
311    def _raise_or_start_receive_message(self):
312        if self._state.client is _CANCELLED:
313            _raise_rpc_error(self._state)
314        elif self._state.client is _CLOSED or self._state.statused:
315            raise StopIteration()
316        else:
317            self._call.start_server_batch(
318                (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
319                _receive_message(self._state, self._call,
320                                 self._request_deserializer))
321            self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
322
323    def _look_for_request(self):
324        if self._state.client is _CANCELLED:
325            _raise_rpc_error(self._state)
326        elif (self._state.request is None and
327              _RECEIVE_MESSAGE_TOKEN not in self._state.due):
328            raise StopIteration()
329        else:
330            request = self._state.request
331            self._state.request = None
332            return request
333
334        raise AssertionError()  # should never run
335
336    def _next(self):
337        with self._state.condition:
338            self._raise_or_start_receive_message()
339            while True:
340                self._state.condition.wait()
341                request = self._look_for_request()
342                if request is not None:
343                    return request
344
345    def __iter__(self):
346        return self
347
348    def __next__(self):
349        return self._next()
350
351    def next(self):
352        return self._next()
353
354
355def _unary_request(rpc_event, state, request_deserializer):
356
357    def unary_request():
358        with state.condition:
359            if state.client is _CANCELLED or state.statused:
360                return None
361            else:
362                rpc_event.call.start_server_batch(
363                    (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
364                    _receive_message(state, rpc_event.call,
365                                     request_deserializer))
366                state.due.add(_RECEIVE_MESSAGE_TOKEN)
367                while True:
368                    state.condition.wait()
369                    if state.request is None:
370                        if state.client is _CLOSED:
371                            details = '"{}" requires exactly one request message.'.format(
372                                rpc_event.call_details.method)
373                            _abort(state, rpc_event.call,
374                                   cygrpc.StatusCode.unimplemented,
375                                   _common.encode(details))
376                            return None
377                        elif state.client is _CANCELLED:
378                            return None
379                    else:
380                        request = state.request
381                        state.request = None
382                        return request
383
384    return unary_request
385
386
387def _call_behavior(rpc_event, state, behavior, argument, request_deserializer):
388    context = _Context(rpc_event, state, request_deserializer)
389    try:
390        return behavior(argument, context), True
391    except Exception as exception:  # pylint: disable=broad-except
392        with state.condition:
393            if exception is state.abortion:
394                _abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
395                       b'RPC Aborted')
396            elif exception not in state.rpc_errors:
397                details = 'Exception calling application: {}'.format(exception)
398                _LOGGER.exception(details)
399                _abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
400                       _common.encode(details))
401        return None, False
402
403
404def _take_response_from_response_iterator(rpc_event, state, response_iterator):
405    try:
406        return next(response_iterator), True
407    except StopIteration:
408        return None, True
409    except Exception as exception:  # pylint: disable=broad-except
410        with state.condition:
411            if exception is state.abortion:
412                _abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
413                       b'RPC Aborted')
414            elif exception not in state.rpc_errors:
415                details = 'Exception iterating responses: {}'.format(exception)
416                _LOGGER.exception(details)
417                _abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
418                       _common.encode(details))
419        return None, False
420
421
422def _serialize_response(rpc_event, state, response, response_serializer):
423    serialized_response = _common.serialize(response, response_serializer)
424    if serialized_response is None:
425        with state.condition:
426            _abort(state, rpc_event.call, cygrpc.StatusCode.internal,
427                   b'Failed to serialize response!')
428        return None
429    else:
430        return serialized_response
431
432
433def _send_response(rpc_event, state, serialized_response):
434    with state.condition:
435        if state.client is _CANCELLED or state.statused:
436            return False
437        else:
438            if state.initial_metadata_allowed:
439                operations = (
440                    cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS),
441                    cygrpc.SendMessageOperation(serialized_response,
442                                                _EMPTY_FLAGS),
443                )
444                state.initial_metadata_allowed = False
445                token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
446            else:
447                operations = (cygrpc.SendMessageOperation(
448                    serialized_response, _EMPTY_FLAGS),)
449                token = _SEND_MESSAGE_TOKEN
450            rpc_event.call.start_server_batch(operations,
451                                              _send_message(state, token))
452            state.due.add(token)
453            while True:
454                state.condition.wait()
455                if token not in state.due:
456                    return state.client is not _CANCELLED and not state.statused
457
458
459def _status(rpc_event, state, serialized_response):
460    with state.condition:
461        if state.client is not _CANCELLED:
462            code = _completion_code(state)
463            details = _details(state)
464            operations = [
465                cygrpc.SendStatusFromServerOperation(
466                    state.trailing_metadata, code, details, _EMPTY_FLAGS),
467            ]
468            if state.initial_metadata_allowed:
469                operations.append(
470                    cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS))
471            if serialized_response is not None:
472                operations.append(
473                    cygrpc.SendMessageOperation(serialized_response,
474                                                _EMPTY_FLAGS))
475            rpc_event.call.start_server_batch(
476                operations,
477                _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN))
478            state.statused = True
479            state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
480
481
482def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk,
483                            request_deserializer, response_serializer):
484    argument = argument_thunk()
485    if argument is not None:
486        response, proceed = _call_behavior(rpc_event, state, behavior, argument,
487                                           request_deserializer)
488        if proceed:
489            serialized_response = _serialize_response(
490                rpc_event, state, response, response_serializer)
491            if serialized_response is not None:
492                _status(rpc_event, state, serialized_response)
493
494
495def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk,
496                             request_deserializer, response_serializer):
497    argument = argument_thunk()
498    if argument is not None:
499        response_iterator, proceed = _call_behavior(
500            rpc_event, state, behavior, argument, request_deserializer)
501        if proceed:
502            while True:
503                response, proceed = _take_response_from_response_iterator(
504                    rpc_event, state, response_iterator)
505                if proceed:
506                    if response is None:
507                        _status(rpc_event, state, None)
508                        break
509                    else:
510                        serialized_response = _serialize_response(
511                            rpc_event, state, response, response_serializer)
512                        if serialized_response is not None:
513                            proceed = _send_response(rpc_event, state,
514                                                     serialized_response)
515                            if not proceed:
516                                break
517                        else:
518                            break
519                else:
520                    break
521
522
523def _handle_unary_unary(rpc_event, state, method_handler, thread_pool):
524    unary_request = _unary_request(rpc_event, state,
525                                   method_handler.request_deserializer)
526    return thread_pool.submit(_unary_response_in_pool, rpc_event, state,
527                              method_handler.unary_unary, unary_request,
528                              method_handler.request_deserializer,
529                              method_handler.response_serializer)
530
531
532def _handle_unary_stream(rpc_event, state, method_handler, thread_pool):
533    unary_request = _unary_request(rpc_event, state,
534                                   method_handler.request_deserializer)
535    return thread_pool.submit(_stream_response_in_pool, rpc_event, state,
536                              method_handler.unary_stream, unary_request,
537                              method_handler.request_deserializer,
538                              method_handler.response_serializer)
539
540
541def _handle_stream_unary(rpc_event, state, method_handler, thread_pool):
542    request_iterator = _RequestIterator(state, rpc_event.call,
543                                        method_handler.request_deserializer)
544    return thread_pool.submit(
545        _unary_response_in_pool, rpc_event, state, method_handler.stream_unary,
546        lambda: request_iterator, method_handler.request_deserializer,
547        method_handler.response_serializer)
548
549
550def _handle_stream_stream(rpc_event, state, method_handler, thread_pool):
551    request_iterator = _RequestIterator(state, rpc_event.call,
552                                        method_handler.request_deserializer)
553    return thread_pool.submit(
554        _stream_response_in_pool, rpc_event, state,
555        method_handler.stream_stream, lambda: request_iterator,
556        method_handler.request_deserializer, method_handler.response_serializer)
557
558
559def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline):
560
561    def query_handlers(handler_call_details):
562        for generic_handler in generic_handlers:
563            method_handler = generic_handler.service(handler_call_details)
564            if method_handler is not None:
565                return method_handler
566        return None
567
568    handler_call_details = _HandlerCallDetails(
569        _common.decode(rpc_event.call_details.method),
570        rpc_event.invocation_metadata)
571
572    if interceptor_pipeline is not None:
573        return interceptor_pipeline.execute(query_handlers,
574                                            handler_call_details)
575    else:
576        return query_handlers(handler_call_details)
577
578
579def _reject_rpc(rpc_event, status, details):
580    operations = (
581        cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS),
582        cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
583        cygrpc.SendStatusFromServerOperation(None, status, details,
584                                             _EMPTY_FLAGS),
585    )
586    rpc_state = _RPCState()
587    rpc_event.call.start_server_batch(operations,
588                                      lambda ignored_event: (rpc_state, (),))
589    return rpc_state
590
591
592def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
593    state = _RPCState()
594    with state.condition:
595        rpc_event.call.start_server_batch(
596            (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),),
597            _receive_close_on_server(state))
598        state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
599        if method_handler.request_streaming:
600            if method_handler.response_streaming:
601                return state, _handle_stream_stream(rpc_event, state,
602                                                    method_handler, thread_pool)
603            else:
604                return state, _handle_stream_unary(rpc_event, state,
605                                                   method_handler, thread_pool)
606        else:
607            if method_handler.response_streaming:
608                return state, _handle_unary_stream(rpc_event, state,
609                                                   method_handler, thread_pool)
610            else:
611                return state, _handle_unary_unary(rpc_event, state,
612                                                  method_handler, thread_pool)
613
614
615def _handle_call(rpc_event, generic_handlers, interceptor_pipeline, thread_pool,
616                 concurrency_exceeded):
617    if not rpc_event.success:
618        return None, None
619    if rpc_event.call_details.method is not None:
620        try:
621            method_handler = _find_method_handler(rpc_event, generic_handlers,
622                                                  interceptor_pipeline)
623        except Exception as exception:  # pylint: disable=broad-except
624            details = 'Exception servicing handler: {}'.format(exception)
625            _LOGGER.exception(details)
626            return _reject_rpc(rpc_event, cygrpc.StatusCode.unknown,
627                               b'Error in service handler!'), None
628        if method_handler is None:
629            return _reject_rpc(rpc_event, cygrpc.StatusCode.unimplemented,
630                               b'Method not found!'), None
631        elif concurrency_exceeded:
632            return _reject_rpc(rpc_event, cygrpc.StatusCode.resource_exhausted,
633                               b'Concurrent RPC limit exceeded!'), None
634        else:
635            return _handle_with_method_handler(rpc_event, method_handler,
636                                               thread_pool)
637    else:
638        return None, None
639
640
641@enum.unique
642class _ServerStage(enum.Enum):
643    STOPPED = 'stopped'
644    STARTED = 'started'
645    GRACE = 'grace'
646
647
648class _ServerState(object):
649
650    # pylint: disable=too-many-arguments
651    def __init__(self, completion_queue, server, generic_handlers,
652                 interceptor_pipeline, thread_pool, maximum_concurrent_rpcs):
653        self.lock = threading.RLock()
654        self.completion_queue = completion_queue
655        self.server = server
656        self.generic_handlers = list(generic_handlers)
657        self.interceptor_pipeline = interceptor_pipeline
658        self.thread_pool = thread_pool
659        self.stage = _ServerStage.STOPPED
660        self.shutdown_events = None
661        self.maximum_concurrent_rpcs = maximum_concurrent_rpcs
662        self.active_rpc_count = 0
663
664        # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
665        self.rpc_states = set()
666        self.due = set()
667
668
669def _add_generic_handlers(state, generic_handlers):
670    with state.lock:
671        state.generic_handlers.extend(generic_handlers)
672
673
674def _add_insecure_port(state, address):
675    with state.lock:
676        return state.server.add_http2_port(address)
677
678
679def _add_secure_port(state, address, server_credentials):
680    with state.lock:
681        return state.server.add_http2_port(address,
682                                           server_credentials._credentials)
683
684
685def _request_call(state):
686    state.server.request_call(state.completion_queue, state.completion_queue,
687                              _REQUEST_CALL_TAG)
688    state.due.add(_REQUEST_CALL_TAG)
689
690
691# TODO(https://github.com/grpc/grpc/issues/6597): delete this function.
692def _stop_serving(state):
693    if not state.rpc_states and not state.due:
694        for shutdown_event in state.shutdown_events:
695            shutdown_event.set()
696        state.stage = _ServerStage.STOPPED
697        return True
698    else:
699        return False
700
701
702def _on_call_completed(state):
703    with state.lock:
704        state.active_rpc_count -= 1
705
706
707def _serve(state):
708    while True:
709        event = state.completion_queue.poll()
710        if event.tag is _SHUTDOWN_TAG:
711            with state.lock:
712                state.due.remove(_SHUTDOWN_TAG)
713                if _stop_serving(state):
714                    return
715        elif event.tag is _REQUEST_CALL_TAG:
716            with state.lock:
717                state.due.remove(_REQUEST_CALL_TAG)
718                concurrency_exceeded = (
719                    state.maximum_concurrent_rpcs is not None and
720                    state.active_rpc_count >= state.maximum_concurrent_rpcs)
721                rpc_state, rpc_future = _handle_call(
722                    event, state.generic_handlers, state.interceptor_pipeline,
723                    state.thread_pool, concurrency_exceeded)
724                if rpc_state is not None:
725                    state.rpc_states.add(rpc_state)
726                if rpc_future is not None:
727                    state.active_rpc_count += 1
728                    rpc_future.add_done_callback(
729                        lambda unused_future: _on_call_completed(state))
730                if state.stage is _ServerStage.STARTED:
731                    _request_call(state)
732                elif _stop_serving(state):
733                    return
734        else:
735            rpc_state, callbacks = event.tag(event)
736            for callback in callbacks:
737                callable_util.call_logging_exceptions(
738                    callback, 'Exception calling callback!')
739            if rpc_state is not None:
740                with state.lock:
741                    state.rpc_states.remove(rpc_state)
742                    if _stop_serving(state):
743                        return
744        # We want to force the deletion of the previous event
745        # ~before~ we poll again; if the event has a reference
746        # to a shutdown Call object, this can induce spinlock.
747        event = None
748
749
750def _stop(state, grace):
751    with state.lock:
752        if state.stage is _ServerStage.STOPPED:
753            shutdown_event = threading.Event()
754            shutdown_event.set()
755            return shutdown_event
756        else:
757            if state.stage is _ServerStage.STARTED:
758                state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG)
759                state.stage = _ServerStage.GRACE
760                state.shutdown_events = []
761                state.due.add(_SHUTDOWN_TAG)
762            shutdown_event = threading.Event()
763            state.shutdown_events.append(shutdown_event)
764            if grace is None:
765                state.server.cancel_all_calls()
766            else:
767
768                def cancel_all_calls_after_grace():
769                    shutdown_event.wait(timeout=grace)
770                    with state.lock:
771                        state.server.cancel_all_calls()
772
773                thread = threading.Thread(target=cancel_all_calls_after_grace)
774                thread.start()
775                return shutdown_event
776    shutdown_event.wait()
777    return shutdown_event
778
779
780def _start(state):
781    with state.lock:
782        if state.stage is not _ServerStage.STOPPED:
783            raise ValueError('Cannot start already-started server!')
784        state.server.start()
785        state.stage = _ServerStage.STARTED
786        _request_call(state)
787
788        thread = threading.Thread(target=_serve, args=(state,))
789        thread.daemon = True
790        thread.start()
791
792
793def _validate_generic_rpc_handlers(generic_rpc_handlers):
794    for generic_rpc_handler in generic_rpc_handlers:
795        service_attribute = getattr(generic_rpc_handler, 'service', None)
796        if service_attribute is None:
797            raise AttributeError(
798                '"{}" must conform to grpc.GenericRpcHandler type but does '
799                'not have "service" method!'.format(generic_rpc_handler))
800
801
802class _Server(grpc.Server):
803
804    # pylint: disable=too-many-arguments
805    def __init__(self, thread_pool, generic_handlers, interceptors, options,
806                 maximum_concurrent_rpcs):
807        completion_queue = cygrpc.CompletionQueue()
808        server = cygrpc.Server(options)
809        server.register_completion_queue(completion_queue)
810        self._state = _ServerState(completion_queue, server, generic_handlers,
811                                   _interceptor.service_pipeline(interceptors),
812                                   thread_pool, maximum_concurrent_rpcs)
813
814    def add_generic_rpc_handlers(self, generic_rpc_handlers):
815        _validate_generic_rpc_handlers(generic_rpc_handlers)
816        _add_generic_handlers(self._state, generic_rpc_handlers)
817
818    def add_insecure_port(self, address):
819        return _add_insecure_port(self._state, _common.encode(address))
820
821    def add_secure_port(self, address, server_credentials):
822        return _add_secure_port(self._state, _common.encode(address),
823                                server_credentials)
824
825    def start(self):
826        _start(self._state)
827
828    def stop(self, grace):
829        return _stop(self._state, grace)
830
831    def __del__(self):
832        _stop(self._state, None)
833
834
835def create_server(thread_pool, generic_rpc_handlers, interceptors, options,
836                  maximum_concurrent_rpcs):
837    _validate_generic_rpc_handlers(generic_rpc_handlers)
838    return _Server(thread_pool, generic_rpc_handlers, interceptors, options,
839                   maximum_concurrent_rpcs)
840