1
2 // Copyright (C) 2021 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 #include "net/posix/posix_async_socket_connector.h"
16
17 #include <arpa/inet.h> // for inet_addr, inet_ntoa
18 #include <errno.h> // for errno, EAGAIN, EINPROGRESS
19 #include <netdb.h> // for gethostbyname, addrinfo
20 #include <netinet/in.h> // for sockaddr_in, in_addr
21 #include <poll.h> // for poll, POLLHUP, POLLIN, POL...
22 #include <string.h> // for strerror, NULL
23 #include <sys/socket.h> // for connect, getpeername, gets...
24
25 #include <type_traits> // for remove_extent_t
26
27 #include "log.h" // for LOG_INFO
28 #include "net/posix/posix_async_socket.h" // for PosixAsyncSocket
29
30 namespace android {
31 namespace net {
32
PosixAsyncSocketConnector(AsyncManager * am)33 PosixAsyncSocketConnector::PosixAsyncSocketConnector(AsyncManager* am)
34 : am_(am) {}
35
36 std::shared_ptr<AsyncDataChannel>
ConnectToRemoteServer(const std::string & server,int port,const std::chrono::milliseconds timeout)37 PosixAsyncSocketConnector::ConnectToRemoteServer(
38 const std::string& server, int port,
39 const std::chrono::milliseconds timeout) {
40 LOG_INFO("Connecting to %s:%d in %d ms", server.c_str(), port,
41 (int)timeout.count());
42 int socket_fd = socket(AF_INET, SOCK_STREAM, 0);
43 std::shared_ptr<PosixAsyncSocket> pas =
44 std::make_shared<PosixAsyncSocket>(socket_fd, am_);
45
46 if (socket_fd < 1) {
47 LOG_INFO("socket() call failed: %s", strerror(errno));
48 return pas;
49 }
50
51 struct hostent* host;
52 host = gethostbyname(server.c_str());
53 if (host == NULL) {
54 LOG_INFO("gethostbyname() failed for %s: %s", server.c_str(),
55 strerror(errno));
56 pas->Close();
57 return pas;
58 }
59
60 struct in_addr** addr_list = (struct in_addr**)host->h_addr_list;
61 struct sockaddr_in serv_addr {};
62 serv_addr.sin_family = AF_INET;
63 serv_addr.sin_addr.s_addr = inet_addr(inet_ntoa(*addr_list[0]));
64 serv_addr.sin_port = htons(port);
65
66 int result =
67 connect(socket_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
68
69 if (result != 0 && errno != EWOULDBLOCK && errno != EAGAIN &&
70 errno != EINPROGRESS) {
71 LOG_INFO("Failed to connect to %s:%d, error: %s", server.c_str(), port,
72 strerror(errno));
73 pas->Close();
74 return pas;
75 }
76
77 // wait for the connection.
78 struct pollfd fds[] = {
79 {
80 .fd = socket_fd,
81 .events = POLLIN | POLLOUT | POLLHUP,
82 .revents = 0,
83 },
84 };
85
86 int numFdsReady = 0;
87 REPEAT_UNTIL_NO_INTR(numFdsReady = ::poll(fds, 1, timeout.count()));
88
89 if (numFdsReady <= 0) {
90 LOG_INFO("Failed to connect to %s:%d, error: %s", server.c_str(), port,
91 strerror(errno));
92 pas->Close();
93 return pas;
94 }
95
96 // As per https://cr.yp.to/docs/connect.html, we should get the peername
97 // for validating if a connection was established.
98 struct sockaddr_storage ss;
99 socklen_t sslen = sizeof(ss);
100
101 if (getpeername(socket_fd, (struct sockaddr*)&ss, &sslen) < 0) {
102 LOG_INFO("Failed to connect to %s:%d, error: %s", server.c_str(), port,
103 strerror(errno));
104 pas->Close();
105 return pas;
106 }
107
108 int err = 0;
109 socklen_t optLen = sizeof(err);
110 if (getsockopt(socket_fd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&err),
111 &optLen) ||
112 err) {
113 // Either getsockopt failed or there was an error associated
114 // with the socket. The connection did not succeed.
115 LOG_INFO("Failed to connect to %s:%d, error: %s", server.c_str(), port,
116 strerror(err));
117 pas->Close();
118 return pas;
119 }
120
121 LOG_INFO("Connected to %s:%d (%d)", server.c_str(), port, socket_fd);
122 return pas;
123 }
124
125 } // namespace net
126 } // namespace android
127