• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License
15  */
16 
17 package org.conscrypt;
18 
19 import static org.junit.Assert.assertArrayEquals;
20 import static org.junit.Assert.assertEquals;
21 import static org.junit.Assert.assertNotEquals;
22 
23 import java.io.EOFException;
24 import java.io.IOException;
25 import java.net.InetSocketAddress;
26 import java.nio.ByteBuffer;
27 import java.nio.channels.ServerSocketChannel;
28 import java.nio.channels.SocketChannel;
29 import java.util.Arrays;
30 import java.util.LinkedHashSet;
31 import java.util.Set;
32 import java.util.concurrent.ExecutionException;
33 import java.util.concurrent.ExecutorService;
34 import java.util.concurrent.Executors;
35 import java.util.concurrent.Future;
36 import java.util.concurrent.TimeUnit;
37 import java.util.concurrent.TimeoutException;
38 import javax.net.ssl.SSLContext;
39 import javax.net.ssl.SSLEngine;
40 import javax.net.ssl.SSLEngineResult;
41 import javax.net.ssl.SSLEngineResult.Status;
42 import javax.net.ssl.SSLSocket;
43 import javax.net.ssl.SSLSocketFactory;
44 import org.conscrypt.java.security.TestKeyStore;
45 import org.junit.After;
46 import org.junit.Before;
47 import org.junit.Test;
48 import org.junit.runner.RunWith;
49 import org.junit.runners.Parameterized;
50 import org.junit.runners.Parameterized.Parameter;
51 import org.junit.runners.Parameterized.Parameters;
52 
53 /**
54  * This tests that server-initiated cipher renegotiation works properly with a Conscrypt client.
55  * BoringSSL does not support user-initiated renegotiation, so we use the JDK implementation for
56  * the server.
57  */
58 @RunWith(Parameterized.class)
59 public class RenegotiationTest {
60     private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0);
61     private static final String[] CIPHERS = TestUtils.getCommonCipherSuites();
62     private static final byte[] MESSAGE_BYTES = "Hello".getBytes(TestUtils.UTF_8);
63     private static final ByteBuffer MESSAGE_BUFFER =
64             ByteBuffer.wrap(MESSAGE_BYTES).asReadOnlyBuffer();
65     private static final int MESSAGE_LENGTH = MESSAGE_BYTES.length;
66 
67     public enum SocketType {
68         FILE_DESCRIPTOR {
69             @Override
newClient(int port)70             Client newClient(int port) {
71                 return new Client(false, port);
72             }
73         },
74         ENGINE {
75             @Override
newClient(int port)76             Client newClient(int port) {
77                 return new Client(true, port);
78             }
79         };
80 
newClient(int port)81         abstract Client newClient(int port);
82     }
83 
84     @Parameters(name = "{0}")
data()85     public static Object[] data() {
86         return new Object[] {SocketType.FILE_DESCRIPTOR, SocketType.ENGINE};
87     }
88 
89     @Parameter
90     public SocketType socketType;
91 
92     private Client client;
93     private Server server;
94 
95     @Before
setup()96     public void setup() throws Exception {
97         server = new Server();
98         Future<?> connectedFuture = server.start();
99 
100         client = socketType.newClient(server.port());
101         client.start();
102 
103         // Wait for the initial connection to complete.
104         connectedFuture.get(5, TimeUnit.SECONDS);
105     }
106 
107     @After
teardown()108     public void teardown() {
109         client.stop();
110         server.stop();
111     }
112 
113     @Test
test()114     public void test() throws Exception {
115         client.socket.startHandshake();
116         String initialCipher = client.socket.getSession().getCipherSuite();
117 
118         client.sendMessage();
119 
120         Future<?> repliesFuture = client.readReplies();
121         server.await(5, TimeUnit.SECONDS);
122         repliesFuture.get(5, TimeUnit.SECONDS);
123 
124         // Verify that the cipher has changed.
125         assertNotEquals(initialCipher, client.socket.getSession().getCipherSuite());
126     }
127 
newConscryptClientContext()128     private static SSLContext newConscryptClientContext() {
129         SSLContext context = TestUtils.newContext(TestUtils.getConscryptProvider());
130         return TestUtils.initSslContext(context, TestKeyStore.getClient());
131     }
132 
newJdkServerContext()133     private static SSLContext newJdkServerContext() {
134         SSLContext context = TestUtils.newContext(TestUtils.getJdkProvider());
135         return TestUtils.initSslContext(context, TestKeyStore.getServer());
136     }
137 
138     private static final class Client {
139         private final SSLSocket socket;
140         private ExecutorService executor;
141         private volatile boolean stopping;
142 
Client(boolean useEngineSocket, int port)143         Client(boolean useEngineSocket, int port) {
144             try {
145                 SSLSocketFactory socketFactory = newConscryptClientContext().getSocketFactory();
146                 Conscrypt.setUseEngineSocket(socketFactory, useEngineSocket);
147                 socket = (SSLSocket) socketFactory.createSocket(
148                         TestUtils.getLoopbackAddress(), port);
149                 socket.setEnabledCipherSuites(CIPHERS);
150             } catch (IOException e) {
151                 throw new RuntimeException(e);
152             }
153         }
154 
start()155         void start() {
156             try {
157                 executor = Executors.newSingleThreadExecutor();
158                 socket.startHandshake();
159             } catch (IOException e) {
160                 e.printStackTrace();
161                 throw new RuntimeException(e);
162             }
163         }
164 
stop()165         void stop() {
166             try {
167                 stopping = true;
168                 socket.close();
169 
170                 if (executor != null) {
171                     executor.shutdown();
172                     executor.awaitTermination(5, TimeUnit.SECONDS);
173                     executor = null;
174                 }
175             } catch (RuntimeException e) {
176                 throw e;
177             } catch (Exception e) {
178                 throw new RuntimeException(e);
179             }
180         }
181 
readReplies()182         Future<?> readReplies() {
183             return executor.submit(new Runnable() {
184                 @Override
185                 public void run() {
186                     readReply();
187                 }
188             });
189         }
190 
readReply()191         private void readReply() {
192             try {
193                 byte[] buffer = new byte[MESSAGE_LENGTH];
194                 int totalBytesRead = 0;
195                 while (totalBytesRead < MESSAGE_LENGTH) {
196                     int remaining = MESSAGE_LENGTH - totalBytesRead;
197                     int bytesRead = socket.getInputStream().read(buffer, totalBytesRead, remaining);
198                     if (bytesRead == -1) {
199                         throw new EOFException();
200                     }
201                     totalBytesRead += bytesRead;
202                 }
203 
204                 // Verify the reply is correct.
205                 assertEquals(MESSAGE_LENGTH, totalBytesRead);
206                 assertArrayEquals(MESSAGE_BYTES, buffer);
207             } catch (IOException e) {
208                 throw new RuntimeException(e);
209             }
210         }
211 
sendMessage()212         void sendMessage() throws IOException {
213             try {
214                 socket.getOutputStream().write(MESSAGE_BYTES);
215                 socket.getOutputStream().flush();
216             } catch (IOException e) {
217                 throw new RuntimeException(e);
218             }
219         }
220     }
221 
222     private static final class Server {
223         private final ServerSocketChannel serverChannel;
224         private final SSLEngine engine;
225         private final ByteBuffer inboundPacketBuffer;
226         private final ByteBuffer inboundAppBuffer;
227         private final ByteBuffer outboundPacketBuffer;
228         private final Set<String> ciphers = new LinkedHashSet<String>(Arrays.asList(CIPHERS));
229         private SocketChannel channel;
230         private ExecutorService executor;
231         private volatile boolean stopping;
232         private volatile Future<?> echoFuture;
233 
234         Server() throws IOException {
235             serverChannel = ServerSocketChannel.open();
236             serverChannel.socket().bind(new InetSocketAddress(TestUtils.getLoopbackAddress(), 0));
237             engine = newJdkServerContext().createSSLEngine();
238             engine.setEnabledCipherSuites(CIPHERS);
239             engine.setUseClientMode(false);
240 
241             inboundPacketBuffer =
242                     ByteBuffer.allocateDirect(engine.getSession().getPacketBufferSize());
243             inboundAppBuffer =
244                     ByteBuffer.allocateDirect(engine.getSession().getApplicationBufferSize());
245             outboundPacketBuffer =
246                     ByteBuffer.allocateDirect(engine.getSession().getPacketBufferSize());
247         }
248 
249         Future<?> start() throws IOException {
250             executor = Executors.newSingleThreadExecutor();
251             return executor.submit(new AcceptTask());
252         }
253 
254         void await(long timeout, TimeUnit unit)
255                 throws InterruptedException, ExecutionException, TimeoutException {
256             echoFuture.get(timeout, unit);
257         }
258 
259         void stop() {
260             try {
261                 stopping = true;
262 
263                 if (channel != null) {
264                     channel.close();
265                     channel = null;
266                 }
267 
268                 serverChannel.close();
269 
270                 if (executor != null) {
271                     executor.shutdown();
272                     executor.awaitTermination(5, TimeUnit.SECONDS);
273                     executor = null;
274                 }
275             } catch (IOException e) {
276                 throw new RuntimeException(e);
277             } catch (InterruptedException e) {
278                 throw new RuntimeException(e);
279             }
280         }
281 
282         int port() {
283             return serverChannel.socket().getLocalPort();
284         }
285 
286         private final class AcceptTask implements Runnable {
287             @Override
288             public void run() {
289                 try {
290                     if (stopping) {
291                         return;
292                     }
293                     channel = serverChannel.accept();
294                     channel.configureBlocking(false);
295 
296                     doHandshake();
297 
298                     if (stopping) {
299                         return;
300                     }
301                     echoFuture = executor.submit(new EchoTask());
302                 } catch (Throwable e) {
303                     e.printStackTrace();
304                     throw new RuntimeException(e);
305                 }
306             }
307         }
308 
309         private final class EchoTask implements Runnable {
310             @Override
311             public void run() {
312                 try {
313                     readMessage();
314                     renegotiate();
315                     reply();
316                 } catch (Throwable e) {
317                     e.printStackTrace();
318                     throw new RuntimeException(e);
319                 }
320             }
321 
322             private void renegotiate() throws Exception {
323                 // Remove the current cipher from the set and renegotiate to force a new
324                 // cipher to be selected.
325                 String currentCipher = engine.getSession().getCipherSuite();
326                 ciphers.remove(currentCipher);
327                 engine.setEnabledCipherSuites(ciphers.toArray(new String[ciphers.size()]));
328                 doHandshake();
329             }
330 
331             private void reply() throws IOException {
332                 SSLEngineResult result = wrap(newMessage());
333                 if (result.getStatus() != Status.OK) {
334                     throw new RuntimeException("Wrap failed. Status: " + result.getStatus());
335                 }
336             }
337 
338             private ByteBuffer newMessage() {
339                 return MESSAGE_BUFFER.duplicate();
340             }
341 
342             private void readMessage() throws IOException {
343                 int totalProduced = 0;
344                 while (!stopping) {
345                     SSLEngineResult result = unwrap();
346                     if (result.getStatus() != Status.OK) {
347                         throw new RuntimeException("Failed reading message: " + result);
348                     }
349                     totalProduced += result.bytesProduced();
350                     if (totalProduced == MESSAGE_LENGTH) {
351                         return;
352                     }
353                 }
354             }
355         }
356 
357         private SSLEngineResult wrap(ByteBuffer src) throws IOException {
358             outboundPacketBuffer.clear();
359 
360             // Check if the engine has bytes to wrap.
361             SSLEngineResult result = engine.wrap(src, outboundPacketBuffer);
362 
363             // Write any wrapped bytes to the socket.
364             outboundPacketBuffer.flip();
365 
366             do {
367                 channel.write(outboundPacketBuffer);
368             } while (outboundPacketBuffer.hasRemaining());
369 
370             return result;
371         }
372 
373         private SSLEngineResult unwrap() throws IOException {
374             // Unwrap any available bytes from the socket.
375             SSLEngineResult result = null;
376             boolean done = false;
377             while (!done) {
378                 if (channel.read(inboundPacketBuffer) == -1) {
379                     throw new EOFException();
380                 }
381                 // Just clear the app buffer - we don't really use it.
382                 inboundAppBuffer.clear();
383                 inboundPacketBuffer.flip();
384                 result = engine.unwrap(inboundPacketBuffer, inboundAppBuffer);
385                 switch (result.getStatus()) {
386                     case BUFFER_UNDERFLOW:
387                         // Continue reading from the socket in a moment.
388                         try {
389                             Thread.sleep(10);
390                         } catch (InterruptedException e) {
391                             throw new RuntimeException(e);
392                         }
393                         break;
394                     case OK:
395                         done = true;
396                         break;
397                     default: { throw new RuntimeException("Unexpected unwrap result: " + result); }
398                 }
399 
400                 // Compact for the next socket read.
401                 inboundPacketBuffer.compact();
402             }
403             return result;
404         }
405 
406         private void doHandshake() throws IOException {
407             engine.beginHandshake();
408 
409             boolean done = false;
410             while (!done) {
411                 switch (engine.getHandshakeStatus()) {
412                     case NEED_WRAP: {
413                         wrap(EMPTY_BUFFER);
414                         break;
415                     }
416                     case NEED_UNWRAP: {
417                         unwrap();
418                         break;
419                     }
420                     case NEED_TASK: {
421                         runDelegatedTasks();
422                         break;
423                     }
424                     default: {
425                         done = true;
426                         break;
427                     }
428                 }
429             }
430         }
431 
432         private void runDelegatedTasks() {
433             for (;;) {
434                 Runnable task = engine.getDelegatedTask();
435                 if (task == null) {
436                     break;
437                 }
438                 task.run();
439             }
440         }
441     }
442 }
443