1# Copyright 2014 Google Inc. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Oauth2client.file tests 16 17Unit tests for oauth2client.file 18""" 19 20import copy 21import datetime 22import json 23import os 24import pickle 25import stat 26import tempfile 27 28import six 29from six.moves import http_client 30import unittest2 31 32from oauth2client import client 33from oauth2client import file 34from .http_mock import HttpMockSequence 35 36try: 37 # Python2 38 from future_builtins import oct 39except: # pragma: NO COVER 40 pass 41 42__author__ = 'jcgregorio@google.com (Joe Gregorio)' 43 44_filehandle, FILENAME = tempfile.mkstemp('oauth2client_test.data') 45os.close(_filehandle) 46 47 48class OAuth2ClientFileTests(unittest2.TestCase): 49 50 def tearDown(self): 51 try: 52 os.unlink(FILENAME) 53 except OSError: 54 pass 55 56 def setUp(self): 57 try: 58 os.unlink(FILENAME) 59 except OSError: 60 pass 61 62 def _create_test_credentials(self, client_id='some_client_id', 63 expiration=None): 64 access_token = 'foo' 65 client_secret = 'cOuDdkfjxxnv+' 66 refresh_token = '1/0/a.df219fjls0' 67 token_expiry = expiration or datetime.datetime.utcnow() 68 token_uri = 'https://www.google.com/accounts/o8/oauth2/token' 69 user_agent = 'refresh_checker/1.0' 70 71 credentials = client.OAuth2Credentials( 72 access_token, client_id, client_secret, 73 refresh_token, token_expiry, token_uri, 74 user_agent) 75 return credentials 76 77 def test_non_existent_file_storage(self): 78 s = file.Storage(FILENAME) 79 credentials = s.get() 80 self.assertEquals(None, credentials) 81 82 @unittest2.skipIf(not hasattr(os, 'symlink'), 'No symlink available') 83 def test_no_sym_link_credentials(self): 84 SYMFILENAME = FILENAME + '.sym' 85 os.symlink(FILENAME, SYMFILENAME) 86 s = file.Storage(SYMFILENAME) 87 try: 88 with self.assertRaises(file.CredentialsFileSymbolicLinkError): 89 s.get() 90 finally: 91 os.unlink(SYMFILENAME) 92 93 def test_pickle_and_json_interop(self): 94 # Write a file with a pickled OAuth2Credentials. 95 credentials = self._create_test_credentials() 96 97 f = open(FILENAME, 'wb') 98 pickle.dump(credentials, f) 99 f.close() 100 101 # Storage should be not be able to read that object, as the capability 102 # to read and write credentials as pickled objects has been removed. 103 s = file.Storage(FILENAME) 104 read_credentials = s.get() 105 self.assertEquals(None, read_credentials) 106 107 # Now write it back out and confirm it has been rewritten as JSON 108 s.put(credentials) 109 with open(FILENAME) as f: 110 data = json.load(f) 111 112 self.assertEquals(data['access_token'], 'foo') 113 self.assertEquals(data['_class'], 'OAuth2Credentials') 114 self.assertEquals(data['_module'], client.OAuth2Credentials.__module__) 115 116 def test_token_refresh_store_expired(self): 117 expiration = (datetime.datetime.utcnow() - 118 datetime.timedelta(minutes=15)) 119 credentials = self._create_test_credentials(expiration=expiration) 120 121 s = file.Storage(FILENAME) 122 s.put(credentials) 123 credentials = s.get() 124 new_cred = copy.copy(credentials) 125 new_cred.access_token = 'bar' 126 s.put(new_cred) 127 128 access_token = '1/3w' 129 token_response = {'access_token': access_token, 'expires_in': 3600} 130 http = HttpMockSequence([ 131 ({'status': '200'}, json.dumps(token_response).encode('utf-8')), 132 ]) 133 134 credentials._refresh(http.request) 135 self.assertEquals(credentials.access_token, access_token) 136 137 def test_token_refresh_store_expires_soon(self): 138 # Tests the case where an access token that is valid when it is read 139 # from the store expires before the original request succeeds. 140 expiration = (datetime.datetime.utcnow() + 141 datetime.timedelta(minutes=15)) 142 credentials = self._create_test_credentials(expiration=expiration) 143 144 s = file.Storage(FILENAME) 145 s.put(credentials) 146 credentials = s.get() 147 new_cred = copy.copy(credentials) 148 new_cred.access_token = 'bar' 149 s.put(new_cred) 150 151 access_token = '1/3w' 152 token_response = {'access_token': access_token, 'expires_in': 3600} 153 http = HttpMockSequence([ 154 ({'status': str(int(http_client.UNAUTHORIZED))}, 155 b'Initial token expired'), 156 ({'status': str(int(http_client.UNAUTHORIZED))}, 157 b'Store token expired'), 158 ({'status': str(int(http_client.OK))}, 159 json.dumps(token_response).encode('utf-8')), 160 ({'status': str(int(http_client.OK))}, 161 b'Valid response to original request') 162 ]) 163 164 credentials.authorize(http) 165 http.request('https://example.com') 166 self.assertEqual(credentials.access_token, access_token) 167 168 def test_token_refresh_good_store(self): 169 expiration = (datetime.datetime.utcnow() + 170 datetime.timedelta(minutes=15)) 171 credentials = self._create_test_credentials(expiration=expiration) 172 173 s = file.Storage(FILENAME) 174 s.put(credentials) 175 credentials = s.get() 176 new_cred = copy.copy(credentials) 177 new_cred.access_token = 'bar' 178 s.put(new_cred) 179 180 credentials._refresh(None) 181 self.assertEquals(credentials.access_token, 'bar') 182 183 def test_token_refresh_stream_body(self): 184 expiration = (datetime.datetime.utcnow() + 185 datetime.timedelta(minutes=15)) 186 credentials = self._create_test_credentials(expiration=expiration) 187 188 s = file.Storage(FILENAME) 189 s.put(credentials) 190 credentials = s.get() 191 new_cred = copy.copy(credentials) 192 new_cred.access_token = 'bar' 193 s.put(new_cred) 194 195 valid_access_token = '1/3w' 196 token_response = {'access_token': valid_access_token, 197 'expires_in': 3600} 198 http = HttpMockSequence([ 199 ({'status': str(int(http_client.UNAUTHORIZED))}, 200 b'Initial token expired'), 201 ({'status': str(int(http_client.UNAUTHORIZED))}, 202 b'Store token expired'), 203 ({'status': str(int(http_client.OK))}, 204 json.dumps(token_response).encode('utf-8')), 205 ({'status': str(int(http_client.OK))}, 'echo_request_body') 206 ]) 207 208 body = six.StringIO('streaming body') 209 210 credentials.authorize(http) 211 _, content = http.request('https://example.com', body=body) 212 self.assertEqual(content, 'streaming body') 213 self.assertEqual(credentials.access_token, valid_access_token) 214 215 def test_credentials_delete(self): 216 credentials = self._create_test_credentials() 217 218 s = file.Storage(FILENAME) 219 s.put(credentials) 220 credentials = s.get() 221 self.assertNotEquals(None, credentials) 222 s.delete() 223 credentials = s.get() 224 self.assertEquals(None, credentials) 225 226 def test_access_token_credentials(self): 227 access_token = 'foo' 228 user_agent = 'refresh_checker/1.0' 229 230 credentials = client.AccessTokenCredentials(access_token, user_agent) 231 232 s = file.Storage(FILENAME) 233 credentials = s.put(credentials) 234 credentials = s.get() 235 236 self.assertNotEquals(None, credentials) 237 self.assertEquals('foo', credentials.access_token) 238 239 self.assertTrue(os.path.exists(FILENAME)) 240 241 if os.name == 'posix': # pragma: NO COVER 242 mode = os.stat(FILENAME).st_mode 243 self.assertEquals('0o600', oct(stat.S_IMODE(mode))) 244