1 /* 2 * Copyright (C) 2015 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package org.conscrypt; 18 19 import static org.junit.Assert.assertEquals; 20 import static org.junit.Assert.assertFalse; 21 22 import java.io.FileNotFoundException; 23 import java.io.IOException; 24 import java.io.InputStream; 25 import java.lang.reflect.Method; 26 import java.net.ServerSocket; 27 import java.nio.ByteBuffer; 28 import java.nio.charset.Charset; 29 import java.security.NoSuchAlgorithmException; 30 import java.security.Provider; 31 import java.security.Security; 32 import javax.net.ssl.SSLContext; 33 import javax.net.ssl.SSLEngine; 34 import javax.net.ssl.SSLEngineResult; 35 import javax.net.ssl.SSLException; 36 import javax.net.ssl.SSLServerSocketFactory; 37 import javax.net.ssl.SSLSocketFactory; 38 import libcore.io.Streams; 39 import libcore.java.security.TestKeyStore; 40 41 /** 42 * Utility methods to support testing. 43 */ 44 public final class TestUtils { 45 static final Charset UTF_8 = Charset.forName("UTF-8"); 46 47 private static final Provider JDK_PROVIDER = getDefaultTlsProvider(); 48 private static final byte[] CHARS = 49 "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".getBytes(UTF_8); 50 private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0); 51 52 public static final String PROTOCOL_TLS_V1_2 = "TLSv1.2"; 53 public static final String PROVIDER_PROPERTY = "SSLContext.TLSv1.2"; 54 public static final String LOCALHOST = "localhost"; 55 56 static final String TEST_CIPHER = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"; 57 TestUtils()58 private TestUtils() {} 59 getDefaultTlsProvider()60 private static Provider getDefaultTlsProvider() { 61 for (Provider p : Security.getProviders()) { 62 if (p.get(PROVIDER_PROPERTY) != null) { 63 return p; 64 } 65 } 66 throw new RuntimeException("Unable to find a default provider for " + PROVIDER_PROPERTY); 67 } 68 getJdkProvider()69 static Provider getJdkProvider() { 70 return JDK_PROVIDER; 71 } 72 getConscryptProvider()73 public static Provider getConscryptProvider() { 74 try { 75 return (Provider) conscryptClass("OpenSSLProvider") 76 .getConstructor() 77 .newInstance(); 78 } catch (Exception e) { 79 throw new RuntimeException(e); 80 } 81 } 82 installConscryptAsDefaultProvider()83 public static void installConscryptAsDefaultProvider() { 84 final Provider conscryptProvider = getConscryptProvider(); 85 synchronized (getConscryptProvider()) { 86 Provider[] providers = Security.getProviders(); 87 if (providers.length == 0 || !providers[0].equals(conscryptProvider)) { 88 Security.insertProviderAt(conscryptProvider, 1); 89 return; 90 } 91 } 92 } 93 openTestFile(String name)94 public static InputStream openTestFile(String name) throws FileNotFoundException { 95 InputStream is = TestUtils.class.getResourceAsStream("/" + name); 96 if (is == null) { 97 throw new FileNotFoundException(name); 98 } 99 return is; 100 } 101 readTestFile(String name)102 public static byte[] readTestFile(String name) throws IOException { 103 return Streams.readFully(openTestFile(name)); 104 } 105 106 /** 107 * Looks up the conscrypt class for the given simple name (i.e. no package prefix). 108 */ conscryptClass(String simpleName)109 public static Class<?> conscryptClass(String simpleName) throws ClassNotFoundException { 110 ClassNotFoundException ex = null; 111 for (String packageName : new String[]{"com.android.org.conscrypt", "org.conscrypt"}) { 112 String name = packageName + "." + simpleName; 113 try { 114 return Class.forName(name); 115 } catch (ClassNotFoundException e) { 116 ex = e; 117 } 118 } 119 throw ex; 120 } 121 122 /** 123 * Returns an array containing only {@link #PROTOCOL_TLS_V1_2}. 124 */ getProtocols()125 public static String[] getProtocols() { 126 return new String[] {PROTOCOL_TLS_V1_2}; 127 } 128 getJdkSocketFactory()129 public static SSLSocketFactory getJdkSocketFactory() { 130 return getSocketFactory(JDK_PROVIDER); 131 } 132 getJdkServerSocketFactory()133 public static SSLServerSocketFactory getJdkServerSocketFactory() { 134 return getServerSocketFactory(JDK_PROVIDER); 135 } 136 setUseEngineSocket(SSLSocketFactory conscryptFactory, boolean useEngineSocket)137 static SSLSocketFactory setUseEngineSocket(SSLSocketFactory conscryptFactory, boolean useEngineSocket) { 138 try { 139 Class<?> clazz = conscryptClass("Conscrypt$SocketFactories"); 140 Method method = clazz.getMethod("setUseEngineSocket", SSLSocketFactory.class, boolean.class); 141 method.invoke(null, conscryptFactory, useEngineSocket); 142 return conscryptFactory; 143 } catch (Exception e) { 144 throw new RuntimeException(e); 145 } 146 } 147 setUseEngineSocket(SSLServerSocketFactory conscryptFactory, boolean useEngineSocket)148 static SSLServerSocketFactory setUseEngineSocket(SSLServerSocketFactory conscryptFactory, boolean useEngineSocket) { 149 try { 150 Class<?> clazz = conscryptClass("Conscrypt$ServerSocketFactories"); 151 Method method = clazz.getMethod("setUseEngineSocket", SSLServerSocketFactory.class, boolean.class); 152 method.invoke(null, conscryptFactory, useEngineSocket); 153 return conscryptFactory; 154 } catch (Exception e) { 155 throw new RuntimeException(e); 156 } 157 } 158 getConscryptSocketFactory(boolean useEngineSocket)159 public static SSLSocketFactory getConscryptSocketFactory(boolean useEngineSocket) { 160 return setUseEngineSocket(getSocketFactory(getConscryptProvider()), useEngineSocket); 161 } 162 getConscryptServerSocketFactory(boolean useEngineSocket)163 public static SSLServerSocketFactory getConscryptServerSocketFactory(boolean useEngineSocket) { 164 return setUseEngineSocket(getServerSocketFactory(getConscryptProvider()), useEngineSocket); 165 } 166 getSocketFactory(Provider provider)167 private static SSLSocketFactory getSocketFactory(Provider provider) { 168 SSLContext clientContext = initClientSslContext(newContext(provider)); 169 return clientContext.getSocketFactory(); 170 } 171 getServerSocketFactory(Provider provider)172 private static SSLServerSocketFactory getServerSocketFactory(Provider provider) { 173 SSLContext serverContext = initServerSslContext(newContext(provider)); 174 return serverContext.getServerSocketFactory(); 175 } 176 newContext(Provider provider)177 private static SSLContext newContext(Provider provider) { 178 try { 179 return SSLContext.getInstance("TLS", provider); 180 } catch (NoSuchAlgorithmException e) { 181 throw new RuntimeException(e); 182 } 183 } 184 185 /** 186 * Picks a port that is not used right at this moment. 187 * Warning: Not thread safe. May see "BindException: Address already in use: bind" if using the 188 * returned port to create a new server socket when other threads/processes are concurrently 189 * creating new sockets without a specific port. 190 */ pickUnusedPort()191 public static int pickUnusedPort() { 192 try { 193 ServerSocket serverSocket = new ServerSocket(0); 194 int port = serverSocket.getLocalPort(); 195 serverSocket.close(); 196 return port; 197 } catch (IOException e) { 198 throw new RuntimeException(e); 199 } 200 } 201 202 /** 203 * Creates a text message of the given length. 204 */ newTextMessage(int length)205 public static byte[] newTextMessage(int length) { 206 byte[] msg = new byte[length]; 207 for (int msgIndex = 0; msgIndex < length;) { 208 int remaining = length - msgIndex; 209 int numChars = Math.min(remaining, CHARS.length); 210 System.arraycopy(CHARS, 0, msg, msgIndex, numChars); 211 msgIndex += numChars; 212 } 213 return msg; 214 } 215 216 /** 217 * Initializes the given engine with the cipher and client mode. 218 */ initEngine(SSLEngine engine, String cipher, boolean client)219 static SSLEngine initEngine(SSLEngine engine, String cipher, boolean client) { 220 engine.setEnabledProtocols(getProtocols()); 221 engine.setEnabledCipherSuites(new String[] {cipher}); 222 engine.setUseClientMode(client); 223 return engine; 224 } 225 newClientSslContext(Provider provider)226 static SSLContext newClientSslContext(Provider provider) { 227 SSLContext context = newContext(provider); 228 return initClientSslContext(context); 229 } 230 newServerSslContext(Provider provider)231 static SSLContext newServerSslContext(Provider provider) { 232 SSLContext context = newContext(provider); 233 return initServerSslContext(context); 234 } 235 236 /** 237 * Initializes the given client-side {@code context} with a default cert. 238 */ initClientSslContext(SSLContext context)239 public static SSLContext initClientSslContext(SSLContext context) { 240 return initSslContext(context, TestKeyStore.getClient()); 241 } 242 243 /** 244 * Initializes the given server-side {@code context} with the given cert chain and private key. 245 */ initServerSslContext(SSLContext context)246 public static SSLContext initServerSslContext(SSLContext context) { 247 return initSslContext(context, TestKeyStore.getServer()); 248 } 249 250 /** 251 * Initializes the given {@code context} from the {@code keyStore}. 252 */ initSslContext(SSLContext context, TestKeyStore keyStore)253 static SSLContext initSslContext(SSLContext context, TestKeyStore keyStore) { 254 try { 255 context.init(keyStore.keyManagers, keyStore.trustManagers, null); 256 return context; 257 } catch (Exception e) { 258 throw new RuntimeException(e); 259 } 260 } 261 262 /** 263 * Performs the intial TLS handshake between the two {@link SSLEngine} instances. 264 */ doEngineHandshake(SSLEngine clientEngine, SSLEngine serverEngine, ByteBuffer clientAppBuffer, ByteBuffer clientPacketBuffer, ByteBuffer serverAppBuffer, ByteBuffer serverPacketBuffer)265 public static void doEngineHandshake(SSLEngine clientEngine, SSLEngine serverEngine, 266 ByteBuffer clientAppBuffer, ByteBuffer clientPacketBuffer, ByteBuffer serverAppBuffer, 267 ByteBuffer serverPacketBuffer) throws SSLException { 268 clientEngine.beginHandshake(); 269 serverEngine.beginHandshake(); 270 271 SSLEngineResult clientResult; 272 SSLEngineResult serverResult; 273 274 boolean clientHandshakeFinished = false; 275 boolean serverHandshakeFinished = false; 276 277 do { 278 int cTOsPos = clientPacketBuffer.position(); 279 int sTOcPos = serverPacketBuffer.position(); 280 281 clientResult = clientEngine.wrap(EMPTY_BUFFER, clientPacketBuffer); 282 runDelegatedTasks(clientResult, clientEngine); 283 serverResult = serverEngine.wrap(EMPTY_BUFFER, serverPacketBuffer); 284 runDelegatedTasks(serverResult, serverEngine); 285 286 // Verify that the consumed and produced number match what is in the buffers now. 287 assertEquals(0, clientResult.bytesConsumed()); 288 assertEquals(0, serverResult.bytesConsumed()); 289 assertEquals(clientPacketBuffer.position() - cTOsPos, clientResult.bytesProduced()); 290 assertEquals(serverPacketBuffer.position() - sTOcPos, serverResult.bytesProduced()); 291 292 clientPacketBuffer.flip(); 293 serverPacketBuffer.flip(); 294 295 // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED 296 if (isHandshakeFinished(clientResult)) { 297 assertFalse(clientHandshakeFinished); 298 clientHandshakeFinished = true; 299 } 300 if (isHandshakeFinished(serverResult)) { 301 assertFalse(serverHandshakeFinished); 302 serverHandshakeFinished = true; 303 } 304 305 cTOsPos = clientPacketBuffer.position(); 306 sTOcPos = serverPacketBuffer.position(); 307 308 int clientAppReadBufferPos = clientAppBuffer.position(); 309 int serverAppReadBufferPos = serverAppBuffer.position(); 310 311 clientResult = clientEngine.unwrap(serverPacketBuffer, clientAppBuffer); 312 runDelegatedTasks(clientResult, clientEngine); 313 serverResult = serverEngine.unwrap(clientPacketBuffer, serverAppBuffer); 314 runDelegatedTasks(serverResult, serverEngine); 315 316 // Verify that the consumed and produced number match what is in the buffers now. 317 assertEquals(serverPacketBuffer.position() - sTOcPos, clientResult.bytesConsumed()); 318 assertEquals(clientPacketBuffer.position() - cTOsPos, serverResult.bytesConsumed()); 319 assertEquals(clientAppBuffer.position() - clientAppReadBufferPos, 320 clientResult.bytesProduced()); 321 assertEquals(serverAppBuffer.position() - serverAppReadBufferPos, 322 serverResult.bytesProduced()); 323 324 clientPacketBuffer.compact(); 325 serverPacketBuffer.compact(); 326 327 // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED 328 if (isHandshakeFinished(clientResult)) { 329 assertFalse(clientHandshakeFinished); 330 clientHandshakeFinished = true; 331 } 332 if (isHandshakeFinished(serverResult)) { 333 assertFalse(serverHandshakeFinished); 334 serverHandshakeFinished = true; 335 } 336 } while (!clientHandshakeFinished || !serverHandshakeFinished); 337 } 338 isHandshakeFinished(SSLEngineResult result)339 private static boolean isHandshakeFinished(SSLEngineResult result) { 340 return result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED; 341 } 342 runDelegatedTasks(SSLEngineResult result, SSLEngine engine)343 private static void runDelegatedTasks(SSLEngineResult result, SSLEngine engine) { 344 if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) { 345 for (;;) { 346 Runnable task = engine.getDelegatedTask(); 347 if (task == null) { 348 break; 349 } 350 task.run(); 351 } 352 } 353 } 354 } 355