1 /* 2 * Copyright 2021 Google LLC 3 * 4 * Redistribution and use in source and binary forms, with or without 5 * modification, are permitted provided that the following conditions are 6 * met: 7 * 8 * * Redistributions of source code must retain the above copyright 9 * notice, this list of conditions and the following disclaimer. 10 * * Redistributions in binary form must reproduce the above 11 * copyright notice, this list of conditions and the following disclaimer 12 * in the documentation and/or other materials provided with the 13 * distribution. 14 * 15 * * Neither the name of Google LLC nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 */ 31 32 package com.google.auth.oauth2; 33 34 import static org.junit.Assert.assertEquals; 35 import static org.junit.Assert.assertNotNull; 36 37 import com.google.api.client.http.HttpStatusCodes; 38 import com.google.api.client.http.LowLevelHttpRequest; 39 import com.google.api.client.http.LowLevelHttpResponse; 40 import com.google.api.client.json.GenericJson; 41 import com.google.api.client.json.Json; 42 import com.google.api.client.json.gson.GsonFactory; 43 import com.google.api.client.testing.http.MockHttpTransport; 44 import com.google.api.client.testing.http.MockLowLevelHttpRequest; 45 import com.google.api.client.testing.http.MockLowLevelHttpResponse; 46 import com.google.auth.TestUtils; 47 import com.google.common.base.Joiner; 48 import java.io.IOException; 49 import java.util.ArrayDeque; 50 import java.util.Collections; 51 import java.util.List; 52 import java.util.Map; 53 import java.util.Queue; 54 import java.util.regex.Matcher; 55 import java.util.regex.Pattern; 56 57 /** Transport that mocks a basic STS endpoint. */ 58 public final class MockStsTransport extends MockHttpTransport { 59 60 private static final String EXPECTED_GRANT_TYPE = 61 "urn:ietf:params:oauth:grant-type:token-exchange"; 62 private static final String ISSUED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token"; 63 private static final String VALID_STS_PATTERN = 64 "https:\\/\\/sts.[a-z-_\\.]+\\/v1\\/(token|oauthtoken)"; 65 private static final String ACCESS_TOKEN = "accessToken"; 66 private static final String TOKEN_TYPE = "Bearer"; 67 private static final Long EXPIRES_IN = 3600L; 68 69 private final Queue<IOException> responseErrorSequence = new ArrayDeque<>(); 70 private final Queue<List<String>> scopeSequence = new ArrayDeque<>(); 71 private final Queue<String> refreshTokenSequence = new ArrayDeque<>(); 72 73 private boolean returnExpiresIn = true; 74 private MockLowLevelHttpRequest request; 75 addResponseErrorSequence(IOException... errors)76 public void addResponseErrorSequence(IOException... errors) { 77 Collections.addAll(responseErrorSequence, errors); 78 } 79 addRefreshTokenSequence(String... refreshTokens)80 public void addRefreshTokenSequence(String... refreshTokens) { 81 Collections.addAll(refreshTokenSequence, refreshTokens); 82 } 83 addScopeSequence(List<String> scopes)84 public void addScopeSequence(List<String> scopes) { 85 Collections.addAll(scopeSequence, scopes); 86 } 87 88 @Override buildRequest(final String method, final String url)89 public LowLevelHttpRequest buildRequest(final String method, final String url) { 90 this.request = 91 new MockLowLevelHttpRequest(url) { 92 @Override 93 public LowLevelHttpResponse execute() throws IOException { 94 // Environment version is prefixed by "aws". e.g. "aws1". 95 Matcher matcher = Pattern.compile(VALID_STS_PATTERN).matcher(url); 96 if (!matcher.matches()) { 97 return makeErrorResponse(); 98 } 99 100 if (!responseErrorSequence.isEmpty()) { 101 throw responseErrorSequence.poll(); 102 } 103 104 GenericJson response = new GenericJson(); 105 response.setFactory(new GsonFactory()); 106 107 Map<String, String> query = TestUtils.parseQuery(getContentAsString()); 108 if (!url.contains("v1/oauthtoken")) { 109 assertEquals(EXPECTED_GRANT_TYPE, query.get("grant_type")); 110 assertNotNull(query.get("subject_token_type")); 111 assertNotNull(query.get("subject_token")); 112 113 response.put("token_type", TOKEN_TYPE); 114 response.put("access_token", ACCESS_TOKEN); 115 response.put("issued_token_type", ISSUED_TOKEN_TYPE); 116 117 if (returnExpiresIn) { 118 response.put("expires_in", EXPIRES_IN); 119 } 120 if (!refreshTokenSequence.isEmpty()) { 121 response.put("refresh_token", refreshTokenSequence.poll()); 122 } 123 if (!scopeSequence.isEmpty()) { 124 response.put("scope", Joiner.on(' ').join(scopeSequence.poll())); 125 } 126 } else { 127 assertEquals("refresh_token", query.get("grant_type")); 128 129 response.put("access_token", ACCESS_TOKEN); 130 response.put("expires_in", EXPIRES_IN); 131 132 if (!refreshTokenSequence.isEmpty()) { 133 response.put("refresh_token", refreshTokenSequence.poll()); 134 } 135 } 136 return new MockLowLevelHttpResponse() 137 .setContentType(Json.MEDIA_TYPE) 138 .setContent(response.toPrettyString()); 139 } 140 }; 141 return this.request; 142 } 143 makeErrorResponse()144 private MockLowLevelHttpResponse makeErrorResponse() { 145 MockLowLevelHttpResponse errorResponse = new MockLowLevelHttpResponse(); 146 errorResponse.setStatusCode(HttpStatusCodes.STATUS_CODE_BAD_REQUEST); 147 errorResponse.setContentType(Json.MEDIA_TYPE); 148 errorResponse.setContent("{\"error\":\"error\"}"); 149 return errorResponse; 150 } 151 getRequest()152 public MockLowLevelHttpRequest getRequest() { 153 return request; 154 } 155 getAccessToken()156 public String getAccessToken() { 157 return ACCESS_TOKEN; 158 } 159 getTokenType()160 public String getTokenType() { 161 return TOKEN_TYPE; 162 } 163 getIssuedTokenType()164 public String getIssuedTokenType() { 165 return ISSUED_TOKEN_TYPE; 166 } 167 getExpiresIn()168 public Long getExpiresIn() { 169 return EXPIRES_IN; 170 } 171 setReturnExpiresIn(boolean returnExpiresIn)172 public void setReturnExpiresIn(boolean returnExpiresIn) { 173 this.returnExpiresIn = returnExpiresIn; 174 } 175 } 176