• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2015 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package org.conscrypt;
18 
19 import static org.junit.Assert.assertEquals;
20 import static org.junit.Assert.assertFalse;
21 
22 import java.io.FileNotFoundException;
23 import java.io.IOException;
24 import java.io.InputStream;
25 import java.lang.reflect.Method;
26 import java.net.ServerSocket;
27 import java.nio.ByteBuffer;
28 import java.nio.charset.Charset;
29 import java.security.NoSuchAlgorithmException;
30 import java.security.Provider;
31 import java.security.Security;
32 import javax.net.ssl.SSLContext;
33 import javax.net.ssl.SSLEngine;
34 import javax.net.ssl.SSLEngineResult;
35 import javax.net.ssl.SSLException;
36 import javax.net.ssl.SSLServerSocketFactory;
37 import javax.net.ssl.SSLSocketFactory;
38 import libcore.io.Streams;
39 import libcore.java.security.TestKeyStore;
40 
41 /**
42  * Utility methods to support testing.
43  */
44 public final class TestUtils {
45     static final Charset UTF_8 = Charset.forName("UTF-8");
46 
47     private static final Provider JDK_PROVIDER = getDefaultTlsProvider();
48     private static final byte[] CHARS =
49             "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".getBytes(UTF_8);
50     private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0);
51 
52     public static final String PROTOCOL_TLS_V1_2 = "TLSv1.2";
53     public static final String PROVIDER_PROPERTY = "SSLContext.TLSv1.2";
54     public static final String LOCALHOST = "localhost";
55 
56     static final String TEST_CIPHER = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256";
57 
TestUtils()58     private TestUtils() {}
59 
getDefaultTlsProvider()60     private static Provider getDefaultTlsProvider() {
61         for (Provider p : Security.getProviders()) {
62             if (p.get(PROVIDER_PROPERTY) != null) {
63                 return p;
64             }
65         }
66         throw new RuntimeException("Unable to find a default provider for " + PROVIDER_PROPERTY);
67     }
68 
getJdkProvider()69     static Provider getJdkProvider() {
70         return JDK_PROVIDER;
71     }
72 
getConscryptProvider()73     public static Provider getConscryptProvider() {
74         try {
75             return (Provider) conscryptClass("OpenSSLProvider")
76                 .getConstructor()
77                 .newInstance();
78         } catch (Exception e) {
79             throw new RuntimeException(e);
80         }
81     }
82 
installConscryptAsDefaultProvider()83     public static void installConscryptAsDefaultProvider() {
84         final Provider conscryptProvider = getConscryptProvider();
85         synchronized (getConscryptProvider()) {
86             Provider[] providers = Security.getProviders();
87             if (providers.length == 0 || !providers[0].equals(conscryptProvider)) {
88                 Security.insertProviderAt(conscryptProvider, 1);
89                 return;
90             }
91         }
92     }
93 
openTestFile(String name)94     public static InputStream openTestFile(String name) throws FileNotFoundException {
95         InputStream is = TestUtils.class.getResourceAsStream("/" + name);
96         if (is == null) {
97             throw new FileNotFoundException(name);
98         }
99         return is;
100     }
101 
readTestFile(String name)102     public static byte[] readTestFile(String name) throws IOException {
103         return Streams.readFully(openTestFile(name));
104     }
105 
106     /**
107      * Looks up the conscrypt class for the given simple name (i.e. no package prefix).
108      */
conscryptClass(String simpleName)109     public static Class<?> conscryptClass(String simpleName) throws ClassNotFoundException {
110         ClassNotFoundException ex = null;
111         for (String packageName : new String[]{"com.android.org.conscrypt", "org.conscrypt"}) {
112             String name = packageName + "." + simpleName;
113             try {
114                 return Class.forName(name);
115             } catch (ClassNotFoundException e) {
116                 ex = e;
117             }
118         }
119         throw ex;
120     }
121 
122     /**
123      * Returns an array containing only {@link #PROTOCOL_TLS_V1_2}.
124      */
getProtocols()125     public static String[] getProtocols() {
126         return new String[] {PROTOCOL_TLS_V1_2};
127     }
128 
getJdkSocketFactory()129     public static SSLSocketFactory getJdkSocketFactory() {
130         return getSocketFactory(JDK_PROVIDER);
131     }
132 
getJdkServerSocketFactory()133     public static SSLServerSocketFactory getJdkServerSocketFactory() {
134         return getServerSocketFactory(JDK_PROVIDER);
135     }
136 
setUseEngineSocket(SSLSocketFactory conscryptFactory, boolean useEngineSocket)137     static SSLSocketFactory setUseEngineSocket(SSLSocketFactory conscryptFactory, boolean useEngineSocket) {
138         try {
139             Class<?> clazz = conscryptClass("Conscrypt$SocketFactories");
140             Method method = clazz.getMethod("setUseEngineSocket", SSLSocketFactory.class, boolean.class);
141             method.invoke(null, conscryptFactory, useEngineSocket);
142             return conscryptFactory;
143         } catch (Exception e) {
144             throw new RuntimeException(e);
145         }
146     }
147 
setUseEngineSocket(SSLServerSocketFactory conscryptFactory, boolean useEngineSocket)148     static SSLServerSocketFactory setUseEngineSocket(SSLServerSocketFactory conscryptFactory, boolean useEngineSocket) {
149         try {
150             Class<?> clazz = conscryptClass("Conscrypt$ServerSocketFactories");
151             Method method = clazz.getMethod("setUseEngineSocket", SSLServerSocketFactory.class, boolean.class);
152             method.invoke(null, conscryptFactory, useEngineSocket);
153             return conscryptFactory;
154         } catch (Exception e) {
155             throw new RuntimeException(e);
156         }
157     }
158 
getConscryptSocketFactory(boolean useEngineSocket)159     public static SSLSocketFactory getConscryptSocketFactory(boolean useEngineSocket) {
160         return setUseEngineSocket(getSocketFactory(getConscryptProvider()), useEngineSocket);
161     }
162 
getConscryptServerSocketFactory(boolean useEngineSocket)163     public static SSLServerSocketFactory getConscryptServerSocketFactory(boolean useEngineSocket) {
164         return setUseEngineSocket(getServerSocketFactory(getConscryptProvider()), useEngineSocket);
165     }
166 
getSocketFactory(Provider provider)167     private static SSLSocketFactory getSocketFactory(Provider provider) {
168         SSLContext clientContext = initClientSslContext(newContext(provider));
169         return clientContext.getSocketFactory();
170     }
171 
getServerSocketFactory(Provider provider)172     private static SSLServerSocketFactory getServerSocketFactory(Provider provider) {
173         SSLContext serverContext = initServerSslContext(newContext(provider));
174         return serverContext.getServerSocketFactory();
175     }
176 
newContext(Provider provider)177     private static SSLContext newContext(Provider provider) {
178         try {
179             return SSLContext.getInstance("TLS", provider);
180         } catch (NoSuchAlgorithmException e) {
181             throw new RuntimeException(e);
182         }
183     }
184 
185     /**
186      * Picks a port that is not used right at this moment.
187      * Warning: Not thread safe. May see "BindException: Address already in use: bind" if using the
188      * returned port to create a new server socket when other threads/processes are concurrently
189      * creating new sockets without a specific port.
190      */
pickUnusedPort()191     public static int pickUnusedPort() {
192         try {
193             ServerSocket serverSocket = new ServerSocket(0);
194             int port = serverSocket.getLocalPort();
195             serverSocket.close();
196             return port;
197         } catch (IOException e) {
198             throw new RuntimeException(e);
199         }
200     }
201 
202     /**
203      * Creates a text message of the given length.
204      */
newTextMessage(int length)205     public static byte[] newTextMessage(int length) {
206         byte[] msg = new byte[length];
207         for (int msgIndex = 0; msgIndex < length;) {
208             int remaining = length - msgIndex;
209             int numChars = Math.min(remaining, CHARS.length);
210             System.arraycopy(CHARS, 0, msg, msgIndex, numChars);
211             msgIndex += numChars;
212         }
213         return msg;
214     }
215 
216     /**
217      * Initializes the given engine with the cipher and client mode.
218      */
initEngine(SSLEngine engine, String cipher, boolean client)219     static SSLEngine initEngine(SSLEngine engine, String cipher, boolean client) {
220         engine.setEnabledProtocols(getProtocols());
221         engine.setEnabledCipherSuites(new String[] {cipher});
222         engine.setUseClientMode(client);
223         return engine;
224     }
225 
newClientSslContext(Provider provider)226     static SSLContext newClientSslContext(Provider provider) {
227         SSLContext context = newContext(provider);
228         return initClientSslContext(context);
229     }
230 
newServerSslContext(Provider provider)231     static SSLContext newServerSslContext(Provider provider) {
232         SSLContext context = newContext(provider);
233         return initServerSslContext(context);
234     }
235 
236     /**
237      * Initializes the given client-side {@code context} with a default cert.
238      */
initClientSslContext(SSLContext context)239     public static SSLContext initClientSslContext(SSLContext context) {
240         return initSslContext(context, TestKeyStore.getClient());
241     }
242 
243     /**
244      * Initializes the given server-side {@code context} with the given cert chain and private key.
245      */
initServerSslContext(SSLContext context)246     public static SSLContext initServerSslContext(SSLContext context) {
247         return initSslContext(context, TestKeyStore.getServer());
248     }
249 
250     /**
251      * Initializes the given {@code context} from the {@code keyStore}.
252      */
initSslContext(SSLContext context, TestKeyStore keyStore)253     static SSLContext initSslContext(SSLContext context, TestKeyStore keyStore) {
254         try {
255             context.init(keyStore.keyManagers, keyStore.trustManagers, null);
256             return context;
257         } catch (Exception e) {
258             throw new RuntimeException(e);
259         }
260     }
261 
262     /**
263      * Performs the intial TLS handshake between the two {@link SSLEngine} instances.
264      */
doEngineHandshake(SSLEngine clientEngine, SSLEngine serverEngine, ByteBuffer clientAppBuffer, ByteBuffer clientPacketBuffer, ByteBuffer serverAppBuffer, ByteBuffer serverPacketBuffer)265     public static void doEngineHandshake(SSLEngine clientEngine, SSLEngine serverEngine,
266             ByteBuffer clientAppBuffer, ByteBuffer clientPacketBuffer, ByteBuffer serverAppBuffer,
267             ByteBuffer serverPacketBuffer) throws SSLException {
268         clientEngine.beginHandshake();
269         serverEngine.beginHandshake();
270 
271         SSLEngineResult clientResult;
272         SSLEngineResult serverResult;
273 
274         boolean clientHandshakeFinished = false;
275         boolean serverHandshakeFinished = false;
276 
277         do {
278             int cTOsPos = clientPacketBuffer.position();
279             int sTOcPos = serverPacketBuffer.position();
280 
281             clientResult = clientEngine.wrap(EMPTY_BUFFER, clientPacketBuffer);
282             runDelegatedTasks(clientResult, clientEngine);
283             serverResult = serverEngine.wrap(EMPTY_BUFFER, serverPacketBuffer);
284             runDelegatedTasks(serverResult, serverEngine);
285 
286             // Verify that the consumed and produced number match what is in the buffers now.
287             assertEquals(0, clientResult.bytesConsumed());
288             assertEquals(0, serverResult.bytesConsumed());
289             assertEquals(clientPacketBuffer.position() - cTOsPos, clientResult.bytesProduced());
290             assertEquals(serverPacketBuffer.position() - sTOcPos, serverResult.bytesProduced());
291 
292             clientPacketBuffer.flip();
293             serverPacketBuffer.flip();
294 
295             // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED
296             if (isHandshakeFinished(clientResult)) {
297                 assertFalse(clientHandshakeFinished);
298                 clientHandshakeFinished = true;
299             }
300             if (isHandshakeFinished(serverResult)) {
301                 assertFalse(serverHandshakeFinished);
302                 serverHandshakeFinished = true;
303             }
304 
305             cTOsPos = clientPacketBuffer.position();
306             sTOcPos = serverPacketBuffer.position();
307 
308             int clientAppReadBufferPos = clientAppBuffer.position();
309             int serverAppReadBufferPos = serverAppBuffer.position();
310 
311             clientResult = clientEngine.unwrap(serverPacketBuffer, clientAppBuffer);
312             runDelegatedTasks(clientResult, clientEngine);
313             serverResult = serverEngine.unwrap(clientPacketBuffer, serverAppBuffer);
314             runDelegatedTasks(serverResult, serverEngine);
315 
316             // Verify that the consumed and produced number match what is in the buffers now.
317             assertEquals(serverPacketBuffer.position() - sTOcPos, clientResult.bytesConsumed());
318             assertEquals(clientPacketBuffer.position() - cTOsPos, serverResult.bytesConsumed());
319             assertEquals(clientAppBuffer.position() - clientAppReadBufferPos,
320                     clientResult.bytesProduced());
321             assertEquals(serverAppBuffer.position() - serverAppReadBufferPos,
322                     serverResult.bytesProduced());
323 
324             clientPacketBuffer.compact();
325             serverPacketBuffer.compact();
326 
327             // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED
328             if (isHandshakeFinished(clientResult)) {
329                 assertFalse(clientHandshakeFinished);
330                 clientHandshakeFinished = true;
331             }
332             if (isHandshakeFinished(serverResult)) {
333                 assertFalse(serverHandshakeFinished);
334                 serverHandshakeFinished = true;
335             }
336         } while (!clientHandshakeFinished || !serverHandshakeFinished);
337     }
338 
isHandshakeFinished(SSLEngineResult result)339     private static boolean isHandshakeFinished(SSLEngineResult result) {
340         return result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED;
341     }
342 
runDelegatedTasks(SSLEngineResult result, SSLEngine engine)343     private static void runDelegatedTasks(SSLEngineResult result, SSLEngine engine) {
344         if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) {
345             for (;;) {
346                 Runnable task = engine.getDelegatedTask();
347                 if (task == null) {
348                     break;
349                 }
350                 task.run();
351             }
352         }
353     }
354 }
355