1 //===-- TCPSocket.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #if defined(_MSC_VER)
10 #define _WINSOCK_DEPRECATED_NO_WARNINGS
11 #endif
12
13 #include "lldb/Host/common/TCPSocket.h"
14
15 #include "lldb/Host/Config.h"
16 #include "lldb/Host/MainLoop.h"
17 #include "lldb/Utility/Log.h"
18
19 #include "llvm/Config/llvm-config.h"
20 #include "llvm/Support/Errno.h"
21 #include "llvm/Support/WindowsError.h"
22 #include "llvm/Support/raw_ostream.h"
23
24 #if LLDB_ENABLE_POSIX
25 #include <arpa/inet.h>
26 #include <netinet/tcp.h>
27 #include <sys/socket.h>
28 #endif
29
30 #if defined(_WIN32)
31 #include <winsock2.h>
32 #endif
33
34 #ifdef _WIN32
35 #define CLOSE_SOCKET closesocket
36 typedef const char *set_socket_option_arg_type;
37 #else
38 #include <unistd.h>
39 #define CLOSE_SOCKET ::close
40 typedef const void *set_socket_option_arg_type;
41 #endif
42
43 using namespace lldb;
44 using namespace lldb_private;
45
GetLastSocketError()46 static Status GetLastSocketError() {
47 std::error_code EC;
48 #ifdef _WIN32
49 EC = llvm::mapWindowsError(WSAGetLastError());
50 #else
51 EC = std::error_code(errno, std::generic_category());
52 #endif
53 return EC;
54 }
55
56 namespace {
57 const int kType = SOCK_STREAM;
58 }
59
TCPSocket(bool should_close,bool child_processes_inherit)60 TCPSocket::TCPSocket(bool should_close, bool child_processes_inherit)
61 : Socket(ProtocolTcp, should_close, child_processes_inherit) {}
62
TCPSocket(NativeSocket socket,const TCPSocket & listen_socket)63 TCPSocket::TCPSocket(NativeSocket socket, const TCPSocket &listen_socket)
64 : Socket(ProtocolTcp, listen_socket.m_should_close_fd,
65 listen_socket.m_child_processes_inherit) {
66 m_socket = socket;
67 }
68
TCPSocket(NativeSocket socket,bool should_close,bool child_processes_inherit)69 TCPSocket::TCPSocket(NativeSocket socket, bool should_close,
70 bool child_processes_inherit)
71 : Socket(ProtocolTcp, should_close, child_processes_inherit) {
72 m_socket = socket;
73 }
74
~TCPSocket()75 TCPSocket::~TCPSocket() { CloseListenSockets(); }
76
IsValid() const77 bool TCPSocket::IsValid() const {
78 return m_socket != kInvalidSocketValue || m_listen_sockets.size() != 0;
79 }
80
81 // Return the port number that is being used by the socket.
GetLocalPortNumber() const82 uint16_t TCPSocket::GetLocalPortNumber() const {
83 if (m_socket != kInvalidSocketValue) {
84 SocketAddress sock_addr;
85 socklen_t sock_addr_len = sock_addr.GetMaxLength();
86 if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
87 return sock_addr.GetPort();
88 } else if (!m_listen_sockets.empty()) {
89 SocketAddress sock_addr;
90 socklen_t sock_addr_len = sock_addr.GetMaxLength();
91 if (::getsockname(m_listen_sockets.begin()->first, sock_addr,
92 &sock_addr_len) == 0)
93 return sock_addr.GetPort();
94 }
95 return 0;
96 }
97
GetLocalIPAddress() const98 std::string TCPSocket::GetLocalIPAddress() const {
99 // We bound to port zero, so we need to figure out which port we actually
100 // bound to
101 if (m_socket != kInvalidSocketValue) {
102 SocketAddress sock_addr;
103 socklen_t sock_addr_len = sock_addr.GetMaxLength();
104 if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
105 return sock_addr.GetIPAddress();
106 }
107 return "";
108 }
109
GetRemotePortNumber() const110 uint16_t TCPSocket::GetRemotePortNumber() const {
111 if (m_socket != kInvalidSocketValue) {
112 SocketAddress sock_addr;
113 socklen_t sock_addr_len = sock_addr.GetMaxLength();
114 if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
115 return sock_addr.GetPort();
116 }
117 return 0;
118 }
119
GetRemoteIPAddress() const120 std::string TCPSocket::GetRemoteIPAddress() const {
121 // We bound to port zero, so we need to figure out which port we actually
122 // bound to
123 if (m_socket != kInvalidSocketValue) {
124 SocketAddress sock_addr;
125 socklen_t sock_addr_len = sock_addr.GetMaxLength();
126 if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
127 return sock_addr.GetIPAddress();
128 }
129 return "";
130 }
131
GetRemoteConnectionURI() const132 std::string TCPSocket::GetRemoteConnectionURI() const {
133 if (m_socket != kInvalidSocketValue) {
134 return std::string(llvm::formatv(
135 "connect://[{0}]:{1}", GetRemoteIPAddress(), GetRemotePortNumber()));
136 }
137 return "";
138 }
139
CreateSocket(int domain)140 Status TCPSocket::CreateSocket(int domain) {
141 Status error;
142 if (IsValid())
143 error = Close();
144 if (error.Fail())
145 return error;
146 m_socket = Socket::CreateSocket(domain, kType, IPPROTO_TCP,
147 m_child_processes_inherit, error);
148 return error;
149 }
150
Connect(llvm::StringRef name)151 Status TCPSocket::Connect(llvm::StringRef name) {
152
153 Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_COMMUNICATION));
154 LLDB_LOGF(log, "TCPSocket::%s (host/port = %s)", __FUNCTION__, name.data());
155
156 Status error;
157 std::string host_str;
158 std::string port_str;
159 int32_t port = INT32_MIN;
160 if (!DecodeHostAndPort(name, host_str, port_str, port, &error))
161 return error;
162
163 std::vector<SocketAddress> addresses = SocketAddress::GetAddressInfo(
164 host_str.c_str(), nullptr, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
165 for (SocketAddress &address : addresses) {
166 error = CreateSocket(address.GetFamily());
167 if (error.Fail())
168 continue;
169
170 address.SetPort(port);
171
172 if (-1 == llvm::sys::RetryAfterSignal(-1, ::connect,
173 GetNativeSocket(), &address.sockaddr(), address.GetLength())) {
174 CLOSE_SOCKET(GetNativeSocket());
175 continue;
176 }
177
178 SetOptionNoDelay();
179
180 error.Clear();
181 return error;
182 }
183
184 error.SetErrorString("Failed to connect port");
185 return error;
186 }
187
Listen(llvm::StringRef name,int backlog)188 Status TCPSocket::Listen(llvm::StringRef name, int backlog) {
189 Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_CONNECTION));
190 LLDB_LOGF(log, "TCPSocket::%s (%s)", __FUNCTION__, name.data());
191
192 Status error;
193 std::string host_str;
194 std::string port_str;
195 int32_t port = INT32_MIN;
196 if (!DecodeHostAndPort(name, host_str, port_str, port, &error))
197 return error;
198
199 if (host_str == "*")
200 host_str = "0.0.0.0";
201 std::vector<SocketAddress> addresses = SocketAddress::GetAddressInfo(
202 host_str.c_str(), nullptr, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
203 for (SocketAddress &address : addresses) {
204 int fd = Socket::CreateSocket(address.GetFamily(), kType, IPPROTO_TCP,
205 m_child_processes_inherit, error);
206 if (error.Fail())
207 continue;
208
209 // enable local address reuse
210 int option_value = 1;
211 set_socket_option_arg_type option_value_p =
212 reinterpret_cast<set_socket_option_arg_type>(&option_value);
213 ::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, option_value_p,
214 sizeof(option_value));
215
216 SocketAddress listen_address = address;
217 if(!listen_address.IsLocalhost())
218 listen_address.SetToAnyAddress(address.GetFamily(), port);
219 else
220 listen_address.SetPort(port);
221
222 int err =
223 ::bind(fd, &listen_address.sockaddr(), listen_address.GetLength());
224 if (-1 != err)
225 err = ::listen(fd, backlog);
226
227 if (-1 == err) {
228 error = GetLastSocketError();
229 CLOSE_SOCKET(fd);
230 continue;
231 }
232
233 if (port == 0) {
234 socklen_t sa_len = address.GetLength();
235 if (getsockname(fd, &address.sockaddr(), &sa_len) == 0)
236 port = address.GetPort();
237 }
238 m_listen_sockets[fd] = address;
239 }
240
241 if (m_listen_sockets.empty()) {
242 assert(error.Fail());
243 return error;
244 }
245 return Status();
246 }
247
CloseListenSockets()248 void TCPSocket::CloseListenSockets() {
249 for (auto socket : m_listen_sockets)
250 CLOSE_SOCKET(socket.first);
251 m_listen_sockets.clear();
252 }
253
Accept(Socket * & conn_socket)254 Status TCPSocket::Accept(Socket *&conn_socket) {
255 Status error;
256 if (m_listen_sockets.size() == 0) {
257 error.SetErrorString("No open listening sockets!");
258 return error;
259 }
260
261 int sock = -1;
262 int listen_sock = -1;
263 lldb_private::SocketAddress AcceptAddr;
264 MainLoop accept_loop;
265 std::vector<MainLoopBase::ReadHandleUP> handles;
266 for (auto socket : m_listen_sockets) {
267 auto fd = socket.first;
268 auto inherit = this->m_child_processes_inherit;
269 auto io_sp = IOObjectSP(new TCPSocket(socket.first, false, inherit));
270 handles.emplace_back(accept_loop.RegisterReadObject(
271 io_sp, [fd, inherit, &sock, &AcceptAddr, &error,
272 &listen_sock](MainLoopBase &loop) {
273 socklen_t sa_len = AcceptAddr.GetMaxLength();
274 sock = AcceptSocket(fd, &AcceptAddr.sockaddr(), &sa_len, inherit,
275 error);
276 listen_sock = fd;
277 loop.RequestTermination();
278 }, error));
279 if (error.Fail())
280 return error;
281 }
282
283 bool accept_connection = false;
284 std::unique_ptr<TCPSocket> accepted_socket;
285 // Loop until we are happy with our connection
286 while (!accept_connection) {
287 accept_loop.Run();
288
289 if (error.Fail())
290 return error;
291
292 lldb_private::SocketAddress &AddrIn = m_listen_sockets[listen_sock];
293 if (!AddrIn.IsAnyAddr() && AcceptAddr != AddrIn) {
294 CLOSE_SOCKET(sock);
295 llvm::errs() << llvm::formatv(
296 "error: rejecting incoming connection from {0} (expecting {1})",
297 AcceptAddr.GetIPAddress(), AddrIn.GetIPAddress());
298 continue;
299 }
300 accept_connection = true;
301 accepted_socket.reset(new TCPSocket(sock, *this));
302 }
303
304 if (!accepted_socket)
305 return error;
306
307 // Keep our TCP packets coming without any delays.
308 accepted_socket->SetOptionNoDelay();
309 error.Clear();
310 conn_socket = accepted_socket.release();
311 return error;
312 }
313
SetOptionNoDelay()314 int TCPSocket::SetOptionNoDelay() {
315 return SetOption(IPPROTO_TCP, TCP_NODELAY, 1);
316 }
317
SetOptionReuseAddress()318 int TCPSocket::SetOptionReuseAddress() {
319 return SetOption(SOL_SOCKET, SO_REUSEADDR, 1);
320 }
321