1# Copyright 2018 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests server and client side compression.""" 15 16import itertools 17import logging 18import threading 19import time 20import unittest 21 22import grpc 23 24from tests.unit import test_common 25from tests.unit.framework.common import test_constants 26 27_BEAT = 0.5 28_SOME_TIME = 5 29_MORE_TIME = 10 30 31_STREAM_URI = 'Meffod' 32_UNARY_URI = 'MeffodMan' 33 34 35class _StreamingMethodHandler(grpc.RpcMethodHandler): 36 37 request_streaming = True 38 response_streaming = True 39 request_deserializer = None 40 response_serializer = None 41 42 def stream_stream(self, request_iterator, servicer_context): 43 for request in request_iterator: 44 yield request * 2 45 46 47class _UnaryMethodHandler(grpc.RpcMethodHandler): 48 49 request_streaming = False 50 response_streaming = False 51 request_deserializer = None 52 response_serializer = None 53 54 def unary_unary(self, request, servicer_context): 55 return request * 2 56 57 58_STREAMING_METHOD_HANDLER = _StreamingMethodHandler() 59_UNARY_METHOD_HANDLER = _UnaryMethodHandler() 60 61 62class _GenericHandler(grpc.GenericRpcHandler): 63 64 def service(self, handler_call_details): 65 if handler_call_details.method == _STREAM_URI: 66 return _STREAMING_METHOD_HANDLER 67 else: 68 return _UNARY_METHOD_HANDLER 69 70 71_GENERIC_HANDLER = _GenericHandler() 72 73 74class _Pipe(object): 75 76 def __init__(self, values): 77 self._condition = threading.Condition() 78 self._values = list(values) 79 self._open = True 80 81 def __iter__(self): 82 return self 83 84 def _next(self): 85 with self._condition: 86 while not self._values and self._open: 87 self._condition.wait() 88 if self._values: 89 return self._values.pop(0) 90 else: 91 raise StopIteration() 92 93 def next(self): 94 return self._next() 95 96 def __next__(self): 97 return self._next() 98 99 def add(self, value): 100 with self._condition: 101 self._values.append(value) 102 self._condition.notify() 103 104 def close(self): 105 with self._condition: 106 self._open = False 107 self._condition.notify() 108 109 def __enter__(self): 110 return self 111 112 def __exit__(self, type, value, traceback): 113 self.close() 114 115 116class ChannelCloseTest(unittest.TestCase): 117 118 def setUp(self): 119 self._server = test_common.test_server( 120 max_workers=test_constants.THREAD_CONCURRENCY) 121 self._server.add_generic_rpc_handlers((_GENERIC_HANDLER,)) 122 self._port = self._server.add_insecure_port('[::]:0') 123 self._server.start() 124 125 def tearDown(self): 126 self._server.stop(None) 127 128 def test_close_immediately_after_call_invocation(self): 129 channel = grpc.insecure_channel('localhost:{}'.format(self._port)) 130 multi_callable = channel.stream_stream(_STREAM_URI) 131 request_iterator = _Pipe(()) 132 response_iterator = multi_callable(request_iterator) 133 channel.close() 134 request_iterator.close() 135 136 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 137 138 def test_close_while_call_active(self): 139 channel = grpc.insecure_channel('localhost:{}'.format(self._port)) 140 multi_callable = channel.stream_stream(_STREAM_URI) 141 request_iterator = _Pipe((b'abc',)) 142 response_iterator = multi_callable(request_iterator) 143 next(response_iterator) 144 channel.close() 145 request_iterator.close() 146 147 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 148 149 def test_context_manager_close_while_call_active(self): 150 with grpc.insecure_channel('localhost:{}'.format( 151 self._port)) as channel: # pylint: disable=bad-continuation 152 multi_callable = channel.stream_stream(_STREAM_URI) 153 request_iterator = _Pipe((b'abc',)) 154 response_iterator = multi_callable(request_iterator) 155 next(response_iterator) 156 request_iterator.close() 157 158 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 159 160 def test_context_manager_close_while_many_calls_active(self): 161 with grpc.insecure_channel('localhost:{}'.format( 162 self._port)) as channel: # pylint: disable=bad-continuation 163 multi_callable = channel.stream_stream(_STREAM_URI) 164 request_iterators = tuple( 165 _Pipe((b'abc',)) 166 for _ in range(test_constants.THREAD_CONCURRENCY)) 167 response_iterators = [] 168 for request_iterator in request_iterators: 169 response_iterator = multi_callable(request_iterator) 170 next(response_iterator) 171 response_iterators.append(response_iterator) 172 for request_iterator in request_iterators: 173 request_iterator.close() 174 175 for response_iterator in response_iterators: 176 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 177 178 def test_many_concurrent_closes(self): 179 channel = grpc.insecure_channel('localhost:{}'.format(self._port)) 180 multi_callable = channel.stream_stream(_STREAM_URI) 181 request_iterator = _Pipe((b'abc',)) 182 response_iterator = multi_callable(request_iterator) 183 next(response_iterator) 184 start = time.time() 185 end = start + _MORE_TIME 186 187 def sleep_some_time_then_close(): 188 time.sleep(_SOME_TIME) 189 channel.close() 190 191 for _ in range(test_constants.THREAD_CONCURRENCY): 192 close_thread = threading.Thread(target=sleep_some_time_then_close) 193 close_thread.start() 194 while True: 195 request_iterator.add(b'def') 196 time.sleep(_BEAT) 197 if end < time.time(): 198 break 199 request_iterator.close() 200 201 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 202 203 def test_exception_in_callback(self): 204 with grpc.insecure_channel('localhost:{}'.format( 205 self._port)) as channel: 206 stream_multi_callable = channel.stream_stream(_STREAM_URI) 207 endless_iterator = itertools.repeat(b'abc') 208 stream_response_iterator = stream_multi_callable(endless_iterator) 209 future = channel.unary_unary(_UNARY_URI).future(b'abc') 210 211 def on_done_callback(future): 212 raise Exception("This should not cause a deadlock.") 213 214 future.add_done_callback(on_done_callback) 215 future.result() 216 217 218if __name__ == '__main__': 219 logging.basicConfig() 220 unittest.main(verbosity=2) 221