1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_stream/socket_stream.h"
16
17 #if defined(_WIN32) && _WIN32
18 #include <fcntl.h>
19 #include <io.h>
20 #include <winsock2.h>
21 #include <ws2tcpip.h>
22 #define SHUT_RDWR SD_BOTH
23 #else
24 #include <arpa/inet.h>
25 #include <netdb.h>
26 #include <netinet/in.h>
27 #include <poll.h>
28 #include <sys/socket.h>
29 #include <sys/types.h>
30 #include <unistd.h>
31 #endif // defined(_WIN32) && _WIN32
32
33 #include <cerrno>
34 #include <cstring>
35
36 #include "pw_assert/check.h"
37 #include "pw_log/log.h"
38 #include "pw_status/status.h"
39 #include "pw_string/to_string.h"
40
41 namespace pw::stream {
42 namespace {
43
44 constexpr uint32_t kServerBacklogLength = 1;
45 constexpr const char* kLocalhostAddress = "localhost";
46
47 // Set necessary options on a socket file descriptor.
ConfigureSocket(int socket)48 void ConfigureSocket([[maybe_unused]] int socket) {
49 #if defined(__APPLE__)
50 // Use SO_NOSIGPIPE to avoid getting a SIGPIPE signal when the remote peer
51 // drops the connection. This is supported on macOS only.
52 constexpr int value = 1;
53 if (setsockopt(socket, SOL_SOCKET, SO_NOSIGPIPE, &value, sizeof(int)) < 0) {
54 PW_LOG_WARN("Failed to set SO_NOSIGPIPE: %s", std::strerror(errno));
55 }
56 #endif // defined(__APPLE__)
57 }
58
59 #if defined(_WIN32) && _WIN32
close(SOCKET s)60 int close(SOCKET s) { return closesocket(s); }
61
write(int fd,const void * buf,size_t count)62 ssize_t write(int fd, const void* buf, size_t count) {
63 return _write(fd, buf, count);
64 }
65
poll(struct pollfd * fds,unsigned int nfds,int timeout)66 int poll(struct pollfd* fds, unsigned int nfds, int timeout) {
67 return WSAPoll(fds, nfds, timeout);
68 }
69
pipe(int pipefd[2])70 int pipe(int pipefd[2]) { return _pipe(pipefd, 256, O_BINARY); }
71
setsockopt(int fd,int level,int optname,const void * optval,unsigned int optlen)72 int setsockopt(
73 int fd, int level, int optname, const void* optval, unsigned int optlen) {
74 return setsockopt(static_cast<SOCKET>(fd),
75 level,
76 optname,
77 static_cast<const char*>(optval),
78 static_cast<int>(optlen));
79 }
80
81 class WinsockInitializer {
82 public:
WinsockInitializer()83 WinsockInitializer() {
84 WSADATA data = {};
85 PW_CHECK_INT_EQ(
86 WSAStartup(MAKEWORD(2, 2), &data), 0, "Failed to initialize winsock");
87 }
~WinsockInitializer()88 ~WinsockInitializer() {
89 // TODO: b/301545011 - This currently fails, probably a cleanup race.
90 WSACleanup();
91 }
92 };
93
94 [[maybe_unused]] WinsockInitializer initializer;
95
96 #endif // defined(_WIN32) && _WIN32
97
98 } // namespace
99
Connect(const char * host,uint16_t port)100 Status SocketStream::SocketStream::Connect(const char* host, uint16_t port) {
101 if (host == nullptr) {
102 host = kLocalhostAddress;
103 }
104
105 struct addrinfo hints = {};
106 struct addrinfo* res;
107 char port_buffer[6];
108 PW_CHECK(ToString(port, port_buffer).ok());
109 hints.ai_family = AF_UNSPEC;
110 hints.ai_socktype = SOCK_STREAM;
111 hints.ai_flags = AI_NUMERICSERV;
112 if (getaddrinfo(host, port_buffer, &hints, &res) != 0) {
113 PW_LOG_ERROR("Failed to configure connection address for socket");
114 return Status::InvalidArgument();
115 }
116
117 struct addrinfo* rp;
118 int connection_fd;
119 for (rp = res; rp != nullptr; rp = rp->ai_next) {
120 connection_fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
121 if (connection_fd != kInvalidFd) {
122 break;
123 }
124 }
125
126 if (connection_fd == kInvalidFd) {
127 PW_LOG_ERROR("Failed to create a socket: %s", std::strerror(errno));
128 freeaddrinfo(res);
129 return Status::Unknown();
130 }
131
132 ConfigureSocket(connection_fd);
133 if (connect(connection_fd, rp->ai_addr, rp->ai_addrlen) == -1) {
134 close(connection_fd);
135 PW_LOG_ERROR(
136 "Failed to connect to %s:%d: %s", host, port, std::strerror(errno));
137 freeaddrinfo(res);
138 return Status::Unknown();
139 }
140
141 // Mark as ready and take ownership of the connection by this object.
142 {
143 std::lock_guard lock(connection_mutex_);
144 connection_fd_ = connection_fd;
145 TakeConnectionWithLockHeld();
146 ready_ = true;
147 }
148
149 freeaddrinfo(res);
150 return OkStatus();
151 }
152
153 // Configures socket options.
SetSockOpt(int level,int optname,const void * optval,unsigned int optlen)154 int SocketStream::SetSockOpt(int level,
155 int optname,
156 const void* optval,
157 unsigned int optlen) {
158 ConnectionOwnership ownership(this);
159 if (ownership.fd() == kInvalidFd) {
160 return EBADF;
161 }
162 return setsockopt(ownership.fd(), level, optname, optval, optlen);
163 }
164
IsReady()165 bool SocketStream::IsReady() {
166 std::lock_guard lock(connection_mutex_);
167 return ready_;
168 }
169
Close()170 void SocketStream::Close() {
171 ConnectionOwnership ownership(this);
172 {
173 std::lock_guard lock(connection_mutex_);
174 if (ready_) {
175 // Shutdown the connection and send tear down notification to unblock any
176 // waiters.
177 if (connection_fd_ != kInvalidFd) {
178 shutdown(connection_fd_, SHUT_RDWR);
179 }
180 if (connection_pipe_w_fd_ != kInvalidFd) {
181 write(connection_pipe_w_fd_, "T", 1);
182 }
183
184 // Release ownership of the connection by this object and mark as no
185 // longer ready.
186 ReleaseConnectionWithLockHeld();
187 ready_ = false;
188 }
189 }
190 }
191
DoWrite(span<const std::byte> data)192 Status SocketStream::DoWrite(span<const std::byte> data) {
193 int send_flags = 0;
194 #if defined(__linux__)
195 // Use MSG_NOSIGNAL to avoid getting a SIGPIPE signal when the remote
196 // peer drops the connection. This is supported on Linux only.
197 send_flags |= MSG_NOSIGNAL;
198 #endif // defined(__linux__)
199
200 ssize_t bytes_sent;
201 {
202 ConnectionOwnership ownership(this);
203 if (ownership.fd() == kInvalidFd) {
204 return Status::Unknown();
205 }
206 bytes_sent = send(ownership.fd(),
207 reinterpret_cast<const char*>(data.data()),
208 data.size_bytes(),
209 send_flags);
210 }
211
212 if (bytes_sent < 0 || static_cast<size_t>(bytes_sent) != data.size()) {
213 if (errno == EPIPE) {
214 // An EPIPE indicates that the connection is closed. Return an OutOfRange
215 // error.
216 return Status::OutOfRange();
217 }
218
219 return Status::Unknown();
220 }
221 return OkStatus();
222 }
223
DoRead(ByteSpan dest)224 StatusWithSize SocketStream::DoRead(ByteSpan dest) {
225 ConnectionOwnership ownership(this);
226 if (ownership.fd() == kInvalidFd) {
227 return StatusWithSize::Unknown();
228 }
229
230 // Wait for data to read or a tear down notification.
231 pollfd fds_to_poll[2];
232 fds_to_poll[0].fd = ownership.fd();
233 fds_to_poll[0].events = POLLIN | POLLERR | POLLHUP;
234 fds_to_poll[1].fd = ownership.pipe_r_fd();
235 fds_to_poll[1].events = POLLIN;
236 poll(fds_to_poll, 2, -1);
237 if (!(fds_to_poll[0].revents & POLLIN)) {
238 return StatusWithSize::Unknown();
239 }
240
241 ssize_t bytes_rcvd = recv(ownership.fd(),
242 reinterpret_cast<char*>(dest.data()),
243 dest.size_bytes(),
244 0);
245 if (bytes_rcvd == 0) {
246 // Remote peer has closed the connection.
247 Close();
248 return StatusWithSize::OutOfRange();
249 } else if (bytes_rcvd < 0) {
250 if (errno == EAGAIN || errno == EWOULDBLOCK) {
251 // Socket timed out when trying to read.
252 // This should only occur if SO_RCVTIMEO was configured to be nonzero, or
253 // if the socket was opened with the O_NONBLOCK flag to prevent any
254 // blocking when performing reads or writes.
255 return StatusWithSize::ResourceExhausted();
256 }
257 return StatusWithSize::Unknown();
258 }
259 return StatusWithSize(bytes_rcvd);
260 }
261
TakeConnection()262 int SocketStream::TakeConnection() {
263 std::lock_guard lock(connection_mutex_);
264 return TakeConnectionWithLockHeld();
265 }
266
TakeConnectionWithLockHeld()267 int SocketStream::TakeConnectionWithLockHeld() {
268 ++connection_own_count_;
269
270 if (ready_ && (connection_fd_ != kInvalidFd) &&
271 (connection_pipe_r_fd_ == kInvalidFd)) {
272 int fd_list[2];
273 if (pipe(fd_list) >= 0) {
274 connection_pipe_r_fd_ = fd_list[0];
275 connection_pipe_w_fd_ = fd_list[1];
276 }
277 }
278
279 if (!ready_ || (connection_pipe_r_fd_ == kInvalidFd) ||
280 (connection_pipe_w_fd_ == kInvalidFd)) {
281 return kInvalidFd;
282 }
283 return connection_fd_;
284 }
285
ReleaseConnection()286 void SocketStream::ReleaseConnection() {
287 std::lock_guard lock(connection_mutex_);
288 ReleaseConnectionWithLockHeld();
289 }
290
ReleaseConnectionWithLockHeld()291 void SocketStream::ReleaseConnectionWithLockHeld() {
292 --connection_own_count_;
293
294 if (connection_own_count_ <= 0) {
295 ready_ = false;
296 if (connection_fd_ != kInvalidFd) {
297 close(connection_fd_);
298 connection_fd_ = kInvalidFd;
299 }
300 if (connection_pipe_r_fd_ != kInvalidFd) {
301 close(connection_pipe_r_fd_);
302 connection_pipe_r_fd_ = kInvalidFd;
303 }
304 if (connection_pipe_w_fd_ != kInvalidFd) {
305 close(connection_pipe_w_fd_);
306 connection_pipe_w_fd_ = kInvalidFd;
307 }
308 }
309 }
310
311 // Listen for connections on the given port.
312 // If port is 0, a random unused port is chosen and can be retrieved with
313 // port().
Listen(uint16_t port)314 Status ServerSocket::Listen(uint16_t port) {
315 int socket_fd = socket(AF_INET6, SOCK_STREAM, 0);
316 if (socket_fd == kInvalidFd) {
317 return Status::Unknown();
318 }
319
320 // Allow binding to an address that may still be in use by a closed socket.
321 constexpr int value = 1;
322 setsockopt(socket_fd,
323 SOL_SOCKET,
324 SO_REUSEADDR,
325 reinterpret_cast<const char*>(&value),
326 sizeof(int));
327
328 if (port != 0) {
329 struct sockaddr_in6 addr = {};
330 socklen_t addr_len = sizeof(addr);
331 addr.sin6_family = AF_INET6;
332 addr.sin6_port = htons(port);
333 addr.sin6_addr = in6addr_any;
334 if (bind(socket_fd, reinterpret_cast<sockaddr*>(&addr), addr_len) < 0) {
335 close(socket_fd);
336 return Status::Unknown();
337 }
338 }
339
340 if (listen(socket_fd, kServerBacklogLength) < 0) {
341 close(socket_fd);
342 return Status::Unknown();
343 }
344
345 // Find out which port the socket is listening on, and fill in port_.
346 struct sockaddr_in6 addr = {};
347 socklen_t addr_len = sizeof(addr);
348 if (getsockname(socket_fd, reinterpret_cast<sockaddr*>(&addr), &addr_len) <
349 0 ||
350 static_cast<size_t>(addr_len) > sizeof(addr)) {
351 close(socket_fd);
352 return Status::Unknown();
353 }
354
355 port_ = ntohs(addr.sin6_port);
356
357 // Mark as ready and take ownership of the socket by this object.
358 {
359 std::lock_guard lock(socket_mutex_);
360 socket_fd_ = socket_fd;
361 TakeSocketWithLockHeld();
362 ready_ = true;
363 }
364
365 return OkStatus();
366 }
367
368 // Accept a connection. Blocks until after a client is connected.
369 // On success, returns a SocketStream connected to the new client.
Accept()370 Result<SocketStream> ServerSocket::Accept() {
371 struct sockaddr_in6 sockaddr_client_ = {};
372 socklen_t len = sizeof(sockaddr_client_);
373
374 SocketOwnership ownership(this);
375 if (ownership.fd() == kInvalidFd) {
376 return Status::Unknown();
377 }
378
379 // Wait for a connection or a tear down notification.
380 pollfd fds_to_poll[2];
381 fds_to_poll[0].fd = ownership.fd();
382 fds_to_poll[0].events = POLLIN | POLLERR | POLLHUP;
383 fds_to_poll[1].fd = ownership.pipe_r_fd();
384 fds_to_poll[1].events = POLLIN;
385 int rv = poll(fds_to_poll, 2, -1);
386 if ((rv <= 0) || !(fds_to_poll[0].revents & POLLIN)) {
387 return Status::Unknown();
388 }
389
390 int connection_fd = accept(
391 ownership.fd(), reinterpret_cast<sockaddr*>(&sockaddr_client_), &len);
392 if (connection_fd == kInvalidFd) {
393 return Status::Unknown();
394 }
395 ConfigureSocket(connection_fd);
396
397 return SocketStream(connection_fd);
398 }
399
400 // Close the server socket, preventing further connections.
Close()401 void ServerSocket::Close() {
402 SocketOwnership ownership(this);
403 {
404 std::lock_guard lock(socket_mutex_);
405 if (ready_) {
406 // Shutdown the socket and send tear down notification to unblock any
407 // waiters.
408 if (socket_fd_ != kInvalidFd) {
409 shutdown(socket_fd_, SHUT_RDWR);
410 }
411 if (socket_pipe_w_fd_ != kInvalidFd) {
412 write(socket_pipe_w_fd_, "T", 1);
413 }
414
415 // Release ownership of the socket by this object and mark as no longer
416 // ready.
417 ReleaseSocketWithLockHeld();
418 ready_ = false;
419 }
420 }
421 }
422
TakeSocket()423 int ServerSocket::TakeSocket() {
424 std::lock_guard lock(socket_mutex_);
425 return TakeSocketWithLockHeld();
426 }
427
TakeSocketWithLockHeld()428 int ServerSocket::TakeSocketWithLockHeld() {
429 ++socket_own_count_;
430
431 if (ready_ && (socket_fd_ != kInvalidFd) &&
432 (socket_pipe_r_fd_ == kInvalidFd)) {
433 int fd_list[2];
434 if (pipe(fd_list) >= 0) {
435 socket_pipe_r_fd_ = fd_list[0];
436 socket_pipe_w_fd_ = fd_list[1];
437 }
438 }
439
440 if (!ready_ || (socket_pipe_r_fd_ == kInvalidFd) ||
441 (socket_pipe_w_fd_ == kInvalidFd)) {
442 return kInvalidFd;
443 }
444 return socket_fd_;
445 }
446
ReleaseSocket()447 void ServerSocket::ReleaseSocket() {
448 std::lock_guard lock(socket_mutex_);
449 ReleaseSocketWithLockHeld();
450 }
451
ReleaseSocketWithLockHeld()452 void ServerSocket::ReleaseSocketWithLockHeld() {
453 --socket_own_count_;
454
455 if (socket_own_count_ <= 0) {
456 ready_ = false;
457 if (socket_fd_ != kInvalidFd) {
458 close(socket_fd_);
459 socket_fd_ = kInvalidFd;
460 }
461 if (socket_pipe_r_fd_ != kInvalidFd) {
462 close(socket_pipe_r_fd_);
463 socket_pipe_r_fd_ = kInvalidFd;
464 }
465 if (socket_pipe_w_fd_ != kInvalidFd) {
466 close(socket_pipe_w_fd_);
467 socket_pipe_w_fd_ = kInvalidFd;
468 }
469 }
470 }
471
472 } // namespace pw::stream
473