• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests metadata flags feature by testing wait-for-ready semantics"""
15
16import logging
17import queue
18import socket
19import threading
20import time
21import unittest
22import weakref
23
24import grpc
25
26from tests.unit import test_common
27import tests.unit.framework.common
28from tests.unit.framework.common import get_socket
29from tests.unit.framework.common import test_constants
30
31_UNARY_UNARY = "/test/UnaryUnary"
32_UNARY_STREAM = "/test/UnaryStream"
33_STREAM_UNARY = "/test/StreamUnary"
34_STREAM_STREAM = "/test/StreamStream"
35
36_REQUEST = b"\x00\x00\x00"
37_RESPONSE = b"\x00\x00\x00"
38
39
40def handle_unary_unary(test, request, servicer_context):
41    return _RESPONSE
42
43
44def handle_unary_stream(test, request, servicer_context):
45    for _ in range(test_constants.STREAM_LENGTH):
46        yield _RESPONSE
47
48
49def handle_stream_unary(test, request_iterator, servicer_context):
50    for _ in request_iterator:
51        pass
52    return _RESPONSE
53
54
55def handle_stream_stream(test, request_iterator, servicer_context):
56    for _ in request_iterator:
57        yield _RESPONSE
58
59
60class _MethodHandler(grpc.RpcMethodHandler):
61    def __init__(self, test, request_streaming, response_streaming):
62        self.request_streaming = request_streaming
63        self.response_streaming = response_streaming
64        self.request_deserializer = None
65        self.response_serializer = None
66        self.unary_unary = None
67        self.unary_stream = None
68        self.stream_unary = None
69        self.stream_stream = None
70        if self.request_streaming and self.response_streaming:
71            self.stream_stream = lambda req, ctx: handle_stream_stream(
72                test, req, ctx
73            )
74        elif self.request_streaming:
75            self.stream_unary = lambda req, ctx: handle_stream_unary(
76                test, req, ctx
77            )
78        elif self.response_streaming:
79            self.unary_stream = lambda req, ctx: handle_unary_stream(
80                test, req, ctx
81            )
82        else:
83            self.unary_unary = lambda req, ctx: handle_unary_unary(
84                test, req, ctx
85            )
86
87
88class _GenericHandler(grpc.GenericRpcHandler):
89    def __init__(self, test):
90        self._test = test
91
92    def service(self, handler_call_details):
93        if handler_call_details.method == _UNARY_UNARY:
94            return _MethodHandler(self._test, False, False)
95        elif handler_call_details.method == _UNARY_STREAM:
96            return _MethodHandler(self._test, False, True)
97        elif handler_call_details.method == _STREAM_UNARY:
98            return _MethodHandler(self._test, True, False)
99        elif handler_call_details.method == _STREAM_STREAM:
100            return _MethodHandler(self._test, True, True)
101        else:
102            return None
103
104
105def create_phony_channel():
106    """Creating phony channels is a workaround for retries"""
107    host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,))
108    sock.close()
109    return grpc.insecure_channel("{}:{}".format(host, port))
110
111
112def perform_unary_unary_call(channel, wait_for_ready=None):
113    channel.unary_unary(
114        _UNARY_UNARY,
115        _registered_method=True,
116    ).__call__(
117        _REQUEST,
118        timeout=test_constants.LONG_TIMEOUT,
119        wait_for_ready=wait_for_ready,
120    )
121
122
123def perform_unary_unary_with_call(channel, wait_for_ready=None):
124    channel.unary_unary(
125        _UNARY_UNARY,
126        _registered_method=True,
127    ).with_call(
128        _REQUEST,
129        timeout=test_constants.LONG_TIMEOUT,
130        wait_for_ready=wait_for_ready,
131    )
132
133
134def perform_unary_unary_future(channel, wait_for_ready=None):
135    channel.unary_unary(
136        _UNARY_UNARY,
137        _registered_method=True,
138    ).future(
139        _REQUEST,
140        timeout=test_constants.LONG_TIMEOUT,
141        wait_for_ready=wait_for_ready,
142    ).result(timeout=test_constants.LONG_TIMEOUT)
143
144
145def perform_unary_stream_call(channel, wait_for_ready=None):
146    response_iterator = channel.unary_stream(
147        _UNARY_STREAM,
148        _registered_method=True,
149    ).__call__(
150        _REQUEST,
151        timeout=test_constants.LONG_TIMEOUT,
152        wait_for_ready=wait_for_ready,
153    )
154    for _ in response_iterator:
155        pass
156
157
158def perform_stream_unary_call(channel, wait_for_ready=None):
159    channel.stream_unary(
160        _STREAM_UNARY,
161        _registered_method=True,
162    ).__call__(
163        iter([_REQUEST] * test_constants.STREAM_LENGTH),
164        timeout=test_constants.LONG_TIMEOUT,
165        wait_for_ready=wait_for_ready,
166    )
167
168
169def perform_stream_unary_with_call(channel, wait_for_ready=None):
170    channel.stream_unary(
171        _STREAM_UNARY,
172        _registered_method=True,
173    ).with_call(
174        iter([_REQUEST] * test_constants.STREAM_LENGTH),
175        timeout=test_constants.LONG_TIMEOUT,
176        wait_for_ready=wait_for_ready,
177    )
178
179
180def perform_stream_unary_future(channel, wait_for_ready=None):
181    channel.stream_unary(
182        _STREAM_UNARY,
183        _registered_method=True,
184    ).future(
185        iter([_REQUEST] * test_constants.STREAM_LENGTH),
186        timeout=test_constants.LONG_TIMEOUT,
187        wait_for_ready=wait_for_ready,
188    ).result(timeout=test_constants.LONG_TIMEOUT)
189
190
191def perform_stream_stream_call(channel, wait_for_ready=None):
192    response_iterator = channel.stream_stream(
193        _STREAM_STREAM, _registered_method=True
194    ).__call__(
195        iter([_REQUEST] * test_constants.STREAM_LENGTH),
196        timeout=test_constants.LONG_TIMEOUT,
197        wait_for_ready=wait_for_ready,
198    )
199    for _ in response_iterator:
200        pass
201
202
203_ALL_CALL_CASES = [
204    perform_unary_unary_call,
205    perform_unary_unary_with_call,
206    perform_unary_unary_future,
207    perform_unary_stream_call,
208    perform_stream_unary_call,
209    perform_stream_unary_with_call,
210    perform_stream_unary_future,
211    perform_stream_stream_call,
212]
213
214
215class MetadataFlagsTest(unittest.TestCase):
216    def check_connection_does_failfast(self, fn, channel, wait_for_ready=None):
217        try:
218            fn(channel, wait_for_ready)
219            self.fail("The Call should fail")
220        except BaseException as e:  # pylint: disable=broad-except
221            self.assertIs(grpc.StatusCode.UNAVAILABLE, e.code())
222
223    def test_call_wait_for_ready_default(self):
224        for perform_call in _ALL_CALL_CASES:
225            with create_phony_channel() as channel:
226                self.check_connection_does_failfast(perform_call, channel)
227
228    def test_call_wait_for_ready_disabled(self):
229        for perform_call in _ALL_CALL_CASES:
230            with create_phony_channel() as channel:
231                self.check_connection_does_failfast(
232                    perform_call, channel, wait_for_ready=False
233                )
234
235    def test_call_wait_for_ready_enabled(self):
236        # To test the wait mechanism, Python thread is required to make
237        #   client set up first without handling them case by case.
238        # Also, Python thread don't pass the unhandled exceptions to
239        #   main thread. So, it need another method to store the
240        #   exceptions and raise them again in main thread.
241        unhandled_exceptions = queue.Queue()
242
243        # We just need an unused TCP port
244        host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,))
245        sock.close()
246
247        addr = "{}:{}".format(host, port)
248        wg = test_common.WaitGroup(len(_ALL_CALL_CASES))
249
250        def wait_for_transient_failure(channel_connectivity):
251            if (
252                channel_connectivity
253                == grpc.ChannelConnectivity.TRANSIENT_FAILURE
254            ):
255                wg.done()
256
257        def test_call(perform_call):
258            with grpc.insecure_channel(addr) as channel:
259                try:
260                    channel.subscribe(wait_for_transient_failure)
261                    perform_call(channel, wait_for_ready=True)
262                except BaseException as e:  # pylint: disable=broad-except
263                    # If the call failed, the thread would be destroyed. The
264                    # channel object can be collected before calling the
265                    # callback, which will result in a deadlock.
266                    wg.done()
267                    unhandled_exceptions.put(e, True)
268
269        test_threads = []
270        for perform_call in _ALL_CALL_CASES:
271            test_thread = threading.Thread(
272                target=test_call, args=(perform_call,)
273            )
274            test_thread.daemon = True
275            test_thread.exception = None
276            test_thread.start()
277            test_threads.append(test_thread)
278
279        # Start the server after the connections are waiting
280        wg.wait()
281        server = test_common.test_server(reuse_port=True)
282        server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),))
283        server.add_insecure_port(addr)
284        server.start()
285
286        for test_thread in test_threads:
287            test_thread.join()
288
289        # Stop the server to make test end properly
290        server.stop(0)
291
292        if not unhandled_exceptions.empty():
293            raise unhandled_exceptions.get(True)
294
295
296if __name__ == "__main__":
297    logging.basicConfig(level=logging.DEBUG)
298    unittest.main(verbosity=2)
299