1 /* 2 * Copyright (C) 2022 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.adservices.test.scenario.adservices.utils; 18 19 import static org.junit.Assert.assertEquals; 20 import static org.junit.Assert.assertFalse; 21 22 import android.content.Context; 23 import android.net.Uri; 24 25 import com.google.mockwebserver.Dispatcher; 26 import com.google.mockwebserver.MockResponse; 27 import com.google.mockwebserver.MockWebServer; 28 import com.google.mockwebserver.RecordedRequest; 29 30 import org.junit.Assert; 31 import org.junit.rules.TestRule; 32 import org.junit.runner.Description; 33 import org.junit.runners.model.Statement; 34 35 import java.io.IOException; 36 import java.io.InputStream; 37 import java.net.ServerSocket; 38 import java.security.GeneralSecurityException; 39 import java.security.KeyStore; 40 import java.util.HashSet; 41 import java.util.List; 42 import java.util.Objects; 43 import java.util.Set; 44 import java.util.function.Function; 45 46 import javax.net.ssl.KeyManagerFactory; 47 import javax.net.ssl.SSLContext; 48 import javax.net.ssl.SSLSocketFactory; 49 50 /** Instances of this class are not thread safe. */ 51 public class MockWebServerRule implements TestRule { 52 private static final int UNINITIALIZED = -1; 53 private final InputStream mCertificateInputStream; 54 private final char[] mKeyStorePassword; 55 private int mPort = UNINITIALIZED; 56 private MockWebServer mMockWebServer; 57 MockWebServerRule(InputStream inputStream, String keyStorePassword)58 private MockWebServerRule(InputStream inputStream, String keyStorePassword) { 59 mCertificateInputStream = inputStream; 60 mKeyStorePassword = keyStorePassword == null ? null : keyStorePassword.toCharArray(); 61 } 62 forHttp()63 public static MockWebServerRule forHttp() { 64 return new MockWebServerRule(null, null); 65 } 66 67 /** 68 * Builds an instance of the MockWebServerRule configured for HTTPS traffic. 69 * 70 * @param context The app context used to load the PKCS12 key store 71 * @param assetName The name of the key store under the app assets folder 72 * @param keyStorePassword The password of the keystore 73 */ forHttps( Context context, String assetName, String keyStorePassword)74 public static MockWebServerRule forHttps( 75 Context context, String assetName, String keyStorePassword) { 76 try { 77 return new MockWebServerRule(context.getAssets().open(assetName), keyStorePassword); 78 } catch (IOException ioException) { 79 throw new RuntimeException("Unable to initialize MockWebServerRule", ioException); 80 } 81 } 82 83 /** 84 * Builds an instance of the MockWebServerRule configured for HTTPS traffic. 85 * 86 * @param certificateInputStream An input stream to load the content of a PKCS12 key store 87 * @param keyStorePassword The password of the keystore 88 */ forHttps( InputStream certificateInputStream, String keyStorePassword)89 public static MockWebServerRule forHttps( 90 InputStream certificateInputStream, String keyStorePassword) { 91 return new MockWebServerRule(certificateInputStream, keyStorePassword); 92 } 93 useHttps()94 private boolean useHttps() { 95 return Objects.nonNull(mCertificateInputStream); 96 } 97 startMockWebServer(List<MockResponse> responses)98 public MockWebServer startMockWebServer(List<MockResponse> responses) throws Exception { 99 if (mPort == UNINITIALIZED) { 100 reserveServerListeningPort(); 101 } 102 103 mMockWebServer = new MockWebServer(); 104 if (useHttps()) { 105 mMockWebServer.useHttps(getTestingSslSocketFactory(), false); 106 } 107 for (MockResponse response : responses) { 108 mMockWebServer.enqueue(response); 109 } 110 mMockWebServer.play(mPort); 111 return mMockWebServer; 112 } 113 startMockWebServer(Function<RecordedRequest, MockResponse> lambda)114 public MockWebServer startMockWebServer(Function<RecordedRequest, MockResponse> lambda) 115 throws Exception { 116 Dispatcher dispatcher = 117 new Dispatcher() { 118 @Override 119 public MockResponse dispatch(RecordedRequest request) { 120 return lambda.apply(request); 121 } 122 }; 123 return startMockWebServer(dispatcher); 124 } 125 startMockWebServer(Dispatcher dispatcher)126 public MockWebServer startMockWebServer(Dispatcher dispatcher) throws Exception { 127 if (mPort == UNINITIALIZED) { 128 reserveServerListeningPort(); 129 } 130 131 mMockWebServer = new MockWebServer(); 132 if (useHttps()) { 133 mMockWebServer.useHttps(getTestingSslSocketFactory(), false); 134 } 135 mMockWebServer.setDispatcher(dispatcher); 136 137 mMockWebServer.play(mPort); 138 return mMockWebServer; 139 } 140 createMockWebServer()141 public MockWebServer createMockWebServer() throws Exception { 142 if (mPort == UNINITIALIZED) { 143 reserveServerListeningPort(); 144 } 145 146 mMockWebServer = new MockWebServer(); 147 if (useHttps()) { 148 mMockWebServer.useHttps(getTestingSslSocketFactory(), false); 149 } 150 return mMockWebServer; 151 } 152 startCreatedMockWebServer(Dispatcher dispatcher)153 public MockWebServer startCreatedMockWebServer(Dispatcher dispatcher) throws Exception { 154 if (mMockWebServer == null || mPort == UNINITIALIZED) { 155 throw new IllegalStateException( 156 "MockWebServer is not created or the port is not reserved."); 157 } 158 mMockWebServer.setDispatcher(dispatcher); 159 160 mMockWebServer.play(mPort); 161 return mMockWebServer; 162 } 163 164 /** 165 * @return the mock web server for this rull and {@code null} if it hasn't been started yet by 166 * calling {@link #startMockWebServer(List)}. 167 */ getMockWebServer()168 public MockWebServer getMockWebServer() { 169 return mMockWebServer; 170 } 171 172 /** @return the base address the mock web server will be listening to when started. */ getServerBaseAddress()173 public String getServerBaseAddress() { 174 return String.format("%s://localhost:%d", useHttps() ? "https" : "http", mPort); 175 } 176 177 /** 178 * This method is equivalent to {@link MockWebServer#getUrl(String)} but it can be used before 179 * you prepare and start the server if you need to prepare responses that will reference the 180 * same test server. 181 * 182 * @return an Uri to use to reach the given {@code @path} on the mock web server. 183 */ uriForPath(String path)184 public Uri uriForPath(String path) { 185 return Uri.parse( 186 String.format( 187 "%s%s%s", getServerBaseAddress(), path.startsWith("/") ? "" : "/", path)); 188 } 189 reserveServerListeningPort()190 private void reserveServerListeningPort() throws IOException { 191 ServerSocket serverSocket = new ServerSocket(38383); 192 serverSocket.setReuseAddress(true); 193 mPort = serverSocket.getLocalPort(); 194 serverSocket.close(); 195 } 196 197 /** 198 * Provides the ability to define a port before starting the mock web server. Otherwise, if the 199 * port has already been initialized it will throw an {@link IllegalStateException} 200 * 201 * @param port the port to be configured 202 * @throws IOException if port already in used 203 */ reserveServerListeningPort(int port)204 public void reserveServerListeningPort(int port) throws IOException { 205 if (mPort != UNINITIALIZED) { 206 throw new IllegalStateException("Port has already been initialized"); 207 } 208 209 ServerSocket serverSocket = new ServerSocket(port); 210 serverSocket.setReuseAddress(true); 211 mPort = serverSocket.getLocalPort(); 212 serverSocket.close(); 213 } 214 getTestingSslSocketFactory()215 private SSLSocketFactory getTestingSslSocketFactory() 216 throws GeneralSecurityException, IOException { 217 final KeyManagerFactory keyManagerFactory = 218 KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); 219 KeyStore keyStore = KeyStore.getInstance("PKCS12"); 220 keyStore.load(mCertificateInputStream, mKeyStorePassword); 221 keyManagerFactory.init(keyStore, mKeyStorePassword); 222 SSLContext sslContext = SSLContext.getInstance("TLS"); 223 sslContext.init(keyManagerFactory.getKeyManagers(), null, null); 224 return sslContext.getSocketFactory(); 225 } 226 227 /** 228 * A utility that validates that the mock web server got the expected traffic. 229 * 230 * @param mockWebServer server instance used for making requests 231 * @param expectedRequestCount the number of requests expected to be received by the server 232 * @param expectedRequests the list of URLs that should have been requested, in case of repeat 233 * requests the size of expectedRequests list could be less than the expectedRequestCount 234 * @param requestMatcher A custom matcher that dictates if the request meets the criteria of 235 * being hit or not. This allows tests to do partial match of URLs in case of params or 236 * other sub path of URL. 237 */ verifyMockServerRequests( final MockWebServer mockWebServer, final int expectedRequestCount, final List<String> expectedRequests, final RequestMatcher<String> requestMatcher)238 public void verifyMockServerRequests( 239 final MockWebServer mockWebServer, 240 final int expectedRequestCount, 241 final List<String> expectedRequests, 242 final RequestMatcher<String> requestMatcher) { 243 244 assertEquals( 245 "Number of expected requests does not match actual request count", 246 expectedRequestCount, 247 mockWebServer.getRequestCount()); 248 249 // For parallel executions requests should be checked agnostic of order 250 final Set<String> actualRequests = new HashSet<>(); 251 for (int i = 0; i < expectedRequestCount; i++) { 252 try { 253 actualRequests.add(mockWebServer.takeRequest().getPath()); 254 } catch (InterruptedException e) { 255 Thread.currentThread().interrupt(); 256 } 257 } 258 259 assertFalse( 260 String.format( 261 "Expected requests cannot be empty, actual requests <%s>", actualRequests), 262 expectedRequestCount != 0 && expectedRequests.isEmpty()); 263 264 for (String request : expectedRequests) { 265 Assert.assertTrue( 266 String.format( 267 "Actual requests <%s> do not contain request <%s>", 268 actualRequests, request), 269 wasPathRequested(actualRequests, request, requestMatcher)); 270 } 271 } 272 wasPathRequested( final Set<String> actualRequests, final String request, final RequestMatcher<String> requestMatcher)273 private boolean wasPathRequested( 274 final Set<String> actualRequests, 275 final String request, 276 final RequestMatcher<String> requestMatcher) { 277 for (String actualRequest : actualRequests) { 278 // Passing a custom comparator allows tests to do partial match of URLs in case of 279 // params or other sub path of URL 280 if (requestMatcher.wasRequestMade(actualRequest, request)) { 281 return true; 282 } 283 } 284 return false; 285 } 286 287 @Override apply(Statement base, Description description)288 public Statement apply(Statement base, Description description) { 289 return new Statement() { 290 @Override 291 public void evaluate() throws Throwable { 292 reserveServerListeningPort(); 293 try { 294 base.evaluate(); 295 } finally { 296 if (mMockWebServer != null) { 297 mMockWebServer.shutdown(); 298 } 299 } 300 } 301 }; 302 } 303 304 public interface RequestMatcher<T> { 305 boolean wasRequestMade(T actualRequest, T expectedRequest); 306 } 307 } 308