1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "tools/android/forwarder2/socket.h"
6
7 #include <arpa/inet.h>
8 #include <fcntl.h>
9 #include <netdb.h>
10 #include <netinet/in.h>
11 #include <stdio.h>
12 #include <string.h>
13 #include <sys/socket.h>
14 #include <sys/types.h>
15 #include <unistd.h>
16
17 #include "base/logging.h"
18 #include "base/posix/eintr_wrapper.h"
19 #include "base/safe_strerror_posix.h"
20 #include "tools/android/common/net.h"
21 #include "tools/android/forwarder2/common.h"
22
23 namespace {
24 const int kNoTimeout = -1;
25 const int kConnectTimeOut = 10; // Seconds.
26
FamilyIsTCP(int family)27 bool FamilyIsTCP(int family) {
28 return family == AF_INET || family == AF_INET6;
29 }
30 } // namespace
31
32 namespace forwarder2 {
33
BindUnix(const std::string & path)34 bool Socket::BindUnix(const std::string& path) {
35 errno = 0;
36 if (!InitUnixSocket(path) || !BindAndListen()) {
37 Close();
38 return false;
39 }
40 return true;
41 }
42
BindTcp(const std::string & host,int port)43 bool Socket::BindTcp(const std::string& host, int port) {
44 errno = 0;
45 if (!InitTcpSocket(host, port) || !BindAndListen()) {
46 Close();
47 return false;
48 }
49 return true;
50 }
51
ConnectUnix(const std::string & path)52 bool Socket::ConnectUnix(const std::string& path) {
53 errno = 0;
54 if (!InitUnixSocket(path) || !Connect()) {
55 Close();
56 return false;
57 }
58 return true;
59 }
60
ConnectTcp(const std::string & host,int port)61 bool Socket::ConnectTcp(const std::string& host, int port) {
62 errno = 0;
63 if (!InitTcpSocket(host, port) || !Connect()) {
64 Close();
65 return false;
66 }
67 return true;
68 }
69
Socket()70 Socket::Socket()
71 : socket_(-1),
72 port_(0),
73 socket_error_(false),
74 family_(AF_INET),
75 addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)),
76 addr_len_(sizeof(sockaddr)) {
77 memset(&addr_, 0, sizeof(addr_));
78 }
79
~Socket()80 Socket::~Socket() {
81 Close();
82 }
83
Shutdown()84 void Socket::Shutdown() {
85 if (!IsClosed()) {
86 PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR));
87 }
88 }
89
Close()90 void Socket::Close() {
91 if (!IsClosed()) {
92 CloseFD(socket_);
93 socket_ = -1;
94 }
95 }
96
InitSocketInternal()97 bool Socket::InitSocketInternal() {
98 socket_ = socket(family_, SOCK_STREAM, 0);
99 if (socket_ < 0) {
100 PLOG(ERROR) << "socket";
101 return false;
102 }
103 tools::DisableNagle(socket_);
104 int reuse_addr = 1;
105 setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &reuse_addr,
106 sizeof(reuse_addr));
107 if (!SetNonBlocking())
108 return false;
109 return true;
110 }
111
SetNonBlocking()112 bool Socket::SetNonBlocking() {
113 const int flags = fcntl(socket_, F_GETFL);
114 if (flags < 0) {
115 PLOG(ERROR) << "fcntl";
116 return false;
117 }
118 if (flags & O_NONBLOCK)
119 return true;
120 if (fcntl(socket_, F_SETFL, flags | O_NONBLOCK) < 0) {
121 PLOG(ERROR) << "fcntl";
122 return false;
123 }
124 return true;
125 }
126
InitUnixSocket(const std::string & path)127 bool Socket::InitUnixSocket(const std::string& path) {
128 static const size_t kPathMax = sizeof(addr_.addr_un.sun_path);
129 // For abstract sockets we need one extra byte for the leading zero.
130 if (path.size() + 2 /* '\0' */ > kPathMax) {
131 LOG(ERROR) << "The provided path is too big to create a unix "
132 << "domain socket: " << path;
133 return false;
134 }
135 family_ = PF_UNIX;
136 addr_.addr_un.sun_family = family_;
137 // Copied from net/socket/unix_domain_socket_posix.cc
138 // Convert the path given into abstract socket name. It must start with
139 // the '\0' character, so we are adding it. |addr_len| must specify the
140 // length of the structure exactly, as potentially the socket name may
141 // have '\0' characters embedded (although we don't support this).
142 // Note that addr_.addr_un.sun_path is already zero initialized.
143 memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size());
144 addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
145 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un);
146 return InitSocketInternal();
147 }
148
InitTcpSocket(const std::string & host,int port)149 bool Socket::InitTcpSocket(const std::string& host, int port) {
150 port_ = port;
151 if (host.empty()) {
152 // Use localhost: INADDR_LOOPBACK
153 family_ = AF_INET;
154 addr_.addr4.sin_family = family_;
155 addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
156 } else if (!Resolve(host)) {
157 return false;
158 }
159 CHECK(FamilyIsTCP(family_)) << "Invalid socket family.";
160 if (family_ == AF_INET) {
161 addr_.addr4.sin_port = htons(port_);
162 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4);
163 addr_len_ = sizeof(addr_.addr4);
164 } else if (family_ == AF_INET6) {
165 addr_.addr6.sin6_port = htons(port_);
166 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6);
167 addr_len_ = sizeof(addr_.addr6);
168 }
169 return InitSocketInternal();
170 }
171
BindAndListen()172 bool Socket::BindAndListen() {
173 errno = 0;
174 if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 ||
175 HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) {
176 PLOG(ERROR) << "bind/listen";
177 SetSocketError();
178 return false;
179 }
180 if (port_ == 0 && FamilyIsTCP(family_)) {
181 SockAddr addr;
182 memset(&addr, 0, sizeof(addr));
183 socklen_t addrlen = 0;
184 sockaddr* addr_ptr = NULL;
185 uint16* port_ptr = NULL;
186 if (family_ == AF_INET) {
187 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4);
188 port_ptr = &addr.addr4.sin_port;
189 addrlen = sizeof(addr.addr4);
190 } else if (family_ == AF_INET6) {
191 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6);
192 port_ptr = &addr.addr6.sin6_port;
193 addrlen = sizeof(addr.addr6);
194 }
195 errno = 0;
196 if (getsockname(socket_, addr_ptr, &addrlen) != 0) {
197 PLOG(ERROR) << "getsockname";
198 SetSocketError();
199 return false;
200 }
201 port_ = ntohs(*port_ptr);
202 }
203 return true;
204 }
205
Accept(Socket * new_socket)206 bool Socket::Accept(Socket* new_socket) {
207 DCHECK(new_socket != NULL);
208 if (!WaitForEvent(READ, kNoTimeout)) {
209 SetSocketError();
210 return false;
211 }
212 errno = 0;
213 int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL));
214 if (new_socket_fd < 0) {
215 SetSocketError();
216 return false;
217 }
218 tools::DisableNagle(new_socket_fd);
219 new_socket->socket_ = new_socket_fd;
220 if (!new_socket->SetNonBlocking())
221 return false;
222 return true;
223 }
224
Connect()225 bool Socket::Connect() {
226 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
227 errno = 0;
228 if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 &&
229 errno != EINPROGRESS) {
230 SetSocketError();
231 return false;
232 }
233 // Wait for connection to complete, or receive a notification.
234 if (!WaitForEvent(WRITE, kConnectTimeOut)) {
235 SetSocketError();
236 return false;
237 }
238 int socket_errno;
239 socklen_t opt_len = sizeof(socket_errno);
240 if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) {
241 PLOG(ERROR) << "getsockopt()";
242 SetSocketError();
243 return false;
244 }
245 if (socket_errno != 0) {
246 LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno);
247 SetSocketError();
248 return false;
249 }
250 return true;
251 }
252
Resolve(const std::string & host)253 bool Socket::Resolve(const std::string& host) {
254 struct addrinfo hints;
255 struct addrinfo* res;
256 memset(&hints, 0, sizeof(hints));
257 hints.ai_family = AF_UNSPEC;
258 hints.ai_socktype = SOCK_STREAM;
259 hints.ai_flags |= AI_CANONNAME;
260
261 int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res);
262 if (errcode != 0) {
263 errno = 0;
264 SetSocketError();
265 freeaddrinfo(res);
266 return false;
267 }
268 family_ = res->ai_family;
269 switch (res->ai_family) {
270 case AF_INET:
271 memcpy(&addr_.addr4,
272 reinterpret_cast<sockaddr_in*>(res->ai_addr),
273 sizeof(sockaddr_in));
274 break;
275 case AF_INET6:
276 memcpy(&addr_.addr6,
277 reinterpret_cast<sockaddr_in6*>(res->ai_addr),
278 sizeof(sockaddr_in6));
279 break;
280 }
281 freeaddrinfo(res);
282 return true;
283 }
284
GetPort()285 int Socket::GetPort() {
286 if (!FamilyIsTCP(family_)) {
287 LOG(ERROR) << "Can't call GetPort() on an unix domain socket.";
288 return 0;
289 }
290 return port_;
291 }
292
ReadNumBytes(void * buffer,size_t num_bytes)293 int Socket::ReadNumBytes(void* buffer, size_t num_bytes) {
294 int bytes_read = 0;
295 int ret = 1;
296 while (bytes_read < num_bytes && ret > 0) {
297 ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read);
298 if (ret >= 0)
299 bytes_read += ret;
300 }
301 return bytes_read;
302 }
303
SetSocketError()304 void Socket::SetSocketError() {
305 socket_error_ = true;
306 DCHECK_NE(EAGAIN, errno);
307 DCHECK_NE(EWOULDBLOCK, errno);
308 Close();
309 }
310
Read(void * buffer,size_t buffer_size)311 int Socket::Read(void* buffer, size_t buffer_size) {
312 if (!WaitForEvent(READ, kNoTimeout)) {
313 SetSocketError();
314 return 0;
315 }
316 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
317 if (ret < 0) {
318 PLOG(ERROR) << "read";
319 SetSocketError();
320 }
321 return ret;
322 }
323
NonBlockingRead(void * buffer,size_t buffer_size)324 int Socket::NonBlockingRead(void* buffer, size_t buffer_size) {
325 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
326 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
327 if (ret < 0) {
328 PLOG(ERROR) << "read";
329 SetSocketError();
330 }
331 return ret;
332 }
333
Write(const void * buffer,size_t count)334 int Socket::Write(const void* buffer, size_t count) {
335 if (!WaitForEvent(WRITE, kNoTimeout)) {
336 SetSocketError();
337 return 0;
338 }
339 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
340 if (ret < 0) {
341 PLOG(ERROR) << "send";
342 SetSocketError();
343 }
344 return ret;
345 }
346
NonBlockingWrite(const void * buffer,size_t count)347 int Socket::NonBlockingWrite(const void* buffer, size_t count) {
348 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
349 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
350 if (ret < 0) {
351 PLOG(ERROR) << "send";
352 SetSocketError();
353 }
354 return ret;
355 }
356
WriteString(const std::string & buffer)357 int Socket::WriteString(const std::string& buffer) {
358 return WriteNumBytes(buffer.c_str(), buffer.size());
359 }
360
AddEventFd(int event_fd)361 void Socket::AddEventFd(int event_fd) {
362 Event event;
363 event.fd = event_fd;
364 event.was_fired = false;
365 events_.push_back(event);
366 }
367
DidReceiveEventOnFd(int fd) const368 bool Socket::DidReceiveEventOnFd(int fd) const {
369 for (size_t i = 0; i < events_.size(); ++i)
370 if (events_[i].fd == fd)
371 return events_[i].was_fired;
372 return false;
373 }
374
DidReceiveEvent() const375 bool Socket::DidReceiveEvent() const {
376 for (size_t i = 0; i < events_.size(); ++i)
377 if (events_[i].was_fired)
378 return true;
379 return false;
380 }
381
WriteNumBytes(const void * buffer,size_t num_bytes)382 int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) {
383 int bytes_written = 0;
384 int ret = 1;
385 while (bytes_written < num_bytes && ret > 0) {
386 ret = Write(static_cast<const char*>(buffer) + bytes_written,
387 num_bytes - bytes_written);
388 if (ret >= 0)
389 bytes_written += ret;
390 }
391 return bytes_written;
392 }
393
WaitForEvent(EventType type,int timeout_secs)394 bool Socket::WaitForEvent(EventType type, int timeout_secs) {
395 if (socket_ == -1)
396 return true;
397 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
398 fd_set read_fds;
399 fd_set write_fds;
400 FD_ZERO(&read_fds);
401 FD_ZERO(&write_fds);
402 if (type == READ)
403 FD_SET(socket_, &read_fds);
404 else
405 FD_SET(socket_, &write_fds);
406 for (size_t i = 0; i < events_.size(); ++i)
407 FD_SET(events_[i].fd, &read_fds);
408 timeval tv = {};
409 timeval* tv_ptr = NULL;
410 if (timeout_secs > 0) {
411 tv.tv_sec = timeout_secs;
412 tv.tv_usec = 0;
413 tv_ptr = &tv;
414 }
415 int max_fd = socket_;
416 for (size_t i = 0; i < events_.size(); ++i)
417 if (events_[i].fd > max_fd)
418 max_fd = events_[i].fd;
419 if (HANDLE_EINTR(
420 select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) {
421 PLOG(ERROR) << "select";
422 return false;
423 }
424 bool event_was_fired = false;
425 for (size_t i = 0; i < events_.size(); ++i) {
426 if (FD_ISSET(events_[i].fd, &read_fds)) {
427 events_[i].was_fired = true;
428 event_was_fired = true;
429 }
430 }
431 return !event_was_fired;
432 }
433
434 // static
GetUnixDomainSocketProcessOwner(const std::string & path)435 pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) {
436 Socket socket;
437 if (!socket.ConnectUnix(path))
438 return -1;
439 ucred ucred;
440 socklen_t len = sizeof(ucred);
441 if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) {
442 CHECK_NE(ENOPROTOOPT, errno);
443 return -1;
444 }
445 return ucred.pid;
446 }
447
448 } // namespace forwarder2
449