1# 2# Copyright 2015 Google Inc. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16import mock 17import six 18import unittest2 19 20from apitools.base.py import credentials_lib 21from apitools.base.py import util 22 23 24class CredentialsLibTest(unittest2.TestCase): 25 26 def _GetServiceCreds(self, service_account_name=None, scopes=None): 27 kwargs = {} 28 if service_account_name is not None: 29 kwargs['service_account_name'] = service_account_name 30 service_account_name = service_account_name or 'default' 31 32 def MockMetadataCalls(request_url): 33 default_scopes = scopes or ['scope1'] 34 if request_url.endswith('scopes'): 35 return six.StringIO(''.join(default_scopes)) 36 elif request_url.endswith('service-accounts'): 37 return six.StringIO(service_account_name) 38 elif request_url.endswith( 39 '/service-accounts/%s/token' % service_account_name): 40 return six.StringIO('{"access_token": "token"}') 41 self.fail('Unexpected HTTP request to %s' % request_url) 42 43 with mock.patch.object(credentials_lib, '_GceMetadataRequest', 44 side_effect=MockMetadataCalls, 45 autospec=True) as opener_mock: 46 with mock.patch.object(util, 'DetectGce', 47 autospec=True) as mock_detect: 48 mock_detect.return_value = True 49 credentials = credentials_lib.GceAssertionCredentials( 50 scopes, **kwargs) 51 self.assertIsNone(credentials._refresh(None)) 52 self.assertEqual(3, opener_mock.call_count) 53 return credentials 54 55 def testGceServiceAccounts(self): 56 scopes = ['scope1'] 57 self._GetServiceCreds() 58 self._GetServiceCreds(scopes=scopes) 59 self._GetServiceCreds(service_account_name='my_service_account', 60 scopes=scopes) 61 62 def testGetServiceAccount(self): 63 # We'd also like to test the metadata calls, which requires 64 # having some knowledge about how HTTP calls are made (so that 65 # we can mock them). It's unfortunate, but there's no way 66 # around it. 67 creds = self._GetServiceCreds() 68 opener = mock.MagicMock() 69 opener.open = mock.MagicMock() 70 opener.open.return_value = six.StringIO('default/\nanother') 71 with mock.patch.object(six.moves.urllib.request, 'build_opener', 72 return_value=opener, 73 autospec=True) as build_opener: 74 creds.GetServiceAccount('default') 75 self.assertEqual(1, build_opener.call_count) 76 self.assertEqual(1, opener.open.call_count) 77 req = opener.open.call_args[0][0] 78 self.assertTrue(req.get_full_url().startswith( 79 'http://metadata.google.internal/')) 80 # The urllib module does weird things with header case. 81 self.assertEqual('Google', req.get_header('Metadata-flavor')) 82 83 def testGetAdcNone(self): 84 # Tests that we correctly return None when ADC aren't present in 85 # the well-known file. 86 creds = credentials_lib._GetApplicationDefaultCredentials( 87 client_info={'scope': ''}) 88 self.assertIsNone(creds) 89 90 91class TestGetRunFlowFlags(unittest2.TestCase): 92 93 def setUp(self): 94 self._flags_actual = credentials_lib.FLAGS 95 96 def tearDown(self): 97 credentials_lib.FLAGS = self._flags_actual 98 99 def test_with_gflags(self): 100 HOST = 'myhostname' 101 PORT = '144169' 102 103 class MockFlags(object): 104 auth_host_name = HOST 105 auth_host_port = PORT 106 auth_local_webserver = False 107 108 credentials_lib.FLAGS = MockFlags 109 flags = credentials_lib._GetRunFlowFlags([ 110 '--auth_host_name=%s' % HOST, 111 '--auth_host_port=%s' % PORT, 112 '--noauth_local_webserver', 113 ]) 114 self.assertEqual(flags.auth_host_name, HOST) 115 self.assertEqual(flags.auth_host_port, PORT) 116 self.assertEqual(flags.logging_level, 'ERROR') 117 self.assertEqual(flags.noauth_local_webserver, True) 118 119 def test_without_gflags(self): 120 credentials_lib.FLAGS = None 121 flags = credentials_lib._GetRunFlowFlags([]) 122 self.assertEqual(flags.auth_host_name, 'localhost') 123 self.assertEqual(flags.auth_host_port, [8080, 8090]) 124 self.assertEqual(flags.logging_level, 'ERROR') 125 self.assertEqual(flags.noauth_local_webserver, False) 126