1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import datetime 16 17import mock 18import pytest 19 20from google.auth import _helpers 21from google.auth import crypt 22from google.auth import jwt 23from google.auth import transport 24from google.oauth2 import _service_account_async as service_account 25from tests.oauth2 import test_service_account 26 27 28class TestCredentials(object): 29 SERVICE_ACCOUNT_EMAIL = "service-account@example.com" 30 TOKEN_URI = "https://example.com/oauth2/token" 31 32 @classmethod 33 def make_credentials(cls): 34 return service_account.Credentials( 35 test_service_account.SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI 36 ) 37 38 def test_from_service_account_info(self): 39 credentials = service_account.Credentials.from_service_account_info( 40 test_service_account.SERVICE_ACCOUNT_INFO 41 ) 42 43 assert ( 44 credentials._signer.key_id 45 == test_service_account.SERVICE_ACCOUNT_INFO["private_key_id"] 46 ) 47 assert ( 48 credentials.service_account_email 49 == test_service_account.SERVICE_ACCOUNT_INFO["client_email"] 50 ) 51 assert ( 52 credentials._token_uri 53 == test_service_account.SERVICE_ACCOUNT_INFO["token_uri"] 54 ) 55 56 def test_from_service_account_info_args(self): 57 info = test_service_account.SERVICE_ACCOUNT_INFO.copy() 58 scopes = ["email", "profile"] 59 subject = "subject" 60 additional_claims = {"meta": "data"} 61 62 credentials = service_account.Credentials.from_service_account_info( 63 info, scopes=scopes, subject=subject, additional_claims=additional_claims 64 ) 65 66 assert credentials.service_account_email == info["client_email"] 67 assert credentials.project_id == info["project_id"] 68 assert credentials._signer.key_id == info["private_key_id"] 69 assert credentials._token_uri == info["token_uri"] 70 assert credentials._scopes == scopes 71 assert credentials._subject == subject 72 assert credentials._additional_claims == additional_claims 73 74 def test_from_service_account_file(self): 75 info = test_service_account.SERVICE_ACCOUNT_INFO.copy() 76 77 credentials = service_account.Credentials.from_service_account_file( 78 test_service_account.SERVICE_ACCOUNT_JSON_FILE 79 ) 80 81 assert credentials.service_account_email == info["client_email"] 82 assert credentials.project_id == info["project_id"] 83 assert credentials._signer.key_id == info["private_key_id"] 84 assert credentials._token_uri == info["token_uri"] 85 86 def test_from_service_account_file_args(self): 87 info = test_service_account.SERVICE_ACCOUNT_INFO.copy() 88 scopes = ["email", "profile"] 89 subject = "subject" 90 additional_claims = {"meta": "data"} 91 92 credentials = service_account.Credentials.from_service_account_file( 93 test_service_account.SERVICE_ACCOUNT_JSON_FILE, 94 subject=subject, 95 scopes=scopes, 96 additional_claims=additional_claims, 97 ) 98 99 assert credentials.service_account_email == info["client_email"] 100 assert credentials.project_id == info["project_id"] 101 assert credentials._signer.key_id == info["private_key_id"] 102 assert credentials._token_uri == info["token_uri"] 103 assert credentials._scopes == scopes 104 assert credentials._subject == subject 105 assert credentials._additional_claims == additional_claims 106 107 def test_default_state(self): 108 credentials = self.make_credentials() 109 assert not credentials.valid 110 # Expiration hasn't been set yet 111 assert not credentials.expired 112 # Scopes haven't been specified yet 113 assert credentials.requires_scopes 114 115 def test_sign_bytes(self): 116 credentials = self.make_credentials() 117 to_sign = b"123" 118 signature = credentials.sign_bytes(to_sign) 119 assert crypt.verify_signature( 120 to_sign, signature, test_service_account.PUBLIC_CERT_BYTES 121 ) 122 123 def test_signer(self): 124 credentials = self.make_credentials() 125 assert isinstance(credentials.signer, crypt.Signer) 126 127 def test_signer_email(self): 128 credentials = self.make_credentials() 129 assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL 130 131 def test_create_scoped(self): 132 credentials = self.make_credentials() 133 scopes = ["email", "profile"] 134 credentials = credentials.with_scopes(scopes) 135 assert credentials._scopes == scopes 136 137 def test_with_claims(self): 138 credentials = self.make_credentials() 139 new_credentials = credentials.with_claims({"meep": "moop"}) 140 assert new_credentials._additional_claims == {"meep": "moop"} 141 142 def test_with_quota_project(self): 143 credentials = self.make_credentials() 144 new_credentials = credentials.with_quota_project("new-project-456") 145 assert new_credentials.quota_project_id == "new-project-456" 146 hdrs = {} 147 new_credentials.apply(hdrs, token="tok") 148 assert "x-goog-user-project" in hdrs 149 150 def test__make_authorization_grant_assertion(self): 151 credentials = self.make_credentials() 152 token = credentials._make_authorization_grant_assertion() 153 payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES) 154 assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL 155 assert ( 156 payload["aud"] 157 == service_account.service_account._GOOGLE_OAUTH2_TOKEN_ENDPOINT 158 ) 159 160 def test__make_authorization_grant_assertion_scoped(self): 161 credentials = self.make_credentials() 162 scopes = ["email", "profile"] 163 credentials = credentials.with_scopes(scopes) 164 token = credentials._make_authorization_grant_assertion() 165 payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES) 166 assert payload["scope"] == "email profile" 167 168 def test__make_authorization_grant_assertion_subject(self): 169 credentials = self.make_credentials() 170 subject = "user@example.com" 171 credentials = credentials.with_subject(subject) 172 token = credentials._make_authorization_grant_assertion() 173 payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES) 174 assert payload["sub"] == subject 175 176 @mock.patch("google.oauth2._client_async.jwt_grant", autospec=True) 177 @pytest.mark.asyncio 178 async def test_refresh_success(self, jwt_grant): 179 credentials = self.make_credentials() 180 token = "token" 181 jwt_grant.return_value = ( 182 token, 183 _helpers.utcnow() + datetime.timedelta(seconds=500), 184 {}, 185 ) 186 request = mock.create_autospec(transport.Request, instance=True) 187 188 # Refresh credentials 189 await credentials.refresh(request) 190 191 # Check jwt grant call. 192 assert jwt_grant.called 193 194 called_request, token_uri, assertion = jwt_grant.call_args[0] 195 assert called_request == request 196 assert token_uri == credentials._token_uri 197 assert jwt.decode(assertion, test_service_account.PUBLIC_CERT_BYTES) 198 # No further assertion done on the token, as there are separate tests 199 # for checking the authorization grant assertion. 200 201 # Check that the credentials have the token. 202 assert credentials.token == token 203 204 # Check that the credentials are valid (have a token and are not 205 # expired) 206 assert credentials.valid 207 208 @mock.patch("google.oauth2._client_async.jwt_grant", autospec=True) 209 @pytest.mark.asyncio 210 async def test_before_request_refreshes(self, jwt_grant): 211 credentials = self.make_credentials() 212 token = "token" 213 jwt_grant.return_value = ( 214 token, 215 _helpers.utcnow() + datetime.timedelta(seconds=500), 216 None, 217 ) 218 request = mock.create_autospec(transport.Request, instance=True) 219 220 # Credentials should start as invalid 221 assert not credentials.valid 222 223 # before_request should cause a refresh 224 await credentials.before_request(request, "GET", "http://example.com?a=1#3", {}) 225 226 # The refresh endpoint should've been called. 227 assert jwt_grant.called 228 229 # Credentials should now be valid. 230 assert credentials.valid 231 232 233class TestIDTokenCredentials(object): 234 SERVICE_ACCOUNT_EMAIL = "service-account@example.com" 235 TOKEN_URI = "https://example.com/oauth2/token" 236 TARGET_AUDIENCE = "https://example.com" 237 238 @classmethod 239 def make_credentials(cls): 240 return service_account.IDTokenCredentials( 241 test_service_account.SIGNER, 242 cls.SERVICE_ACCOUNT_EMAIL, 243 cls.TOKEN_URI, 244 cls.TARGET_AUDIENCE, 245 ) 246 247 def test_from_service_account_info(self): 248 credentials = service_account.IDTokenCredentials.from_service_account_info( 249 test_service_account.SERVICE_ACCOUNT_INFO, 250 target_audience=self.TARGET_AUDIENCE, 251 ) 252 253 assert ( 254 credentials._signer.key_id 255 == test_service_account.SERVICE_ACCOUNT_INFO["private_key_id"] 256 ) 257 assert ( 258 credentials.service_account_email 259 == test_service_account.SERVICE_ACCOUNT_INFO["client_email"] 260 ) 261 assert ( 262 credentials._token_uri 263 == test_service_account.SERVICE_ACCOUNT_INFO["token_uri"] 264 ) 265 assert credentials._target_audience == self.TARGET_AUDIENCE 266 267 def test_from_service_account_file(self): 268 info = test_service_account.SERVICE_ACCOUNT_INFO.copy() 269 270 credentials = service_account.IDTokenCredentials.from_service_account_file( 271 test_service_account.SERVICE_ACCOUNT_JSON_FILE, 272 target_audience=self.TARGET_AUDIENCE, 273 ) 274 275 assert credentials.service_account_email == info["client_email"] 276 assert credentials._signer.key_id == info["private_key_id"] 277 assert credentials._token_uri == info["token_uri"] 278 assert credentials._target_audience == self.TARGET_AUDIENCE 279 280 def test_default_state(self): 281 credentials = self.make_credentials() 282 assert not credentials.valid 283 # Expiration hasn't been set yet 284 assert not credentials.expired 285 286 def test_sign_bytes(self): 287 credentials = self.make_credentials() 288 to_sign = b"123" 289 signature = credentials.sign_bytes(to_sign) 290 assert crypt.verify_signature( 291 to_sign, signature, test_service_account.PUBLIC_CERT_BYTES 292 ) 293 294 def test_signer(self): 295 credentials = self.make_credentials() 296 assert isinstance(credentials.signer, crypt.Signer) 297 298 def test_signer_email(self): 299 credentials = self.make_credentials() 300 assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL 301 302 def test_with_target_audience(self): 303 credentials = self.make_credentials() 304 new_credentials = credentials.with_target_audience("https://new.example.com") 305 assert new_credentials._target_audience == "https://new.example.com" 306 307 def test_with_quota_project(self): 308 credentials = self.make_credentials() 309 new_credentials = credentials.with_quota_project("project-foo") 310 assert new_credentials._quota_project_id == "project-foo" 311 312 def test__make_authorization_grant_assertion(self): 313 credentials = self.make_credentials() 314 token = credentials._make_authorization_grant_assertion() 315 payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES) 316 assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL 317 assert ( 318 payload["aud"] 319 == service_account.service_account._GOOGLE_OAUTH2_TOKEN_ENDPOINT 320 ) 321 assert payload["target_audience"] == self.TARGET_AUDIENCE 322 323 @mock.patch("google.oauth2._client_async.id_token_jwt_grant", autospec=True) 324 @pytest.mark.asyncio 325 async def test_refresh_success(self, id_token_jwt_grant): 326 credentials = self.make_credentials() 327 token = "token" 328 id_token_jwt_grant.return_value = ( 329 token, 330 _helpers.utcnow() + datetime.timedelta(seconds=500), 331 {}, 332 ) 333 334 request = mock.AsyncMock(spec=["transport.Request"]) 335 336 # Refresh credentials 337 await credentials.refresh(request) 338 339 # Check jwt grant call. 340 assert id_token_jwt_grant.called 341 342 called_request, token_uri, assertion = id_token_jwt_grant.call_args[0] 343 assert called_request == request 344 assert token_uri == credentials._token_uri 345 assert jwt.decode(assertion, test_service_account.PUBLIC_CERT_BYTES) 346 # No further assertion done on the token, as there are separate tests 347 # for checking the authorization grant assertion. 348 349 # Check that the credentials have the token. 350 assert credentials.token == token 351 352 # Check that the credentials are valid (have a token and are not 353 # expired) 354 assert credentials.valid 355 356 @mock.patch("google.oauth2._client_async.id_token_jwt_grant", autospec=True) 357 @pytest.mark.asyncio 358 async def test_before_request_refreshes(self, id_token_jwt_grant): 359 credentials = self.make_credentials() 360 token = "token" 361 id_token_jwt_grant.return_value = ( 362 token, 363 _helpers.utcnow() + datetime.timedelta(seconds=500), 364 None, 365 ) 366 request = mock.AsyncMock(spec=["transport.Request"]) 367 368 # Credentials should start as invalid 369 assert not credentials.valid 370 371 # before_request should cause a refresh 372 await credentials.before_request(request, "GET", "http://example.com?a=1#3", {}) 373 374 # The refresh endpoint should've been called. 375 assert id_token_jwt_grant.called 376 377 # Credentials should now be valid. 378 assert credentials.valid 379