• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test of gRPC Python interceptors."""
15
16import collections
17from concurrent import futures
18from contextvars import ContextVar
19import itertools
20import logging
21import os
22import threading
23import unittest
24
25import grpc
26from grpc.framework.foundation import logging_pool
27
28from tests.unit import test_common
29from tests.unit.framework.common import test_constants
30from tests.unit.framework.common import test_control
31
32_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2
33_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2 :]
34_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3
35_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[: len(bytestring) // 3]
36
37_EXCEPTION_REQUEST = b"\x09\x0a"
38
39_UNARY_UNARY = "/test/UnaryUnary"
40_UNARY_STREAM = "/test/UnaryStream"
41_STREAM_UNARY = "/test/StreamUnary"
42_STREAM_STREAM = "/test/StreamStream"
43
44_TEST_CONTEXT_VAR: ContextVar[str] = ContextVar("")
45
46
47class _ApplicationErrorStandin(Exception):
48    pass
49
50
51class _Callback(object):
52    def __init__(self):
53        self._condition = threading.Condition()
54        self._value = None
55        self._called = False
56
57    def __call__(self, value):
58        with self._condition:
59            self._value = value
60            self._called = True
61            self._condition.notify_all()
62
63    def value(self):
64        with self._condition:
65            while not self._called:
66                self._condition.wait()
67            return self._value
68
69
70class _Handler(object):
71    def __init__(self, control, record):
72        self._control = control
73        self._record = record
74
75    def _append_to_log(self, message: str) -> None:
76        context_var_value = _TEST_CONTEXT_VAR.get("")
77        if context_var_value:
78            context_var_value = "[{}]".format(context_var_value)
79        self._record.append("handler:" + message + context_var_value)
80
81    def handle_unary_unary(self, request, servicer_context):
82        self._append_to_log("handle_unary_unary")
83        self._control.control()
84        if servicer_context is not None:
85            servicer_context.set_trailing_metadata(
86                (
87                    (
88                        "testkey",
89                        "testvalue",
90                    ),
91                )
92            )
93        if request == _EXCEPTION_REQUEST:
94            raise _ApplicationErrorStandin()
95        return request
96
97    def handle_unary_stream(self, request, servicer_context):
98        self._append_to_log("handle_unary_stream")
99        if request == _EXCEPTION_REQUEST:
100            raise _ApplicationErrorStandin()
101        for _ in range(test_constants.STREAM_LENGTH):
102            self._control.control()
103            yield request
104        self._control.control()
105        if servicer_context is not None:
106            servicer_context.set_trailing_metadata(
107                (
108                    (
109                        "testkey",
110                        "testvalue",
111                    ),
112                )
113            )
114
115    def handle_stream_unary(self, request_iterator, servicer_context):
116        self._append_to_log("handle_stream_unary")
117        if servicer_context is not None:
118            servicer_context.invocation_metadata()
119        self._control.control()
120        response_elements = []
121        for request in request_iterator:
122            self._control.control()
123            response_elements.append(request)
124        self._control.control()
125        if servicer_context is not None:
126            servicer_context.set_trailing_metadata(
127                (
128                    (
129                        "testkey",
130                        "testvalue",
131                    ),
132                )
133            )
134        if _EXCEPTION_REQUEST in response_elements:
135            raise _ApplicationErrorStandin()
136        return b"".join(response_elements)
137
138    def handle_stream_stream(self, request_iterator, servicer_context):
139        self._append_to_log("handle_stream_stream")
140        self._control.control()
141        if servicer_context is not None:
142            servicer_context.set_trailing_metadata(
143                (
144                    (
145                        "testkey",
146                        "testvalue",
147                    ),
148                )
149            )
150        for request in request_iterator:
151            if request == _EXCEPTION_REQUEST:
152                raise _ApplicationErrorStandin()
153            self._control.control()
154            yield request
155        self._control.control()
156
157
158class _MethodHandler(grpc.RpcMethodHandler):
159    def __init__(
160        self,
161        request_streaming,
162        response_streaming,
163        request_deserializer,
164        response_serializer,
165        unary_unary,
166        unary_stream,
167        stream_unary,
168        stream_stream,
169    ):
170        self.request_streaming = request_streaming
171        self.response_streaming = response_streaming
172        self.request_deserializer = request_deserializer
173        self.response_serializer = response_serializer
174        self.unary_unary = unary_unary
175        self.unary_stream = unary_stream
176        self.stream_unary = stream_unary
177        self.stream_stream = stream_stream
178
179
180class _GenericHandler(grpc.GenericRpcHandler):
181    def __init__(self, handler):
182        self._handler = handler
183
184    def service(self, handler_call_details):
185        if handler_call_details.method == _UNARY_UNARY:
186            return _MethodHandler(
187                False,
188                False,
189                None,
190                None,
191                self._handler.handle_unary_unary,
192                None,
193                None,
194                None,
195            )
196        elif handler_call_details.method == _UNARY_STREAM:
197            return _MethodHandler(
198                False,
199                True,
200                _DESERIALIZE_REQUEST,
201                _SERIALIZE_RESPONSE,
202                None,
203                self._handler.handle_unary_stream,
204                None,
205                None,
206            )
207        elif handler_call_details.method == _STREAM_UNARY:
208            return _MethodHandler(
209                True,
210                False,
211                _DESERIALIZE_REQUEST,
212                _SERIALIZE_RESPONSE,
213                None,
214                None,
215                self._handler.handle_stream_unary,
216                None,
217            )
218        elif handler_call_details.method == _STREAM_STREAM:
219            return _MethodHandler(
220                True,
221                True,
222                None,
223                None,
224                None,
225                None,
226                None,
227                self._handler.handle_stream_stream,
228            )
229        else:
230            return None
231
232
233def _unary_unary_multi_callable(channel):
234    return channel.unary_unary(_UNARY_UNARY, _registered_method=True)
235
236
237def _unary_stream_multi_callable(channel):
238    return channel.unary_stream(
239        _UNARY_STREAM,
240        request_serializer=_SERIALIZE_REQUEST,
241        response_deserializer=_DESERIALIZE_RESPONSE,
242        _registered_method=True,
243    )
244
245
246def _stream_unary_multi_callable(channel):
247    return channel.stream_unary(
248        _STREAM_UNARY,
249        request_serializer=_SERIALIZE_REQUEST,
250        response_deserializer=_DESERIALIZE_RESPONSE,
251        _registered_method=True,
252    )
253
254
255def _stream_stream_multi_callable(channel):
256    return channel.stream_stream(_STREAM_STREAM, _registered_method=True)
257
258
259class _ClientCallDetails(
260    collections.namedtuple(
261        "_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
262    ),
263    grpc.ClientCallDetails,
264):
265    pass
266
267
268class _GenericClientInterceptor(
269    grpc.UnaryUnaryClientInterceptor,
270    grpc.UnaryStreamClientInterceptor,
271    grpc.StreamUnaryClientInterceptor,
272    grpc.StreamStreamClientInterceptor,
273):
274    def __init__(self, interceptor_function):
275        self._fn = interceptor_function
276
277    def intercept_unary_unary(self, continuation, client_call_details, request):
278        new_details, new_request_iterator, postprocess = self._fn(
279            client_call_details, iter((request,)), False, False
280        )
281        response = continuation(new_details, next(new_request_iterator))
282        return postprocess(response) if postprocess else response
283
284    def intercept_unary_stream(
285        self, continuation, client_call_details, request
286    ):
287        new_details, new_request_iterator, postprocess = self._fn(
288            client_call_details, iter((request,)), False, True
289        )
290        response_it = continuation(new_details, new_request_iterator)
291        return postprocess(response_it) if postprocess else response_it
292
293    def intercept_stream_unary(
294        self, continuation, client_call_details, request_iterator
295    ):
296        new_details, new_request_iterator, postprocess = self._fn(
297            client_call_details, request_iterator, True, False
298        )
299        response = continuation(new_details, next(new_request_iterator))
300        return postprocess(response) if postprocess else response
301
302    def intercept_stream_stream(
303        self, continuation, client_call_details, request_iterator
304    ):
305        new_details, new_request_iterator, postprocess = self._fn(
306            client_call_details, request_iterator, True, True
307        )
308        response_it = continuation(new_details, new_request_iterator)
309        return postprocess(response_it) if postprocess else response_it
310
311
312class _ContextVarSettingInterceptor(grpc.ServerInterceptor):
313    def __init__(self, value: str) -> None:
314        self.value = value
315
316    def intercept_service(self, continuation, handler_call_details):
317        old_value = _TEST_CONTEXT_VAR.get("")
318        assert (
319            not old_value
320        ), "expected context var to have no value but was '{}'".format(
321            old_value
322        )
323        _TEST_CONTEXT_VAR.set(self.value)
324        return continuation(handler_call_details)
325
326
327class _LoggingInterceptor(
328    grpc.ServerInterceptor,
329    grpc.UnaryUnaryClientInterceptor,
330    grpc.UnaryStreamClientInterceptor,
331    grpc.StreamUnaryClientInterceptor,
332    grpc.StreamStreamClientInterceptor,
333):
334    def __init__(self, tag, record):
335        self.tag = tag
336        self.record = record
337
338    def _append_to_log(self, message: str) -> None:
339        context_var_value = _TEST_CONTEXT_VAR.get("")
340        if context_var_value:
341            context_var_value = "[{}]".format(context_var_value)
342        self.record.append(self.tag + message + context_var_value)
343
344    def intercept_service(self, continuation, handler_call_details):
345        self._append_to_log(":intercept_service")
346        return continuation(handler_call_details)
347
348    def intercept_unary_unary(self, continuation, client_call_details, request):
349        self._append_to_log(":intercept_unary_unary")
350        result = continuation(client_call_details, request)
351        assert isinstance(
352            result, grpc.Call
353        ), "{} ({}) is not an instance of grpc.Call".format(
354            result, type(result)
355        )
356        assert isinstance(
357            result, grpc.Future
358        ), "{} ({}) is not an instance of grpc.Future".format(
359            result, type(result)
360        )
361        return result
362
363    def intercept_unary_stream(
364        self, continuation, client_call_details, request
365    ):
366        self._append_to_log(":intercept_unary_stream")
367        return continuation(client_call_details, request)
368
369    def intercept_stream_unary(
370        self, continuation, client_call_details, request_iterator
371    ):
372        self._append_to_log(":intercept_stream_unary")
373        result = continuation(client_call_details, request_iterator)
374        assert isinstance(
375            result, grpc.Call
376        ), "{} is not an instance of grpc.Call".format(result)
377        assert isinstance(
378            result, grpc.Future
379        ), "{} is not an instance of grpc.Future".format(result)
380        return result
381
382    def intercept_stream_stream(
383        self, continuation, client_call_details, request_iterator
384    ):
385        self._append_to_log(":intercept_stream_stream")
386        return continuation(client_call_details, request_iterator)
387
388
389class _DefectiveClientInterceptor(grpc.UnaryUnaryClientInterceptor):
390    def intercept_unary_unary(
391        self, ignored_continuation, ignored_client_call_details, ignored_request
392    ):
393        raise test_control.Defect()
394
395
396def _wrap_request_iterator_stream_interceptor(wrapper):
397    def intercept_call(
398        client_call_details,
399        request_iterator,
400        request_streaming,
401        ignored_response_streaming,
402    ):
403        if request_streaming:
404            return client_call_details, wrapper(request_iterator), None
405        else:
406            return client_call_details, request_iterator, None
407
408    return _GenericClientInterceptor(intercept_call)
409
410
411def _append_request_header_interceptor(header, value):
412    def intercept_call(
413        client_call_details,
414        request_iterator,
415        ignored_request_streaming,
416        ignored_response_streaming,
417    ):
418        metadata = []
419        if client_call_details.metadata:
420            metadata = list(client_call_details.metadata)
421        metadata.append(
422            (
423                header,
424                value,
425            )
426        )
427        client_call_details = _ClientCallDetails(
428            client_call_details.method,
429            client_call_details.timeout,
430            metadata,
431            client_call_details.credentials,
432        )
433        return client_call_details, request_iterator, None
434
435    return _GenericClientInterceptor(intercept_call)
436
437
438class _GenericServerInterceptor(grpc.ServerInterceptor):
439    def __init__(self, fn):
440        self._fn = fn
441
442    def intercept_service(self, continuation, handler_call_details):
443        return self._fn(continuation, handler_call_details)
444
445
446def _filter_server_interceptor(condition, interceptor):
447    def intercept_service(continuation, handler_call_details):
448        if condition(handler_call_details):
449            return interceptor.intercept_service(
450                continuation, handler_call_details
451            )
452        return continuation(handler_call_details)
453
454    return _GenericServerInterceptor(intercept_service)
455
456
457class InterceptorTest(unittest.TestCase):
458    def setUp(self):
459        self._control = test_control.PauseFailControl()
460        self._record = []
461        self._handler = _Handler(self._control, self._record)
462        self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
463
464        conditional_interceptor = _filter_server_interceptor(
465            lambda x: ("secret", "42") in x.invocation_metadata,
466            _LoggingInterceptor("s3", self._record),
467        )
468
469        self._server = grpc.server(
470            self._server_pool,
471            options=(("grpc.so_reuseport", 0),),
472            interceptors=(
473                _LoggingInterceptor("s1", self._record),
474                conditional_interceptor,
475                _ContextVarSettingInterceptor("context-var-value"),
476                _LoggingInterceptor("s2", self._record),
477            ),
478        )
479        port = self._server.add_insecure_port("[::]:0")
480        self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
481        self._server.start()
482
483        self._channel = grpc.insecure_channel("localhost:%d" % port)
484
485    def tearDown(self):
486        self._server.stop(None)
487        self._server_pool.shutdown(wait=True)
488        self._channel.close()
489
490    def testTripleRequestMessagesClientInterceptor(self):
491        def triple(request_iterator):
492            while True:
493                try:
494                    item = next(request_iterator)
495                    yield item
496                    yield item
497                    yield item
498                except StopIteration:
499                    break
500
501        interceptor = _wrap_request_iterator_stream_interceptor(triple)
502        channel = grpc.intercept_channel(self._channel, interceptor)
503        requests = tuple(
504            b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH)
505        )
506
507        multi_callable = _stream_stream_multi_callable(channel)
508        response_iterator = multi_callable(
509            iter(requests),
510            metadata=(
511                (
512                    "test",
513                    "InterceptedStreamRequestBlockingUnaryResponseWithCall",
514                ),
515            ),
516        )
517
518        responses = tuple(response_iterator)
519        self.assertEqual(len(responses), 3 * test_constants.STREAM_LENGTH)
520
521        multi_callable = _stream_stream_multi_callable(self._channel)
522        response_iterator = multi_callable(
523            iter(requests),
524            metadata=(
525                (
526                    "test",
527                    "InterceptedStreamRequestBlockingUnaryResponseWithCall",
528                ),
529            ),
530        )
531
532        responses = tuple(response_iterator)
533        self.assertEqual(len(responses), test_constants.STREAM_LENGTH)
534
535    def testDefectiveClientInterceptor(self):
536        interceptor = _DefectiveClientInterceptor()
537        defective_channel = grpc.intercept_channel(self._channel, interceptor)
538
539        request = b"\x07\x08"
540
541        multi_callable = _unary_unary_multi_callable(defective_channel)
542        call_future = multi_callable.future(
543            request,
544            metadata=(
545                ("test", "InterceptedUnaryRequestBlockingUnaryResponse"),
546            ),
547        )
548
549        self.assertIsNotNone(call_future.exception())
550        self.assertEqual(call_future.code(), grpc.StatusCode.INTERNAL)
551
552    def testInterceptedHeaderManipulationWithServerSideVerification(self):
553        request = b"\x07\x08"
554
555        channel = grpc.intercept_channel(
556            self._channel, _append_request_header_interceptor("secret", "42")
557        )
558        channel = grpc.intercept_channel(
559            channel,
560            _LoggingInterceptor("c1", self._record),
561            _LoggingInterceptor("c2", self._record),
562        )
563
564        self._record[:] = []
565
566        multi_callable = _unary_unary_multi_callable(channel)
567        response, call = multi_callable.with_call(
568            request,
569            metadata=(
570                (
571                    "test",
572                    "InterceptedUnaryRequestBlockingUnaryResponseWithCall",
573                ),
574            ),
575        )
576
577        self.assertSequenceEqual(
578            self._record,
579            [
580                "c1:intercept_unary_unary",
581                "c2:intercept_unary_unary",
582                "s1:intercept_service",
583                "s3:intercept_service",
584                "s2:intercept_service[context-var-value]",
585                "handler:handle_unary_unary[context-var-value]",
586            ],
587        )
588
589    def testInterceptedUnaryRequestBlockingUnaryResponse(self):
590        request = b"\x07\x08"
591
592        self._record[:] = []
593
594        channel = grpc.intercept_channel(
595            self._channel,
596            _LoggingInterceptor("c1", self._record),
597            _LoggingInterceptor("c2", self._record),
598        )
599
600        multi_callable = _unary_unary_multi_callable(channel)
601        multi_callable(
602            request,
603            metadata=(
604                ("test", "InterceptedUnaryRequestBlockingUnaryResponse"),
605            ),
606        )
607
608        self.assertSequenceEqual(
609            self._record,
610            [
611                "c1:intercept_unary_unary",
612                "c2:intercept_unary_unary",
613                "s1:intercept_service",
614                "s2:intercept_service[context-var-value]",
615                "handler:handle_unary_unary[context-var-value]",
616            ],
617        )
618
619    def testInterceptedUnaryRequestBlockingUnaryResponseWithError(self):
620        request = _EXCEPTION_REQUEST
621
622        self._record[:] = []
623
624        channel = grpc.intercept_channel(
625            self._channel,
626            _LoggingInterceptor("c1", self._record),
627            _LoggingInterceptor("c2", self._record),
628        )
629
630        multi_callable = _unary_unary_multi_callable(channel)
631        with self.assertRaises(grpc.RpcError) as exception_context:
632            multi_callable(
633                request,
634                metadata=(
635                    ("test", "InterceptedUnaryRequestBlockingUnaryResponse"),
636                ),
637            )
638        exception = exception_context.exception
639        self.assertFalse(exception.cancelled())
640        self.assertFalse(exception.running())
641        self.assertTrue(exception.done())
642        with self.assertRaises(grpc.RpcError):
643            exception.result()
644        self.assertIsInstance(exception.exception(), grpc.RpcError)
645
646    def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self):
647        request = b"\x07\x08"
648
649        channel = grpc.intercept_channel(
650            self._channel,
651            _LoggingInterceptor("c1", self._record),
652            _LoggingInterceptor("c2", self._record),
653        )
654
655        self._record[:] = []
656
657        multi_callable = _unary_unary_multi_callable(channel)
658        multi_callable.with_call(
659            request,
660            metadata=(
661                (
662                    "test",
663                    "InterceptedUnaryRequestBlockingUnaryResponseWithCall",
664                ),
665            ),
666        )
667
668        self.assertSequenceEqual(
669            self._record,
670            [
671                "c1:intercept_unary_unary",
672                "c2:intercept_unary_unary",
673                "s1:intercept_service",
674                "s2:intercept_service[context-var-value]",
675                "handler:handle_unary_unary[context-var-value]",
676            ],
677        )
678
679    def testInterceptedUnaryRequestFutureUnaryResponse(self):
680        request = b"\x07\x08"
681
682        self._record[:] = []
683        channel = grpc.intercept_channel(
684            self._channel,
685            _LoggingInterceptor("c1", self._record),
686            _LoggingInterceptor("c2", self._record),
687        )
688
689        multi_callable = _unary_unary_multi_callable(channel)
690        response_future = multi_callable.future(
691            request,
692            metadata=(("test", "InterceptedUnaryRequestFutureUnaryResponse"),),
693        )
694        response_future.result()
695
696        self.assertSequenceEqual(
697            self._record,
698            [
699                "c1:intercept_unary_unary",
700                "c2:intercept_unary_unary",
701                "s1:intercept_service",
702                "s2:intercept_service[context-var-value]",
703                "handler:handle_unary_unary[context-var-value]",
704            ],
705        )
706
707    def testInterceptedUnaryRequestStreamResponse(self):
708        request = b"\x37\x58"
709
710        self._record[:] = []
711        channel = grpc.intercept_channel(
712            self._channel,
713            _LoggingInterceptor("c1", self._record),
714            _LoggingInterceptor("c2", self._record),
715        )
716
717        multi_callable = _unary_stream_multi_callable(channel)
718        response_iterator = multi_callable(
719            request,
720            metadata=(("test", "InterceptedUnaryRequestStreamResponse"),),
721        )
722        tuple(response_iterator)
723
724        self.assertSequenceEqual(
725            self._record,
726            [
727                "c1:intercept_unary_stream",
728                "c2:intercept_unary_stream",
729                "s1:intercept_service",
730                "s2:intercept_service[context-var-value]",
731                "handler:handle_unary_stream[context-var-value]",
732            ],
733        )
734
735    def testInterceptedUnaryRequestStreamResponseWithError(self):
736        request = _EXCEPTION_REQUEST
737
738        self._record[:] = []
739        channel = grpc.intercept_channel(
740            self._channel,
741            _LoggingInterceptor("c1", self._record),
742            _LoggingInterceptor("c2", self._record),
743        )
744
745        multi_callable = _unary_stream_multi_callable(channel)
746        response_iterator = multi_callable(
747            request,
748            metadata=(("test", "InterceptedUnaryRequestStreamResponse"),),
749        )
750        with self.assertRaises(grpc.RpcError) as exception_context:
751            tuple(response_iterator)
752        exception = exception_context.exception
753        self.assertFalse(exception.cancelled())
754        self.assertFalse(exception.running())
755        self.assertTrue(exception.done())
756        with self.assertRaises(grpc.RpcError):
757            exception.result()
758        self.assertIsInstance(exception.exception(), grpc.RpcError)
759
760    def testInterceptedStreamRequestBlockingUnaryResponse(self):
761        requests = tuple(
762            b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH)
763        )
764        request_iterator = iter(requests)
765
766        self._record[:] = []
767        channel = grpc.intercept_channel(
768            self._channel,
769            _LoggingInterceptor("c1", self._record),
770            _LoggingInterceptor("c2", self._record),
771        )
772
773        multi_callable = _stream_unary_multi_callable(channel)
774        multi_callable(
775            request_iterator,
776            metadata=(
777                ("test", "InterceptedStreamRequestBlockingUnaryResponse"),
778            ),
779        )
780
781        self.assertSequenceEqual(
782            self._record,
783            [
784                "c1:intercept_stream_unary",
785                "c2:intercept_stream_unary",
786                "s1:intercept_service",
787                "s2:intercept_service[context-var-value]",
788                "handler:handle_stream_unary[context-var-value]",
789            ],
790        )
791
792    def testInterceptedStreamRequestBlockingUnaryResponseWithCall(self):
793        requests = tuple(
794            b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH)
795        )
796        request_iterator = iter(requests)
797
798        self._record[:] = []
799        channel = grpc.intercept_channel(
800            self._channel,
801            _LoggingInterceptor("c1", self._record),
802            _LoggingInterceptor("c2", self._record),
803        )
804
805        multi_callable = _stream_unary_multi_callable(channel)
806        multi_callable.with_call(
807            request_iterator,
808            metadata=(
809                (
810                    "test",
811                    "InterceptedStreamRequestBlockingUnaryResponseWithCall",
812                ),
813            ),
814        )
815
816        self.assertSequenceEqual(
817            self._record,
818            [
819                "c1:intercept_stream_unary",
820                "c2:intercept_stream_unary",
821                "s1:intercept_service",
822                "s2:intercept_service[context-var-value]",
823                "handler:handle_stream_unary[context-var-value]",
824            ],
825        )
826
827    def testInterceptedStreamRequestFutureUnaryResponse(self):
828        requests = tuple(
829            b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH)
830        )
831        request_iterator = iter(requests)
832
833        self._record[:] = []
834        channel = grpc.intercept_channel(
835            self._channel,
836            _LoggingInterceptor("c1", self._record),
837            _LoggingInterceptor("c2", self._record),
838        )
839
840        multi_callable = _stream_unary_multi_callable(channel)
841        response_future = multi_callable.future(
842            request_iterator,
843            metadata=(("test", "InterceptedStreamRequestFutureUnaryResponse"),),
844        )
845        response_future.result()
846
847        self.assertSequenceEqual(
848            self._record,
849            [
850                "c1:intercept_stream_unary",
851                "c2:intercept_stream_unary",
852                "s1:intercept_service",
853                "s2:intercept_service[context-var-value]",
854                "handler:handle_stream_unary[context-var-value]",
855            ],
856        )
857
858    def testInterceptedStreamRequestFutureUnaryResponseWithError(self):
859        requests = tuple(
860            _EXCEPTION_REQUEST for _ in range(test_constants.STREAM_LENGTH)
861        )
862        request_iterator = iter(requests)
863
864        self._record[:] = []
865        channel = grpc.intercept_channel(
866            self._channel,
867            _LoggingInterceptor("c1", self._record),
868            _LoggingInterceptor("c2", self._record),
869        )
870
871        multi_callable = _stream_unary_multi_callable(channel)
872        response_future = multi_callable.future(
873            request_iterator,
874            metadata=(("test", "InterceptedStreamRequestFutureUnaryResponse"),),
875        )
876        with self.assertRaises(grpc.RpcError) as exception_context:
877            response_future.result()
878        exception = exception_context.exception
879        self.assertFalse(exception.cancelled())
880        self.assertFalse(exception.running())
881        self.assertTrue(exception.done())
882        with self.assertRaises(grpc.RpcError):
883            exception.result()
884        self.assertIsInstance(exception.exception(), grpc.RpcError)
885
886    def testInterceptedStreamRequestStreamResponse(self):
887        requests = tuple(
888            b"\x77\x58" for _ in range(test_constants.STREAM_LENGTH)
889        )
890        request_iterator = iter(requests)
891
892        self._record[:] = []
893        channel = grpc.intercept_channel(
894            self._channel,
895            _LoggingInterceptor("c1", self._record),
896            _LoggingInterceptor("c2", self._record),
897        )
898
899        multi_callable = _stream_stream_multi_callable(channel)
900        response_iterator = multi_callable(
901            request_iterator,
902            metadata=(("test", "InterceptedStreamRequestStreamResponse"),),
903        )
904        tuple(response_iterator)
905
906        self.assertSequenceEqual(
907            self._record,
908            [
909                "c1:intercept_stream_stream",
910                "c2:intercept_stream_stream",
911                "s1:intercept_service",
912                "s2:intercept_service[context-var-value]",
913                "handler:handle_stream_stream[context-var-value]",
914            ],
915        )
916
917    def testInterceptedStreamRequestStreamResponseWithError(self):
918        requests = tuple(
919            _EXCEPTION_REQUEST for _ in range(test_constants.STREAM_LENGTH)
920        )
921        request_iterator = iter(requests)
922
923        self._record[:] = []
924        channel = grpc.intercept_channel(
925            self._channel,
926            _LoggingInterceptor("c1", self._record),
927            _LoggingInterceptor("c2", self._record),
928        )
929
930        multi_callable = _stream_stream_multi_callable(channel)
931        response_iterator = multi_callable(
932            request_iterator,
933            metadata=(("test", "InterceptedStreamRequestStreamResponse"),),
934        )
935        with self.assertRaises(grpc.RpcError) as exception_context:
936            tuple(response_iterator)
937        exception = exception_context.exception
938        self.assertFalse(exception.cancelled())
939        self.assertFalse(exception.running())
940        self.assertTrue(exception.done())
941        with self.assertRaises(grpc.RpcError):
942            exception.result()
943        self.assertIsInstance(exception.exception(), grpc.RpcError)
944
945
946if __name__ == "__main__":
947    logging.basicConfig()
948    unittest.main(verbosity=2)
949