1# 2# Copyright 2015 Google Inc. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16import json 17import os.path 18import shutil 19import tempfile 20import unittest 21 22import mock 23import six 24 25from apitools.base.py import credentials_lib 26from apitools.base.py import util 27 28 29class MetadataMock(object): 30 31 def __init__(self, scopes=None, service_account_name=None): 32 self._scopes = scopes or ['scope1'] 33 self._sa = service_account_name or 'default' 34 35 def __call__(self, request_url): 36 if request_url.endswith('scopes'): 37 return six.StringIO(''.join(self._scopes)) 38 elif request_url.endswith('service-accounts'): 39 return six.StringIO(self._sa) 40 elif request_url.endswith( 41 '/service-accounts/%s/token' % self._sa): 42 return six.StringIO('{"access_token": "token"}') 43 self.fail('Unexpected HTTP request to %s' % request_url) 44 45 46class CredentialsLibTest(unittest.TestCase): 47 48 def _RunGceAssertionCredentials( 49 self, service_account_name=None, scopes=None, cache_filename=None): 50 kwargs = {} 51 if service_account_name is not None: 52 kwargs['service_account_name'] = service_account_name 53 if cache_filename is not None: 54 kwargs['cache_filename'] = cache_filename 55 service_account_name = service_account_name or 'default' 56 credentials = credentials_lib.GceAssertionCredentials( 57 scopes, **kwargs) 58 self.assertIsNone(credentials._refresh(None)) 59 return credentials 60 61 def _GetServiceCreds(self, service_account_name=None, scopes=None): 62 metadatamock = MetadataMock(scopes, service_account_name) 63 with mock.patch.object(util, 'DetectGce', autospec=True) as gce_detect: 64 gce_detect.return_value = True 65 with mock.patch.object(credentials_lib, 66 '_GceMetadataRequest', 67 side_effect=metadatamock, 68 autospec=True) as opener_mock: 69 credentials = self._RunGceAssertionCredentials( 70 service_account_name=service_account_name, 71 scopes=scopes) 72 self.assertEqual(3, opener_mock.call_count) 73 return credentials 74 75 def testGceServiceAccounts(self): 76 scopes = ['scope1'] 77 self._GetServiceCreds(service_account_name=None, 78 scopes=None) 79 self._GetServiceCreds(service_account_name=None, 80 scopes=scopes) 81 self._GetServiceCreds( 82 service_account_name='my_service_account', 83 scopes=scopes) 84 85 def testGceAssertionCredentialsToJson(self): 86 scopes = ['scope1'] 87 service_account_name = 'my_service_account' 88 # Ensure that we can obtain a JSON representation of 89 # GceAssertionCredentials to put in a credential Storage object, and 90 # that the JSON representation is valid. 91 original_creds = self._GetServiceCreds( 92 service_account_name=service_account_name, 93 scopes=scopes) 94 original_creds_json_str = original_creds.to_json() 95 json.loads(original_creds_json_str) 96 97 @mock.patch.object(util, 'DetectGce', autospec=True) 98 def testGceServiceAccountsCached(self, mock_detect): 99 mock_detect.return_value = True 100 tempd = tempfile.mkdtemp() 101 tempname = os.path.join(tempd, 'creds') 102 scopes = ['scope1'] 103 service_account_name = 'some_service_account_name' 104 metadatamock = MetadataMock(scopes, service_account_name) 105 with mock.patch.object(credentials_lib, 106 '_GceMetadataRequest', 107 side_effect=metadatamock, 108 autospec=True) as opener_mock: 109 try: 110 creds1 = self._RunGceAssertionCredentials( 111 service_account_name=service_account_name, 112 cache_filename=tempname, 113 scopes=scopes) 114 pre_cache_call_count = opener_mock.call_count 115 creds2 = self._RunGceAssertionCredentials( 116 service_account_name=service_account_name, 117 cache_filename=tempname, 118 scopes=None) 119 finally: 120 shutil.rmtree(tempd) 121 self.assertEqual(creds1.client_id, creds2.client_id) 122 self.assertEqual(pre_cache_call_count, 3) 123 # Caching obviates the need for extra metadata server requests. 124 # Only one metadata request is made if the cache is hit. 125 self.assertEqual(opener_mock.call_count, 4) 126 127 def testGetServiceAccount(self): 128 # We'd also like to test the metadata calls, which requires 129 # having some knowledge about how HTTP calls are made (so that 130 # we can mock them). It's unfortunate, but there's no way 131 # around it. 132 creds = self._GetServiceCreds() 133 opener = mock.MagicMock() 134 opener.open = mock.MagicMock() 135 opener.open.return_value = six.StringIO('default/\nanother') 136 with mock.patch.object(six.moves.urllib.request, 'build_opener', 137 return_value=opener, 138 autospec=True) as build_opener: 139 creds.GetServiceAccount('default') 140 self.assertEqual(1, build_opener.call_count) 141 self.assertEqual(1, opener.open.call_count) 142 req = opener.open.call_args[0][0] 143 self.assertTrue(req.get_full_url().startswith( 144 'http://metadata.google.internal/')) 145 # The urllib module does weird things with header case. 146 self.assertEqual('Google', req.get_header('Metadata-flavor')) 147 148 def testGetAdcNone(self): 149 # Tests that we correctly return None when ADC aren't present in 150 # the well-known file. 151 creds = credentials_lib._GetApplicationDefaultCredentials( 152 client_info={'scope': ''}) 153 self.assertIsNone(creds) 154 155 156class TestGetRunFlowFlags(unittest.TestCase): 157 158 def setUp(self): 159 self._flags_actual = credentials_lib.FLAGS 160 161 def tearDown(self): 162 credentials_lib.FLAGS = self._flags_actual 163 164 def test_with_gflags(self): 165 HOST = 'myhostname' 166 PORT = '144169' 167 168 class MockFlags(object): 169 auth_host_name = HOST 170 auth_host_port = PORT 171 auth_local_webserver = False 172 173 credentials_lib.FLAGS = MockFlags 174 flags = credentials_lib._GetRunFlowFlags([ 175 '--auth_host_name=%s' % HOST, 176 '--auth_host_port=%s' % PORT, 177 '--noauth_local_webserver', 178 ]) 179 self.assertEqual(flags.auth_host_name, HOST) 180 self.assertEqual(flags.auth_host_port, PORT) 181 self.assertEqual(flags.logging_level, 'ERROR') 182 self.assertEqual(flags.noauth_local_webserver, True) 183 184 def test_without_gflags(self): 185 credentials_lib.FLAGS = None 186 flags = credentials_lib._GetRunFlowFlags([]) 187 self.assertEqual(flags.auth_host_name, 'localhost') 188 self.assertEqual(flags.auth_host_port, [8080, 8090]) 189 self.assertEqual(flags.logging_level, 'ERROR') 190 self.assertEqual(flags.noauth_local_webserver, False) 191