• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2014 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Unit tests for oauth2_client and related classes."""
16
17from __future__ import absolute_import
18
19import datetime
20import httplib2
21import logging
22import mox
23import os
24import stat
25import sys
26import unittest
27
28from freezegun import freeze_time
29
30from gcs_oauth2_boto_plugin import oauth2_client
31
32LOG = logging.getLogger('test_oauth2_client')
33
34ACCESS_TOKEN = 'abc123'
35TOKEN_URI = 'https://provider.example.com/oauth/provider?mode=token'
36AUTH_URI = 'https://provider.example.com/oauth/provider?mode=authorize'
37DEFAULT_CA_CERTS_FILE = os.path.abspath(
38    os.path.join('gslib', 'data', 'cacerts.txt'))
39
40IS_WINDOWS = 'win32' in str(sys.platform).lower()
41
42
43class MockDateTime(object):
44  def __init__(self):
45    self.mock_now = None
46
47  def utcnow(self):  # pylint: disable=invalid-name
48    return self.mock_now
49
50
51class MockOAuth2ServiceAccountClient(oauth2_client.OAuth2ServiceAccountClient):
52  """Mock service account client for testing OAuth2 with service accounts."""
53
54  def __init__(self, client_id, private_key, password, auth_uri, token_uri,
55               datetime_strategy):
56    super(MockOAuth2ServiceAccountClient, self).__init__(
57        client_id, private_key, password, auth_uri=auth_uri,
58        token_uri=token_uri, datetime_strategy=datetime_strategy,
59        ca_certs_file=DEFAULT_CA_CERTS_FILE)
60    self.Reset()
61
62  def Reset(self):
63    self.fetched_token = False
64
65  def FetchAccessToken(self):
66    self.fetched_token = True
67    return oauth2_client.AccessToken(
68        ACCESS_TOKEN,
69        GetExpiry(self.datetime_strategy, 3600),
70        datetime_strategy=self.datetime_strategy)
71
72
73class MockOAuth2UserAccountClient(oauth2_client.OAuth2UserAccountClient):
74  """Mock user account client for testing OAuth2 with user accounts."""
75
76  def __init__(self, token_uri, client_id, client_secret, refresh_token,
77               auth_uri, datetime_strategy):
78    super(MockOAuth2UserAccountClient, self).__init__(
79        token_uri, client_id, client_secret, refresh_token, auth_uri=auth_uri,
80        datetime_strategy=datetime_strategy,
81        ca_certs_file=DEFAULT_CA_CERTS_FILE)
82    self.Reset()
83
84  def Reset(self):
85    self.fetched_token = False
86
87  def FetchAccessToken(self):
88    self.fetched_token = True
89    return oauth2_client.AccessToken(
90        ACCESS_TOKEN,
91        GetExpiry(self.datetime_strategy, 3600),
92        datetime_strategy=self.datetime_strategy)
93
94
95def GetExpiry(datetime_strategy, length_in_seconds):
96  token_expiry = (datetime_strategy.utcnow()
97                  + datetime.timedelta(seconds=length_in_seconds))
98  return token_expiry
99
100
101def CreateMockUserAccountClient(mock_datetime):
102  return MockOAuth2UserAccountClient(
103      TOKEN_URI, 'clid', 'clsecret', 'ref_token_abc123', AUTH_URI,
104      mock_datetime)
105
106
107def CreateMockServiceAccountClient(mock_datetime):
108  return MockOAuth2ServiceAccountClient(
109      'clid', 'private_key', 'password', AUTH_URI, TOKEN_URI, mock_datetime)
110
111
112class OAuth2AccountClientTest(unittest.TestCase):
113  """Unit tests for OAuth2UserAccountClient and OAuth2ServiceAccountClient."""
114
115  def setUp(self):
116    self.tempdirs = []
117    self.mock_datetime = MockDateTime()
118    self.start_time = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826)
119    self.mock_datetime.mock_now = self.start_time
120
121  def testGetAccessTokenUserAccount(self):
122    self.client = CreateMockUserAccountClient(self.mock_datetime)
123    self._RunGetAccessTokenTest()
124
125  def testGetAccessTokenServiceAccount(self):
126    self.client = CreateMockServiceAccountClient(self.mock_datetime)
127    self._RunGetAccessTokenTest()
128
129  def _RunGetAccessTokenTest(self):
130    """Tests access token gets with self.client."""
131    access_token_1 = 'abc123'
132
133    self.assertFalse(self.client.fetched_token)
134    token_1 = self.client.GetAccessToken()
135
136    # There's no access token in the cache; verify that we fetched a fresh
137    # token.
138    self.assertTrue(self.client.fetched_token)
139    self.assertEquals(access_token_1, token_1.token)
140    self.assertEquals(self.start_time + datetime.timedelta(minutes=60),
141                      token_1.expiry)
142
143    # Advance time by less than expiry time, and fetch another token.
144    self.client.Reset()
145    self.mock_datetime.mock_now = (
146        self.start_time + datetime.timedelta(minutes=55))
147    token_2 = self.client.GetAccessToken()
148
149    # Since the access token wasn't expired, we get the cache token, and there
150    # was no refresh request.
151    self.assertEquals(token_1, token_2)
152    self.assertEquals(access_token_1, token_2.token)
153    self.assertFalse(self.client.fetched_token)
154
155    # Advance time past expiry time, and fetch another token.
156    self.client.Reset()
157    self.mock_datetime.mock_now = (
158        self.start_time + datetime.timedelta(minutes=55, seconds=1))
159    self.client.datetime_strategy = self.mock_datetime
160    token_3 = self.client.GetAccessToken()
161
162    # This should have resulted in a refresh request and a fresh access token.
163    self.assertTrue(self.client.fetched_token)
164    self.assertEquals(
165        self.mock_datetime.mock_now + datetime.timedelta(minutes=60),
166        token_3.expiry)
167
168
169class AccessTokenTest(unittest.TestCase):
170  """Unit tests for access token functions."""
171
172  def testShouldRefresh(self):
173    """Tests that token.ShouldRefresh returns the correct value."""
174    mock_datetime = MockDateTime()
175    start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826)
176    expiry = start + datetime.timedelta(minutes=60)
177    token = oauth2_client.AccessToken(
178        'foo', expiry, datetime_strategy=mock_datetime)
179
180    mock_datetime.mock_now = start
181    self.assertFalse(token.ShouldRefresh())
182
183    mock_datetime.mock_now = start + datetime.timedelta(minutes=54)
184    self.assertFalse(token.ShouldRefresh())
185
186    mock_datetime.mock_now = start + datetime.timedelta(minutes=55)
187    self.assertFalse(token.ShouldRefresh())
188
189    mock_datetime.mock_now = start + datetime.timedelta(
190        minutes=55, seconds=1)
191    self.assertTrue(token.ShouldRefresh())
192
193    mock_datetime.mock_now = start + datetime.timedelta(
194        minutes=61)
195    self.assertTrue(token.ShouldRefresh())
196
197    mock_datetime.mock_now = start + datetime.timedelta(minutes=58)
198    self.assertFalse(token.ShouldRefresh(time_delta=120))
199
200    mock_datetime.mock_now = start + datetime.timedelta(
201        minutes=58, seconds=1)
202    self.assertTrue(token.ShouldRefresh(time_delta=120))
203
204  def testShouldRefreshNoExpiry(self):
205    """Tests token.ShouldRefresh with no expiry time."""
206    mock_datetime = MockDateTime()
207    start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826)
208    token = oauth2_client.AccessToken(
209        'foo', None, datetime_strategy=mock_datetime)
210
211    mock_datetime.mock_now = start
212    self.assertFalse(token.ShouldRefresh())
213
214    mock_datetime.mock_now = start + datetime.timedelta(
215        minutes=472)
216    self.assertFalse(token.ShouldRefresh())
217
218  def testSerialization(self):
219    """Tests token serialization."""
220    expiry = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826)
221    token = oauth2_client.AccessToken('foo', expiry)
222    serialized_token = token.Serialize()
223    LOG.debug('testSerialization: serialized_token=%s', serialized_token)
224
225    token2 = oauth2_client.AccessToken.UnSerialize(serialized_token)
226    self.assertEquals(token, token2)
227
228
229class FileSystemTokenCacheTest(unittest.TestCase):
230  """Unit tests for FileSystemTokenCache."""
231
232  def setUp(self):
233    self.cache = oauth2_client.FileSystemTokenCache()
234    self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826)
235    self.token_1 = oauth2_client.AccessToken('token1', self.start_time)
236    self.token_2 = oauth2_client.AccessToken(
237        'token2', self.start_time + datetime.timedelta(seconds=492))
238    self.key = 'token1key'
239
240  def tearDown(self):
241    try:
242      os.unlink(self.cache.CacheFileName(self.key))
243    except:  # pylint: disable=bare-except
244      pass
245
246  def testPut(self):
247    self.cache.PutToken(self.key, self.token_1)
248    # Assert that the cache file exists and has correct permissions.
249    if not IS_WINDOWS:
250      self.assertEquals(
251          0600,
252          stat.S_IMODE(os.stat(self.cache.CacheFileName(self.key)).st_mode))
253
254  def testPutGet(self):
255    """Tests putting and getting various tokens."""
256    # No cache file present.
257    self.assertEquals(None, self.cache.GetToken(self.key))
258
259    # Put a token
260    self.cache.PutToken(self.key, self.token_1)
261    cached_token = self.cache.GetToken(self.key)
262    self.assertEquals(self.token_1, cached_token)
263
264    # Put a different token
265    self.cache.PutToken(self.key, self.token_2)
266    cached_token = self.cache.GetToken(self.key)
267    self.assertEquals(self.token_2, cached_token)
268
269  def testGetBadFile(self):
270    f = open(self.cache.CacheFileName(self.key), 'w')
271    f.write('blah')
272    f.close()
273    self.assertEquals(None, self.cache.GetToken(self.key))
274
275  def testCacheFileName(self):
276    """Tests configuring the cache with a specific file name."""
277    cache = oauth2_client.FileSystemTokenCache(
278        path_pattern='/var/run/ccache/token.%(uid)s.%(key)s')
279    if IS_WINDOWS:
280      uid = '_'
281    else:
282      uid = os.getuid()
283    self.assertEquals('/var/run/ccache/token.%s.abc123' % uid,
284                      cache.CacheFileName('abc123'))
285
286    cache = oauth2_client.FileSystemTokenCache(
287        path_pattern='/var/run/ccache/token.%(key)s')
288    self.assertEquals('/var/run/ccache/token.abc123',
289                      cache.CacheFileName('abc123'))
290
291
292class RefreshTokenTest(unittest.TestCase):
293  """Unit tests for refresh tokens."""
294
295  def setUp(self):
296    self.mock_datetime = MockDateTime()
297    self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826)
298    self.mock_datetime.mock_now = self.start_time
299    self.client = CreateMockUserAccountClient(self.mock_datetime)
300
301  def testUniqeId(self):
302    cred_id = self.client.CacheKey()
303    self.assertEquals('0720afed6871f12761fbea3271f451e6ba184bf5', cred_id)
304
305  def testGetAuthorizationHeader(self):
306    self.assertEquals('Bearer %s' % ACCESS_TOKEN,
307                      self.client.GetAuthorizationHeader())
308
309
310class FakeResponse:
311  def __init__(self, status):
312    self._status = status
313
314  @property
315  def status(self):
316    return self._status
317
318
319class OAuth2GCEClientTest(unittest.TestCase):
320  """Unit tests for OAuth2GCEClient."""
321
322  def setUp(self):
323    self.mox = mox.Mox()
324    self.mox.StubOutClassWithMocks(httplib2, 'Http')
325    self.mock_http = httplib2.Http()
326
327  def tearDown(self):
328    self.mox.UnsetStubs()
329
330  @freeze_time('2014-03-26 01:01:01')
331  def testFetchAccessToken(self):
332    token = 'my_token'
333
334    self.mock_http.request(
335        oauth2_client.META_TOKEN_URI,
336        method='GET',
337        body=None,
338        headers=oauth2_client.META_HEADERS).AndReturn((
339            FakeResponse(200),
340            '{"access_token":"%(TOKEN)s",'
341            '"expires_in": %(EXPIRES_IN)d}' % {
342                'TOKEN': token,
343                'EXPIRES_IN': 42
344            }))
345
346    self.mox.ReplayAll()
347
348    client = oauth2_client.OAuth2GCEClient()
349
350    self.assertEqual(
351        str(client.FetchAccessToken()),
352        'AccessToken(token=%s, expiry=2014-03-26 01:01:43Z)' % token)
353
354    self.mox.VerifyAll()
355
356  def testIsGCENotFound(self):
357    self.mock_http.request(oauth2_client.METADATA_SERVER).AndReturn((
358        FakeResponse(404), ''))
359
360    self.mox.ReplayAll()
361    self.assertFalse(oauth2_client._IsGCE())
362
363    self.mox.VerifyAll()
364
365  def testIsGCEServerNotFound(self):
366    self.mock_http.request(oauth2_client.METADATA_SERVER).AndRaise(
367        httplib2.ServerNotFoundError)
368
369    self.mox.ReplayAll()
370    self.assertFalse(oauth2_client._IsGCE())
371
372    self.mox.VerifyAll()
373
374  def testIsGCETrue(self):
375    self.mock_http.request(oauth2_client.METADATA_SERVER).AndReturn((
376        FakeResponse(200), ''))
377
378    self.mox.ReplayAll()
379    self.assertTrue(oauth2_client._IsGCE())
380
381    self.mox.VerifyAll()
382
383
384if __name__ == '__main__':
385  unittest.main()
386