1# Copyright 2017 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Test of gRPC Python interceptors.""" 15 16import collections 17from concurrent import futures 18from contextvars import ContextVar 19import itertools 20import logging 21import os 22import threading 23import unittest 24 25import grpc 26from grpc.framework.foundation import logging_pool 27 28from tests.unit import test_common 29from tests.unit.framework.common import test_constants 30from tests.unit.framework.common import test_control 31 32_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 33_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2 :] 34_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 35_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[: len(bytestring) // 3] 36 37_EXCEPTION_REQUEST = b"\x09\x0a" 38 39_SERVICE_NAME = "test" 40_UNARY_UNARY = "UnaryUnary" 41_UNARY_STREAM = "UnaryStream" 42_STREAM_UNARY = "StreamUnary" 43_STREAM_STREAM = "StreamStream" 44 45_TEST_CONTEXT_VAR: ContextVar[str] = ContextVar("") 46 47 48class _ApplicationErrorStandin(Exception): 49 pass 50 51 52class _Callback(object): 53 def __init__(self): 54 self._condition = threading.Condition() 55 self._value = None 56 self._called = False 57 58 def __call__(self, value): 59 with self._condition: 60 self._value = value 61 self._called = True 62 self._condition.notify_all() 63 64 def value(self): 65 with self._condition: 66 while not self._called: 67 self._condition.wait() 68 return self._value 69 70 71class _Handler(object): 72 def __init__(self, control, record): 73 self._control = control 74 self._record = record 75 76 def _append_to_log(self, message: str) -> None: 77 context_var_value = _TEST_CONTEXT_VAR.get("") 78 if context_var_value: 79 context_var_value = "[{}]".format(context_var_value) 80 self._record.append("handler:" + message + context_var_value) 81 82 def handle_unary_unary(self, request, servicer_context): 83 self._append_to_log("handle_unary_unary") 84 self._control.control() 85 if servicer_context is not None: 86 servicer_context.set_trailing_metadata( 87 ( 88 ( 89 "testkey", 90 "testvalue", 91 ), 92 ) 93 ) 94 if request == _EXCEPTION_REQUEST: 95 raise _ApplicationErrorStandin() 96 return request 97 98 def handle_unary_stream(self, request, servicer_context): 99 self._append_to_log("handle_unary_stream") 100 if request == _EXCEPTION_REQUEST: 101 raise _ApplicationErrorStandin() 102 for _ in range(test_constants.STREAM_LENGTH): 103 self._control.control() 104 yield request 105 self._control.control() 106 if servicer_context is not None: 107 servicer_context.set_trailing_metadata( 108 ( 109 ( 110 "testkey", 111 "testvalue", 112 ), 113 ) 114 ) 115 116 def handle_stream_unary(self, request_iterator, servicer_context): 117 self._append_to_log("handle_stream_unary") 118 if servicer_context is not None: 119 servicer_context.invocation_metadata() 120 self._control.control() 121 response_elements = [] 122 for request in request_iterator: 123 self._control.control() 124 response_elements.append(request) 125 self._control.control() 126 if servicer_context is not None: 127 servicer_context.set_trailing_metadata( 128 ( 129 ( 130 "testkey", 131 "testvalue", 132 ), 133 ) 134 ) 135 if _EXCEPTION_REQUEST in response_elements: 136 raise _ApplicationErrorStandin() 137 return b"".join(response_elements) 138 139 def handle_stream_stream(self, request_iterator, servicer_context): 140 self._append_to_log("handle_stream_stream") 141 self._control.control() 142 if servicer_context is not None: 143 servicer_context.set_trailing_metadata( 144 ( 145 ( 146 "testkey", 147 "testvalue", 148 ), 149 ) 150 ) 151 for request in request_iterator: 152 if request == _EXCEPTION_REQUEST: 153 raise _ApplicationErrorStandin() 154 self._control.control() 155 yield request 156 self._control.control() 157 158 159class _MethodHandler(grpc.RpcMethodHandler): 160 def __init__( 161 self, 162 request_streaming, 163 response_streaming, 164 request_deserializer, 165 response_serializer, 166 unary_unary, 167 unary_stream, 168 stream_unary, 169 stream_stream, 170 ): 171 self.request_streaming = request_streaming 172 self.response_streaming = response_streaming 173 self.request_deserializer = request_deserializer 174 self.response_serializer = response_serializer 175 self.unary_unary = unary_unary 176 self.unary_stream = unary_stream 177 self.stream_unary = stream_unary 178 self.stream_stream = stream_stream 179 180 181def get_method_handlers(handler): 182 return { 183 _UNARY_UNARY: _MethodHandler( 184 False, 185 False, 186 None, 187 None, 188 handler.handle_unary_unary, 189 None, 190 None, 191 None, 192 ), 193 _UNARY_STREAM: _MethodHandler( 194 False, 195 True, 196 _DESERIALIZE_REQUEST, 197 _SERIALIZE_RESPONSE, 198 None, 199 handler.handle_unary_stream, 200 None, 201 None, 202 ), 203 _STREAM_UNARY: _MethodHandler( 204 True, 205 False, 206 _DESERIALIZE_REQUEST, 207 _SERIALIZE_RESPONSE, 208 None, 209 None, 210 handler.handle_stream_unary, 211 None, 212 ), 213 _STREAM_STREAM: _MethodHandler( 214 True, 215 True, 216 None, 217 None, 218 None, 219 None, 220 None, 221 handler.handle_stream_stream, 222 ), 223 } 224 225 226def _unary_unary_multi_callable(channel): 227 return channel.unary_unary( 228 grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_UNARY), 229 _registered_method=True, 230 ) 231 232 233def _unary_stream_multi_callable(channel): 234 return channel.unary_stream( 235 grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_STREAM), 236 request_serializer=_SERIALIZE_REQUEST, 237 response_deserializer=_DESERIALIZE_RESPONSE, 238 _registered_method=True, 239 ) 240 241 242def _stream_unary_multi_callable(channel): 243 return channel.stream_unary( 244 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_UNARY), 245 request_serializer=_SERIALIZE_REQUEST, 246 response_deserializer=_DESERIALIZE_RESPONSE, 247 _registered_method=True, 248 ) 249 250 251def _stream_stream_multi_callable(channel): 252 return channel.stream_stream( 253 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_STREAM), 254 _registered_method=True, 255 ) 256 257 258class _ClientCallDetails( 259 collections.namedtuple( 260 "_ClientCallDetails", ("method", "timeout", "metadata", "credentials") 261 ), 262 grpc.ClientCallDetails, 263): 264 pass 265 266 267class _GenericClientInterceptor( 268 grpc.UnaryUnaryClientInterceptor, 269 grpc.UnaryStreamClientInterceptor, 270 grpc.StreamUnaryClientInterceptor, 271 grpc.StreamStreamClientInterceptor, 272): 273 def __init__(self, interceptor_function): 274 self._fn = interceptor_function 275 276 def intercept_unary_unary(self, continuation, client_call_details, request): 277 new_details, new_request_iterator, postprocess = self._fn( 278 client_call_details, iter((request,)), False, False 279 ) 280 response = continuation(new_details, next(new_request_iterator)) 281 return postprocess(response) if postprocess else response 282 283 def intercept_unary_stream( 284 self, continuation, client_call_details, request 285 ): 286 new_details, new_request_iterator, postprocess = self._fn( 287 client_call_details, iter((request,)), False, True 288 ) 289 response_it = continuation(new_details, new_request_iterator) 290 return postprocess(response_it) if postprocess else response_it 291 292 def intercept_stream_unary( 293 self, continuation, client_call_details, request_iterator 294 ): 295 new_details, new_request_iterator, postprocess = self._fn( 296 client_call_details, request_iterator, True, False 297 ) 298 response = continuation(new_details, next(new_request_iterator)) 299 return postprocess(response) if postprocess else response 300 301 def intercept_stream_stream( 302 self, continuation, client_call_details, request_iterator 303 ): 304 new_details, new_request_iterator, postprocess = self._fn( 305 client_call_details, request_iterator, True, True 306 ) 307 response_it = continuation(new_details, new_request_iterator) 308 return postprocess(response_it) if postprocess else response_it 309 310 311class _ContextVarSettingInterceptor(grpc.ServerInterceptor): 312 def __init__(self, value: str) -> None: 313 self.value = value 314 315 def intercept_service(self, continuation, handler_call_details): 316 old_value = _TEST_CONTEXT_VAR.get("") 317 assert ( 318 not old_value 319 ), "expected context var to have no value but was '{}'".format( 320 old_value 321 ) 322 _TEST_CONTEXT_VAR.set(self.value) 323 return continuation(handler_call_details) 324 325 326class _LoggingInterceptor( 327 grpc.ServerInterceptor, 328 grpc.UnaryUnaryClientInterceptor, 329 grpc.UnaryStreamClientInterceptor, 330 grpc.StreamUnaryClientInterceptor, 331 grpc.StreamStreamClientInterceptor, 332): 333 def __init__(self, tag, record): 334 self.tag = tag 335 self.record = record 336 337 def _append_to_log(self, message: str) -> None: 338 context_var_value = _TEST_CONTEXT_VAR.get("") 339 if context_var_value: 340 context_var_value = "[{}]".format(context_var_value) 341 self.record.append(self.tag + message + context_var_value) 342 343 def intercept_service(self, continuation, handler_call_details): 344 if "check_handler_call_details" in self.tag: 345 self._append_to_log(f":method={handler_call_details.method}") 346 else: 347 self._append_to_log(":intercept_service") 348 349 return continuation(handler_call_details) 350 351 def intercept_unary_unary(self, continuation, client_call_details, request): 352 self._append_to_log(":intercept_unary_unary") 353 result = continuation(client_call_details, request) 354 assert isinstance( 355 result, grpc.Call 356 ), "{} ({}) is not an instance of grpc.Call".format( 357 result, type(result) 358 ) 359 assert isinstance( 360 result, grpc.Future 361 ), "{} ({}) is not an instance of grpc.Future".format( 362 result, type(result) 363 ) 364 return result 365 366 def intercept_unary_stream( 367 self, continuation, client_call_details, request 368 ): 369 self._append_to_log(":intercept_unary_stream") 370 return continuation(client_call_details, request) 371 372 def intercept_stream_unary( 373 self, continuation, client_call_details, request_iterator 374 ): 375 self._append_to_log(":intercept_stream_unary") 376 result = continuation(client_call_details, request_iterator) 377 assert isinstance( 378 result, grpc.Call 379 ), "{} is not an instance of grpc.Call".format(result) 380 assert isinstance( 381 result, grpc.Future 382 ), "{} is not an instance of grpc.Future".format(result) 383 return result 384 385 def intercept_stream_stream( 386 self, continuation, client_call_details, request_iterator 387 ): 388 self._append_to_log(":intercept_stream_stream") 389 return continuation(client_call_details, request_iterator) 390 391 392class _DefectiveClientInterceptor(grpc.UnaryUnaryClientInterceptor): 393 def intercept_unary_unary( 394 self, ignored_continuation, ignored_client_call_details, ignored_request 395 ): 396 raise test_control.Defect() 397 398 399def _wrap_request_iterator_stream_interceptor(wrapper): 400 def intercept_call( 401 client_call_details, 402 request_iterator, 403 request_streaming, 404 ignored_response_streaming, 405 ): 406 if request_streaming: 407 return client_call_details, wrapper(request_iterator), None 408 else: 409 return client_call_details, request_iterator, None 410 411 return _GenericClientInterceptor(intercept_call) 412 413 414def _append_request_header_interceptor(header, value): 415 def intercept_call( 416 client_call_details, 417 request_iterator, 418 ignored_request_streaming, 419 ignored_response_streaming, 420 ): 421 metadata = [] 422 if client_call_details.metadata: 423 metadata = list(client_call_details.metadata) 424 metadata.append( 425 ( 426 header, 427 value, 428 ) 429 ) 430 client_call_details = _ClientCallDetails( 431 client_call_details.method, 432 client_call_details.timeout, 433 metadata, 434 client_call_details.credentials, 435 ) 436 return client_call_details, request_iterator, None 437 438 return _GenericClientInterceptor(intercept_call) 439 440 441class _GenericServerInterceptor(grpc.ServerInterceptor): 442 def __init__(self, fn): 443 self._fn = fn 444 445 def intercept_service(self, continuation, handler_call_details): 446 return self._fn(continuation, handler_call_details) 447 448 449def _filter_server_interceptor(condition, interceptor): 450 def intercept_service(continuation, handler_call_details): 451 if condition(handler_call_details): 452 return interceptor.intercept_service( 453 continuation, handler_call_details 454 ) 455 return continuation(handler_call_details) 456 457 return _GenericServerInterceptor(intercept_service) 458 459 460class InterceptorTest(unittest.TestCase): 461 def setUp(self): 462 self._control = test_control.PauseFailControl() 463 self._record = [] 464 self._handler = _Handler(self._control, self._record) 465 self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) 466 467 conditional_interceptor = _filter_server_interceptor( 468 lambda x: ("secret", "42") in x.invocation_metadata, 469 _LoggingInterceptor("s3", self._record), 470 ) 471 472 conditional_interceptor_check_handler_call_details = ( 473 _filter_server_interceptor( 474 lambda x: ("test_case", "check_handler_call_details") 475 in x.invocation_metadata, 476 _LoggingInterceptor( 477 "s4:check_handler_call_details", self._record 478 ), 479 ) 480 ) 481 482 self._server = grpc.server( 483 self._server_pool, 484 options=(("grpc.so_reuseport", 0),), 485 interceptors=( 486 _LoggingInterceptor("s1", self._record), 487 conditional_interceptor, 488 conditional_interceptor_check_handler_call_details, 489 _ContextVarSettingInterceptor("context-var-value"), 490 _LoggingInterceptor("s2", self._record), 491 ), 492 ) 493 port = self._server.add_insecure_port("[::]:0") 494 self._server.add_registered_method_handlers( 495 _SERVICE_NAME, get_method_handlers(self._handler) 496 ) 497 self._server.start() 498 499 self._channel = grpc.insecure_channel("localhost:%d" % port) 500 501 def tearDown(self): 502 self._server.stop(None) 503 self._server_pool.shutdown(wait=True) 504 self._channel.close() 505 506 def testTripleRequestMessagesClientInterceptor(self): 507 def triple(request_iterator): 508 while True: 509 try: 510 item = next(request_iterator) 511 yield item 512 yield item 513 yield item 514 except StopIteration: 515 break 516 517 interceptor = _wrap_request_iterator_stream_interceptor(triple) 518 channel = grpc.intercept_channel(self._channel, interceptor) 519 requests = tuple( 520 b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) 521 ) 522 523 multi_callable = _stream_stream_multi_callable(channel) 524 response_iterator = multi_callable( 525 iter(requests), 526 metadata=( 527 ( 528 "test", 529 "InterceptedStreamRequestBlockingUnaryResponseWithCall", 530 ), 531 ), 532 ) 533 534 responses = tuple(response_iterator) 535 self.assertEqual(len(responses), 3 * test_constants.STREAM_LENGTH) 536 537 multi_callable = _stream_stream_multi_callable(self._channel) 538 response_iterator = multi_callable( 539 iter(requests), 540 metadata=( 541 ( 542 "test", 543 "InterceptedStreamRequestBlockingUnaryResponseWithCall", 544 ), 545 ), 546 ) 547 548 responses = tuple(response_iterator) 549 self.assertEqual(len(responses), test_constants.STREAM_LENGTH) 550 551 def testDefectiveClientInterceptor(self): 552 interceptor = _DefectiveClientInterceptor() 553 defective_channel = grpc.intercept_channel(self._channel, interceptor) 554 555 request = b"\x07\x08" 556 557 multi_callable = _unary_unary_multi_callable(defective_channel) 558 call_future = multi_callable.future( 559 request, 560 metadata=( 561 ("test", "InterceptedUnaryRequestBlockingUnaryResponse"), 562 ), 563 ) 564 565 self.assertIsNotNone(call_future.exception()) 566 self.assertEqual(call_future.code(), grpc.StatusCode.INTERNAL) 567 568 def testInterceptedHeaderManipulationWithServerSideVerification(self): 569 request = b"\x07\x08" 570 571 channel = grpc.intercept_channel( 572 self._channel, _append_request_header_interceptor("secret", "42") 573 ) 574 channel = grpc.intercept_channel( 575 channel, 576 _LoggingInterceptor("c1", self._record), 577 _LoggingInterceptor("c2", self._record), 578 ) 579 580 self._record[:] = [] 581 582 multi_callable = _unary_unary_multi_callable(channel) 583 response, call = multi_callable.with_call( 584 request, 585 metadata=( 586 ( 587 "test", 588 "InterceptedUnaryRequestBlockingUnaryResponseWithCall", 589 ), 590 ), 591 ) 592 593 self.assertSequenceEqual( 594 self._record, 595 [ 596 "c1:intercept_unary_unary", 597 "c2:intercept_unary_unary", 598 "s1:intercept_service", 599 "s3:intercept_service", 600 "s2:intercept_service[context-var-value]", 601 "handler:handle_unary_unary[context-var-value]", 602 ], 603 ) 604 605 def testInterceptedUnaryRequestBlockingUnaryResponse(self): 606 request = b"\x07\x08" 607 608 self._record[:] = [] 609 610 channel = grpc.intercept_channel( 611 self._channel, 612 _LoggingInterceptor("c1", self._record), 613 _LoggingInterceptor("c2", self._record), 614 ) 615 616 multi_callable = _unary_unary_multi_callable(channel) 617 multi_callable( 618 request, 619 metadata=( 620 ("test", "InterceptedUnaryRequestBlockingUnaryResponse"), 621 ), 622 ) 623 624 self.assertSequenceEqual( 625 self._record, 626 [ 627 "c1:intercept_unary_unary", 628 "c2:intercept_unary_unary", 629 "s1:intercept_service", 630 "s2:intercept_service[context-var-value]", 631 "handler:handle_unary_unary[context-var-value]", 632 ], 633 ) 634 635 def testInterceptedUnaryRequestBlockingUnaryResponseWithError(self): 636 request = _EXCEPTION_REQUEST 637 638 self._record[:] = [] 639 640 channel = grpc.intercept_channel( 641 self._channel, 642 _LoggingInterceptor("c1", self._record), 643 _LoggingInterceptor("c2", self._record), 644 ) 645 646 multi_callable = _unary_unary_multi_callable(channel) 647 with self.assertRaises(grpc.RpcError) as exception_context: 648 multi_callable( 649 request, 650 metadata=( 651 ("test", "InterceptedUnaryRequestBlockingUnaryResponse"), 652 ), 653 ) 654 exception = exception_context.exception 655 self.assertFalse(exception.cancelled()) 656 self.assertFalse(exception.running()) 657 self.assertTrue(exception.done()) 658 with self.assertRaises(grpc.RpcError): 659 exception.result() 660 self.assertIsInstance(exception.exception(), grpc.RpcError) 661 662 def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self): 663 request = b"\x07\x08" 664 665 channel = grpc.intercept_channel( 666 self._channel, 667 _LoggingInterceptor("c1", self._record), 668 _LoggingInterceptor("c2", self._record), 669 ) 670 671 self._record[:] = [] 672 673 multi_callable = _unary_unary_multi_callable(channel) 674 multi_callable.with_call( 675 request, 676 metadata=( 677 ( 678 "test", 679 "InterceptedUnaryRequestBlockingUnaryResponseWithCall", 680 ), 681 ), 682 ) 683 684 self.assertSequenceEqual( 685 self._record, 686 [ 687 "c1:intercept_unary_unary", 688 "c2:intercept_unary_unary", 689 "s1:intercept_service", 690 "s2:intercept_service[context-var-value]", 691 "handler:handle_unary_unary[context-var-value]", 692 ], 693 ) 694 695 def testInterceptedUnaryRequestFutureUnaryResponse(self): 696 request = b"\x07\x08" 697 698 self._record[:] = [] 699 channel = grpc.intercept_channel( 700 self._channel, 701 _LoggingInterceptor("c1", self._record), 702 _LoggingInterceptor("c2", self._record), 703 ) 704 705 multi_callable = _unary_unary_multi_callable(channel) 706 response_future = multi_callable.future( 707 request, 708 metadata=(("test", "InterceptedUnaryRequestFutureUnaryResponse"),), 709 ) 710 response_future.result() 711 712 self.assertSequenceEqual( 713 self._record, 714 [ 715 "c1:intercept_unary_unary", 716 "c2:intercept_unary_unary", 717 "s1:intercept_service", 718 "s2:intercept_service[context-var-value]", 719 "handler:handle_unary_unary[context-var-value]", 720 ], 721 ) 722 723 def testInterceptedUnaryRequestStreamResponse(self): 724 request = b"\x37\x58" 725 726 self._record[:] = [] 727 channel = grpc.intercept_channel( 728 self._channel, 729 _LoggingInterceptor("c1", self._record), 730 _LoggingInterceptor("c2", self._record), 731 ) 732 733 multi_callable = _unary_stream_multi_callable(channel) 734 response_iterator = multi_callable( 735 request, 736 metadata=(("test", "InterceptedUnaryRequestStreamResponse"),), 737 ) 738 tuple(response_iterator) 739 740 self.assertSequenceEqual( 741 self._record, 742 [ 743 "c1:intercept_unary_stream", 744 "c2:intercept_unary_stream", 745 "s1:intercept_service", 746 "s2:intercept_service[context-var-value]", 747 "handler:handle_unary_stream[context-var-value]", 748 ], 749 ) 750 751 def testInterceptedUnaryRequestStreamResponseWithError(self): 752 request = _EXCEPTION_REQUEST 753 754 self._record[:] = [] 755 channel = grpc.intercept_channel( 756 self._channel, 757 _LoggingInterceptor("c1", self._record), 758 _LoggingInterceptor("c2", self._record), 759 ) 760 761 multi_callable = _unary_stream_multi_callable(channel) 762 response_iterator = multi_callable( 763 request, 764 metadata=(("test", "InterceptedUnaryRequestStreamResponse"),), 765 ) 766 with self.assertRaises(grpc.RpcError) as exception_context: 767 tuple(response_iterator) 768 exception = exception_context.exception 769 self.assertFalse(exception.cancelled()) 770 self.assertFalse(exception.running()) 771 self.assertTrue(exception.done()) 772 with self.assertRaises(grpc.RpcError): 773 exception.result() 774 self.assertIsInstance(exception.exception(), grpc.RpcError) 775 776 def testInterceptedStreamRequestBlockingUnaryResponse(self): 777 requests = tuple( 778 b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) 779 ) 780 request_iterator = iter(requests) 781 782 self._record[:] = [] 783 channel = grpc.intercept_channel( 784 self._channel, 785 _LoggingInterceptor("c1", self._record), 786 _LoggingInterceptor("c2", self._record), 787 ) 788 789 multi_callable = _stream_unary_multi_callable(channel) 790 multi_callable( 791 request_iterator, 792 metadata=( 793 ("test", "InterceptedStreamRequestBlockingUnaryResponse"), 794 ), 795 ) 796 797 self.assertSequenceEqual( 798 self._record, 799 [ 800 "c1:intercept_stream_unary", 801 "c2:intercept_stream_unary", 802 "s1:intercept_service", 803 "s2:intercept_service[context-var-value]", 804 "handler:handle_stream_unary[context-var-value]", 805 ], 806 ) 807 808 def testInterceptedStreamRequestBlockingUnaryResponseWithCall(self): 809 requests = tuple( 810 b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) 811 ) 812 request_iterator = iter(requests) 813 814 self._record[:] = [] 815 channel = grpc.intercept_channel( 816 self._channel, 817 _LoggingInterceptor("c1", self._record), 818 _LoggingInterceptor("c2", self._record), 819 ) 820 821 multi_callable = _stream_unary_multi_callable(channel) 822 multi_callable.with_call( 823 request_iterator, 824 metadata=( 825 ( 826 "test", 827 "InterceptedStreamRequestBlockingUnaryResponseWithCall", 828 ), 829 ), 830 ) 831 832 self.assertSequenceEqual( 833 self._record, 834 [ 835 "c1:intercept_stream_unary", 836 "c2:intercept_stream_unary", 837 "s1:intercept_service", 838 "s2:intercept_service[context-var-value]", 839 "handler:handle_stream_unary[context-var-value]", 840 ], 841 ) 842 843 def testInterceptedStreamRequestFutureUnaryResponse(self): 844 requests = tuple( 845 b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) 846 ) 847 request_iterator = iter(requests) 848 849 self._record[:] = [] 850 channel = grpc.intercept_channel( 851 self._channel, 852 _LoggingInterceptor("c1", self._record), 853 _LoggingInterceptor("c2", self._record), 854 ) 855 856 multi_callable = _stream_unary_multi_callable(channel) 857 response_future = multi_callable.future( 858 request_iterator, 859 metadata=(("test", "InterceptedStreamRequestFutureUnaryResponse"),), 860 ) 861 response_future.result() 862 863 self.assertSequenceEqual( 864 self._record, 865 [ 866 "c1:intercept_stream_unary", 867 "c2:intercept_stream_unary", 868 "s1:intercept_service", 869 "s2:intercept_service[context-var-value]", 870 "handler:handle_stream_unary[context-var-value]", 871 ], 872 ) 873 874 def testInterceptedStreamRequestFutureUnaryResponseWithError(self): 875 requests = tuple( 876 _EXCEPTION_REQUEST for _ in range(test_constants.STREAM_LENGTH) 877 ) 878 request_iterator = iter(requests) 879 880 self._record[:] = [] 881 channel = grpc.intercept_channel( 882 self._channel, 883 _LoggingInterceptor("c1", self._record), 884 _LoggingInterceptor("c2", self._record), 885 ) 886 887 multi_callable = _stream_unary_multi_callable(channel) 888 response_future = multi_callable.future( 889 request_iterator, 890 metadata=(("test", "InterceptedStreamRequestFutureUnaryResponse"),), 891 ) 892 with self.assertRaises(grpc.RpcError) as exception_context: 893 response_future.result() 894 exception = exception_context.exception 895 self.assertFalse(exception.cancelled()) 896 self.assertFalse(exception.running()) 897 self.assertTrue(exception.done()) 898 with self.assertRaises(grpc.RpcError): 899 exception.result() 900 self.assertIsInstance(exception.exception(), grpc.RpcError) 901 902 def testInterceptedStreamRequestStreamResponse(self): 903 requests = tuple( 904 b"\x77\x58" for _ in range(test_constants.STREAM_LENGTH) 905 ) 906 request_iterator = iter(requests) 907 908 self._record[:] = [] 909 channel = grpc.intercept_channel( 910 self._channel, 911 _LoggingInterceptor("c1", self._record), 912 _LoggingInterceptor("c2", self._record), 913 ) 914 915 multi_callable = _stream_stream_multi_callable(channel) 916 response_iterator = multi_callable( 917 request_iterator, 918 metadata=(("test", "InterceptedStreamRequestStreamResponse"),), 919 ) 920 tuple(response_iterator) 921 922 self.assertSequenceEqual( 923 self._record, 924 [ 925 "c1:intercept_stream_stream", 926 "c2:intercept_stream_stream", 927 "s1:intercept_service", 928 "s2:intercept_service[context-var-value]", 929 "handler:handle_stream_stream[context-var-value]", 930 ], 931 ) 932 933 def testInterceptedStreamRequestStreamResponseWithError(self): 934 requests = tuple( 935 _EXCEPTION_REQUEST for _ in range(test_constants.STREAM_LENGTH) 936 ) 937 request_iterator = iter(requests) 938 939 self._record[:] = [] 940 channel = grpc.intercept_channel( 941 self._channel, 942 _LoggingInterceptor("c1", self._record), 943 _LoggingInterceptor("c2", self._record), 944 ) 945 946 multi_callable = _stream_stream_multi_callable(channel) 947 response_iterator = multi_callable( 948 request_iterator, 949 metadata=(("test", "InterceptedStreamRequestStreamResponse"),), 950 ) 951 with self.assertRaises(grpc.RpcError) as exception_context: 952 tuple(response_iterator) 953 exception = exception_context.exception 954 self.assertFalse(exception.cancelled()) 955 self.assertFalse(exception.running()) 956 self.assertTrue(exception.done()) 957 with self.assertRaises(grpc.RpcError): 958 exception.result() 959 self.assertIsInstance(exception.exception(), grpc.RpcError) 960 961 def testServerInterceptorWithCorrectHandlerCallDetails(self): 962 request = b"\x07\x08" 963 964 self._record[:] = [] 965 966 channel = grpc.intercept_channel( 967 self._channel, 968 _append_request_header_interceptor( 969 "test_case", "check_handler_call_details" 970 ), 971 ) 972 973 multi_callable = _unary_unary_multi_callable(channel) 974 multi_callable( 975 request, 976 metadata=( 977 ("test", "InterceptedUnaryRequestBlockingUnaryResponse"), 978 ), 979 ) 980 981 self.assertSequenceEqual( 982 self._record, 983 [ 984 "s1:intercept_service", 985 "s4:check_handler_call_details:method=/test/UnaryUnary", 986 "s2:intercept_service[context-var-value]", 987 "handler:handle_unary_unary[context-var-value]", 988 ], 989 ) 990 991 992if __name__ == "__main__": 993 logging.basicConfig() 994 unittest.main(verbosity=2) 995