1# Copyright 2014 Google Inc. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import base64 16import os 17 18import mock 19import unittest2 20 21from oauth2client import _helpers 22from oauth2client import client 23from oauth2client import crypt 24from oauth2client import service_account 25 26 27def data_filename(filename): 28 return os.path.join(os.path.dirname(__file__), 'data', filename) 29 30 31def datafile(filename): 32 with open(data_filename(filename), 'rb') as file_obj: 33 return file_obj.read() 34 35 36class Test__bad_pkcs12_key_as_pem(unittest2.TestCase): 37 38 def test_fails(self): 39 with self.assertRaises(NotImplementedError): 40 crypt._bad_pkcs12_key_as_pem() 41 42 43class Test_pkcs12_key_as_pem(unittest2.TestCase): 44 45 def _make_svc_account_creds(self, private_key_file='privatekey.p12'): 46 filename = data_filename(private_key_file) 47 credentials = ( 48 service_account.ServiceAccountCredentials.from_p12_keyfile( 49 'some_account@example.com', filename, 50 scopes='read+write')) 51 credentials._kwargs['sub'] = 'joe@example.org' 52 return credentials 53 54 def _succeeds_helper(self, password=None): 55 self.assertEqual(True, client.HAS_OPENSSL) 56 57 credentials = self._make_svc_account_creds() 58 if password is None: 59 password = credentials._private_key_password 60 pem_contents = crypt.pkcs12_key_as_pem( 61 credentials._private_key_pkcs12, password) 62 pkcs12_key_as_pem = datafile('pem_from_pkcs12.pem') 63 pkcs12_key_as_pem = _helpers._parse_pem_key(pkcs12_key_as_pem) 64 alternate_pem = datafile('pem_from_pkcs12_alternate.pem') 65 self.assertTrue(pem_contents in [pkcs12_key_as_pem, alternate_pem]) 66 67 def test_succeeds(self): 68 self._succeeds_helper() 69 70 def test_succeeds_with_unicode_password(self): 71 password = u'notasecret' 72 self._succeeds_helper(password) 73 74 75class Test__verify_signature(unittest2.TestCase): 76 77 def test_success_single_cert(self): 78 cert_value = 'cert-value' 79 certs = [cert_value] 80 message = object() 81 signature = object() 82 83 verifier = mock.MagicMock() 84 verifier.verify = mock.MagicMock(name='verify', return_value=True) 85 with mock.patch('oauth2client.crypt.Verifier') as Verifier: 86 Verifier.from_string = mock.MagicMock(name='from_string', 87 return_value=verifier) 88 result = crypt._verify_signature(message, signature, certs) 89 self.assertEqual(result, None) 90 91 # Make sure our mocks were called as expected. 92 Verifier.from_string.assert_called_once_with(cert_value, 93 is_x509_cert=True) 94 verifier.verify.assert_called_once_with(message, signature) 95 96 def test_success_multiple_certs(self): 97 cert_value1 = 'cert-value1' 98 cert_value2 = 'cert-value2' 99 cert_value3 = 'cert-value3' 100 certs = [cert_value1, cert_value2, cert_value3] 101 message = object() 102 signature = object() 103 104 verifier = mock.MagicMock() 105 # Use side_effect to force all 3 cert values to be used by failing 106 # to verify on the first two. 107 verifier.verify = mock.MagicMock(name='verify', 108 side_effect=[False, False, True]) 109 with mock.patch('oauth2client.crypt.Verifier') as Verifier: 110 Verifier.from_string = mock.MagicMock(name='from_string', 111 return_value=verifier) 112 result = crypt._verify_signature(message, signature, certs) 113 self.assertEqual(result, None) 114 115 # Make sure our mocks were called three times. 116 expected_from_string_calls = [ 117 mock.call(cert_value1, is_x509_cert=True), 118 mock.call(cert_value2, is_x509_cert=True), 119 mock.call(cert_value3, is_x509_cert=True), 120 ] 121 self.assertEqual(Verifier.from_string.mock_calls, 122 expected_from_string_calls) 123 expected_verify_calls = [mock.call(message, signature)] * 3 124 self.assertEqual(verifier.verify.mock_calls, 125 expected_verify_calls) 126 127 def test_failure(self): 128 cert_value = 'cert-value' 129 certs = [cert_value] 130 message = object() 131 signature = object() 132 133 verifier = mock.MagicMock() 134 verifier.verify = mock.MagicMock(name='verify', return_value=False) 135 with mock.patch('oauth2client.crypt.Verifier') as Verifier: 136 Verifier.from_string = mock.MagicMock(name='from_string', 137 return_value=verifier) 138 with self.assertRaises(crypt.AppIdentityError): 139 crypt._verify_signature(message, signature, certs) 140 141 # Make sure our mocks were called as expected. 142 Verifier.from_string.assert_called_once_with(cert_value, 143 is_x509_cert=True) 144 verifier.verify.assert_called_once_with(message, signature) 145 146 147class Test__check_audience(unittest2.TestCase): 148 149 def test_null_audience(self): 150 result = crypt._check_audience(None, None) 151 self.assertEqual(result, None) 152 153 def test_success(self): 154 audience = 'audience' 155 payload_dict = {'aud': audience} 156 result = crypt._check_audience(payload_dict, audience) 157 # No exception and no result. 158 self.assertEqual(result, None) 159 160 def test_missing_aud(self): 161 audience = 'audience' 162 payload_dict = {} 163 with self.assertRaises(crypt.AppIdentityError): 164 crypt._check_audience(payload_dict, audience) 165 166 def test_wrong_aud(self): 167 audience1 = 'audience1' 168 audience2 = 'audience2' 169 self.assertNotEqual(audience1, audience2) 170 payload_dict = {'aud': audience1} 171 with self.assertRaises(crypt.AppIdentityError): 172 crypt._check_audience(payload_dict, audience2) 173 174 175class Test__verify_time_range(unittest2.TestCase): 176 177 def _exception_helper(self, payload_dict): 178 exception_caught = None 179 try: 180 crypt._verify_time_range(payload_dict) 181 except crypt.AppIdentityError as exc: 182 exception_caught = exc 183 184 return exception_caught 185 186 def test_without_issued_at(self): 187 payload_dict = {} 188 exception_caught = self._exception_helper(payload_dict) 189 self.assertNotEqual(exception_caught, None) 190 self.assertTrue(str(exception_caught).startswith( 191 'No iat field in token')) 192 193 def test_without_expiration(self): 194 payload_dict = {'iat': 'iat'} 195 exception_caught = self._exception_helper(payload_dict) 196 self.assertNotEqual(exception_caught, None) 197 self.assertTrue(str(exception_caught).startswith( 198 'No exp field in token')) 199 200 def test_with_bad_token_lifetime(self): 201 current_time = 123456 202 payload_dict = { 203 'iat': 'iat', 204 'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS + 1, 205 } 206 with mock.patch('oauth2client.crypt.time') as time: 207 time.time = mock.MagicMock(name='time', 208 return_value=current_time) 209 210 exception_caught = self._exception_helper(payload_dict) 211 self.assertNotEqual(exception_caught, None) 212 self.assertTrue(str(exception_caught).startswith( 213 'exp field too far in future')) 214 215 def test_with_issued_at_in_future(self): 216 current_time = 123456 217 payload_dict = { 218 'iat': current_time + crypt.CLOCK_SKEW_SECS + 1, 219 'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS - 1, 220 } 221 with mock.patch('oauth2client.crypt.time') as time: 222 time.time = mock.MagicMock(name='time', 223 return_value=current_time) 224 225 exception_caught = self._exception_helper(payload_dict) 226 self.assertNotEqual(exception_caught, None) 227 self.assertTrue(str(exception_caught).startswith( 228 'Token used too early')) 229 230 def test_with_expiration_in_the_past(self): 231 current_time = 123456 232 payload_dict = { 233 'iat': current_time, 234 'exp': current_time - crypt.CLOCK_SKEW_SECS - 1, 235 } 236 with mock.patch('oauth2client.crypt.time') as time: 237 time.time = mock.MagicMock(name='time', 238 return_value=current_time) 239 240 exception_caught = self._exception_helper(payload_dict) 241 self.assertNotEqual(exception_caught, None) 242 self.assertTrue(str(exception_caught).startswith( 243 'Token used too late')) 244 245 def test_success(self): 246 current_time = 123456 247 payload_dict = { 248 'iat': current_time, 249 'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS - 1, 250 } 251 with mock.patch('oauth2client.crypt.time') as time: 252 time.time = mock.MagicMock(name='time', 253 return_value=current_time) 254 255 exception_caught = self._exception_helper(payload_dict) 256 self.assertEqual(exception_caught, None) 257 258 259class Test_verify_signed_jwt_with_certs(unittest2.TestCase): 260 261 def test_jwt_no_segments(self): 262 exception_caught = None 263 try: 264 crypt.verify_signed_jwt_with_certs(b'', None) 265 except crypt.AppIdentityError as exc: 266 exception_caught = exc 267 268 self.assertNotEqual(exception_caught, None) 269 self.assertTrue(str(exception_caught).startswith( 270 'Wrong number of segments in token')) 271 272 def test_jwt_payload_bad_json(self): 273 header = signature = b'' 274 payload = base64.b64encode(b'{BADJSON') 275 jwt = b'.'.join([header, payload, signature]) 276 277 exception_caught = None 278 try: 279 crypt.verify_signed_jwt_with_certs(jwt, None) 280 except crypt.AppIdentityError as exc: 281 exception_caught = exc 282 283 self.assertNotEqual(exception_caught, None) 284 self.assertTrue(str(exception_caught).startswith( 285 'Can\'t parse token')) 286 287 @mock.patch('oauth2client.crypt._check_audience') 288 @mock.patch('oauth2client.crypt._verify_time_range') 289 @mock.patch('oauth2client.crypt._verify_signature') 290 def test_success(self, verify_sig, verify_time, check_aud): 291 certs = mock.MagicMock() 292 cert_values = object() 293 certs.values = mock.MagicMock(name='values', 294 return_value=cert_values) 295 audience = object() 296 297 header = b'header' 298 signature_bytes = b'signature' 299 signature = base64.b64encode(signature_bytes) 300 payload_dict = {'a': 'b'} 301 payload = base64.b64encode(b'{"a": "b"}') 302 jwt = b'.'.join([header, payload, signature]) 303 304 result = crypt.verify_signed_jwt_with_certs( 305 jwt, certs, audience=audience) 306 self.assertEqual(result, payload_dict) 307 308 message_to_sign = header + b'.' + payload 309 verify_sig.assert_called_once_with( 310 message_to_sign, signature_bytes, cert_values) 311 verify_time.assert_called_once_with(payload_dict) 312 check_aud.assert_called_once_with(payload_dict, audience) 313 certs.values.assert_called_once_with() 314