• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions are
6  * met:
7  *
8  *    * Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  *    * Redistributions in binary form must reproduce the above
11  * copyright notice, this list of conditions and the following disclaimer
12  * in the documentation and/or other materials provided with the
13  * distribution.
14  *
15  *    * Neither the name of Google LLC nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30  */
31 
32 package com.google.auth.oauth2;
33 
34 import static org.junit.Assert.assertEquals;
35 import static org.junit.Assert.assertNotNull;
36 
37 import com.google.api.client.http.HttpStatusCodes;
38 import com.google.api.client.http.LowLevelHttpRequest;
39 import com.google.api.client.http.LowLevelHttpResponse;
40 import com.google.api.client.json.GenericJson;
41 import com.google.api.client.json.Json;
42 import com.google.api.client.json.gson.GsonFactory;
43 import com.google.api.client.testing.http.MockHttpTransport;
44 import com.google.api.client.testing.http.MockLowLevelHttpRequest;
45 import com.google.api.client.testing.http.MockLowLevelHttpResponse;
46 import com.google.auth.TestUtils;
47 import com.google.common.base.Joiner;
48 import java.io.IOException;
49 import java.util.ArrayDeque;
50 import java.util.Collections;
51 import java.util.List;
52 import java.util.Map;
53 import java.util.Queue;
54 import java.util.regex.Matcher;
55 import java.util.regex.Pattern;
56 
57 /** Transport that mocks a basic STS endpoint. */
58 public final class MockStsTransport extends MockHttpTransport {
59 
60   private static final String EXPECTED_GRANT_TYPE =
61       "urn:ietf:params:oauth:grant-type:token-exchange";
62   private static final String ISSUED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token";
63   private static final String VALID_STS_PATTERN =
64       "https:\\/\\/sts.[a-z-_\\.]+\\/v1\\/(token|oauthtoken)";
65   private static final String ACCESS_TOKEN = "accessToken";
66   private static final String TOKEN_TYPE = "Bearer";
67   private static final Long EXPIRES_IN = 3600L;
68 
69   private final Queue<IOException> responseErrorSequence = new ArrayDeque<>();
70   private final Queue<List<String>> scopeSequence = new ArrayDeque<>();
71   private final Queue<String> refreshTokenSequence = new ArrayDeque<>();
72 
73   private boolean returnExpiresIn = true;
74   private MockLowLevelHttpRequest request;
75 
addResponseErrorSequence(IOException... errors)76   public void addResponseErrorSequence(IOException... errors) {
77     Collections.addAll(responseErrorSequence, errors);
78   }
79 
addRefreshTokenSequence(String... refreshTokens)80   public void addRefreshTokenSequence(String... refreshTokens) {
81     Collections.addAll(refreshTokenSequence, refreshTokens);
82   }
83 
addScopeSequence(List<String> scopes)84   public void addScopeSequence(List<String> scopes) {
85     Collections.addAll(scopeSequence, scopes);
86   }
87 
88   @Override
buildRequest(final String method, final String url)89   public LowLevelHttpRequest buildRequest(final String method, final String url) {
90     this.request =
91         new MockLowLevelHttpRequest(url) {
92           @Override
93           public LowLevelHttpResponse execute() throws IOException {
94             // Environment version is prefixed by "aws". e.g. "aws1".
95             Matcher matcher = Pattern.compile(VALID_STS_PATTERN).matcher(url);
96             if (!matcher.matches()) {
97               return makeErrorResponse();
98             }
99 
100             if (!responseErrorSequence.isEmpty()) {
101               throw responseErrorSequence.poll();
102             }
103 
104             GenericJson response = new GenericJson();
105             response.setFactory(new GsonFactory());
106 
107             Map<String, String> query = TestUtils.parseQuery(getContentAsString());
108             if (!url.contains("v1/oauthtoken")) {
109               assertEquals(EXPECTED_GRANT_TYPE, query.get("grant_type"));
110               assertNotNull(query.get("subject_token_type"));
111               assertNotNull(query.get("subject_token"));
112 
113               response.put("token_type", TOKEN_TYPE);
114               response.put("access_token", ACCESS_TOKEN);
115               response.put("issued_token_type", ISSUED_TOKEN_TYPE);
116 
117               if (returnExpiresIn) {
118                 response.put("expires_in", EXPIRES_IN);
119               }
120               if (!refreshTokenSequence.isEmpty()) {
121                 response.put("refresh_token", refreshTokenSequence.poll());
122               }
123               if (!scopeSequence.isEmpty()) {
124                 response.put("scope", Joiner.on(' ').join(scopeSequence.poll()));
125               }
126             } else {
127               assertEquals("refresh_token", query.get("grant_type"));
128 
129               response.put("access_token", ACCESS_TOKEN);
130               response.put("expires_in", EXPIRES_IN);
131 
132               if (!refreshTokenSequence.isEmpty()) {
133                 response.put("refresh_token", refreshTokenSequence.poll());
134               }
135             }
136             return new MockLowLevelHttpResponse()
137                 .setContentType(Json.MEDIA_TYPE)
138                 .setContent(response.toPrettyString());
139           }
140         };
141     return this.request;
142   }
143 
makeErrorResponse()144   private MockLowLevelHttpResponse makeErrorResponse() {
145     MockLowLevelHttpResponse errorResponse = new MockLowLevelHttpResponse();
146     errorResponse.setStatusCode(HttpStatusCodes.STATUS_CODE_BAD_REQUEST);
147     errorResponse.setContentType(Json.MEDIA_TYPE);
148     errorResponse.setContent("{\"error\":\"error\"}");
149     return errorResponse;
150   }
151 
getRequest()152   public MockLowLevelHttpRequest getRequest() {
153     return request;
154   }
155 
getAccessToken()156   public String getAccessToken() {
157     return ACCESS_TOKEN;
158   }
159 
getTokenType()160   public String getTokenType() {
161     return TOKEN_TYPE;
162   }
163 
getIssuedTokenType()164   public String getIssuedTokenType() {
165     return ISSUED_TOKEN_TYPE;
166   }
167 
getExpiresIn()168   public Long getExpiresIn() {
169     return EXPIRES_IN;
170   }
171 
setReturnExpiresIn(boolean returnExpiresIn)172   public void setReturnExpiresIn(boolean returnExpiresIn) {
173     this.returnExpiresIn = returnExpiresIn;
174   }
175 }
176