• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test helpers for RPC invocation tests."""
15
16import datetime
17import threading
18
19import grpc
20from grpc.framework.foundation import logging_pool
21
22from tests.unit import test_common
23from tests.unit import thread_pool
24from tests.unit.framework.common import test_constants
25from tests.unit.framework.common import test_control
26
27_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2
28_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2 :]
29_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3
30_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[: len(bytestring) // 3]
31
32_SERVICE_NAME = "test"
33_UNARY_UNARY = "UnaryUnary"
34_UNARY_STREAM = "UnaryStream"
35_UNARY_STREAM_NON_BLOCKING = "UnaryStreamNonBlocking"
36_STREAM_UNARY = "StreamUnary"
37_STREAM_STREAM = "StreamStream"
38_STREAM_STREAM_NON_BLOCKING = "StreamStreamNonBlocking"
39
40TIMEOUT_SHORT = datetime.timedelta(seconds=4).total_seconds()
41
42
43class Callback(object):
44    def __init__(self):
45        self._condition = threading.Condition()
46        self._value = None
47        self._called = False
48
49    def __call__(self, value):
50        with self._condition:
51            self._value = value
52            self._called = True
53            self._condition.notify_all()
54
55    def value(self):
56        with self._condition:
57            while not self._called:
58                self._condition.wait()
59            return self._value
60
61
62class _Handler(object):
63    def __init__(self, control, thread_pool):
64        self._control = control
65        self._thread_pool = thread_pool
66        non_blocking_functions = (
67            self.handle_unary_stream_non_blocking,
68            self.handle_stream_stream_non_blocking,
69        )
70        for non_blocking_function in non_blocking_functions:
71            non_blocking_function.__func__.experimental_non_blocking = True
72            non_blocking_function.__func__.experimental_thread_pool = (
73                self._thread_pool
74            )
75
76    def handle_unary_unary(self, request, servicer_context):
77        self._control.control()
78        if servicer_context is not None:
79            servicer_context.set_trailing_metadata(
80                (
81                    (
82                        "testkey",
83                        "testvalue",
84                    ),
85                )
86            )
87            # TODO(https://github.com/grpc/grpc/issues/8483): test the values
88            # returned by these methods rather than only "smoke" testing that
89            # the return after having been called.
90            servicer_context.is_active()
91            servicer_context.time_remaining()
92        return request
93
94    def handle_unary_stream(self, request, servicer_context):
95        for _ in range(test_constants.STREAM_LENGTH):
96            self._control.control()
97            yield request
98        self._control.control()
99        if servicer_context is not None:
100            servicer_context.set_trailing_metadata(
101                (
102                    (
103                        "testkey",
104                        "testvalue",
105                    ),
106                )
107            )
108
109    def handle_unary_stream_non_blocking(
110        self, request, servicer_context, on_next
111    ):
112        for _ in range(test_constants.STREAM_LENGTH):
113            self._control.control()
114            on_next(request)
115        self._control.control()
116        if servicer_context is not None:
117            servicer_context.set_trailing_metadata(
118                (
119                    (
120                        "testkey",
121                        "testvalue",
122                    ),
123                )
124            )
125        on_next(None)
126
127    def handle_stream_unary(self, request_iterator, servicer_context):
128        if servicer_context is not None:
129            servicer_context.invocation_metadata()
130        self._control.control()
131        response_elements = []
132        for request in request_iterator:
133            self._control.control()
134            response_elements.append(request)
135        self._control.control()
136        if servicer_context is not None:
137            servicer_context.set_trailing_metadata(
138                (
139                    (
140                        "testkey",
141                        "testvalue",
142                    ),
143                )
144            )
145        return b"".join(response_elements)
146
147    def handle_stream_stream(self, request_iterator, servicer_context):
148        self._control.control()
149        if servicer_context is not None:
150            servicer_context.set_trailing_metadata(
151                (
152                    (
153                        "testkey",
154                        "testvalue",
155                    ),
156                )
157            )
158        for request in request_iterator:
159            self._control.control()
160            yield request
161        self._control.control()
162
163    def handle_stream_stream_non_blocking(
164        self, request_iterator, servicer_context, on_next
165    ):
166        self._control.control()
167        if servicer_context is not None:
168            servicer_context.set_trailing_metadata(
169                (
170                    (
171                        "testkey",
172                        "testvalue",
173                    ),
174                )
175            )
176        for request in request_iterator:
177            self._control.control()
178            on_next(request)
179        self._control.control()
180        on_next(None)
181
182
183class _MethodHandler(grpc.RpcMethodHandler):
184    def __init__(
185        self,
186        request_streaming,
187        response_streaming,
188        request_deserializer,
189        response_serializer,
190        unary_unary,
191        unary_stream,
192        stream_unary,
193        stream_stream,
194    ):
195        self.request_streaming = request_streaming
196        self.response_streaming = response_streaming
197        self.request_deserializer = request_deserializer
198        self.response_serializer = response_serializer
199        self.unary_unary = unary_unary
200        self.unary_stream = unary_stream
201        self.stream_unary = stream_unary
202        self.stream_stream = stream_stream
203
204
205class _GenericHandler(grpc.GenericRpcHandler):
206    def __init__(self, handler):
207        self._handler = handler
208
209    def service(self, handler_call_details):
210        if handler_call_details.method == _UNARY_UNARY:
211            return _MethodHandler(
212                False,
213                False,
214                None,
215                None,
216                self._handler.handle_unary_unary,
217                None,
218                None,
219                None,
220            )
221        elif handler_call_details.method == _UNARY_STREAM:
222            return _MethodHandler(
223                False,
224                True,
225                _DESERIALIZE_REQUEST,
226                _SERIALIZE_RESPONSE,
227                None,
228                self._handler.handle_unary_stream,
229                None,
230                None,
231            )
232        elif handler_call_details.method == _UNARY_STREAM_NON_BLOCKING:
233            return _MethodHandler(
234                False,
235                True,
236                _DESERIALIZE_REQUEST,
237                _SERIALIZE_RESPONSE,
238                None,
239                self._handler.handle_unary_stream_non_blocking,
240                None,
241                None,
242            )
243        elif handler_call_details.method == _STREAM_UNARY:
244            return _MethodHandler(
245                True,
246                False,
247                _DESERIALIZE_REQUEST,
248                _SERIALIZE_RESPONSE,
249                None,
250                None,
251                self._handler.handle_stream_unary,
252                None,
253            )
254        elif handler_call_details.method == _STREAM_STREAM:
255            return _MethodHandler(
256                True,
257                True,
258                None,
259                None,
260                None,
261                None,
262                None,
263                self._handler.handle_stream_stream,
264            )
265        elif handler_call_details.method == _STREAM_STREAM_NON_BLOCKING:
266            return _MethodHandler(
267                True,
268                True,
269                None,
270                None,
271                None,
272                None,
273                None,
274                self._handler.handle_stream_stream_non_blocking,
275            )
276        else:
277            return None
278
279
280def get_method_handlers(handler):
281    return {
282        _UNARY_UNARY: _MethodHandler(
283            False,
284            False,
285            None,
286            None,
287            handler.handle_unary_unary,
288            None,
289            None,
290            None,
291        ),
292        _UNARY_STREAM: _MethodHandler(
293            False,
294            True,
295            _DESERIALIZE_REQUEST,
296            _SERIALIZE_RESPONSE,
297            None,
298            handler.handle_unary_stream,
299            None,
300            None,
301        ),
302        _UNARY_STREAM_NON_BLOCKING: _MethodHandler(
303            False,
304            True,
305            _DESERIALIZE_REQUEST,
306            _SERIALIZE_RESPONSE,
307            None,
308            handler.handle_unary_stream_non_blocking,
309            None,
310            None,
311        ),
312        _STREAM_UNARY: _MethodHandler(
313            True,
314            False,
315            _DESERIALIZE_REQUEST,
316            _SERIALIZE_RESPONSE,
317            None,
318            None,
319            handler.handle_stream_unary,
320            None,
321        ),
322        _STREAM_STREAM: _MethodHandler(
323            True,
324            True,
325            None,
326            None,
327            None,
328            None,
329            None,
330            handler.handle_stream_stream,
331        ),
332        _STREAM_STREAM_NON_BLOCKING: _MethodHandler(
333            True,
334            True,
335            None,
336            None,
337            None,
338            None,
339            None,
340            handler.handle_stream_stream_non_blocking,
341        ),
342    }
343
344
345def unary_unary_multi_callable(channel):
346    return channel.unary_unary(
347        grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_UNARY),
348        _registered_method=True,
349    )
350
351
352def unary_stream_multi_callable(channel):
353    return channel.unary_stream(
354        grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_STREAM),
355        request_serializer=_SERIALIZE_REQUEST,
356        response_deserializer=_DESERIALIZE_RESPONSE,
357        _registered_method=True,
358    )
359
360
361def unary_stream_non_blocking_multi_callable(channel):
362    return channel.unary_stream(
363        grpc._common.fully_qualified_method(
364            _SERVICE_NAME, _UNARY_STREAM_NON_BLOCKING
365        ),
366        request_serializer=_SERIALIZE_REQUEST,
367        response_deserializer=_DESERIALIZE_RESPONSE,
368        _registered_method=True,
369    )
370
371
372def stream_unary_multi_callable(channel):
373    return channel.stream_unary(
374        grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_UNARY),
375        request_serializer=_SERIALIZE_REQUEST,
376        response_deserializer=_DESERIALIZE_RESPONSE,
377        _registered_method=True,
378    )
379
380
381def stream_stream_multi_callable(channel):
382    return channel.stream_stream(
383        grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_STREAM),
384        _registered_method=True,
385    )
386
387
388def stream_stream_non_blocking_multi_callable(channel):
389    return channel.stream_stream(
390        grpc._common.fully_qualified_method(
391            _SERVICE_NAME, _STREAM_STREAM_NON_BLOCKING
392        ),
393        _registered_method=True,
394    )
395
396
397class BaseRPCTest(object):
398    def setUp(self):
399        self._control = test_control.PauseFailControl()
400        self._thread_pool = thread_pool.RecordingThreadPool(max_workers=None)
401        self._handler = _Handler(self._control, self._thread_pool)
402
403        self._server = test_common.test_server()
404        port = self._server.add_insecure_port("[::]:0")
405        self._server.add_registered_method_handlers(
406            _SERVICE_NAME, get_method_handlers(self._handler)
407        )
408        self._server.start()
409
410        self._channel = grpc.insecure_channel("localhost:%d" % port)
411
412    def tearDown(self):
413        self._server.stop(None)
414        self._channel.close()
415
416    def _consume_one_stream_response_unary_request(self, multi_callable):
417        request = b"\x57\x38"
418
419        response_iterator = multi_callable(
420            request,
421            metadata=(("test", "ConsumingOneStreamResponseUnaryRequest"),),
422        )
423        next(response_iterator)
424
425    def _consume_some_but_not_all_stream_responses_unary_request(
426        self, multi_callable
427    ):
428        request = b"\x57\x38"
429
430        response_iterator = multi_callable(
431            request,
432            metadata=(
433                ("test", "ConsumingSomeButNotAllStreamResponsesUnaryRequest"),
434            ),
435        )
436        for _ in range(test_constants.STREAM_LENGTH // 2):
437            next(response_iterator)
438
439    def _consume_some_but_not_all_stream_responses_stream_request(
440        self, multi_callable
441    ):
442        requests = tuple(
443            b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH)
444        )
445        request_iterator = iter(requests)
446
447        response_iterator = multi_callable(
448            request_iterator,
449            metadata=(
450                ("test", "ConsumingSomeButNotAllStreamResponsesStreamRequest"),
451            ),
452        )
453        for _ in range(test_constants.STREAM_LENGTH // 2):
454            next(response_iterator)
455
456    def _consume_too_many_stream_responses_stream_request(self, multi_callable):
457        requests = tuple(
458            b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH)
459        )
460        request_iterator = iter(requests)
461
462        response_iterator = multi_callable(
463            request_iterator,
464            metadata=(
465                ("test", "ConsumingTooManyStreamResponsesStreamRequest"),
466            ),
467        )
468        for _ in range(test_constants.STREAM_LENGTH):
469            next(response_iterator)
470        for _ in range(test_constants.STREAM_LENGTH):
471            with self.assertRaises(StopIteration):
472                next(response_iterator)
473
474        self.assertIsNotNone(response_iterator.initial_metadata())
475        self.assertIs(grpc.StatusCode.OK, response_iterator.code())
476        self.assertIsNotNone(response_iterator.details())
477        self.assertIsNotNone(response_iterator.trailing_metadata())
478
479    def _cancelled_unary_request_stream_response(self, multi_callable):
480        request = b"\x07\x19"
481
482        with self._control.pause():
483            response_iterator = multi_callable(
484                request,
485                metadata=(("test", "CancelledUnaryRequestStreamResponse"),),
486            )
487            self._control.block_until_paused()
488            response_iterator.cancel()
489
490        with self.assertRaises(grpc.RpcError) as exception_context:
491            next(response_iterator)
492        self.assertIs(
493            grpc.StatusCode.CANCELLED, exception_context.exception.code()
494        )
495        self.assertIsNotNone(response_iterator.initial_metadata())
496        self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
497        self.assertIsNotNone(response_iterator.details())
498        self.assertIsNotNone(response_iterator.trailing_metadata())
499
500    def _cancelled_stream_request_stream_response(self, multi_callable):
501        requests = tuple(
502            b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH)
503        )
504        request_iterator = iter(requests)
505
506        with self._control.pause():
507            response_iterator = multi_callable(
508                request_iterator,
509                metadata=(("test", "CancelledStreamRequestStreamResponse"),),
510            )
511            response_iterator.cancel()
512
513        with self.assertRaises(grpc.RpcError):
514            next(response_iterator)
515        self.assertIsNotNone(response_iterator.initial_metadata())
516        self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
517        self.assertIsNotNone(response_iterator.details())
518        self.assertIsNotNone(response_iterator.trailing_metadata())
519
520    def _expired_unary_request_stream_response(self, multi_callable):
521        request = b"\x07\x19"
522
523        with self._control.pause():
524            with self.assertRaises(grpc.RpcError) as exception_context:
525                response_iterator = multi_callable(
526                    request,
527                    timeout=test_constants.SHORT_TIMEOUT,
528                    metadata=(("test", "ExpiredUnaryRequestStreamResponse"),),
529                )
530                next(response_iterator)
531
532        self.assertIs(
533            grpc.StatusCode.DEADLINE_EXCEEDED,
534            exception_context.exception.code(),
535        )
536        self.assertIs(
537            grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code()
538        )
539
540    def _expired_stream_request_stream_response(self, multi_callable):
541        requests = tuple(
542            b"\x67\x18" for _ in range(test_constants.STREAM_LENGTH)
543        )
544        request_iterator = iter(requests)
545
546        with self._control.pause():
547            with self.assertRaises(grpc.RpcError) as exception_context:
548                response_iterator = multi_callable(
549                    request_iterator,
550                    timeout=test_constants.SHORT_TIMEOUT,
551                    metadata=(("test", "ExpiredStreamRequestStreamResponse"),),
552                )
553                next(response_iterator)
554
555        self.assertIs(
556            grpc.StatusCode.DEADLINE_EXCEEDED,
557            exception_context.exception.code(),
558        )
559        self.assertIs(
560            grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code()
561        )
562
563    def _failed_unary_request_stream_response(self, multi_callable):
564        request = b"\x37\x17"
565
566        with self.assertRaises(grpc.RpcError) as exception_context:
567            with self._control.fail():
568                response_iterator = multi_callable(
569                    request,
570                    metadata=(("test", "FailedUnaryRequestStreamResponse"),),
571                )
572                next(response_iterator)
573
574        self.assertIs(
575            grpc.StatusCode.UNKNOWN, exception_context.exception.code()
576        )
577
578    def _failed_stream_request_stream_response(self, multi_callable):
579        requests = tuple(
580            b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH)
581        )
582        request_iterator = iter(requests)
583
584        with self._control.fail():
585            with self.assertRaises(grpc.RpcError) as exception_context:
586                response_iterator = multi_callable(
587                    request_iterator,
588                    metadata=(("test", "FailedStreamRequestStreamResponse"),),
589                )
590                tuple(response_iterator)
591
592        self.assertIs(
593            grpc.StatusCode.UNKNOWN, exception_context.exception.code()
594        )
595        self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code())
596
597    def _ignored_unary_stream_request_future_unary_response(
598        self, multi_callable
599    ):
600        request = b"\x37\x17"
601
602        multi_callable(
603            request, metadata=(("test", "IgnoredUnaryRequestStreamResponse"),)
604        )
605
606    def _ignored_stream_request_stream_response(self, multi_callable):
607        requests = tuple(
608            b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH)
609        )
610        request_iterator = iter(requests)
611
612        multi_callable(
613            request_iterator,
614            metadata=(("test", "IgnoredStreamRequestStreamResponse"),),
615        )
616