• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests metadata flags feature by testing wait-for-ready semantics"""
15
16import time
17import weakref
18import unittest
19import threading
20import logging
21import socket
22from six.moves import queue
23
24import grpc
25
26from tests.unit import test_common
27from tests.unit.framework.common import test_constants
28import tests.unit.framework.common
29from tests.unit.framework.common import get_socket
30
31_UNARY_UNARY = '/test/UnaryUnary'
32_UNARY_STREAM = '/test/UnaryStream'
33_STREAM_UNARY = '/test/StreamUnary'
34_STREAM_STREAM = '/test/StreamStream'
35
36_REQUEST = b'\x00\x00\x00'
37_RESPONSE = b'\x00\x00\x00'
38
39
40def handle_unary_unary(test, request, servicer_context):
41    return _RESPONSE
42
43
44def handle_unary_stream(test, request, servicer_context):
45    for _ in range(test_constants.STREAM_LENGTH):
46        yield _RESPONSE
47
48
49def handle_stream_unary(test, request_iterator, servicer_context):
50    for _ in request_iterator:
51        pass
52    return _RESPONSE
53
54
55def handle_stream_stream(test, request_iterator, servicer_context):
56    for _ in request_iterator:
57        yield _RESPONSE
58
59
60class _MethodHandler(grpc.RpcMethodHandler):
61
62    def __init__(self, test, request_streaming, response_streaming):
63        self.request_streaming = request_streaming
64        self.response_streaming = response_streaming
65        self.request_deserializer = None
66        self.response_serializer = None
67        self.unary_unary = None
68        self.unary_stream = None
69        self.stream_unary = None
70        self.stream_stream = None
71        if self.request_streaming and self.response_streaming:
72            self.stream_stream = lambda req, ctx: handle_stream_stream(
73                test, req, ctx)
74        elif self.request_streaming:
75            self.stream_unary = lambda req, ctx: handle_stream_unary(
76                test, req, ctx)
77        elif self.response_streaming:
78            self.unary_stream = lambda req, ctx: handle_unary_stream(
79                test, req, ctx)
80        else:
81            self.unary_unary = lambda req, ctx: handle_unary_unary(
82                test, req, ctx)
83
84
85class _GenericHandler(grpc.GenericRpcHandler):
86
87    def __init__(self, test):
88        self._test = test
89
90    def service(self, handler_call_details):
91        if handler_call_details.method == _UNARY_UNARY:
92            return _MethodHandler(self._test, False, False)
93        elif handler_call_details.method == _UNARY_STREAM:
94            return _MethodHandler(self._test, False, True)
95        elif handler_call_details.method == _STREAM_UNARY:
96            return _MethodHandler(self._test, True, False)
97        elif handler_call_details.method == _STREAM_STREAM:
98            return _MethodHandler(self._test, True, True)
99        else:
100            return None
101
102
103def create_dummy_channel():
104    """Creating dummy channels is a workaround for retries"""
105    host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,))
106    sock.close()
107    return grpc.insecure_channel('{}:{}'.format(host, port))
108
109
110def perform_unary_unary_call(channel, wait_for_ready=None):
111    channel.unary_unary(_UNARY_UNARY).__call__(
112        _REQUEST,
113        timeout=test_constants.LONG_TIMEOUT,
114        wait_for_ready=wait_for_ready)
115
116
117def perform_unary_unary_with_call(channel, wait_for_ready=None):
118    channel.unary_unary(_UNARY_UNARY).with_call(
119        _REQUEST,
120        timeout=test_constants.LONG_TIMEOUT,
121        wait_for_ready=wait_for_ready)
122
123
124def perform_unary_unary_future(channel, wait_for_ready=None):
125    channel.unary_unary(_UNARY_UNARY).future(
126        _REQUEST,
127        timeout=test_constants.LONG_TIMEOUT,
128        wait_for_ready=wait_for_ready).result(
129            timeout=test_constants.LONG_TIMEOUT)
130
131
132def perform_unary_stream_call(channel, wait_for_ready=None):
133    response_iterator = channel.unary_stream(_UNARY_STREAM).__call__(
134        _REQUEST,
135        timeout=test_constants.LONG_TIMEOUT,
136        wait_for_ready=wait_for_ready)
137    for _ in response_iterator:
138        pass
139
140
141def perform_stream_unary_call(channel, wait_for_ready=None):
142    channel.stream_unary(_STREAM_UNARY).__call__(
143        iter([_REQUEST] * test_constants.STREAM_LENGTH),
144        timeout=test_constants.LONG_TIMEOUT,
145        wait_for_ready=wait_for_ready)
146
147
148def perform_stream_unary_with_call(channel, wait_for_ready=None):
149    channel.stream_unary(_STREAM_UNARY).with_call(
150        iter([_REQUEST] * test_constants.STREAM_LENGTH),
151        timeout=test_constants.LONG_TIMEOUT,
152        wait_for_ready=wait_for_ready)
153
154
155def perform_stream_unary_future(channel, wait_for_ready=None):
156    channel.stream_unary(_STREAM_UNARY).future(
157        iter([_REQUEST] * test_constants.STREAM_LENGTH),
158        timeout=test_constants.LONG_TIMEOUT,
159        wait_for_ready=wait_for_ready).result(
160            timeout=test_constants.LONG_TIMEOUT)
161
162
163def perform_stream_stream_call(channel, wait_for_ready=None):
164    response_iterator = channel.stream_stream(_STREAM_STREAM).__call__(
165        iter([_REQUEST] * test_constants.STREAM_LENGTH),
166        timeout=test_constants.LONG_TIMEOUT,
167        wait_for_ready=wait_for_ready)
168    for _ in response_iterator:
169        pass
170
171
172_ALL_CALL_CASES = [
173    perform_unary_unary_call, perform_unary_unary_with_call,
174    perform_unary_unary_future, perform_unary_stream_call,
175    perform_stream_unary_call, perform_stream_unary_with_call,
176    perform_stream_unary_future, perform_stream_stream_call
177]
178
179
180class MetadataFlagsTest(unittest.TestCase):
181
182    def check_connection_does_failfast(self, fn, channel, wait_for_ready=None):
183        try:
184            fn(channel, wait_for_ready)
185            self.fail("The Call should fail")
186        except BaseException as e:  # pylint: disable=broad-except
187            self.assertIs(grpc.StatusCode.UNAVAILABLE, e.code())
188
189    def test_call_wait_for_ready_default(self):
190        for perform_call in _ALL_CALL_CASES:
191            with create_dummy_channel() as channel:
192                self.check_connection_does_failfast(perform_call, channel)
193
194    def test_call_wait_for_ready_disabled(self):
195        for perform_call in _ALL_CALL_CASES:
196            with create_dummy_channel() as channel:
197                self.check_connection_does_failfast(perform_call,
198                                                    channel,
199                                                    wait_for_ready=False)
200
201    def test_call_wait_for_ready_enabled(self):
202        # To test the wait mechanism, Python thread is required to make
203        #   client set up first without handling them case by case.
204        # Also, Python thread don't pass the unhandled exceptions to
205        #   main thread. So, it need another method to store the
206        #   exceptions and raise them again in main thread.
207        unhandled_exceptions = queue.Queue()
208
209        # We just need an unused TCP port
210        host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,))
211        sock.close()
212
213        addr = '{}:{}'.format(host, port)
214        wg = test_common.WaitGroup(len(_ALL_CALL_CASES))
215
216        def wait_for_transient_failure(channel_connectivity):
217            if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE:
218                wg.done()
219
220        def test_call(perform_call):
221            with grpc.insecure_channel(addr) as channel:
222                try:
223                    channel.subscribe(wait_for_transient_failure)
224                    perform_call(channel, wait_for_ready=True)
225                except BaseException as e:  # pylint: disable=broad-except
226                    # If the call failed, the thread would be destroyed. The
227                    # channel object can be collected before calling the
228                    # callback, which will result in a deadlock.
229                    wg.done()
230                    unhandled_exceptions.put(e, True)
231
232        test_threads = []
233        for perform_call in _ALL_CALL_CASES:
234            test_thread = threading.Thread(target=test_call,
235                                           args=(perform_call,))
236            test_thread.daemon = True
237            test_thread.exception = None
238            test_thread.start()
239            test_threads.append(test_thread)
240
241        # Start the server after the connections are waiting
242        wg.wait()
243        server = test_common.test_server(reuse_port=True)
244        server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),))
245        server.add_insecure_port(addr)
246        server.start()
247
248        for test_thread in test_threads:
249            test_thread.join()
250
251        # Stop the server to make test end properly
252        server.stop(0)
253
254        if not unhandled_exceptions.empty():
255            raise unhandled_exceptions.get(True)
256
257
258if __name__ == '__main__':
259    logging.basicConfig(level=logging.DEBUG)
260    unittest.main(verbosity=2)
261