• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_stream/socket_stream.h"
16 
17 #if defined(_WIN32) && _WIN32
18 #include <fcntl.h>
19 #include <io.h>
20 #include <winsock2.h>
21 #include <ws2tcpip.h>
22 #define SHUT_RDWR SD_BOTH
23 #else
24 #include <arpa/inet.h>
25 #include <netdb.h>
26 #include <netinet/in.h>
27 #include <poll.h>
28 #include <sys/socket.h>
29 #include <sys/types.h>
30 #include <unistd.h>
31 #endif  // defined(_WIN32) && _WIN32
32 
33 #include <cerrno>
34 #include <cstring>
35 
36 #include "pw_assert/check.h"
37 #include "pw_log/log.h"
38 #include "pw_status/status.h"
39 #include "pw_string/to_string.h"
40 
41 namespace pw::stream {
42 namespace {
43 
44 constexpr uint32_t kServerBacklogLength = 1;
45 constexpr const char* kLocalhostAddress = "localhost";
46 
47 // Set necessary options on a socket file descriptor.
ConfigureSocket(int socket)48 void ConfigureSocket([[maybe_unused]] int socket) {
49 #if defined(__APPLE__)
50   // Use SO_NOSIGPIPE to avoid getting a SIGPIPE signal when the remote peer
51   // drops the connection. This is supported on macOS only.
52   constexpr int value = 1;
53   if (setsockopt(socket, SOL_SOCKET, SO_NOSIGPIPE, &value, sizeof(int)) < 0) {
54     PW_LOG_WARN("Failed to set SO_NOSIGPIPE: %s", std::strerror(errno));
55   }
56 #endif  // defined(__APPLE__)
57 }
58 
59 #if defined(_WIN32) && _WIN32
close(SOCKET s)60 int close(SOCKET s) { return closesocket(s); }
61 
write(int fd,const void * buf,size_t count)62 ssize_t write(int fd, const void* buf, size_t count) {
63   return _write(fd, buf, count);
64 }
65 
poll(struct pollfd * fds,unsigned int nfds,int timeout)66 int poll(struct pollfd* fds, unsigned int nfds, int timeout) {
67   return WSAPoll(fds, nfds, timeout);
68 }
69 
pipe(int pipefd[2])70 int pipe(int pipefd[2]) { return _pipe(pipefd, 256, O_BINARY); }
71 
setsockopt(int fd,int level,int optname,const void * optval,unsigned int optlen)72 int setsockopt(
73     int fd, int level, int optname, const void* optval, unsigned int optlen) {
74   return setsockopt(static_cast<SOCKET>(fd),
75                     level,
76                     optname,
77                     static_cast<const char*>(optval),
78                     static_cast<int>(optlen));
79 }
80 
81 class WinsockInitializer {
82  public:
WinsockInitializer()83   WinsockInitializer() {
84     WSADATA data = {};
85     PW_CHECK_INT_EQ(
86         WSAStartup(MAKEWORD(2, 2), &data), 0, "Failed to initialize winsock");
87   }
~WinsockInitializer()88   ~WinsockInitializer() {
89     // TODO: b/301545011 - This currently fails, probably a cleanup race.
90     WSACleanup();
91   }
92 };
93 
94 [[maybe_unused]] WinsockInitializer initializer;
95 
96 #endif  // defined(_WIN32) && _WIN32
97 
98 }  // namespace
99 
Connect(const char * host,uint16_t port)100 Status SocketStream::SocketStream::Connect(const char* host, uint16_t port) {
101   if (host == nullptr) {
102     host = kLocalhostAddress;
103   }
104 
105   struct addrinfo hints = {};
106   struct addrinfo* res;
107   char port_buffer[6];
108   PW_CHECK(ToString(port, port_buffer).ok());
109   hints.ai_family = AF_UNSPEC;
110   hints.ai_socktype = SOCK_STREAM;
111   hints.ai_flags = AI_NUMERICSERV;
112   if (getaddrinfo(host, port_buffer, &hints, &res) != 0) {
113     PW_LOG_ERROR("Failed to configure connection address for socket");
114     return Status::InvalidArgument();
115   }
116 
117   struct addrinfo* rp;
118   int connection_fd;
119   for (rp = res; rp != nullptr; rp = rp->ai_next) {
120     connection_fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
121     if (connection_fd != kInvalidFd) {
122       break;
123     }
124   }
125 
126   if (connection_fd == kInvalidFd) {
127     PW_LOG_ERROR("Failed to create a socket: %s", std::strerror(errno));
128     freeaddrinfo(res);
129     return Status::Unknown();
130   }
131 
132   ConfigureSocket(connection_fd);
133   if (connect(connection_fd, rp->ai_addr, rp->ai_addrlen) == -1) {
134     close(connection_fd);
135     PW_LOG_ERROR(
136         "Failed to connect to %s:%d: %s", host, port, std::strerror(errno));
137     freeaddrinfo(res);
138     return Status::Unknown();
139   }
140 
141   // Mark as ready and take ownership of the connection by this object.
142   {
143     std::lock_guard lock(connection_mutex_);
144     connection_fd_ = connection_fd;
145     TakeConnectionWithLockHeld();
146     ready_ = true;
147   }
148 
149   freeaddrinfo(res);
150   return OkStatus();
151 }
152 
153 // Configures socket options.
SetSockOpt(int level,int optname,const void * optval,unsigned int optlen)154 int SocketStream::SetSockOpt(int level,
155                              int optname,
156                              const void* optval,
157                              unsigned int optlen) {
158   ConnectionOwnership ownership(this);
159   if (ownership.fd() == kInvalidFd) {
160     return EBADF;
161   }
162   return setsockopt(ownership.fd(), level, optname, optval, optlen);
163 }
164 
IsReady()165 bool SocketStream::IsReady() {
166   std::lock_guard lock(connection_mutex_);
167   return ready_;
168 }
169 
Close()170 void SocketStream::Close() {
171   ConnectionOwnership ownership(this);
172   {
173     std::lock_guard lock(connection_mutex_);
174     if (ready_) {
175       // Shutdown the connection and send tear down notification to unblock any
176       // waiters.
177       if (connection_fd_ != kInvalidFd) {
178         shutdown(connection_fd_, SHUT_RDWR);
179       }
180       if (connection_pipe_w_fd_ != kInvalidFd) {
181         write(connection_pipe_w_fd_, "T", 1);
182       }
183 
184       // Release ownership of the connection by this object and mark as no
185       // longer ready.
186       ReleaseConnectionWithLockHeld();
187       ready_ = false;
188     }
189   }
190 }
191 
DoWrite(span<const std::byte> data)192 Status SocketStream::DoWrite(span<const std::byte> data) {
193   int send_flags = 0;
194 #if defined(__linux__)
195   // Use MSG_NOSIGNAL to avoid getting a SIGPIPE signal when the remote
196   // peer drops the connection. This is supported on Linux only.
197   send_flags |= MSG_NOSIGNAL;
198 #endif  // defined(__linux__)
199 
200   ssize_t bytes_sent;
201   {
202     ConnectionOwnership ownership(this);
203     if (ownership.fd() == kInvalidFd) {
204       return Status::Unknown();
205     }
206     bytes_sent = send(ownership.fd(),
207                       reinterpret_cast<const char*>(data.data()),
208                       data.size_bytes(),
209                       send_flags);
210   }
211 
212   if (bytes_sent < 0 || static_cast<size_t>(bytes_sent) != data.size()) {
213     if (errno == EPIPE) {
214       // An EPIPE indicates that the connection is closed.  Return an OutOfRange
215       // error.
216       return Status::OutOfRange();
217     }
218 
219     return Status::Unknown();
220   }
221   return OkStatus();
222 }
223 
DoRead(ByteSpan dest)224 StatusWithSize SocketStream::DoRead(ByteSpan dest) {
225   ConnectionOwnership ownership(this);
226   if (ownership.fd() == kInvalidFd) {
227     return StatusWithSize::Unknown();
228   }
229 
230   // Wait for data to read or a tear down notification.
231   pollfd fds_to_poll[2];
232   fds_to_poll[0].fd = ownership.fd();
233   fds_to_poll[0].events = POLLIN | POLLERR | POLLHUP;
234   fds_to_poll[1].fd = ownership.pipe_r_fd();
235   fds_to_poll[1].events = POLLIN;
236   poll(fds_to_poll, 2, -1);
237   if (!(fds_to_poll[0].revents & POLLIN)) {
238     return StatusWithSize::Unknown();
239   }
240 
241   ssize_t bytes_rcvd = recv(ownership.fd(),
242                             reinterpret_cast<char*>(dest.data()),
243                             dest.size_bytes(),
244                             0);
245   if (bytes_rcvd == 0) {
246     // Remote peer has closed the connection.
247     Close();
248     return StatusWithSize::OutOfRange();
249   } else if (bytes_rcvd < 0) {
250     if (errno == EAGAIN || errno == EWOULDBLOCK) {
251       // Socket timed out when trying to read.
252       // This should only occur if SO_RCVTIMEO was configured to be nonzero, or
253       // if the socket was opened with the O_NONBLOCK flag to prevent any
254       // blocking when performing reads or writes.
255       return StatusWithSize::ResourceExhausted();
256     }
257     return StatusWithSize::Unknown();
258   }
259   return StatusWithSize(bytes_rcvd);
260 }
261 
TakeConnection()262 int SocketStream::TakeConnection() {
263   std::lock_guard lock(connection_mutex_);
264   return TakeConnectionWithLockHeld();
265 }
266 
TakeConnectionWithLockHeld()267 int SocketStream::TakeConnectionWithLockHeld() {
268   ++connection_own_count_;
269 
270   if (ready_ && (connection_fd_ != kInvalidFd) &&
271       (connection_pipe_r_fd_ == kInvalidFd)) {
272     int fd_list[2];
273     if (pipe(fd_list) >= 0) {
274       connection_pipe_r_fd_ = fd_list[0];
275       connection_pipe_w_fd_ = fd_list[1];
276     }
277   }
278 
279   if (!ready_ || (connection_pipe_r_fd_ == kInvalidFd) ||
280       (connection_pipe_w_fd_ == kInvalidFd)) {
281     return kInvalidFd;
282   }
283   return connection_fd_;
284 }
285 
ReleaseConnection()286 void SocketStream::ReleaseConnection() {
287   std::lock_guard lock(connection_mutex_);
288   ReleaseConnectionWithLockHeld();
289 }
290 
ReleaseConnectionWithLockHeld()291 void SocketStream::ReleaseConnectionWithLockHeld() {
292   --connection_own_count_;
293 
294   if (connection_own_count_ <= 0) {
295     ready_ = false;
296     if (connection_fd_ != kInvalidFd) {
297       close(connection_fd_);
298       connection_fd_ = kInvalidFd;
299     }
300     if (connection_pipe_r_fd_ != kInvalidFd) {
301       close(connection_pipe_r_fd_);
302       connection_pipe_r_fd_ = kInvalidFd;
303     }
304     if (connection_pipe_w_fd_ != kInvalidFd) {
305       close(connection_pipe_w_fd_);
306       connection_pipe_w_fd_ = kInvalidFd;
307     }
308   }
309 }
310 
311 // Listen for connections on the given port.
312 // If port is 0, a random unused port is chosen and can be retrieved with
313 // port().
Listen(uint16_t port)314 Status ServerSocket::Listen(uint16_t port) {
315   int socket_fd = socket(AF_INET6, SOCK_STREAM, 0);
316   if (socket_fd == kInvalidFd) {
317     return Status::Unknown();
318   }
319 
320   // Allow binding to an address that may still be in use by a closed socket.
321   constexpr int value = 1;
322   setsockopt(socket_fd,
323              SOL_SOCKET,
324              SO_REUSEADDR,
325              reinterpret_cast<const char*>(&value),
326              sizeof(int));
327 
328   if (port != 0) {
329     struct sockaddr_in6 addr = {};
330     socklen_t addr_len = sizeof(addr);
331     addr.sin6_family = AF_INET6;
332     addr.sin6_port = htons(port);
333     addr.sin6_addr = in6addr_any;
334     if (bind(socket_fd, reinterpret_cast<sockaddr*>(&addr), addr_len) < 0) {
335       close(socket_fd);
336       return Status::Unknown();
337     }
338   }
339 
340   if (listen(socket_fd, kServerBacklogLength) < 0) {
341     close(socket_fd);
342     return Status::Unknown();
343   }
344 
345   // Find out which port the socket is listening on, and fill in port_.
346   struct sockaddr_in6 addr = {};
347   socklen_t addr_len = sizeof(addr);
348   if (getsockname(socket_fd, reinterpret_cast<sockaddr*>(&addr), &addr_len) <
349           0 ||
350       static_cast<size_t>(addr_len) > sizeof(addr)) {
351     close(socket_fd);
352     return Status::Unknown();
353   }
354 
355   port_ = ntohs(addr.sin6_port);
356 
357   // Mark as ready and take ownership of the socket by this object.
358   {
359     std::lock_guard lock(socket_mutex_);
360     socket_fd_ = socket_fd;
361     TakeSocketWithLockHeld();
362     ready_ = true;
363   }
364 
365   return OkStatus();
366 }
367 
368 // Accept a connection. Blocks until after a client is connected.
369 // On success, returns a SocketStream connected to the new client.
Accept()370 Result<SocketStream> ServerSocket::Accept() {
371   struct sockaddr_in6 sockaddr_client_ = {};
372   socklen_t len = sizeof(sockaddr_client_);
373 
374   SocketOwnership ownership(this);
375   if (ownership.fd() == kInvalidFd) {
376     return Status::Unknown();
377   }
378 
379   // Wait for a connection or a tear down notification.
380   pollfd fds_to_poll[2];
381   fds_to_poll[0].fd = ownership.fd();
382   fds_to_poll[0].events = POLLIN | POLLERR | POLLHUP;
383   fds_to_poll[1].fd = ownership.pipe_r_fd();
384   fds_to_poll[1].events = POLLIN;
385   int rv = poll(fds_to_poll, 2, -1);
386   if ((rv <= 0) || !(fds_to_poll[0].revents & POLLIN)) {
387     return Status::Unknown();
388   }
389 
390   int connection_fd = accept(
391       ownership.fd(), reinterpret_cast<sockaddr*>(&sockaddr_client_), &len);
392   if (connection_fd == kInvalidFd) {
393     return Status::Unknown();
394   }
395   ConfigureSocket(connection_fd);
396 
397   return SocketStream(connection_fd);
398 }
399 
400 // Close the server socket, preventing further connections.
Close()401 void ServerSocket::Close() {
402   SocketOwnership ownership(this);
403   {
404     std::lock_guard lock(socket_mutex_);
405     if (ready_) {
406       // Shutdown the socket and send tear down notification to unblock any
407       // waiters.
408       if (socket_fd_ != kInvalidFd) {
409         shutdown(socket_fd_, SHUT_RDWR);
410       }
411       if (socket_pipe_w_fd_ != kInvalidFd) {
412         write(socket_pipe_w_fd_, "T", 1);
413       }
414 
415       // Release ownership of the socket by this object and mark as no longer
416       // ready.
417       ReleaseSocketWithLockHeld();
418       ready_ = false;
419     }
420   }
421 }
422 
TakeSocket()423 int ServerSocket::TakeSocket() {
424   std::lock_guard lock(socket_mutex_);
425   return TakeSocketWithLockHeld();
426 }
427 
TakeSocketWithLockHeld()428 int ServerSocket::TakeSocketWithLockHeld() {
429   ++socket_own_count_;
430 
431   if (ready_ && (socket_fd_ != kInvalidFd) &&
432       (socket_pipe_r_fd_ == kInvalidFd)) {
433     int fd_list[2];
434     if (pipe(fd_list) >= 0) {
435       socket_pipe_r_fd_ = fd_list[0];
436       socket_pipe_w_fd_ = fd_list[1];
437     }
438   }
439 
440   if (!ready_ || (socket_pipe_r_fd_ == kInvalidFd) ||
441       (socket_pipe_w_fd_ == kInvalidFd)) {
442     return kInvalidFd;
443   }
444   return socket_fd_;
445 }
446 
ReleaseSocket()447 void ServerSocket::ReleaseSocket() {
448   std::lock_guard lock(socket_mutex_);
449   ReleaseSocketWithLockHeld();
450 }
451 
ReleaseSocketWithLockHeld()452 void ServerSocket::ReleaseSocketWithLockHeld() {
453   --socket_own_count_;
454 
455   if (socket_own_count_ <= 0) {
456     ready_ = false;
457     if (socket_fd_ != kInvalidFd) {
458       close(socket_fd_);
459       socket_fd_ = kInvalidFd;
460     }
461     if (socket_pipe_r_fd_ != kInvalidFd) {
462       close(socket_pipe_r_fd_);
463       socket_pipe_r_fd_ = kInvalidFd;
464     }
465     if (socket_pipe_w_fd_ != kInvalidFd) {
466       close(socket_pipe_w_fd_);
467       socket_pipe_w_fd_ = kInvalidFd;
468     }
469   }
470 }
471 
472 }  // namespace pw::stream
473