1# Copyright 2019 the gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Test of responsiveness to signals.""" 15 16import logging 17import os 18import signal 19import subprocess 20import tempfile 21import threading 22import unittest 23import sys 24 25import grpc 26 27from tests.unit import test_common 28from tests.unit import _signal_client 29 30_CLIENT_PATH = None 31if sys.executable is not None: 32 _CLIENT_PATH = os.path.abspath(os.path.realpath(_signal_client.__file__)) 33else: 34 # NOTE(rbellevi): For compatibility with internal testing. 35 if len(sys.argv) != 2: 36 raise RuntimeError("Must supply path to executable client.") 37 client_name = sys.argv[1].split("/")[-1] 38 del sys.argv[1] # For compatibility with test runner. 39 _CLIENT_PATH = os.path.realpath( 40 os.path.join(os.path.dirname(os.path.abspath(__file__)), client_name)) 41 42_HOST = 'localhost' 43 44 45class _GenericHandler(grpc.GenericRpcHandler): 46 47 def __init__(self): 48 self._connected_clients_lock = threading.RLock() 49 self._connected_clients_event = threading.Event() 50 self._connected_clients = 0 51 52 self._unary_unary_handler = grpc.unary_unary_rpc_method_handler( 53 self._handle_unary_unary) 54 self._unary_stream_handler = grpc.unary_stream_rpc_method_handler( 55 self._handle_unary_stream) 56 57 def _on_client_connect(self): 58 with self._connected_clients_lock: 59 self._connected_clients += 1 60 self._connected_clients_event.set() 61 62 def _on_client_disconnect(self): 63 with self._connected_clients_lock: 64 self._connected_clients -= 1 65 if self._connected_clients == 0: 66 self._connected_clients_event.clear() 67 68 def await_connected_client(self): 69 """Blocks until a client connects to the server.""" 70 self._connected_clients_event.wait() 71 72 def _handle_unary_unary(self, request, servicer_context): 73 """Handles a unary RPC. 74 75 Blocks until the client disconnects and then echoes. 76 """ 77 stop_event = threading.Event() 78 79 def on_rpc_end(): 80 self._on_client_disconnect() 81 stop_event.set() 82 83 servicer_context.add_callback(on_rpc_end) 84 self._on_client_connect() 85 stop_event.wait() 86 return request 87 88 def _handle_unary_stream(self, request, servicer_context): 89 """Handles a server streaming RPC. 90 91 Blocks until the client disconnects and then echoes. 92 """ 93 stop_event = threading.Event() 94 95 def on_rpc_end(): 96 self._on_client_disconnect() 97 stop_event.set() 98 99 servicer_context.add_callback(on_rpc_end) 100 self._on_client_connect() 101 stop_event.wait() 102 yield request 103 104 def service(self, handler_call_details): 105 if handler_call_details.method == _signal_client.UNARY_UNARY: 106 return self._unary_unary_handler 107 elif handler_call_details.method == _signal_client.UNARY_STREAM: 108 return self._unary_stream_handler 109 else: 110 return None 111 112 113def _read_stream(stream): 114 stream.seek(0) 115 return stream.read() 116 117 118def _start_client(args, stdout, stderr): 119 invocation = None 120 if sys.executable is not None: 121 invocation = (sys.executable, _CLIENT_PATH) + tuple(args) 122 else: 123 invocation = (_CLIENT_PATH,) + tuple(args) 124 return subprocess.Popen(invocation, stdout=stdout, stderr=stderr) 125 126 127class SignalHandlingTest(unittest.TestCase): 128 129 def setUp(self): 130 self._server = test_common.test_server() 131 self._port = self._server.add_insecure_port('{}:0'.format(_HOST)) 132 self._handler = _GenericHandler() 133 self._server.add_generic_rpc_handlers((self._handler,)) 134 self._server.start() 135 136 def tearDown(self): 137 self._server.stop(None) 138 139 @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') 140 def testUnary(self): 141 """Tests that the server unary code path does not stall signal handlers.""" 142 server_target = '{}:{}'.format(_HOST, self._port) 143 with tempfile.TemporaryFile(mode='r') as client_stdout: 144 with tempfile.TemporaryFile(mode='r') as client_stderr: 145 client = _start_client((server_target, 'unary'), client_stdout, 146 client_stderr) 147 self._handler.await_connected_client() 148 client.send_signal(signal.SIGINT) 149 self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) 150 client_stdout.seek(0) 151 self.assertIn(_signal_client.SIGTERM_MESSAGE, 152 client_stdout.read()) 153 154 @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') 155 def testStreaming(self): 156 """Tests that the server streaming code path does not stall signal handlers.""" 157 server_target = '{}:{}'.format(_HOST, self._port) 158 with tempfile.TemporaryFile(mode='r') as client_stdout: 159 with tempfile.TemporaryFile(mode='r') as client_stderr: 160 client = _start_client((server_target, 'streaming'), 161 client_stdout, client_stderr) 162 self._handler.await_connected_client() 163 client.send_signal(signal.SIGINT) 164 self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) 165 client_stdout.seek(0) 166 self.assertIn(_signal_client.SIGTERM_MESSAGE, 167 client_stdout.read()) 168 169 @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') 170 def testUnaryWithException(self): 171 server_target = '{}:{}'.format(_HOST, self._port) 172 with tempfile.TemporaryFile(mode='r') as client_stdout: 173 with tempfile.TemporaryFile(mode='r') as client_stderr: 174 client = _start_client(('--exception', server_target, 'unary'), 175 client_stdout, client_stderr) 176 self._handler.await_connected_client() 177 client.send_signal(signal.SIGINT) 178 client.wait() 179 self.assertEqual(0, client.returncode) 180 181 @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') 182 def testStreamingHandlerWithException(self): 183 server_target = '{}:{}'.format(_HOST, self._port) 184 with tempfile.TemporaryFile(mode='r') as client_stdout: 185 with tempfile.TemporaryFile(mode='r') as client_stderr: 186 client = _start_client( 187 ('--exception', server_target, 'streaming'), client_stdout, 188 client_stderr) 189 self._handler.await_connected_client() 190 client.send_signal(signal.SIGINT) 191 client.wait() 192 print(_read_stream(client_stderr)) 193 self.assertEqual(0, client.returncode) 194 195 196if __name__ == '__main__': 197 logging.basicConfig() 198 unittest.main(verbosity=2) 199