1 /* 2 * Copyright (C) 2014 Square, Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package com.squareup.okhttp; 17 18 import com.squareup.okhttp.internal.NamedRunnable; 19 import com.squareup.okhttp.internal.Util; 20 import java.io.IOException; 21 import java.net.InetAddress; 22 import java.net.InetSocketAddress; 23 import java.net.ProtocolException; 24 import java.net.Proxy; 25 import java.net.ServerSocket; 26 import java.net.Socket; 27 import java.net.SocketException; 28 import java.util.concurrent.ExecutorService; 29 import java.util.concurrent.Executors; 30 import java.util.concurrent.TimeUnit; 31 import java.util.concurrent.atomic.AtomicInteger; 32 import java.util.logging.Level; 33 import java.util.logging.Logger; 34 import okio.Buffer; 35 import okio.BufferedSink; 36 import okio.BufferedSource; 37 import okio.Okio; 38 39 /** 40 * A limited implementation of SOCKS Protocol Version 5, intended to be similar to MockWebServer. 41 * See <a href="https://www.ietf.org/rfc/rfc1928.txt">RFC 1928</a>. 42 */ 43 public final class SocksProxy { 44 public final String HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS = "onlyProxyCanResolveMe.org"; 45 46 private static final int VERSION_5 = 5; 47 private static final int METHOD_NONE = 0xff; 48 private static final int METHOD_NO_AUTHENTICATION_REQUIRED = 0; 49 private static final int ADDRESS_TYPE_IPV4 = 1; 50 private static final int ADDRESS_TYPE_DOMAIN_NAME = 3; 51 private static final int COMMAND_CONNECT = 1; 52 private static final int REPLY_SUCCEEDED = 0; 53 54 private static final Logger logger = Logger.getLogger(SocksProxy.class.getName()); 55 56 private final ExecutorService executor = Executors.newCachedThreadPool( 57 Util.threadFactory("SocksProxy", false)); 58 59 private ServerSocket serverSocket; 60 private AtomicInteger connectionCount = new AtomicInteger(); 61 play()62 public void play() throws IOException { 63 serverSocket = new ServerSocket(0); 64 executor.execute(new NamedRunnable("SocksProxy %s", serverSocket.getLocalPort()) { 65 @Override protected void execute() { 66 try { 67 while (true) { 68 Socket socket = serverSocket.accept(); 69 connectionCount.incrementAndGet(); 70 service(socket); 71 } 72 } catch (SocketException e) { 73 logger.info(name + " done accepting connections: " + e.getMessage()); 74 } catch (IOException e) { 75 logger.log(Level.WARNING, name + " failed unexpectedly", e); 76 } 77 } 78 }); 79 } 80 proxy()81 public Proxy proxy() { 82 return new Proxy(Proxy.Type.SOCKS, InetSocketAddress.createUnresolved( 83 "localhost", serverSocket.getLocalPort())); 84 } 85 connectionCount()86 public int connectionCount() { 87 return connectionCount.get(); 88 } 89 shutdown()90 public void shutdown() throws Exception { 91 serverSocket.close(); 92 executor.shutdown(); 93 if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { 94 throw new IOException("Gave up waiting for executor to shut down"); 95 } 96 } 97 service(final Socket from)98 private void service(final Socket from) { 99 executor.execute(new NamedRunnable("SocksProxy %s", from.getRemoteSocketAddress()) { 100 @Override protected void execute() { 101 try { 102 BufferedSource fromSource = Okio.buffer(Okio.source(from)); 103 BufferedSink fromSink = Okio.buffer(Okio.sink(from)); 104 hello(fromSource, fromSink); 105 acceptCommand(from.getInetAddress(), fromSource, fromSink); 106 } catch (IOException e) { 107 logger.log(Level.WARNING, name + " failed", e); 108 Util.closeQuietly(from); 109 } 110 } 111 }); 112 } 113 hello(BufferedSource fromSource, BufferedSink fromSink)114 private void hello(BufferedSource fromSource, BufferedSink fromSink) throws IOException { 115 int version = fromSource.readByte() & 0xff; 116 int methodCount = fromSource.readByte() & 0xff; 117 int selectedMethod = METHOD_NONE; 118 119 if (version != VERSION_5) { 120 throw new ProtocolException("unsupported version: " + version); 121 } 122 123 for (int i = 0; i < methodCount; i++) { 124 int candidateMethod = fromSource.readByte() & 0xff; 125 if (candidateMethod == METHOD_NO_AUTHENTICATION_REQUIRED) { 126 selectedMethod = candidateMethod; 127 } 128 } 129 130 switch (selectedMethod) { 131 case METHOD_NO_AUTHENTICATION_REQUIRED: 132 fromSink.writeByte(VERSION_5); 133 fromSink.writeByte(selectedMethod); 134 fromSink.emit(); 135 break; 136 137 default: 138 throw new ProtocolException("unsupported method: " + selectedMethod); 139 } 140 } 141 acceptCommand(InetAddress fromAddress, BufferedSource fromSource, BufferedSink fromSink)142 private void acceptCommand(InetAddress fromAddress, BufferedSource fromSource, 143 BufferedSink fromSink) throws IOException { 144 // Read the command. 145 int version = fromSource.readByte() & 0xff; 146 if (version != VERSION_5) throw new ProtocolException("unexpected version: " + version); 147 int command = fromSource.readByte() & 0xff; 148 int reserved = fromSource.readByte() & 0xff; 149 if (reserved != 0) throw new ProtocolException("unexpected reserved: " + reserved); 150 151 int addressType = fromSource.readByte() & 0xff; 152 InetAddress toAddress; 153 switch (addressType) { 154 case ADDRESS_TYPE_IPV4: 155 toAddress = InetAddress.getByAddress(fromSource.readByteArray(4L)); 156 break; 157 158 case ADDRESS_TYPE_DOMAIN_NAME: 159 int domainNameLength = fromSource.readByte() & 0xff; 160 String domainName = fromSource.readUtf8(domainNameLength); 161 // Resolve HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS to localhost. 162 toAddress = domainName.equalsIgnoreCase(HOSTNAME_THAT_ONLY_THE_PROXY_KNOWS) 163 ? InetAddress.getByName("localhost") 164 : InetAddress.getByName(domainName); 165 break; 166 167 default: 168 throw new ProtocolException("unsupported address type: " + addressType); 169 } 170 171 int port = fromSource.readShort() & 0xffff; 172 173 switch (command) { 174 case COMMAND_CONNECT: 175 Socket toSocket = new Socket(toAddress, port); 176 byte[] localAddress = toSocket.getLocalAddress().getAddress(); 177 if (localAddress.length != 4) { 178 throw new ProtocolException("unexpected address: " + toSocket.getLocalAddress()); 179 } 180 181 // Write the reply. 182 fromSink.writeByte(VERSION_5); 183 fromSink.writeByte(REPLY_SUCCEEDED); 184 fromSink.writeByte(0); 185 fromSink.writeByte(ADDRESS_TYPE_IPV4); 186 fromSink.write(localAddress); 187 fromSink.writeShort(toSocket.getLocalPort()); 188 fromSink.emit(); 189 190 logger.log(Level.INFO, "SocksProxy connected " + fromAddress + " to " + toAddress); 191 192 // Copy sources to sinks in both directions. 193 BufferedSource toSource = Okio.buffer(Okio.source(toSocket)); 194 BufferedSink toSink = Okio.buffer(Okio.sink(toSocket)); 195 transfer(fromAddress, toAddress, fromSource, toSink); 196 transfer(fromAddress, toAddress, toSource, fromSink); 197 break; 198 199 default: 200 throw new ProtocolException("unexpected command: " + command); 201 } 202 } 203 transfer(final InetAddress fromAddress, final InetAddress toAddress, final BufferedSource source, final BufferedSink sink)204 private void transfer(final InetAddress fromAddress, final InetAddress toAddress, 205 final BufferedSource source, final BufferedSink sink) { 206 executor.execute(new NamedRunnable("SocksProxy %s to %s", fromAddress, toAddress) { 207 @Override protected void execute() { 208 Buffer buffer = new Buffer(); 209 try { 210 while (true) { 211 long byteCount = source.read(buffer, 2048L); 212 if (byteCount == -1L) break; 213 sink.write(buffer, byteCount); 214 sink.emit(); 215 } 216 } catch (SocketException e) { 217 logger.info(name + " done: " + e.getMessage()); 218 } catch (IOException e) { 219 logger.log(Level.WARNING, name + " failed", e); 220 } 221 222 try { 223 source.close(); 224 } catch (IOException e) { 225 logger.log(Level.WARNING, name + " failed", e); 226 } 227 228 try { 229 sink.close(); 230 } catch (IOException e) { 231 logger.log(Level.WARNING, name + " failed", e); 232 } 233 } 234 }); 235 } 236 } 237