• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server and client side compression."""
15
16import itertools
17import logging
18import threading
19import time
20import unittest
21
22import grpc
23
24from tests.unit import test_common
25from tests.unit.framework.common import test_constants
26
27_BEAT = 0.5
28_SOME_TIME = 5
29_MORE_TIME = 10
30
31_STREAM_URI = 'Meffod'
32_UNARY_URI = 'MeffodMan'
33
34
35class _StreamingMethodHandler(grpc.RpcMethodHandler):
36
37    request_streaming = True
38    response_streaming = True
39    request_deserializer = None
40    response_serializer = None
41
42    def stream_stream(self, request_iterator, servicer_context):
43        for request in request_iterator:
44            yield request * 2
45
46
47class _UnaryMethodHandler(grpc.RpcMethodHandler):
48
49    request_streaming = False
50    response_streaming = False
51    request_deserializer = None
52    response_serializer = None
53
54    def unary_unary(self, request, servicer_context):
55        return request * 2
56
57
58_STREAMING_METHOD_HANDLER = _StreamingMethodHandler()
59_UNARY_METHOD_HANDLER = _UnaryMethodHandler()
60
61
62class _GenericHandler(grpc.GenericRpcHandler):
63
64    def service(self, handler_call_details):
65        if handler_call_details.method == _STREAM_URI:
66            return _STREAMING_METHOD_HANDLER
67        else:
68            return _UNARY_METHOD_HANDLER
69
70
71_GENERIC_HANDLER = _GenericHandler()
72
73
74class _Pipe(object):
75
76    def __init__(self, values):
77        self._condition = threading.Condition()
78        self._values = list(values)
79        self._open = True
80
81    def __iter__(self):
82        return self
83
84    def _next(self):
85        with self._condition:
86            while not self._values and self._open:
87                self._condition.wait()
88            if self._values:
89                return self._values.pop(0)
90            else:
91                raise StopIteration()
92
93    def next(self):
94        return self._next()
95
96    def __next__(self):
97        return self._next()
98
99    def add(self, value):
100        with self._condition:
101            self._values.append(value)
102            self._condition.notify()
103
104    def close(self):
105        with self._condition:
106            self._open = False
107            self._condition.notify()
108
109    def __enter__(self):
110        return self
111
112    def __exit__(self, type, value, traceback):
113        self.close()
114
115
116class ChannelCloseTest(unittest.TestCase):
117
118    def setUp(self):
119        self._server = test_common.test_server(
120            max_workers=test_constants.THREAD_CONCURRENCY)
121        self._server.add_generic_rpc_handlers((_GENERIC_HANDLER,))
122        self._port = self._server.add_insecure_port('[::]:0')
123        self._server.start()
124
125    def tearDown(self):
126        self._server.stop(None)
127
128    def test_close_immediately_after_call_invocation(self):
129        channel = grpc.insecure_channel('localhost:{}'.format(self._port))
130        multi_callable = channel.stream_stream(_STREAM_URI)
131        request_iterator = _Pipe(())
132        response_iterator = multi_callable(request_iterator)
133        channel.close()
134        request_iterator.close()
135
136        self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
137
138    def test_close_while_call_active(self):
139        channel = grpc.insecure_channel('localhost:{}'.format(self._port))
140        multi_callable = channel.stream_stream(_STREAM_URI)
141        request_iterator = _Pipe((b'abc',))
142        response_iterator = multi_callable(request_iterator)
143        next(response_iterator)
144        channel.close()
145        request_iterator.close()
146
147        self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
148
149    def test_context_manager_close_while_call_active(self):
150        with grpc.insecure_channel('localhost:{}'.format(
151                self._port)) as channel:  # pylint: disable=bad-continuation
152            multi_callable = channel.stream_stream(_STREAM_URI)
153            request_iterator = _Pipe((b'abc',))
154            response_iterator = multi_callable(request_iterator)
155            next(response_iterator)
156        request_iterator.close()
157
158        self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
159
160    def test_context_manager_close_while_many_calls_active(self):
161        with grpc.insecure_channel('localhost:{}'.format(
162                self._port)) as channel:  # pylint: disable=bad-continuation
163            multi_callable = channel.stream_stream(_STREAM_URI)
164            request_iterators = tuple(
165                _Pipe((b'abc',))
166                for _ in range(test_constants.THREAD_CONCURRENCY))
167            response_iterators = []
168            for request_iterator in request_iterators:
169                response_iterator = multi_callable(request_iterator)
170                next(response_iterator)
171                response_iterators.append(response_iterator)
172        for request_iterator in request_iterators:
173            request_iterator.close()
174
175        for response_iterator in response_iterators:
176            self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
177
178    def test_many_concurrent_closes(self):
179        channel = grpc.insecure_channel('localhost:{}'.format(self._port))
180        multi_callable = channel.stream_stream(_STREAM_URI)
181        request_iterator = _Pipe((b'abc',))
182        response_iterator = multi_callable(request_iterator)
183        next(response_iterator)
184        start = time.time()
185        end = start + _MORE_TIME
186
187        def sleep_some_time_then_close():
188            time.sleep(_SOME_TIME)
189            channel.close()
190
191        for _ in range(test_constants.THREAD_CONCURRENCY):
192            close_thread = threading.Thread(target=sleep_some_time_then_close)
193            close_thread.start()
194        while True:
195            request_iterator.add(b'def')
196            time.sleep(_BEAT)
197            if end < time.time():
198                break
199        request_iterator.close()
200
201        self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
202
203    def test_exception_in_callback(self):
204        with grpc.insecure_channel('localhost:{}'.format(
205                self._port)) as channel:
206            stream_multi_callable = channel.stream_stream(_STREAM_URI)
207            endless_iterator = itertools.repeat(b'abc')
208            stream_response_iterator = stream_multi_callable(endless_iterator)
209            future = channel.unary_unary(_UNARY_URI).future(b'abc')
210
211            def on_done_callback(future):
212                raise Exception("This should not cause a deadlock.")
213
214            future.add_done_callback(on_done_callback)
215            future.result()
216
217
218if __name__ == '__main__':
219    logging.basicConfig()
220    unittest.main(verbosity=2)
221