• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test of propagation of contextvars to AuthMetadataPlugin threads."""
15
16import contextlib
17import logging
18import os
19import queue
20import tempfile
21import threading
22import unittest
23
24import grpc
25
26from tests.unit import test_common
27
28_SERVICE_NAME = "test"
29_UNARY_UNARY = "UnaryUnary"
30_REQUEST = b"0000"
31_UDS_PATH = os.path.join(tempfile.mkdtemp(), "grpc_fullstack_test.sock")
32
33
34def _unary_unary_handler(request, context):
35    return request
36
37
38def contextvars_supported():
39    try:
40        import contextvars
41
42        return True
43    except ImportError:
44        return False
45
46
47_METHOD_HANDLERS = {
48    _UNARY_UNARY: grpc.unary_unary_rpc_method_handler(_unary_unary_handler)
49}
50
51
52@contextlib.contextmanager
53def _server():
54    try:
55        server = test_common.test_server()
56        server.add_registered_method_handlers(_SERVICE_NAME, _METHOD_HANDLERS)
57        server_creds = grpc.local_server_credentials(
58            grpc.LocalConnectionType.UDS
59        )
60        server.add_secure_port(f"unix:{_UDS_PATH}", server_creds)
61        server.start()
62        yield _UDS_PATH
63    finally:
64        server.stop(None)
65        if os.path.exists(_UDS_PATH):
66            os.remove(_UDS_PATH)
67
68
69if contextvars_supported():
70    import contextvars
71
72    _EXPECTED_VALUE = 24601
73    test_var = contextvars.ContextVar("test_var", default=None)
74
75    def set_up_expected_context():
76        test_var.set(_EXPECTED_VALUE)
77
78    class TestCallCredentials(grpc.AuthMetadataPlugin):
79        def __call__(self, context, callback):
80            if (
81                test_var.get() != _EXPECTED_VALUE
82                and not test_common.running_under_gevent()
83            ):
84                # contextvars do not work under gevent, but the rest of this
85                # test is still valuable as a test of concurrent runs of the
86                # metadata credentials code path.
87                raise AssertionError(
88                    "{} != {}".format(test_var.get(), _EXPECTED_VALUE)
89                )
90            callback((), None)
91
92        def assert_called(self, test):
93            test.assertTrue(self._invoked)
94            test.assertEqual(_EXPECTED_VALUE, self._recorded_value)
95
96else:
97
98    def set_up_expected_context():
99        pass
100
101    class TestCallCredentials(grpc.AuthMetadataPlugin):
102        def __call__(self, context, callback):
103            callback((), None)
104
105
106# TODO(https://github.com/grpc/grpc/issues/22257)
107@unittest.skipIf(os.name == "nt", "LocalCredentials not supported on Windows.")
108class ContextVarsPropagationTest(unittest.TestCase):
109    def test_propagation_to_auth_plugin(self):
110        set_up_expected_context()
111        with _server() as uds_path:
112            local_credentials = grpc.local_channel_credentials(
113                grpc.LocalConnectionType.UDS
114            )
115            test_call_credentials = TestCallCredentials()
116            call_credentials = grpc.metadata_call_credentials(
117                test_call_credentials, "test call credentials"
118            )
119            composite_credentials = grpc.composite_channel_credentials(
120                local_credentials, call_credentials
121            )
122            with grpc.secure_channel(
123                f"unix:{uds_path}", composite_credentials
124            ) as channel:
125                stub = channel.unary_unary(
126                    grpc._common.fully_qualified_method(
127                        _SERVICE_NAME, _UNARY_UNARY
128                    ),
129                    _registered_method=True,
130                )
131                response = stub(_REQUEST, wait_for_ready=True)
132                self.assertEqual(_REQUEST, response)
133
134    def test_concurrent_propagation(self):
135        _THREAD_COUNT = 32
136        _RPC_COUNT = 32
137
138        set_up_expected_context()
139        with _server() as uds_path:
140            local_credentials = grpc.local_channel_credentials(
141                grpc.LocalConnectionType.UDS
142            )
143            test_call_credentials = TestCallCredentials()
144            call_credentials = grpc.metadata_call_credentials(
145                test_call_credentials, "test call credentials"
146            )
147            composite_credentials = grpc.composite_channel_credentials(
148                local_credentials, call_credentials
149            )
150            wait_group = test_common.WaitGroup(_THREAD_COUNT)
151
152            def _run_on_thread(exception_queue):
153                try:
154                    with grpc.secure_channel(
155                        f"unix:{uds_path}", composite_credentials
156                    ) as channel:
157                        stub = channel.unary_unary(
158                            grpc._common.fully_qualified_method(
159                                _SERVICE_NAME, _UNARY_UNARY
160                            ),
161                            _registered_method=True,
162                        )
163                        wait_group.done()
164                        wait_group.wait()
165                        for i in range(_RPC_COUNT):
166                            response = stub(_REQUEST, wait_for_ready=True)
167                            self.assertEqual(_REQUEST, response)
168                except Exception as e:  # pylint: disable=broad-except
169                    exception_queue.put(e)
170
171            threads = []
172
173            for _ in range(_THREAD_COUNT):
174                q = queue.Queue()
175                thread = threading.Thread(target=_run_on_thread, args=(q,))
176                thread.setDaemon(True)
177                thread.start()
178                threads.append((thread, q))
179
180            for thread, q in threads:
181                thread.join()
182                if not q.empty():
183                    raise q.get()
184
185
186if __name__ == "__main__":
187    logging.basicConfig()
188    unittest.main(verbosity=2)
189