1# Copyright 2014 Google Inc. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Unit tests for oauth2_client and related classes.""" 16 17from __future__ import absolute_import 18 19import datetime 20import httplib2 21import logging 22import mox 23import os 24import stat 25import sys 26import unittest 27 28from freezegun import freeze_time 29 30from gcs_oauth2_boto_plugin import oauth2_client 31 32LOG = logging.getLogger('test_oauth2_client') 33 34ACCESS_TOKEN = 'abc123' 35TOKEN_URI = 'https://provider.example.com/oauth/provider?mode=token' 36AUTH_URI = 'https://provider.example.com/oauth/provider?mode=authorize' 37DEFAULT_CA_CERTS_FILE = os.path.abspath( 38 os.path.join('gslib', 'data', 'cacerts.txt')) 39 40IS_WINDOWS = 'win32' in str(sys.platform).lower() 41 42 43class MockDateTime(object): 44 def __init__(self): 45 self.mock_now = None 46 47 def utcnow(self): # pylint: disable=invalid-name 48 return self.mock_now 49 50 51class MockOAuth2ServiceAccountClient(oauth2_client.OAuth2ServiceAccountClient): 52 """Mock service account client for testing OAuth2 with service accounts.""" 53 54 def __init__(self, client_id, private_key, password, auth_uri, token_uri, 55 datetime_strategy): 56 super(MockOAuth2ServiceAccountClient, self).__init__( 57 client_id, private_key, password, auth_uri=auth_uri, 58 token_uri=token_uri, datetime_strategy=datetime_strategy, 59 ca_certs_file=DEFAULT_CA_CERTS_FILE) 60 self.Reset() 61 62 def Reset(self): 63 self.fetched_token = False 64 65 def FetchAccessToken(self): 66 self.fetched_token = True 67 return oauth2_client.AccessToken( 68 ACCESS_TOKEN, 69 GetExpiry(self.datetime_strategy, 3600), 70 datetime_strategy=self.datetime_strategy) 71 72 73class MockOAuth2UserAccountClient(oauth2_client.OAuth2UserAccountClient): 74 """Mock user account client for testing OAuth2 with user accounts.""" 75 76 def __init__(self, token_uri, client_id, client_secret, refresh_token, 77 auth_uri, datetime_strategy): 78 super(MockOAuth2UserAccountClient, self).__init__( 79 token_uri, client_id, client_secret, refresh_token, auth_uri=auth_uri, 80 datetime_strategy=datetime_strategy, 81 ca_certs_file=DEFAULT_CA_CERTS_FILE) 82 self.Reset() 83 84 def Reset(self): 85 self.fetched_token = False 86 87 def FetchAccessToken(self): 88 self.fetched_token = True 89 return oauth2_client.AccessToken( 90 ACCESS_TOKEN, 91 GetExpiry(self.datetime_strategy, 3600), 92 datetime_strategy=self.datetime_strategy) 93 94 95def GetExpiry(datetime_strategy, length_in_seconds): 96 token_expiry = (datetime_strategy.utcnow() 97 + datetime.timedelta(seconds=length_in_seconds)) 98 return token_expiry 99 100 101def CreateMockUserAccountClient(mock_datetime): 102 return MockOAuth2UserAccountClient( 103 TOKEN_URI, 'clid', 'clsecret', 'ref_token_abc123', AUTH_URI, 104 mock_datetime) 105 106 107def CreateMockServiceAccountClient(mock_datetime): 108 return MockOAuth2ServiceAccountClient( 109 'clid', 'private_key', 'password', AUTH_URI, TOKEN_URI, mock_datetime) 110 111 112class OAuth2AccountClientTest(unittest.TestCase): 113 """Unit tests for OAuth2UserAccountClient and OAuth2ServiceAccountClient.""" 114 115 def setUp(self): 116 self.tempdirs = [] 117 self.mock_datetime = MockDateTime() 118 self.start_time = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) 119 self.mock_datetime.mock_now = self.start_time 120 121 def testGetAccessTokenUserAccount(self): 122 self.client = CreateMockUserAccountClient(self.mock_datetime) 123 self._RunGetAccessTokenTest() 124 125 def testGetAccessTokenServiceAccount(self): 126 self.client = CreateMockServiceAccountClient(self.mock_datetime) 127 self._RunGetAccessTokenTest() 128 129 def _RunGetAccessTokenTest(self): 130 """Tests access token gets with self.client.""" 131 access_token_1 = 'abc123' 132 133 self.assertFalse(self.client.fetched_token) 134 token_1 = self.client.GetAccessToken() 135 136 # There's no access token in the cache; verify that we fetched a fresh 137 # token. 138 self.assertTrue(self.client.fetched_token) 139 self.assertEquals(access_token_1, token_1.token) 140 self.assertEquals(self.start_time + datetime.timedelta(minutes=60), 141 token_1.expiry) 142 143 # Advance time by less than expiry time, and fetch another token. 144 self.client.Reset() 145 self.mock_datetime.mock_now = ( 146 self.start_time + datetime.timedelta(minutes=55)) 147 token_2 = self.client.GetAccessToken() 148 149 # Since the access token wasn't expired, we get the cache token, and there 150 # was no refresh request. 151 self.assertEquals(token_1, token_2) 152 self.assertEquals(access_token_1, token_2.token) 153 self.assertFalse(self.client.fetched_token) 154 155 # Advance time past expiry time, and fetch another token. 156 self.client.Reset() 157 self.mock_datetime.mock_now = ( 158 self.start_time + datetime.timedelta(minutes=55, seconds=1)) 159 self.client.datetime_strategy = self.mock_datetime 160 token_3 = self.client.GetAccessToken() 161 162 # This should have resulted in a refresh request and a fresh access token. 163 self.assertTrue(self.client.fetched_token) 164 self.assertEquals( 165 self.mock_datetime.mock_now + datetime.timedelta(minutes=60), 166 token_3.expiry) 167 168 169class AccessTokenTest(unittest.TestCase): 170 """Unit tests for access token functions.""" 171 172 def testShouldRefresh(self): 173 """Tests that token.ShouldRefresh returns the correct value.""" 174 mock_datetime = MockDateTime() 175 start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) 176 expiry = start + datetime.timedelta(minutes=60) 177 token = oauth2_client.AccessToken( 178 'foo', expiry, datetime_strategy=mock_datetime) 179 180 mock_datetime.mock_now = start 181 self.assertFalse(token.ShouldRefresh()) 182 183 mock_datetime.mock_now = start + datetime.timedelta(minutes=54) 184 self.assertFalse(token.ShouldRefresh()) 185 186 mock_datetime.mock_now = start + datetime.timedelta(minutes=55) 187 self.assertFalse(token.ShouldRefresh()) 188 189 mock_datetime.mock_now = start + datetime.timedelta( 190 minutes=55, seconds=1) 191 self.assertTrue(token.ShouldRefresh()) 192 193 mock_datetime.mock_now = start + datetime.timedelta( 194 minutes=61) 195 self.assertTrue(token.ShouldRefresh()) 196 197 mock_datetime.mock_now = start + datetime.timedelta(minutes=58) 198 self.assertFalse(token.ShouldRefresh(time_delta=120)) 199 200 mock_datetime.mock_now = start + datetime.timedelta( 201 minutes=58, seconds=1) 202 self.assertTrue(token.ShouldRefresh(time_delta=120)) 203 204 def testShouldRefreshNoExpiry(self): 205 """Tests token.ShouldRefresh with no expiry time.""" 206 mock_datetime = MockDateTime() 207 start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) 208 token = oauth2_client.AccessToken( 209 'foo', None, datetime_strategy=mock_datetime) 210 211 mock_datetime.mock_now = start 212 self.assertFalse(token.ShouldRefresh()) 213 214 mock_datetime.mock_now = start + datetime.timedelta( 215 minutes=472) 216 self.assertFalse(token.ShouldRefresh()) 217 218 def testSerialization(self): 219 """Tests token serialization.""" 220 expiry = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) 221 token = oauth2_client.AccessToken('foo', expiry) 222 serialized_token = token.Serialize() 223 LOG.debug('testSerialization: serialized_token=%s', serialized_token) 224 225 token2 = oauth2_client.AccessToken.UnSerialize(serialized_token) 226 self.assertEquals(token, token2) 227 228 229class FileSystemTokenCacheTest(unittest.TestCase): 230 """Unit tests for FileSystemTokenCache.""" 231 232 def setUp(self): 233 self.cache = oauth2_client.FileSystemTokenCache() 234 self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) 235 self.token_1 = oauth2_client.AccessToken('token1', self.start_time) 236 self.token_2 = oauth2_client.AccessToken( 237 'token2', self.start_time + datetime.timedelta(seconds=492)) 238 self.key = 'token1key' 239 240 def tearDown(self): 241 try: 242 os.unlink(self.cache.CacheFileName(self.key)) 243 except: # pylint: disable=bare-except 244 pass 245 246 def testPut(self): 247 self.cache.PutToken(self.key, self.token_1) 248 # Assert that the cache file exists and has correct permissions. 249 if not IS_WINDOWS: 250 self.assertEquals( 251 0600, 252 stat.S_IMODE(os.stat(self.cache.CacheFileName(self.key)).st_mode)) 253 254 def testPutGet(self): 255 """Tests putting and getting various tokens.""" 256 # No cache file present. 257 self.assertEquals(None, self.cache.GetToken(self.key)) 258 259 # Put a token 260 self.cache.PutToken(self.key, self.token_1) 261 cached_token = self.cache.GetToken(self.key) 262 self.assertEquals(self.token_1, cached_token) 263 264 # Put a different token 265 self.cache.PutToken(self.key, self.token_2) 266 cached_token = self.cache.GetToken(self.key) 267 self.assertEquals(self.token_2, cached_token) 268 269 def testGetBadFile(self): 270 f = open(self.cache.CacheFileName(self.key), 'w') 271 f.write('blah') 272 f.close() 273 self.assertEquals(None, self.cache.GetToken(self.key)) 274 275 def testCacheFileName(self): 276 """Tests configuring the cache with a specific file name.""" 277 cache = oauth2_client.FileSystemTokenCache( 278 path_pattern='/var/run/ccache/token.%(uid)s.%(key)s') 279 if IS_WINDOWS: 280 uid = '_' 281 else: 282 uid = os.getuid() 283 self.assertEquals('/var/run/ccache/token.%s.abc123' % uid, 284 cache.CacheFileName('abc123')) 285 286 cache = oauth2_client.FileSystemTokenCache( 287 path_pattern='/var/run/ccache/token.%(key)s') 288 self.assertEquals('/var/run/ccache/token.abc123', 289 cache.CacheFileName('abc123')) 290 291 292class RefreshTokenTest(unittest.TestCase): 293 """Unit tests for refresh tokens.""" 294 295 def setUp(self): 296 self.mock_datetime = MockDateTime() 297 self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) 298 self.mock_datetime.mock_now = self.start_time 299 self.client = CreateMockUserAccountClient(self.mock_datetime) 300 301 def testUniqeId(self): 302 cred_id = self.client.CacheKey() 303 self.assertEquals('0720afed6871f12761fbea3271f451e6ba184bf5', cred_id) 304 305 def testGetAuthorizationHeader(self): 306 self.assertEquals('Bearer %s' % ACCESS_TOKEN, 307 self.client.GetAuthorizationHeader()) 308 309 310class FakeResponse: 311 def __init__(self, status): 312 self._status = status 313 314 @property 315 def status(self): 316 return self._status 317 318 319class OAuth2GCEClientTest(unittest.TestCase): 320 """Unit tests for OAuth2GCEClient.""" 321 322 def setUp(self): 323 self.mox = mox.Mox() 324 self.mox.StubOutClassWithMocks(httplib2, 'Http') 325 self.mock_http = httplib2.Http() 326 327 def tearDown(self): 328 self.mox.UnsetStubs() 329 330 @freeze_time('2014-03-26 01:01:01') 331 def testFetchAccessToken(self): 332 token = 'my_token' 333 334 self.mock_http.request( 335 oauth2_client.META_TOKEN_URI, 336 method='GET', 337 body=None, 338 headers=oauth2_client.META_HEADERS).AndReturn(( 339 FakeResponse(200), 340 '{"access_token":"%(TOKEN)s",' 341 '"expires_in": %(EXPIRES_IN)d}' % { 342 'TOKEN': token, 343 'EXPIRES_IN': 42 344 })) 345 346 self.mox.ReplayAll() 347 348 client = oauth2_client.OAuth2GCEClient() 349 350 self.assertEqual( 351 str(client.FetchAccessToken()), 352 'AccessToken(token=%s, expiry=2014-03-26 01:01:43Z)' % token) 353 354 self.mox.VerifyAll() 355 356 def testIsGCENotFound(self): 357 self.mock_http.request(oauth2_client.METADATA_SERVER).AndReturn(( 358 FakeResponse(404), '')) 359 360 self.mox.ReplayAll() 361 self.assertFalse(oauth2_client._IsGCE()) 362 363 self.mox.VerifyAll() 364 365 def testIsGCEServerNotFound(self): 366 self.mock_http.request(oauth2_client.METADATA_SERVER).AndRaise( 367 httplib2.ServerNotFoundError) 368 369 self.mox.ReplayAll() 370 self.assertFalse(oauth2_client._IsGCE()) 371 372 self.mox.VerifyAll() 373 374 def testIsGCETrue(self): 375 self.mock_http.request(oauth2_client.METADATA_SERVER).AndReturn(( 376 FakeResponse(200), '')) 377 378 self.mox.ReplayAll() 379 self.assertTrue(oauth2_client._IsGCE()) 380 381 self.mox.VerifyAll() 382 383 384if __name__ == '__main__': 385 unittest.main() 386