• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2015, Google Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions are
6  * met:
7  *
8  *    * Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  *    * Redistributions in binary form must reproduce the above
11  * copyright notice, this list of conditions and the following disclaimer
12  * in the documentation and/or other materials provided with the
13  * distribution.
14  *
15  *    * Neither the name of Google Inc. nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30  */
31 
32 package com.google.auth.oauth2;
33 
34 import static java.util.concurrent.TimeUnit.HOURS;
35 import static org.junit.Assert.assertEquals;
36 import static org.junit.Assert.assertFalse;
37 import static org.junit.Assert.assertNotNull;
38 import static org.junit.Assert.assertNull;
39 import static org.junit.Assert.assertSame;
40 import static org.junit.Assert.assertThrows;
41 import static org.junit.Assert.assertTrue;
42 import static org.junit.Assert.fail;
43 
44 import com.google.api.client.util.Clock;
45 import com.google.auth.TestClock;
46 import com.google.auth.TestUtils;
47 import com.google.auth.http.AuthHttpConstants;
48 import com.google.auth.oauth2.OAuth2Credentials.OAuthValue;
49 import com.google.auth.oauth2.OAuth2Credentials.RefreshTask;
50 import com.google.common.collect.ImmutableList;
51 import com.google.common.collect.ImmutableMap;
52 import com.google.common.util.concurrent.ListenableFutureTask;
53 import com.google.common.util.concurrent.SettableFuture;
54 import java.io.IOException;
55 import java.net.URI;
56 import java.time.Duration;
57 import java.time.Instant;
58 import java.util.ArrayList;
59 import java.util.Arrays;
60 import java.util.Date;
61 import java.util.HashMap;
62 import java.util.List;
63 import java.util.Map;
64 import java.util.concurrent.Callable;
65 import java.util.concurrent.ExecutionException;
66 import java.util.concurrent.ExecutorService;
67 import java.util.concurrent.Executors;
68 import java.util.concurrent.Future;
69 import java.util.concurrent.FutureTask;
70 import java.util.concurrent.TimeoutException;
71 import java.util.concurrent.atomic.AtomicInteger;
72 import java.util.concurrent.atomic.AtomicReference;
73 import org.junit.After;
74 import org.junit.Before;
75 import org.junit.Ignore;
76 import org.junit.Test;
77 import org.junit.function.ThrowingRunnable;
78 import org.junit.runner.RunWith;
79 import org.junit.runners.JUnit4;
80 
81 /** Test case for {@link OAuth2Credentials}. */
82 @RunWith(JUnit4.class)
83 public class OAuth2CredentialsTest extends BaseSerializationTest {
84 
85   private static final String CLIENT_SECRET = "jakuaL9YyieakhECKL2SwZcu";
86   private static final String CLIENT_ID = "ya29.1.AADtN_UtlxN3PuGAxrN2XQnZTVRvDyVWnYq4I6dws";
87   private static final String REFRESH_TOKEN = "1/Tl6awhpFjkMkSJoj1xsli0H2eL5YsMgU_NKPY2TyGWY";
88   private static final String ACCESS_TOKEN = "aashpFjkMkSJoj1xsli0H2eL5YsMgU_NKPY2TyGWY";
89   private static final URI CALL_URI = URI.create("http://googleapis.com/testapi/v1/foo");
90 
91   private ExecutorService realExecutor;
92 
93   @Before
setUp()94   public void setUp() {
95     realExecutor = Executors.newCachedThreadPool();
96   }
97 
98   @After
tearDown()99   public void tearDown() {
100     realExecutor.shutdown();
101   }
102 
103   @Test
constructor_storesAccessToken()104   public void constructor_storesAccessToken() {
105     OAuth2Credentials credentials =
106         OAuth2Credentials.newBuilder().setAccessToken(new AccessToken(ACCESS_TOKEN, null)).build();
107     assertEquals(credentials.getAccessToken().getTokenValue(), ACCESS_TOKEN);
108   }
109 
110   @Test
constructor_overrideMargin()111   public void constructor_overrideMargin() throws Throwable {
112     Duration staleMargin = Duration.ofMinutes(3);
113     Duration expirationMargin = Duration.ofMinutes(2);
114 
115     Instant actualExpiration = Instant.now();
116     Instant clientStale = actualExpiration.minus(staleMargin);
117     Instant clientExpired = actualExpiration.minus(expirationMargin);
118 
119     AccessToken initialToken = new AccessToken(ACCESS_TOKEN, Date.from(actualExpiration));
120     AtomicInteger refreshCount = new AtomicInteger();
121     AtomicReference<AccessToken> currentToken = new AtomicReference<>(initialToken);
122 
123     OAuth2Credentials credentials =
124         new OAuth2Credentials(
125             currentToken.get(),
126             /* refreshMargin= */ Duration.ofMinutes(3),
127             /* expirationMargin= */ Duration.ofMinutes(2)) {
128           @Override
129           public AccessToken refreshAccessToken() throws IOException {
130             refreshCount.incrementAndGet();
131             // Inject delay to model network latency
132             // This is needed to make to deflake the stale tests:
133             // if the refresh is super quick, then a stale refresh will return the new token
134             try {
135               Thread.sleep(100);
136             } catch (InterruptedException e) {
137               throw new IOException(e);
138             }
139 
140             return currentToken.get();
141           }
142         };
143 
144     TestClock clock = new TestClock();
145     credentials.clock = clock;
146 
147     // Rewind time to when the token is fresh
148     clock.setCurrentTime(clientStale.toEpochMilli() - 1);
149     MockRequestMetadataCallback callback = new MockRequestMetadataCallback();
150     credentials.getRequestMetadata(CALL_URI, realExecutor, callback);
151     synchronized (credentials.lock) {
152       assertNull(credentials.refreshTask);
153     }
154     assertEquals(0, refreshCount.get());
155     Map<String, List<String>> lastMetadata = credentials.getRequestMetadata(CALL_URI);
156 
157     // Fast forward to when the token just turned STALE
158     clock.setCurrentTime(clientStale.toEpochMilli());
159     currentToken.set(new AccessToken(ACCESS_TOKEN + "-1", Date.from(actualExpiration)));
160     callback.reset();
161     credentials.getRequestMetadata(CALL_URI, realExecutor, callback);
162     assertEquals(lastMetadata, callback.awaitResult());
163     waitForRefreshTaskCompletion(credentials);
164     assertEquals(1, refreshCount.get());
165     lastMetadata = credentials.getRequestMetadata(CALL_URI);
166     refreshCount.set(0);
167 
168     // Fast forward to when the token turned STALE just before expiration
169     clock.setCurrentTime(clientExpired.toEpochMilli() - 1);
170     currentToken.set(new AccessToken(ACCESS_TOKEN + "-2", Date.from(actualExpiration)));
171     callback.reset();
172     credentials.getRequestMetadata(CALL_URI, realExecutor, callback);
173     assertEquals(lastMetadata, callback.awaitResult());
174     waitForRefreshTaskCompletion(credentials);
175     assertEquals(1, refreshCount.get());
176     lastMetadata = credentials.getRequestMetadata();
177     refreshCount.set(0);
178 
179     // Fast forward to expired
180     clock.setCurrentTime(clientExpired.toEpochMilli());
181     AccessToken newToken = new AccessToken(ACCESS_TOKEN + "-3", Date.from(actualExpiration));
182     currentToken.set(newToken);
183     callback.reset();
184     credentials.getRequestMetadata(CALL_URI, realExecutor, callback);
185     TestUtils.assertContainsBearerToken(callback.awaitResult(), newToken.getTokenValue());
186     assertEquals(1, refreshCount.get());
187     waitForRefreshTaskCompletion(credentials);
188     lastMetadata = credentials.getRequestMetadata();
189   }
190 
191   @Test
getAuthenticationType_returnsOAuth2()192   public void getAuthenticationType_returnsOAuth2() {
193     OAuth2Credentials credentials =
194         UserCredentials.newBuilder()
195             .setClientId(CLIENT_ID)
196             .setClientSecret(CLIENT_SECRET)
197             .setRefreshToken(REFRESH_TOKEN)
198             .build();
199     assertEquals(credentials.getAuthenticationType(), "OAuth2");
200   }
201 
202   @Test
hasRequestMetadata_returnsTrue()203   public void hasRequestMetadata_returnsTrue() {
204     OAuth2Credentials credentials =
205         UserCredentials.newBuilder()
206             .setClientId(CLIENT_ID)
207             .setClientSecret(CLIENT_SECRET)
208             .setRefreshToken(REFRESH_TOKEN)
209             .build();
210     assertTrue(credentials.hasRequestMetadata());
211   }
212 
213   @Test
hasRequestMetadataOnly_returnsTrue()214   public void hasRequestMetadataOnly_returnsTrue() {
215     OAuth2Credentials credentials =
216         UserCredentials.newBuilder()
217             .setClientId(CLIENT_ID)
218             .setClientSecret(CLIENT_SECRET)
219             .setRefreshToken(REFRESH_TOKEN)
220             .build();
221     assertTrue(credentials.hasRequestMetadata());
222   }
223 
224   @Test
addChangeListener_notifiesOnRefresh()225   public void addChangeListener_notifiesOnRefresh() throws IOException {
226     final String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2";
227     final String accessToken2 = "2/MkSJoj1xsli0AccessToken_NKPY2";
228     MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
229     transportFactory.transport.addClient(CLIENT_ID, CLIENT_SECRET);
230     transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken1);
231     OAuth2Credentials userCredentials =
232         UserCredentials.newBuilder()
233             .setClientId(CLIENT_ID)
234             .setClientSecret(CLIENT_SECRET)
235             .setRefreshToken(REFRESH_TOKEN)
236             .setHttpTransportFactory(transportFactory)
237             .build();
238     // Use a fixed clock so tokens don't expire
239     userCredentials.clock = new TestClock();
240     TestChangeListener listener = new TestChangeListener();
241     userCredentials.addChangeListener(listener);
242     Map<String, List<String>> metadata;
243     assertEquals(0, listener.callCount);
244 
245     // Get a first token
246     metadata = userCredentials.getRequestMetadata(CALL_URI);
247     TestUtils.assertContainsBearerToken(metadata, accessToken1);
248     assertEquals(accessToken1, listener.accessToken.getTokenValue());
249     assertEquals(1, listener.callCount);
250 
251     // Change server to a different token and refresh
252     transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken2);
253     // Refresh to force getting next token
254     userCredentials.refresh();
255 
256     metadata = userCredentials.getRequestMetadata(CALL_URI);
257     TestUtils.assertContainsBearerToken(metadata, accessToken2);
258     assertEquals(accessToken2, listener.accessToken.getTokenValue());
259     assertEquals(2, listener.callCount);
260   }
261 
262   @Test
removeChangeListener_unregisters_observer()263   public void removeChangeListener_unregisters_observer() throws IOException {
264     final String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2";
265     final String accessToken2 = "2/MkSJoj1xsli0AccessToken_NKPY2";
266     MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
267     transportFactory.transport.addClient(CLIENT_ID, CLIENT_SECRET);
268     transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken1);
269     OAuth2Credentials userCredentials =
270         UserCredentials.newBuilder()
271             .setClientId(CLIENT_ID)
272             .setClientSecret(CLIENT_SECRET)
273             .setRefreshToken(REFRESH_TOKEN)
274             .setHttpTransportFactory(transportFactory)
275             .build();
276     // Use a fixed clock so tokens don't expire
277     userCredentials.clock = new TestClock();
278     TestChangeListener listener = new TestChangeListener();
279     userCredentials.addChangeListener(listener);
280     assertEquals(0, listener.callCount);
281 
282     // Get a first token
283     userCredentials.getRequestMetadata(CALL_URI);
284     assertEquals(1, listener.callCount);
285 
286     // Change server to a different token and refresh
287     transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken2);
288     // Refresh to force getting next token
289     userCredentials.refresh();
290     assertEquals(2, listener.callCount);
291 
292     // Remove the listener and refresh the credential again
293     userCredentials.removeChangeListener(listener);
294     transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken2);
295     userCredentials.refresh();
296     assertEquals(2, listener.callCount);
297   }
298 
299   @Test
getRequestMetadata_blocking_cachesExpiringToken()300   public void getRequestMetadata_blocking_cachesExpiringToken() throws IOException {
301     final String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2";
302     final String accessToken2 = "2/MkSJoj1xsli0AccessToken_NKPY2";
303     MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
304     transportFactory.transport.addClient(CLIENT_ID, CLIENT_SECRET);
305     transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken1);
306     TestClock clock = new TestClock();
307     OAuth2Credentials credentials =
308         UserCredentials.newBuilder()
309             .setClientId(CLIENT_ID)
310             .setClientSecret(CLIENT_SECRET)
311             .setRefreshToken(REFRESH_TOKEN)
312             .setHttpTransportFactory(transportFactory)
313             .build();
314     credentials.clock = clock;
315 
316     // Verify getting the first token
317     assertEquals(0, transportFactory.transport.buildRequestCount);
318     Map<String, List<String>> metadata = credentials.getRequestMetadata(CALL_URI);
319     TestUtils.assertContainsBearerToken(metadata, accessToken1);
320     assertEquals(1, transportFactory.transport.buildRequestCount--);
321 
322     // Change server to a different token
323     transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken2);
324 
325     // Make transport fail when used next time.
326     IOException error = new IOException("error");
327     transportFactory.transport.setError(error);
328 
329     // Advance 5 minutes and verify original token
330     clock.addToCurrentTime(5 * 60 * 1000);
331     metadata = credentials.getRequestMetadata(CALL_URI);
332     TestUtils.assertContainsBearerToken(metadata, accessToken1);
333 
334     // Advance 60 minutes and verify revised token
335     clock.addToCurrentTime(60 * 60 * 1000);
336     assertEquals(0, transportFactory.transport.buildRequestCount);
337 
338     try {
339       credentials.getRequestMetadata(CALL_URI);
340       fail("Should throw");
341     } catch (IOException e) {
342       assertSame(error, e.getCause());
343       assertEquals(1, transportFactory.transport.buildRequestCount--);
344     }
345 
346     // Reset the error and try again
347     transportFactory.transport.setError(null);
348     metadata = credentials.getRequestMetadata(CALL_URI);
349     TestUtils.assertContainsBearerToken(metadata, accessToken2);
350     assertEquals(1, transportFactory.transport.buildRequestCount--);
351   }
352 
353   @Test
getRequestMetadata_async()354   public void getRequestMetadata_async() throws IOException {
355     final String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2";
356     final String accessToken2 = "2/MkSJoj1xsli0AccessToken_NKPY2";
357     MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
358     transportFactory.transport.addClient(CLIENT_ID, CLIENT_SECRET);
359     transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken1);
360     TestClock clock = new TestClock();
361     OAuth2Credentials credentials =
362         UserCredentials.newBuilder()
363             .setClientId(CLIENT_ID)
364             .setClientSecret(CLIENT_SECRET)
365             .setRefreshToken(REFRESH_TOKEN)
366             .setHttpTransportFactory(transportFactory)
367             .build();
368     credentials.clock = clock;
369 
370     MockExecutor executor = new MockExecutor();
371     MockRequestMetadataCallback callback = new MockRequestMetadataCallback();
372     // Verify getting the first token, which uses the transport and calls the callback in the
373     // executor.
374     credentials.getRequestMetadata(CALL_URI, executor, callback);
375     assertEquals(0, transportFactory.transport.buildRequestCount);
376     assertNull(callback.metadata);
377 
378     assertEquals(1, executor.runTasksExhaustively());
379     assertNotNull(callback.metadata);
380     TestUtils.assertContainsBearerToken(callback.metadata, accessToken1);
381     assertEquals(1, transportFactory.transport.buildRequestCount--);
382 
383     // Change server to a different token
384     transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken2);
385 
386     // Make transport fail when used next time.
387     IOException error = new IOException("error");
388     transportFactory.transport.setError(error);
389 
390     // Advance 5 minutes and verify original token. Callback is called inline.
391     callback.reset();
392     clock.addToCurrentTime(5 * 60 * 1000);
393     assertNull(callback.metadata);
394     credentials.getRequestMetadata(CALL_URI, executor, callback);
395     assertNotNull(callback.metadata);
396     assertEquals(0, executor.numTasks());
397     TestUtils.assertContainsBearerToken(callback.metadata, accessToken1);
398 
399     // Advance 60 minutes and verify revised token, which uses the executor.
400     callback.reset();
401     clock.addToCurrentTime(60 * 60 * 1000);
402     credentials.getRequestMetadata(CALL_URI, executor, callback);
403     assertEquals(0, transportFactory.transport.buildRequestCount);
404     assertNull(callback.exception);
405 
406     assertEquals(1, executor.runTasksExhaustively());
407     assertSame(error, callback.exception.getCause());
408     assertEquals(1, transportFactory.transport.buildRequestCount--);
409 
410     // Reset the error and try again
411     transportFactory.transport.setError(null);
412     callback.reset();
413     credentials.getRequestMetadata(CALL_URI, executor, callback);
414     assertEquals(0, transportFactory.transport.buildRequestCount);
415     assertNull(callback.metadata);
416 
417     assertEquals(1, executor.runTasksExhaustively());
418     assertNotNull(callback.metadata);
419     TestUtils.assertContainsBearerToken(callback.metadata, accessToken2);
420     assertEquals(1, transportFactory.transport.buildRequestCount--);
421   }
422 
423   @Test
getRequestMetadata_async_refreshRace()424   public void getRequestMetadata_async_refreshRace()
425       throws ExecutionException, InterruptedException {
426     final String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2";
427     MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
428     transportFactory.transport.addClient(CLIENT_ID, CLIENT_SECRET);
429     transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken1);
430     TestClock clock = new TestClock();
431     final OAuth2Credentials credentials =
432         UserCredentials.newBuilder()
433             .setClientId(CLIENT_ID)
434             .setClientSecret(CLIENT_SECRET)
435             .setRefreshToken(REFRESH_TOKEN)
436             .setHttpTransportFactory(transportFactory)
437             .build();
438     credentials.clock = clock;
439 
440     MockExecutor executor = new MockExecutor();
441     MockRequestMetadataCallback callback = new MockRequestMetadataCallback();
442     // Getting the first token, which uses the transport and calls the callback in the executor.
443     credentials.getRequestMetadata(CALL_URI, executor, callback);
444     assertEquals(0, transportFactory.transport.buildRequestCount);
445     assertNull(callback.metadata);
446 
447     // Asynchronous task is scheduled, and a blocking call follows it
448     assertEquals(1, executor.numTasks());
449 
450     ExecutorService testExecutor = Executors.newFixedThreadPool(1);
451 
452     FutureTask<Map<String, List<String>>> blockingTask =
453         new FutureTask<>(
454             new Callable<Map<String, List<String>>>() {
455               @Override
456               public Map<String, List<String>> call() throws Exception {
457                 return credentials.getRequestMetadata(CALL_URI);
458               }
459             });
460 
461     @SuppressWarnings("FutureReturnValueIgnored")
462     Future<?> ignored = testExecutor.submit(blockingTask);
463     testExecutor.shutdown();
464 
465     // give the blockingTask a chance to run
466     for (int i = 0; i < 10; i++) {
467       Thread.yield();
468     }
469 
470     // blocking task is waiting on the async task to finish
471     assertFalse(blockingTask.isDone());
472     assertEquals(0, transportFactory.transport.buildRequestCount);
473 
474     // When the task is run, the result is shared
475     assertEquals(1, executor.runTasksExhaustively());
476     assertEquals(1, transportFactory.transport.buildRequestCount--);
477     Map<String, List<String>> metadata = blockingTask.get();
478     assertEquals(0, transportFactory.transport.buildRequestCount);
479     assertEquals(metadata, callback.metadata);
480   }
481 
482   @Test
getRequestMetadata_temporaryToken_hasToken()483   public void getRequestMetadata_temporaryToken_hasToken() throws IOException {
484     OAuth2Credentials credentials =
485         OAuth2Credentials.newBuilder().setAccessToken(new AccessToken(ACCESS_TOKEN, null)).build();
486 
487     // Verify getting the first token
488     Map<String, List<String>> metadata = credentials.getRequestMetadata(CALL_URI);
489     TestUtils.assertContainsBearerToken(metadata, ACCESS_TOKEN);
490   }
491 
492   @Test
getRequestMetadata_staleTemporaryToken()493   public void getRequestMetadata_staleTemporaryToken() throws IOException, InterruptedException {
494     Instant actualExpiration = Instant.now();
495     Instant clientStale = actualExpiration.minus(OAuth2Credentials.DEFAULT_REFRESH_MARGIN);
496 
497     TestClock testClock = new TestClock();
498     testClock.setCurrentTime(clientStale.toEpochMilli());
499 
500     // Initialize credentials which are initially stale and set to refresh
501     final SettableFuture<AccessToken> refreshedTokenFuture = SettableFuture.create();
502     OAuth2Credentials creds =
503         new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, Date.from(actualExpiration))) {
504           @Override
505           public AccessToken refreshAccessToken() {
506 
507             try {
508               return refreshedTokenFuture.get();
509             } catch (Exception e) {
510               throw new RuntimeException(e);
511             }
512           }
513         };
514     creds.clock = testClock;
515     synchronized (creds.lock) {
516       assertNull(creds.refreshTask);
517     }
518 
519     // Calls should return immediately with stale token
520     MockRequestMetadataCallback callback = new MockRequestMetadataCallback();
521     creds.getRequestMetadata(CALL_URI, realExecutor, callback);
522     TestUtils.assertContainsBearerToken(callback.metadata, ACCESS_TOKEN);
523     TestUtils.assertContainsBearerToken(creds.getRequestMetadata(CALL_URI), ACCESS_TOKEN);
524 
525     // But a refresh task should be scheduled
526     synchronized (creds.lock) {
527       assertNotNull(creds.refreshTask);
528     }
529 
530     // Resolve the outstanding refresh
531     AccessToken refreshedToken =
532         new AccessToken(
533             "2/MkSJoj1xsli0AccessToken_NKPY2",
534             new Date(testClock.currentTimeMillis() + HOURS.toMillis(1)));
535     refreshedTokenFuture.set(refreshedToken);
536 
537     // The access token should available once the refresh thread completes
538     // However it will be populated asynchronously, so we need to wait until it propagates
539     // Wait at most 1 minute are 100ms intervals. It should never come close to this.
540     for (int i = 0; i < 600; i++) {
541       Map<String, List<String>> requestMetadata = creds.getRequestMetadata(CALL_URI);
542       String s = requestMetadata.get(AuthHttpConstants.AUTHORIZATION).get(0);
543       if (s.contains(refreshedToken.getTokenValue())) {
544         break;
545       }
546       Thread.sleep(100);
547     }
548 
549     // Everything should return the new token
550     callback = new MockRequestMetadataCallback();
551     creds.getRequestMetadata(CALL_URI, realExecutor, callback);
552     TestUtils.assertContainsBearerToken(callback.metadata, refreshedToken.getTokenValue());
553     TestUtils.assertContainsBearerToken(
554         creds.getRequestMetadata(CALL_URI), refreshedToken.getTokenValue());
555 
556     // And the task slot is reset
557     synchronized (creds.lock) {
558       assertNull(creds.refreshTask);
559     }
560   }
561 
562   @Test
getRequestMetadata_staleTemporaryToken_expirationWaits()563   public void getRequestMetadata_staleTemporaryToken_expirationWaits() throws Throwable {
564     Instant actualExpiration = Instant.now();
565     Instant clientStale = actualExpiration.minus(OAuth2Credentials.DEFAULT_REFRESH_MARGIN);
566     Instant clientExpired = actualExpiration.minus(OAuth2Credentials.DEFAULT_EXPIRATION_MARGIN);
567 
568     TestClock testClock = new TestClock();
569 
570     // Initialize credentials which are initially stale and set to refresh
571     final SettableFuture<AccessToken> refreshedTokenFuture = SettableFuture.create();
572     OAuth2Credentials creds =
573         new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, Date.from(actualExpiration))) {
574           @Override
575           public AccessToken refreshAccessToken() {
576 
577             try {
578               return refreshedTokenFuture.get();
579             } catch (Exception e) {
580               throw new RuntimeException(e);
581             }
582           }
583         };
584     creds.clock = testClock;
585     synchronized (creds.lock) {
586       assertNull(creds.refreshTask);
587     }
588 
589     // Calls should return immediately with stale token, but a refresh is scheduled
590     testClock.setCurrentTime(clientStale.toEpochMilli());
591     MockRequestMetadataCallback callback = new MockRequestMetadataCallback();
592     creds.getRequestMetadata(CALL_URI, realExecutor, callback);
593     TestUtils.assertContainsBearerToken(callback.metadata, ACCESS_TOKEN);
594     assertNotNull(creds.refreshTask);
595     RefreshTask refreshTask = creds.refreshTask;
596 
597     // Fast forward to expiration, which will hang cause the callback to hang
598     testClock.setCurrentTime(clientExpired.toEpochMilli());
599     // Make sure that the callback is hung (while giving it a chance to run)
600     for (int i = 0; i < 10; i++) {
601       Thread.sleep(10);
602       callback = new MockRequestMetadataCallback();
603       creds.getRequestMetadata(CALL_URI, realExecutor, callback);
604       assertNull(callback.metadata);
605     }
606     // The original refresh task should still be active
607     synchronized (creds.lock) {
608       assertSame(refreshTask, creds.refreshTask);
609     }
610 
611     // Resolve the outstanding refresh
612     AccessToken refreshedToken =
613         new AccessToken(
614             "2/MkSJoj1xsli0AccessToken_NKPY2",
615             new Date(testClock.currentTimeMillis() + HOURS.toMillis(1)));
616     refreshedTokenFuture.set(refreshedToken);
617 
618     // The access token should available once the refresh thread completes
619     TestUtils.assertContainsBearerToken(
620         creds.getRequestMetadata(CALL_URI), refreshedToken.getTokenValue());
621     callback = new MockRequestMetadataCallback();
622     creds.getRequestMetadata(CALL_URI, realExecutor, callback);
623     TestUtils.assertContainsBearerToken(callback.awaitResult(), refreshedToken.getTokenValue());
624 
625     // The refresh slot should be cleared
626     synchronized (creds.lock) {
627       assertNull(creds.refreshTask);
628     }
629   }
630 
631   @Test
getRequestMetadata_singleFlightErrorSharing()632   public void getRequestMetadata_singleFlightErrorSharing() {
633     Instant actualExpiration = Instant.now();
634     Instant clientStale = actualExpiration.minus(OAuth2Credentials.DEFAULT_REFRESH_MARGIN);
635     Instant clientExpired = actualExpiration.minus(OAuth2Credentials.DEFAULT_EXPIRATION_MARGIN);
636 
637     TestClock testClock = new TestClock();
638     testClock.setCurrentTime(clientStale.toEpochMilli());
639 
640     // Initialize credentials which are initially expired
641     final SettableFuture<RuntimeException> refreshErrorFuture = SettableFuture.create();
642     final OAuth2Credentials creds =
643         new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, Date.from(clientExpired))) {
644           @Override
645           public AccessToken refreshAccessToken() {
646             RuntimeException injectedError;
647 
648             try {
649               injectedError = refreshErrorFuture.get();
650             } catch (Exception e) {
651               throw new IllegalStateException("Unexpected error fetching injected error");
652             }
653             throw injectedError;
654           }
655         };
656     creds.clock = testClock;
657 
658     // Calls will hang waiting for the refresh
659     final MockRequestMetadataCallback callback1 = new MockRequestMetadataCallback();
660     creds.getRequestMetadata(CALL_URI, realExecutor, callback1);
661 
662     final Future<Map<String, List<String>>> blockingCall =
663         realExecutor.submit(
664             new Callable<Map<String, List<String>>>() {
665               @Override
666               public Map<String, List<String>> call() throws Exception {
667                 return creds.getRequestMetadata(CALL_URI);
668               }
669             });
670 
671     RuntimeException error = new RuntimeException("fake error");
672     refreshErrorFuture.set(error);
673 
674     // Get the error that getRequestMetadata(uri) created
675     Throwable actualBlockingError =
676         assertThrows(
677                 ExecutionException.class,
678                 new ThrowingRunnable() {
679                   @Override
680                   public void run() throws Throwable {
681                     blockingCall.get();
682                   }
683                 })
684             .getCause();
685 
686     assertEquals(error, actualBlockingError);
687 
688     RuntimeException actualAsyncError =
689         assertThrows(
690             RuntimeException.class,
691             new ThrowingRunnable() {
692               @Override
693               public void run() throws Throwable {
694                 callback1.awaitResult();
695               }
696             });
697     assertEquals(error, actualAsyncError);
698   }
699 
700   @Test
getRequestMetadata_syncErrorsIncludeCallingStackframe()701   public void getRequestMetadata_syncErrorsIncludeCallingStackframe() {
702     final OAuth2Credentials creds =
703         new OAuth2Credentials() {
704           @Override
705           public AccessToken refreshAccessToken() {
706             throw new RuntimeException("fake error");
707           }
708         };
709 
710     List<StackTraceElement> expectedStacktrace =
711         new ArrayList<>(Arrays.asList(new Exception().getStackTrace()));
712     expectedStacktrace = expectedStacktrace.subList(1, expectedStacktrace.size());
713 
714     AtomicReference<Exception> actualError = new AtomicReference<>();
715     try {
716       creds.getRequestMetadata(CALL_URI);
717       fail("Should not be able to use credential without exception.");
718     } catch (Exception refreshError) {
719       actualError.set(refreshError);
720     }
721 
722     List<StackTraceElement> actualStacktrace = Arrays.asList(actualError.get().getStackTrace());
723     actualStacktrace =
724         actualStacktrace.subList(
725             actualStacktrace.size() - expectedStacktrace.size(), actualStacktrace.size());
726 
727     // ensure the remaining frames are identical
728     assertEquals(expectedStacktrace, actualStacktrace);
729   }
730 
731   @Test
refresh_refreshesToken()732   public void refresh_refreshesToken() throws IOException {
733     final String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2";
734     final String accessToken2 = "2/MkSJoj1xsli0AccessToken_NKPY2";
735     MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
736     transportFactory.transport.addClient(CLIENT_ID, CLIENT_SECRET);
737     transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken1);
738     OAuth2Credentials userCredentials =
739         UserCredentials.newBuilder()
740             .setClientId(CLIENT_ID)
741             .setClientSecret(CLIENT_SECRET)
742             .setRefreshToken(REFRESH_TOKEN)
743             .setHttpTransportFactory(transportFactory)
744             .build();
745     // Use a fixed clock so tokens don't expire
746     userCredentials.clock = new TestClock();
747 
748     // Get a first token
749     Map<String, List<String>> metadata = userCredentials.getRequestMetadata(CALL_URI);
750     TestUtils.assertContainsBearerToken(metadata, accessToken1);
751     assertEquals(1, transportFactory.transport.buildRequestCount--);
752 
753     // Change server to a different token
754     transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken2);
755 
756     // Confirm token being cached
757     TestUtils.assertContainsBearerToken(metadata, accessToken1);
758     assertEquals(0, transportFactory.transport.buildRequestCount);
759 
760     // Refresh to force getting next token
761     userCredentials.refresh();
762     metadata = userCredentials.getRequestMetadata(CALL_URI);
763     TestUtils.assertContainsBearerToken(metadata, accessToken2);
764     assertEquals(1, transportFactory.transport.buildRequestCount--);
765   }
766 
767   @Test
refreshIfExpired_refreshesToken()768   public void refreshIfExpired_refreshesToken() throws IOException {
769     final String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2";
770     final String accessToken2 = "2/MkSJoj1xsli0AccessToken_NKPY2";
771     MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
772     transportFactory.transport.addClient(CLIENT_ID, CLIENT_SECRET);
773     transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken1);
774     OAuth2Credentials userCredentials =
775         UserCredentials.newBuilder()
776             .setClientId(CLIENT_ID)
777             .setClientSecret(CLIENT_SECRET)
778             .setRefreshToken(REFRESH_TOKEN)
779             .setHttpTransportFactory(transportFactory)
780             .build();
781     // Use a fixed clock so tokens don't expire
782     TestClock mockClock = new TestClock();
783     userCredentials.clock = mockClock;
784 
785     // Get a first token
786     Map<String, List<String>> metadata = userCredentials.getRequestMetadata(CALL_URI);
787     TestUtils.assertContainsBearerToken(metadata, accessToken1);
788     assertEquals(1, transportFactory.transport.buildRequestCount--);
789 
790     // Change server to a different token
791     transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken2);
792 
793     // Confirm token being cached
794     TestUtils.assertContainsBearerToken(metadata, accessToken1);
795     assertEquals(0, transportFactory.transport.buildRequestCount);
796 
797     // Should not refresh yet
798     userCredentials.refreshIfExpired();
799     metadata = userCredentials.getRequestMetadata(CALL_URI);
800     TestUtils.assertNotContainsBearerToken(metadata, accessToken2);
801 
802     // Jump ahead to expire the token
803     mockClock.addToCurrentTime(3600000);
804     userCredentials.refreshIfExpired();
805     metadata = userCredentials.getRequestMetadata(CALL_URI);
806     TestUtils.assertContainsBearerToken(metadata, accessToken2);
807 
808     assertEquals(1, transportFactory.transport.buildRequestCount--);
809   }
810 
811   @Test(expected = IllegalStateException.class)
refresh_temporaryToken_throws()812   public void refresh_temporaryToken_throws() throws IOException {
813     OAuth2Credentials credentials =
814         OAuth2Credentials.newBuilder().setAccessToken(new AccessToken(ACCESS_TOKEN, null)).build();
815     credentials.refresh();
816   }
817 
818   @Test
equals_true()819   public void equals_true() throws IOException {
820     final String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2";
821     OAuth2Credentials credentials =
822         OAuth2Credentials.newBuilder().setAccessToken(new AccessToken(accessToken1, null)).build();
823     OAuth2Credentials otherCredentials =
824         OAuth2Credentials.newBuilder().setAccessToken(new AccessToken(accessToken1, null)).build();
825     assertTrue(credentials.equals(otherCredentials));
826     assertTrue(otherCredentials.equals(credentials));
827   }
828 
829   @Test
equals_false_accessToken()830   public void equals_false_accessToken() throws IOException {
831     final String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2";
832     final String accessToken2 = "2/MkSJoj1xsli0AccessToken_NKPY2";
833     OAuth2Credentials credentials =
834         OAuth2Credentials.newBuilder().setAccessToken(new AccessToken(accessToken1, null)).build();
835     OAuth2Credentials otherCredentials =
836         OAuth2Credentials.newBuilder().setAccessToken(new AccessToken(accessToken2, null)).build();
837     assertFalse(credentials.equals(otherCredentials));
838     assertFalse(otherCredentials.equals(credentials));
839   }
840 
841   @Test
toString_containsFields()842   public void toString_containsFields() throws IOException {
843     AccessToken accessToken = new AccessToken("1/MkSJoj1xsli0AccessToken_NKPY2", null);
844     OAuth2Credentials credentials =
845         OAuth2Credentials.newBuilder().setAccessToken(accessToken).build();
846     String expectedToString =
847         String.format(
848             "OAuth2Credentials{requestMetadata=%s, temporaryAccess=%s}",
849             ImmutableMap.of(
850                 AuthHttpConstants.AUTHORIZATION,
851                 ImmutableList.of(OAuth2Utils.BEARER_PREFIX + accessToken.getTokenValue())),
852             accessToken.toString());
853     assertEquals(expectedToString, credentials.toString());
854   }
855 
856   @Test
hashCode_equals()857   public void hashCode_equals() throws IOException {
858     final String accessToken = "1/MkSJoj1xsli0AccessToken_NKPY2";
859     OAuth2Credentials credentials =
860         OAuth2Credentials.newBuilder().setAccessToken(new AccessToken(accessToken, null)).build();
861     OAuth2Credentials otherCredentials =
862         OAuth2Credentials.create(new AccessToken(accessToken, null));
863     assertEquals(credentials.hashCode(), otherCredentials.hashCode());
864   }
865 
866   @Test
serialize()867   public void serialize() throws IOException, ClassNotFoundException {
868     final String accessToken = "1/MkSJoj1xsli0AccessToken_NKPY2";
869     OAuth2Credentials credentials =
870         OAuth2Credentials.newBuilder().setAccessToken(new AccessToken(accessToken, null)).build();
871     OAuth2Credentials deserializedCredentials = serializeAndDeserialize(credentials);
872     assertEquals(credentials, deserializedCredentials);
873     assertEquals(credentials.hashCode(), deserializedCredentials.hashCode());
874     assertEquals(credentials.toString(), deserializedCredentials.toString());
875     assertSame(deserializedCredentials.clock, Clock.SYSTEM);
876   }
877 
878   @Test
879   @Ignore
updateTokenValueBeforeWake()880   public void updateTokenValueBeforeWake() throws IOException, InterruptedException {
881     final SettableFuture<AccessToken> refreshedTokenFuture = SettableFuture.create();
882     AccessToken refreshedToken = new AccessToken("2/MkSJoj1xsli0AccessToken_NKPY2", null);
883     refreshedTokenFuture.set(refreshedToken);
884 
885     final ListenableFutureTask<OAuthValue> task =
886         ListenableFutureTask.create(
887             new Callable<OAuthValue>() {
888               @Override
889               public OAuthValue call() throws Exception {
890                 return OAuthValue.create(refreshedToken, new HashMap<>());
891               }
892             });
893 
894     OAuth2Credentials creds =
895         new OAuth2Credentials() {
896           @Override
897           public AccessToken refreshAccessToken() {
898             synchronized (this) {
899               // Wake up the main thread. This is done now because the child thread (t) is known to
900               // have the refresh task. Now we want the main thread to wake up and create a future
901               // in order to wait for the refresh to complete.
902               this.notify();
903             }
904             RefreshTaskListener listener =
905                 new RefreshTaskListener(task) {
906                   @Override
907                   public void run() {
908                     try {
909                       // Sleep before setting accessToken to new accessToken. Refresh should not
910                       // complete before this, and the accessToken is `null` until it is.
911                       Thread.sleep(300);
912                       super.run();
913                     } catch (Exception e) {
914                       fail("Unexpected error. Exception: " + e);
915                     }
916                   }
917                 };
918 
919             this.refreshTask = new RefreshTask(task, listener);
920 
921             try {
922               // Sleep for 100 milliseconds to give parent thread time to create a refresh future.
923               Thread.sleep(100);
924               return refreshedTokenFuture.get();
925             } catch (Exception e) {
926               throw new RuntimeException(e);
927             }
928           }
929         };
930 
931     Thread t =
932         new Thread(
933             new Runnable() {
934               @Override
935               public void run() {
936                 try {
937                   creds.refresh();
938                   assertNotNull(creds.getAccessToken());
939                 } catch (Exception e) {
940                   fail("Unexpected error. Exception: " + e);
941                 }
942               }
943             });
944     t.start();
945 
946     synchronized (creds) {
947       // Grab a lock on creds object. This thread (the main thread) will wait here until the child
948       // thread (t) calls `notify` on the creds object.
949       creds.wait();
950     }
951 
952     AccessToken token = creds.getAccessToken();
953     assertNull(token);
954 
955     creds.refresh();
956     token = creds.getAccessToken();
957     // Token should never be NULL after a refresh that succeeded.
958     // Previously the token could be NULL due to an internal race condition between the future
959     // completing and the task listener updating the value of the access token.
960     assertNotNull(token);
961     t.join();
962   }
963 
waitForRefreshTaskCompletion(OAuth2Credentials credentials)964   private void waitForRefreshTaskCompletion(OAuth2Credentials credentials)
965       throws TimeoutException, InterruptedException {
966     for (int i = 0; i < 100; i++) {
967       synchronized (credentials.lock) {
968         if (credentials.refreshTask == null) {
969           return;
970         }
971       }
972       Thread.sleep(100);
973     }
974     throw new TimeoutException("timed out waiting for refresh task to finish");
975   }
976 
977   private static class TestChangeListener implements OAuth2Credentials.CredentialsChangedListener {
978 
979     public AccessToken accessToken = null;
980     public int callCount = 0;
981 
982     @Override
onChanged(OAuth2Credentials credentials)983     public void onChanged(OAuth2Credentials credentials) throws IOException {
984       accessToken = credentials.getAccessToken();
985       callCount++;
986     }
987   }
988 }
989