1 // Copyright 2021 The Tint Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "tools/src/cmd/remote-compile/socket.h"
16
17 #include "tools/src/cmd/remote-compile/rwmutex.h"
18
19 #if defined(_WIN32)
20 #include <winsock2.h>
21 #include <ws2tcpip.h>
22 #else
23 #include <netdb.h>
24 #include <netinet/in.h>
25 #include <netinet/tcp.h>
26 #include <sys/select.h>
27 #include <sys/socket.h>
28 #include <sys/time.h>
29 #include <unistd.h>
30 #endif
31
32 #if defined(_WIN32)
33 #include <atomic>
34 namespace {
35 std::atomic<int> wsaInitCount = {0};
36 } // anonymous namespace
37 #else
38 #include <fcntl.h>
39 namespace {
40 using SOCKET = int;
41 } // anonymous namespace
42 #endif
43
44 namespace {
45 constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1);
init()46 void init() {
47 #if defined(_WIN32)
48 if (wsaInitCount++ == 0) {
49 WSADATA winsockData;
50 (void)WSAStartup(MAKEWORD(2, 2), &winsockData);
51 }
52 #endif
53 }
54
term()55 void term() {
56 #if defined(_WIN32)
57 if (--wsaInitCount == 0) {
58 WSACleanup();
59 }
60 #endif
61 }
62
setBlocking(SOCKET s,bool blocking)63 bool setBlocking(SOCKET s, bool blocking) {
64 #if defined(_WIN32)
65 u_long mode = blocking ? 0 : 1;
66 return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR;
67 #else
68 auto arg = fcntl(s, F_GETFL, nullptr);
69 if (arg < 0) {
70 return false;
71 }
72 arg = blocking ? (arg & ~O_NONBLOCK) : (arg | O_NONBLOCK);
73 return fcntl(s, F_SETFL, arg) >= 0;
74 #endif
75 }
76
errored(SOCKET s)77 bool errored(SOCKET s) {
78 if (s == InvalidSocket) {
79 return true;
80 }
81 char error = 0;
82 socklen_t len = sizeof(error);
83 getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len);
84 return error != 0;
85 }
86
87 class Impl : public Socket {
88 public:
create(const char * address,const char * port)89 static std::shared_ptr<Impl> create(const char* address, const char* port) {
90 init();
91
92 addrinfo hints = {};
93 hints.ai_family = AF_INET;
94 hints.ai_socktype = SOCK_STREAM;
95 hints.ai_protocol = IPPROTO_TCP;
96 hints.ai_flags = AI_PASSIVE;
97
98 addrinfo* info = nullptr;
99 auto err = getaddrinfo(address, port, &hints, &info);
100 #if !defined(_WIN32)
101 if (err) {
102 printf("getaddrinfo(%s, %s) error: %s\n", address, port,
103 gai_strerror(err));
104 }
105 #endif
106
107 if (info) {
108 auto socket =
109 ::socket(info->ai_family, info->ai_socktype, info->ai_protocol);
110 auto out = std::make_shared<Impl>(info, socket);
111 out->setOptions();
112 return out;
113 }
114
115 freeaddrinfo(info);
116 term();
117 return nullptr;
118 }
119
Impl(SOCKET socket)120 explicit Impl(SOCKET socket) : info(nullptr), s(socket) {}
Impl(addrinfo * info,SOCKET socket)121 Impl(addrinfo* info, SOCKET socket) : info(info), s(socket) {}
122
~Impl()123 ~Impl() {
124 freeaddrinfo(info);
125 Close();
126 term();
127 }
128
129 template <typename FUNCTION>
lock(FUNCTION && f)130 void lock(FUNCTION&& f) {
131 RLock l(mutex);
132 f(s, info);
133 }
134
setOptions()135 void setOptions() {
136 RLock l(mutex);
137 if (s == InvalidSocket) {
138 return;
139 }
140
141 int enable = 1;
142
143 #if !defined(_WIN32)
144 // Prevent sockets lingering after process termination, causing
145 // reconnection issues on the same port.
146 setsockopt(s, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&enable),
147 sizeof(enable));
148
149 struct {
150 int l_onoff; /* linger active */
151 int l_linger; /* how many seconds to linger for */
152 } linger = {false, 0};
153 setsockopt(s, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&linger),
154 sizeof(linger));
155 #endif // !defined(_WIN32)
156
157 // Enable TCP_NODELAY.
158 setsockopt(s, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char*>(&enable),
159 sizeof(enable));
160 }
161
IsOpen()162 bool IsOpen() override {
163 {
164 RLock l(mutex);
165 if ((s != InvalidSocket) && !errored(s)) {
166 return true;
167 }
168 }
169 WLock lock(mutex);
170 s = InvalidSocket;
171 return false;
172 }
173
Close()174 void Close() override {
175 {
176 RLock l(mutex);
177 if (s != InvalidSocket) {
178 #if defined(_WIN32)
179 closesocket(s);
180 #else
181 ::shutdown(s, SHUT_RDWR);
182 #endif
183 }
184 }
185
186 WLock l(mutex);
187 if (s != InvalidSocket) {
188 #if !defined(_WIN32)
189 ::close(s);
190 #endif
191 s = InvalidSocket;
192 }
193 }
194
Read(void * buffer,size_t bytes)195 size_t Read(void* buffer, size_t bytes) override {
196 RLock lock(mutex);
197 if (s == InvalidSocket) {
198 return 0;
199 }
200 auto len =
201 recv(s, reinterpret_cast<char*>(buffer), static_cast<int>(bytes), 0);
202 return (len < 0) ? 0 : len;
203 }
204
Write(const void * buffer,size_t bytes)205 bool Write(const void* buffer, size_t bytes) override {
206 RLock lock(mutex);
207 if (s == InvalidSocket) {
208 return false;
209 }
210 if (bytes == 0) {
211 return true;
212 }
213 return ::send(s, reinterpret_cast<const char*>(buffer),
214 static_cast<int>(bytes), 0) > 0;
215 }
216
Accept()217 std::shared_ptr<Socket> Accept() override {
218 std::shared_ptr<Impl> out;
219 lock([&](SOCKET socket, const addrinfo*) {
220 if (socket != InvalidSocket) {
221 init();
222 out = std::make_shared<Impl>(::accept(socket, 0, 0));
223 out->setOptions();
224 }
225 });
226 return out;
227 }
228
229 private:
230 addrinfo* const info;
231 SOCKET s = InvalidSocket;
232 RWMutex mutex;
233 };
234
235 } // anonymous namespace
236
Listen(const char * address,const char * port)237 std::shared_ptr<Socket> Socket::Listen(const char* address, const char* port) {
238 auto impl = Impl::create(address, port);
239 if (!impl) {
240 return nullptr;
241 }
242 impl->lock([&](SOCKET socket, const addrinfo* info) {
243 if (bind(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) != 0) {
244 impl.reset();
245 return;
246 }
247
248 if (listen(socket, 0) != 0) {
249 impl.reset();
250 return;
251 }
252 });
253 return impl;
254 }
255
Connect(const char * address,const char * port,uint32_t timeoutMillis)256 std::shared_ptr<Socket> Socket::Connect(const char* address,
257 const char* port,
258 uint32_t timeoutMillis) {
259 auto impl = Impl::create(address, port);
260 if (!impl) {
261 return nullptr;
262 }
263
264 std::shared_ptr<Socket> out;
265 impl->lock([&](SOCKET socket, const addrinfo* info) {
266 if (socket == InvalidSocket) {
267 return;
268 }
269
270 if (timeoutMillis == 0) {
271 if (::connect(socket, info->ai_addr,
272 static_cast<int>(info->ai_addrlen)) == 0) {
273 out = impl;
274 }
275 return;
276 }
277
278 if (!setBlocking(socket, false)) {
279 return;
280 }
281
282 auto res =
283 ::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen));
284 if (res == 0) {
285 if (setBlocking(socket, true)) {
286 out = impl;
287 }
288 } else {
289 const auto microseconds = timeoutMillis * 1000;
290
291 fd_set fdset;
292 FD_ZERO(&fdset);
293 FD_SET(socket, &fdset);
294
295 timeval tv;
296 tv.tv_sec = microseconds / 1000000;
297 tv.tv_usec = microseconds - static_cast<uint32_t>(tv.tv_sec * 1000000);
298 res = select(static_cast<int>(socket + 1), nullptr, &fdset, nullptr, &tv);
299 if (res > 0 && !errored(socket) && setBlocking(socket, true)) {
300 out = impl;
301 }
302 }
303 });
304
305 if (!out) {
306 return nullptr;
307 }
308
309 return out->IsOpen() ? out : nullptr;
310 }
311