• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test helpers for RPC invocation tests."""
15
16import datetime
17import threading
18
19import grpc
20from grpc.framework.foundation import logging_pool
21
22from tests.unit import test_common
23from tests.unit import thread_pool
24from tests.unit.framework.common import test_constants
25from tests.unit.framework.common import test_control
26
27_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2
28_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:]
29_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3
30_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3]
31
32_UNARY_UNARY = '/test/UnaryUnary'
33_UNARY_STREAM = '/test/UnaryStream'
34_UNARY_STREAM_NON_BLOCKING = '/test/UnaryStreamNonBlocking'
35_STREAM_UNARY = '/test/StreamUnary'
36_STREAM_STREAM = '/test/StreamStream'
37_STREAM_STREAM_NON_BLOCKING = '/test/StreamStreamNonBlocking'
38
39TIMEOUT_SHORT = datetime.timedelta(seconds=1).total_seconds()
40
41
42class Callback(object):
43
44    def __init__(self):
45        self._condition = threading.Condition()
46        self._value = None
47        self._called = False
48
49    def __call__(self, value):
50        with self._condition:
51            self._value = value
52            self._called = True
53            self._condition.notify_all()
54
55    def value(self):
56        with self._condition:
57            while not self._called:
58                self._condition.wait()
59            return self._value
60
61
62class _Handler(object):
63
64    def __init__(self, control, thread_pool):
65        self._control = control
66        self._thread_pool = thread_pool
67        non_blocking_functions = (self.handle_unary_stream_non_blocking,
68                                  self.handle_stream_stream_non_blocking)
69        for non_blocking_function in non_blocking_functions:
70            non_blocking_function.__func__.experimental_non_blocking = True
71            non_blocking_function.__func__.experimental_thread_pool = self._thread_pool
72
73    def handle_unary_unary(self, request, servicer_context):
74        self._control.control()
75        if servicer_context is not None:
76            servicer_context.set_trailing_metadata(((
77                'testkey',
78                'testvalue',
79            ),))
80            # TODO(https://github.com/grpc/grpc/issues/8483): test the values
81            # returned by these methods rather than only "smoke" testing that
82            # the return after having been called.
83            servicer_context.is_active()
84            servicer_context.time_remaining()
85        return request
86
87    def handle_unary_stream(self, request, servicer_context):
88        for _ in range(test_constants.STREAM_LENGTH):
89            self._control.control()
90            yield request
91        self._control.control()
92        if servicer_context is not None:
93            servicer_context.set_trailing_metadata(((
94                'testkey',
95                'testvalue',
96            ),))
97
98    def handle_unary_stream_non_blocking(self, request, servicer_context,
99                                         on_next):
100        for _ in range(test_constants.STREAM_LENGTH):
101            self._control.control()
102            on_next(request)
103        self._control.control()
104        if servicer_context is not None:
105            servicer_context.set_trailing_metadata(((
106                'testkey',
107                'testvalue',
108            ),))
109        on_next(None)
110
111    def handle_stream_unary(self, request_iterator, servicer_context):
112        if servicer_context is not None:
113            servicer_context.invocation_metadata()
114        self._control.control()
115        response_elements = []
116        for request in request_iterator:
117            self._control.control()
118            response_elements.append(request)
119        self._control.control()
120        if servicer_context is not None:
121            servicer_context.set_trailing_metadata(((
122                'testkey',
123                'testvalue',
124            ),))
125        return b''.join(response_elements)
126
127    def handle_stream_stream(self, request_iterator, servicer_context):
128        self._control.control()
129        if servicer_context is not None:
130            servicer_context.set_trailing_metadata(((
131                'testkey',
132                'testvalue',
133            ),))
134        for request in request_iterator:
135            self._control.control()
136            yield request
137        self._control.control()
138
139    def handle_stream_stream_non_blocking(self, request_iterator,
140                                          servicer_context, on_next):
141        self._control.control()
142        if servicer_context is not None:
143            servicer_context.set_trailing_metadata(((
144                'testkey',
145                'testvalue',
146            ),))
147        for request in request_iterator:
148            self._control.control()
149            on_next(request)
150        self._control.control()
151        on_next(None)
152
153
154class _MethodHandler(grpc.RpcMethodHandler):
155
156    def __init__(self, request_streaming, response_streaming,
157                 request_deserializer, response_serializer, unary_unary,
158                 unary_stream, stream_unary, stream_stream):
159        self.request_streaming = request_streaming
160        self.response_streaming = response_streaming
161        self.request_deserializer = request_deserializer
162        self.response_serializer = response_serializer
163        self.unary_unary = unary_unary
164        self.unary_stream = unary_stream
165        self.stream_unary = stream_unary
166        self.stream_stream = stream_stream
167
168
169class _GenericHandler(grpc.GenericRpcHandler):
170
171    def __init__(self, handler):
172        self._handler = handler
173
174    def service(self, handler_call_details):
175        if handler_call_details.method == _UNARY_UNARY:
176            return _MethodHandler(False, False, None, None,
177                                  self._handler.handle_unary_unary, None, None,
178                                  None)
179        elif handler_call_details.method == _UNARY_STREAM:
180            return _MethodHandler(False, True, _DESERIALIZE_REQUEST,
181                                  _SERIALIZE_RESPONSE, None,
182                                  self._handler.handle_unary_stream, None, None)
183        elif handler_call_details.method == _UNARY_STREAM_NON_BLOCKING:
184            return _MethodHandler(
185                False, True, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None,
186                self._handler.handle_unary_stream_non_blocking, None, None)
187        elif handler_call_details.method == _STREAM_UNARY:
188            return _MethodHandler(True, False, _DESERIALIZE_REQUEST,
189                                  _SERIALIZE_RESPONSE, None, None,
190                                  self._handler.handle_stream_unary, None)
191        elif handler_call_details.method == _STREAM_STREAM:
192            return _MethodHandler(True, True, None, None, None, None, None,
193                                  self._handler.handle_stream_stream)
194        elif handler_call_details.method == _STREAM_STREAM_NON_BLOCKING:
195            return _MethodHandler(
196                True, True, None, None, None, None, None,
197                self._handler.handle_stream_stream_non_blocking)
198        else:
199            return None
200
201
202def unary_unary_multi_callable(channel):
203    return channel.unary_unary(_UNARY_UNARY)
204
205
206def unary_stream_multi_callable(channel):
207    return channel.unary_stream(_UNARY_STREAM,
208                                request_serializer=_SERIALIZE_REQUEST,
209                                response_deserializer=_DESERIALIZE_RESPONSE)
210
211
212def unary_stream_non_blocking_multi_callable(channel):
213    return channel.unary_stream(_UNARY_STREAM_NON_BLOCKING,
214                                request_serializer=_SERIALIZE_REQUEST,
215                                response_deserializer=_DESERIALIZE_RESPONSE)
216
217
218def stream_unary_multi_callable(channel):
219    return channel.stream_unary(_STREAM_UNARY,
220                                request_serializer=_SERIALIZE_REQUEST,
221                                response_deserializer=_DESERIALIZE_RESPONSE)
222
223
224def stream_stream_multi_callable(channel):
225    return channel.stream_stream(_STREAM_STREAM)
226
227
228def stream_stream_non_blocking_multi_callable(channel):
229    return channel.stream_stream(_STREAM_STREAM_NON_BLOCKING)
230
231
232class BaseRPCTest(object):
233
234    def setUp(self):
235        self._control = test_control.PauseFailControl()
236        self._thread_pool = thread_pool.RecordingThreadPool(max_workers=None)
237        self._handler = _Handler(self._control, self._thread_pool)
238
239        self._server = test_common.test_server()
240        port = self._server.add_insecure_port('[::]:0')
241        self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
242        self._server.start()
243
244        self._channel = grpc.insecure_channel('localhost:%d' % port)
245
246    def tearDown(self):
247        self._server.stop(None)
248        self._channel.close()
249
250    def _consume_one_stream_response_unary_request(self, multi_callable):
251        request = b'\x57\x38'
252
253        response_iterator = multi_callable(
254            request,
255            metadata=(('test', 'ConsumingOneStreamResponseUnaryRequest'),))
256        next(response_iterator)
257
258    def _consume_some_but_not_all_stream_responses_unary_request(
259            self, multi_callable):
260        request = b'\x57\x38'
261
262        response_iterator = multi_callable(
263            request,
264            metadata=(('test',
265                       'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),))
266        for _ in range(test_constants.STREAM_LENGTH // 2):
267            next(response_iterator)
268
269    def _consume_some_but_not_all_stream_responses_stream_request(
270            self, multi_callable):
271        requests = tuple(
272            b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
273        request_iterator = iter(requests)
274
275        response_iterator = multi_callable(
276            request_iterator,
277            metadata=(('test',
278                       'ConsumingSomeButNotAllStreamResponsesStreamRequest'),))
279        for _ in range(test_constants.STREAM_LENGTH // 2):
280            next(response_iterator)
281
282    def _consume_too_many_stream_responses_stream_request(self, multi_callable):
283        requests = tuple(
284            b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
285        request_iterator = iter(requests)
286
287        response_iterator = multi_callable(
288            request_iterator,
289            metadata=(('test',
290                       'ConsumingTooManyStreamResponsesStreamRequest'),))
291        for _ in range(test_constants.STREAM_LENGTH):
292            next(response_iterator)
293        for _ in range(test_constants.STREAM_LENGTH):
294            with self.assertRaises(StopIteration):
295                next(response_iterator)
296
297        self.assertIsNotNone(response_iterator.initial_metadata())
298        self.assertIs(grpc.StatusCode.OK, response_iterator.code())
299        self.assertIsNotNone(response_iterator.details())
300        self.assertIsNotNone(response_iterator.trailing_metadata())
301
302    def _cancelled_unary_request_stream_response(self, multi_callable):
303        request = b'\x07\x19'
304
305        with self._control.pause():
306            response_iterator = multi_callable(
307                request,
308                metadata=(('test', 'CancelledUnaryRequestStreamResponse'),))
309            self._control.block_until_paused()
310            response_iterator.cancel()
311
312        with self.assertRaises(grpc.RpcError) as exception_context:
313            next(response_iterator)
314        self.assertIs(grpc.StatusCode.CANCELLED,
315                      exception_context.exception.code())
316        self.assertIsNotNone(response_iterator.initial_metadata())
317        self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
318        self.assertIsNotNone(response_iterator.details())
319        self.assertIsNotNone(response_iterator.trailing_metadata())
320
321    def _cancelled_stream_request_stream_response(self, multi_callable):
322        requests = tuple(
323            b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
324        request_iterator = iter(requests)
325
326        with self._control.pause():
327            response_iterator = multi_callable(
328                request_iterator,
329                metadata=(('test', 'CancelledStreamRequestStreamResponse'),))
330            response_iterator.cancel()
331
332        with self.assertRaises(grpc.RpcError):
333            next(response_iterator)
334        self.assertIsNotNone(response_iterator.initial_metadata())
335        self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
336        self.assertIsNotNone(response_iterator.details())
337        self.assertIsNotNone(response_iterator.trailing_metadata())
338
339    def _expired_unary_request_stream_response(self, multi_callable):
340        request = b'\x07\x19'
341
342        with self._control.pause():
343            with self.assertRaises(grpc.RpcError) as exception_context:
344                response_iterator = multi_callable(
345                    request,
346                    timeout=test_constants.SHORT_TIMEOUT,
347                    metadata=(('test', 'ExpiredUnaryRequestStreamResponse'),))
348                next(response_iterator)
349
350        self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
351                      exception_context.exception.code())
352        self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
353                      response_iterator.code())
354
355    def _expired_stream_request_stream_response(self, multi_callable):
356        requests = tuple(
357            b'\x67\x18' for _ in range(test_constants.STREAM_LENGTH))
358        request_iterator = iter(requests)
359
360        with self._control.pause():
361            with self.assertRaises(grpc.RpcError) as exception_context:
362                response_iterator = multi_callable(
363                    request_iterator,
364                    timeout=test_constants.SHORT_TIMEOUT,
365                    metadata=(('test', 'ExpiredStreamRequestStreamResponse'),))
366                next(response_iterator)
367
368        self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
369                      exception_context.exception.code())
370        self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
371                      response_iterator.code())
372
373    def _failed_unary_request_stream_response(self, multi_callable):
374        request = b'\x37\x17'
375
376        with self.assertRaises(grpc.RpcError) as exception_context:
377            with self._control.fail():
378                response_iterator = multi_callable(
379                    request,
380                    metadata=(('test', 'FailedUnaryRequestStreamResponse'),))
381                next(response_iterator)
382
383        self.assertIs(grpc.StatusCode.UNKNOWN,
384                      exception_context.exception.code())
385
386    def _failed_stream_request_stream_response(self, multi_callable):
387        requests = tuple(
388            b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
389        request_iterator = iter(requests)
390
391        with self._control.fail():
392            with self.assertRaises(grpc.RpcError) as exception_context:
393                response_iterator = multi_callable(
394                    request_iterator,
395                    metadata=(('test', 'FailedStreamRequestStreamResponse'),))
396                tuple(response_iterator)
397
398        self.assertIs(grpc.StatusCode.UNKNOWN,
399                      exception_context.exception.code())
400        self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code())
401
402    def _ignored_unary_stream_request_future_unary_response(
403            self, multi_callable):
404        request = b'\x37\x17'
405
406        multi_callable(request,
407                       metadata=(('test',
408                                  'IgnoredUnaryRequestStreamResponse'),))
409
410    def _ignored_stream_request_stream_response(self, multi_callable):
411        requests = tuple(
412            b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
413        request_iterator = iter(requests)
414
415        multi_callable(request_iterator,
416                       metadata=(('test',
417                                  'IgnoredStreamRequestStreamResponse'),))
418