• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server and client side compression."""
15
16import itertools
17import logging
18import threading
19import time
20import unittest
21
22import grpc
23
24from tests.unit import test_common
25from tests.unit.framework.common import test_constants
26
27_BEAT = 0.5
28_SOME_TIME = 5
29_MORE_TIME = 10
30
31_STREAM_URI = "Meffod"
32_UNARY_URI = "MeffodMan"
33
34
35class _StreamingMethodHandler(grpc.RpcMethodHandler):
36    request_streaming = True
37    response_streaming = True
38    request_deserializer = None
39    response_serializer = None
40
41    def stream_stream(self, request_iterator, servicer_context):
42        for request in request_iterator:
43            yield request * 2
44
45
46class _UnaryMethodHandler(grpc.RpcMethodHandler):
47    request_streaming = False
48    response_streaming = False
49    request_deserializer = None
50    response_serializer = None
51
52    def unary_unary(self, request, servicer_context):
53        return request * 2
54
55
56_STREAMING_METHOD_HANDLER = _StreamingMethodHandler()
57_UNARY_METHOD_HANDLER = _UnaryMethodHandler()
58
59
60class _GenericHandler(grpc.GenericRpcHandler):
61    def service(self, handler_call_details):
62        if handler_call_details.method == _STREAM_URI:
63            return _STREAMING_METHOD_HANDLER
64        else:
65            return _UNARY_METHOD_HANDLER
66
67
68_GENERIC_HANDLER = _GenericHandler()
69
70
71class _Pipe(object):
72    def __init__(self, values):
73        self._condition = threading.Condition()
74        self._values = list(values)
75        self._open = True
76
77    def __iter__(self):
78        return self
79
80    def _next(self):
81        with self._condition:
82            while not self._values and self._open:
83                self._condition.wait()
84            if self._values:
85                return self._values.pop(0)
86            else:
87                raise StopIteration()
88
89    def next(self):
90        return self._next()
91
92    def __next__(self):
93        return self._next()
94
95    def add(self, value):
96        with self._condition:
97            self._values.append(value)
98            self._condition.notify()
99
100    def close(self):
101        with self._condition:
102            self._open = False
103            self._condition.notify()
104
105    def __enter__(self):
106        return self
107
108    def __exit__(self, type, value, traceback):
109        self.close()
110
111
112class ChannelCloseTest(unittest.TestCase):
113    def setUp(self):
114        self._server = test_common.test_server(
115            max_workers=test_constants.THREAD_CONCURRENCY
116        )
117        self._server.add_generic_rpc_handlers((_GENERIC_HANDLER,))
118        self._port = self._server.add_insecure_port("[::]:0")
119        self._server.start()
120
121    def tearDown(self):
122        self._server.stop(None)
123
124    def test_close_immediately_after_call_invocation(self):
125        channel = grpc.insecure_channel("localhost:{}".format(self._port))
126        multi_callable = channel.stream_stream(
127            _STREAM_URI,
128            _registered_method=True,
129        )
130        request_iterator = _Pipe(())
131        response_iterator = multi_callable(request_iterator)
132        channel.close()
133        request_iterator.close()
134
135        self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
136
137    def test_close_while_call_active(self):
138        channel = grpc.insecure_channel("localhost:{}".format(self._port))
139        multi_callable = channel.stream_stream(
140            _STREAM_URI,
141            _registered_method=True,
142        )
143        request_iterator = _Pipe((b"abc",))
144        response_iterator = multi_callable(request_iterator)
145        next(response_iterator)
146        channel.close()
147        request_iterator.close()
148
149        self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
150
151    def test_context_manager_close_while_call_active(self):
152        with grpc.insecure_channel(
153            "localhost:{}".format(self._port)
154        ) as channel:  # pylint: disable=bad-continuation
155            multi_callable = channel.stream_stream(
156                _STREAM_URI,
157                _registered_method=True,
158            )
159            request_iterator = _Pipe((b"abc",))
160            response_iterator = multi_callable(request_iterator)
161            next(response_iterator)
162        request_iterator.close()
163
164        self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
165
166    def test_context_manager_close_while_many_calls_active(self):
167        with grpc.insecure_channel(
168            "localhost:{}".format(self._port)
169        ) as channel:  # pylint: disable=bad-continuation
170            multi_callable = channel.stream_stream(
171                _STREAM_URI,
172                _registered_method=True,
173            )
174            request_iterators = tuple(
175                _Pipe((b"abc",))
176                for _ in range(test_constants.THREAD_CONCURRENCY)
177            )
178            response_iterators = []
179            for request_iterator in request_iterators:
180                response_iterator = multi_callable(request_iterator)
181                next(response_iterator)
182                response_iterators.append(response_iterator)
183        for request_iterator in request_iterators:
184            request_iterator.close()
185
186        for response_iterator in response_iterators:
187            self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
188
189    def test_many_concurrent_closes(self):
190        channel = grpc.insecure_channel("localhost:{}".format(self._port))
191        multi_callable = channel.stream_stream(
192            _STREAM_URI,
193            _registered_method=True,
194        )
195        request_iterator = _Pipe((b"abc",))
196        response_iterator = multi_callable(request_iterator)
197        next(response_iterator)
198        start = time.time()
199        end = start + _MORE_TIME
200
201        def sleep_some_time_then_close():
202            time.sleep(_SOME_TIME)
203            channel.close()
204
205        for _ in range(test_constants.THREAD_CONCURRENCY):
206            close_thread = threading.Thread(target=sleep_some_time_then_close)
207            close_thread.start()
208        while True:
209            request_iterator.add(b"def")
210            time.sleep(_BEAT)
211            if end < time.time():
212                break
213        request_iterator.close()
214
215        self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
216
217    def test_exception_in_callback(self):
218        with grpc.insecure_channel(
219            "localhost:{}".format(self._port)
220        ) as channel:
221            stream_multi_callable = channel.stream_stream(
222                _STREAM_URI,
223                _registered_method=True,
224            )
225            endless_iterator = itertools.repeat(b"abc")
226            stream_response_iterator = stream_multi_callable(endless_iterator)
227            future = channel.unary_unary(
228                _UNARY_URI,
229                _registered_method=True,
230            ).future(b"abc")
231
232            def on_done_callback(future):
233                raise Exception("This should not cause a deadlock.")
234
235            future.add_done_callback(on_done_callback)
236            future.result()
237
238
239if __name__ == "__main__":
240    logging.basicConfig()
241    unittest.main(verbosity=2)
242