• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9 
10 #include "net/socket/unix_domain_client_socket_posix.h"
11 
12 #include <unistd.h>
13 
14 #include <memory>
15 #include <utility>
16 
17 #include "base/files/file_path.h"
18 #include "base/files/scoped_temp_dir.h"
19 #include "base/functional/bind.h"
20 #include "base/posix/eintr_wrapper.h"
21 #include "build/build_config.h"
22 #include "net/base/io_buffer.h"
23 #include "net/base/net_errors.h"
24 #include "net/base/sockaddr_storage.h"
25 #include "net/base/sockaddr_util_posix.h"
26 #include "net/base/test_completion_callback.h"
27 #include "net/socket/socket_posix.h"
28 #include "net/socket/unix_domain_server_socket_posix.h"
29 #include "net/test/gtest_util.h"
30 #include "net/test/test_with_task_environment.h"
31 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
32 #include "testing/gmock/include/gmock/gmock.h"
33 #include "testing/gtest/include/gtest/gtest.h"
34 
35 using net::test::IsError;
36 using net::test::IsOk;
37 
38 namespace net {
39 namespace {
40 
41 const char kSocketFilename[] = "socket_for_testing";
42 
UserCanConnectCallback(bool allow_user,const UnixDomainServerSocket::Credentials & credentials)43 bool UserCanConnectCallback(
44     bool allow_user, const UnixDomainServerSocket::Credentials& credentials) {
45   // Here peers are running in same process.
46 #if BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS) || BUILDFLAG(IS_ANDROID)
47   EXPECT_EQ(getpid(), credentials.process_id);
48 #endif
49   EXPECT_EQ(getuid(), credentials.user_id);
50   EXPECT_EQ(getgid(), credentials.group_id);
51   return allow_user;
52 }
53 
CreateAuthCallback(bool allow_user)54 UnixDomainServerSocket::AuthCallback CreateAuthCallback(bool allow_user) {
55   return base::BindRepeating(&UserCanConnectCallback, allow_user);
56 }
57 
58 // Connects socket synchronously.
ConnectSynchronously(StreamSocket * socket)59 int ConnectSynchronously(StreamSocket* socket) {
60   TestCompletionCallback connect_callback;
61   int rv = socket->Connect(connect_callback.callback());
62   if (rv == ERR_IO_PENDING)
63     rv = connect_callback.WaitForResult();
64   return rv;
65 }
66 
67 // Reads data from |socket| until it fills |buf| at least up to |min_data_len|.
68 // Returns length of data read, or a net error.
ReadSynchronously(StreamSocket * socket,IOBuffer * buf,int buf_len,int min_data_len)69 int ReadSynchronously(StreamSocket* socket,
70                       IOBuffer* buf,
71                       int buf_len,
72                       int min_data_len) {
73   DCHECK_LE(min_data_len, buf_len);
74   scoped_refptr<DrainableIOBuffer> read_buf =
75       base::MakeRefCounted<DrainableIOBuffer>(buf, buf_len);
76   TestCompletionCallback read_callback;
77   // Iterate reading several times (but not infinite) until it reads at least
78   // |min_data_len| bytes into |buf|.
79   for (int retry_count = 10;
80        retry_count > 0 && (read_buf->BytesConsumed() < min_data_len ||
81                            // Try at least once when min_data_len == 0.
82                            min_data_len == 0);
83        --retry_count) {
84     int rv = socket->Read(
85         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
86     EXPECT_GE(read_buf->BytesRemaining(), rv);
87     if (rv == ERR_IO_PENDING) {
88       // If |min_data_len| is 0, returns ERR_IO_PENDING to distinguish the case
89       // when some data has been read.
90       if (min_data_len == 0) {
91         // No data has been read because of for-loop condition.
92         DCHECK_EQ(0, read_buf->BytesConsumed());
93         return ERR_IO_PENDING;
94       }
95       rv = read_callback.WaitForResult();
96     }
97     EXPECT_NE(ERR_IO_PENDING, rv);
98     if (rv < 0)
99       return rv;
100     read_buf->DidConsume(rv);
101   }
102   EXPECT_LE(0, read_buf->BytesRemaining());
103   return read_buf->BytesConsumed();
104 }
105 
106 // Writes data to |socket| until it completes writing |buf| up to |buf_len|.
107 // Returns length of data written, or a net error.
WriteSynchronously(StreamSocket * socket,IOBuffer * buf,int buf_len)108 int WriteSynchronously(StreamSocket* socket,
109                        IOBuffer* buf,
110                        int buf_len) {
111   scoped_refptr<DrainableIOBuffer> write_buf =
112       base::MakeRefCounted<DrainableIOBuffer>(buf, buf_len);
113   TestCompletionCallback write_callback;
114   // Iterate writing several times (but not infinite) until it writes buf fully.
115   for (int retry_count = 10;
116        retry_count > 0 && write_buf->BytesRemaining() > 0;
117        --retry_count) {
118     int rv =
119         socket->Write(write_buf.get(), write_buf->BytesRemaining(),
120                       write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
121     EXPECT_GE(write_buf->BytesRemaining(), rv);
122     if (rv == ERR_IO_PENDING)
123       rv = write_callback.WaitForResult();
124     EXPECT_NE(ERR_IO_PENDING, rv);
125     if (rv < 0)
126       return rv;
127     write_buf->DidConsume(rv);
128   }
129   EXPECT_LE(0, write_buf->BytesRemaining());
130   return write_buf->BytesConsumed();
131 }
132 
133 class UnixDomainClientSocketTest : public TestWithTaskEnvironment {
134  protected:
UnixDomainClientSocketTest()135   UnixDomainClientSocketTest() {
136     EXPECT_TRUE(temp_dir_.CreateUniqueTempDir());
137     socket_path_ = temp_dir_.GetPath().Append(kSocketFilename).value();
138   }
139 
140   base::ScopedTempDir temp_dir_;
141   std::string socket_path_;
142 };
143 
TEST_F(UnixDomainClientSocketTest,Connect)144 TEST_F(UnixDomainClientSocketTest, Connect) {
145   const bool kUseAbstractNamespace = false;
146 
147   UnixDomainServerSocket server_socket(CreateAuthCallback(true),
148                                        kUseAbstractNamespace);
149   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
150 
151   std::unique_ptr<StreamSocket> accepted_socket;
152   TestCompletionCallback accept_callback;
153   EXPECT_EQ(ERR_IO_PENDING,
154             server_socket.Accept(&accepted_socket, accept_callback.callback()));
155   EXPECT_FALSE(accepted_socket);
156 
157   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
158   EXPECT_FALSE(client_socket.IsConnected());
159 
160   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
161   EXPECT_TRUE(client_socket.IsConnected());
162   // Server has not yet been notified of the connection.
163   EXPECT_FALSE(accepted_socket);
164 
165   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
166   EXPECT_TRUE(accepted_socket);
167   EXPECT_TRUE(accepted_socket->IsConnected());
168 }
169 
TEST_F(UnixDomainClientSocketTest,ConnectWithSocketDescriptor)170 TEST_F(UnixDomainClientSocketTest, ConnectWithSocketDescriptor) {
171   const bool kUseAbstractNamespace = false;
172 
173   UnixDomainServerSocket server_socket(CreateAuthCallback(true),
174                                        kUseAbstractNamespace);
175   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
176 
177   SocketDescriptor accepted_socket_fd = kInvalidSocket;
178   TestCompletionCallback accept_callback;
179   EXPECT_EQ(ERR_IO_PENDING,
180             server_socket.AcceptSocketDescriptor(&accepted_socket_fd,
181                                                  accept_callback.callback()));
182   EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
183 
184   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
185   EXPECT_FALSE(client_socket.IsConnected());
186 
187   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
188   EXPECT_TRUE(client_socket.IsConnected());
189   // Server has not yet been notified of the connection.
190   EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
191 
192   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
193   EXPECT_NE(kInvalidSocket, accepted_socket_fd);
194 
195   SocketDescriptor client_socket_fd = client_socket.ReleaseConnectedSocket();
196   EXPECT_NE(kInvalidSocket, client_socket_fd);
197 
198   // Now, re-wrap client_socket_fd in a UnixDomainClientSocket and try a read
199   // to be sure it hasn't gotten accidentally closed.
200   SockaddrStorage addr;
201   ASSERT_TRUE(FillUnixAddress(socket_path_, false, &addr));
202   auto adopter = std::make_unique<SocketPosix>();
203   adopter->AdoptConnectedSocket(client_socket_fd, addr);
204   UnixDomainClientSocket rewrapped_socket(std::move(adopter));
205   EXPECT_TRUE(rewrapped_socket.IsConnected());
206 
207   // Try to read data.
208   const int kReadDataSize = 10;
209   auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadDataSize);
210   TestCompletionCallback read_callback;
211   EXPECT_EQ(ERR_IO_PENDING,
212             rewrapped_socket.Read(
213                 read_buffer.get(), kReadDataSize, read_callback.callback()));
214 
215   EXPECT_EQ(0, IGNORE_EINTR(close(accepted_socket_fd)));
216 }
217 
TEST_F(UnixDomainClientSocketTest,ConnectWithAbstractNamespace)218 TEST_F(UnixDomainClientSocketTest, ConnectWithAbstractNamespace) {
219   const bool kUseAbstractNamespace = true;
220 
221   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
222   EXPECT_FALSE(client_socket.IsConnected());
223 
224 #if BUILDFLAG(IS_ANDROID) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
225   UnixDomainServerSocket server_socket(CreateAuthCallback(true),
226                                        kUseAbstractNamespace);
227   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
228 
229   std::unique_ptr<StreamSocket> accepted_socket;
230   TestCompletionCallback accept_callback;
231   EXPECT_EQ(ERR_IO_PENDING,
232             server_socket.Accept(&accepted_socket, accept_callback.callback()));
233   EXPECT_FALSE(accepted_socket);
234 
235   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
236   EXPECT_TRUE(client_socket.IsConnected());
237   // Server has not yet beend notified of the connection.
238   EXPECT_FALSE(accepted_socket);
239 
240   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
241   EXPECT_TRUE(accepted_socket);
242   EXPECT_TRUE(accepted_socket->IsConnected());
243 #else
244   EXPECT_THAT(ConnectSynchronously(&client_socket),
245               IsError(ERR_ADDRESS_INVALID));
246 #endif
247 }
248 
TEST_F(UnixDomainClientSocketTest,ConnectToNonExistentSocket)249 TEST_F(UnixDomainClientSocketTest, ConnectToNonExistentSocket) {
250   const bool kUseAbstractNamespace = false;
251 
252   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
253   EXPECT_FALSE(client_socket.IsConnected());
254   EXPECT_THAT(ConnectSynchronously(&client_socket),
255               IsError(ERR_FILE_NOT_FOUND));
256 }
257 
TEST_F(UnixDomainClientSocketTest,ConnectToNonExistentSocketWithAbstractNamespace)258 TEST_F(UnixDomainClientSocketTest,
259        ConnectToNonExistentSocketWithAbstractNamespace) {
260   const bool kUseAbstractNamespace = true;
261 
262   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
263   EXPECT_FALSE(client_socket.IsConnected());
264 
265   TestCompletionCallback connect_callback;
266 #if BUILDFLAG(IS_ANDROID) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
267   EXPECT_THAT(ConnectSynchronously(&client_socket),
268               IsError(ERR_CONNECTION_REFUSED));
269 #else
270   EXPECT_THAT(ConnectSynchronously(&client_socket),
271               IsError(ERR_ADDRESS_INVALID));
272 #endif
273 }
274 
TEST_F(UnixDomainClientSocketTest,DisconnectFromClient)275 TEST_F(UnixDomainClientSocketTest, DisconnectFromClient) {
276   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
277   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
278   std::unique_ptr<StreamSocket> accepted_socket;
279   TestCompletionCallback accept_callback;
280   EXPECT_EQ(ERR_IO_PENDING,
281             server_socket.Accept(&accepted_socket, accept_callback.callback()));
282   UnixDomainClientSocket client_socket(socket_path_, false);
283   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
284 
285   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
286   EXPECT_TRUE(accepted_socket->IsConnected());
287   EXPECT_TRUE(client_socket.IsConnected());
288 
289   // Try to read data.
290   const int kReadDataSize = 10;
291   auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadDataSize);
292   TestCompletionCallback read_callback;
293   EXPECT_EQ(ERR_IO_PENDING,
294             accepted_socket->Read(
295                 read_buffer.get(), kReadDataSize, read_callback.callback()));
296 
297   // Disconnect from client side.
298   client_socket.Disconnect();
299   EXPECT_FALSE(client_socket.IsConnected());
300   EXPECT_FALSE(accepted_socket->IsConnected());
301 
302   // Connection closed by peer.
303   EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult());
304   // Note that read callback won't be called when the connection is closed
305   // locally before the peer closes it. SocketPosix just clears callbacks.
306 }
307 
TEST_F(UnixDomainClientSocketTest,DisconnectFromServer)308 TEST_F(UnixDomainClientSocketTest, DisconnectFromServer) {
309   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
310   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
311   std::unique_ptr<StreamSocket> accepted_socket;
312   TestCompletionCallback accept_callback;
313   EXPECT_EQ(ERR_IO_PENDING,
314             server_socket.Accept(&accepted_socket, accept_callback.callback()));
315   UnixDomainClientSocket client_socket(socket_path_, false);
316   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
317 
318   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
319   EXPECT_TRUE(accepted_socket->IsConnected());
320   EXPECT_TRUE(client_socket.IsConnected());
321 
322   // Try to read data.
323   const int kReadDataSize = 10;
324   auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadDataSize);
325   TestCompletionCallback read_callback;
326   EXPECT_EQ(ERR_IO_PENDING,
327             client_socket.Read(
328                 read_buffer.get(), kReadDataSize, read_callback.callback()));
329 
330   // Disconnect from server side.
331   accepted_socket->Disconnect();
332   EXPECT_FALSE(accepted_socket->IsConnected());
333   EXPECT_FALSE(client_socket.IsConnected());
334 
335   // Connection closed by peer.
336   EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult());
337   // Note that read callback won't be called when the connection is closed
338   // locally before the peer closes it. SocketPosix just clears callbacks.
339 }
340 
TEST_F(UnixDomainClientSocketTest,ReadAfterWrite)341 TEST_F(UnixDomainClientSocketTest, ReadAfterWrite) {
342   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
343   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
344   std::unique_ptr<StreamSocket> accepted_socket;
345   TestCompletionCallback accept_callback;
346   EXPECT_EQ(ERR_IO_PENDING,
347             server_socket.Accept(&accepted_socket, accept_callback.callback()));
348   UnixDomainClientSocket client_socket(socket_path_, false);
349   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
350 
351   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
352   EXPECT_TRUE(accepted_socket->IsConnected());
353   EXPECT_TRUE(client_socket.IsConnected());
354 
355   // Send data from client to server.
356   const int kWriteDataSize = 10;
357   auto write_buffer =
358       base::MakeRefCounted<StringIOBuffer>(std::string(kWriteDataSize, 'd'));
359   EXPECT_EQ(
360       kWriteDataSize,
361       WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize));
362 
363   // The buffer is bigger than write data size.
364   const int kReadBufferSize = kWriteDataSize * 2;
365   auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadBufferSize);
366   EXPECT_EQ(kWriteDataSize,
367             ReadSynchronously(accepted_socket.get(),
368                               read_buffer.get(),
369                               kReadBufferSize,
370                               kWriteDataSize));
371   EXPECT_EQ(std::string(write_buffer->data(), kWriteDataSize),
372             std::string(read_buffer->data(), kWriteDataSize));
373 
374   // Send data from server and client.
375   EXPECT_EQ(kWriteDataSize,
376             WriteSynchronously(
377                 accepted_socket.get(), write_buffer.get(), kWriteDataSize));
378 
379   // Read multiple times.
380   const int kSmallReadBufferSize = kWriteDataSize / 3;
381   EXPECT_EQ(kSmallReadBufferSize,
382             ReadSynchronously(&client_socket,
383                               read_buffer.get(),
384                               kSmallReadBufferSize,
385                               kSmallReadBufferSize));
386   EXPECT_EQ(std::string(write_buffer->data(), kSmallReadBufferSize),
387             std::string(read_buffer->data(), kSmallReadBufferSize));
388 
389   EXPECT_EQ(kWriteDataSize - kSmallReadBufferSize,
390             ReadSynchronously(&client_socket,
391                               read_buffer.get(),
392                               kReadBufferSize,
393                               kWriteDataSize - kSmallReadBufferSize));
394   EXPECT_EQ(std::string(write_buffer->data() + kSmallReadBufferSize,
395                         kWriteDataSize - kSmallReadBufferSize),
396             std::string(read_buffer->data(),
397                         kWriteDataSize - kSmallReadBufferSize));
398 
399   // No more data.
400   EXPECT_EQ(
401       ERR_IO_PENDING,
402       ReadSynchronously(&client_socket, read_buffer.get(), kReadBufferSize, 0));
403 
404   // Disconnect from server side after read-write.
405   accepted_socket->Disconnect();
406   EXPECT_FALSE(accepted_socket->IsConnected());
407   EXPECT_FALSE(client_socket.IsConnected());
408 }
409 
TEST_F(UnixDomainClientSocketTest,ReadBeforeWrite)410 TEST_F(UnixDomainClientSocketTest, ReadBeforeWrite) {
411   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
412   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
413   std::unique_ptr<StreamSocket> accepted_socket;
414   TestCompletionCallback accept_callback;
415   EXPECT_EQ(ERR_IO_PENDING,
416             server_socket.Accept(&accepted_socket, accept_callback.callback()));
417   UnixDomainClientSocket client_socket(socket_path_, false);
418   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
419 
420   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
421   EXPECT_TRUE(accepted_socket->IsConnected());
422   EXPECT_TRUE(client_socket.IsConnected());
423 
424   // Wait for data from client.
425   const int kWriteDataSize = 10;
426   const int kReadBufferSize = kWriteDataSize * 2;
427   const int kSmallReadBufferSize = kWriteDataSize / 3;
428   // Read smaller than write data size first.
429   auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadBufferSize);
430   TestCompletionCallback read_callback;
431   EXPECT_EQ(
432       ERR_IO_PENDING,
433       accepted_socket->Read(
434           read_buffer.get(), kSmallReadBufferSize, read_callback.callback()));
435 
436   auto write_buffer =
437       base::MakeRefCounted<StringIOBuffer>(std::string(kWriteDataSize, 'd'));
438   EXPECT_EQ(
439       kWriteDataSize,
440       WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize));
441 
442   // First read completed.
443   int rv = read_callback.WaitForResult();
444   EXPECT_LT(0, rv);
445   EXPECT_LE(rv, kSmallReadBufferSize);
446 
447   // Read remaining data.
448   const int kExpectedRemainingDataSize = kWriteDataSize - rv;
449   EXPECT_LE(0, kExpectedRemainingDataSize);
450   EXPECT_EQ(kExpectedRemainingDataSize,
451             ReadSynchronously(accepted_socket.get(),
452                               read_buffer.get(),
453                               kReadBufferSize,
454                               kExpectedRemainingDataSize));
455   // No more data.
456   EXPECT_EQ(ERR_IO_PENDING,
457             ReadSynchronously(
458                 accepted_socket.get(), read_buffer.get(), kReadBufferSize, 0));
459 
460   // Disconnect from server side after read-write.
461   accepted_socket->Disconnect();
462   EXPECT_FALSE(accepted_socket->IsConnected());
463   EXPECT_FALSE(client_socket.IsConnected());
464 }
465 
466 }  // namespace
467 }  // namespace net
468