1# Copyright 2018 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests server and client side compression.""" 15 16import itertools 17import logging 18import threading 19import time 20import unittest 21 22import grpc 23 24from tests.unit import test_common 25from tests.unit.framework.common import test_constants 26 27_BEAT = 0.5 28_SOME_TIME = 5 29_MORE_TIME = 10 30 31_STREAM_URI = "Meffod" 32_UNARY_URI = "MeffodMan" 33 34 35class _StreamingMethodHandler(grpc.RpcMethodHandler): 36 request_streaming = True 37 response_streaming = True 38 request_deserializer = None 39 response_serializer = None 40 41 def stream_stream(self, request_iterator, servicer_context): 42 for request in request_iterator: 43 yield request * 2 44 45 46class _UnaryMethodHandler(grpc.RpcMethodHandler): 47 request_streaming = False 48 response_streaming = False 49 request_deserializer = None 50 response_serializer = None 51 52 def unary_unary(self, request, servicer_context): 53 return request * 2 54 55 56_STREAMING_METHOD_HANDLER = _StreamingMethodHandler() 57_UNARY_METHOD_HANDLER = _UnaryMethodHandler() 58 59 60class _GenericHandler(grpc.GenericRpcHandler): 61 def service(self, handler_call_details): 62 if handler_call_details.method == _STREAM_URI: 63 return _STREAMING_METHOD_HANDLER 64 else: 65 return _UNARY_METHOD_HANDLER 66 67 68_GENERIC_HANDLER = _GenericHandler() 69 70 71class _Pipe(object): 72 def __init__(self, values): 73 self._condition = threading.Condition() 74 self._values = list(values) 75 self._open = True 76 77 def __iter__(self): 78 return self 79 80 def _next(self): 81 with self._condition: 82 while not self._values and self._open: 83 self._condition.wait() 84 if self._values: 85 return self._values.pop(0) 86 else: 87 raise StopIteration() 88 89 def next(self): 90 return self._next() 91 92 def __next__(self): 93 return self._next() 94 95 def add(self, value): 96 with self._condition: 97 self._values.append(value) 98 self._condition.notify() 99 100 def close(self): 101 with self._condition: 102 self._open = False 103 self._condition.notify() 104 105 def __enter__(self): 106 return self 107 108 def __exit__(self, type, value, traceback): 109 self.close() 110 111 112class ChannelCloseTest(unittest.TestCase): 113 def setUp(self): 114 self._server = test_common.test_server( 115 max_workers=test_constants.THREAD_CONCURRENCY 116 ) 117 self._server.add_generic_rpc_handlers((_GENERIC_HANDLER,)) 118 self._port = self._server.add_insecure_port("[::]:0") 119 self._server.start() 120 121 def tearDown(self): 122 self._server.stop(None) 123 124 def test_close_immediately_after_call_invocation(self): 125 channel = grpc.insecure_channel("localhost:{}".format(self._port)) 126 multi_callable = channel.stream_stream( 127 _STREAM_URI, 128 _registered_method=True, 129 ) 130 request_iterator = _Pipe(()) 131 response_iterator = multi_callable(request_iterator) 132 channel.close() 133 request_iterator.close() 134 135 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 136 137 def test_close_while_call_active(self): 138 channel = grpc.insecure_channel("localhost:{}".format(self._port)) 139 multi_callable = channel.stream_stream( 140 _STREAM_URI, 141 _registered_method=True, 142 ) 143 request_iterator = _Pipe((b"abc",)) 144 response_iterator = multi_callable(request_iterator) 145 next(response_iterator) 146 channel.close() 147 request_iterator.close() 148 149 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 150 151 def test_context_manager_close_while_call_active(self): 152 with grpc.insecure_channel( 153 "localhost:{}".format(self._port) 154 ) as channel: # pylint: disable=bad-continuation 155 multi_callable = channel.stream_stream( 156 _STREAM_URI, 157 _registered_method=True, 158 ) 159 request_iterator = _Pipe((b"abc",)) 160 response_iterator = multi_callable(request_iterator) 161 next(response_iterator) 162 request_iterator.close() 163 164 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 165 166 def test_context_manager_close_while_many_calls_active(self): 167 with grpc.insecure_channel( 168 "localhost:{}".format(self._port) 169 ) as channel: # pylint: disable=bad-continuation 170 multi_callable = channel.stream_stream( 171 _STREAM_URI, 172 _registered_method=True, 173 ) 174 request_iterators = tuple( 175 _Pipe((b"abc",)) 176 for _ in range(test_constants.THREAD_CONCURRENCY) 177 ) 178 response_iterators = [] 179 for request_iterator in request_iterators: 180 response_iterator = multi_callable(request_iterator) 181 next(response_iterator) 182 response_iterators.append(response_iterator) 183 for request_iterator in request_iterators: 184 request_iterator.close() 185 186 for response_iterator in response_iterators: 187 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 188 189 def test_many_concurrent_closes(self): 190 channel = grpc.insecure_channel("localhost:{}".format(self._port)) 191 multi_callable = channel.stream_stream( 192 _STREAM_URI, 193 _registered_method=True, 194 ) 195 request_iterator = _Pipe((b"abc",)) 196 response_iterator = multi_callable(request_iterator) 197 next(response_iterator) 198 start = time.time() 199 end = start + _MORE_TIME 200 201 def sleep_some_time_then_close(): 202 time.sleep(_SOME_TIME) 203 channel.close() 204 205 for _ in range(test_constants.THREAD_CONCURRENCY): 206 close_thread = threading.Thread(target=sleep_some_time_then_close) 207 close_thread.start() 208 while True: 209 request_iterator.add(b"def") 210 time.sleep(_BEAT) 211 if end < time.time(): 212 break 213 request_iterator.close() 214 215 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 216 217 def test_exception_in_callback(self): 218 with grpc.insecure_channel( 219 "localhost:{}".format(self._port) 220 ) as channel: 221 stream_multi_callable = channel.stream_stream( 222 _STREAM_URI, 223 _registered_method=True, 224 ) 225 endless_iterator = itertools.repeat(b"abc") 226 stream_response_iterator = stream_multi_callable(endless_iterator) 227 future = channel.unary_unary( 228 _UNARY_URI, 229 _registered_method=True, 230 ).future(b"abc") 231 232 def on_done_callback(future): 233 raise Exception("This should not cause a deadlock.") 234 235 future.add_done_callback(on_done_callback) 236 future.result() 237 238 239if __name__ == "__main__": 240 logging.basicConfig() 241 unittest.main(verbosity=2) 242