• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "tools/android/forwarder2/socket.h"
6 
7 #include <arpa/inet.h>
8 #include <fcntl.h>
9 #include <netdb.h>
10 #include <netinet/in.h>
11 #include <stdio.h>
12 #include <string.h>
13 #include <sys/socket.h>
14 #include <sys/types.h>
15 #include <unistd.h>
16 
17 #include "base/logging.h"
18 #include "base/posix/eintr_wrapper.h"
19 #include "base/safe_strerror_posix.h"
20 #include "tools/android/common/net.h"
21 #include "tools/android/forwarder2/common.h"
22 
23 namespace {
24 const int kNoTimeout = -1;
25 const int kConnectTimeOut = 10;  // Seconds.
26 
FamilyIsTCP(int family)27 bool FamilyIsTCP(int family) {
28   return family == AF_INET || family == AF_INET6;
29 }
30 }  // namespace
31 
32 namespace forwarder2 {
33 
BindUnix(const std::string & path)34 bool Socket::BindUnix(const std::string& path) {
35   errno = 0;
36   if (!InitUnixSocket(path) || !BindAndListen()) {
37     Close();
38     return false;
39   }
40   return true;
41 }
42 
BindTcp(const std::string & host,int port)43 bool Socket::BindTcp(const std::string& host, int port) {
44   errno = 0;
45   if (!InitTcpSocket(host, port) || !BindAndListen()) {
46     Close();
47     return false;
48   }
49   return true;
50 }
51 
ConnectUnix(const std::string & path)52 bool Socket::ConnectUnix(const std::string& path) {
53   errno = 0;
54   if (!InitUnixSocket(path) || !Connect()) {
55     Close();
56     return false;
57   }
58   return true;
59 }
60 
ConnectTcp(const std::string & host,int port)61 bool Socket::ConnectTcp(const std::string& host, int port) {
62   errno = 0;
63   if (!InitTcpSocket(host, port) || !Connect()) {
64     Close();
65     return false;
66   }
67   return true;
68 }
69 
Socket()70 Socket::Socket()
71     : socket_(-1),
72       port_(0),
73       socket_error_(false),
74       family_(AF_INET),
75       addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)),
76       addr_len_(sizeof(sockaddr)) {
77   memset(&addr_, 0, sizeof(addr_));
78 }
79 
~Socket()80 Socket::~Socket() {
81   Close();
82 }
83 
Shutdown()84 void Socket::Shutdown() {
85   if (!IsClosed()) {
86     PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR));
87   }
88 }
89 
Close()90 void Socket::Close() {
91   if (!IsClosed()) {
92     CloseFD(socket_);
93     socket_ = -1;
94   }
95 }
96 
InitSocketInternal()97 bool Socket::InitSocketInternal() {
98   socket_ = socket(family_, SOCK_STREAM, 0);
99   if (socket_ < 0)
100     return false;
101   tools::DisableNagle(socket_);
102   int reuse_addr = 1;
103   setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &reuse_addr,
104              sizeof(reuse_addr));
105   if (!SetNonBlocking())
106     return false;
107   return true;
108 }
109 
SetNonBlocking()110 bool Socket::SetNonBlocking() {
111   const int flags = fcntl(socket_, F_GETFL);
112   if (flags < 0) {
113     PLOG(ERROR) << "fcntl";
114     return false;
115   }
116   if (flags & O_NONBLOCK)
117     return true;
118   if (fcntl(socket_, F_SETFL, flags | O_NONBLOCK) < 0) {
119     PLOG(ERROR) << "fcntl";
120     return false;
121   }
122   return true;
123 }
124 
InitUnixSocket(const std::string & path)125 bool Socket::InitUnixSocket(const std::string& path) {
126   static const size_t kPathMax = sizeof(addr_.addr_un.sun_path);
127   // For abstract sockets we need one extra byte for the leading zero.
128   if (path.size() + 2 /* '\0' */ > kPathMax) {
129     LOG(ERROR) << "The provided path is too big to create a unix "
130                << "domain socket: " << path;
131     return false;
132   }
133   family_ = PF_UNIX;
134   addr_.addr_un.sun_family = family_;
135   // Copied from net/socket/unix_domain_socket_posix.cc
136   // Convert the path given into abstract socket name. It must start with
137   // the '\0' character, so we are adding it. |addr_len| must specify the
138   // length of the structure exactly, as potentially the socket name may
139   // have '\0' characters embedded (although we don't support this).
140   // Note that addr_.addr_un.sun_path is already zero initialized.
141   memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size());
142   addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
143   addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un);
144   return InitSocketInternal();
145 }
146 
InitTcpSocket(const std::string & host,int port)147 bool Socket::InitTcpSocket(const std::string& host, int port) {
148   port_ = port;
149   if (host.empty()) {
150     // Use localhost: INADDR_LOOPBACK
151     family_ = AF_INET;
152     addr_.addr4.sin_family = family_;
153     addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
154   } else if (!Resolve(host)) {
155     return false;
156   }
157   CHECK(FamilyIsTCP(family_)) << "Invalid socket family.";
158   if (family_ == AF_INET) {
159     addr_.addr4.sin_port = htons(port_);
160     addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4);
161     addr_len_ = sizeof(addr_.addr4);
162   } else if (family_ == AF_INET6) {
163     addr_.addr6.sin6_port = htons(port_);
164     addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6);
165     addr_len_ = sizeof(addr_.addr6);
166   }
167   return InitSocketInternal();
168 }
169 
BindAndListen()170 bool Socket::BindAndListen() {
171   errno = 0;
172   if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 ||
173       HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) {
174     SetSocketError();
175     return false;
176   }
177   if (port_ == 0 && FamilyIsTCP(family_)) {
178     SockAddr addr;
179     memset(&addr, 0, sizeof(addr));
180     socklen_t addrlen = 0;
181     sockaddr* addr_ptr = NULL;
182     uint16* port_ptr = NULL;
183     if (family_ == AF_INET) {
184       addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4);
185       port_ptr = &addr.addr4.sin_port;
186       addrlen = sizeof(addr.addr4);
187     } else if (family_ == AF_INET6) {
188       addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6);
189       port_ptr = &addr.addr6.sin6_port;
190       addrlen = sizeof(addr.addr6);
191     }
192     errno = 0;
193     if (getsockname(socket_, addr_ptr, &addrlen) != 0) {
194       PLOG(ERROR) << "getsockname";
195       SetSocketError();
196       return false;
197     }
198     port_ = ntohs(*port_ptr);
199   }
200   return true;
201 }
202 
Accept(Socket * new_socket)203 bool Socket::Accept(Socket* new_socket) {
204   DCHECK(new_socket != NULL);
205   if (!WaitForEvent(READ, kNoTimeout)) {
206     SetSocketError();
207     return false;
208   }
209   errno = 0;
210   int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL));
211   if (new_socket_fd < 0) {
212     SetSocketError();
213     return false;
214   }
215   tools::DisableNagle(new_socket_fd);
216   new_socket->socket_ = new_socket_fd;
217   if (!new_socket->SetNonBlocking())
218     return false;
219   return true;
220 }
221 
Connect()222 bool Socket::Connect() {
223   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
224   errno = 0;
225   if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 &&
226       errno != EINPROGRESS) {
227     SetSocketError();
228     return false;
229   }
230   // Wait for connection to complete, or receive a notification.
231   if (!WaitForEvent(WRITE, kConnectTimeOut)) {
232     SetSocketError();
233     return false;
234   }
235   int socket_errno;
236   socklen_t opt_len = sizeof(socket_errno);
237   if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) {
238     PLOG(ERROR) << "getsockopt()";
239     SetSocketError();
240     return false;
241   }
242   if (socket_errno != 0) {
243     LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno);
244     SetSocketError();
245     return false;
246   }
247   return true;
248 }
249 
Resolve(const std::string & host)250 bool Socket::Resolve(const std::string& host) {
251   struct addrinfo hints;
252   struct addrinfo* res;
253   memset(&hints, 0, sizeof(hints));
254   hints.ai_family = AF_UNSPEC;
255   hints.ai_socktype = SOCK_STREAM;
256   hints.ai_flags |= AI_CANONNAME;
257 
258   int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res);
259   if (errcode != 0) {
260     errno = 0;
261     SetSocketError();
262     freeaddrinfo(res);
263     return false;
264   }
265   family_ = res->ai_family;
266   switch (res->ai_family) {
267     case AF_INET:
268       memcpy(&addr_.addr4,
269              reinterpret_cast<sockaddr_in*>(res->ai_addr),
270              sizeof(sockaddr_in));
271       break;
272     case AF_INET6:
273       memcpy(&addr_.addr6,
274              reinterpret_cast<sockaddr_in6*>(res->ai_addr),
275              sizeof(sockaddr_in6));
276       break;
277   }
278   freeaddrinfo(res);
279   return true;
280 }
281 
GetPort()282 int Socket::GetPort() {
283   if (!FamilyIsTCP(family_)) {
284     LOG(ERROR) << "Can't call GetPort() on an unix domain socket.";
285     return 0;
286   }
287   return port_;
288 }
289 
IsFdInSet(const fd_set & fds) const290 bool Socket::IsFdInSet(const fd_set& fds) const {
291   if (IsClosed())
292     return false;
293   return FD_ISSET(socket_, &fds);
294 }
295 
AddFdToSet(fd_set * fds) const296 bool Socket::AddFdToSet(fd_set* fds) const {
297   if (IsClosed())
298     return false;
299   FD_SET(socket_, fds);
300   return true;
301 }
302 
ReadNumBytes(void * buffer,size_t num_bytes)303 int Socket::ReadNumBytes(void* buffer, size_t num_bytes) {
304   int bytes_read = 0;
305   int ret = 1;
306   while (bytes_read < num_bytes && ret > 0) {
307     ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read);
308     if (ret >= 0)
309       bytes_read += ret;
310   }
311   return bytes_read;
312 }
313 
SetSocketError()314 void Socket::SetSocketError() {
315   socket_error_ = true;
316   DCHECK_NE(EAGAIN, errno);
317   DCHECK_NE(EWOULDBLOCK, errno);
318   Close();
319 }
320 
Read(void * buffer,size_t buffer_size)321 int Socket::Read(void* buffer, size_t buffer_size) {
322   if (!WaitForEvent(READ, kNoTimeout)) {
323     SetSocketError();
324     return 0;
325   }
326   int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
327   if (ret < 0) {
328     PLOG(ERROR) << "read";
329     SetSocketError();
330   }
331   return ret;
332 }
333 
NonBlockingRead(void * buffer,size_t buffer_size)334 int Socket::NonBlockingRead(void* buffer, size_t buffer_size) {
335   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
336   int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
337   if (ret < 0) {
338     PLOG(ERROR) << "read";
339     SetSocketError();
340   }
341   return ret;
342 }
343 
Write(const void * buffer,size_t count)344 int Socket::Write(const void* buffer, size_t count) {
345   if (!WaitForEvent(WRITE, kNoTimeout)) {
346     SetSocketError();
347     return 0;
348   }
349   int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
350   if (ret < 0) {
351     PLOG(ERROR) << "send";
352     SetSocketError();
353   }
354   return ret;
355 }
356 
NonBlockingWrite(const void * buffer,size_t count)357 int Socket::NonBlockingWrite(const void* buffer, size_t count) {
358   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
359   int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
360   if (ret < 0) {
361     PLOG(ERROR) << "send";
362     SetSocketError();
363   }
364   return ret;
365 }
366 
WriteString(const std::string & buffer)367 int Socket::WriteString(const std::string& buffer) {
368   return WriteNumBytes(buffer.c_str(), buffer.size());
369 }
370 
AddEventFd(int event_fd)371 void Socket::AddEventFd(int event_fd) {
372   Event event;
373   event.fd = event_fd;
374   event.was_fired = false;
375   events_.push_back(event);
376 }
377 
DidReceiveEventOnFd(int fd) const378 bool Socket::DidReceiveEventOnFd(int fd) const {
379   for (size_t i = 0; i < events_.size(); ++i)
380     if (events_[i].fd == fd)
381       return events_[i].was_fired;
382   return false;
383 }
384 
DidReceiveEvent() const385 bool Socket::DidReceiveEvent() const {
386   for (size_t i = 0; i < events_.size(); ++i)
387     if (events_[i].was_fired)
388       return true;
389   return false;
390 }
391 
WriteNumBytes(const void * buffer,size_t num_bytes)392 int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) {
393   int bytes_written = 0;
394   int ret = 1;
395   while (bytes_written < num_bytes && ret > 0) {
396     ret = Write(static_cast<const char*>(buffer) + bytes_written,
397                 num_bytes - bytes_written);
398     if (ret >= 0)
399       bytes_written += ret;
400   }
401   return bytes_written;
402 }
403 
WaitForEvent(EventType type,int timeout_secs)404 bool Socket::WaitForEvent(EventType type, int timeout_secs) {
405   if (socket_ == -1)
406     return true;
407   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
408   fd_set read_fds;
409   fd_set write_fds;
410   FD_ZERO(&read_fds);
411   FD_ZERO(&write_fds);
412   if (type == READ)
413     FD_SET(socket_, &read_fds);
414   else
415     FD_SET(socket_, &write_fds);
416   for (size_t i = 0; i < events_.size(); ++i)
417     FD_SET(events_[i].fd, &read_fds);
418   timeval tv = {};
419   timeval* tv_ptr = NULL;
420   if (timeout_secs > 0) {
421     tv.tv_sec = timeout_secs;
422     tv.tv_usec = 0;
423     tv_ptr = &tv;
424   }
425   int max_fd = socket_;
426   for (size_t i = 0; i < events_.size(); ++i)
427     if (events_[i].fd > max_fd)
428       max_fd = events_[i].fd;
429   if (HANDLE_EINTR(
430           select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) {
431     PLOG(ERROR) << "select";
432     return false;
433   }
434   bool event_was_fired = false;
435   for (size_t i = 0; i < events_.size(); ++i) {
436     if (FD_ISSET(events_[i].fd, &read_fds)) {
437       events_[i].was_fired = true;
438       event_was_fired = true;
439     }
440   }
441   return !event_was_fired;
442 }
443 
444 // static
GetHighestFileDescriptor(const Socket & s1,const Socket & s2)445 int Socket::GetHighestFileDescriptor(const Socket& s1, const Socket& s2) {
446   return std::max(s1.socket_, s2.socket_);
447 }
448 
449 // static
GetUnixDomainSocketProcessOwner(const std::string & path)450 pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) {
451   Socket socket;
452   if (!socket.ConnectUnix(path))
453     return -1;
454   ucred ucred;
455   socklen_t len = sizeof(ucred);
456   if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) {
457     CHECK_NE(ENOPROTOOPT, errno);
458     return -1;
459   }
460   return ucred.pid;
461 }
462 
463 }  // namespace forwarder2
464