• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Defines a number of module-scope gRPC scenarios to test clean exit."""
15
16import argparse
17import logging
18import threading
19import time
20
21import grpc
22
23from tests.unit.framework.common import test_constants
24
25WAIT_TIME = 1000
26
27REQUEST = b"request"
28
29UNSTARTED_SERVER = "unstarted_server"
30RUNNING_SERVER = "running_server"
31POLL_CONNECTIVITY_NO_SERVER = "poll_connectivity_no_server"
32POLL_CONNECTIVITY = "poll_connectivity"
33IN_FLIGHT_UNARY_UNARY_CALL = "in_flight_unary_unary_call"
34IN_FLIGHT_UNARY_STREAM_CALL = "in_flight_unary_stream_call"
35IN_FLIGHT_STREAM_UNARY_CALL = "in_flight_stream_unary_call"
36IN_FLIGHT_STREAM_STREAM_CALL = "in_flight_stream_stream_call"
37IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL = "in_flight_partial_unary_stream_call"
38IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL = "in_flight_partial_stream_unary_call"
39IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL = "in_flight_partial_stream_stream_call"
40
41_SERVICE_NAME = "test"
42UNARY_UNARY = b"UnaryUnary"
43UNARY_STREAM = b"UnaryStream"
44STREAM_UNARY = b"StreamUnary"
45STREAM_STREAM = b"StreamStream"
46PARTIAL_UNARY_STREAM = b"PartialUnaryStream"
47PARTIAL_STREAM_UNARY = b"PartialStreamUnary"
48PARTIAL_STREAM_STREAM = b"PartialStreamStream"
49
50TEST_TO_METHOD = {
51    IN_FLIGHT_UNARY_UNARY_CALL: UNARY_UNARY,
52    IN_FLIGHT_UNARY_STREAM_CALL: UNARY_STREAM,
53    IN_FLIGHT_STREAM_UNARY_CALL: STREAM_UNARY,
54    IN_FLIGHT_STREAM_STREAM_CALL: STREAM_STREAM,
55    IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL: PARTIAL_UNARY_STREAM,
56    IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL: PARTIAL_STREAM_UNARY,
57    IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL: PARTIAL_STREAM_STREAM,
58}
59
60
61def hang_unary_unary(request, servicer_context):
62    time.sleep(WAIT_TIME)
63
64
65def hang_unary_stream(request, servicer_context):
66    time.sleep(WAIT_TIME)
67
68
69def hang_partial_unary_stream(request, servicer_context):
70    for _ in range(test_constants.STREAM_LENGTH // 2):
71        yield request
72    time.sleep(WAIT_TIME)
73
74
75def hang_stream_unary(request_iterator, servicer_context):
76    time.sleep(WAIT_TIME)
77
78
79def hang_partial_stream_unary(request_iterator, servicer_context):
80    for _ in range(test_constants.STREAM_LENGTH // 2):
81        next(request_iterator)
82    time.sleep(WAIT_TIME)
83
84
85def hang_stream_stream(request_iterator, servicer_context):
86    time.sleep(WAIT_TIME)
87
88
89def hang_partial_stream_stream(request_iterator, servicer_context):
90    for _ in range(test_constants.STREAM_LENGTH // 2):
91        yield next(request_iterator)  # pylint: disable=stop-iteration-return
92    time.sleep(WAIT_TIME)
93
94
95class MethodHandler(grpc.RpcMethodHandler):
96    def __init__(self, request_streaming, response_streaming, partial_hang):
97        self.request_streaming = request_streaming
98        self.response_streaming = response_streaming
99        self.request_deserializer = None
100        self.response_serializer = None
101        self.unary_unary = None
102        self.unary_stream = None
103        self.stream_unary = None
104        self.stream_stream = None
105        if self.request_streaming and self.response_streaming:
106            if partial_hang:
107                self.stream_stream = hang_partial_stream_stream
108            else:
109                self.stream_stream = hang_stream_stream
110        elif self.request_streaming:
111            if partial_hang:
112                self.stream_unary = hang_partial_stream_unary
113            else:
114                self.stream_unary = hang_stream_unary
115        elif self.response_streaming:
116            if partial_hang:
117                self.unary_stream = hang_partial_unary_stream
118            else:
119                self.unary_stream = hang_unary_stream
120        else:
121            self.unary_unary = hang_unary_unary
122
123
124class GenericHandler(grpc.GenericRpcHandler):
125    def service(self, handler_call_details):
126        if handler_call_details.method == UNARY_UNARY:
127            return MethodHandler(False, False, False)
128        elif handler_call_details.method == UNARY_STREAM:
129            return MethodHandler(False, True, False)
130        elif handler_call_details.method == STREAM_UNARY:
131            return MethodHandler(True, False, False)
132        elif handler_call_details.method == STREAM_STREAM:
133            return MethodHandler(True, True, False)
134        elif handler_call_details.method == PARTIAL_UNARY_STREAM:
135            return MethodHandler(False, True, True)
136        elif handler_call_details.method == PARTIAL_STREAM_UNARY:
137            return MethodHandler(True, False, True)
138        elif handler_call_details.method == PARTIAL_STREAM_STREAM:
139            return MethodHandler(True, True, True)
140        else:
141            return None
142
143
144_METHOD_HANDLERS = {
145    UNARY_UNARY: MethodHandler(False, False, False),
146    UNARY_STREAM: MethodHandler(False, True, False),
147    STREAM_UNARY: MethodHandler(True, False, False),
148    STREAM_STREAM: MethodHandler(True, True, False),
149    PARTIAL_UNARY_STREAM: MethodHandler(False, True, True),
150    PARTIAL_STREAM_UNARY: MethodHandler(True, False, True),
151    PARTIAL_STREAM_STREAM: MethodHandler(True, True, True),
152}
153
154
155# Traditional executors will not exit until all their
156# current jobs complete.  Because we submit jobs that will
157# never finish, we don't want to block exit on these jobs.
158class DaemonPool(object):
159    def submit(self, fn, *args, **kwargs):
160        thread = threading.Thread(target=fn, args=args, kwargs=kwargs)
161        thread.daemon = True
162        thread.start()
163
164    def shutdown(self, wait=True):
165        pass
166
167
168def infinite_request_iterator():
169    while True:
170        yield REQUEST
171
172
173if __name__ == "__main__":
174    logging.basicConfig()
175    parser = argparse.ArgumentParser()
176    parser.add_argument("scenario", type=str)
177    parser.add_argument(
178        "--wait_for_interrupt", dest="wait_for_interrupt", action="store_true"
179    )
180    args = parser.parse_args()
181
182    if args.scenario == UNSTARTED_SERVER:
183        server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),))
184        if args.wait_for_interrupt:
185            time.sleep(WAIT_TIME)
186    elif args.scenario == RUNNING_SERVER:
187        server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),))
188        port = server.add_insecure_port("[::]:0")
189        server.start()
190        if args.wait_for_interrupt:
191            time.sleep(WAIT_TIME)
192    elif args.scenario == POLL_CONNECTIVITY_NO_SERVER:
193        channel = grpc.insecure_channel("localhost:12345")
194
195        def connectivity_callback(connectivity):
196            pass
197
198        channel.subscribe(connectivity_callback, try_to_connect=True)
199        if args.wait_for_interrupt:
200            time.sleep(WAIT_TIME)
201    elif args.scenario == POLL_CONNECTIVITY:
202        server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),))
203        port = server.add_insecure_port("[::]:0")
204        server.start()
205        channel = grpc.insecure_channel("localhost:%d" % port)
206
207        def connectivity_callback(connectivity):
208            pass
209
210        channel.subscribe(connectivity_callback, try_to_connect=True)
211        if args.wait_for_interrupt:
212            time.sleep(WAIT_TIME)
213
214    else:
215        server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),))
216        port = server.add_insecure_port("[::]:0")
217        server.add_registered_method_handlers(_SERVICE_NAME, _METHOD_HANDLERS)
218        server.start()
219        channel = grpc.insecure_channel("localhost:%d" % port)
220
221        method = (
222            grpc._common.fully_qualified_method(
223                _SERVICE_NAME, TEST_TO_METHOD[args.scenario]
224            ),
225        )
226
227        if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL:
228            multi_callable = channel.unary_unary(
229                method,
230                _registered_method=True,
231            )
232            future = multi_callable.future(REQUEST)
233            result, call = multi_callable.with_call(REQUEST)
234        elif (
235            args.scenario == IN_FLIGHT_UNARY_STREAM_CALL
236            or args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL
237        ):
238            multi_callable = channel.unary_stream(
239                method,
240                _registered_method=True,
241            )
242            response_iterator = multi_callable(REQUEST)
243            for response in response_iterator:
244                pass
245        elif (
246            args.scenario == IN_FLIGHT_STREAM_UNARY_CALL
247            or args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL
248        ):
249            multi_callable = channel.stream_unary(
250                method,
251                _registered_method=True,
252            )
253            future = multi_callable.future(infinite_request_iterator())
254            result, call = multi_callable.with_call(
255                iter([REQUEST] * test_constants.STREAM_LENGTH)
256            )
257        elif (
258            args.scenario == IN_FLIGHT_STREAM_STREAM_CALL
259            or args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL
260        ):
261            multi_callable = channel.stream_stream(
262                method,
263                _registered_method=True,
264            )
265            response_iterator = multi_callable(infinite_request_iterator())
266            for response in response_iterator:
267                pass
268