1# Copyright 2017 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests server responding with RESOURCE_EXHAUSTED.""" 15 16import logging 17import threading 18import unittest 19 20import grpc 21from grpc import _channel 22from grpc.framework.foundation import logging_pool 23 24from tests.unit import test_common 25from tests.unit.framework.common import test_constants 26 27_REQUEST = b"\x00\x00\x00" 28_RESPONSE = b"\x00\x00\x00" 29 30_UNARY_UNARY = "/test/UnaryUnary" 31_UNARY_STREAM = "/test/UnaryStream" 32_STREAM_UNARY = "/test/StreamUnary" 33_STREAM_STREAM = "/test/StreamStream" 34 35 36class _TestTrigger(object): 37 def __init__(self, total_call_count): 38 self._total_call_count = total_call_count 39 self._pending_calls = 0 40 self._triggered = False 41 self._finish_condition = threading.Condition() 42 self._start_condition = threading.Condition() 43 44 # Wait for all calls be blocked in their handler 45 def await_calls(self): 46 with self._start_condition: 47 while self._pending_calls < self._total_call_count: 48 self._start_condition.wait() 49 50 # Block in a response handler and wait for a trigger 51 def await_trigger(self): 52 with self._start_condition: 53 self._pending_calls += 1 54 self._start_condition.notify() 55 56 with self._finish_condition: 57 if not self._triggered: 58 self._finish_condition.wait() 59 60 # Finish all response handlers 61 def trigger(self): 62 with self._finish_condition: 63 self._triggered = True 64 self._finish_condition.notify_all() 65 66 67def handle_unary_unary(trigger, request, servicer_context): 68 trigger.await_trigger() 69 return _RESPONSE 70 71 72def handle_unary_stream(trigger, request, servicer_context): 73 trigger.await_trigger() 74 for _ in range(test_constants.STREAM_LENGTH): 75 yield _RESPONSE 76 77 78def handle_stream_unary(trigger, request_iterator, servicer_context): 79 trigger.await_trigger() 80 # TODO(issue:#6891) We should be able to remove this loop 81 for request in request_iterator: 82 pass 83 return _RESPONSE 84 85 86def handle_stream_stream(trigger, request_iterator, servicer_context): 87 trigger.await_trigger() 88 # TODO(issue:#6891) We should be able to remove this loop, 89 # and replace with return; yield 90 for request in request_iterator: 91 yield _RESPONSE 92 93 94class _MethodHandler(grpc.RpcMethodHandler): 95 def __init__(self, trigger, request_streaming, response_streaming): 96 self.request_streaming = request_streaming 97 self.response_streaming = response_streaming 98 self.request_deserializer = None 99 self.response_serializer = None 100 self.unary_unary = None 101 self.unary_stream = None 102 self.stream_unary = None 103 self.stream_stream = None 104 if self.request_streaming and self.response_streaming: 105 self.stream_stream = lambda x, y: handle_stream_stream( 106 trigger, x, y 107 ) 108 elif self.request_streaming: 109 self.stream_unary = lambda x, y: handle_stream_unary(trigger, x, y) 110 elif self.response_streaming: 111 self.unary_stream = lambda x, y: handle_unary_stream(trigger, x, y) 112 else: 113 self.unary_unary = lambda x, y: handle_unary_unary(trigger, x, y) 114 115 116class _GenericHandler(grpc.GenericRpcHandler): 117 def __init__(self, trigger): 118 self._trigger = trigger 119 120 def service(self, handler_call_details): 121 if handler_call_details.method == _UNARY_UNARY: 122 return _MethodHandler(self._trigger, False, False) 123 elif handler_call_details.method == _UNARY_STREAM: 124 return _MethodHandler(self._trigger, False, True) 125 elif handler_call_details.method == _STREAM_UNARY: 126 return _MethodHandler(self._trigger, True, False) 127 elif handler_call_details.method == _STREAM_STREAM: 128 return _MethodHandler(self._trigger, True, True) 129 else: 130 return None 131 132 133class ResourceExhaustedTest(unittest.TestCase): 134 def setUp(self): 135 self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) 136 self._trigger = _TestTrigger(test_constants.THREAD_CONCURRENCY) 137 self._server = grpc.server( 138 self._server_pool, 139 handlers=(_GenericHandler(self._trigger),), 140 options=(("grpc.so_reuseport", 0),), 141 maximum_concurrent_rpcs=test_constants.THREAD_CONCURRENCY, 142 ) 143 port = self._server.add_insecure_port("[::]:0") 144 self._server.start() 145 self._channel = grpc.insecure_channel("localhost:%d" % port) 146 147 def tearDown(self): 148 self._server.stop(0) 149 self._channel.close() 150 151 def testUnaryUnary(self): 152 multi_callable = self._channel.unary_unary( 153 _UNARY_UNARY, 154 _registered_method=True, 155 ) 156 futures = [] 157 for _ in range(test_constants.THREAD_CONCURRENCY): 158 futures.append(multi_callable.future(_REQUEST)) 159 160 self._trigger.await_calls() 161 162 with self.assertRaises(grpc.RpcError) as exception_context: 163 multi_callable(_REQUEST) 164 165 self.assertEqual( 166 grpc.StatusCode.RESOURCE_EXHAUSTED, 167 exception_context.exception.code(), 168 ) 169 170 future_exception = multi_callable.future(_REQUEST) 171 self.assertEqual( 172 grpc.StatusCode.RESOURCE_EXHAUSTED, 173 future_exception.exception().code(), 174 ) 175 176 self._trigger.trigger() 177 for future in futures: 178 self.assertEqual(_RESPONSE, future.result()) 179 180 # Ensure a new request can be handled 181 self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) 182 183 def testUnaryStream(self): 184 multi_callable = self._channel.unary_stream( 185 _UNARY_STREAM, 186 _registered_method=True, 187 ) 188 calls = [] 189 for _ in range(test_constants.THREAD_CONCURRENCY): 190 calls.append(multi_callable(_REQUEST)) 191 192 self._trigger.await_calls() 193 194 with self.assertRaises(grpc.RpcError) as exception_context: 195 next(multi_callable(_REQUEST)) 196 197 self.assertEqual( 198 grpc.StatusCode.RESOURCE_EXHAUSTED, 199 exception_context.exception.code(), 200 ) 201 202 self._trigger.trigger() 203 204 for call in calls: 205 for response in call: 206 self.assertEqual(_RESPONSE, response) 207 208 # Ensure a new request can be handled 209 new_call = multi_callable(_REQUEST) 210 for response in new_call: 211 self.assertEqual(_RESPONSE, response) 212 213 def testStreamUnary(self): 214 multi_callable = self._channel.stream_unary( 215 _STREAM_UNARY, 216 _registered_method=True, 217 ) 218 futures = [] 219 request = iter([_REQUEST] * test_constants.STREAM_LENGTH) 220 for _ in range(test_constants.THREAD_CONCURRENCY): 221 futures.append(multi_callable.future(request)) 222 223 self._trigger.await_calls() 224 225 with self.assertRaises(grpc.RpcError) as exception_context: 226 multi_callable(request) 227 228 self.assertEqual( 229 grpc.StatusCode.RESOURCE_EXHAUSTED, 230 exception_context.exception.code(), 231 ) 232 233 future_exception = multi_callable.future(request) 234 self.assertEqual( 235 grpc.StatusCode.RESOURCE_EXHAUSTED, 236 future_exception.exception().code(), 237 ) 238 239 self._trigger.trigger() 240 241 for future in futures: 242 self.assertEqual(_RESPONSE, future.result()) 243 244 # Ensure a new request can be handled 245 self.assertEqual(_RESPONSE, multi_callable(request)) 246 247 def testStreamStream(self): 248 multi_callable = self._channel.stream_stream( 249 _STREAM_STREAM, 250 _registered_method=True, 251 ) 252 calls = [] 253 request = iter([_REQUEST] * test_constants.STREAM_LENGTH) 254 for _ in range(test_constants.THREAD_CONCURRENCY): 255 calls.append(multi_callable(request)) 256 257 self._trigger.await_calls() 258 259 with self.assertRaises(grpc.RpcError) as exception_context: 260 next(multi_callable(request)) 261 262 self.assertEqual( 263 grpc.StatusCode.RESOURCE_EXHAUSTED, 264 exception_context.exception.code(), 265 ) 266 267 self._trigger.trigger() 268 269 for call in calls: 270 for response in call: 271 self.assertEqual(_RESPONSE, response) 272 273 # Ensure a new request can be handled 274 new_call = multi_callable(request) 275 for response in new_call: 276 self.assertEqual(_RESPONSE, response) 277 278 279if __name__ == "__main__": 280 logging.basicConfig() 281 unittest.main(verbosity=2) 282