1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/socket/unix_domain_client_socket_posix.h"
11
12 #include <unistd.h>
13
14 #include <memory>
15 #include <utility>
16
17 #include "base/files/file_path.h"
18 #include "base/files/scoped_temp_dir.h"
19 #include "base/functional/bind.h"
20 #include "base/posix/eintr_wrapper.h"
21 #include "build/build_config.h"
22 #include "net/base/io_buffer.h"
23 #include "net/base/net_errors.h"
24 #include "net/base/sockaddr_storage.h"
25 #include "net/base/sockaddr_util_posix.h"
26 #include "net/base/test_completion_callback.h"
27 #include "net/socket/socket_posix.h"
28 #include "net/socket/unix_domain_server_socket_posix.h"
29 #include "net/test/gtest_util.h"
30 #include "net/test/test_with_task_environment.h"
31 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
32 #include "testing/gmock/include/gmock/gmock.h"
33 #include "testing/gtest/include/gtest/gtest.h"
34
35 using net::test::IsError;
36 using net::test::IsOk;
37
38 namespace net {
39 namespace {
40
41 const char kSocketFilename[] = "socket_for_testing";
42
UserCanConnectCallback(bool allow_user,const UnixDomainServerSocket::Credentials & credentials)43 bool UserCanConnectCallback(
44 bool allow_user, const UnixDomainServerSocket::Credentials& credentials) {
45 // Here peers are running in same process.
46 #if BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS) || BUILDFLAG(IS_ANDROID)
47 EXPECT_EQ(getpid(), credentials.process_id);
48 #endif
49 EXPECT_EQ(getuid(), credentials.user_id);
50 EXPECT_EQ(getgid(), credentials.group_id);
51 return allow_user;
52 }
53
CreateAuthCallback(bool allow_user)54 UnixDomainServerSocket::AuthCallback CreateAuthCallback(bool allow_user) {
55 return base::BindRepeating(&UserCanConnectCallback, allow_user);
56 }
57
58 // Connects socket synchronously.
ConnectSynchronously(StreamSocket * socket)59 int ConnectSynchronously(StreamSocket* socket) {
60 TestCompletionCallback connect_callback;
61 int rv = socket->Connect(connect_callback.callback());
62 if (rv == ERR_IO_PENDING)
63 rv = connect_callback.WaitForResult();
64 return rv;
65 }
66
67 // Reads data from |socket| until it fills |buf| at least up to |min_data_len|.
68 // Returns length of data read, or a net error.
ReadSynchronously(StreamSocket * socket,IOBuffer * buf,int buf_len,int min_data_len)69 int ReadSynchronously(StreamSocket* socket,
70 IOBuffer* buf,
71 int buf_len,
72 int min_data_len) {
73 DCHECK_LE(min_data_len, buf_len);
74 scoped_refptr<DrainableIOBuffer> read_buf =
75 base::MakeRefCounted<DrainableIOBuffer>(buf, buf_len);
76 TestCompletionCallback read_callback;
77 // Iterate reading several times (but not infinite) until it reads at least
78 // |min_data_len| bytes into |buf|.
79 for (int retry_count = 10;
80 retry_count > 0 && (read_buf->BytesConsumed() < min_data_len ||
81 // Try at least once when min_data_len == 0.
82 min_data_len == 0);
83 --retry_count) {
84 int rv = socket->Read(
85 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
86 EXPECT_GE(read_buf->BytesRemaining(), rv);
87 if (rv == ERR_IO_PENDING) {
88 // If |min_data_len| is 0, returns ERR_IO_PENDING to distinguish the case
89 // when some data has been read.
90 if (min_data_len == 0) {
91 // No data has been read because of for-loop condition.
92 DCHECK_EQ(0, read_buf->BytesConsumed());
93 return ERR_IO_PENDING;
94 }
95 rv = read_callback.WaitForResult();
96 }
97 EXPECT_NE(ERR_IO_PENDING, rv);
98 if (rv < 0)
99 return rv;
100 read_buf->DidConsume(rv);
101 }
102 EXPECT_LE(0, read_buf->BytesRemaining());
103 return read_buf->BytesConsumed();
104 }
105
106 // Writes data to |socket| until it completes writing |buf| up to |buf_len|.
107 // Returns length of data written, or a net error.
WriteSynchronously(StreamSocket * socket,IOBuffer * buf,int buf_len)108 int WriteSynchronously(StreamSocket* socket,
109 IOBuffer* buf,
110 int buf_len) {
111 scoped_refptr<DrainableIOBuffer> write_buf =
112 base::MakeRefCounted<DrainableIOBuffer>(buf, buf_len);
113 TestCompletionCallback write_callback;
114 // Iterate writing several times (but not infinite) until it writes buf fully.
115 for (int retry_count = 10;
116 retry_count > 0 && write_buf->BytesRemaining() > 0;
117 --retry_count) {
118 int rv =
119 socket->Write(write_buf.get(), write_buf->BytesRemaining(),
120 write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
121 EXPECT_GE(write_buf->BytesRemaining(), rv);
122 if (rv == ERR_IO_PENDING)
123 rv = write_callback.WaitForResult();
124 EXPECT_NE(ERR_IO_PENDING, rv);
125 if (rv < 0)
126 return rv;
127 write_buf->DidConsume(rv);
128 }
129 EXPECT_LE(0, write_buf->BytesRemaining());
130 return write_buf->BytesConsumed();
131 }
132
133 class UnixDomainClientSocketTest : public TestWithTaskEnvironment {
134 protected:
UnixDomainClientSocketTest()135 UnixDomainClientSocketTest() {
136 EXPECT_TRUE(temp_dir_.CreateUniqueTempDir());
137 socket_path_ = temp_dir_.GetPath().Append(kSocketFilename).value();
138 }
139
140 base::ScopedTempDir temp_dir_;
141 std::string socket_path_;
142 };
143
TEST_F(UnixDomainClientSocketTest,Connect)144 TEST_F(UnixDomainClientSocketTest, Connect) {
145 const bool kUseAbstractNamespace = false;
146
147 UnixDomainServerSocket server_socket(CreateAuthCallback(true),
148 kUseAbstractNamespace);
149 EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
150
151 std::unique_ptr<StreamSocket> accepted_socket;
152 TestCompletionCallback accept_callback;
153 EXPECT_EQ(ERR_IO_PENDING,
154 server_socket.Accept(&accepted_socket, accept_callback.callback()));
155 EXPECT_FALSE(accepted_socket);
156
157 UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
158 EXPECT_FALSE(client_socket.IsConnected());
159
160 EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
161 EXPECT_TRUE(client_socket.IsConnected());
162 // Server has not yet been notified of the connection.
163 EXPECT_FALSE(accepted_socket);
164
165 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
166 EXPECT_TRUE(accepted_socket);
167 EXPECT_TRUE(accepted_socket->IsConnected());
168 }
169
TEST_F(UnixDomainClientSocketTest,ConnectWithSocketDescriptor)170 TEST_F(UnixDomainClientSocketTest, ConnectWithSocketDescriptor) {
171 const bool kUseAbstractNamespace = false;
172
173 UnixDomainServerSocket server_socket(CreateAuthCallback(true),
174 kUseAbstractNamespace);
175 EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
176
177 SocketDescriptor accepted_socket_fd = kInvalidSocket;
178 TestCompletionCallback accept_callback;
179 EXPECT_EQ(ERR_IO_PENDING,
180 server_socket.AcceptSocketDescriptor(&accepted_socket_fd,
181 accept_callback.callback()));
182 EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
183
184 UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
185 EXPECT_FALSE(client_socket.IsConnected());
186
187 EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
188 EXPECT_TRUE(client_socket.IsConnected());
189 // Server has not yet been notified of the connection.
190 EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
191
192 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
193 EXPECT_NE(kInvalidSocket, accepted_socket_fd);
194
195 SocketDescriptor client_socket_fd = client_socket.ReleaseConnectedSocket();
196 EXPECT_NE(kInvalidSocket, client_socket_fd);
197
198 // Now, re-wrap client_socket_fd in a UnixDomainClientSocket and try a read
199 // to be sure it hasn't gotten accidentally closed.
200 SockaddrStorage addr;
201 ASSERT_TRUE(FillUnixAddress(socket_path_, false, &addr));
202 auto adopter = std::make_unique<SocketPosix>();
203 adopter->AdoptConnectedSocket(client_socket_fd, addr);
204 UnixDomainClientSocket rewrapped_socket(std::move(adopter));
205 EXPECT_TRUE(rewrapped_socket.IsConnected());
206
207 // Try to read data.
208 const int kReadDataSize = 10;
209 auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadDataSize);
210 TestCompletionCallback read_callback;
211 EXPECT_EQ(ERR_IO_PENDING,
212 rewrapped_socket.Read(
213 read_buffer.get(), kReadDataSize, read_callback.callback()));
214
215 EXPECT_EQ(0, IGNORE_EINTR(close(accepted_socket_fd)));
216 }
217
TEST_F(UnixDomainClientSocketTest,ConnectWithAbstractNamespace)218 TEST_F(UnixDomainClientSocketTest, ConnectWithAbstractNamespace) {
219 const bool kUseAbstractNamespace = true;
220
221 UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
222 EXPECT_FALSE(client_socket.IsConnected());
223
224 #if BUILDFLAG(IS_ANDROID) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
225 UnixDomainServerSocket server_socket(CreateAuthCallback(true),
226 kUseAbstractNamespace);
227 EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
228
229 std::unique_ptr<StreamSocket> accepted_socket;
230 TestCompletionCallback accept_callback;
231 EXPECT_EQ(ERR_IO_PENDING,
232 server_socket.Accept(&accepted_socket, accept_callback.callback()));
233 EXPECT_FALSE(accepted_socket);
234
235 EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
236 EXPECT_TRUE(client_socket.IsConnected());
237 // Server has not yet beend notified of the connection.
238 EXPECT_FALSE(accepted_socket);
239
240 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
241 EXPECT_TRUE(accepted_socket);
242 EXPECT_TRUE(accepted_socket->IsConnected());
243 #else
244 EXPECT_THAT(ConnectSynchronously(&client_socket),
245 IsError(ERR_ADDRESS_INVALID));
246 #endif
247 }
248
TEST_F(UnixDomainClientSocketTest,ConnectToNonExistentSocket)249 TEST_F(UnixDomainClientSocketTest, ConnectToNonExistentSocket) {
250 const bool kUseAbstractNamespace = false;
251
252 UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
253 EXPECT_FALSE(client_socket.IsConnected());
254 EXPECT_THAT(ConnectSynchronously(&client_socket),
255 IsError(ERR_FILE_NOT_FOUND));
256 }
257
TEST_F(UnixDomainClientSocketTest,ConnectToNonExistentSocketWithAbstractNamespace)258 TEST_F(UnixDomainClientSocketTest,
259 ConnectToNonExistentSocketWithAbstractNamespace) {
260 const bool kUseAbstractNamespace = true;
261
262 UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
263 EXPECT_FALSE(client_socket.IsConnected());
264
265 TestCompletionCallback connect_callback;
266 #if BUILDFLAG(IS_ANDROID) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
267 EXPECT_THAT(ConnectSynchronously(&client_socket),
268 IsError(ERR_CONNECTION_REFUSED));
269 #else
270 EXPECT_THAT(ConnectSynchronously(&client_socket),
271 IsError(ERR_ADDRESS_INVALID));
272 #endif
273 }
274
TEST_F(UnixDomainClientSocketTest,DisconnectFromClient)275 TEST_F(UnixDomainClientSocketTest, DisconnectFromClient) {
276 UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
277 EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
278 std::unique_ptr<StreamSocket> accepted_socket;
279 TestCompletionCallback accept_callback;
280 EXPECT_EQ(ERR_IO_PENDING,
281 server_socket.Accept(&accepted_socket, accept_callback.callback()));
282 UnixDomainClientSocket client_socket(socket_path_, false);
283 EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
284
285 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
286 EXPECT_TRUE(accepted_socket->IsConnected());
287 EXPECT_TRUE(client_socket.IsConnected());
288
289 // Try to read data.
290 const int kReadDataSize = 10;
291 auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadDataSize);
292 TestCompletionCallback read_callback;
293 EXPECT_EQ(ERR_IO_PENDING,
294 accepted_socket->Read(
295 read_buffer.get(), kReadDataSize, read_callback.callback()));
296
297 // Disconnect from client side.
298 client_socket.Disconnect();
299 EXPECT_FALSE(client_socket.IsConnected());
300 EXPECT_FALSE(accepted_socket->IsConnected());
301
302 // Connection closed by peer.
303 EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult());
304 // Note that read callback won't be called when the connection is closed
305 // locally before the peer closes it. SocketPosix just clears callbacks.
306 }
307
TEST_F(UnixDomainClientSocketTest,DisconnectFromServer)308 TEST_F(UnixDomainClientSocketTest, DisconnectFromServer) {
309 UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
310 EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
311 std::unique_ptr<StreamSocket> accepted_socket;
312 TestCompletionCallback accept_callback;
313 EXPECT_EQ(ERR_IO_PENDING,
314 server_socket.Accept(&accepted_socket, accept_callback.callback()));
315 UnixDomainClientSocket client_socket(socket_path_, false);
316 EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
317
318 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
319 EXPECT_TRUE(accepted_socket->IsConnected());
320 EXPECT_TRUE(client_socket.IsConnected());
321
322 // Try to read data.
323 const int kReadDataSize = 10;
324 auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadDataSize);
325 TestCompletionCallback read_callback;
326 EXPECT_EQ(ERR_IO_PENDING,
327 client_socket.Read(
328 read_buffer.get(), kReadDataSize, read_callback.callback()));
329
330 // Disconnect from server side.
331 accepted_socket->Disconnect();
332 EXPECT_FALSE(accepted_socket->IsConnected());
333 EXPECT_FALSE(client_socket.IsConnected());
334
335 // Connection closed by peer.
336 EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult());
337 // Note that read callback won't be called when the connection is closed
338 // locally before the peer closes it. SocketPosix just clears callbacks.
339 }
340
TEST_F(UnixDomainClientSocketTest,ReadAfterWrite)341 TEST_F(UnixDomainClientSocketTest, ReadAfterWrite) {
342 UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
343 EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
344 std::unique_ptr<StreamSocket> accepted_socket;
345 TestCompletionCallback accept_callback;
346 EXPECT_EQ(ERR_IO_PENDING,
347 server_socket.Accept(&accepted_socket, accept_callback.callback()));
348 UnixDomainClientSocket client_socket(socket_path_, false);
349 EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
350
351 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
352 EXPECT_TRUE(accepted_socket->IsConnected());
353 EXPECT_TRUE(client_socket.IsConnected());
354
355 // Send data from client to server.
356 const int kWriteDataSize = 10;
357 auto write_buffer =
358 base::MakeRefCounted<StringIOBuffer>(std::string(kWriteDataSize, 'd'));
359 EXPECT_EQ(
360 kWriteDataSize,
361 WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize));
362
363 // The buffer is bigger than write data size.
364 const int kReadBufferSize = kWriteDataSize * 2;
365 auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadBufferSize);
366 EXPECT_EQ(kWriteDataSize,
367 ReadSynchronously(accepted_socket.get(),
368 read_buffer.get(),
369 kReadBufferSize,
370 kWriteDataSize));
371 EXPECT_EQ(std::string(write_buffer->data(), kWriteDataSize),
372 std::string(read_buffer->data(), kWriteDataSize));
373
374 // Send data from server and client.
375 EXPECT_EQ(kWriteDataSize,
376 WriteSynchronously(
377 accepted_socket.get(), write_buffer.get(), kWriteDataSize));
378
379 // Read multiple times.
380 const int kSmallReadBufferSize = kWriteDataSize / 3;
381 EXPECT_EQ(kSmallReadBufferSize,
382 ReadSynchronously(&client_socket,
383 read_buffer.get(),
384 kSmallReadBufferSize,
385 kSmallReadBufferSize));
386 EXPECT_EQ(std::string(write_buffer->data(), kSmallReadBufferSize),
387 std::string(read_buffer->data(), kSmallReadBufferSize));
388
389 EXPECT_EQ(kWriteDataSize - kSmallReadBufferSize,
390 ReadSynchronously(&client_socket,
391 read_buffer.get(),
392 kReadBufferSize,
393 kWriteDataSize - kSmallReadBufferSize));
394 EXPECT_EQ(std::string(write_buffer->data() + kSmallReadBufferSize,
395 kWriteDataSize - kSmallReadBufferSize),
396 std::string(read_buffer->data(),
397 kWriteDataSize - kSmallReadBufferSize));
398
399 // No more data.
400 EXPECT_EQ(
401 ERR_IO_PENDING,
402 ReadSynchronously(&client_socket, read_buffer.get(), kReadBufferSize, 0));
403
404 // Disconnect from server side after read-write.
405 accepted_socket->Disconnect();
406 EXPECT_FALSE(accepted_socket->IsConnected());
407 EXPECT_FALSE(client_socket.IsConnected());
408 }
409
TEST_F(UnixDomainClientSocketTest,ReadBeforeWrite)410 TEST_F(UnixDomainClientSocketTest, ReadBeforeWrite) {
411 UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
412 EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
413 std::unique_ptr<StreamSocket> accepted_socket;
414 TestCompletionCallback accept_callback;
415 EXPECT_EQ(ERR_IO_PENDING,
416 server_socket.Accept(&accepted_socket, accept_callback.callback()));
417 UnixDomainClientSocket client_socket(socket_path_, false);
418 EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
419
420 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
421 EXPECT_TRUE(accepted_socket->IsConnected());
422 EXPECT_TRUE(client_socket.IsConnected());
423
424 // Wait for data from client.
425 const int kWriteDataSize = 10;
426 const int kReadBufferSize = kWriteDataSize * 2;
427 const int kSmallReadBufferSize = kWriteDataSize / 3;
428 // Read smaller than write data size first.
429 auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadBufferSize);
430 TestCompletionCallback read_callback;
431 EXPECT_EQ(
432 ERR_IO_PENDING,
433 accepted_socket->Read(
434 read_buffer.get(), kSmallReadBufferSize, read_callback.callback()));
435
436 auto write_buffer =
437 base::MakeRefCounted<StringIOBuffer>(std::string(kWriteDataSize, 'd'));
438 EXPECT_EQ(
439 kWriteDataSize,
440 WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize));
441
442 // First read completed.
443 int rv = read_callback.WaitForResult();
444 EXPECT_LT(0, rv);
445 EXPECT_LE(rv, kSmallReadBufferSize);
446
447 // Read remaining data.
448 const int kExpectedRemainingDataSize = kWriteDataSize - rv;
449 EXPECT_LE(0, kExpectedRemainingDataSize);
450 EXPECT_EQ(kExpectedRemainingDataSize,
451 ReadSynchronously(accepted_socket.get(),
452 read_buffer.get(),
453 kReadBufferSize,
454 kExpectedRemainingDataSize));
455 // No more data.
456 EXPECT_EQ(ERR_IO_PENDING,
457 ReadSynchronously(
458 accepted_socket.get(), read_buffer.get(), kReadBufferSize, 0));
459
460 // Disconnect from server side after read-write.
461 accepted_socket->Disconnect();
462 EXPECT_FALSE(accepted_socket->IsConnected());
463 EXPECT_FALSE(client_socket.IsConnected());
464 }
465
466 } // namespace
467 } // namespace net
468