• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 the gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test of responsiveness to signals."""
15
16import logging
17import os
18import signal
19import subprocess
20import sys
21import tempfile
22import threading
23import unittest
24
25import grpc
26
27from tests.unit import _signal_client
28from tests.unit import test_common
29
30_CLIENT_PATH = None
31if sys.executable is not None:
32    _CLIENT_PATH = os.path.abspath(os.path.realpath(_signal_client.__file__))
33else:
34    # NOTE(rbellevi): For compatibility with internal testing.
35    if len(sys.argv) != 2:
36        raise RuntimeError("Must supply path to executable client.")
37    client_name = sys.argv[1].split("/")[-1]
38    del sys.argv[1]  # For compatibility with test runner.
39    _CLIENT_PATH = os.path.realpath(
40        os.path.join(os.path.dirname(os.path.abspath(__file__)), client_name)
41    )
42
43_HOST = "localhost"
44
45# The gevent test harness cannot run the monkeypatch code for the child process,
46# so we need to instrument it manually.
47_GEVENT_ARG = ("--gevent",) if test_common.running_under_gevent() else ()
48
49
50class _GenericHandler:
51    def __init__(self):
52        self._connected_clients_lock = threading.RLock()
53        self._connected_clients_event = threading.Event()
54        self._connected_clients = 0
55
56        self._unary_unary_handler = grpc.unary_unary_rpc_method_handler(
57            self._handle_unary_unary
58        )
59        self._unary_stream_handler = grpc.unary_stream_rpc_method_handler(
60            self._handle_unary_stream
61        )
62
63    def _on_client_connect(self):
64        with self._connected_clients_lock:
65            self._connected_clients += 1
66            self._connected_clients_event.set()
67
68    def _on_client_disconnect(self):
69        with self._connected_clients_lock:
70            self._connected_clients -= 1
71            if self._connected_clients == 0:
72                self._connected_clients_event.clear()
73
74    def await_connected_client(self):
75        """Blocks until a client connects to the server."""
76        self._connected_clients_event.wait()
77
78    def _handle_unary_unary(self, request, servicer_context):
79        """Handles a unary RPC.
80
81        Blocks until the client disconnects and then echoes.
82        """
83        stop_event = threading.Event()
84
85        def on_rpc_end():
86            self._on_client_disconnect()
87            stop_event.set()
88
89        servicer_context.add_callback(on_rpc_end)
90        self._on_client_connect()
91        stop_event.wait()
92        return request
93
94    def _handle_unary_stream(self, request, servicer_context):
95        """Handles a server streaming RPC.
96
97        Blocks until the client disconnects and then echoes.
98        """
99        stop_event = threading.Event()
100
101        def on_rpc_end():
102            self._on_client_disconnect()
103            stop_event.set()
104
105        servicer_context.add_callback(on_rpc_end)
106        self._on_client_connect()
107        stop_event.wait()
108        yield request
109
110
111def get_method_handlers(handler):
112    return {
113        _signal_client.UNARY_UNARY: handler._unary_unary_handler,
114        _signal_client.UNARY_STREAM: handler._unary_stream_handler,
115    }
116
117
118def _read_stream(stream):
119    stream.seek(0)
120    return stream.read()
121
122
123def _start_client(args, stdout, stderr):
124    invocation = None
125    if sys.executable is not None:
126        invocation = (sys.executable, _CLIENT_PATH) + tuple(args)
127    else:
128        invocation = (_CLIENT_PATH,) + tuple(args)
129    return subprocess.Popen(invocation, stdout=stdout, stderr=stderr)
130
131
132class SignalHandlingTest(unittest.TestCase):
133    def setUp(self):
134        self._server = test_common.test_server()
135        self._port = self._server.add_insecure_port("{}:0".format(_HOST))
136        self._handler = _GenericHandler()
137        self._server.add_registered_method_handlers(
138            _signal_client._SERVICE_NAME, get_method_handlers(self._handler)
139        )
140        self._server.start()
141
142    def tearDown(self):
143        self._server.stop(None)
144
145    @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows")
146    def testUnary(self):
147        """Tests that the server unary code path does not stall signal handlers."""
148        server_target = "{}:{}".format(_HOST, self._port)
149        with tempfile.TemporaryFile(mode="r") as client_stdout:
150            with tempfile.TemporaryFile(mode="r") as client_stderr:
151                client = _start_client(
152                    (server_target, "unary") + _GEVENT_ARG,
153                    client_stdout,
154                    client_stderr,
155                )
156                self._handler.await_connected_client()
157                client.send_signal(signal.SIGINT)
158                self.assertFalse(client.wait(), msg=_read_stream(client_stderr))
159                client_stdout.seek(0)
160                self.assertIn(
161                    _signal_client.SIGTERM_MESSAGE, client_stdout.read()
162                )
163
164    @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows")
165    def testStreaming(self):
166        """Tests that the server streaming code path does not stall signal handlers."""
167        server_target = "{}:{}".format(_HOST, self._port)
168        with tempfile.TemporaryFile(mode="r") as client_stdout:
169            with tempfile.TemporaryFile(mode="r") as client_stderr:
170                client = _start_client(
171                    (server_target, "streaming") + _GEVENT_ARG,
172                    client_stdout,
173                    client_stderr,
174                )
175                self._handler.await_connected_client()
176                client.send_signal(signal.SIGINT)
177                self.assertFalse(client.wait(), msg=_read_stream(client_stderr))
178                client_stdout.seek(0)
179                self.assertIn(
180                    _signal_client.SIGTERM_MESSAGE, client_stdout.read()
181                )
182
183    @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows")
184    def testUnaryWithException(self):
185        server_target = "{}:{}".format(_HOST, self._port)
186        with tempfile.TemporaryFile(mode="r") as client_stdout:
187            with tempfile.TemporaryFile(mode="r") as client_stderr:
188                client = _start_client(
189                    ("--exception", server_target, "unary") + _GEVENT_ARG,
190                    client_stdout,
191                    client_stderr,
192                )
193                self._handler.await_connected_client()
194                client.send_signal(signal.SIGINT)
195                client.wait()
196                self.assertEqual(0, client.returncode)
197
198    @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows")
199    def testStreamingHandlerWithException(self):
200        server_target = "{}:{}".format(_HOST, self._port)
201        with tempfile.TemporaryFile(mode="r") as client_stdout:
202            with tempfile.TemporaryFile(mode="r") as client_stderr:
203                client = _start_client(
204                    ("--exception", server_target, "streaming") + _GEVENT_ARG,
205                    client_stdout,
206                    client_stderr,
207                )
208                self._handler.await_connected_client()
209                client.send_signal(signal.SIGINT)
210                client.wait()
211                client_stderr_output = _read_stream(client_stderr)
212                try:
213                    self.assertEqual(0, client.returncode)
214                except AssertionError:
215                    print(client_stderr_output, file=sys.stderr)
216                    raise
217
218
219if __name__ == "__main__":
220    logging.basicConfig()
221    unittest.main(verbosity=2)
222