• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server and client side compression."""
15
16import itertools
17import logging
18import threading
19import time
20import unittest
21
22import grpc
23
24from tests.unit import test_common
25from tests.unit.framework.common import test_constants
26
27_BEAT = 0.5
28_SOME_TIME = 5
29_MORE_TIME = 10
30
31_SERVICE_NAME = ""
32_STREAM_URI = "Meffod"
33_UNARY_URI = "MeffodMan"
34
35
36class _StreamingMethodHandler(grpc.RpcMethodHandler):
37    request_streaming = True
38    response_streaming = True
39    request_deserializer = None
40    response_serializer = None
41
42    def stream_stream(self, request_iterator, servicer_context):
43        for request in request_iterator:
44            yield request * 2
45
46
47class _UnaryMethodHandler(grpc.RpcMethodHandler):
48    request_streaming = False
49    response_streaming = False
50    request_deserializer = None
51    response_serializer = None
52
53    def unary_unary(self, request, servicer_context):
54        return request * 2
55
56
57_STREAMING_METHOD_HANDLER = _StreamingMethodHandler()
58_UNARY_METHOD_HANDLER = _UnaryMethodHandler()
59
60
61_METHOD_HANDLERS = {
62    _STREAM_URI: _STREAMING_METHOD_HANDLER,
63    _UNARY_URI: _UNARY_METHOD_HANDLER,
64}
65
66
67class _Pipe(object):
68    def __init__(self, values):
69        self._condition = threading.Condition()
70        self._values = list(values)
71        self._open = True
72
73    def __iter__(self):
74        return self
75
76    def _next(self):
77        with self._condition:
78            while not self._values and self._open:
79                self._condition.wait()
80            if self._values:
81                return self._values.pop(0)
82            else:
83                raise StopIteration()
84
85    def next(self):
86        return self._next()
87
88    def __next__(self):
89        return self._next()
90
91    def add(self, value):
92        with self._condition:
93            self._values.append(value)
94            self._condition.notify()
95
96    def close(self):
97        with self._condition:
98            self._open = False
99            self._condition.notify()
100
101    def __enter__(self):
102        return self
103
104    def __exit__(self, type, value, traceback):
105        self.close()
106
107
108class ChannelCloseTest(unittest.TestCase):
109    def setUp(self):
110        self._server = test_common.test_server(
111            max_workers=test_constants.THREAD_CONCURRENCY
112        )
113        self._server.add_registered_method_handlers("", _METHOD_HANDLERS)
114        self._port = self._server.add_insecure_port("[::]:0")
115        self._server.start()
116
117    def tearDown(self):
118        self._server.stop(None)
119
120    def test_close_immediately_after_call_invocation(self):
121        channel = grpc.insecure_channel("localhost:{}".format(self._port))
122        multi_callable = channel.stream_stream(
123            grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_URI),
124            _registered_method=True,
125        )
126        request_iterator = _Pipe(())
127        response_iterator = multi_callable(request_iterator)
128        channel.close()
129        request_iterator.close()
130
131        self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
132
133    def test_close_while_call_active(self):
134        channel = grpc.insecure_channel("localhost:{}".format(self._port))
135        multi_callable = channel.stream_stream(
136            grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_URI),
137            _registered_method=True,
138        )
139        request_iterator = _Pipe((b"abc",))
140        response_iterator = multi_callable(request_iterator)
141        next(response_iterator)
142        channel.close()
143        request_iterator.close()
144
145        self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
146
147    def test_context_manager_close_while_call_active(self):
148        with grpc.insecure_channel(
149            "localhost:{}".format(self._port)
150        ) as channel:  # pylint: disable=bad-continuation
151            multi_callable = channel.stream_stream(
152                grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_URI),
153                _registered_method=True,
154            )
155            request_iterator = _Pipe((b"abc",))
156            response_iterator = multi_callable(request_iterator)
157            next(response_iterator)
158        request_iterator.close()
159
160        self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
161
162    def test_context_manager_close_while_many_calls_active(self):
163        with grpc.insecure_channel(
164            "localhost:{}".format(self._port)
165        ) as channel:  # pylint: disable=bad-continuation
166            multi_callable = channel.stream_stream(
167                grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_URI),
168                _registered_method=True,
169            )
170            request_iterators = tuple(
171                _Pipe((b"abc",))
172                for _ in range(test_constants.THREAD_CONCURRENCY)
173            )
174            response_iterators = []
175            for request_iterator in request_iterators:
176                response_iterator = multi_callable(request_iterator)
177                next(response_iterator)
178                response_iterators.append(response_iterator)
179        for request_iterator in request_iterators:
180            request_iterator.close()
181
182        for response_iterator in response_iterators:
183            self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
184
185    def test_many_concurrent_closes(self):
186        channel = grpc.insecure_channel("localhost:{}".format(self._port))
187        multi_callable = channel.stream_stream(
188            grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_URI),
189            _registered_method=True,
190        )
191        request_iterator = _Pipe((b"abc",))
192        response_iterator = multi_callable(request_iterator)
193        next(response_iterator)
194        start = time.time()
195        end = start + _MORE_TIME
196
197        def sleep_some_time_then_close():
198            time.sleep(_SOME_TIME)
199            channel.close()
200
201        for _ in range(test_constants.THREAD_CONCURRENCY):
202            close_thread = threading.Thread(target=sleep_some_time_then_close)
203            close_thread.start()
204        while True:
205            request_iterator.add(b"def")
206            time.sleep(_BEAT)
207            if end < time.time():
208                break
209        request_iterator.close()
210
211        self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
212
213    def test_exception_in_callback(self):
214        with grpc.insecure_channel(
215            "localhost:{}".format(self._port)
216        ) as channel:
217            stream_multi_callable = channel.stream_stream(
218                grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_URI),
219                _registered_method=True,
220            )
221            endless_iterator = itertools.repeat(b"abc")
222            stream_response_iterator = stream_multi_callable(endless_iterator)
223            future = channel.unary_unary(
224                grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_URI),
225                _registered_method=True,
226            ).future(b"abc")
227
228            def on_done_callback(future):
229                raise Exception("This should not cause a deadlock.")
230
231            future.add_done_callback(on_done_callback)
232            future.result()
233
234
235if __name__ == "__main__":
236    logging.basicConfig()
237    unittest.main(verbosity=2)
238