1# Copyright 2020 The gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Test of propagation of contextvars to AuthMetadataPlugin threads.""" 15 16import contextlib 17import logging 18import os 19import queue 20import tempfile 21import threading 22import unittest 23 24import grpc 25 26from tests.unit import test_common 27 28_SERVICE_NAME = "test" 29_UNARY_UNARY = "UnaryUnary" 30_REQUEST = b"0000" 31_UDS_PATH = os.path.join(tempfile.mkdtemp(), "grpc_fullstack_test.sock") 32 33 34def _unary_unary_handler(request, context): 35 return request 36 37 38def contextvars_supported(): 39 try: 40 import contextvars 41 42 return True 43 except ImportError: 44 return False 45 46 47_METHOD_HANDLERS = { 48 _UNARY_UNARY: grpc.unary_unary_rpc_method_handler(_unary_unary_handler) 49} 50 51 52@contextlib.contextmanager 53def _server(): 54 try: 55 server = test_common.test_server() 56 server.add_registered_method_handlers(_SERVICE_NAME, _METHOD_HANDLERS) 57 server_creds = grpc.local_server_credentials( 58 grpc.LocalConnectionType.UDS 59 ) 60 server.add_secure_port(f"unix:{_UDS_PATH}", server_creds) 61 server.start() 62 yield _UDS_PATH 63 finally: 64 server.stop(None) 65 if os.path.exists(_UDS_PATH): 66 os.remove(_UDS_PATH) 67 68 69if contextvars_supported(): 70 import contextvars 71 72 _EXPECTED_VALUE = 24601 73 test_var = contextvars.ContextVar("test_var", default=None) 74 75 def set_up_expected_context(): 76 test_var.set(_EXPECTED_VALUE) 77 78 class TestCallCredentials(grpc.AuthMetadataPlugin): 79 def __call__(self, context, callback): 80 if ( 81 test_var.get() != _EXPECTED_VALUE 82 and not test_common.running_under_gevent() 83 ): 84 # contextvars do not work under gevent, but the rest of this 85 # test is still valuable as a test of concurrent runs of the 86 # metadata credentials code path. 87 raise AssertionError( 88 "{} != {}".format(test_var.get(), _EXPECTED_VALUE) 89 ) 90 callback((), None) 91 92 def assert_called(self, test): 93 test.assertTrue(self._invoked) 94 test.assertEqual(_EXPECTED_VALUE, self._recorded_value) 95 96else: 97 98 def set_up_expected_context(): 99 pass 100 101 class TestCallCredentials(grpc.AuthMetadataPlugin): 102 def __call__(self, context, callback): 103 callback((), None) 104 105 106# TODO(https://github.com/grpc/grpc/issues/22257) 107@unittest.skipIf(os.name == "nt", "LocalCredentials not supported on Windows.") 108class ContextVarsPropagationTest(unittest.TestCase): 109 def test_propagation_to_auth_plugin(self): 110 set_up_expected_context() 111 with _server() as uds_path: 112 local_credentials = grpc.local_channel_credentials( 113 grpc.LocalConnectionType.UDS 114 ) 115 test_call_credentials = TestCallCredentials() 116 call_credentials = grpc.metadata_call_credentials( 117 test_call_credentials, "test call credentials" 118 ) 119 composite_credentials = grpc.composite_channel_credentials( 120 local_credentials, call_credentials 121 ) 122 with grpc.secure_channel( 123 f"unix:{uds_path}", composite_credentials 124 ) as channel: 125 stub = channel.unary_unary( 126 grpc._common.fully_qualified_method( 127 _SERVICE_NAME, _UNARY_UNARY 128 ), 129 _registered_method=True, 130 ) 131 response = stub(_REQUEST, wait_for_ready=True) 132 self.assertEqual(_REQUEST, response) 133 134 def test_concurrent_propagation(self): 135 _THREAD_COUNT = 32 136 _RPC_COUNT = 32 137 138 set_up_expected_context() 139 with _server() as uds_path: 140 local_credentials = grpc.local_channel_credentials( 141 grpc.LocalConnectionType.UDS 142 ) 143 test_call_credentials = TestCallCredentials() 144 call_credentials = grpc.metadata_call_credentials( 145 test_call_credentials, "test call credentials" 146 ) 147 composite_credentials = grpc.composite_channel_credentials( 148 local_credentials, call_credentials 149 ) 150 wait_group = test_common.WaitGroup(_THREAD_COUNT) 151 152 def _run_on_thread(exception_queue): 153 try: 154 with grpc.secure_channel( 155 f"unix:{uds_path}", composite_credentials 156 ) as channel: 157 stub = channel.unary_unary( 158 grpc._common.fully_qualified_method( 159 _SERVICE_NAME, _UNARY_UNARY 160 ), 161 _registered_method=True, 162 ) 163 wait_group.done() 164 wait_group.wait() 165 for i in range(_RPC_COUNT): 166 response = stub(_REQUEST, wait_for_ready=True) 167 self.assertEqual(_REQUEST, response) 168 except Exception as e: # pylint: disable=broad-except 169 exception_queue.put(e) 170 171 threads = [] 172 173 for _ in range(_THREAD_COUNT): 174 q = queue.Queue() 175 thread = threading.Thread(target=_run_on_thread, args=(q,)) 176 thread.setDaemon(True) 177 thread.start() 178 threads.append((thread, q)) 179 180 for thread, q in threads: 181 thread.join() 182 if not q.empty(): 183 raise q.get() 184 185 186if __name__ == "__main__": 187 logging.basicConfig() 188 unittest.main(verbosity=2) 189