• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "tools/android/forwarder2/socket.h"
6 
7 #include <arpa/inet.h>
8 #include <fcntl.h>
9 #include <netdb.h>
10 #include <netinet/in.h>
11 #include <stdio.h>
12 #include <string.h>
13 #include <sys/socket.h>
14 #include <sys/types.h>
15 #include <unistd.h>
16 
17 #include "base/logging.h"
18 #include "base/posix/eintr_wrapper.h"
19 #include "base/safe_strerror_posix.h"
20 #include "tools/android/common/net.h"
21 #include "tools/android/forwarder2/common.h"
22 
23 namespace {
24 const int kNoTimeout = -1;
25 const int kConnectTimeOut = 10;  // Seconds.
26 
FamilyIsTCP(int family)27 bool FamilyIsTCP(int family) {
28   return family == AF_INET || family == AF_INET6;
29 }
30 }  // namespace
31 
32 namespace forwarder2 {
33 
BindUnix(const std::string & path)34 bool Socket::BindUnix(const std::string& path) {
35   errno = 0;
36   if (!InitUnixSocket(path) || !BindAndListen()) {
37     Close();
38     return false;
39   }
40   return true;
41 }
42 
BindTcp(const std::string & host,int port)43 bool Socket::BindTcp(const std::string& host, int port) {
44   errno = 0;
45   if (!InitTcpSocket(host, port) || !BindAndListen()) {
46     Close();
47     return false;
48   }
49   return true;
50 }
51 
ConnectUnix(const std::string & path)52 bool Socket::ConnectUnix(const std::string& path) {
53   errno = 0;
54   if (!InitUnixSocket(path) || !Connect()) {
55     Close();
56     return false;
57   }
58   return true;
59 }
60 
ConnectTcp(const std::string & host,int port)61 bool Socket::ConnectTcp(const std::string& host, int port) {
62   errno = 0;
63   if (!InitTcpSocket(host, port) || !Connect()) {
64     Close();
65     return false;
66   }
67   return true;
68 }
69 
Socket()70 Socket::Socket()
71     : socket_(-1),
72       port_(0),
73       socket_error_(false),
74       family_(AF_INET),
75       addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)),
76       addr_len_(sizeof(sockaddr)) {
77   memset(&addr_, 0, sizeof(addr_));
78 }
79 
~Socket()80 Socket::~Socket() {
81   Close();
82 }
83 
Shutdown()84 void Socket::Shutdown() {
85   if (!IsClosed()) {
86     PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR));
87   }
88 }
89 
Close()90 void Socket::Close() {
91   if (!IsClosed()) {
92     CloseFD(socket_);
93     socket_ = -1;
94   }
95 }
96 
InitSocketInternal()97 bool Socket::InitSocketInternal() {
98   socket_ = socket(family_, SOCK_STREAM, 0);
99   if (socket_ < 0) {
100     PLOG(ERROR) << "socket";
101     return false;
102   }
103   tools::DisableNagle(socket_);
104   int reuse_addr = 1;
105   setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &reuse_addr,
106              sizeof(reuse_addr));
107   if (!SetNonBlocking())
108     return false;
109   return true;
110 }
111 
SetNonBlocking()112 bool Socket::SetNonBlocking() {
113   const int flags = fcntl(socket_, F_GETFL);
114   if (flags < 0) {
115     PLOG(ERROR) << "fcntl";
116     return false;
117   }
118   if (flags & O_NONBLOCK)
119     return true;
120   if (fcntl(socket_, F_SETFL, flags | O_NONBLOCK) < 0) {
121     PLOG(ERROR) << "fcntl";
122     return false;
123   }
124   return true;
125 }
126 
InitUnixSocket(const std::string & path)127 bool Socket::InitUnixSocket(const std::string& path) {
128   static const size_t kPathMax = sizeof(addr_.addr_un.sun_path);
129   // For abstract sockets we need one extra byte for the leading zero.
130   if (path.size() + 2 /* '\0' */ > kPathMax) {
131     LOG(ERROR) << "The provided path is too big to create a unix "
132                << "domain socket: " << path;
133     return false;
134   }
135   family_ = PF_UNIX;
136   addr_.addr_un.sun_family = family_;
137   // Copied from net/socket/unix_domain_socket_posix.cc
138   // Convert the path given into abstract socket name. It must start with
139   // the '\0' character, so we are adding it. |addr_len| must specify the
140   // length of the structure exactly, as potentially the socket name may
141   // have '\0' characters embedded (although we don't support this).
142   // Note that addr_.addr_un.sun_path is already zero initialized.
143   memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size());
144   addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
145   addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un);
146   return InitSocketInternal();
147 }
148 
InitTcpSocket(const std::string & host,int port)149 bool Socket::InitTcpSocket(const std::string& host, int port) {
150   port_ = port;
151   if (host.empty()) {
152     // Use localhost: INADDR_LOOPBACK
153     family_ = AF_INET;
154     addr_.addr4.sin_family = family_;
155     addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
156   } else if (!Resolve(host)) {
157     return false;
158   }
159   CHECK(FamilyIsTCP(family_)) << "Invalid socket family.";
160   if (family_ == AF_INET) {
161     addr_.addr4.sin_port = htons(port_);
162     addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4);
163     addr_len_ = sizeof(addr_.addr4);
164   } else if (family_ == AF_INET6) {
165     addr_.addr6.sin6_port = htons(port_);
166     addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6);
167     addr_len_ = sizeof(addr_.addr6);
168   }
169   return InitSocketInternal();
170 }
171 
BindAndListen()172 bool Socket::BindAndListen() {
173   errno = 0;
174   if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 ||
175       HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) {
176     PLOG(ERROR) << "bind/listen";
177     SetSocketError();
178     return false;
179   }
180   if (port_ == 0 && FamilyIsTCP(family_)) {
181     SockAddr addr;
182     memset(&addr, 0, sizeof(addr));
183     socklen_t addrlen = 0;
184     sockaddr* addr_ptr = NULL;
185     uint16* port_ptr = NULL;
186     if (family_ == AF_INET) {
187       addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4);
188       port_ptr = &addr.addr4.sin_port;
189       addrlen = sizeof(addr.addr4);
190     } else if (family_ == AF_INET6) {
191       addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6);
192       port_ptr = &addr.addr6.sin6_port;
193       addrlen = sizeof(addr.addr6);
194     }
195     errno = 0;
196     if (getsockname(socket_, addr_ptr, &addrlen) != 0) {
197       PLOG(ERROR) << "getsockname";
198       SetSocketError();
199       return false;
200     }
201     port_ = ntohs(*port_ptr);
202   }
203   return true;
204 }
205 
Accept(Socket * new_socket)206 bool Socket::Accept(Socket* new_socket) {
207   DCHECK(new_socket != NULL);
208   if (!WaitForEvent(READ, kNoTimeout)) {
209     SetSocketError();
210     return false;
211   }
212   errno = 0;
213   int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL));
214   if (new_socket_fd < 0) {
215     SetSocketError();
216     return false;
217   }
218   tools::DisableNagle(new_socket_fd);
219   new_socket->socket_ = new_socket_fd;
220   if (!new_socket->SetNonBlocking())
221     return false;
222   return true;
223 }
224 
Connect()225 bool Socket::Connect() {
226   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
227   errno = 0;
228   if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 &&
229       errno != EINPROGRESS) {
230     SetSocketError();
231     return false;
232   }
233   // Wait for connection to complete, or receive a notification.
234   if (!WaitForEvent(WRITE, kConnectTimeOut)) {
235     SetSocketError();
236     return false;
237   }
238   int socket_errno;
239   socklen_t opt_len = sizeof(socket_errno);
240   if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) {
241     PLOG(ERROR) << "getsockopt()";
242     SetSocketError();
243     return false;
244   }
245   if (socket_errno != 0) {
246     LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno);
247     SetSocketError();
248     return false;
249   }
250   return true;
251 }
252 
Resolve(const std::string & host)253 bool Socket::Resolve(const std::string& host) {
254   struct addrinfo hints;
255   struct addrinfo* res;
256   memset(&hints, 0, sizeof(hints));
257   hints.ai_family = AF_UNSPEC;
258   hints.ai_socktype = SOCK_STREAM;
259   hints.ai_flags |= AI_CANONNAME;
260 
261   int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res);
262   if (errcode != 0) {
263     errno = 0;
264     SetSocketError();
265     freeaddrinfo(res);
266     return false;
267   }
268   family_ = res->ai_family;
269   switch (res->ai_family) {
270     case AF_INET:
271       memcpy(&addr_.addr4,
272              reinterpret_cast<sockaddr_in*>(res->ai_addr),
273              sizeof(sockaddr_in));
274       break;
275     case AF_INET6:
276       memcpy(&addr_.addr6,
277              reinterpret_cast<sockaddr_in6*>(res->ai_addr),
278              sizeof(sockaddr_in6));
279       break;
280   }
281   freeaddrinfo(res);
282   return true;
283 }
284 
GetPort()285 int Socket::GetPort() {
286   if (!FamilyIsTCP(family_)) {
287     LOG(ERROR) << "Can't call GetPort() on an unix domain socket.";
288     return 0;
289   }
290   return port_;
291 }
292 
ReadNumBytes(void * buffer,size_t num_bytes)293 int Socket::ReadNumBytes(void* buffer, size_t num_bytes) {
294   int bytes_read = 0;
295   int ret = 1;
296   while (bytes_read < num_bytes && ret > 0) {
297     ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read);
298     if (ret >= 0)
299       bytes_read += ret;
300   }
301   return bytes_read;
302 }
303 
SetSocketError()304 void Socket::SetSocketError() {
305   socket_error_ = true;
306   DCHECK_NE(EAGAIN, errno);
307   DCHECK_NE(EWOULDBLOCK, errno);
308   Close();
309 }
310 
Read(void * buffer,size_t buffer_size)311 int Socket::Read(void* buffer, size_t buffer_size) {
312   if (!WaitForEvent(READ, kNoTimeout)) {
313     SetSocketError();
314     return 0;
315   }
316   int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
317   if (ret < 0) {
318     PLOG(ERROR) << "read";
319     SetSocketError();
320   }
321   return ret;
322 }
323 
NonBlockingRead(void * buffer,size_t buffer_size)324 int Socket::NonBlockingRead(void* buffer, size_t buffer_size) {
325   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
326   int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
327   if (ret < 0) {
328     PLOG(ERROR) << "read";
329     SetSocketError();
330   }
331   return ret;
332 }
333 
Write(const void * buffer,size_t count)334 int Socket::Write(const void* buffer, size_t count) {
335   if (!WaitForEvent(WRITE, kNoTimeout)) {
336     SetSocketError();
337     return 0;
338   }
339   int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
340   if (ret < 0) {
341     PLOG(ERROR) << "send";
342     SetSocketError();
343   }
344   return ret;
345 }
346 
NonBlockingWrite(const void * buffer,size_t count)347 int Socket::NonBlockingWrite(const void* buffer, size_t count) {
348   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
349   int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
350   if (ret < 0) {
351     PLOG(ERROR) << "send";
352     SetSocketError();
353   }
354   return ret;
355 }
356 
WriteString(const std::string & buffer)357 int Socket::WriteString(const std::string& buffer) {
358   return WriteNumBytes(buffer.c_str(), buffer.size());
359 }
360 
AddEventFd(int event_fd)361 void Socket::AddEventFd(int event_fd) {
362   Event event;
363   event.fd = event_fd;
364   event.was_fired = false;
365   events_.push_back(event);
366 }
367 
DidReceiveEventOnFd(int fd) const368 bool Socket::DidReceiveEventOnFd(int fd) const {
369   for (size_t i = 0; i < events_.size(); ++i)
370     if (events_[i].fd == fd)
371       return events_[i].was_fired;
372   return false;
373 }
374 
DidReceiveEvent() const375 bool Socket::DidReceiveEvent() const {
376   for (size_t i = 0; i < events_.size(); ++i)
377     if (events_[i].was_fired)
378       return true;
379   return false;
380 }
381 
WriteNumBytes(const void * buffer,size_t num_bytes)382 int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) {
383   int bytes_written = 0;
384   int ret = 1;
385   while (bytes_written < num_bytes && ret > 0) {
386     ret = Write(static_cast<const char*>(buffer) + bytes_written,
387                 num_bytes - bytes_written);
388     if (ret >= 0)
389       bytes_written += ret;
390   }
391   return bytes_written;
392 }
393 
WaitForEvent(EventType type,int timeout_secs)394 bool Socket::WaitForEvent(EventType type, int timeout_secs) {
395   if (socket_ == -1)
396     return true;
397   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
398   fd_set read_fds;
399   fd_set write_fds;
400   FD_ZERO(&read_fds);
401   FD_ZERO(&write_fds);
402   if (type == READ)
403     FD_SET(socket_, &read_fds);
404   else
405     FD_SET(socket_, &write_fds);
406   for (size_t i = 0; i < events_.size(); ++i)
407     FD_SET(events_[i].fd, &read_fds);
408   timeval tv = {};
409   timeval* tv_ptr = NULL;
410   if (timeout_secs > 0) {
411     tv.tv_sec = timeout_secs;
412     tv.tv_usec = 0;
413     tv_ptr = &tv;
414   }
415   int max_fd = socket_;
416   for (size_t i = 0; i < events_.size(); ++i)
417     if (events_[i].fd > max_fd)
418       max_fd = events_[i].fd;
419   if (HANDLE_EINTR(
420           select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) {
421     PLOG(ERROR) << "select";
422     return false;
423   }
424   bool event_was_fired = false;
425   for (size_t i = 0; i < events_.size(); ++i) {
426     if (FD_ISSET(events_[i].fd, &read_fds)) {
427       events_[i].was_fired = true;
428       event_was_fired = true;
429     }
430   }
431   return !event_was_fired;
432 }
433 
434 // static
GetUnixDomainSocketProcessOwner(const std::string & path)435 pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) {
436   Socket socket;
437   if (!socket.ConnectUnix(path))
438     return -1;
439   ucred ucred;
440   socklen_t len = sizeof(ucred);
441   if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) {
442     CHECK_NE(ENOPROTOOPT, errno);
443     return -1;
444   }
445   return ucred.pid;
446 }
447 
448 }  // namespace forwarder2
449