• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server responding with RESOURCE_EXHAUSTED."""
15
16import logging
17import threading
18import unittest
19
20import grpc
21from grpc import _channel
22from grpc.framework.foundation import logging_pool
23
24from tests.unit import test_common
25from tests.unit.framework.common import test_constants
26
27_REQUEST = b"\x00\x00\x00"
28_RESPONSE = b"\x00\x00\x00"
29
30_UNARY_UNARY = "/test/UnaryUnary"
31_UNARY_STREAM = "/test/UnaryStream"
32_STREAM_UNARY = "/test/StreamUnary"
33_STREAM_STREAM = "/test/StreamStream"
34
35
36class _TestTrigger(object):
37    def __init__(self, total_call_count):
38        self._total_call_count = total_call_count
39        self._pending_calls = 0
40        self._triggered = False
41        self._finish_condition = threading.Condition()
42        self._start_condition = threading.Condition()
43
44    # Wait for all calls be blocked in their handler
45    def await_calls(self):
46        with self._start_condition:
47            while self._pending_calls < self._total_call_count:
48                self._start_condition.wait()
49
50    # Block in a response handler and wait for a trigger
51    def await_trigger(self):
52        with self._start_condition:
53            self._pending_calls += 1
54            self._start_condition.notify()
55
56        with self._finish_condition:
57            if not self._triggered:
58                self._finish_condition.wait()
59
60    # Finish all response handlers
61    def trigger(self):
62        with self._finish_condition:
63            self._triggered = True
64            self._finish_condition.notify_all()
65
66
67def handle_unary_unary(trigger, request, servicer_context):
68    trigger.await_trigger()
69    return _RESPONSE
70
71
72def handle_unary_stream(trigger, request, servicer_context):
73    trigger.await_trigger()
74    for _ in range(test_constants.STREAM_LENGTH):
75        yield _RESPONSE
76
77
78def handle_stream_unary(trigger, request_iterator, servicer_context):
79    trigger.await_trigger()
80    # TODO(issue:#6891) We should be able to remove this loop
81    for request in request_iterator:
82        pass
83    return _RESPONSE
84
85
86def handle_stream_stream(trigger, request_iterator, servicer_context):
87    trigger.await_trigger()
88    # TODO(issue:#6891) We should be able to remove this loop,
89    # and replace with return; yield
90    for request in request_iterator:
91        yield _RESPONSE
92
93
94class _MethodHandler(grpc.RpcMethodHandler):
95    def __init__(self, trigger, request_streaming, response_streaming):
96        self.request_streaming = request_streaming
97        self.response_streaming = response_streaming
98        self.request_deserializer = None
99        self.response_serializer = None
100        self.unary_unary = None
101        self.unary_stream = None
102        self.stream_unary = None
103        self.stream_stream = None
104        if self.request_streaming and self.response_streaming:
105            self.stream_stream = lambda x, y: handle_stream_stream(
106                trigger, x, y
107            )
108        elif self.request_streaming:
109            self.stream_unary = lambda x, y: handle_stream_unary(trigger, x, y)
110        elif self.response_streaming:
111            self.unary_stream = lambda x, y: handle_unary_stream(trigger, x, y)
112        else:
113            self.unary_unary = lambda x, y: handle_unary_unary(trigger, x, y)
114
115
116class _GenericHandler(grpc.GenericRpcHandler):
117    def __init__(self, trigger):
118        self._trigger = trigger
119
120    def service(self, handler_call_details):
121        if handler_call_details.method == _UNARY_UNARY:
122            return _MethodHandler(self._trigger, False, False)
123        elif handler_call_details.method == _UNARY_STREAM:
124            return _MethodHandler(self._trigger, False, True)
125        elif handler_call_details.method == _STREAM_UNARY:
126            return _MethodHandler(self._trigger, True, False)
127        elif handler_call_details.method == _STREAM_STREAM:
128            return _MethodHandler(self._trigger, True, True)
129        else:
130            return None
131
132
133class ResourceExhaustedTest(unittest.TestCase):
134    def setUp(self):
135        self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
136        self._trigger = _TestTrigger(test_constants.THREAD_CONCURRENCY)
137        self._server = grpc.server(
138            self._server_pool,
139            handlers=(_GenericHandler(self._trigger),),
140            options=(("grpc.so_reuseport", 0),),
141            maximum_concurrent_rpcs=test_constants.THREAD_CONCURRENCY,
142        )
143        port = self._server.add_insecure_port("[::]:0")
144        self._server.start()
145        self._channel = grpc.insecure_channel("localhost:%d" % port)
146
147    def tearDown(self):
148        self._server.stop(0)
149        self._channel.close()
150
151    def testUnaryUnary(self):
152        multi_callable = self._channel.unary_unary(
153            _UNARY_UNARY,
154            _registered_method=True,
155        )
156        futures = []
157        for _ in range(test_constants.THREAD_CONCURRENCY):
158            futures.append(multi_callable.future(_REQUEST))
159
160        self._trigger.await_calls()
161
162        with self.assertRaises(grpc.RpcError) as exception_context:
163            multi_callable(_REQUEST)
164
165        self.assertEqual(
166            grpc.StatusCode.RESOURCE_EXHAUSTED,
167            exception_context.exception.code(),
168        )
169
170        future_exception = multi_callable.future(_REQUEST)
171        self.assertEqual(
172            grpc.StatusCode.RESOURCE_EXHAUSTED,
173            future_exception.exception().code(),
174        )
175
176        self._trigger.trigger()
177        for future in futures:
178            self.assertEqual(_RESPONSE, future.result())
179
180        # Ensure a new request can be handled
181        self.assertEqual(_RESPONSE, multi_callable(_REQUEST))
182
183    def testUnaryStream(self):
184        multi_callable = self._channel.unary_stream(
185            _UNARY_STREAM,
186            _registered_method=True,
187        )
188        calls = []
189        for _ in range(test_constants.THREAD_CONCURRENCY):
190            calls.append(multi_callable(_REQUEST))
191
192        self._trigger.await_calls()
193
194        with self.assertRaises(grpc.RpcError) as exception_context:
195            next(multi_callable(_REQUEST))
196
197        self.assertEqual(
198            grpc.StatusCode.RESOURCE_EXHAUSTED,
199            exception_context.exception.code(),
200        )
201
202        self._trigger.trigger()
203
204        for call in calls:
205            for response in call:
206                self.assertEqual(_RESPONSE, response)
207
208        # Ensure a new request can be handled
209        new_call = multi_callable(_REQUEST)
210        for response in new_call:
211            self.assertEqual(_RESPONSE, response)
212
213    def testStreamUnary(self):
214        multi_callable = self._channel.stream_unary(
215            _STREAM_UNARY,
216            _registered_method=True,
217        )
218        futures = []
219        request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
220        for _ in range(test_constants.THREAD_CONCURRENCY):
221            futures.append(multi_callable.future(request))
222
223        self._trigger.await_calls()
224
225        with self.assertRaises(grpc.RpcError) as exception_context:
226            multi_callable(request)
227
228        self.assertEqual(
229            grpc.StatusCode.RESOURCE_EXHAUSTED,
230            exception_context.exception.code(),
231        )
232
233        future_exception = multi_callable.future(request)
234        self.assertEqual(
235            grpc.StatusCode.RESOURCE_EXHAUSTED,
236            future_exception.exception().code(),
237        )
238
239        self._trigger.trigger()
240
241        for future in futures:
242            self.assertEqual(_RESPONSE, future.result())
243
244        # Ensure a new request can be handled
245        self.assertEqual(_RESPONSE, multi_callable(request))
246
247    def testStreamStream(self):
248        multi_callable = self._channel.stream_stream(
249            _STREAM_STREAM,
250            _registered_method=True,
251        )
252        calls = []
253        request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
254        for _ in range(test_constants.THREAD_CONCURRENCY):
255            calls.append(multi_callable(request))
256
257        self._trigger.await_calls()
258
259        with self.assertRaises(grpc.RpcError) as exception_context:
260            next(multi_callable(request))
261
262        self.assertEqual(
263            grpc.StatusCode.RESOURCE_EXHAUSTED,
264            exception_context.exception.code(),
265        )
266
267        self._trigger.trigger()
268
269        for call in calls:
270            for response in call:
271                self.assertEqual(_RESPONSE, response)
272
273        # Ensure a new request can be handled
274        new_call = multi_callable(request)
275        for response in new_call:
276            self.assertEqual(_RESPONSE, response)
277
278
279if __name__ == "__main__":
280    logging.basicConfig()
281    unittest.main(verbosity=2)
282