• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests metadata flags feature by testing wait-for-ready semantics"""
15
16import logging
17import queue
18import socket
19import threading
20import time
21import unittest
22import weakref
23
24import grpc
25
26from tests.unit import test_common
27import tests.unit.framework.common
28from tests.unit.framework.common import get_socket
29from tests.unit.framework.common import test_constants
30
31_SERVICE_NAME = "test"
32_UNARY_UNARY = "UnaryUnary"
33_UNARY_STREAM = "UnaryStream"
34_STREAM_UNARY = "StreamUnary"
35_STREAM_STREAM = "StreamStream"
36
37_REQUEST = b"\x00\x00\x00"
38_RESPONSE = b"\x00\x00\x00"
39
40
41def handle_unary_unary(test, request, servicer_context):
42    return _RESPONSE
43
44
45def handle_unary_stream(test, request, servicer_context):
46    for _ in range(test_constants.STREAM_LENGTH):
47        yield _RESPONSE
48
49
50def handle_stream_unary(test, request_iterator, servicer_context):
51    for _ in request_iterator:
52        pass
53    return _RESPONSE
54
55
56def handle_stream_stream(test, request_iterator, servicer_context):
57    for _ in request_iterator:
58        yield _RESPONSE
59
60
61class _MethodHandler(grpc.RpcMethodHandler):
62    def __init__(self, test, request_streaming, response_streaming):
63        self.request_streaming = request_streaming
64        self.response_streaming = response_streaming
65        self.request_deserializer = None
66        self.response_serializer = None
67        self.unary_unary = None
68        self.unary_stream = None
69        self.stream_unary = None
70        self.stream_stream = None
71        if self.request_streaming and self.response_streaming:
72            self.stream_stream = lambda req, ctx: handle_stream_stream(
73                test, req, ctx
74            )
75        elif self.request_streaming:
76            self.stream_unary = lambda req, ctx: handle_stream_unary(
77                test, req, ctx
78            )
79        elif self.response_streaming:
80            self.unary_stream = lambda req, ctx: handle_unary_stream(
81                test, req, ctx
82            )
83        else:
84            self.unary_unary = lambda req, ctx: handle_unary_unary(
85                test, req, ctx
86            )
87
88
89def get_method_handlers(test):
90    return {
91        _UNARY_UNARY: _MethodHandler(test, False, False),
92        _UNARY_STREAM: _MethodHandler(test, False, True),
93        _STREAM_UNARY: _MethodHandler(test, True, False),
94        _STREAM_STREAM: _MethodHandler(test, True, True),
95    }
96
97
98def create_phony_channel():
99    """Creating phony channels is a workaround for retries"""
100    host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,))
101    sock.close()
102    return grpc.insecure_channel("{}:{}".format(host, port))
103
104
105def perform_unary_unary_call(channel, wait_for_ready=None):
106    channel.unary_unary(
107        grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_UNARY),
108        _registered_method=True,
109    ).__call__(
110        _REQUEST,
111        timeout=test_constants.LONG_TIMEOUT,
112        wait_for_ready=wait_for_ready,
113    )
114
115
116def perform_unary_unary_with_call(channel, wait_for_ready=None):
117    channel.unary_unary(
118        grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_UNARY),
119        _registered_method=True,
120    ).with_call(
121        _REQUEST,
122        timeout=test_constants.LONG_TIMEOUT,
123        wait_for_ready=wait_for_ready,
124    )
125
126
127def perform_unary_unary_future(channel, wait_for_ready=None):
128    channel.unary_unary(
129        grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_UNARY),
130        _registered_method=True,
131    ).future(
132        _REQUEST,
133        timeout=test_constants.LONG_TIMEOUT,
134        wait_for_ready=wait_for_ready,
135    ).result(
136        timeout=test_constants.LONG_TIMEOUT
137    )
138
139
140def perform_unary_stream_call(channel, wait_for_ready=None):
141    response_iterator = channel.unary_stream(
142        grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_STREAM),
143        _registered_method=True,
144    ).__call__(
145        _REQUEST,
146        timeout=test_constants.LONG_TIMEOUT,
147        wait_for_ready=wait_for_ready,
148    )
149    for _ in response_iterator:
150        pass
151
152
153def perform_stream_unary_call(channel, wait_for_ready=None):
154    channel.stream_unary(
155        grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_UNARY),
156        _registered_method=True,
157    ).__call__(
158        iter([_REQUEST] * test_constants.STREAM_LENGTH),
159        timeout=test_constants.LONG_TIMEOUT,
160        wait_for_ready=wait_for_ready,
161    )
162
163
164def perform_stream_unary_with_call(channel, wait_for_ready=None):
165    channel.stream_unary(
166        grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_UNARY),
167        _registered_method=True,
168    ).with_call(
169        iter([_REQUEST] * test_constants.STREAM_LENGTH),
170        timeout=test_constants.LONG_TIMEOUT,
171        wait_for_ready=wait_for_ready,
172    )
173
174
175def perform_stream_unary_future(channel, wait_for_ready=None):
176    channel.stream_unary(
177        grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_UNARY),
178        _registered_method=True,
179    ).future(
180        iter([_REQUEST] * test_constants.STREAM_LENGTH),
181        timeout=test_constants.LONG_TIMEOUT,
182        wait_for_ready=wait_for_ready,
183    ).result(
184        timeout=test_constants.LONG_TIMEOUT
185    )
186
187
188def perform_stream_stream_call(channel, wait_for_ready=None):
189    response_iterator = channel.stream_stream(
190        grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_STREAM),
191        _registered_method=True,
192    ).__call__(
193        iter([_REQUEST] * test_constants.STREAM_LENGTH),
194        timeout=test_constants.LONG_TIMEOUT,
195        wait_for_ready=wait_for_ready,
196    )
197    for _ in response_iterator:
198        pass
199
200
201_ALL_CALL_CASES = [
202    perform_unary_unary_call,
203    perform_unary_unary_with_call,
204    perform_unary_unary_future,
205    perform_unary_stream_call,
206    perform_stream_unary_call,
207    perform_stream_unary_with_call,
208    perform_stream_unary_future,
209    perform_stream_stream_call,
210]
211
212
213class MetadataFlagsTest(unittest.TestCase):
214    def check_connection_does_failfast(self, fn, channel, wait_for_ready=None):
215        try:
216            fn(channel, wait_for_ready)
217            self.fail("The Call should fail")
218        except BaseException as e:  # pylint: disable=broad-except
219            self.assertIs(grpc.StatusCode.UNAVAILABLE, e.code())
220
221    def test_call_wait_for_ready_default(self):
222        for perform_call in _ALL_CALL_CASES:
223            with create_phony_channel() as channel:
224                self.check_connection_does_failfast(perform_call, channel)
225
226    def test_call_wait_for_ready_disabled(self):
227        for perform_call in _ALL_CALL_CASES:
228            with create_phony_channel() as channel:
229                self.check_connection_does_failfast(
230                    perform_call, channel, wait_for_ready=False
231                )
232
233    def test_call_wait_for_ready_enabled(self):
234        # To test the wait mechanism, Python thread is required to make
235        #   client set up first without handling them case by case.
236        # Also, Python thread don't pass the unhandled exceptions to
237        #   main thread. So, it needs another method to store the
238        #   exceptions and raise them again in main thread.
239        unhandled_exceptions = queue.Queue()
240
241        # We just need an unused TCP port
242        host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,))
243        sock.close()
244
245        addr = "{}:{}".format(host, port)
246        wg = test_common.WaitGroup(len(_ALL_CALL_CASES))
247
248        def wait_for_transient_failure(channel_connectivity):
249            if (
250                channel_connectivity
251                == grpc.ChannelConnectivity.TRANSIENT_FAILURE
252            ):
253                wg.done()
254
255        def test_call(perform_call):
256            with grpc.insecure_channel(addr) as channel:
257                try:
258                    channel.subscribe(wait_for_transient_failure)
259                    perform_call(channel, wait_for_ready=True)
260                except BaseException as e:  # pylint: disable=broad-except
261                    # If the call failed, the thread would be destroyed. The
262                    # channel object can be collected before calling the
263                    # callback, which will result in a deadlock.
264                    wg.done()
265                    unhandled_exceptions.put(e, True)
266
267        test_threads = []
268        for perform_call in _ALL_CALL_CASES:
269            test_thread = threading.Thread(
270                target=test_call, args=(perform_call,)
271            )
272            test_thread.daemon = True
273            test_thread.exception = None
274            test_thread.start()
275            test_threads.append(test_thread)
276
277        # Start the server after the connections are waiting
278        wg.wait()
279        server = test_common.test_server(reuse_port=True)
280        server.add_registered_method_handlers(
281            _SERVICE_NAME, get_method_handlers(weakref.proxy(self))
282        )
283        server.add_insecure_port(addr)
284        server.start()
285
286        for test_thread in test_threads:
287            test_thread.join()
288
289        # Stop the server to make test end properly
290        server.stop(0)
291
292        if not unhandled_exceptions.empty():
293            raise unhandled_exceptions.get(True)
294
295
296if __name__ == "__main__":
297    logging.basicConfig(level=logging.DEBUG)
298    unittest.main(verbosity=2)
299