• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 package org.chromium.net.test.util;
6 
7 import android.util.Base64;
8 
9 import androidx.annotation.GuardedBy;
10 
11 import org.chromium.base.Log;
12 
13 import java.io.ByteArrayInputStream;
14 import java.io.ByteArrayOutputStream;
15 import java.io.IOException;
16 import java.io.InputStream;
17 import java.io.OutputStream;
18 import java.io.UnsupportedEncodingException;
19 import java.net.MalformedURLException;
20 import java.net.ServerSocket;
21 import java.net.Socket;
22 import java.net.SocketException;
23 import java.security.KeyStore;
24 import java.util.ArrayList;
25 import java.util.Arrays;
26 import java.util.HashSet;
27 import java.util.List;
28 import java.util.Set;
29 
30 import javax.net.ssl.KeyManager;
31 import javax.net.ssl.KeyManagerFactory;
32 import javax.net.ssl.SSLContext;
33 
34 /**
35  * Simple http test server for testing.
36  *
37  * This server runs in a thread in the current process, so it is convenient
38  * for loopback testing without the need to setup TCP forwarding to the
39  * host computer.
40  */
41 public class WebServer implements AutoCloseable {
42     private static final String TAG = "WebServer";
43 
44     private static Set<WebServer> sInstances = new HashSet<>();
45     private static Set<WebServer> sSecureInstances = new HashSet<>();
46 
47     private final ServerThread mServerThread;
48     private String mServerUri;
49     private final boolean mSsl;
50     private final int mPort;
51 
52     public static final String STATUS_OK = "200 OK";
53 
54     /**
55      * Writes an HTTP response to |output|.
56      * |status| should be one of the STATUS_* values above.
57      */
writeResponse(OutputStream output, String status, byte[] body)58     public static void writeResponse(OutputStream output, String status, byte[] body)
59             throws IOException {
60         if (body == null) {
61             body = new byte[0];
62         }
63         output.write(
64                 String.format("HTTP/1.1 %s\r\nContent-Length: %s\r\n\r\n", status, body.length)
65                         .getBytes());
66         output.write(body);
67         output.flush();
68     }
69 
70     /** Represents an HTTP header. */
71     public static class HTTPHeader {
72         public final String key;
73         public final String value;
74 
75         /** Constructs an HTTP header. */
HTTPHeader(String key, String value)76         public HTTPHeader(String key, String value) {
77             this.key = key;
78             this.value = value;
79         }
80 
81         /**
82          * Parse an HTTP header from a string line. Returns null if the line is not a valid HTTP
83          * header.
84          */
parseLine(String line)85         public static HTTPHeader parseLine(String line) {
86             String[] parts = line.split(":", 2);
87             if (parts.length == 2) {
88                 return new HTTPHeader(parts[0].trim(), parts[1].trim());
89             }
90             return null;
91         }
92 
93         @Override
toString()94         public String toString() {
95             return key + ": " + value;
96         }
97     }
98 
99     /** Thrown when an HTTP request could not be parsed. */
100     public static class InvalidRequest extends Exception {
101         /** Constructor */
InvalidRequest()102         public InvalidRequest() {
103             super("Invalid HTTP request");
104         }
105     }
106 
107     /** A parsed HTTP request. */
108     public static class HTTPRequest {
109         private String mMethod;
110         private String mURI;
111         private String mHTTPVersion;
112         private HTTPHeader[] mHeaders;
113         private byte[] mBody;
114 
115         @Override
toString()116         public String toString() {
117             StringBuilder builder = new StringBuilder();
118             builder.append(requestLine());
119             builder.append("\r\n");
120             for (HTTPHeader header : mHeaders) {
121                 builder.append(header.toString());
122                 builder.append("\r\n");
123             }
124             if (mBody != null) {
125                 builder.append("\r\n");
126                 try {
127                     builder.append(new String(mBody, "UTF-8"));
128                 } catch (UnsupportedEncodingException e) {
129                     builder.append("<binary body, length=").append(mBody.length).append(">\r\n");
130                 }
131             }
132             return builder.toString();
133         }
134 
135         /** Returns the request line as a String. */
requestLine()136         public String requestLine() {
137             return mMethod + " " + mURI + " " + mHTTPVersion;
138         }
139 
140         /** Returns the request method. */
getMethod()141         public String getMethod() {
142             return mMethod;
143         }
144 
145         /** Returns the request URI. */
getURI()146         public String getURI() {
147             return mURI;
148         }
149 
150         /** Returns the request HTTP version. */
getHTTPVersion()151         public String getHTTPVersion() {
152             return mHTTPVersion;
153         }
154 
155         /** Returns the request headers. */
getHeaders()156         public HTTPHeader[] getHeaders() {
157             return mHeaders;
158         }
159 
160         /** Returns the request body. */
getBody()161         public byte[] getBody() {
162             return mBody;
163         }
164 
165         /**
166          * Returns the header value for the given header name. If a header is present multiple
167          * times, this only returns the first occurence. Returns "" if the header is not found.
168          */
headerValue(String headerName)169         public String headerValue(String headerName) {
170             for (String value : headerValues(headerName)) {
171                 return value;
172             }
173             return "";
174         }
175 
176         /** Returns all header values for the given header name. */
headerValues(String headerName)177         public List<String> headerValues(String headerName) {
178             List<String> matchingHeaders = new ArrayList<String>();
179             for (HTTPHeader header : mHeaders) {
180                 if (header.key.equalsIgnoreCase(headerName)) {
181                     matchingHeaders.add(header.value);
182                 }
183             }
184             return matchingHeaders;
185         }
186 
hasChunkedTransferEncoding(HTTPRequest req)187         private static boolean hasChunkedTransferEncoding(HTTPRequest req) {
188             List<String> transferEncodings = req.headerValues("Transfer-Encoding");
189             for (String encoding : transferEncodings) {
190                 if (encoding.equals("chunked")) {
191                     return true;
192                 }
193             }
194             return false;
195         }
196 
197         /** Parses an HTTP request from an input stream. */
parse(InputStream stream)198         public static HTTPRequest parse(InputStream stream) throws InvalidRequest, IOException {
199             boolean firstLine = true;
200             HTTPRequest req = new HTTPRequest();
201             ArrayList<HTTPHeader> mHeaders = new ArrayList<HTTPHeader>();
202             ByteArrayOutputStream line = new ByteArrayOutputStream();
203             for (int b = stream.read(); b != -1; b = stream.read()) {
204                 if (b == '\r') {
205                     int next = stream.read();
206                     if (next == '\n') {
207                         String lineString;
208                         try {
209                             lineString = new String(line.toByteArray(), "UTF-8");
210                         } catch (UnsupportedEncodingException e) {
211                             throw new InvalidRequest();
212                         }
213                         line.reset();
214                         if (firstLine) {
215                             String[] parts = lineString.split(" ", 3);
216                             if (parts.length != 3) {
217                                 throw new InvalidRequest();
218                             }
219                             req.mMethod = parts[0];
220                             req.mURI = parts[1];
221                             req.mHTTPVersion = parts[2];
222                             firstLine = false;
223                         } else {
224                             if (lineString.length() == 0) {
225                                 break;
226                             }
227                             HTTPHeader header = HTTPHeader.parseLine(lineString);
228                             if (header != null) {
229                                 mHeaders.add(header);
230                             }
231                         }
232                     } else if (next == -1) {
233                         throw new InvalidRequest();
234                     } else {
235                         line.write(b);
236                         line.write(next);
237                     }
238                 } else {
239                     line.write(b);
240                 }
241             }
242             if (firstLine) {
243                 if (line.size() == 0) return null;
244                 throw new InvalidRequest();
245             }
246             req.mHeaders = mHeaders.toArray(new HTTPHeader[0]);
247             int contentLength = -1;
248             if (req.mMethod.equals("GET") || req.mMethod.equals("HEAD")) {
249                 contentLength = 0;
250             }
251             try {
252                 contentLength = Integer.parseInt(req.headerValue("Content-Length"));
253             } catch (NumberFormatException e) {
254             }
255             if (contentLength >= 0) {
256                 byte[] content = new byte[contentLength];
257                 for (int offset = 0; offset < contentLength; ) {
258                     int bytesRead = stream.read(content, offset, contentLength);
259                     if (bytesRead == -1) { // short read, keep truncated content.
260                         content = Arrays.copyOf(content, offset);
261                         break;
262                     }
263                     offset += bytesRead;
264                 }
265                 req.mBody = content;
266             } else if (hasChunkedTransferEncoding(req)) {
267                 ByteArrayOutputStream mBody = new ByteArrayOutputStream();
268                 byte[] buffer = new byte[1000];
269                 int bytesRead;
270                 while ((bytesRead = stream.read(buffer, 0, buffer.length)) != -1) {
271                     mBody.write(buffer, 0, bytesRead);
272                 }
273                 req.mBody = mBody.toByteArray();
274             }
275             return req;
276         }
277     }
278 
279     /** An interface for handling HTTP requests. */
280     public interface RequestHandler {
281         /** handleRequest is called when an HTTP request is received. handleRequest should write a
282          * response to stream. */
handleRequest(HTTPRequest request, OutputStream stream)283         void handleRequest(HTTPRequest request, OutputStream stream);
284     }
285 
286     private RequestHandler mRequestHandler;
287 
288     /** Sets the request handler. */
setRequestHandler(RequestHandler handler)289     public void setRequestHandler(RequestHandler handler) {
290         mRequestHandler = handler;
291     }
292 
293     /** Handle an HTTP request. Calls |mRequestHandler| if set. */
handleRequest(HTTPRequest request, OutputStream stream)294     private void handleRequest(HTTPRequest request, OutputStream stream) {
295         assert Thread.currentThread() == mServerThread
296                 : "handleRequest called from non-server thread";
297         if (mRequestHandler != null) {
298             mRequestHandler.handleRequest(request, stream);
299         }
300     }
301 
setServerHost(String hostname)302     public void setServerHost(String hostname) {
303         try {
304             mServerUri =
305                     new java.net.URI(
306                                     mSsl ? "https" : "http",
307                                     null,
308                                     hostname,
309                                     mServerThread.mSocket.getLocalPort(),
310                                     null,
311                                     null,
312                                     null)
313                             .toString();
314         } catch (java.net.URISyntaxException e) {
315             Log.wtf(TAG, e.getMessage());
316         }
317     }
318 
319     /**
320      * Create and start a local HTTP server instance. Additional must only be true
321      * if an instance was already created. You are responsible for calling
322      * shutdown() on each instance you create.
323      *
324      * @param port Port number the server must use, or 0 to automatically choose a free port.
325      * @param ssl True if the server should be using secure sockets.
326      * @param additional True if creating an additional server instance.
327      * @throws Exception
328      */
WebServer(int port, boolean ssl, boolean additional)329     public WebServer(int port, boolean ssl, boolean additional) throws Exception {
330         mPort = port;
331         mSsl = ssl;
332 
333         if (mSsl) {
334             if ((additional && WebServer.sSecureInstances.isEmpty())
335                     || (!additional && !WebServer.sSecureInstances.isEmpty())) {
336                 throw new IllegalStateException(
337                         "There are "
338                                 + WebServer.sSecureInstances.size()
339                                 + " SSL WebServer instances. Expected "
340                                 + (additional ? ">=1" : "0")
341                                 + " because additional is "
342                                 + additional);
343             }
344         } else {
345             if ((additional && WebServer.sInstances.isEmpty())
346                     || (!additional && !WebServer.sInstances.isEmpty())) {
347                 throw new IllegalStateException(
348                         "There are "
349                                 + WebServer.sSecureInstances.size()
350                                 + " WebServer instances. Expected "
351                                 + (additional ? ">=1" : "0")
352                                 + " because additional is "
353                                 + additional);
354             }
355         }
356         mServerThread = new ServerThread(mPort, mSsl);
357 
358         setServerHost("localhost");
359 
360         mServerThread.start();
361         if (mSsl) {
362             WebServer.sSecureInstances.add(this);
363         } else {
364             WebServer.sInstances.add(this);
365         }
366     }
367 
368     /**
369      * Create and start a local HTTP server instance.
370      *
371      * @param port Port number the server must use, or 0 to automatically choose a free port.
372      * @param ssl True if the server should be using secure sockets.
373      * @throws Exception
374      */
WebServer(int port, boolean ssl)375     public WebServer(int port, boolean ssl) throws Exception {
376         this(port, ssl, false);
377     }
378 
379     /** Terminate the http server. */
shutdown()380     public void shutdown() {
381         if (mSsl) {
382             WebServer.sSecureInstances.remove(this);
383         } else {
384             WebServer.sInstances.remove(this);
385         }
386 
387         try {
388             mServerThread.cancelAllRequests();
389             // Block until the server thread is done shutting down.
390             mServerThread.join();
391         } catch (MalformedURLException e) {
392             throw new IllegalStateException(e);
393         } catch (InterruptedException | IOException e) {
394             throw new RuntimeException(e);
395         }
396     }
397 
398     /**
399      * Make the WebServer AutoCloseable.
400      * Calls the shutdown method.
401      */
402     @Override
close()403     public void close() {
404         shutdown();
405     }
406 
getBaseUrl()407     public String getBaseUrl() {
408         return mServerUri + "/";
409     }
410 
411     /**
412      * Gets the URL on the server under which a particular request path will be accessible.
413      *
414      * This only gets the URL, you still need to set the response if you intend to access it.
415      *
416      * @param requestPath The path to respond to.
417      * @return The full URL including the requestPath.
418      */
getResponseUrl(String requestPath)419     public String getResponseUrl(String requestPath) {
420         return mServerUri + requestPath;
421     }
422 
423     private class ServerThread extends Thread {
424         private final boolean mIsSsl;
425         private ServerSocket mSocket;
426         private SSLContext mSslContext;
427 
428         private final Object mLock = new Object();
429 
430         @GuardedBy("mLock")
431         private boolean mIsCancelled;
432 
433         @GuardedBy("mLock")
434         private Socket mCurrentRequestSocket;
435 
436         /**
437          * Defines the keystore contents for the server, BKS version. Holds just a single
438          * self-generated key. The subject name is "Test Server".
439          */
440         private static final String SERVER_KEYS_BKS =
441                 "AAAAAQAAABQDkebzoP1XwqyWKRCJEpn/t8dqIQAABDkEAAVteWtleQAAARpYl20nAAAAAQAFWC41"
442                     + "MDkAAAJNMIICSTCCAbKgAwIBAgIESEfU1jANBgkqhkiG9w0BAQUFADBpMQswCQYDVQQGEwJVUzET"
443                     + "MBEGA1UECBMKQ2FsaWZvcm5pYTEMMAoGA1UEBxMDTVRWMQ8wDQYDVQQKEwZHb29nbGUxEDAOBgNV"
444                     + "BAsTB0FuZHJvaWQxFDASBgNVBAMTC1Rlc3QgU2VydmVyMB4XDTA4MDYwNTExNTgxNFoXDTA4MDkw"
445                     + "MzExNTgxNFowaTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExDDAKBgNVBAcTA01U"
446                     + "VjEPMA0GA1UEChMGR29vZ2xlMRAwDgYDVQQLEwdBbmRyb2lkMRQwEgYDVQQDEwtUZXN0IFNlcnZl"
447                     + "cjCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA0LIdKaIr9/vsTq8BZlA3R+NFWRaH4lGsTAQy"
448                     + "DPMF9ZqEDOaL6DJuu0colSBBBQ85hQTPa9m9nyJoN3pEi1hgamqOvQIWcXBk+SOpUGRZZFXwniJV"
449                     + "zDKU5nE9MYgn2B9AoiH3CSuMz6HRqgVaqtppIe1jhukMc/kHVJvlKRNy9XMCAwEAATANBgkqhkiG"
450                     + "9w0BAQUFAAOBgQC7yBmJ9O/eWDGtSH9BH0R3dh2NdST3W9hNZ8hIa8U8klhNHbUCSSktZmZkvbPU"
451                     + "hse5LI3dh6RyNDuqDrbYwcqzKbFJaq/jX9kCoeb3vgbQElMRX8D2ID1vRjxwlALFISrtaN4VpWzV"
452                     + "yeoHPW4xldeZmoVtjn8zXNzQhLuBqX2MmAAAAqwAAAAUvkUScfw9yCSmALruURNmtBai7kQAAAZx"
453                     + "4Jmijxs/l8EBaleaUru6EOPioWkUAEVWCxjM/TxbGHOi2VMsQWqRr/DZ3wsDmtQgw3QTrUK666sR"
454                     + "MBnbqdnyCyvM1J2V1xxLXPUeRBmR2CXorYGF9Dye7NkgVdfA+9g9L/0Au6Ugn+2Cj5leoIgkgApN"
455                     + "vuEcZegFlNOUPVEs3SlBgUF1BY6OBM0UBHTPwGGxFBBcetcuMRbUnu65vyDG0pslT59qpaR0TMVs"
456                     + "P+tcheEzhyjbfM32/vwhnL9dBEgM8qMt0sqF6itNOQU/F4WGkK2Cm2v4CYEyKYw325fEhzTXosck"
457                     + "MhbqmcyLab8EPceWF3dweoUT76+jEZx8lV2dapR+CmczQI43tV9btsd1xiBbBHAKvymm9Ep9bPzM"
458                     + "J0MQi+OtURL9Lxke/70/MRueqbPeUlOaGvANTmXQD2OnW7PISwJ9lpeLfTG0LcqkoqkbtLKQLYHI"
459                     + "rQfV5j0j+wmvmpMxzjN3uvNajLa4zQ8l0Eok9SFaRr2RL0gN8Q2JegfOL4pUiHPsh64WWya2NB7f"
460                     + "V+1s65eA5ospXYsShRjo046QhGTmymwXXzdzuxu8IlnTEont6P4+J+GsWk6cldGbl20hctuUKzyx"
461                     + "OptjEPOKejV60iDCYGmHbCWAzQ8h5MILV82IclzNViZmzAapeeCnexhpXhWTs+xDEYSKEiG/camt"
462                     + "bhmZc3BcyVJrW23PktSfpBQ6D8ZxoMfF0L7V2GQMaUg+3r7ucrx82kpqotjv0xHghNIm95aBr1Qw"
463                     + "1gaEjsC/0wGmmBDg1dTDH+F1p9TInzr3EFuYD0YiQ7YlAHq3cPuyGoLXJ5dXYuSBfhDXJSeddUkl"
464                     + "k1ufZyOOcskeInQge7jzaRfmKg3U94r+spMEvb0AzDQVOKvjjo1ivxMSgFRZaDb/4qw=";
465 
466         private static final String PASSWORD = "android";
467 
468         /**
469          * Loads a keystore from a base64-encoded String. Returns the KeyManager[]
470          * for the result.
471          */
getKeyManagers()472         private KeyManager[] getKeyManagers() throws Exception {
473             byte[] bytes = Base64.decode(SERVER_KEYS_BKS, Base64.DEFAULT);
474             InputStream inputStream = new ByteArrayInputStream(bytes);
475 
476             KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
477             keyStore.load(inputStream, PASSWORD.toCharArray());
478             inputStream.close();
479 
480             String algorithm = KeyManagerFactory.getDefaultAlgorithm();
481             KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(algorithm);
482             keyManagerFactory.init(keyStore, PASSWORD.toCharArray());
483 
484             return keyManagerFactory.getKeyManagers();
485         }
486 
setCurrentRequestSocket(Socket socket)487         private void setCurrentRequestSocket(Socket socket) {
488             synchronized (mLock) {
489                 mCurrentRequestSocket = socket;
490             }
491         }
492 
getIsCancelled()493         private boolean getIsCancelled() {
494             synchronized (mLock) {
495                 return mIsCancelled;
496             }
497         }
498 
499         // Called from non-server thread.
cancelAllRequests()500         public void cancelAllRequests() throws IOException {
501             synchronized (mLock) {
502                 mIsCancelled = true;
503                 if (mCurrentRequestSocket != null) {
504                     try {
505                         mCurrentRequestSocket.close();
506                     } catch (IOException ignored) {
507                         // Catching this to ensure the server socket is closed as well.
508                     }
509                 }
510             }
511             // Any current and subsequent accept call will throw instead of block.
512             mSocket.close();
513         }
514 
ServerThread(int port, boolean ssl)515         public ServerThread(int port, boolean ssl) throws Exception {
516             super("ServerThread");
517             mIsSsl = ssl;
518             // If tests are run back-to-back, it may take time for the port to become available.
519             // Retry a few times with a sleep to wait for the port.
520             int retry = 3;
521             while (true) {
522                 try {
523                     if (mIsSsl) {
524                         mSslContext = SSLContext.getInstance("TLS");
525                         mSslContext.init(getKeyManagers(), null, null);
526                         mSocket = mSslContext.getServerSocketFactory().createServerSocket(port);
527                     } else {
528                         mSocket = new ServerSocket(port);
529                     }
530                     return;
531                 } catch (IOException e) {
532                     Log.w(TAG, e.getMessage());
533                     if (--retry == 0) {
534                         throw e;
535                     }
536                     // sleep in case server socket is still being closed
537                     Thread.sleep(1000);
538                 }
539             }
540         }
541 
542         @Override
run()543         public void run() {
544             try {
545                 while (!getIsCancelled()) {
546                     Socket socket = mSocket.accept();
547                     try {
548                         setCurrentRequestSocket(socket);
549                         HTTPRequest request = HTTPRequest.parse(socket.getInputStream());
550                         if (request != null) {
551                             handleRequest(request, socket.getOutputStream());
552                         }
553                     } catch (InvalidRequest | IOException e) {
554                         Log.e(TAG, e.getMessage());
555                     } finally {
556                         socket.close();
557                     }
558                 }
559             } catch (SocketException e) {
560             } catch (IOException e) {
561                 Log.w(TAG, e.getMessage());
562             }
563         }
564     }
565 }
566