• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import datetime
16
17import mock
18import pytest
19
20from google.auth import _helpers
21from google.auth import crypt
22from google.auth import jwt
23from google.auth import transport
24from google.oauth2 import _service_account_async as service_account
25from tests.oauth2 import test_service_account
26
27
28class TestCredentials(object):
29    SERVICE_ACCOUNT_EMAIL = "service-account@example.com"
30    TOKEN_URI = "https://example.com/oauth2/token"
31
32    @classmethod
33    def make_credentials(cls):
34        return service_account.Credentials(
35            test_service_account.SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI
36        )
37
38    def test_from_service_account_info(self):
39        credentials = service_account.Credentials.from_service_account_info(
40            test_service_account.SERVICE_ACCOUNT_INFO
41        )
42
43        assert (
44            credentials._signer.key_id
45            == test_service_account.SERVICE_ACCOUNT_INFO["private_key_id"]
46        )
47        assert (
48            credentials.service_account_email
49            == test_service_account.SERVICE_ACCOUNT_INFO["client_email"]
50        )
51        assert (
52            credentials._token_uri
53            == test_service_account.SERVICE_ACCOUNT_INFO["token_uri"]
54        )
55
56    def test_from_service_account_info_args(self):
57        info = test_service_account.SERVICE_ACCOUNT_INFO.copy()
58        scopes = ["email", "profile"]
59        subject = "subject"
60        additional_claims = {"meta": "data"}
61
62        credentials = service_account.Credentials.from_service_account_info(
63            info, scopes=scopes, subject=subject, additional_claims=additional_claims
64        )
65
66        assert credentials.service_account_email == info["client_email"]
67        assert credentials.project_id == info["project_id"]
68        assert credentials._signer.key_id == info["private_key_id"]
69        assert credentials._token_uri == info["token_uri"]
70        assert credentials._scopes == scopes
71        assert credentials._subject == subject
72        assert credentials._additional_claims == additional_claims
73
74    def test_from_service_account_file(self):
75        info = test_service_account.SERVICE_ACCOUNT_INFO.copy()
76
77        credentials = service_account.Credentials.from_service_account_file(
78            test_service_account.SERVICE_ACCOUNT_JSON_FILE
79        )
80
81        assert credentials.service_account_email == info["client_email"]
82        assert credentials.project_id == info["project_id"]
83        assert credentials._signer.key_id == info["private_key_id"]
84        assert credentials._token_uri == info["token_uri"]
85
86    def test_from_service_account_file_args(self):
87        info = test_service_account.SERVICE_ACCOUNT_INFO.copy()
88        scopes = ["email", "profile"]
89        subject = "subject"
90        additional_claims = {"meta": "data"}
91
92        credentials = service_account.Credentials.from_service_account_file(
93            test_service_account.SERVICE_ACCOUNT_JSON_FILE,
94            subject=subject,
95            scopes=scopes,
96            additional_claims=additional_claims,
97        )
98
99        assert credentials.service_account_email == info["client_email"]
100        assert credentials.project_id == info["project_id"]
101        assert credentials._signer.key_id == info["private_key_id"]
102        assert credentials._token_uri == info["token_uri"]
103        assert credentials._scopes == scopes
104        assert credentials._subject == subject
105        assert credentials._additional_claims == additional_claims
106
107    def test_default_state(self):
108        credentials = self.make_credentials()
109        assert not credentials.valid
110        # Expiration hasn't been set yet
111        assert not credentials.expired
112        # Scopes haven't been specified yet
113        assert credentials.requires_scopes
114
115    def test_sign_bytes(self):
116        credentials = self.make_credentials()
117        to_sign = b"123"
118        signature = credentials.sign_bytes(to_sign)
119        assert crypt.verify_signature(
120            to_sign, signature, test_service_account.PUBLIC_CERT_BYTES
121        )
122
123    def test_signer(self):
124        credentials = self.make_credentials()
125        assert isinstance(credentials.signer, crypt.Signer)
126
127    def test_signer_email(self):
128        credentials = self.make_credentials()
129        assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL
130
131    def test_create_scoped(self):
132        credentials = self.make_credentials()
133        scopes = ["email", "profile"]
134        credentials = credentials.with_scopes(scopes)
135        assert credentials._scopes == scopes
136
137    def test_with_claims(self):
138        credentials = self.make_credentials()
139        new_credentials = credentials.with_claims({"meep": "moop"})
140        assert new_credentials._additional_claims == {"meep": "moop"}
141
142    def test_with_quota_project(self):
143        credentials = self.make_credentials()
144        new_credentials = credentials.with_quota_project("new-project-456")
145        assert new_credentials.quota_project_id == "new-project-456"
146        hdrs = {}
147        new_credentials.apply(hdrs, token="tok")
148        assert "x-goog-user-project" in hdrs
149
150    def test__make_authorization_grant_assertion(self):
151        credentials = self.make_credentials()
152        token = credentials._make_authorization_grant_assertion()
153        payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES)
154        assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
155        assert (
156            payload["aud"]
157            == service_account.service_account._GOOGLE_OAUTH2_TOKEN_ENDPOINT
158        )
159
160    def test__make_authorization_grant_assertion_scoped(self):
161        credentials = self.make_credentials()
162        scopes = ["email", "profile"]
163        credentials = credentials.with_scopes(scopes)
164        token = credentials._make_authorization_grant_assertion()
165        payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES)
166        assert payload["scope"] == "email profile"
167
168    def test__make_authorization_grant_assertion_subject(self):
169        credentials = self.make_credentials()
170        subject = "user@example.com"
171        credentials = credentials.with_subject(subject)
172        token = credentials._make_authorization_grant_assertion()
173        payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES)
174        assert payload["sub"] == subject
175
176    @mock.patch("google.oauth2._client_async.jwt_grant", autospec=True)
177    @pytest.mark.asyncio
178    async def test_refresh_success(self, jwt_grant):
179        credentials = self.make_credentials()
180        token = "token"
181        jwt_grant.return_value = (
182            token,
183            _helpers.utcnow() + datetime.timedelta(seconds=500),
184            {},
185        )
186        request = mock.create_autospec(transport.Request, instance=True)
187
188        # Refresh credentials
189        await credentials.refresh(request)
190
191        # Check jwt grant call.
192        assert jwt_grant.called
193
194        called_request, token_uri, assertion = jwt_grant.call_args[0]
195        assert called_request == request
196        assert token_uri == credentials._token_uri
197        assert jwt.decode(assertion, test_service_account.PUBLIC_CERT_BYTES)
198        # No further assertion done on the token, as there are separate tests
199        # for checking the authorization grant assertion.
200
201        # Check that the credentials have the token.
202        assert credentials.token == token
203
204        # Check that the credentials are valid (have a token and are not
205        # expired)
206        assert credentials.valid
207
208    @mock.patch("google.oauth2._client_async.jwt_grant", autospec=True)
209    @pytest.mark.asyncio
210    async def test_before_request_refreshes(self, jwt_grant):
211        credentials = self.make_credentials()
212        token = "token"
213        jwt_grant.return_value = (
214            token,
215            _helpers.utcnow() + datetime.timedelta(seconds=500),
216            None,
217        )
218        request = mock.create_autospec(transport.Request, instance=True)
219
220        # Credentials should start as invalid
221        assert not credentials.valid
222
223        # before_request should cause a refresh
224        await credentials.before_request(request, "GET", "http://example.com?a=1#3", {})
225
226        # The refresh endpoint should've been called.
227        assert jwt_grant.called
228
229        # Credentials should now be valid.
230        assert credentials.valid
231
232
233class TestIDTokenCredentials(object):
234    SERVICE_ACCOUNT_EMAIL = "service-account@example.com"
235    TOKEN_URI = "https://example.com/oauth2/token"
236    TARGET_AUDIENCE = "https://example.com"
237
238    @classmethod
239    def make_credentials(cls):
240        return service_account.IDTokenCredentials(
241            test_service_account.SIGNER,
242            cls.SERVICE_ACCOUNT_EMAIL,
243            cls.TOKEN_URI,
244            cls.TARGET_AUDIENCE,
245        )
246
247    def test_from_service_account_info(self):
248        credentials = service_account.IDTokenCredentials.from_service_account_info(
249            test_service_account.SERVICE_ACCOUNT_INFO,
250            target_audience=self.TARGET_AUDIENCE,
251        )
252
253        assert (
254            credentials._signer.key_id
255            == test_service_account.SERVICE_ACCOUNT_INFO["private_key_id"]
256        )
257        assert (
258            credentials.service_account_email
259            == test_service_account.SERVICE_ACCOUNT_INFO["client_email"]
260        )
261        assert (
262            credentials._token_uri
263            == test_service_account.SERVICE_ACCOUNT_INFO["token_uri"]
264        )
265        assert credentials._target_audience == self.TARGET_AUDIENCE
266
267    def test_from_service_account_file(self):
268        info = test_service_account.SERVICE_ACCOUNT_INFO.copy()
269
270        credentials = service_account.IDTokenCredentials.from_service_account_file(
271            test_service_account.SERVICE_ACCOUNT_JSON_FILE,
272            target_audience=self.TARGET_AUDIENCE,
273        )
274
275        assert credentials.service_account_email == info["client_email"]
276        assert credentials._signer.key_id == info["private_key_id"]
277        assert credentials._token_uri == info["token_uri"]
278        assert credentials._target_audience == self.TARGET_AUDIENCE
279
280    def test_default_state(self):
281        credentials = self.make_credentials()
282        assert not credentials.valid
283        # Expiration hasn't been set yet
284        assert not credentials.expired
285
286    def test_sign_bytes(self):
287        credentials = self.make_credentials()
288        to_sign = b"123"
289        signature = credentials.sign_bytes(to_sign)
290        assert crypt.verify_signature(
291            to_sign, signature, test_service_account.PUBLIC_CERT_BYTES
292        )
293
294    def test_signer(self):
295        credentials = self.make_credentials()
296        assert isinstance(credentials.signer, crypt.Signer)
297
298    def test_signer_email(self):
299        credentials = self.make_credentials()
300        assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL
301
302    def test_with_target_audience(self):
303        credentials = self.make_credentials()
304        new_credentials = credentials.with_target_audience("https://new.example.com")
305        assert new_credentials._target_audience == "https://new.example.com"
306
307    def test_with_quota_project(self):
308        credentials = self.make_credentials()
309        new_credentials = credentials.with_quota_project("project-foo")
310        assert new_credentials._quota_project_id == "project-foo"
311
312    def test__make_authorization_grant_assertion(self):
313        credentials = self.make_credentials()
314        token = credentials._make_authorization_grant_assertion()
315        payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES)
316        assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
317        assert (
318            payload["aud"]
319            == service_account.service_account._GOOGLE_OAUTH2_TOKEN_ENDPOINT
320        )
321        assert payload["target_audience"] == self.TARGET_AUDIENCE
322
323    @mock.patch("google.oauth2._client_async.id_token_jwt_grant", autospec=True)
324    @pytest.mark.asyncio
325    async def test_refresh_success(self, id_token_jwt_grant):
326        credentials = self.make_credentials()
327        token = "token"
328        id_token_jwt_grant.return_value = (
329            token,
330            _helpers.utcnow() + datetime.timedelta(seconds=500),
331            {},
332        )
333
334        request = mock.AsyncMock(spec=["transport.Request"])
335
336        # Refresh credentials
337        await credentials.refresh(request)
338
339        # Check jwt grant call.
340        assert id_token_jwt_grant.called
341
342        called_request, token_uri, assertion = id_token_jwt_grant.call_args[0]
343        assert called_request == request
344        assert token_uri == credentials._token_uri
345        assert jwt.decode(assertion, test_service_account.PUBLIC_CERT_BYTES)
346        # No further assertion done on the token, as there are separate tests
347        # for checking the authorization grant assertion.
348
349        # Check that the credentials have the token.
350        assert credentials.token == token
351
352        # Check that the credentials are valid (have a token and are not
353        # expired)
354        assert credentials.valid
355
356    @mock.patch("google.oauth2._client_async.id_token_jwt_grant", autospec=True)
357    @pytest.mark.asyncio
358    async def test_before_request_refreshes(self, id_token_jwt_grant):
359        credentials = self.make_credentials()
360        token = "token"
361        id_token_jwt_grant.return_value = (
362            token,
363            _helpers.utcnow() + datetime.timedelta(seconds=500),
364            None,
365        )
366        request = mock.AsyncMock(spec=["transport.Request"])
367
368        # Credentials should start as invalid
369        assert not credentials.valid
370
371        # before_request should cause a refresh
372        await credentials.before_request(request, "GET", "http://example.com?a=1#3", {})
373
374        # The refresh endpoint should've been called.
375        assert id_token_jwt_grant.called
376
377        # Credentials should now be valid.
378        assert credentials.valid
379