1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Defines a number of module-scope gRPC scenarios to test clean exit.""" 15 16import argparse 17import logging 18import threading 19import time 20 21import grpc 22 23from tests.unit.framework.common import test_constants 24 25WAIT_TIME = 1000 26 27REQUEST = b"request" 28 29UNSTARTED_SERVER = "unstarted_server" 30RUNNING_SERVER = "running_server" 31POLL_CONNECTIVITY_NO_SERVER = "poll_connectivity_no_server" 32POLL_CONNECTIVITY = "poll_connectivity" 33IN_FLIGHT_UNARY_UNARY_CALL = "in_flight_unary_unary_call" 34IN_FLIGHT_UNARY_STREAM_CALL = "in_flight_unary_stream_call" 35IN_FLIGHT_STREAM_UNARY_CALL = "in_flight_stream_unary_call" 36IN_FLIGHT_STREAM_STREAM_CALL = "in_flight_stream_stream_call" 37IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL = "in_flight_partial_unary_stream_call" 38IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL = "in_flight_partial_stream_unary_call" 39IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL = "in_flight_partial_stream_stream_call" 40 41_SERVICE_NAME = "test" 42UNARY_UNARY = b"UnaryUnary" 43UNARY_STREAM = b"UnaryStream" 44STREAM_UNARY = b"StreamUnary" 45STREAM_STREAM = b"StreamStream" 46PARTIAL_UNARY_STREAM = b"PartialUnaryStream" 47PARTIAL_STREAM_UNARY = b"PartialStreamUnary" 48PARTIAL_STREAM_STREAM = b"PartialStreamStream" 49 50TEST_TO_METHOD = { 51 IN_FLIGHT_UNARY_UNARY_CALL: UNARY_UNARY, 52 IN_FLIGHT_UNARY_STREAM_CALL: UNARY_STREAM, 53 IN_FLIGHT_STREAM_UNARY_CALL: STREAM_UNARY, 54 IN_FLIGHT_STREAM_STREAM_CALL: STREAM_STREAM, 55 IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL: PARTIAL_UNARY_STREAM, 56 IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL: PARTIAL_STREAM_UNARY, 57 IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL: PARTIAL_STREAM_STREAM, 58} 59 60 61def hang_unary_unary(request, servicer_context): 62 time.sleep(WAIT_TIME) 63 64 65def hang_unary_stream(request, servicer_context): 66 time.sleep(WAIT_TIME) 67 68 69def hang_partial_unary_stream(request, servicer_context): 70 for _ in range(test_constants.STREAM_LENGTH // 2): 71 yield request 72 time.sleep(WAIT_TIME) 73 74 75def hang_stream_unary(request_iterator, servicer_context): 76 time.sleep(WAIT_TIME) 77 78 79def hang_partial_stream_unary(request_iterator, servicer_context): 80 for _ in range(test_constants.STREAM_LENGTH // 2): 81 next(request_iterator) 82 time.sleep(WAIT_TIME) 83 84 85def hang_stream_stream(request_iterator, servicer_context): 86 time.sleep(WAIT_TIME) 87 88 89def hang_partial_stream_stream(request_iterator, servicer_context): 90 for _ in range(test_constants.STREAM_LENGTH // 2): 91 yield next(request_iterator) # pylint: disable=stop-iteration-return 92 time.sleep(WAIT_TIME) 93 94 95class MethodHandler(grpc.RpcMethodHandler): 96 def __init__(self, request_streaming, response_streaming, partial_hang): 97 self.request_streaming = request_streaming 98 self.response_streaming = response_streaming 99 self.request_deserializer = None 100 self.response_serializer = None 101 self.unary_unary = None 102 self.unary_stream = None 103 self.stream_unary = None 104 self.stream_stream = None 105 if self.request_streaming and self.response_streaming: 106 if partial_hang: 107 self.stream_stream = hang_partial_stream_stream 108 else: 109 self.stream_stream = hang_stream_stream 110 elif self.request_streaming: 111 if partial_hang: 112 self.stream_unary = hang_partial_stream_unary 113 else: 114 self.stream_unary = hang_stream_unary 115 elif self.response_streaming: 116 if partial_hang: 117 self.unary_stream = hang_partial_unary_stream 118 else: 119 self.unary_stream = hang_unary_stream 120 else: 121 self.unary_unary = hang_unary_unary 122 123 124class GenericHandler(grpc.GenericRpcHandler): 125 def service(self, handler_call_details): 126 if handler_call_details.method == UNARY_UNARY: 127 return MethodHandler(False, False, False) 128 elif handler_call_details.method == UNARY_STREAM: 129 return MethodHandler(False, True, False) 130 elif handler_call_details.method == STREAM_UNARY: 131 return MethodHandler(True, False, False) 132 elif handler_call_details.method == STREAM_STREAM: 133 return MethodHandler(True, True, False) 134 elif handler_call_details.method == PARTIAL_UNARY_STREAM: 135 return MethodHandler(False, True, True) 136 elif handler_call_details.method == PARTIAL_STREAM_UNARY: 137 return MethodHandler(True, False, True) 138 elif handler_call_details.method == PARTIAL_STREAM_STREAM: 139 return MethodHandler(True, True, True) 140 else: 141 return None 142 143 144_METHOD_HANDLERS = { 145 UNARY_UNARY: MethodHandler(False, False, False), 146 UNARY_STREAM: MethodHandler(False, True, False), 147 STREAM_UNARY: MethodHandler(True, False, False), 148 STREAM_STREAM: MethodHandler(True, True, False), 149 PARTIAL_UNARY_STREAM: MethodHandler(False, True, True), 150 PARTIAL_STREAM_UNARY: MethodHandler(True, False, True), 151 PARTIAL_STREAM_STREAM: MethodHandler(True, True, True), 152} 153 154 155# Traditional executors will not exit until all their 156# current jobs complete. Because we submit jobs that will 157# never finish, we don't want to block exit on these jobs. 158class DaemonPool(object): 159 def submit(self, fn, *args, **kwargs): 160 thread = threading.Thread(target=fn, args=args, kwargs=kwargs) 161 thread.daemon = True 162 thread.start() 163 164 def shutdown(self, wait=True): 165 pass 166 167 168def infinite_request_iterator(): 169 while True: 170 yield REQUEST 171 172 173if __name__ == "__main__": 174 logging.basicConfig() 175 parser = argparse.ArgumentParser() 176 parser.add_argument("scenario", type=str) 177 parser.add_argument( 178 "--wait_for_interrupt", dest="wait_for_interrupt", action="store_true" 179 ) 180 args = parser.parse_args() 181 182 if args.scenario == UNSTARTED_SERVER: 183 server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) 184 if args.wait_for_interrupt: 185 time.sleep(WAIT_TIME) 186 elif args.scenario == RUNNING_SERVER: 187 server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) 188 port = server.add_insecure_port("[::]:0") 189 server.start() 190 if args.wait_for_interrupt: 191 time.sleep(WAIT_TIME) 192 elif args.scenario == POLL_CONNECTIVITY_NO_SERVER: 193 channel = grpc.insecure_channel("localhost:12345") 194 195 def connectivity_callback(connectivity): 196 pass 197 198 channel.subscribe(connectivity_callback, try_to_connect=True) 199 if args.wait_for_interrupt: 200 time.sleep(WAIT_TIME) 201 elif args.scenario == POLL_CONNECTIVITY: 202 server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) 203 port = server.add_insecure_port("[::]:0") 204 server.start() 205 channel = grpc.insecure_channel("localhost:%d" % port) 206 207 def connectivity_callback(connectivity): 208 pass 209 210 channel.subscribe(connectivity_callback, try_to_connect=True) 211 if args.wait_for_interrupt: 212 time.sleep(WAIT_TIME) 213 214 else: 215 server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) 216 port = server.add_insecure_port("[::]:0") 217 server.add_registered_method_handlers(_SERVICE_NAME, _METHOD_HANDLERS) 218 server.start() 219 channel = grpc.insecure_channel("localhost:%d" % port) 220 221 method = ( 222 grpc._common.fully_qualified_method( 223 _SERVICE_NAME, TEST_TO_METHOD[args.scenario] 224 ), 225 ) 226 227 if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL: 228 multi_callable = channel.unary_unary( 229 method, 230 _registered_method=True, 231 ) 232 future = multi_callable.future(REQUEST) 233 result, call = multi_callable.with_call(REQUEST) 234 elif ( 235 args.scenario == IN_FLIGHT_UNARY_STREAM_CALL 236 or args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL 237 ): 238 multi_callable = channel.unary_stream( 239 method, 240 _registered_method=True, 241 ) 242 response_iterator = multi_callable(REQUEST) 243 for response in response_iterator: 244 pass 245 elif ( 246 args.scenario == IN_FLIGHT_STREAM_UNARY_CALL 247 or args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL 248 ): 249 multi_callable = channel.stream_unary( 250 method, 251 _registered_method=True, 252 ) 253 future = multi_callable.future(infinite_request_iterator()) 254 result, call = multi_callable.with_call( 255 iter([REQUEST] * test_constants.STREAM_LENGTH) 256 ) 257 elif ( 258 args.scenario == IN_FLIGHT_STREAM_STREAM_CALL 259 or args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL 260 ): 261 multi_callable = channel.stream_stream( 262 method, 263 _registered_method=True, 264 ) 265 response_iterator = multi_callable(infinite_request_iterator()) 266 for response in response_iterator: 267 pass 268