1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Defines a number of module-scope gRPC scenarios to test clean exit.""" 15 16import argparse 17import logging 18import threading 19import time 20 21import grpc 22 23from tests.unit.framework.common import test_constants 24 25WAIT_TIME = 1000 26 27REQUEST = b"request" 28 29UNSTARTED_SERVER = "unstarted_server" 30RUNNING_SERVER = "running_server" 31POLL_CONNECTIVITY_NO_SERVER = "poll_connectivity_no_server" 32POLL_CONNECTIVITY = "poll_connectivity" 33IN_FLIGHT_UNARY_UNARY_CALL = "in_flight_unary_unary_call" 34IN_FLIGHT_UNARY_STREAM_CALL = "in_flight_unary_stream_call" 35IN_FLIGHT_STREAM_UNARY_CALL = "in_flight_stream_unary_call" 36IN_FLIGHT_STREAM_STREAM_CALL = "in_flight_stream_stream_call" 37IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL = "in_flight_partial_unary_stream_call" 38IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL = "in_flight_partial_stream_unary_call" 39IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL = "in_flight_partial_stream_stream_call" 40 41UNARY_UNARY = b"/test/UnaryUnary" 42UNARY_STREAM = b"/test/UnaryStream" 43STREAM_UNARY = b"/test/StreamUnary" 44STREAM_STREAM = b"/test/StreamStream" 45PARTIAL_UNARY_STREAM = b"/test/PartialUnaryStream" 46PARTIAL_STREAM_UNARY = b"/test/PartialStreamUnary" 47PARTIAL_STREAM_STREAM = b"/test/PartialStreamStream" 48 49TEST_TO_METHOD = { 50 IN_FLIGHT_UNARY_UNARY_CALL: UNARY_UNARY, 51 IN_FLIGHT_UNARY_STREAM_CALL: UNARY_STREAM, 52 IN_FLIGHT_STREAM_UNARY_CALL: STREAM_UNARY, 53 IN_FLIGHT_STREAM_STREAM_CALL: STREAM_STREAM, 54 IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL: PARTIAL_UNARY_STREAM, 55 IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL: PARTIAL_STREAM_UNARY, 56 IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL: PARTIAL_STREAM_STREAM, 57} 58 59 60def hang_unary_unary(request, servicer_context): 61 time.sleep(WAIT_TIME) 62 63 64def hang_unary_stream(request, servicer_context): 65 time.sleep(WAIT_TIME) 66 67 68def hang_partial_unary_stream(request, servicer_context): 69 for _ in range(test_constants.STREAM_LENGTH // 2): 70 yield request 71 time.sleep(WAIT_TIME) 72 73 74def hang_stream_unary(request_iterator, servicer_context): 75 time.sleep(WAIT_TIME) 76 77 78def hang_partial_stream_unary(request_iterator, servicer_context): 79 for _ in range(test_constants.STREAM_LENGTH // 2): 80 next(request_iterator) 81 time.sleep(WAIT_TIME) 82 83 84def hang_stream_stream(request_iterator, servicer_context): 85 time.sleep(WAIT_TIME) 86 87 88def hang_partial_stream_stream(request_iterator, servicer_context): 89 for _ in range(test_constants.STREAM_LENGTH // 2): 90 yield next(request_iterator) # pylint: disable=stop-iteration-return 91 time.sleep(WAIT_TIME) 92 93 94class MethodHandler(grpc.RpcMethodHandler): 95 def __init__(self, request_streaming, response_streaming, partial_hang): 96 self.request_streaming = request_streaming 97 self.response_streaming = response_streaming 98 self.request_deserializer = None 99 self.response_serializer = None 100 self.unary_unary = None 101 self.unary_stream = None 102 self.stream_unary = None 103 self.stream_stream = None 104 if self.request_streaming and self.response_streaming: 105 if partial_hang: 106 self.stream_stream = hang_partial_stream_stream 107 else: 108 self.stream_stream = hang_stream_stream 109 elif self.request_streaming: 110 if partial_hang: 111 self.stream_unary = hang_partial_stream_unary 112 else: 113 self.stream_unary = hang_stream_unary 114 elif self.response_streaming: 115 if partial_hang: 116 self.unary_stream = hang_partial_unary_stream 117 else: 118 self.unary_stream = hang_unary_stream 119 else: 120 self.unary_unary = hang_unary_unary 121 122 123class GenericHandler(grpc.GenericRpcHandler): 124 def service(self, handler_call_details): 125 if handler_call_details.method == UNARY_UNARY: 126 return MethodHandler(False, False, False) 127 elif handler_call_details.method == UNARY_STREAM: 128 return MethodHandler(False, True, False) 129 elif handler_call_details.method == STREAM_UNARY: 130 return MethodHandler(True, False, False) 131 elif handler_call_details.method == STREAM_STREAM: 132 return MethodHandler(True, True, False) 133 elif handler_call_details.method == PARTIAL_UNARY_STREAM: 134 return MethodHandler(False, True, True) 135 elif handler_call_details.method == PARTIAL_STREAM_UNARY: 136 return MethodHandler(True, False, True) 137 elif handler_call_details.method == PARTIAL_STREAM_STREAM: 138 return MethodHandler(True, True, True) 139 else: 140 return None 141 142 143# Traditional executors will not exit until all their 144# current jobs complete. Because we submit jobs that will 145# never finish, we don't want to block exit on these jobs. 146class DaemonPool(object): 147 def submit(self, fn, *args, **kwargs): 148 thread = threading.Thread(target=fn, args=args, kwargs=kwargs) 149 thread.daemon = True 150 thread.start() 151 152 def shutdown(self, wait=True): 153 pass 154 155 156def infinite_request_iterator(): 157 while True: 158 yield REQUEST 159 160 161if __name__ == "__main__": 162 logging.basicConfig() 163 parser = argparse.ArgumentParser() 164 parser.add_argument("scenario", type=str) 165 parser.add_argument( 166 "--wait_for_interrupt", dest="wait_for_interrupt", action="store_true" 167 ) 168 args = parser.parse_args() 169 170 if args.scenario == UNSTARTED_SERVER: 171 server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) 172 if args.wait_for_interrupt: 173 time.sleep(WAIT_TIME) 174 elif args.scenario == RUNNING_SERVER: 175 server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) 176 port = server.add_insecure_port("[::]:0") 177 server.start() 178 if args.wait_for_interrupt: 179 time.sleep(WAIT_TIME) 180 elif args.scenario == POLL_CONNECTIVITY_NO_SERVER: 181 channel = grpc.insecure_channel("localhost:12345") 182 183 def connectivity_callback(connectivity): 184 pass 185 186 channel.subscribe(connectivity_callback, try_to_connect=True) 187 if args.wait_for_interrupt: 188 time.sleep(WAIT_TIME) 189 elif args.scenario == POLL_CONNECTIVITY: 190 server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) 191 port = server.add_insecure_port("[::]:0") 192 server.start() 193 channel = grpc.insecure_channel("localhost:%d" % port) 194 195 def connectivity_callback(connectivity): 196 pass 197 198 channel.subscribe(connectivity_callback, try_to_connect=True) 199 if args.wait_for_interrupt: 200 time.sleep(WAIT_TIME) 201 202 else: 203 handler = GenericHandler() 204 server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) 205 port = server.add_insecure_port("[::]:0") 206 server.add_generic_rpc_handlers((handler,)) 207 server.start() 208 channel = grpc.insecure_channel("localhost:%d" % port) 209 210 method = TEST_TO_METHOD[args.scenario] 211 212 if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL: 213 multi_callable = channel.unary_unary( 214 method, 215 _registered_method=True, 216 ) 217 future = multi_callable.future(REQUEST) 218 result, call = multi_callable.with_call(REQUEST) 219 elif ( 220 args.scenario == IN_FLIGHT_UNARY_STREAM_CALL 221 or args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL 222 ): 223 multi_callable = channel.unary_stream( 224 method, 225 _registered_method=True, 226 ) 227 response_iterator = multi_callable(REQUEST) 228 for response in response_iterator: 229 pass 230 elif ( 231 args.scenario == IN_FLIGHT_STREAM_UNARY_CALL 232 or args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL 233 ): 234 multi_callable = channel.stream_unary( 235 method, 236 _registered_method=True, 237 ) 238 future = multi_callable.future(infinite_request_iterator()) 239 result, call = multi_callable.with_call( 240 iter([REQUEST] * test_constants.STREAM_LENGTH) 241 ) 242 elif ( 243 args.scenario == IN_FLIGHT_STREAM_STREAM_CALL 244 or args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL 245 ): 246 multi_callable = channel.stream_stream( 247 method, 248 _registered_method=True, 249 ) 250 response_iterator = multi_callable(infinite_request_iterator()) 251 for response in response_iterator: 252 pass 253