• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import datetime
16import os
17import time
18
19import mock
20import pytest
21
22from google.auth import _helpers
23from google.auth import credentials
24from google.auth import environment_vars
25from google.auth import exceptions
26from google.auth import transport
27from google.oauth2 import service_account
28
29try:
30    # pylint: disable=ungrouped-imports
31    import grpc
32    import google.auth.transport.grpc
33
34    HAS_GRPC = True
35except ImportError:  # pragma: NO COVER
36    HAS_GRPC = False
37
38DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
39METADATA_PATH = os.path.join(DATA_DIR, "context_aware_metadata.json")
40with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh:
41    PRIVATE_KEY_BYTES = fh.read()
42with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh:
43    PUBLIC_CERT_BYTES = fh.read()
44
45pytestmark = pytest.mark.skipif(not HAS_GRPC, reason="gRPC is unavailable.")
46
47
48class CredentialsStub(credentials.Credentials):
49    def __init__(self, token="token"):
50        super(CredentialsStub, self).__init__()
51        self.token = token
52        self.expiry = None
53
54    def refresh(self, request):
55        self.token += "1"
56
57    def with_quota_project(self, quota_project_id):
58        raise NotImplementedError()
59
60
61class TestAuthMetadataPlugin(object):
62    def test_call_no_refresh(self):
63        credentials = CredentialsStub()
64        request = mock.create_autospec(transport.Request)
65
66        plugin = google.auth.transport.grpc.AuthMetadataPlugin(credentials, request)
67
68        context = mock.create_autospec(grpc.AuthMetadataContext, instance=True)
69        context.method_name = mock.sentinel.method_name
70        context.service_url = mock.sentinel.service_url
71        callback = mock.create_autospec(grpc.AuthMetadataPluginCallback)
72
73        plugin(context, callback)
74
75        time.sleep(2)
76
77        callback.assert_called_once_with(
78            [("authorization", "Bearer {}".format(credentials.token))], None
79        )
80
81    def test_call_refresh(self):
82        credentials = CredentialsStub()
83        credentials.expiry = datetime.datetime.min + _helpers.REFRESH_THRESHOLD
84        request = mock.create_autospec(transport.Request)
85
86        plugin = google.auth.transport.grpc.AuthMetadataPlugin(credentials, request)
87
88        context = mock.create_autospec(grpc.AuthMetadataContext, instance=True)
89        context.method_name = mock.sentinel.method_name
90        context.service_url = mock.sentinel.service_url
91        callback = mock.create_autospec(grpc.AuthMetadataPluginCallback)
92
93        plugin(context, callback)
94
95        time.sleep(2)
96
97        assert credentials.token == "token1"
98        callback.assert_called_once_with(
99            [("authorization", "Bearer {}".format(credentials.token))], None
100        )
101
102    def test__get_authorization_headers_with_service_account(self):
103        credentials = mock.create_autospec(service_account.Credentials)
104        request = mock.create_autospec(transport.Request)
105
106        plugin = google.auth.transport.grpc.AuthMetadataPlugin(credentials, request)
107
108        context = mock.create_autospec(grpc.AuthMetadataContext, instance=True)
109        context.method_name = "methodName"
110        context.service_url = "https://pubsub.googleapis.com/methodName"
111
112        plugin._get_authorization_headers(context)
113
114        credentials._create_self_signed_jwt.assert_called_once_with(None)
115
116    def test__get_authorization_headers_with_service_account_and_default_host(self):
117        credentials = mock.create_autospec(service_account.Credentials)
118        request = mock.create_autospec(transport.Request)
119
120        default_host = "pubsub.googleapis.com"
121        plugin = google.auth.transport.grpc.AuthMetadataPlugin(
122            credentials, request, default_host=default_host
123        )
124
125        context = mock.create_autospec(grpc.AuthMetadataContext, instance=True)
126        context.method_name = "methodName"
127        context.service_url = "https://pubsub.googleapis.com/methodName"
128
129        plugin._get_authorization_headers(context)
130
131        credentials._create_self_signed_jwt.assert_called_once_with(
132            "https://{}/".format(default_host)
133        )
134
135
136@mock.patch(
137    "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
138)
139@mock.patch("grpc.composite_channel_credentials", autospec=True)
140@mock.patch("grpc.metadata_call_credentials", autospec=True)
141@mock.patch("grpc.ssl_channel_credentials", autospec=True)
142@mock.patch("grpc.secure_channel", autospec=True)
143class TestSecureAuthorizedChannel(object):
144    @mock.patch(
145        "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True
146    )
147    @mock.patch(
148        "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
149    )
150    def test_secure_authorized_channel_adc(
151        self,
152        check_dca_metadata_path,
153        read_dca_metadata_file,
154        secure_channel,
155        ssl_channel_credentials,
156        metadata_call_credentials,
157        composite_channel_credentials,
158        get_client_ssl_credentials,
159    ):
160        credentials = CredentialsStub()
161        request = mock.create_autospec(transport.Request)
162        target = "example.com:80"
163
164        # Mock the context aware metadata and client cert/key so mTLS SSL channel
165        # will be used.
166        check_dca_metadata_path.return_value = METADATA_PATH
167        read_dca_metadata_file.return_value = {
168            "cert_provider_command": ["some command"]
169        }
170        get_client_ssl_credentials.return_value = (
171            True,
172            PUBLIC_CERT_BYTES,
173            PRIVATE_KEY_BYTES,
174            None,
175        )
176
177        channel = None
178        with mock.patch.dict(
179            os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
180        ):
181            channel = google.auth.transport.grpc.secure_authorized_channel(
182                credentials, request, target, options=mock.sentinel.options
183            )
184
185        # Check the auth plugin construction.
186        auth_plugin = metadata_call_credentials.call_args[0][0]
187        assert isinstance(auth_plugin, google.auth.transport.grpc.AuthMetadataPlugin)
188        assert auth_plugin._credentials == credentials
189        assert auth_plugin._request == request
190
191        # Check the ssl channel call.
192        ssl_channel_credentials.assert_called_once_with(
193            certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
194        )
195
196        # Check the composite credentials call.
197        composite_channel_credentials.assert_called_once_with(
198            ssl_channel_credentials.return_value, metadata_call_credentials.return_value
199        )
200
201        # Check the channel call.
202        secure_channel.assert_called_once_with(
203            target,
204            composite_channel_credentials.return_value,
205            options=mock.sentinel.options,
206        )
207        assert channel == secure_channel.return_value
208
209    @mock.patch("google.auth.transport.grpc.SslCredentials", autospec=True)
210    def test_secure_authorized_channel_adc_without_client_cert_env(
211        self,
212        ssl_credentials_adc_method,
213        secure_channel,
214        ssl_channel_credentials,
215        metadata_call_credentials,
216        composite_channel_credentials,
217        get_client_ssl_credentials,
218    ):
219        # Test client cert won't be used if GOOGLE_API_USE_CLIENT_CERTIFICATE
220        # environment variable is not set.
221        credentials = CredentialsStub()
222        request = mock.create_autospec(transport.Request)
223        target = "example.com:80"
224
225        channel = google.auth.transport.grpc.secure_authorized_channel(
226            credentials, request, target, options=mock.sentinel.options
227        )
228
229        # Check the auth plugin construction.
230        auth_plugin = metadata_call_credentials.call_args[0][0]
231        assert isinstance(auth_plugin, google.auth.transport.grpc.AuthMetadataPlugin)
232        assert auth_plugin._credentials == credentials
233        assert auth_plugin._request == request
234
235        # Check the ssl channel call.
236        ssl_channel_credentials.assert_called_once()
237        ssl_credentials_adc_method.assert_not_called()
238
239        # Check the composite credentials call.
240        composite_channel_credentials.assert_called_once_with(
241            ssl_channel_credentials.return_value, metadata_call_credentials.return_value
242        )
243
244        # Check the channel call.
245        secure_channel.assert_called_once_with(
246            target,
247            composite_channel_credentials.return_value,
248            options=mock.sentinel.options,
249        )
250        assert channel == secure_channel.return_value
251
252    def test_secure_authorized_channel_explicit_ssl(
253        self,
254        secure_channel,
255        ssl_channel_credentials,
256        metadata_call_credentials,
257        composite_channel_credentials,
258        get_client_ssl_credentials,
259    ):
260        credentials = mock.Mock()
261        request = mock.Mock()
262        target = "example.com:80"
263        ssl_credentials = mock.Mock()
264
265        google.auth.transport.grpc.secure_authorized_channel(
266            credentials, request, target, ssl_credentials=ssl_credentials
267        )
268
269        # Since explicit SSL credentials are provided, get_client_ssl_credentials
270        # shouldn't be called.
271        assert not get_client_ssl_credentials.called
272
273        # Check the ssl channel call.
274        assert not ssl_channel_credentials.called
275
276        # Check the composite credentials call.
277        composite_channel_credentials.assert_called_once_with(
278            ssl_credentials, metadata_call_credentials.return_value
279        )
280
281    def test_secure_authorized_channel_mutual_exclusive(
282        self,
283        secure_channel,
284        ssl_channel_credentials,
285        metadata_call_credentials,
286        composite_channel_credentials,
287        get_client_ssl_credentials,
288    ):
289        credentials = mock.Mock()
290        request = mock.Mock()
291        target = "example.com:80"
292        ssl_credentials = mock.Mock()
293        client_cert_callback = mock.Mock()
294
295        with pytest.raises(ValueError):
296            google.auth.transport.grpc.secure_authorized_channel(
297                credentials,
298                request,
299                target,
300                ssl_credentials=ssl_credentials,
301                client_cert_callback=client_cert_callback,
302            )
303
304    def test_secure_authorized_channel_with_client_cert_callback_success(
305        self,
306        secure_channel,
307        ssl_channel_credentials,
308        metadata_call_credentials,
309        composite_channel_credentials,
310        get_client_ssl_credentials,
311    ):
312        credentials = mock.Mock()
313        request = mock.Mock()
314        target = "example.com:80"
315        client_cert_callback = mock.Mock()
316        client_cert_callback.return_value = (PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES)
317
318        with mock.patch.dict(
319            os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
320        ):
321            google.auth.transport.grpc.secure_authorized_channel(
322                credentials, request, target, client_cert_callback=client_cert_callback
323            )
324
325        client_cert_callback.assert_called_once()
326
327        # Check we are using the cert and key provided by client_cert_callback.
328        ssl_channel_credentials.assert_called_once_with(
329            certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
330        )
331
332        # Check the composite credentials call.
333        composite_channel_credentials.assert_called_once_with(
334            ssl_channel_credentials.return_value, metadata_call_credentials.return_value
335        )
336
337    @mock.patch(
338        "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True
339    )
340    @mock.patch(
341        "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
342    )
343    def test_secure_authorized_channel_with_client_cert_callback_failure(
344        self,
345        check_dca_metadata_path,
346        read_dca_metadata_file,
347        secure_channel,
348        ssl_channel_credentials,
349        metadata_call_credentials,
350        composite_channel_credentials,
351        get_client_ssl_credentials,
352    ):
353        credentials = mock.Mock()
354        request = mock.Mock()
355        target = "example.com:80"
356
357        client_cert_callback = mock.Mock()
358        client_cert_callback.side_effect = Exception("callback exception")
359
360        with pytest.raises(Exception) as excinfo:
361            with mock.patch.dict(
362                os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
363            ):
364                google.auth.transport.grpc.secure_authorized_channel(
365                    credentials,
366                    request,
367                    target,
368                    client_cert_callback=client_cert_callback,
369                )
370
371        assert str(excinfo.value) == "callback exception"
372
373    def test_secure_authorized_channel_cert_callback_without_client_cert_env(
374        self,
375        secure_channel,
376        ssl_channel_credentials,
377        metadata_call_credentials,
378        composite_channel_credentials,
379        get_client_ssl_credentials,
380    ):
381        # Test client cert won't be used if GOOGLE_API_USE_CLIENT_CERTIFICATE
382        # environment variable is not set.
383        credentials = mock.Mock()
384        request = mock.Mock()
385        target = "example.com:80"
386        client_cert_callback = mock.Mock()
387
388        google.auth.transport.grpc.secure_authorized_channel(
389            credentials, request, target, client_cert_callback=client_cert_callback
390        )
391
392        # Check client_cert_callback is not called because GOOGLE_API_USE_CLIENT_CERTIFICATE
393        # is not set.
394        client_cert_callback.assert_not_called()
395
396        ssl_channel_credentials.assert_called_once()
397
398        # Check the composite credentials call.
399        composite_channel_credentials.assert_called_once_with(
400            ssl_channel_credentials.return_value, metadata_call_credentials.return_value
401        )
402
403
404@mock.patch("grpc.ssl_channel_credentials", autospec=True)
405@mock.patch(
406    "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
407)
408@mock.patch("google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True)
409@mock.patch(
410    "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
411)
412class TestSslCredentials(object):
413    def test_no_context_aware_metadata(
414        self,
415        mock_check_dca_metadata_path,
416        mock_read_dca_metadata_file,
417        mock_get_client_ssl_credentials,
418        mock_ssl_channel_credentials,
419    ):
420        # Mock that the metadata file doesn't exist.
421        mock_check_dca_metadata_path.return_value = None
422
423        with mock.patch.dict(
424            os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
425        ):
426            ssl_credentials = google.auth.transport.grpc.SslCredentials()
427
428        # Since no context aware metadata is found, we wouldn't call
429        # get_client_ssl_credentials, and the SSL channel credentials created is
430        # non mTLS.
431        assert ssl_credentials.ssl_credentials is not None
432        assert not ssl_credentials.is_mtls
433        mock_get_client_ssl_credentials.assert_not_called()
434        mock_ssl_channel_credentials.assert_called_once_with()
435
436    def test_get_client_ssl_credentials_failure(
437        self,
438        mock_check_dca_metadata_path,
439        mock_read_dca_metadata_file,
440        mock_get_client_ssl_credentials,
441        mock_ssl_channel_credentials,
442    ):
443        mock_check_dca_metadata_path.return_value = METADATA_PATH
444        mock_read_dca_metadata_file.return_value = {
445            "cert_provider_command": ["some command"]
446        }
447
448        # Mock that client cert and key are not loaded and exception is raised.
449        mock_get_client_ssl_credentials.side_effect = exceptions.ClientCertError()
450
451        with pytest.raises(exceptions.MutualTLSChannelError):
452            with mock.patch.dict(
453                os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
454            ):
455                assert google.auth.transport.grpc.SslCredentials().ssl_credentials
456
457    def test_get_client_ssl_credentials_success(
458        self,
459        mock_check_dca_metadata_path,
460        mock_read_dca_metadata_file,
461        mock_get_client_ssl_credentials,
462        mock_ssl_channel_credentials,
463    ):
464        mock_check_dca_metadata_path.return_value = METADATA_PATH
465        mock_read_dca_metadata_file.return_value = {
466            "cert_provider_command": ["some command"]
467        }
468        mock_get_client_ssl_credentials.return_value = (
469            True,
470            PUBLIC_CERT_BYTES,
471            PRIVATE_KEY_BYTES,
472            None,
473        )
474
475        with mock.patch.dict(
476            os.environ, {environment_vars.GOOGLE_API_USE_CLIENT_CERTIFICATE: "true"}
477        ):
478            ssl_credentials = google.auth.transport.grpc.SslCredentials()
479
480        assert ssl_credentials.ssl_credentials is not None
481        assert ssl_credentials.is_mtls
482        mock_get_client_ssl_credentials.assert_called_once()
483        mock_ssl_channel_credentials.assert_called_once_with(
484            certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
485        )
486
487    def test_get_client_ssl_credentials_without_client_cert_env(
488        self,
489        mock_check_dca_metadata_path,
490        mock_read_dca_metadata_file,
491        mock_get_client_ssl_credentials,
492        mock_ssl_channel_credentials,
493    ):
494        # Test client cert won't be used if GOOGLE_API_USE_CLIENT_CERTIFICATE is not set.
495        ssl_credentials = google.auth.transport.grpc.SslCredentials()
496
497        assert ssl_credentials.ssl_credentials is not None
498        assert not ssl_credentials.is_mtls
499        mock_check_dca_metadata_path.assert_not_called()
500        mock_read_dca_metadata_file.assert_not_called()
501        mock_get_client_ssl_credentials.assert_not_called()
502        mock_ssl_channel_credentials.assert_called_once()
503