• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server responding with RESOURCE_EXHAUSTED."""
15
16import logging
17import threading
18import unittest
19
20import grpc
21from grpc import _channel
22from grpc.framework.foundation import logging_pool
23
24from tests.unit import test_common
25from tests.unit.framework.common import test_constants
26
27_REQUEST = b"\x00\x00\x00"
28_RESPONSE = b"\x00\x00\x00"
29
30_SERVICE_NAME = "test"
31_UNARY_UNARY = "UnaryUnary"
32_UNARY_STREAM = "UnaryStream"
33_STREAM_UNARY = "StreamUnary"
34_STREAM_STREAM = "StreamStream"
35
36
37class _TestTrigger(object):
38    def __init__(self, total_call_count):
39        self._total_call_count = total_call_count
40        self._pending_calls = 0
41        self._triggered = False
42        self._finish_condition = threading.Condition()
43        self._start_condition = threading.Condition()
44
45    # Wait for all calls be blocked in their handler
46    def await_calls(self):
47        with self._start_condition:
48            while self._pending_calls < self._total_call_count:
49                self._start_condition.wait()
50
51    # Block in a response handler and wait for a trigger
52    def await_trigger(self):
53        with self._start_condition:
54            self._pending_calls += 1
55            self._start_condition.notify()
56
57        with self._finish_condition:
58            if not self._triggered:
59                self._finish_condition.wait()
60
61    # Finish all response handlers
62    def trigger(self):
63        with self._finish_condition:
64            self._triggered = True
65            self._finish_condition.notify_all()
66
67
68def handle_unary_unary(trigger, request, servicer_context):
69    trigger.await_trigger()
70    return _RESPONSE
71
72
73def handle_unary_stream(trigger, request, servicer_context):
74    trigger.await_trigger()
75    for _ in range(test_constants.STREAM_LENGTH):
76        yield _RESPONSE
77
78
79def handle_stream_unary(trigger, request_iterator, servicer_context):
80    trigger.await_trigger()
81    # TODO(issue:#6891) We should be able to remove this loop
82    for request in request_iterator:
83        pass
84    return _RESPONSE
85
86
87def handle_stream_stream(trigger, request_iterator, servicer_context):
88    trigger.await_trigger()
89    # TODO(issue:#6891) We should be able to remove this loop,
90    # and replace with return; yield
91    for request in request_iterator:
92        yield _RESPONSE
93
94
95class _MethodHandler(grpc.RpcMethodHandler):
96    def __init__(self, trigger, request_streaming, response_streaming):
97        self.request_streaming = request_streaming
98        self.response_streaming = response_streaming
99        self.request_deserializer = None
100        self.response_serializer = None
101        self.unary_unary = None
102        self.unary_stream = None
103        self.stream_unary = None
104        self.stream_stream = None
105        if self.request_streaming and self.response_streaming:
106            self.stream_stream = lambda x, y: handle_stream_stream(
107                trigger, x, y
108            )
109        elif self.request_streaming:
110            self.stream_unary = lambda x, y: handle_stream_unary(trigger, x, y)
111        elif self.response_streaming:
112            self.unary_stream = lambda x, y: handle_unary_stream(trigger, x, y)
113        else:
114            self.unary_unary = lambda x, y: handle_unary_unary(trigger, x, y)
115
116
117def get_method_handlers(trigger):
118    return {
119        _UNARY_UNARY: _MethodHandler(trigger, False, False),
120        _UNARY_STREAM: _MethodHandler(trigger, False, True),
121        _STREAM_UNARY: _MethodHandler(trigger, True, False),
122        _STREAM_STREAM: _MethodHandler(trigger, True, True),
123    }
124
125
126class ResourceExhaustedTest(unittest.TestCase):
127    def setUp(self):
128        self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
129        self._trigger = _TestTrigger(test_constants.THREAD_CONCURRENCY)
130        self._server = grpc.server(
131            self._server_pool,
132            options=(("grpc.so_reuseport", 0),),
133            maximum_concurrent_rpcs=test_constants.THREAD_CONCURRENCY,
134        )
135        self._server.add_registered_method_handlers(
136            _SERVICE_NAME, get_method_handlers(self._trigger)
137        )
138        port = self._server.add_insecure_port("[::]:0")
139        self._server.start()
140        self._channel = grpc.insecure_channel("localhost:%d" % port)
141
142    def tearDown(self):
143        self._server.stop(0)
144        self._channel.close()
145
146    def testUnaryUnary(self):
147        multi_callable = self._channel.unary_unary(
148            grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_UNARY),
149            _registered_method=True,
150        )
151        futures = []
152        for _ in range(test_constants.THREAD_CONCURRENCY):
153            futures.append(multi_callable.future(_REQUEST))
154
155        self._trigger.await_calls()
156
157        with self.assertRaises(grpc.RpcError) as exception_context:
158            multi_callable(_REQUEST)
159
160        self.assertEqual(
161            grpc.StatusCode.RESOURCE_EXHAUSTED,
162            exception_context.exception.code(),
163        )
164
165        future_exception = multi_callable.future(_REQUEST)
166        self.assertEqual(
167            grpc.StatusCode.RESOURCE_EXHAUSTED,
168            future_exception.exception().code(),
169        )
170
171        self._trigger.trigger()
172        for future in futures:
173            self.assertEqual(_RESPONSE, future.result())
174
175        # Ensure a new request can be handled
176        self.assertEqual(_RESPONSE, multi_callable(_REQUEST))
177
178    def testUnaryStream(self):
179        multi_callable = self._channel.unary_stream(
180            grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_STREAM),
181            _registered_method=True,
182        )
183        calls = []
184        for _ in range(test_constants.THREAD_CONCURRENCY):
185            calls.append(multi_callable(_REQUEST))
186
187        self._trigger.await_calls()
188
189        with self.assertRaises(grpc.RpcError) as exception_context:
190            next(multi_callable(_REQUEST))
191
192        self.assertEqual(
193            grpc.StatusCode.RESOURCE_EXHAUSTED,
194            exception_context.exception.code(),
195        )
196
197        self._trigger.trigger()
198
199        for call in calls:
200            for response in call:
201                self.assertEqual(_RESPONSE, response)
202
203        # Ensure a new request can be handled
204        new_call = multi_callable(_REQUEST)
205        for response in new_call:
206            self.assertEqual(_RESPONSE, response)
207
208    def testStreamUnary(self):
209        multi_callable = self._channel.stream_unary(
210            grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_UNARY),
211            _registered_method=True,
212        )
213        futures = []
214        request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
215        for _ in range(test_constants.THREAD_CONCURRENCY):
216            futures.append(multi_callable.future(request))
217
218        self._trigger.await_calls()
219
220        with self.assertRaises(grpc.RpcError) as exception_context:
221            multi_callable(request)
222
223        self.assertEqual(
224            grpc.StatusCode.RESOURCE_EXHAUSTED,
225            exception_context.exception.code(),
226        )
227
228        future_exception = multi_callable.future(request)
229        self.assertEqual(
230            grpc.StatusCode.RESOURCE_EXHAUSTED,
231            future_exception.exception().code(),
232        )
233
234        self._trigger.trigger()
235
236        for future in futures:
237            self.assertEqual(_RESPONSE, future.result())
238
239        # Ensure a new request can be handled
240        self.assertEqual(_RESPONSE, multi_callable(request))
241
242    def testStreamStream(self):
243        multi_callable = self._channel.stream_stream(
244            grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_STREAM),
245            _registered_method=True,
246        )
247        calls = []
248        request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
249        for _ in range(test_constants.THREAD_CONCURRENCY):
250            calls.append(multi_callable(request))
251
252        self._trigger.await_calls()
253
254        with self.assertRaises(grpc.RpcError) as exception_context:
255            next(multi_callable(request))
256
257        self.assertEqual(
258            grpc.StatusCode.RESOURCE_EXHAUSTED,
259            exception_context.exception.code(),
260        )
261
262        self._trigger.trigger()
263
264        for call in calls:
265            for response in call:
266                self.assertEqual(_RESPONSE, response)
267
268        # Ensure a new request can be handled
269        new_call = multi_callable(request)
270        for response in new_call:
271            self.assertEqual(_RESPONSE, response)
272
273
274if __name__ == "__main__":
275    logging.basicConfig()
276    unittest.main(verbosity=2)
277