1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "base/posix/unix_domain_socket.h"
11
12 #include <stddef.h>
13 #include <stdint.h>
14 #include <sys/socket.h>
15 #include <sys/types.h>
16 #include <unistd.h>
17
18 #include "base/files/file_util.h"
19 #include "base/files/scoped_file.h"
20 #include "base/functional/bind.h"
21 #include "base/functional/callback_helpers.h"
22 #include "base/location.h"
23 #include "base/pickle.h"
24 #include "base/synchronization/waitable_event.h"
25 #include "base/task/single_thread_task_runner.h"
26 #include "base/threading/thread.h"
27 #include "build/build_config.h"
28 #include "testing/gtest/include/gtest/gtest.h"
29
30 namespace base {
31
32 namespace {
33
34 // Callers should use ASSERT_NO_FATAL_FAILURE with this function, to
35 // ensure that execution is aborted if the function has assertion failure.
CreateSocketPair(int fds[2])36 void CreateSocketPair(int fds[2]) {
37 #if BUILDFLAG(IS_APPLE)
38 // Mac OS does not support SOCK_SEQPACKET.
39 int flags = SOCK_STREAM;
40 #else
41 int flags = SOCK_SEQPACKET;
42 #endif
43 ASSERT_EQ(0, socketpair(AF_UNIX, flags, 0, fds));
44 }
45
TEST(UnixDomainSocketTest,SendRecvMsgAbortOnReplyFDClose)46 TEST(UnixDomainSocketTest, SendRecvMsgAbortOnReplyFDClose) {
47 Thread message_thread("UnixDomainSocketTest");
48 ASSERT_TRUE(message_thread.Start());
49 int fds[2];
50 ASSERT_NO_FATAL_FAILURE(CreateSocketPair(fds));
51 ScopedFD scoped_fd0(fds[0]);
52 ScopedFD scoped_fd1(fds[1]);
53
54 // Have the thread send a synchronous message via the socket.
55 Pickle request;
56 message_thread.task_runner()->PostTask(
57 FROM_HERE, BindOnce(IgnoreResult(&UnixDomainSocket::SendRecvMsg), fds[1],
58 nullptr, 0U, nullptr, request));
59
60 // Receive the message.
61 std::vector<ScopedFD> message_fds;
62 uint8_t buffer[16];
63 ASSERT_EQ(
64 static_cast<int>(request.size()),
65 UnixDomainSocket::RecvMsg(fds[0], buffer, sizeof(buffer), &message_fds));
66 ASSERT_EQ(1U, message_fds.size());
67
68 // Close the reply FD.
69 message_fds.clear();
70
71 // Check that the thread didn't get blocked.
72 WaitableEvent event(WaitableEvent::ResetPolicy::AUTOMATIC,
73 WaitableEvent::InitialState::NOT_SIGNALED);
74 message_thread.task_runner()->PostTask(
75 FROM_HERE, BindOnce(&WaitableEvent::Signal, Unretained(&event)));
76 ASSERT_TRUE(event.TimedWait(Milliseconds(5000)));
77 }
78
TEST(UnixDomainSocketTest,SendRecvMsgAvoidsSIGPIPE)79 TEST(UnixDomainSocketTest, SendRecvMsgAvoidsSIGPIPE) {
80 // Make sure SIGPIPE isn't being ignored.
81 struct sigaction act = {}, oldact;
82 act.sa_handler = SIG_DFL;
83 ASSERT_EQ(0, sigaction(SIGPIPE, &act, &oldact));
84 int fds[2];
85 ASSERT_NO_FATAL_FAILURE(CreateSocketPair(fds));
86 ScopedFD scoped_fd1(fds[1]);
87 ASSERT_EQ(0, IGNORE_EINTR(close(fds[0])));
88
89 // Have the thread send a synchronous message via the socket. Unless the
90 // message is sent with MSG_NOSIGNAL, this shall result in SIGPIPE.
91 Pickle request;
92 ASSERT_EQ(
93 -1, UnixDomainSocket::SendRecvMsg(fds[1], nullptr, 0U, nullptr, request));
94 ASSERT_EQ(EPIPE, errno);
95 // Restore the SIGPIPE handler.
96 ASSERT_EQ(0, sigaction(SIGPIPE, &oldact, nullptr));
97 }
98
99 // Simple sanity check within a single process that receiving PIDs works.
TEST(UnixDomainSocketTest,RecvPid)100 TEST(UnixDomainSocketTest, RecvPid) {
101 int fds[2];
102 ASSERT_NO_FATAL_FAILURE(CreateSocketPair(fds));
103 ScopedFD recv_sock(fds[0]);
104 ScopedFD send_sock(fds[1]);
105
106 ASSERT_TRUE(UnixDomainSocket::EnableReceiveProcessId(recv_sock.get()));
107
108 static const char kHello[] = "hello";
109 ASSERT_TRUE(UnixDomainSocket::SendMsg(send_sock.get(), kHello, sizeof(kHello),
110 std::vector<int>()));
111
112 // Extra receiving buffer space to make sure we really received only
113 // sizeof(kHello) bytes and it wasn't just truncated to fit the buffer.
114 char buf[sizeof(kHello) + 1];
115 ProcessId sender_pid;
116 std::vector<ScopedFD> fd_vec;
117 const ssize_t nread = UnixDomainSocket::RecvMsgWithPid(
118 recv_sock.get(), buf, sizeof(buf), &fd_vec, &sender_pid);
119 ASSERT_EQ(sizeof(kHello), static_cast<size_t>(nread));
120 ASSERT_EQ(0, memcmp(buf, kHello, sizeof(kHello)));
121 ASSERT_EQ(0U, fd_vec.size());
122
123 ASSERT_EQ(getpid(), sender_pid);
124 }
125
126 // Same as above, but send the max number of file descriptors too.
TEST(UnixDomainSocketTest,RecvPidWithMaxDescriptors)127 TEST(UnixDomainSocketTest, RecvPidWithMaxDescriptors) {
128 int fds[2];
129 ASSERT_NO_FATAL_FAILURE(CreateSocketPair(fds));
130 ScopedFD recv_sock(fds[0]);
131 ScopedFD send_sock(fds[1]);
132
133 ASSERT_TRUE(UnixDomainSocket::EnableReceiveProcessId(recv_sock.get()));
134
135 static const char kHello[] = "hello";
136 std::vector<int> send_fds(UnixDomainSocket::kMaxFileDescriptors,
137 send_sock.get());
138 ASSERT_TRUE(UnixDomainSocket::SendMsg(send_sock.get(), kHello, sizeof(kHello),
139 send_fds));
140
141 // Extra receiving buffer space to make sure we really received only
142 // sizeof(kHello) bytes and it wasn't just truncated to fit the buffer.
143 char buf[sizeof(kHello) + 1];
144 ProcessId sender_pid;
145 std::vector<ScopedFD> recv_fds;
146 const ssize_t nread = UnixDomainSocket::RecvMsgWithPid(
147 recv_sock.get(), buf, sizeof(buf), &recv_fds, &sender_pid);
148 ASSERT_EQ(sizeof(kHello), static_cast<size_t>(nread));
149 ASSERT_EQ(0, memcmp(buf, kHello, sizeof(kHello)));
150 ASSERT_EQ(UnixDomainSocket::kMaxFileDescriptors, recv_fds.size());
151
152 ASSERT_EQ(getpid(), sender_pid);
153 }
154
155 // Check that RecvMsgWithPid doesn't DCHECK fail when reading EOF from a
156 // disconnected socket.
TEST(UnixDomianSocketTest,RecvPidDisconnectedSocket)157 TEST(UnixDomianSocketTest, RecvPidDisconnectedSocket) {
158 int fds[2];
159 ASSERT_NO_FATAL_FAILURE(CreateSocketPair(fds));
160 ScopedFD recv_sock(fds[0]);
161 ScopedFD send_sock(fds[1]);
162
163 ASSERT_TRUE(UnixDomainSocket::EnableReceiveProcessId(recv_sock.get()));
164
165 send_sock.reset();
166
167 char ch;
168 ProcessId sender_pid;
169 std::vector<ScopedFD> recv_fds;
170 const ssize_t nread = UnixDomainSocket::RecvMsgWithPid(
171 recv_sock.get(), &ch, sizeof(ch), &recv_fds, &sender_pid);
172 ASSERT_EQ(0, nread);
173 ASSERT_EQ(-1, sender_pid);
174 ASSERT_EQ(0U, recv_fds.size());
175 }
176
177 } // namespace
178
179 } // namespace base
180