• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.adservices.test.scenario.adservices.utils;
18 
19 import static org.junit.Assert.assertEquals;
20 import static org.junit.Assert.assertFalse;
21 
22 import android.content.Context;
23 import android.net.Uri;
24 
25 import com.google.mockwebserver.Dispatcher;
26 import com.google.mockwebserver.MockResponse;
27 import com.google.mockwebserver.MockWebServer;
28 import com.google.mockwebserver.RecordedRequest;
29 
30 import org.junit.Assert;
31 import org.junit.rules.TestRule;
32 import org.junit.runner.Description;
33 import org.junit.runners.model.Statement;
34 
35 import java.io.IOException;
36 import java.io.InputStream;
37 import java.net.ServerSocket;
38 import java.security.GeneralSecurityException;
39 import java.security.KeyStore;
40 import java.util.HashSet;
41 import java.util.List;
42 import java.util.Objects;
43 import java.util.Set;
44 import java.util.function.Function;
45 
46 import javax.net.ssl.KeyManagerFactory;
47 import javax.net.ssl.SSLContext;
48 import javax.net.ssl.SSLSocketFactory;
49 
50 /** Instances of this class are not thread safe. */
51 public class MockWebServerRule implements TestRule {
52     private static final int UNINITIALIZED = -1;
53     private final InputStream mCertificateInputStream;
54     private final char[] mKeyStorePassword;
55     private int mPort = UNINITIALIZED;
56     private MockWebServer mMockWebServer;
57 
MockWebServerRule(InputStream inputStream, String keyStorePassword)58     private MockWebServerRule(InputStream inputStream, String keyStorePassword) {
59         mCertificateInputStream = inputStream;
60         mKeyStorePassword = keyStorePassword == null ? null : keyStorePassword.toCharArray();
61     }
62 
forHttp()63     public static MockWebServerRule forHttp() {
64         return new MockWebServerRule(null, null);
65     }
66 
67     /**
68      * Builds an instance of the MockWebServerRule configured for HTTPS traffic.
69      *
70      * @param context The app context used to load the PKCS12 key store
71      * @param assetName The name of the key store under the app assets folder
72      * @param keyStorePassword The password of the keystore
73      */
forHttps( Context context, String assetName, String keyStorePassword)74     public static MockWebServerRule forHttps(
75             Context context, String assetName, String keyStorePassword) {
76         try {
77             return new MockWebServerRule(context.getAssets().open(assetName), keyStorePassword);
78         } catch (IOException ioException) {
79             throw new RuntimeException("Unable to initialize MockWebServerRule", ioException);
80         }
81     }
82 
83     /**
84      * Builds an instance of the MockWebServerRule configured for HTTPS traffic.
85      *
86      * @param certificateInputStream An input stream to load the content of a PKCS12 key store
87      * @param keyStorePassword The password of the keystore
88      */
forHttps( InputStream certificateInputStream, String keyStorePassword)89     public static MockWebServerRule forHttps(
90             InputStream certificateInputStream, String keyStorePassword) {
91         return new MockWebServerRule(certificateInputStream, keyStorePassword);
92     }
93 
useHttps()94     private boolean useHttps() {
95         return Objects.nonNull(mCertificateInputStream);
96     }
97 
startMockWebServer(List<MockResponse> responses)98     public MockWebServer startMockWebServer(List<MockResponse> responses) throws Exception {
99         if (mPort == UNINITIALIZED) {
100             reserveServerListeningPort();
101         }
102 
103         mMockWebServer = new MockWebServer();
104         if (useHttps()) {
105             mMockWebServer.useHttps(getTestingSslSocketFactory(), false);
106         }
107         for (MockResponse response : responses) {
108             mMockWebServer.enqueue(response);
109         }
110         mMockWebServer.play(mPort);
111         return mMockWebServer;
112     }
113 
startMockWebServer(Function<RecordedRequest, MockResponse> lambda)114     public MockWebServer startMockWebServer(Function<RecordedRequest, MockResponse> lambda)
115             throws Exception {
116         Dispatcher dispatcher =
117                 new Dispatcher() {
118                     @Override
119                     public MockResponse dispatch(RecordedRequest request) {
120                         return lambda.apply(request);
121                     }
122                 };
123         return startMockWebServer(dispatcher);
124     }
125 
startMockWebServer(Dispatcher dispatcher)126     public MockWebServer startMockWebServer(Dispatcher dispatcher) throws Exception {
127         if (mPort == UNINITIALIZED) {
128             reserveServerListeningPort();
129         }
130 
131         mMockWebServer = new MockWebServer();
132         if (useHttps()) {
133             mMockWebServer.useHttps(getTestingSslSocketFactory(), false);
134         }
135         mMockWebServer.setDispatcher(dispatcher);
136 
137         mMockWebServer.play(mPort);
138         return mMockWebServer;
139     }
140 
createMockWebServer()141     public MockWebServer createMockWebServer() throws Exception {
142         if (mPort == UNINITIALIZED) {
143             reserveServerListeningPort();
144         }
145 
146         mMockWebServer = new MockWebServer();
147         if (useHttps()) {
148             mMockWebServer.useHttps(getTestingSslSocketFactory(), false);
149         }
150         return mMockWebServer;
151     }
152 
startCreatedMockWebServer(Dispatcher dispatcher)153     public MockWebServer startCreatedMockWebServer(Dispatcher dispatcher) throws Exception {
154         if (mMockWebServer == null || mPort == UNINITIALIZED) {
155             throw new IllegalStateException(
156                     "MockWebServer is not created or the port is not reserved.");
157         }
158         mMockWebServer.setDispatcher(dispatcher);
159 
160         mMockWebServer.play(mPort);
161         return mMockWebServer;
162     }
163 
164     /**
165      * @return the mock web server for this rull and {@code null} if it hasn't been started yet by
166      *     calling {@link #startMockWebServer(List)}.
167      */
getMockWebServer()168     public MockWebServer getMockWebServer() {
169         return mMockWebServer;
170     }
171 
172     /** @return the base address the mock web server will be listening to when started. */
getServerBaseAddress()173     public String getServerBaseAddress() {
174         return String.format("%s://localhost:%d", useHttps() ? "https" : "http", mPort);
175     }
176 
177     /**
178      * This method is equivalent to {@link MockWebServer#getUrl(String)} but it can be used before
179      * you prepare and start the server if you need to prepare responses that will reference the
180      * same test server.
181      *
182      * @return an Uri to use to reach the given {@code @path} on the mock web server.
183      */
uriForPath(String path)184     public Uri uriForPath(String path) {
185         return Uri.parse(
186                 String.format(
187                         "%s%s%s", getServerBaseAddress(), path.startsWith("/") ? "" : "/", path));
188     }
189 
reserveServerListeningPort()190     private void reserveServerListeningPort() throws IOException {
191         ServerSocket serverSocket = new ServerSocket(38383);
192         serverSocket.setReuseAddress(true);
193         mPort = serverSocket.getLocalPort();
194         serverSocket.close();
195     }
196 
197     /**
198      * Provides the ability to define a port before starting the mock web server. Otherwise, if the
199      * port has already been initialized it will throw an {@link IllegalStateException}
200      *
201      * @param port the port to be configured
202      * @throws IOException if port already in used
203      */
reserveServerListeningPort(int port)204     public void reserveServerListeningPort(int port) throws IOException {
205         if (mPort != UNINITIALIZED) {
206             throw new IllegalStateException("Port has already been initialized");
207         }
208 
209         ServerSocket serverSocket = new ServerSocket(port);
210         serverSocket.setReuseAddress(true);
211         mPort = serverSocket.getLocalPort();
212         serverSocket.close();
213     }
214 
getTestingSslSocketFactory()215     private SSLSocketFactory getTestingSslSocketFactory()
216             throws GeneralSecurityException, IOException {
217         final KeyManagerFactory keyManagerFactory =
218                 KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
219         KeyStore keyStore = KeyStore.getInstance("PKCS12");
220         keyStore.load(mCertificateInputStream, mKeyStorePassword);
221         keyManagerFactory.init(keyStore, mKeyStorePassword);
222         SSLContext sslContext = SSLContext.getInstance("TLS");
223         sslContext.init(keyManagerFactory.getKeyManagers(), null, null);
224         return sslContext.getSocketFactory();
225     }
226 
227     /**
228      * A utility that validates that the mock web server got the expected traffic.
229      *
230      * @param mockWebServer server instance used for making requests
231      * @param expectedRequestCount the number of requests expected to be received by the server
232      * @param expectedRequests the list of URLs that should have been requested, in case of repeat
233      *     requests the size of expectedRequests list could be less than the expectedRequestCount
234      * @param requestMatcher A custom matcher that dictates if the request meets the criteria of
235      *     being hit or not. This allows tests to do partial match of URLs in case of params or
236      *     other sub path of URL.
237      */
verifyMockServerRequests( final MockWebServer mockWebServer, final int expectedRequestCount, final List<String> expectedRequests, final RequestMatcher<String> requestMatcher)238     public void verifyMockServerRequests(
239             final MockWebServer mockWebServer,
240             final int expectedRequestCount,
241             final List<String> expectedRequests,
242             final RequestMatcher<String> requestMatcher) {
243 
244         assertEquals(
245                 "Number of expected requests does not match actual request count",
246                 expectedRequestCount,
247                 mockWebServer.getRequestCount());
248 
249         // For parallel executions requests should be checked agnostic of order
250         final Set<String> actualRequests = new HashSet<>();
251         for (int i = 0; i < expectedRequestCount; i++) {
252             try {
253                 actualRequests.add(mockWebServer.takeRequest().getPath());
254             } catch (InterruptedException e) {
255                 Thread.currentThread().interrupt();
256             }
257         }
258 
259         assertFalse(
260                 String.format(
261                         "Expected requests cannot be empty, actual requests <%s>", actualRequests),
262                 expectedRequestCount != 0 && expectedRequests.isEmpty());
263 
264         for (String request : expectedRequests) {
265             Assert.assertTrue(
266                     String.format(
267                             "Actual requests <%s> do not contain request <%s>",
268                             actualRequests, request),
269                     wasPathRequested(actualRequests, request, requestMatcher));
270         }
271     }
272 
wasPathRequested( final Set<String> actualRequests, final String request, final RequestMatcher<String> requestMatcher)273     private boolean wasPathRequested(
274             final Set<String> actualRequests,
275             final String request,
276             final RequestMatcher<String> requestMatcher) {
277         for (String actualRequest : actualRequests) {
278             // Passing a custom comparator allows tests to do partial match of URLs in case of
279             // params or other sub path of URL
280             if (requestMatcher.wasRequestMade(actualRequest, request)) {
281                 return true;
282             }
283         }
284         return false;
285     }
286 
287     @Override
apply(Statement base, Description description)288     public Statement apply(Statement base, Description description) {
289         return new Statement() {
290             @Override
291             public void evaluate() throws Throwable {
292                 reserveServerListeningPort();
293                 try {
294                     base.evaluate();
295                 } finally {
296                     if (mMockWebServer != null) {
297                         mMockWebServer.shutdown();
298                     }
299                 }
300             }
301         };
302     }
303 
304     public interface RequestMatcher<T> {
305         boolean wasRequestMade(T actualRequest, T expectedRequest);
306     }
307 }
308