1# Copyright 2017 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests server responding with RESOURCE_EXHAUSTED.""" 15 16import logging 17import threading 18import unittest 19 20import grpc 21from grpc import _channel 22from grpc.framework.foundation import logging_pool 23 24from tests.unit import test_common 25from tests.unit.framework.common import test_constants 26 27_REQUEST = b"\x00\x00\x00" 28_RESPONSE = b"\x00\x00\x00" 29 30_SERVICE_NAME = "test" 31_UNARY_UNARY = "UnaryUnary" 32_UNARY_STREAM = "UnaryStream" 33_STREAM_UNARY = "StreamUnary" 34_STREAM_STREAM = "StreamStream" 35 36 37class _TestTrigger(object): 38 def __init__(self, total_call_count): 39 self._total_call_count = total_call_count 40 self._pending_calls = 0 41 self._triggered = False 42 self._finish_condition = threading.Condition() 43 self._start_condition = threading.Condition() 44 45 # Wait for all calls be blocked in their handler 46 def await_calls(self): 47 with self._start_condition: 48 while self._pending_calls < self._total_call_count: 49 self._start_condition.wait() 50 51 # Block in a response handler and wait for a trigger 52 def await_trigger(self): 53 with self._start_condition: 54 self._pending_calls += 1 55 self._start_condition.notify() 56 57 with self._finish_condition: 58 if not self._triggered: 59 self._finish_condition.wait() 60 61 # Finish all response handlers 62 def trigger(self): 63 with self._finish_condition: 64 self._triggered = True 65 self._finish_condition.notify_all() 66 67 68def handle_unary_unary(trigger, request, servicer_context): 69 trigger.await_trigger() 70 return _RESPONSE 71 72 73def handle_unary_stream(trigger, request, servicer_context): 74 trigger.await_trigger() 75 for _ in range(test_constants.STREAM_LENGTH): 76 yield _RESPONSE 77 78 79def handle_stream_unary(trigger, request_iterator, servicer_context): 80 trigger.await_trigger() 81 # TODO(issue:#6891) We should be able to remove this loop 82 for request in request_iterator: 83 pass 84 return _RESPONSE 85 86 87def handle_stream_stream(trigger, request_iterator, servicer_context): 88 trigger.await_trigger() 89 # TODO(issue:#6891) We should be able to remove this loop, 90 # and replace with return; yield 91 for request in request_iterator: 92 yield _RESPONSE 93 94 95class _MethodHandler(grpc.RpcMethodHandler): 96 def __init__(self, trigger, request_streaming, response_streaming): 97 self.request_streaming = request_streaming 98 self.response_streaming = response_streaming 99 self.request_deserializer = None 100 self.response_serializer = None 101 self.unary_unary = None 102 self.unary_stream = None 103 self.stream_unary = None 104 self.stream_stream = None 105 if self.request_streaming and self.response_streaming: 106 self.stream_stream = lambda x, y: handle_stream_stream( 107 trigger, x, y 108 ) 109 elif self.request_streaming: 110 self.stream_unary = lambda x, y: handle_stream_unary(trigger, x, y) 111 elif self.response_streaming: 112 self.unary_stream = lambda x, y: handle_unary_stream(trigger, x, y) 113 else: 114 self.unary_unary = lambda x, y: handle_unary_unary(trigger, x, y) 115 116 117def get_method_handlers(trigger): 118 return { 119 _UNARY_UNARY: _MethodHandler(trigger, False, False), 120 _UNARY_STREAM: _MethodHandler(trigger, False, True), 121 _STREAM_UNARY: _MethodHandler(trigger, True, False), 122 _STREAM_STREAM: _MethodHandler(trigger, True, True), 123 } 124 125 126class ResourceExhaustedTest(unittest.TestCase): 127 def setUp(self): 128 self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) 129 self._trigger = _TestTrigger(test_constants.THREAD_CONCURRENCY) 130 self._server = grpc.server( 131 self._server_pool, 132 options=(("grpc.so_reuseport", 0),), 133 maximum_concurrent_rpcs=test_constants.THREAD_CONCURRENCY, 134 ) 135 self._server.add_registered_method_handlers( 136 _SERVICE_NAME, get_method_handlers(self._trigger) 137 ) 138 port = self._server.add_insecure_port("[::]:0") 139 self._server.start() 140 self._channel = grpc.insecure_channel("localhost:%d" % port) 141 142 def tearDown(self): 143 self._server.stop(0) 144 self._channel.close() 145 146 def testUnaryUnary(self): 147 multi_callable = self._channel.unary_unary( 148 grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_UNARY), 149 _registered_method=True, 150 ) 151 futures = [] 152 for _ in range(test_constants.THREAD_CONCURRENCY): 153 futures.append(multi_callable.future(_REQUEST)) 154 155 self._trigger.await_calls() 156 157 with self.assertRaises(grpc.RpcError) as exception_context: 158 multi_callable(_REQUEST) 159 160 self.assertEqual( 161 grpc.StatusCode.RESOURCE_EXHAUSTED, 162 exception_context.exception.code(), 163 ) 164 165 future_exception = multi_callable.future(_REQUEST) 166 self.assertEqual( 167 grpc.StatusCode.RESOURCE_EXHAUSTED, 168 future_exception.exception().code(), 169 ) 170 171 self._trigger.trigger() 172 for future in futures: 173 self.assertEqual(_RESPONSE, future.result()) 174 175 # Ensure a new request can be handled 176 self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) 177 178 def testUnaryStream(self): 179 multi_callable = self._channel.unary_stream( 180 grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_STREAM), 181 _registered_method=True, 182 ) 183 calls = [] 184 for _ in range(test_constants.THREAD_CONCURRENCY): 185 calls.append(multi_callable(_REQUEST)) 186 187 self._trigger.await_calls() 188 189 with self.assertRaises(grpc.RpcError) as exception_context: 190 next(multi_callable(_REQUEST)) 191 192 self.assertEqual( 193 grpc.StatusCode.RESOURCE_EXHAUSTED, 194 exception_context.exception.code(), 195 ) 196 197 self._trigger.trigger() 198 199 for call in calls: 200 for response in call: 201 self.assertEqual(_RESPONSE, response) 202 203 # Ensure a new request can be handled 204 new_call = multi_callable(_REQUEST) 205 for response in new_call: 206 self.assertEqual(_RESPONSE, response) 207 208 def testStreamUnary(self): 209 multi_callable = self._channel.stream_unary( 210 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_UNARY), 211 _registered_method=True, 212 ) 213 futures = [] 214 request = iter([_REQUEST] * test_constants.STREAM_LENGTH) 215 for _ in range(test_constants.THREAD_CONCURRENCY): 216 futures.append(multi_callable.future(request)) 217 218 self._trigger.await_calls() 219 220 with self.assertRaises(grpc.RpcError) as exception_context: 221 multi_callable(request) 222 223 self.assertEqual( 224 grpc.StatusCode.RESOURCE_EXHAUSTED, 225 exception_context.exception.code(), 226 ) 227 228 future_exception = multi_callable.future(request) 229 self.assertEqual( 230 grpc.StatusCode.RESOURCE_EXHAUSTED, 231 future_exception.exception().code(), 232 ) 233 234 self._trigger.trigger() 235 236 for future in futures: 237 self.assertEqual(_RESPONSE, future.result()) 238 239 # Ensure a new request can be handled 240 self.assertEqual(_RESPONSE, multi_callable(request)) 241 242 def testStreamStream(self): 243 multi_callable = self._channel.stream_stream( 244 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_STREAM), 245 _registered_method=True, 246 ) 247 calls = [] 248 request = iter([_REQUEST] * test_constants.STREAM_LENGTH) 249 for _ in range(test_constants.THREAD_CONCURRENCY): 250 calls.append(multi_callable(request)) 251 252 self._trigger.await_calls() 253 254 with self.assertRaises(grpc.RpcError) as exception_context: 255 next(multi_callable(request)) 256 257 self.assertEqual( 258 grpc.StatusCode.RESOURCE_EXHAUSTED, 259 exception_context.exception.code(), 260 ) 261 262 self._trigger.trigger() 263 264 for call in calls: 265 for response in call: 266 self.assertEqual(_RESPONSE, response) 267 268 # Ensure a new request can be handled 269 new_call = multi_callable(request) 270 for response in new_call: 271 self.assertEqual(_RESPONSE, response) 272 273 274if __name__ == "__main__": 275 logging.basicConfig() 276 unittest.main(verbosity=2) 277