1# Copyright 2018 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests metadata flags feature by testing wait-for-ready semantics""" 15 16import logging 17import queue 18import socket 19import threading 20import time 21import unittest 22import weakref 23 24import grpc 25 26from tests.unit import test_common 27import tests.unit.framework.common 28from tests.unit.framework.common import get_socket 29from tests.unit.framework.common import test_constants 30 31_UNARY_UNARY = "/test/UnaryUnary" 32_UNARY_STREAM = "/test/UnaryStream" 33_STREAM_UNARY = "/test/StreamUnary" 34_STREAM_STREAM = "/test/StreamStream" 35 36_REQUEST = b"\x00\x00\x00" 37_RESPONSE = b"\x00\x00\x00" 38 39 40def handle_unary_unary(test, request, servicer_context): 41 return _RESPONSE 42 43 44def handle_unary_stream(test, request, servicer_context): 45 for _ in range(test_constants.STREAM_LENGTH): 46 yield _RESPONSE 47 48 49def handle_stream_unary(test, request_iterator, servicer_context): 50 for _ in request_iterator: 51 pass 52 return _RESPONSE 53 54 55def handle_stream_stream(test, request_iterator, servicer_context): 56 for _ in request_iterator: 57 yield _RESPONSE 58 59 60class _MethodHandler(grpc.RpcMethodHandler): 61 def __init__(self, test, request_streaming, response_streaming): 62 self.request_streaming = request_streaming 63 self.response_streaming = response_streaming 64 self.request_deserializer = None 65 self.response_serializer = None 66 self.unary_unary = None 67 self.unary_stream = None 68 self.stream_unary = None 69 self.stream_stream = None 70 if self.request_streaming and self.response_streaming: 71 self.stream_stream = lambda req, ctx: handle_stream_stream( 72 test, req, ctx 73 ) 74 elif self.request_streaming: 75 self.stream_unary = lambda req, ctx: handle_stream_unary( 76 test, req, ctx 77 ) 78 elif self.response_streaming: 79 self.unary_stream = lambda req, ctx: handle_unary_stream( 80 test, req, ctx 81 ) 82 else: 83 self.unary_unary = lambda req, ctx: handle_unary_unary( 84 test, req, ctx 85 ) 86 87 88class _GenericHandler(grpc.GenericRpcHandler): 89 def __init__(self, test): 90 self._test = test 91 92 def service(self, handler_call_details): 93 if handler_call_details.method == _UNARY_UNARY: 94 return _MethodHandler(self._test, False, False) 95 elif handler_call_details.method == _UNARY_STREAM: 96 return _MethodHandler(self._test, False, True) 97 elif handler_call_details.method == _STREAM_UNARY: 98 return _MethodHandler(self._test, True, False) 99 elif handler_call_details.method == _STREAM_STREAM: 100 return _MethodHandler(self._test, True, True) 101 else: 102 return None 103 104 105def create_phony_channel(): 106 """Creating phony channels is a workaround for retries""" 107 host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,)) 108 sock.close() 109 return grpc.insecure_channel("{}:{}".format(host, port)) 110 111 112def perform_unary_unary_call(channel, wait_for_ready=None): 113 channel.unary_unary( 114 _UNARY_UNARY, 115 _registered_method=True, 116 ).__call__( 117 _REQUEST, 118 timeout=test_constants.LONG_TIMEOUT, 119 wait_for_ready=wait_for_ready, 120 ) 121 122 123def perform_unary_unary_with_call(channel, wait_for_ready=None): 124 channel.unary_unary( 125 _UNARY_UNARY, 126 _registered_method=True, 127 ).with_call( 128 _REQUEST, 129 timeout=test_constants.LONG_TIMEOUT, 130 wait_for_ready=wait_for_ready, 131 ) 132 133 134def perform_unary_unary_future(channel, wait_for_ready=None): 135 channel.unary_unary( 136 _UNARY_UNARY, 137 _registered_method=True, 138 ).future( 139 _REQUEST, 140 timeout=test_constants.LONG_TIMEOUT, 141 wait_for_ready=wait_for_ready, 142 ).result(timeout=test_constants.LONG_TIMEOUT) 143 144 145def perform_unary_stream_call(channel, wait_for_ready=None): 146 response_iterator = channel.unary_stream( 147 _UNARY_STREAM, 148 _registered_method=True, 149 ).__call__( 150 _REQUEST, 151 timeout=test_constants.LONG_TIMEOUT, 152 wait_for_ready=wait_for_ready, 153 ) 154 for _ in response_iterator: 155 pass 156 157 158def perform_stream_unary_call(channel, wait_for_ready=None): 159 channel.stream_unary( 160 _STREAM_UNARY, 161 _registered_method=True, 162 ).__call__( 163 iter([_REQUEST] * test_constants.STREAM_LENGTH), 164 timeout=test_constants.LONG_TIMEOUT, 165 wait_for_ready=wait_for_ready, 166 ) 167 168 169def perform_stream_unary_with_call(channel, wait_for_ready=None): 170 channel.stream_unary( 171 _STREAM_UNARY, 172 _registered_method=True, 173 ).with_call( 174 iter([_REQUEST] * test_constants.STREAM_LENGTH), 175 timeout=test_constants.LONG_TIMEOUT, 176 wait_for_ready=wait_for_ready, 177 ) 178 179 180def perform_stream_unary_future(channel, wait_for_ready=None): 181 channel.stream_unary( 182 _STREAM_UNARY, 183 _registered_method=True, 184 ).future( 185 iter([_REQUEST] * test_constants.STREAM_LENGTH), 186 timeout=test_constants.LONG_TIMEOUT, 187 wait_for_ready=wait_for_ready, 188 ).result(timeout=test_constants.LONG_TIMEOUT) 189 190 191def perform_stream_stream_call(channel, wait_for_ready=None): 192 response_iterator = channel.stream_stream( 193 _STREAM_STREAM, _registered_method=True 194 ).__call__( 195 iter([_REQUEST] * test_constants.STREAM_LENGTH), 196 timeout=test_constants.LONG_TIMEOUT, 197 wait_for_ready=wait_for_ready, 198 ) 199 for _ in response_iterator: 200 pass 201 202 203_ALL_CALL_CASES = [ 204 perform_unary_unary_call, 205 perform_unary_unary_with_call, 206 perform_unary_unary_future, 207 perform_unary_stream_call, 208 perform_stream_unary_call, 209 perform_stream_unary_with_call, 210 perform_stream_unary_future, 211 perform_stream_stream_call, 212] 213 214 215class MetadataFlagsTest(unittest.TestCase): 216 def check_connection_does_failfast(self, fn, channel, wait_for_ready=None): 217 try: 218 fn(channel, wait_for_ready) 219 self.fail("The Call should fail") 220 except BaseException as e: # pylint: disable=broad-except 221 self.assertIs(grpc.StatusCode.UNAVAILABLE, e.code()) 222 223 def test_call_wait_for_ready_default(self): 224 for perform_call in _ALL_CALL_CASES: 225 with create_phony_channel() as channel: 226 self.check_connection_does_failfast(perform_call, channel) 227 228 def test_call_wait_for_ready_disabled(self): 229 for perform_call in _ALL_CALL_CASES: 230 with create_phony_channel() as channel: 231 self.check_connection_does_failfast( 232 perform_call, channel, wait_for_ready=False 233 ) 234 235 def test_call_wait_for_ready_enabled(self): 236 # To test the wait mechanism, Python thread is required to make 237 # client set up first without handling them case by case. 238 # Also, Python thread don't pass the unhandled exceptions to 239 # main thread. So, it need another method to store the 240 # exceptions and raise them again in main thread. 241 unhandled_exceptions = queue.Queue() 242 243 # We just need an unused TCP port 244 host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,)) 245 sock.close() 246 247 addr = "{}:{}".format(host, port) 248 wg = test_common.WaitGroup(len(_ALL_CALL_CASES)) 249 250 def wait_for_transient_failure(channel_connectivity): 251 if ( 252 channel_connectivity 253 == grpc.ChannelConnectivity.TRANSIENT_FAILURE 254 ): 255 wg.done() 256 257 def test_call(perform_call): 258 with grpc.insecure_channel(addr) as channel: 259 try: 260 channel.subscribe(wait_for_transient_failure) 261 perform_call(channel, wait_for_ready=True) 262 except BaseException as e: # pylint: disable=broad-except 263 # If the call failed, the thread would be destroyed. The 264 # channel object can be collected before calling the 265 # callback, which will result in a deadlock. 266 wg.done() 267 unhandled_exceptions.put(e, True) 268 269 test_threads = [] 270 for perform_call in _ALL_CALL_CASES: 271 test_thread = threading.Thread( 272 target=test_call, args=(perform_call,) 273 ) 274 test_thread.daemon = True 275 test_thread.exception = None 276 test_thread.start() 277 test_threads.append(test_thread) 278 279 # Start the server after the connections are waiting 280 wg.wait() 281 server = test_common.test_server(reuse_port=True) 282 server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),)) 283 server.add_insecure_port(addr) 284 server.start() 285 286 for test_thread in test_threads: 287 test_thread.join() 288 289 # Stop the server to make test end properly 290 server.stop(0) 291 292 if not unhandled_exceptions.empty(): 293 raise unhandled_exceptions.get(True) 294 295 296if __name__ == "__main__": 297 logging.basicConfig(level=logging.DEBUG) 298 unittest.main(verbosity=2) 299