1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License 15 */ 16 17 package org.conscrypt; 18 19 import static org.junit.Assert.assertArrayEquals; 20 import static org.junit.Assert.assertEquals; 21 import static org.junit.Assert.assertNotEquals; 22 23 import java.io.EOFException; 24 import java.io.IOException; 25 import java.net.InetSocketAddress; 26 import java.nio.ByteBuffer; 27 import java.nio.channels.ServerSocketChannel; 28 import java.nio.channels.SocketChannel; 29 import java.util.Arrays; 30 import java.util.LinkedHashSet; 31 import java.util.Set; 32 import java.util.concurrent.ExecutionException; 33 import java.util.concurrent.ExecutorService; 34 import java.util.concurrent.Executors; 35 import java.util.concurrent.Future; 36 import java.util.concurrent.TimeUnit; 37 import java.util.concurrent.TimeoutException; 38 import javax.net.ssl.SSLContext; 39 import javax.net.ssl.SSLEngine; 40 import javax.net.ssl.SSLEngineResult; 41 import javax.net.ssl.SSLEngineResult.Status; 42 import javax.net.ssl.SSLSocket; 43 import javax.net.ssl.SSLSocketFactory; 44 import org.conscrypt.java.security.TestKeyStore; 45 import org.junit.After; 46 import org.junit.Before; 47 import org.junit.Test; 48 import org.junit.runner.RunWith; 49 import org.junit.runners.Parameterized; 50 import org.junit.runners.Parameterized.Parameter; 51 import org.junit.runners.Parameterized.Parameters; 52 53 /** 54 * This tests that server-initiated cipher renegotiation works properly with a Conscrypt client. 55 * BoringSSL does not support user-initiated renegotiation, so we use the JDK implementation for 56 * the server. 57 */ 58 @RunWith(Parameterized.class) 59 public class RenegotiationTest { 60 private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0); 61 private static final String[] CIPHERS = TestUtils.getCommonCipherSuites(); 62 private static final byte[] MESSAGE_BYTES = "Hello".getBytes(TestUtils.UTF_8); 63 private static final ByteBuffer MESSAGE_BUFFER = 64 ByteBuffer.wrap(MESSAGE_BYTES).asReadOnlyBuffer(); 65 private static final int MESSAGE_LENGTH = MESSAGE_BYTES.length; 66 67 public enum SocketType { 68 FILE_DESCRIPTOR { 69 @Override newClient(int port)70 Client newClient(int port) { 71 return new Client(false, port); 72 } 73 }, 74 ENGINE { 75 @Override newClient(int port)76 Client newClient(int port) { 77 return new Client(true, port); 78 } 79 }; 80 newClient(int port)81 abstract Client newClient(int port); 82 } 83 84 @Parameters(name = "{0}") data()85 public static Object[] data() { 86 return new Object[] {SocketType.FILE_DESCRIPTOR, SocketType.ENGINE}; 87 } 88 89 @Parameter 90 public SocketType socketType; 91 92 private Client client; 93 private Server server; 94 95 @Before setup()96 public void setup() throws Exception { 97 server = new Server(); 98 Future<?> connectedFuture = server.start(); 99 100 client = socketType.newClient(server.port()); 101 client.start(); 102 103 // Wait for the initial connection to complete. 104 connectedFuture.get(5, TimeUnit.SECONDS); 105 } 106 107 @After teardown()108 public void teardown() { 109 client.stop(); 110 server.stop(); 111 } 112 113 @Test test()114 public void test() throws Exception { 115 client.socket.startHandshake(); 116 String initialCipher = client.socket.getSession().getCipherSuite(); 117 118 client.sendMessage(); 119 120 Future<?> repliesFuture = client.readReplies(); 121 server.await(5, TimeUnit.SECONDS); 122 repliesFuture.get(5, TimeUnit.SECONDS); 123 124 // Verify that the cipher has changed. 125 assertNotEquals(initialCipher, client.socket.getSession().getCipherSuite()); 126 } 127 newConscryptClientContext()128 private static SSLContext newConscryptClientContext() { 129 SSLContext context = TestUtils.newContext(TestUtils.getConscryptProvider()); 130 return TestUtils.initSslContext(context, TestKeyStore.getClient()); 131 } 132 newJdkServerContext()133 private static SSLContext newJdkServerContext() { 134 SSLContext context = TestUtils.newContext(TestUtils.getJdkProvider()); 135 return TestUtils.initSslContext(context, TestKeyStore.getServer()); 136 } 137 138 private static final class Client { 139 private final SSLSocket socket; 140 private ExecutorService executor; 141 private volatile boolean stopping; 142 Client(boolean useEngineSocket, int port)143 Client(boolean useEngineSocket, int port) { 144 try { 145 SSLSocketFactory socketFactory = newConscryptClientContext().getSocketFactory(); 146 Conscrypt.setUseEngineSocket(socketFactory, useEngineSocket); 147 socket = (SSLSocket) socketFactory.createSocket( 148 TestUtils.getLoopbackAddress(), port); 149 socket.setEnabledCipherSuites(CIPHERS); 150 } catch (IOException e) { 151 throw new RuntimeException(e); 152 } 153 } 154 start()155 void start() { 156 try { 157 executor = Executors.newSingleThreadExecutor(); 158 socket.startHandshake(); 159 } catch (IOException e) { 160 e.printStackTrace(); 161 throw new RuntimeException(e); 162 } 163 } 164 stop()165 void stop() { 166 try { 167 stopping = true; 168 socket.close(); 169 170 if (executor != null) { 171 executor.shutdown(); 172 executor.awaitTermination(5, TimeUnit.SECONDS); 173 executor = null; 174 } 175 } catch (RuntimeException e) { 176 throw e; 177 } catch (Exception e) { 178 throw new RuntimeException(e); 179 } 180 } 181 readReplies()182 Future<?> readReplies() { 183 return executor.submit(new Runnable() { 184 @Override 185 public void run() { 186 readReply(); 187 } 188 }); 189 } 190 readReply()191 private void readReply() { 192 try { 193 byte[] buffer = new byte[MESSAGE_LENGTH]; 194 int totalBytesRead = 0; 195 while (totalBytesRead < MESSAGE_LENGTH) { 196 int remaining = MESSAGE_LENGTH - totalBytesRead; 197 int bytesRead = socket.getInputStream().read(buffer, totalBytesRead, remaining); 198 if (bytesRead == -1) { 199 throw new EOFException(); 200 } 201 totalBytesRead += bytesRead; 202 } 203 204 // Verify the reply is correct. 205 assertEquals(MESSAGE_LENGTH, totalBytesRead); 206 assertArrayEquals(MESSAGE_BYTES, buffer); 207 } catch (IOException e) { 208 throw new RuntimeException(e); 209 } 210 } 211 sendMessage()212 void sendMessage() throws IOException { 213 try { 214 socket.getOutputStream().write(MESSAGE_BYTES); 215 socket.getOutputStream().flush(); 216 } catch (IOException e) { 217 throw new RuntimeException(e); 218 } 219 } 220 } 221 222 private static final class Server { 223 private final ServerSocketChannel serverChannel; 224 private final SSLEngine engine; 225 private final ByteBuffer inboundPacketBuffer; 226 private final ByteBuffer inboundAppBuffer; 227 private final ByteBuffer outboundPacketBuffer; 228 private final Set<String> ciphers = new LinkedHashSet<String>(Arrays.asList(CIPHERS)); 229 private SocketChannel channel; 230 private ExecutorService executor; 231 private volatile boolean stopping; 232 private volatile Future<?> echoFuture; 233 234 Server() throws IOException { 235 serverChannel = ServerSocketChannel.open(); 236 serverChannel.socket().bind(new InetSocketAddress(TestUtils.getLoopbackAddress(), 0)); 237 engine = newJdkServerContext().createSSLEngine(); 238 engine.setEnabledCipherSuites(CIPHERS); 239 engine.setUseClientMode(false); 240 241 inboundPacketBuffer = 242 ByteBuffer.allocateDirect(engine.getSession().getPacketBufferSize()); 243 inboundAppBuffer = 244 ByteBuffer.allocateDirect(engine.getSession().getApplicationBufferSize()); 245 outboundPacketBuffer = 246 ByteBuffer.allocateDirect(engine.getSession().getPacketBufferSize()); 247 } 248 249 Future<?> start() throws IOException { 250 executor = Executors.newSingleThreadExecutor(); 251 return executor.submit(new AcceptTask()); 252 } 253 254 void await(long timeout, TimeUnit unit) 255 throws InterruptedException, ExecutionException, TimeoutException { 256 echoFuture.get(timeout, unit); 257 } 258 259 void stop() { 260 try { 261 stopping = true; 262 263 if (channel != null) { 264 channel.close(); 265 channel = null; 266 } 267 268 serverChannel.close(); 269 270 if (executor != null) { 271 executor.shutdown(); 272 executor.awaitTermination(5, TimeUnit.SECONDS); 273 executor = null; 274 } 275 } catch (IOException e) { 276 throw new RuntimeException(e); 277 } catch (InterruptedException e) { 278 throw new RuntimeException(e); 279 } 280 } 281 282 int port() { 283 return serverChannel.socket().getLocalPort(); 284 } 285 286 private final class AcceptTask implements Runnable { 287 @Override 288 public void run() { 289 try { 290 if (stopping) { 291 return; 292 } 293 channel = serverChannel.accept(); 294 channel.configureBlocking(false); 295 296 doHandshake(); 297 298 if (stopping) { 299 return; 300 } 301 echoFuture = executor.submit(new EchoTask()); 302 } catch (Throwable e) { 303 e.printStackTrace(); 304 throw new RuntimeException(e); 305 } 306 } 307 } 308 309 private final class EchoTask implements Runnable { 310 @Override 311 public void run() { 312 try { 313 readMessage(); 314 renegotiate(); 315 reply(); 316 } catch (Throwable e) { 317 e.printStackTrace(); 318 throw new RuntimeException(e); 319 } 320 } 321 322 private void renegotiate() throws Exception { 323 // Remove the current cipher from the set and renegotiate to force a new 324 // cipher to be selected. 325 String currentCipher = engine.getSession().getCipherSuite(); 326 ciphers.remove(currentCipher); 327 engine.setEnabledCipherSuites(ciphers.toArray(new String[ciphers.size()])); 328 doHandshake(); 329 } 330 331 private void reply() throws IOException { 332 SSLEngineResult result = wrap(newMessage()); 333 if (result.getStatus() != Status.OK) { 334 throw new RuntimeException("Wrap failed. Status: " + result.getStatus()); 335 } 336 } 337 338 private ByteBuffer newMessage() { 339 return MESSAGE_BUFFER.duplicate(); 340 } 341 342 private void readMessage() throws IOException { 343 int totalProduced = 0; 344 while (!stopping) { 345 SSLEngineResult result = unwrap(); 346 if (result.getStatus() != Status.OK) { 347 throw new RuntimeException("Failed reading message: " + result); 348 } 349 totalProduced += result.bytesProduced(); 350 if (totalProduced == MESSAGE_LENGTH) { 351 return; 352 } 353 } 354 } 355 } 356 357 private SSLEngineResult wrap(ByteBuffer src) throws IOException { 358 outboundPacketBuffer.clear(); 359 360 // Check if the engine has bytes to wrap. 361 SSLEngineResult result = engine.wrap(src, outboundPacketBuffer); 362 363 // Write any wrapped bytes to the socket. 364 outboundPacketBuffer.flip(); 365 366 do { 367 channel.write(outboundPacketBuffer); 368 } while (outboundPacketBuffer.hasRemaining()); 369 370 return result; 371 } 372 373 private SSLEngineResult unwrap() throws IOException { 374 // Unwrap any available bytes from the socket. 375 SSLEngineResult result = null; 376 boolean done = false; 377 while (!done) { 378 if (channel.read(inboundPacketBuffer) == -1) { 379 throw new EOFException(); 380 } 381 // Just clear the app buffer - we don't really use it. 382 inboundAppBuffer.clear(); 383 inboundPacketBuffer.flip(); 384 result = engine.unwrap(inboundPacketBuffer, inboundAppBuffer); 385 switch (result.getStatus()) { 386 case BUFFER_UNDERFLOW: 387 // Continue reading from the socket in a moment. 388 try { 389 Thread.sleep(10); 390 } catch (InterruptedException e) { 391 throw new RuntimeException(e); 392 } 393 break; 394 case OK: 395 done = true; 396 break; 397 default: { throw new RuntimeException("Unexpected unwrap result: " + result); } 398 } 399 400 // Compact for the next socket read. 401 inboundPacketBuffer.compact(); 402 } 403 return result; 404 } 405 406 private void doHandshake() throws IOException { 407 engine.beginHandshake(); 408 409 boolean done = false; 410 while (!done) { 411 switch (engine.getHandshakeStatus()) { 412 case NEED_WRAP: { 413 wrap(EMPTY_BUFFER); 414 break; 415 } 416 case NEED_UNWRAP: { 417 unwrap(); 418 break; 419 } 420 case NEED_TASK: { 421 runDelegatedTasks(); 422 break; 423 } 424 default: { 425 done = true; 426 break; 427 } 428 } 429 } 430 } 431 432 private void runDelegatedTasks() { 433 for (;;) { 434 Runnable task = engine.getDelegatedTask(); 435 if (task == null) { 436 break; 437 } 438 task.run(); 439 } 440 } 441 } 442 } 443