• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "tools/src/cmd/remote-compile/socket.h"
16 
17 #include "tools/src/cmd/remote-compile/rwmutex.h"
18 
19 #if defined(_WIN32)
20 #include <winsock2.h>
21 #include <ws2tcpip.h>
22 #else
23 #include <netdb.h>
24 #include <netinet/in.h>
25 #include <netinet/tcp.h>
26 #include <sys/select.h>
27 #include <sys/socket.h>
28 #include <sys/time.h>
29 #include <unistd.h>
30 #endif
31 
32 #if defined(_WIN32)
33 #include <atomic>
34 namespace {
35 std::atomic<int> wsaInitCount = {0};
36 }  // anonymous namespace
37 #else
38 #include <fcntl.h>
39 namespace {
40 using SOCKET = int;
41 }  // anonymous namespace
42 #endif
43 
44 namespace {
45 constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1);
init()46 void init() {
47 #if defined(_WIN32)
48   if (wsaInitCount++ == 0) {
49     WSADATA winsockData;
50     (void)WSAStartup(MAKEWORD(2, 2), &winsockData);
51   }
52 #endif
53 }
54 
term()55 void term() {
56 #if defined(_WIN32)
57   if (--wsaInitCount == 0) {
58     WSACleanup();
59   }
60 #endif
61 }
62 
setBlocking(SOCKET s,bool blocking)63 bool setBlocking(SOCKET s, bool blocking) {
64 #if defined(_WIN32)
65   u_long mode = blocking ? 0 : 1;
66   return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR;
67 #else
68   auto arg = fcntl(s, F_GETFL, nullptr);
69   if (arg < 0) {
70     return false;
71   }
72   arg = blocking ? (arg & ~O_NONBLOCK) : (arg | O_NONBLOCK);
73   return fcntl(s, F_SETFL, arg) >= 0;
74 #endif
75 }
76 
errored(SOCKET s)77 bool errored(SOCKET s) {
78   if (s == InvalidSocket) {
79     return true;
80   }
81   char error = 0;
82   socklen_t len = sizeof(error);
83   getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len);
84   return error != 0;
85 }
86 
87 class Impl : public Socket {
88  public:
create(const char * address,const char * port)89   static std::shared_ptr<Impl> create(const char* address, const char* port) {
90     init();
91 
92     addrinfo hints = {};
93     hints.ai_family = AF_INET;
94     hints.ai_socktype = SOCK_STREAM;
95     hints.ai_protocol = IPPROTO_TCP;
96     hints.ai_flags = AI_PASSIVE;
97 
98     addrinfo* info = nullptr;
99     auto err = getaddrinfo(address, port, &hints, &info);
100 #if !defined(_WIN32)
101     if (err) {
102       printf("getaddrinfo(%s, %s) error: %s\n", address, port,
103              gai_strerror(err));
104     }
105 #endif
106 
107     if (info) {
108       auto socket =
109           ::socket(info->ai_family, info->ai_socktype, info->ai_protocol);
110       auto out = std::make_shared<Impl>(info, socket);
111       out->setOptions();
112       return out;
113     }
114 
115     freeaddrinfo(info);
116     term();
117     return nullptr;
118   }
119 
Impl(SOCKET socket)120   explicit Impl(SOCKET socket) : info(nullptr), s(socket) {}
Impl(addrinfo * info,SOCKET socket)121   Impl(addrinfo* info, SOCKET socket) : info(info), s(socket) {}
122 
~Impl()123   ~Impl() {
124     freeaddrinfo(info);
125     Close();
126     term();
127   }
128 
129   template <typename FUNCTION>
lock(FUNCTION && f)130   void lock(FUNCTION&& f) {
131     RLock l(mutex);
132     f(s, info);
133   }
134 
setOptions()135   void setOptions() {
136     RLock l(mutex);
137     if (s == InvalidSocket) {
138       return;
139     }
140 
141     int enable = 1;
142 
143 #if !defined(_WIN32)
144     // Prevent sockets lingering after process termination, causing
145     // reconnection issues on the same port.
146     setsockopt(s, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&enable),
147                sizeof(enable));
148 
149     struct {
150       int l_onoff;  /* linger active */
151       int l_linger; /* how many seconds to linger for */
152     } linger = {false, 0};
153     setsockopt(s, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&linger),
154                sizeof(linger));
155 #endif  // !defined(_WIN32)
156 
157     // Enable TCP_NODELAY.
158     setsockopt(s, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char*>(&enable),
159                sizeof(enable));
160   }
161 
IsOpen()162   bool IsOpen() override {
163     {
164       RLock l(mutex);
165       if ((s != InvalidSocket) && !errored(s)) {
166         return true;
167       }
168     }
169     WLock lock(mutex);
170     s = InvalidSocket;
171     return false;
172   }
173 
Close()174   void Close() override {
175     {
176       RLock l(mutex);
177       if (s != InvalidSocket) {
178 #if defined(_WIN32)
179         closesocket(s);
180 #else
181         ::shutdown(s, SHUT_RDWR);
182 #endif
183       }
184     }
185 
186     WLock l(mutex);
187     if (s != InvalidSocket) {
188 #if !defined(_WIN32)
189       ::close(s);
190 #endif
191       s = InvalidSocket;
192     }
193   }
194 
Read(void * buffer,size_t bytes)195   size_t Read(void* buffer, size_t bytes) override {
196     RLock lock(mutex);
197     if (s == InvalidSocket) {
198       return 0;
199     }
200     auto len =
201         recv(s, reinterpret_cast<char*>(buffer), static_cast<int>(bytes), 0);
202     return (len < 0) ? 0 : len;
203   }
204 
Write(const void * buffer,size_t bytes)205   bool Write(const void* buffer, size_t bytes) override {
206     RLock lock(mutex);
207     if (s == InvalidSocket) {
208       return false;
209     }
210     if (bytes == 0) {
211       return true;
212     }
213     return ::send(s, reinterpret_cast<const char*>(buffer),
214                   static_cast<int>(bytes), 0) > 0;
215   }
216 
Accept()217   std::shared_ptr<Socket> Accept() override {
218     std::shared_ptr<Impl> out;
219     lock([&](SOCKET socket, const addrinfo*) {
220       if (socket != InvalidSocket) {
221         init();
222         out = std::make_shared<Impl>(::accept(socket, 0, 0));
223         out->setOptions();
224       }
225     });
226     return out;
227   }
228 
229  private:
230   addrinfo* const info;
231   SOCKET s = InvalidSocket;
232   RWMutex mutex;
233 };
234 
235 }  // anonymous namespace
236 
Listen(const char * address,const char * port)237 std::shared_ptr<Socket> Socket::Listen(const char* address, const char* port) {
238   auto impl = Impl::create(address, port);
239   if (!impl) {
240     return nullptr;
241   }
242   impl->lock([&](SOCKET socket, const addrinfo* info) {
243     if (bind(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) != 0) {
244       impl.reset();
245       return;
246     }
247 
248     if (listen(socket, 0) != 0) {
249       impl.reset();
250       return;
251     }
252   });
253   return impl;
254 }
255 
Connect(const char * address,const char * port,uint32_t timeoutMillis)256 std::shared_ptr<Socket> Socket::Connect(const char* address,
257                                         const char* port,
258                                         uint32_t timeoutMillis) {
259   auto impl = Impl::create(address, port);
260   if (!impl) {
261     return nullptr;
262   }
263 
264   std::shared_ptr<Socket> out;
265   impl->lock([&](SOCKET socket, const addrinfo* info) {
266     if (socket == InvalidSocket) {
267       return;
268     }
269 
270     if (timeoutMillis == 0) {
271       if (::connect(socket, info->ai_addr,
272                     static_cast<int>(info->ai_addrlen)) == 0) {
273         out = impl;
274       }
275       return;
276     }
277 
278     if (!setBlocking(socket, false)) {
279       return;
280     }
281 
282     auto res =
283         ::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen));
284     if (res == 0) {
285       if (setBlocking(socket, true)) {
286         out = impl;
287       }
288     } else {
289       const auto microseconds = timeoutMillis * 1000;
290 
291       fd_set fdset;
292       FD_ZERO(&fdset);
293       FD_SET(socket, &fdset);
294 
295       timeval tv;
296       tv.tv_sec = microseconds / 1000000;
297       tv.tv_usec = microseconds - static_cast<uint32_t>(tv.tv_sec * 1000000);
298       res = select(static_cast<int>(socket + 1), nullptr, &fdset, nullptr, &tv);
299       if (res > 0 && !errored(socket) && setBlocking(socket, true)) {
300         out = impl;
301       }
302     }
303   });
304 
305   if (!out) {
306     return nullptr;
307   }
308 
309   return out->IsOpen() ? out : nullptr;
310 }
311