• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# Copyright 2015 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import json
17import os.path
18import shutil
19import tempfile
20import unittest
21
22import mock
23import six
24
25from apitools.base.py import credentials_lib
26from apitools.base.py import util
27
28
29class MetadataMock(object):
30
31    def __init__(self, scopes=None, service_account_name=None):
32        self._scopes = scopes or ['scope1']
33        self._sa = service_account_name or 'default'
34
35    def __call__(self, request_url):
36        if request_url.endswith('scopes'):
37            return six.StringIO(''.join(self._scopes))
38        elif request_url.endswith('service-accounts'):
39            return six.StringIO(self._sa)
40        elif request_url.endswith(
41                '/service-accounts/%s/token' % self._sa):
42            return six.StringIO('{"access_token": "token"}')
43        self.fail('Unexpected HTTP request to %s' % request_url)
44
45
46class CredentialsLibTest(unittest.TestCase):
47
48    def _RunGceAssertionCredentials(
49            self, service_account_name=None, scopes=None, cache_filename=None):
50        kwargs = {}
51        if service_account_name is not None:
52            kwargs['service_account_name'] = service_account_name
53        if cache_filename is not None:
54            kwargs['cache_filename'] = cache_filename
55        service_account_name = service_account_name or 'default'
56        credentials = credentials_lib.GceAssertionCredentials(
57            scopes, **kwargs)
58        self.assertIsNone(credentials._refresh(None))
59        return credentials
60
61    def _GetServiceCreds(self, service_account_name=None, scopes=None):
62        metadatamock = MetadataMock(scopes, service_account_name)
63        with mock.patch.object(util, 'DetectGce', autospec=True) as gce_detect:
64            gce_detect.return_value = True
65            with mock.patch.object(credentials_lib,
66                                   '_GceMetadataRequest',
67                                   side_effect=metadatamock,
68                                   autospec=True) as opener_mock:
69                credentials = self._RunGceAssertionCredentials(
70                    service_account_name=service_account_name,
71                    scopes=scopes)
72            self.assertEqual(3, opener_mock.call_count)
73        return credentials
74
75    def testGceServiceAccounts(self):
76        scopes = ['scope1']
77        self._GetServiceCreds(service_account_name=None,
78                              scopes=None)
79        self._GetServiceCreds(service_account_name=None,
80                              scopes=scopes)
81        self._GetServiceCreds(
82            service_account_name='my_service_account',
83            scopes=scopes)
84
85    def testGceAssertionCredentialsToJson(self):
86        scopes = ['scope1']
87        service_account_name = 'my_service_account'
88        # Ensure that we can obtain a JSON representation of
89        # GceAssertionCredentials to put in a credential Storage object, and
90        # that the JSON representation is valid.
91        original_creds = self._GetServiceCreds(
92            service_account_name=service_account_name,
93            scopes=scopes)
94        original_creds_json_str = original_creds.to_json()
95        json.loads(original_creds_json_str)
96
97    @mock.patch.object(util, 'DetectGce', autospec=True)
98    def testGceServiceAccountsCached(self, mock_detect):
99        mock_detect.return_value = True
100        tempd = tempfile.mkdtemp()
101        tempname = os.path.join(tempd, 'creds')
102        scopes = ['scope1']
103        service_account_name = 'some_service_account_name'
104        metadatamock = MetadataMock(scopes, service_account_name)
105        with mock.patch.object(credentials_lib,
106                               '_GceMetadataRequest',
107                               side_effect=metadatamock,
108                               autospec=True) as opener_mock:
109            try:
110                creds1 = self._RunGceAssertionCredentials(
111                    service_account_name=service_account_name,
112                    cache_filename=tempname,
113                    scopes=scopes)
114                pre_cache_call_count = opener_mock.call_count
115                creds2 = self._RunGceAssertionCredentials(
116                    service_account_name=service_account_name,
117                    cache_filename=tempname,
118                    scopes=None)
119            finally:
120                shutil.rmtree(tempd)
121        self.assertEqual(creds1.client_id, creds2.client_id)
122        self.assertEqual(pre_cache_call_count, 3)
123        # Caching obviates the need for extra metadata server requests.
124        # Only one metadata request is made if the cache is hit.
125        self.assertEqual(opener_mock.call_count, 4)
126
127    def testGetServiceAccount(self):
128        # We'd also like to test the metadata calls, which requires
129        # having some knowledge about how HTTP calls are made (so that
130        # we can mock them). It's unfortunate, but there's no way
131        # around it.
132        creds = self._GetServiceCreds()
133        opener = mock.MagicMock()
134        opener.open = mock.MagicMock()
135        opener.open.return_value = six.StringIO('default/\nanother')
136        with mock.patch.object(six.moves.urllib.request, 'build_opener',
137                               return_value=opener,
138                               autospec=True) as build_opener:
139            creds.GetServiceAccount('default')
140            self.assertEqual(1, build_opener.call_count)
141            self.assertEqual(1, opener.open.call_count)
142            req = opener.open.call_args[0][0]
143            self.assertTrue(req.get_full_url().startswith(
144                'http://metadata.google.internal/'))
145            # The urllib module does weird things with header case.
146            self.assertEqual('Google', req.get_header('Metadata-flavor'))
147
148    def testGetAdcNone(self):
149        # Tests that we correctly return None when ADC aren't present in
150        # the well-known file.
151        creds = credentials_lib._GetApplicationDefaultCredentials(
152            client_info={'scope': ''})
153        self.assertIsNone(creds)
154
155
156class TestGetRunFlowFlags(unittest.TestCase):
157
158    def setUp(self):
159        self._flags_actual = credentials_lib.FLAGS
160
161    def tearDown(self):
162        credentials_lib.FLAGS = self._flags_actual
163
164    def test_with_gflags(self):
165        HOST = 'myhostname'
166        PORT = '144169'
167
168        class MockFlags(object):
169            auth_host_name = HOST
170            auth_host_port = PORT
171            auth_local_webserver = False
172
173        credentials_lib.FLAGS = MockFlags
174        flags = credentials_lib._GetRunFlowFlags([
175            '--auth_host_name=%s' % HOST,
176            '--auth_host_port=%s' % PORT,
177            '--noauth_local_webserver',
178        ])
179        self.assertEqual(flags.auth_host_name, HOST)
180        self.assertEqual(flags.auth_host_port, PORT)
181        self.assertEqual(flags.logging_level, 'ERROR')
182        self.assertEqual(flags.noauth_local_webserver, True)
183
184    def test_without_gflags(self):
185        credentials_lib.FLAGS = None
186        flags = credentials_lib._GetRunFlowFlags([])
187        self.assertEqual(flags.auth_host_name, 'localhost')
188        self.assertEqual(flags.auth_host_port, [8080, 8090])
189        self.assertEqual(flags.logging_level, 'ERROR')
190        self.assertEqual(flags.noauth_local_webserver, False)
191