• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package org.conscrypt;
18 
19 import java.io.EOFException;
20 import java.io.IOException;
21 import java.io.InputStream;
22 import java.io.OutputStream;
23 import java.net.ServerSocket;
24 import java.net.SocketException;
25 import java.nio.channels.ClosedChannelException;
26 import java.util.concurrent.ExecutionException;
27 import java.util.concurrent.ExecutorService;
28 import java.util.concurrent.Executors;
29 import java.util.concurrent.Future;
30 import java.util.concurrent.TimeUnit;
31 import java.util.concurrent.TimeoutException;
32 import javax.net.ssl.SSLException;
33 import javax.net.ssl.SSLServerSocketFactory;
34 import javax.net.ssl.SSLSocket;
35 import javax.net.ssl.SSLSocketFactory;
36 
37 /**
38  * A simple socket-based test server.
39  */
40 final class ServerEndpoint {
41     /**
42      * A processor for receipt of a single message.
43      */
44     public interface MessageProcessor {
processMessage(byte[] message, int numBytes, OutputStream os)45         void processMessage(byte[] message, int numBytes, OutputStream os);
46     }
47 
48     /**
49      * A {@link MessageProcessor} that simply echos back the received message to the client.
50      */
51     public static final class EchoProcessor implements MessageProcessor {
52         @Override
processMessage(byte[] message, int numBytes, OutputStream os)53         public void processMessage(byte[] message, int numBytes, OutputStream os) {
54             try {
55                 os.write(message, 0, numBytes);
56                 os.flush();
57             } catch (IOException e) {
58                 throw new RuntimeException(e);
59             }
60         }
61     }
62 
63     private final ServerSocket serverSocket;
64     private final ChannelType channelType;
65     private final SSLSocketFactory socketFactory;
66     private final int messageSize;
67     private final String[] protocols;
68     private final String[] cipherSuites;
69     private final byte[] buffer;
70     private SSLSocket socket;
71     private ExecutorService executor;
72     private InputStream inputStream;
73     private OutputStream outputStream;
74     private volatile boolean stopping;
75     private volatile MessageProcessor messageProcessor = new EchoProcessor();
76     private volatile Future<?> processFuture;
77 
ServerEndpoint(SSLSocketFactory socketFactory, SSLServerSocketFactory serverSocketFactory, ChannelType channelType, int messageSize, String[] protocols, String[] cipherSuites)78     ServerEndpoint(SSLSocketFactory socketFactory, SSLServerSocketFactory serverSocketFactory,
79             ChannelType channelType, int messageSize, String[] protocols,
80             String[] cipherSuites) throws IOException {
81         this.serverSocket = channelType.newServerSocket(serverSocketFactory);
82         this.socketFactory = socketFactory;
83         this.channelType = channelType;
84         this.messageSize = messageSize;
85         this.protocols = protocols;
86         this.cipherSuites = cipherSuites;
87         buffer = new byte[messageSize];
88     }
89 
setMessageProcessor(MessageProcessor messageProcessor)90     void setMessageProcessor(MessageProcessor messageProcessor) {
91         this.messageProcessor = messageProcessor;
92     }
93 
start()94     Future<?> start() throws IOException {
95         executor = Executors.newSingleThreadExecutor();
96         return executor.submit(new AcceptTask());
97     }
98 
stop()99     void stop() {
100         try {
101             stopping = true;
102 
103             if (socket != null) {
104                 socket.close();
105                 socket = null;
106             }
107 
108             if (processFuture != null) {
109                 processFuture.get(5, TimeUnit.SECONDS);
110             }
111 
112             serverSocket.close();
113 
114             if (executor != null) {
115                 executor.shutdown();
116                 executor.awaitTermination(5, TimeUnit.SECONDS);
117                 executor = null;
118             }
119         } catch (IOException | InterruptedException | ExecutionException | TimeoutException e) {
120             throw new RuntimeException(e);
121         }
122     }
123 
port()124     public int port() {
125         return serverSocket.getLocalPort();
126     }
127 
128     private final class AcceptTask implements Runnable {
129         @Override
run()130         public void run() {
131             try {
132                 if (stopping) {
133                     return;
134                 }
135                 socket = channelType.accept(serverSocket, socketFactory);
136                 socket.setEnabledProtocols(protocols);
137                 socket.setEnabledCipherSuites(cipherSuites);
138 
139                 socket.startHandshake();
140 
141                 inputStream = socket.getInputStream();
142                 outputStream = socket.getOutputStream();
143 
144                 if (stopping) {
145                     return;
146                 }
147                 processFuture = executor.submit(new ProcessTask());
148             } catch (IOException e) {
149                 e.printStackTrace();
150                 throw new RuntimeException(e);
151             }
152         }
153     }
154 
155     private final class ProcessTask implements Runnable {
156         @Override
run()157         public void run() {
158             try {
159                 Thread thread = Thread.currentThread();
160                 while (!stopping && !thread.isInterrupted()) {
161                     int bytesRead = readMessage();
162                     if (!stopping && !thread.isInterrupted()) {
163                         messageProcessor.processMessage(buffer, bytesRead, outputStream);
164                     }
165                 }
166             } catch (Throwable e) {
167                 throw new RuntimeException(e);
168             }
169         }
170 
readMessage()171         private int readMessage() throws IOException {
172             int totalBytesRead = 0;
173             while (!stopping && totalBytesRead < messageSize) {
174                 try {
175                     int remaining = messageSize - totalBytesRead;
176                     int bytesRead = inputStream.read(buffer, totalBytesRead, remaining);
177                     if (bytesRead == -1) {
178                         break;
179                     }
180                     totalBytesRead += bytesRead;
181                 } catch (SSLException e) {
182                     if (e.getCause() instanceof EOFException) {
183                         break;
184                     }
185                     throw e;
186                 } catch (ClosedChannelException e) {
187                     // Thrown for channel-based sockets. Just treat like EOF.
188                     break;
189                 } catch (SocketException e) {
190                     // The socket was broken. Just treat like EOF.
191                     break;
192                 }
193             }
194             return totalBytesRead;
195         }
196     }
197 }
198