1# Copyright 2018 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests metadata flags feature by testing wait-for-ready semantics""" 15 16import time 17import weakref 18import unittest 19import threading 20import logging 21import socket 22from six.moves import queue 23 24import grpc 25 26from tests.unit import test_common 27from tests.unit.framework.common import test_constants 28import tests.unit.framework.common 29from tests.unit.framework.common import get_socket 30 31_UNARY_UNARY = '/test/UnaryUnary' 32_UNARY_STREAM = '/test/UnaryStream' 33_STREAM_UNARY = '/test/StreamUnary' 34_STREAM_STREAM = '/test/StreamStream' 35 36_REQUEST = b'\x00\x00\x00' 37_RESPONSE = b'\x00\x00\x00' 38 39 40def handle_unary_unary(test, request, servicer_context): 41 return _RESPONSE 42 43 44def handle_unary_stream(test, request, servicer_context): 45 for _ in range(test_constants.STREAM_LENGTH): 46 yield _RESPONSE 47 48 49def handle_stream_unary(test, request_iterator, servicer_context): 50 for _ in request_iterator: 51 pass 52 return _RESPONSE 53 54 55def handle_stream_stream(test, request_iterator, servicer_context): 56 for _ in request_iterator: 57 yield _RESPONSE 58 59 60class _MethodHandler(grpc.RpcMethodHandler): 61 62 def __init__(self, test, request_streaming, response_streaming): 63 self.request_streaming = request_streaming 64 self.response_streaming = response_streaming 65 self.request_deserializer = None 66 self.response_serializer = None 67 self.unary_unary = None 68 self.unary_stream = None 69 self.stream_unary = None 70 self.stream_stream = None 71 if self.request_streaming and self.response_streaming: 72 self.stream_stream = lambda req, ctx: handle_stream_stream( 73 test, req, ctx) 74 elif self.request_streaming: 75 self.stream_unary = lambda req, ctx: handle_stream_unary( 76 test, req, ctx) 77 elif self.response_streaming: 78 self.unary_stream = lambda req, ctx: handle_unary_stream( 79 test, req, ctx) 80 else: 81 self.unary_unary = lambda req, ctx: handle_unary_unary( 82 test, req, ctx) 83 84 85class _GenericHandler(grpc.GenericRpcHandler): 86 87 def __init__(self, test): 88 self._test = test 89 90 def service(self, handler_call_details): 91 if handler_call_details.method == _UNARY_UNARY: 92 return _MethodHandler(self._test, False, False) 93 elif handler_call_details.method == _UNARY_STREAM: 94 return _MethodHandler(self._test, False, True) 95 elif handler_call_details.method == _STREAM_UNARY: 96 return _MethodHandler(self._test, True, False) 97 elif handler_call_details.method == _STREAM_STREAM: 98 return _MethodHandler(self._test, True, True) 99 else: 100 return None 101 102 103def create_dummy_channel(): 104 """Creating dummy channels is a workaround for retries""" 105 host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,)) 106 sock.close() 107 return grpc.insecure_channel('{}:{}'.format(host, port)) 108 109 110def perform_unary_unary_call(channel, wait_for_ready=None): 111 channel.unary_unary(_UNARY_UNARY).__call__( 112 _REQUEST, 113 timeout=test_constants.LONG_TIMEOUT, 114 wait_for_ready=wait_for_ready) 115 116 117def perform_unary_unary_with_call(channel, wait_for_ready=None): 118 channel.unary_unary(_UNARY_UNARY).with_call( 119 _REQUEST, 120 timeout=test_constants.LONG_TIMEOUT, 121 wait_for_ready=wait_for_ready) 122 123 124def perform_unary_unary_future(channel, wait_for_ready=None): 125 channel.unary_unary(_UNARY_UNARY).future( 126 _REQUEST, 127 timeout=test_constants.LONG_TIMEOUT, 128 wait_for_ready=wait_for_ready).result( 129 timeout=test_constants.LONG_TIMEOUT) 130 131 132def perform_unary_stream_call(channel, wait_for_ready=None): 133 response_iterator = channel.unary_stream(_UNARY_STREAM).__call__( 134 _REQUEST, 135 timeout=test_constants.LONG_TIMEOUT, 136 wait_for_ready=wait_for_ready) 137 for _ in response_iterator: 138 pass 139 140 141def perform_stream_unary_call(channel, wait_for_ready=None): 142 channel.stream_unary(_STREAM_UNARY).__call__( 143 iter([_REQUEST] * test_constants.STREAM_LENGTH), 144 timeout=test_constants.LONG_TIMEOUT, 145 wait_for_ready=wait_for_ready) 146 147 148def perform_stream_unary_with_call(channel, wait_for_ready=None): 149 channel.stream_unary(_STREAM_UNARY).with_call( 150 iter([_REQUEST] * test_constants.STREAM_LENGTH), 151 timeout=test_constants.LONG_TIMEOUT, 152 wait_for_ready=wait_for_ready) 153 154 155def perform_stream_unary_future(channel, wait_for_ready=None): 156 channel.stream_unary(_STREAM_UNARY).future( 157 iter([_REQUEST] * test_constants.STREAM_LENGTH), 158 timeout=test_constants.LONG_TIMEOUT, 159 wait_for_ready=wait_for_ready).result( 160 timeout=test_constants.LONG_TIMEOUT) 161 162 163def perform_stream_stream_call(channel, wait_for_ready=None): 164 response_iterator = channel.stream_stream(_STREAM_STREAM).__call__( 165 iter([_REQUEST] * test_constants.STREAM_LENGTH), 166 timeout=test_constants.LONG_TIMEOUT, 167 wait_for_ready=wait_for_ready) 168 for _ in response_iterator: 169 pass 170 171 172_ALL_CALL_CASES = [ 173 perform_unary_unary_call, perform_unary_unary_with_call, 174 perform_unary_unary_future, perform_unary_stream_call, 175 perform_stream_unary_call, perform_stream_unary_with_call, 176 perform_stream_unary_future, perform_stream_stream_call 177] 178 179 180class MetadataFlagsTest(unittest.TestCase): 181 182 def check_connection_does_failfast(self, fn, channel, wait_for_ready=None): 183 try: 184 fn(channel, wait_for_ready) 185 self.fail("The Call should fail") 186 except BaseException as e: # pylint: disable=broad-except 187 self.assertIs(grpc.StatusCode.UNAVAILABLE, e.code()) 188 189 def test_call_wait_for_ready_default(self): 190 for perform_call in _ALL_CALL_CASES: 191 with create_dummy_channel() as channel: 192 self.check_connection_does_failfast(perform_call, channel) 193 194 def test_call_wait_for_ready_disabled(self): 195 for perform_call in _ALL_CALL_CASES: 196 with create_dummy_channel() as channel: 197 self.check_connection_does_failfast(perform_call, 198 channel, 199 wait_for_ready=False) 200 201 def test_call_wait_for_ready_enabled(self): 202 # To test the wait mechanism, Python thread is required to make 203 # client set up first without handling them case by case. 204 # Also, Python thread don't pass the unhandled exceptions to 205 # main thread. So, it need another method to store the 206 # exceptions and raise them again in main thread. 207 unhandled_exceptions = queue.Queue() 208 209 # We just need an unused TCP port 210 host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,)) 211 sock.close() 212 213 addr = '{}:{}'.format(host, port) 214 wg = test_common.WaitGroup(len(_ALL_CALL_CASES)) 215 216 def wait_for_transient_failure(channel_connectivity): 217 if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE: 218 wg.done() 219 220 def test_call(perform_call): 221 with grpc.insecure_channel(addr) as channel: 222 try: 223 channel.subscribe(wait_for_transient_failure) 224 perform_call(channel, wait_for_ready=True) 225 except BaseException as e: # pylint: disable=broad-except 226 # If the call failed, the thread would be destroyed. The 227 # channel object can be collected before calling the 228 # callback, which will result in a deadlock. 229 wg.done() 230 unhandled_exceptions.put(e, True) 231 232 test_threads = [] 233 for perform_call in _ALL_CALL_CASES: 234 test_thread = threading.Thread(target=test_call, 235 args=(perform_call,)) 236 test_thread.daemon = True 237 test_thread.exception = None 238 test_thread.start() 239 test_threads.append(test_thread) 240 241 # Start the server after the connections are waiting 242 wg.wait() 243 server = test_common.test_server(reuse_port=True) 244 server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),)) 245 server.add_insecure_port(addr) 246 server.start() 247 248 for test_thread in test_threads: 249 test_thread.join() 250 251 # Stop the server to make test end properly 252 server.stop(0) 253 254 if not unhandled_exceptions.empty(): 255 raise unhandled_exceptions.get(True) 256 257 258if __name__ == '__main__': 259 logging.basicConfig(level=logging.DEBUG) 260 unittest.main(verbosity=2) 261