• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 the gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test of responsiveness to signals."""
15
16import logging
17import os
18import signal
19import subprocess
20import tempfile
21import threading
22import unittest
23import sys
24
25import grpc
26
27from tests.unit import test_common
28from tests.unit import _signal_client
29
30_CLIENT_PATH = None
31if sys.executable is not None:
32    _CLIENT_PATH = os.path.abspath(os.path.realpath(_signal_client.__file__))
33else:
34    # NOTE(rbellevi): For compatibility with internal testing.
35    if len(sys.argv) != 2:
36        raise RuntimeError("Must supply path to executable client.")
37    client_name = sys.argv[1].split("/")[-1]
38    del sys.argv[1]  # For compatibility with test runner.
39    _CLIENT_PATH = os.path.realpath(
40        os.path.join(os.path.dirname(os.path.abspath(__file__)), client_name))
41
42_HOST = 'localhost'
43
44
45class _GenericHandler(grpc.GenericRpcHandler):
46
47    def __init__(self):
48        self._connected_clients_lock = threading.RLock()
49        self._connected_clients_event = threading.Event()
50        self._connected_clients = 0
51
52        self._unary_unary_handler = grpc.unary_unary_rpc_method_handler(
53            self._handle_unary_unary)
54        self._unary_stream_handler = grpc.unary_stream_rpc_method_handler(
55            self._handle_unary_stream)
56
57    def _on_client_connect(self):
58        with self._connected_clients_lock:
59            self._connected_clients += 1
60            self._connected_clients_event.set()
61
62    def _on_client_disconnect(self):
63        with self._connected_clients_lock:
64            self._connected_clients -= 1
65            if self._connected_clients == 0:
66                self._connected_clients_event.clear()
67
68    def await_connected_client(self):
69        """Blocks until a client connects to the server."""
70        self._connected_clients_event.wait()
71
72    def _handle_unary_unary(self, request, servicer_context):
73        """Handles a unary RPC.
74
75        Blocks until the client disconnects and then echoes.
76        """
77        stop_event = threading.Event()
78
79        def on_rpc_end():
80            self._on_client_disconnect()
81            stop_event.set()
82
83        servicer_context.add_callback(on_rpc_end)
84        self._on_client_connect()
85        stop_event.wait()
86        return request
87
88    def _handle_unary_stream(self, request, servicer_context):
89        """Handles a server streaming RPC.
90
91        Blocks until the client disconnects and then echoes.
92        """
93        stop_event = threading.Event()
94
95        def on_rpc_end():
96            self._on_client_disconnect()
97            stop_event.set()
98
99        servicer_context.add_callback(on_rpc_end)
100        self._on_client_connect()
101        stop_event.wait()
102        yield request
103
104    def service(self, handler_call_details):
105        if handler_call_details.method == _signal_client.UNARY_UNARY:
106            return self._unary_unary_handler
107        elif handler_call_details.method == _signal_client.UNARY_STREAM:
108            return self._unary_stream_handler
109        else:
110            return None
111
112
113def _read_stream(stream):
114    stream.seek(0)
115    return stream.read()
116
117
118def _start_client(args, stdout, stderr):
119    invocation = None
120    if sys.executable is not None:
121        invocation = (sys.executable, _CLIENT_PATH) + tuple(args)
122    else:
123        invocation = (_CLIENT_PATH,) + tuple(args)
124    return subprocess.Popen(invocation, stdout=stdout, stderr=stderr)
125
126
127class SignalHandlingTest(unittest.TestCase):
128
129    def setUp(self):
130        self._server = test_common.test_server()
131        self._port = self._server.add_insecure_port('{}:0'.format(_HOST))
132        self._handler = _GenericHandler()
133        self._server.add_generic_rpc_handlers((self._handler,))
134        self._server.start()
135
136    def tearDown(self):
137        self._server.stop(None)
138
139    @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows')
140    def testUnary(self):
141        """Tests that the server unary code path does not stall signal handlers."""
142        server_target = '{}:{}'.format(_HOST, self._port)
143        with tempfile.TemporaryFile(mode='r') as client_stdout:
144            with tempfile.TemporaryFile(mode='r') as client_stderr:
145                client = _start_client((server_target, 'unary'), client_stdout,
146                                       client_stderr)
147                self._handler.await_connected_client()
148                client.send_signal(signal.SIGINT)
149                self.assertFalse(client.wait(), msg=_read_stream(client_stderr))
150                client_stdout.seek(0)
151                self.assertIn(_signal_client.SIGTERM_MESSAGE,
152                              client_stdout.read())
153
154    @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows')
155    def testStreaming(self):
156        """Tests that the server streaming code path does not stall signal handlers."""
157        server_target = '{}:{}'.format(_HOST, self._port)
158        with tempfile.TemporaryFile(mode='r') as client_stdout:
159            with tempfile.TemporaryFile(mode='r') as client_stderr:
160                client = _start_client((server_target, 'streaming'),
161                                       client_stdout, client_stderr)
162                self._handler.await_connected_client()
163                client.send_signal(signal.SIGINT)
164                self.assertFalse(client.wait(), msg=_read_stream(client_stderr))
165                client_stdout.seek(0)
166                self.assertIn(_signal_client.SIGTERM_MESSAGE,
167                              client_stdout.read())
168
169    @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows')
170    def testUnaryWithException(self):
171        server_target = '{}:{}'.format(_HOST, self._port)
172        with tempfile.TemporaryFile(mode='r') as client_stdout:
173            with tempfile.TemporaryFile(mode='r') as client_stderr:
174                client = _start_client(('--exception', server_target, 'unary'),
175                                       client_stdout, client_stderr)
176                self._handler.await_connected_client()
177                client.send_signal(signal.SIGINT)
178                client.wait()
179                self.assertEqual(0, client.returncode)
180
181    @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows')
182    def testStreamingHandlerWithException(self):
183        server_target = '{}:{}'.format(_HOST, self._port)
184        with tempfile.TemporaryFile(mode='r') as client_stdout:
185            with tempfile.TemporaryFile(mode='r') as client_stderr:
186                client = _start_client(
187                    ('--exception', server_target, 'streaming'), client_stdout,
188                    client_stderr)
189                self._handler.await_connected_client()
190                client.send_signal(signal.SIGINT)
191                client.wait()
192                print(_read_stream(client_stderr))
193                self.assertEqual(0, client.returncode)
194
195
196if __name__ == '__main__':
197    logging.basicConfig()
198    unittest.main(verbosity=2)
199