1# Copyright 2018 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests server and client side compression.""" 15 16import itertools 17import logging 18import threading 19import time 20import unittest 21 22import grpc 23 24from tests.unit import test_common 25from tests.unit.framework.common import test_constants 26 27_BEAT = 0.5 28_SOME_TIME = 5 29_MORE_TIME = 10 30 31_SERVICE_NAME = "" 32_STREAM_URI = "Meffod" 33_UNARY_URI = "MeffodMan" 34 35 36class _StreamingMethodHandler(grpc.RpcMethodHandler): 37 request_streaming = True 38 response_streaming = True 39 request_deserializer = None 40 response_serializer = None 41 42 def stream_stream(self, request_iterator, servicer_context): 43 for request in request_iterator: 44 yield request * 2 45 46 47class _UnaryMethodHandler(grpc.RpcMethodHandler): 48 request_streaming = False 49 response_streaming = False 50 request_deserializer = None 51 response_serializer = None 52 53 def unary_unary(self, request, servicer_context): 54 return request * 2 55 56 57_STREAMING_METHOD_HANDLER = _StreamingMethodHandler() 58_UNARY_METHOD_HANDLER = _UnaryMethodHandler() 59 60 61_METHOD_HANDLERS = { 62 _STREAM_URI: _STREAMING_METHOD_HANDLER, 63 _UNARY_URI: _UNARY_METHOD_HANDLER, 64} 65 66 67class _Pipe(object): 68 def __init__(self, values): 69 self._condition = threading.Condition() 70 self._values = list(values) 71 self._open = True 72 73 def __iter__(self): 74 return self 75 76 def _next(self): 77 with self._condition: 78 while not self._values and self._open: 79 self._condition.wait() 80 if self._values: 81 return self._values.pop(0) 82 else: 83 raise StopIteration() 84 85 def next(self): 86 return self._next() 87 88 def __next__(self): 89 return self._next() 90 91 def add(self, value): 92 with self._condition: 93 self._values.append(value) 94 self._condition.notify() 95 96 def close(self): 97 with self._condition: 98 self._open = False 99 self._condition.notify() 100 101 def __enter__(self): 102 return self 103 104 def __exit__(self, type, value, traceback): 105 self.close() 106 107 108class ChannelCloseTest(unittest.TestCase): 109 def setUp(self): 110 self._server = test_common.test_server( 111 max_workers=test_constants.THREAD_CONCURRENCY 112 ) 113 self._server.add_registered_method_handlers("", _METHOD_HANDLERS) 114 self._port = self._server.add_insecure_port("[::]:0") 115 self._server.start() 116 117 def tearDown(self): 118 self._server.stop(None) 119 120 def test_close_immediately_after_call_invocation(self): 121 channel = grpc.insecure_channel("localhost:{}".format(self._port)) 122 multi_callable = channel.stream_stream( 123 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_URI), 124 _registered_method=True, 125 ) 126 request_iterator = _Pipe(()) 127 response_iterator = multi_callable(request_iterator) 128 channel.close() 129 request_iterator.close() 130 131 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 132 133 def test_close_while_call_active(self): 134 channel = grpc.insecure_channel("localhost:{}".format(self._port)) 135 multi_callable = channel.stream_stream( 136 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_URI), 137 _registered_method=True, 138 ) 139 request_iterator = _Pipe((b"abc",)) 140 response_iterator = multi_callable(request_iterator) 141 next(response_iterator) 142 channel.close() 143 request_iterator.close() 144 145 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 146 147 def test_context_manager_close_while_call_active(self): 148 with grpc.insecure_channel( 149 "localhost:{}".format(self._port) 150 ) as channel: # pylint: disable=bad-continuation 151 multi_callable = channel.stream_stream( 152 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_URI), 153 _registered_method=True, 154 ) 155 request_iterator = _Pipe((b"abc",)) 156 response_iterator = multi_callable(request_iterator) 157 next(response_iterator) 158 request_iterator.close() 159 160 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 161 162 def test_context_manager_close_while_many_calls_active(self): 163 with grpc.insecure_channel( 164 "localhost:{}".format(self._port) 165 ) as channel: # pylint: disable=bad-continuation 166 multi_callable = channel.stream_stream( 167 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_URI), 168 _registered_method=True, 169 ) 170 request_iterators = tuple( 171 _Pipe((b"abc",)) 172 for _ in range(test_constants.THREAD_CONCURRENCY) 173 ) 174 response_iterators = [] 175 for request_iterator in request_iterators: 176 response_iterator = multi_callable(request_iterator) 177 next(response_iterator) 178 response_iterators.append(response_iterator) 179 for request_iterator in request_iterators: 180 request_iterator.close() 181 182 for response_iterator in response_iterators: 183 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 184 185 def test_many_concurrent_closes(self): 186 channel = grpc.insecure_channel("localhost:{}".format(self._port)) 187 multi_callable = channel.stream_stream( 188 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_URI), 189 _registered_method=True, 190 ) 191 request_iterator = _Pipe((b"abc",)) 192 response_iterator = multi_callable(request_iterator) 193 next(response_iterator) 194 start = time.time() 195 end = start + _MORE_TIME 196 197 def sleep_some_time_then_close(): 198 time.sleep(_SOME_TIME) 199 channel.close() 200 201 for _ in range(test_constants.THREAD_CONCURRENCY): 202 close_thread = threading.Thread(target=sleep_some_time_then_close) 203 close_thread.start() 204 while True: 205 request_iterator.add(b"def") 206 time.sleep(_BEAT) 207 if end < time.time(): 208 break 209 request_iterator.close() 210 211 self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) 212 213 def test_exception_in_callback(self): 214 with grpc.insecure_channel( 215 "localhost:{}".format(self._port) 216 ) as channel: 217 stream_multi_callable = channel.stream_stream( 218 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_URI), 219 _registered_method=True, 220 ) 221 endless_iterator = itertools.repeat(b"abc") 222 stream_response_iterator = stream_multi_callable(endless_iterator) 223 future = channel.unary_unary( 224 grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_URI), 225 _registered_method=True, 226 ).future(b"abc") 227 228 def on_done_callback(future): 229 raise Exception("This should not cause a deadlock.") 230 231 future.add_done_callback(on_done_callback) 232 future.result() 233 234 235if __name__ == "__main__": 236 logging.basicConfig() 237 unittest.main(verbosity=2) 238