1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "tools/android/forwarder2/socket.h"
6
7 #include <arpa/inet.h>
8 #include <fcntl.h>
9 #include <netdb.h>
10 #include <netinet/in.h>
11 #include <stdio.h>
12 #include <string.h>
13 #include <sys/socket.h>
14 #include <sys/types.h>
15 #include <unistd.h>
16
17 #include "base/logging.h"
18 #include "base/posix/eintr_wrapper.h"
19 #include "base/safe_strerror_posix.h"
20 #include "tools/android/common/net.h"
21 #include "tools/android/forwarder2/common.h"
22
23 namespace {
24 const int kNoTimeout = -1;
25 const int kConnectTimeOut = 10; // Seconds.
26
FamilyIsTCP(int family)27 bool FamilyIsTCP(int family) {
28 return family == AF_INET || family == AF_INET6;
29 }
30 } // namespace
31
32 namespace forwarder2 {
33
BindUnix(const std::string & path)34 bool Socket::BindUnix(const std::string& path) {
35 errno = 0;
36 if (!InitUnixSocket(path) || !BindAndListen()) {
37 Close();
38 return false;
39 }
40 return true;
41 }
42
BindTcp(const std::string & host,int port)43 bool Socket::BindTcp(const std::string& host, int port) {
44 errno = 0;
45 if (!InitTcpSocket(host, port) || !BindAndListen()) {
46 Close();
47 return false;
48 }
49 return true;
50 }
51
ConnectUnix(const std::string & path)52 bool Socket::ConnectUnix(const std::string& path) {
53 errno = 0;
54 if (!InitUnixSocket(path) || !Connect()) {
55 Close();
56 return false;
57 }
58 return true;
59 }
60
ConnectTcp(const std::string & host,int port)61 bool Socket::ConnectTcp(const std::string& host, int port) {
62 errno = 0;
63 if (!InitTcpSocket(host, port) || !Connect()) {
64 Close();
65 return false;
66 }
67 return true;
68 }
69
Socket()70 Socket::Socket()
71 : socket_(-1),
72 port_(0),
73 socket_error_(false),
74 family_(AF_INET),
75 addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)),
76 addr_len_(sizeof(sockaddr)) {
77 memset(&addr_, 0, sizeof(addr_));
78 }
79
~Socket()80 Socket::~Socket() {
81 Close();
82 }
83
Shutdown()84 void Socket::Shutdown() {
85 if (!IsClosed()) {
86 PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR));
87 }
88 }
89
Close()90 void Socket::Close() {
91 if (!IsClosed()) {
92 CloseFD(socket_);
93 socket_ = -1;
94 }
95 }
96
InitSocketInternal()97 bool Socket::InitSocketInternal() {
98 socket_ = socket(family_, SOCK_STREAM, 0);
99 if (socket_ < 0)
100 return false;
101 tools::DisableNagle(socket_);
102 int reuse_addr = 1;
103 setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &reuse_addr,
104 sizeof(reuse_addr));
105 if (!SetNonBlocking())
106 return false;
107 return true;
108 }
109
SetNonBlocking()110 bool Socket::SetNonBlocking() {
111 const int flags = fcntl(socket_, F_GETFL);
112 if (flags < 0) {
113 PLOG(ERROR) << "fcntl";
114 return false;
115 }
116 if (flags & O_NONBLOCK)
117 return true;
118 if (fcntl(socket_, F_SETFL, flags | O_NONBLOCK) < 0) {
119 PLOG(ERROR) << "fcntl";
120 return false;
121 }
122 return true;
123 }
124
InitUnixSocket(const std::string & path)125 bool Socket::InitUnixSocket(const std::string& path) {
126 static const size_t kPathMax = sizeof(addr_.addr_un.sun_path);
127 // For abstract sockets we need one extra byte for the leading zero.
128 if (path.size() + 2 /* '\0' */ > kPathMax) {
129 LOG(ERROR) << "The provided path is too big to create a unix "
130 << "domain socket: " << path;
131 return false;
132 }
133 family_ = PF_UNIX;
134 addr_.addr_un.sun_family = family_;
135 // Copied from net/socket/unix_domain_socket_posix.cc
136 // Convert the path given into abstract socket name. It must start with
137 // the '\0' character, so we are adding it. |addr_len| must specify the
138 // length of the structure exactly, as potentially the socket name may
139 // have '\0' characters embedded (although we don't support this).
140 // Note that addr_.addr_un.sun_path is already zero initialized.
141 memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size());
142 addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
143 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un);
144 return InitSocketInternal();
145 }
146
InitTcpSocket(const std::string & host,int port)147 bool Socket::InitTcpSocket(const std::string& host, int port) {
148 port_ = port;
149 if (host.empty()) {
150 // Use localhost: INADDR_LOOPBACK
151 family_ = AF_INET;
152 addr_.addr4.sin_family = family_;
153 addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
154 } else if (!Resolve(host)) {
155 return false;
156 }
157 CHECK(FamilyIsTCP(family_)) << "Invalid socket family.";
158 if (family_ == AF_INET) {
159 addr_.addr4.sin_port = htons(port_);
160 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4);
161 addr_len_ = sizeof(addr_.addr4);
162 } else if (family_ == AF_INET6) {
163 addr_.addr6.sin6_port = htons(port_);
164 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6);
165 addr_len_ = sizeof(addr_.addr6);
166 }
167 return InitSocketInternal();
168 }
169
BindAndListen()170 bool Socket::BindAndListen() {
171 errno = 0;
172 if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 ||
173 HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) {
174 SetSocketError();
175 return false;
176 }
177 if (port_ == 0 && FamilyIsTCP(family_)) {
178 SockAddr addr;
179 memset(&addr, 0, sizeof(addr));
180 socklen_t addrlen = 0;
181 sockaddr* addr_ptr = NULL;
182 uint16* port_ptr = NULL;
183 if (family_ == AF_INET) {
184 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4);
185 port_ptr = &addr.addr4.sin_port;
186 addrlen = sizeof(addr.addr4);
187 } else if (family_ == AF_INET6) {
188 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6);
189 port_ptr = &addr.addr6.sin6_port;
190 addrlen = sizeof(addr.addr6);
191 }
192 errno = 0;
193 if (getsockname(socket_, addr_ptr, &addrlen) != 0) {
194 PLOG(ERROR) << "getsockname";
195 SetSocketError();
196 return false;
197 }
198 port_ = ntohs(*port_ptr);
199 }
200 return true;
201 }
202
Accept(Socket * new_socket)203 bool Socket::Accept(Socket* new_socket) {
204 DCHECK(new_socket != NULL);
205 if (!WaitForEvent(READ, kNoTimeout)) {
206 SetSocketError();
207 return false;
208 }
209 errno = 0;
210 int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL));
211 if (new_socket_fd < 0) {
212 SetSocketError();
213 return false;
214 }
215 tools::DisableNagle(new_socket_fd);
216 new_socket->socket_ = new_socket_fd;
217 if (!new_socket->SetNonBlocking())
218 return false;
219 return true;
220 }
221
Connect()222 bool Socket::Connect() {
223 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
224 errno = 0;
225 if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 &&
226 errno != EINPROGRESS) {
227 SetSocketError();
228 return false;
229 }
230 // Wait for connection to complete, or receive a notification.
231 if (!WaitForEvent(WRITE, kConnectTimeOut)) {
232 SetSocketError();
233 return false;
234 }
235 int socket_errno;
236 socklen_t opt_len = sizeof(socket_errno);
237 if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) {
238 PLOG(ERROR) << "getsockopt()";
239 SetSocketError();
240 return false;
241 }
242 if (socket_errno != 0) {
243 LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno);
244 SetSocketError();
245 return false;
246 }
247 return true;
248 }
249
Resolve(const std::string & host)250 bool Socket::Resolve(const std::string& host) {
251 struct addrinfo hints;
252 struct addrinfo* res;
253 memset(&hints, 0, sizeof(hints));
254 hints.ai_family = AF_UNSPEC;
255 hints.ai_socktype = SOCK_STREAM;
256 hints.ai_flags |= AI_CANONNAME;
257
258 int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res);
259 if (errcode != 0) {
260 errno = 0;
261 SetSocketError();
262 freeaddrinfo(res);
263 return false;
264 }
265 family_ = res->ai_family;
266 switch (res->ai_family) {
267 case AF_INET:
268 memcpy(&addr_.addr4,
269 reinterpret_cast<sockaddr_in*>(res->ai_addr),
270 sizeof(sockaddr_in));
271 break;
272 case AF_INET6:
273 memcpy(&addr_.addr6,
274 reinterpret_cast<sockaddr_in6*>(res->ai_addr),
275 sizeof(sockaddr_in6));
276 break;
277 }
278 freeaddrinfo(res);
279 return true;
280 }
281
GetPort()282 int Socket::GetPort() {
283 if (!FamilyIsTCP(family_)) {
284 LOG(ERROR) << "Can't call GetPort() on an unix domain socket.";
285 return 0;
286 }
287 return port_;
288 }
289
IsFdInSet(const fd_set & fds) const290 bool Socket::IsFdInSet(const fd_set& fds) const {
291 if (IsClosed())
292 return false;
293 return FD_ISSET(socket_, &fds);
294 }
295
AddFdToSet(fd_set * fds) const296 bool Socket::AddFdToSet(fd_set* fds) const {
297 if (IsClosed())
298 return false;
299 FD_SET(socket_, fds);
300 return true;
301 }
302
ReadNumBytes(void * buffer,size_t num_bytes)303 int Socket::ReadNumBytes(void* buffer, size_t num_bytes) {
304 int bytes_read = 0;
305 int ret = 1;
306 while (bytes_read < num_bytes && ret > 0) {
307 ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read);
308 if (ret >= 0)
309 bytes_read += ret;
310 }
311 return bytes_read;
312 }
313
SetSocketError()314 void Socket::SetSocketError() {
315 socket_error_ = true;
316 DCHECK_NE(EAGAIN, errno);
317 DCHECK_NE(EWOULDBLOCK, errno);
318 Close();
319 }
320
Read(void * buffer,size_t buffer_size)321 int Socket::Read(void* buffer, size_t buffer_size) {
322 if (!WaitForEvent(READ, kNoTimeout)) {
323 SetSocketError();
324 return 0;
325 }
326 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
327 if (ret < 0) {
328 PLOG(ERROR) << "read";
329 SetSocketError();
330 }
331 return ret;
332 }
333
NonBlockingRead(void * buffer,size_t buffer_size)334 int Socket::NonBlockingRead(void* buffer, size_t buffer_size) {
335 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
336 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
337 if (ret < 0) {
338 PLOG(ERROR) << "read";
339 SetSocketError();
340 }
341 return ret;
342 }
343
Write(const void * buffer,size_t count)344 int Socket::Write(const void* buffer, size_t count) {
345 if (!WaitForEvent(WRITE, kNoTimeout)) {
346 SetSocketError();
347 return 0;
348 }
349 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
350 if (ret < 0) {
351 PLOG(ERROR) << "send";
352 SetSocketError();
353 }
354 return ret;
355 }
356
NonBlockingWrite(const void * buffer,size_t count)357 int Socket::NonBlockingWrite(const void* buffer, size_t count) {
358 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
359 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
360 if (ret < 0) {
361 PLOG(ERROR) << "send";
362 SetSocketError();
363 }
364 return ret;
365 }
366
WriteString(const std::string & buffer)367 int Socket::WriteString(const std::string& buffer) {
368 return WriteNumBytes(buffer.c_str(), buffer.size());
369 }
370
AddEventFd(int event_fd)371 void Socket::AddEventFd(int event_fd) {
372 Event event;
373 event.fd = event_fd;
374 event.was_fired = false;
375 events_.push_back(event);
376 }
377
DidReceiveEventOnFd(int fd) const378 bool Socket::DidReceiveEventOnFd(int fd) const {
379 for (size_t i = 0; i < events_.size(); ++i)
380 if (events_[i].fd == fd)
381 return events_[i].was_fired;
382 return false;
383 }
384
DidReceiveEvent() const385 bool Socket::DidReceiveEvent() const {
386 for (size_t i = 0; i < events_.size(); ++i)
387 if (events_[i].was_fired)
388 return true;
389 return false;
390 }
391
WriteNumBytes(const void * buffer,size_t num_bytes)392 int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) {
393 int bytes_written = 0;
394 int ret = 1;
395 while (bytes_written < num_bytes && ret > 0) {
396 ret = Write(static_cast<const char*>(buffer) + bytes_written,
397 num_bytes - bytes_written);
398 if (ret >= 0)
399 bytes_written += ret;
400 }
401 return bytes_written;
402 }
403
WaitForEvent(EventType type,int timeout_secs)404 bool Socket::WaitForEvent(EventType type, int timeout_secs) {
405 if (socket_ == -1)
406 return true;
407 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
408 fd_set read_fds;
409 fd_set write_fds;
410 FD_ZERO(&read_fds);
411 FD_ZERO(&write_fds);
412 if (type == READ)
413 FD_SET(socket_, &read_fds);
414 else
415 FD_SET(socket_, &write_fds);
416 for (size_t i = 0; i < events_.size(); ++i)
417 FD_SET(events_[i].fd, &read_fds);
418 timeval tv = {};
419 timeval* tv_ptr = NULL;
420 if (timeout_secs > 0) {
421 tv.tv_sec = timeout_secs;
422 tv.tv_usec = 0;
423 tv_ptr = &tv;
424 }
425 int max_fd = socket_;
426 for (size_t i = 0; i < events_.size(); ++i)
427 if (events_[i].fd > max_fd)
428 max_fd = events_[i].fd;
429 if (HANDLE_EINTR(
430 select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) {
431 PLOG(ERROR) << "select";
432 return false;
433 }
434 bool event_was_fired = false;
435 for (size_t i = 0; i < events_.size(); ++i) {
436 if (FD_ISSET(events_[i].fd, &read_fds)) {
437 events_[i].was_fired = true;
438 event_was_fired = true;
439 }
440 }
441 return !event_was_fired;
442 }
443
444 // static
GetHighestFileDescriptor(const Socket & s1,const Socket & s2)445 int Socket::GetHighestFileDescriptor(const Socket& s1, const Socket& s2) {
446 return std::max(s1.socket_, s2.socket_);
447 }
448
449 // static
GetUnixDomainSocketProcessOwner(const std::string & path)450 pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) {
451 Socket socket;
452 if (!socket.ConnectUnix(path))
453 return -1;
454 ucred ucred;
455 socklen_t len = sizeof(ucred);
456 if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) {
457 CHECK_NE(ENOPROTOOPT, errno);
458 return -1;
459 }
460 return ucred.pid;
461 }
462
463 } // namespace forwarder2
464