1# Copyright 2019 the gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Test of responsiveness to signals.""" 15 16import logging 17import os 18import signal 19import subprocess 20import sys 21import tempfile 22import threading 23import unittest 24 25import grpc 26 27from tests.unit import _signal_client 28from tests.unit import test_common 29 30_CLIENT_PATH = None 31if sys.executable is not None: 32 _CLIENT_PATH = os.path.abspath(os.path.realpath(_signal_client.__file__)) 33else: 34 # NOTE(rbellevi): For compatibility with internal testing. 35 if len(sys.argv) != 2: 36 raise RuntimeError("Must supply path to executable client.") 37 client_name = sys.argv[1].split("/")[-1] 38 del sys.argv[1] # For compatibility with test runner. 39 _CLIENT_PATH = os.path.realpath( 40 os.path.join(os.path.dirname(os.path.abspath(__file__)), client_name) 41 ) 42 43_HOST = "localhost" 44 45# The gevent test harness cannot run the monkeypatch code for the child process, 46# so we need to instrument it manually. 47_GEVENT_ARG = ("--gevent",) if test_common.running_under_gevent() else () 48 49 50class _GenericHandler: 51 def __init__(self): 52 self._connected_clients_lock = threading.RLock() 53 self._connected_clients_event = threading.Event() 54 self._connected_clients = 0 55 56 self._unary_unary_handler = grpc.unary_unary_rpc_method_handler( 57 self._handle_unary_unary 58 ) 59 self._unary_stream_handler = grpc.unary_stream_rpc_method_handler( 60 self._handle_unary_stream 61 ) 62 63 def _on_client_connect(self): 64 with self._connected_clients_lock: 65 self._connected_clients += 1 66 self._connected_clients_event.set() 67 68 def _on_client_disconnect(self): 69 with self._connected_clients_lock: 70 self._connected_clients -= 1 71 if self._connected_clients == 0: 72 self._connected_clients_event.clear() 73 74 def await_connected_client(self): 75 """Blocks until a client connects to the server.""" 76 self._connected_clients_event.wait() 77 78 def _handle_unary_unary(self, request, servicer_context): 79 """Handles a unary RPC. 80 81 Blocks until the client disconnects and then echoes. 82 """ 83 stop_event = threading.Event() 84 85 def on_rpc_end(): 86 self._on_client_disconnect() 87 stop_event.set() 88 89 servicer_context.add_callback(on_rpc_end) 90 self._on_client_connect() 91 stop_event.wait() 92 return request 93 94 def _handle_unary_stream(self, request, servicer_context): 95 """Handles a server streaming RPC. 96 97 Blocks until the client disconnects and then echoes. 98 """ 99 stop_event = threading.Event() 100 101 def on_rpc_end(): 102 self._on_client_disconnect() 103 stop_event.set() 104 105 servicer_context.add_callback(on_rpc_end) 106 self._on_client_connect() 107 stop_event.wait() 108 yield request 109 110 111def get_method_handlers(handler): 112 return { 113 _signal_client.UNARY_UNARY: handler._unary_unary_handler, 114 _signal_client.UNARY_STREAM: handler._unary_stream_handler, 115 } 116 117 118def _read_stream(stream): 119 stream.seek(0) 120 return stream.read() 121 122 123def _start_client(args, stdout, stderr): 124 invocation = None 125 if sys.executable is not None: 126 invocation = (sys.executable, _CLIENT_PATH) + tuple(args) 127 else: 128 invocation = (_CLIENT_PATH,) + tuple(args) 129 return subprocess.Popen(invocation, stdout=stdout, stderr=stderr) 130 131 132class SignalHandlingTest(unittest.TestCase): 133 def setUp(self): 134 self._server = test_common.test_server() 135 self._port = self._server.add_insecure_port("{}:0".format(_HOST)) 136 self._handler = _GenericHandler() 137 self._server.add_registered_method_handlers( 138 _signal_client._SERVICE_NAME, get_method_handlers(self._handler) 139 ) 140 self._server.start() 141 142 def tearDown(self): 143 self._server.stop(None) 144 145 @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows") 146 def testUnary(self): 147 """Tests that the server unary code path does not stall signal handlers.""" 148 server_target = "{}:{}".format(_HOST, self._port) 149 with tempfile.TemporaryFile(mode="r") as client_stdout: 150 with tempfile.TemporaryFile(mode="r") as client_stderr: 151 client = _start_client( 152 (server_target, "unary") + _GEVENT_ARG, 153 client_stdout, 154 client_stderr, 155 ) 156 self._handler.await_connected_client() 157 client.send_signal(signal.SIGINT) 158 self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) 159 client_stdout.seek(0) 160 self.assertIn( 161 _signal_client.SIGTERM_MESSAGE, client_stdout.read() 162 ) 163 164 @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows") 165 def testStreaming(self): 166 """Tests that the server streaming code path does not stall signal handlers.""" 167 server_target = "{}:{}".format(_HOST, self._port) 168 with tempfile.TemporaryFile(mode="r") as client_stdout: 169 with tempfile.TemporaryFile(mode="r") as client_stderr: 170 client = _start_client( 171 (server_target, "streaming") + _GEVENT_ARG, 172 client_stdout, 173 client_stderr, 174 ) 175 self._handler.await_connected_client() 176 client.send_signal(signal.SIGINT) 177 self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) 178 client_stdout.seek(0) 179 self.assertIn( 180 _signal_client.SIGTERM_MESSAGE, client_stdout.read() 181 ) 182 183 @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows") 184 def testUnaryWithException(self): 185 server_target = "{}:{}".format(_HOST, self._port) 186 with tempfile.TemporaryFile(mode="r") as client_stdout: 187 with tempfile.TemporaryFile(mode="r") as client_stderr: 188 client = _start_client( 189 ("--exception", server_target, "unary") + _GEVENT_ARG, 190 client_stdout, 191 client_stderr, 192 ) 193 self._handler.await_connected_client() 194 client.send_signal(signal.SIGINT) 195 client.wait() 196 self.assertEqual(0, client.returncode) 197 198 @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows") 199 def testStreamingHandlerWithException(self): 200 server_target = "{}:{}".format(_HOST, self._port) 201 with tempfile.TemporaryFile(mode="r") as client_stdout: 202 with tempfile.TemporaryFile(mode="r") as client_stderr: 203 client = _start_client( 204 ("--exception", server_target, "streaming") + _GEVENT_ARG, 205 client_stdout, 206 client_stderr, 207 ) 208 self._handler.await_connected_client() 209 client.send_signal(signal.SIGINT) 210 client.wait() 211 client_stderr_output = _read_stream(client_stderr) 212 try: 213 self.assertEqual(0, client.returncode) 214 except AssertionError: 215 print(client_stderr_output, file=sys.stderr) 216 raise 217 218 219if __name__ == "__main__": 220 logging.basicConfig() 221 unittest.main(verbosity=2) 222