• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <errno.h>
6 #include <fcntl.h>
7 #include <netinet/in.h>
8 #include <netinet/tcp.h>
9 #include <pthread.h>
10 #include <signal.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14 #include <sys/select.h>
15 #include <sys/socket.h>
16 #include <sys/wait.h>
17 #include <unistd.h>
18 
19 #include "base/command_line.h"
20 #include "base/logging.h"
21 #include "base/posix/eintr_wrapper.h"
22 #include "tools/android/common/adb_connection.h"
23 #include "tools/android/common/daemon.h"
24 #include "tools/android/common/net.h"
25 
26 namespace {
27 
28 const pthread_t kInvalidThread = static_cast<pthread_t>(-1);
29 volatile bool g_killed = false;
30 
CloseSocket(int fd)31 void CloseSocket(int fd) {
32   if (fd >= 0) {
33     int old_errno = errno;
34     close(fd);
35     errno = old_errno;
36   }
37 }
38 
39 class Buffer {
40  public:
Buffer()41   Buffer()
42       : bytes_read_(0),
43         write_offset_(0) {
44   }
45 
CanRead()46   bool CanRead() {
47     return bytes_read_ == 0;
48   }
49 
CanWrite()50   bool CanWrite() {
51     return write_offset_ < bytes_read_;
52   }
53 
Read(int fd)54   int Read(int fd) {
55     int ret = -1;
56     if (CanRead()) {
57       ret = HANDLE_EINTR(read(fd, buffer_, kBufferSize));
58       if (ret > 0)
59         bytes_read_ = ret;
60     }
61     return ret;
62   }
63 
Write(int fd)64   int Write(int fd) {
65     int ret = -1;
66     if (CanWrite()) {
67       ret = HANDLE_EINTR(write(fd, buffer_ + write_offset_,
68                                bytes_read_ - write_offset_));
69       if (ret > 0) {
70         write_offset_ += ret;
71         if (write_offset_ == bytes_read_) {
72           write_offset_ = 0;
73           bytes_read_ = 0;
74         }
75       }
76     }
77     return ret;
78   }
79 
80  private:
81   // A big buffer to let our file-over-http bridge work more like real file.
82   static const int kBufferSize = 1024 * 128;
83   int bytes_read_;
84   int write_offset_;
85   char buffer_[kBufferSize];
86 
87   DISALLOW_COPY_AND_ASSIGN(Buffer);
88 };
89 
90 class Server;
91 
92 struct ForwarderThreadInfo {
ForwarderThreadInfo__anone12055a00111::ForwarderThreadInfo93   ForwarderThreadInfo(Server* a_server, int a_forwarder_index)
94       : server(a_server),
95         forwarder_index(a_forwarder_index) {
96   }
97   Server* server;
98   int forwarder_index;
99 };
100 
101 struct ForwarderInfo {
102   time_t start_time;
103   int socket1;
104   time_t socket1_last_byte_time;
105   size_t socket1_bytes;
106   int socket2;
107   time_t socket2_last_byte_time;
108   size_t socket2_bytes;
109 };
110 
111 class Server {
112  public:
Server()113   Server()
114       : thread_(kInvalidThread),
115         socket_(-1) {
116     memset(forward_to_, 0, sizeof(forward_to_));
117     memset(&forwarders_, 0, sizeof(forwarders_));
118   }
119 
GetFreeForwarderIndex()120   int GetFreeForwarderIndex() {
121     for (int i = 0; i < kMaxForwarders; i++) {
122       if (forwarders_[i].start_time == 0)
123         return i;
124     }
125     return -1;
126   }
127 
DisposeForwarderInfo(int index)128   void DisposeForwarderInfo(int index) {
129     forwarders_[index].start_time = 0;
130   }
131 
GetForwarderInfo(int index)132   ForwarderInfo* GetForwarderInfo(int index) {
133     return &forwarders_[index];
134   }
135 
DumpInformation()136   void DumpInformation() {
137     LOG(INFO) << "Server information: " << forward_to_;
138     LOG(INFO) << "No.: age up(bytes,idle) down(bytes,idle)";
139     int count = 0;
140     time_t now = time(NULL);
141     for (int i = 0; i < kMaxForwarders; i++) {
142       const ForwarderInfo& info = forwarders_[i];
143       if (info.start_time) {
144         count++;
145         LOG(INFO) << count << ": " << now - info.start_time << " up("
146                   << info.socket1_bytes << ","
147                   << now - info.socket1_last_byte_time << " down("
148                   << info.socket2_bytes << ","
149                   << now - info.socket2_last_byte_time << ")";
150       }
151     }
152   }
153 
Shutdown()154   void Shutdown() {
155     if (socket_ >= 0)
156       shutdown(socket_, SHUT_RDWR);
157   }
158 
159   bool InitSocket(const char* arg);
160 
StartThread()161   void StartThread() {
162     pthread_create(&thread_, NULL, ServerThread, this);
163   }
164 
JoinThread()165   void JoinThread() {
166     if (thread_ != kInvalidThread)
167       pthread_join(thread_, NULL);
168   }
169 
170  private:
171   static void* ServerThread(void* arg);
172 
173   // There are 3 kinds of threads that will access the array:
174   // 1. Server thread will get a free ForwarderInfo and initialize it;
175   // 2. Forwarder threads will dispose the ForwarderInfo when it finishes;
176   // 3. Main thread will iterate and print the forwarders.
177   // Using an array is not optimal, but can avoid locks or other complex
178   // inter-thread communication.
179   static const int kMaxForwarders = 512;
180   ForwarderInfo forwarders_[kMaxForwarders];
181 
182   pthread_t thread_;
183   int socket_;
184   char forward_to_[40];
185 
186   DISALLOW_COPY_AND_ASSIGN(Server);
187 };
188 
189 // Forwards all outputs from one socket to another socket.
ForwarderThread(void * arg)190 void* ForwarderThread(void* arg) {
191   ForwarderThreadInfo* thread_info =
192       reinterpret_cast<ForwarderThreadInfo*>(arg);
193   Server* server = thread_info->server;
194   int index = thread_info->forwarder_index;
195   delete thread_info;
196   ForwarderInfo* info = server->GetForwarderInfo(index);
197   int socket1 = info->socket1;
198   int socket2 = info->socket2;
199   int nfds = socket1 > socket2 ? socket1 + 1 : socket2 + 1;
200   fd_set read_fds;
201   fd_set write_fds;
202   Buffer buffer1;
203   Buffer buffer2;
204 
205   while (!g_killed) {
206     FD_ZERO(&read_fds);
207     if (buffer1.CanRead())
208       FD_SET(socket1, &read_fds);
209     if (buffer2.CanRead())
210       FD_SET(socket2, &read_fds);
211 
212     FD_ZERO(&write_fds);
213     if (buffer1.CanWrite())
214       FD_SET(socket2, &write_fds);
215     if (buffer2.CanWrite())
216       FD_SET(socket1, &write_fds);
217 
218     if (HANDLE_EINTR(select(nfds, &read_fds, &write_fds, NULL, NULL)) <= 0) {
219       LOG(ERROR) << "Select error: " << strerror(errno);
220       break;
221     }
222 
223     int now = time(NULL);
224     if (FD_ISSET(socket1, &read_fds)) {
225       info->socket1_last_byte_time = now;
226       int bytes = buffer1.Read(socket1);
227       if (bytes <= 0)
228         break;
229       info->socket1_bytes += bytes;
230     }
231     if (FD_ISSET(socket2, &read_fds)) {
232       info->socket2_last_byte_time = now;
233       int bytes = buffer2.Read(socket2);
234       if (bytes <= 0)
235         break;
236       info->socket2_bytes += bytes;
237     }
238     if (FD_ISSET(socket1, &write_fds)) {
239       if (buffer2.Write(socket1) <= 0)
240         break;
241     }
242     if (FD_ISSET(socket2, &write_fds)) {
243       if (buffer1.Write(socket2) <= 0)
244         break;
245     }
246   }
247 
248   CloseSocket(socket1);
249   CloseSocket(socket2);
250   server->DisposeForwarderInfo(index);
251   return NULL;
252 }
253 
254 // Listens to a server socket. On incoming request, forward it to the host.
255 // static
ServerThread(void * arg)256 void* Server::ServerThread(void* arg) {
257   Server* server = reinterpret_cast<Server*>(arg);
258   while (!g_killed) {
259     int forwarder_index = server->GetFreeForwarderIndex();
260     if (forwarder_index < 0) {
261       LOG(ERROR) << "Too many forwarders";
262       continue;
263     }
264 
265     struct sockaddr_in addr;
266     socklen_t addr_len = sizeof(addr);
267     int socket = HANDLE_EINTR(accept(server->socket_,
268                                      reinterpret_cast<sockaddr*>(&addr),
269                                      &addr_len));
270     if (socket < 0) {
271       LOG(ERROR) << "Failed to accept: " << strerror(errno);
272       break;
273     }
274     tools::DisableNagle(socket);
275 
276     int host_socket = tools::ConnectAdbHostSocket(server->forward_to_);
277     if (host_socket >= 0) {
278       // Set NONBLOCK flag because we use select().
279       fcntl(socket, F_SETFL, fcntl(socket, F_GETFL) | O_NONBLOCK);
280       fcntl(host_socket, F_SETFL, fcntl(host_socket, F_GETFL) | O_NONBLOCK);
281 
282       ForwarderInfo* forwarder_info = server->GetForwarderInfo(forwarder_index);
283       time_t now = time(NULL);
284       forwarder_info->start_time = now;
285       forwarder_info->socket1 = socket;
286       forwarder_info->socket1_last_byte_time = now;
287       forwarder_info->socket1_bytes = 0;
288       forwarder_info->socket2 = host_socket;
289       forwarder_info->socket2_last_byte_time = now;
290       forwarder_info->socket2_bytes = 0;
291 
292       pthread_t thread;
293       pthread_create(&thread, NULL, ForwarderThread,
294                      new ForwarderThreadInfo(server, forwarder_index));
295     } else {
296       // Close the unused client socket which is failed to connect to host.
297       CloseSocket(socket);
298     }
299   }
300 
301   CloseSocket(server->socket_);
302   server->socket_ = -1;
303   return NULL;
304 }
305 
306 // Format of arg: <Device port>[:<Forward to port>:<Forward to address>]
InitSocket(const char * arg)307 bool Server::InitSocket(const char* arg) {
308   char* endptr;
309   int local_port = static_cast<int>(strtol(arg, &endptr, 10));
310   if (local_port < 0)
311     return false;
312 
313   if (*endptr != ':') {
314     snprintf(forward_to_, sizeof(forward_to_), "%d:127.0.0.1", local_port);
315   } else {
316     strncpy(forward_to_, endptr + 1, sizeof(forward_to_) - 1);
317   }
318 
319   socket_ = socket(AF_INET, SOCK_STREAM, 0);
320   if (socket_ < 0) {
321     perror("server socket");
322     return false;
323   }
324   tools::DisableNagle(socket_);
325 
326   sockaddr_in addr;
327   memset(&addr, 0, sizeof(addr));
328   addr.sin_family = AF_INET;
329   addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
330   addr.sin_port = htons(local_port);
331   int reuse_addr = 1;
332   setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR,
333              &reuse_addr, sizeof(reuse_addr));
334   tools::DeferAccept(socket_);
335   if (HANDLE_EINTR(bind(socket_, reinterpret_cast<sockaddr*>(&addr),
336                         sizeof(addr))) < 0 ||
337       HANDLE_EINTR(listen(socket_, 5)) < 0) {
338     perror("server bind");
339     CloseSocket(socket_);
340     socket_ = -1;
341     return false;
342   }
343 
344   if (local_port == 0) {
345     socklen_t addrlen = sizeof(addr);
346     if (getsockname(socket_, reinterpret_cast<sockaddr*>(&addr), &addrlen)
347         != 0) {
348       perror("get listen address");
349       CloseSocket(socket_);
350       socket_ = -1;
351       return false;
352     }
353     local_port = ntohs(addr.sin_port);
354   }
355 
356   printf("Forwarding device port %d to host %s\n", local_port, forward_to_);
357   return true;
358 }
359 
360 int g_server_count = 0;
361 Server* g_servers = NULL;
362 
KillHandler(int unused)363 void KillHandler(int unused) {
364   g_killed = true;
365   for (int i = 0; i < g_server_count; i++)
366     g_servers[i].Shutdown();
367 }
368 
DumpInformation(int unused)369 void DumpInformation(int unused) {
370   for (int i = 0; i < g_server_count; i++)
371     g_servers[i].DumpInformation();
372 }
373 
374 }  // namespace
375 
main(int argc,char ** argv)376 int main(int argc, char** argv) {
377   printf("Android device to host TCP forwarder\n");
378   printf("Like 'adb forward' but in the reverse direction\n");
379 
380   CommandLine command_line(argc, argv);
381   CommandLine::StringVector server_args = command_line.GetArgs();
382   if (tools::HasHelpSwitch(command_line) || server_args.empty()) {
383     tools::ShowHelp(
384         argv[0],
385         "<Device port>[:<Forward to port>:<Forward to address>] ...",
386         "  <Forward to port> default is <Device port>\n"
387         "  <Forward to address> default is 127.0.0.1\n"
388         "If <Device port> is 0, a port will by dynamically allocated.\n");
389     return 0;
390   }
391 
392   g_servers = new Server[server_args.size()];
393   g_server_count = 0;
394   int failed_count = 0;
395   for (size_t i = 0; i < server_args.size(); i++) {
396     if (!g_servers[g_server_count].InitSocket(server_args[i].c_str())) {
397       printf("Couldn't start forwarder server for port spec: %s\n",
398              server_args[i].c_str());
399       ++failed_count;
400     } else {
401       ++g_server_count;
402     }
403   }
404 
405   if (g_server_count == 0) {
406     printf("No forwarder servers could be started. Exiting.\n");
407     delete [] g_servers;
408     return failed_count;
409   }
410 
411   if (!tools::HasNoSpawnDaemonSwitch(command_line))
412     tools::SpawnDaemon(failed_count);
413 
414   signal(SIGTERM, KillHandler);
415   signal(SIGUSR2, DumpInformation);
416 
417   for (int i = 0; i < g_server_count; i++)
418     g_servers[i].StartThread();
419   for (int i = 0; i < g_server_count; i++)
420     g_servers[i].JoinThread();
421   g_server_count = 0;
422   delete [] g_servers;
423 
424   return 0;
425 }
426 
427