1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/unix_domain_server_socket_posix.h"
6
7 #include <errno.h>
8 #include <sys/socket.h>
9 #include <sys/un.h>
10 #include <unistd.h>
11 #include <utility>
12
13 #include "base/functional/bind.h"
14 #include "base/logging.h"
15 #include "build/build_config.h"
16 #include "net/base/net_errors.h"
17 #include "net/base/sockaddr_storage.h"
18 #include "net/base/sockaddr_util_posix.h"
19 #include "net/socket/socket_posix.h"
20 #include "net/socket/unix_domain_client_socket_posix.h"
21
22 namespace net {
23
UnixDomainServerSocket(const AuthCallback & auth_callback,bool use_abstract_namespace)24 UnixDomainServerSocket::UnixDomainServerSocket(
25 const AuthCallback& auth_callback,
26 bool use_abstract_namespace)
27 : auth_callback_(auth_callback),
28 use_abstract_namespace_(use_abstract_namespace) {
29 DCHECK(!auth_callback_.is_null());
30 }
31
32 UnixDomainServerSocket::~UnixDomainServerSocket() = default;
33
34 // static
GetPeerCredentials(SocketDescriptor socket,Credentials * credentials)35 bool UnixDomainServerSocket::GetPeerCredentials(SocketDescriptor socket,
36 Credentials* credentials) {
37 #if BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS) || BUILDFLAG(IS_ANDROID) || \
38 BUILDFLAG(IS_FUCHSIA)
39 struct ucred user_cred;
40 socklen_t len = sizeof(user_cred);
41 if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) < 0)
42 return false;
43 credentials->process_id = user_cred.pid;
44 credentials->user_id = user_cred.uid;
45 credentials->group_id = user_cred.gid;
46 return true;
47 #else
48 return getpeereid(
49 socket, &credentials->user_id, &credentials->group_id) == 0;
50 #endif
51 }
52
Listen(const IPEndPoint & address,int backlog,absl::optional<bool> ipv6_only)53 int UnixDomainServerSocket::Listen(const IPEndPoint& address,
54 int backlog,
55 absl::optional<bool> ipv6_only) {
56 NOTIMPLEMENTED();
57 return ERR_NOT_IMPLEMENTED;
58 }
59
ListenWithAddressAndPort(const std::string & address_string,uint16_t port,int backlog)60 int UnixDomainServerSocket::ListenWithAddressAndPort(
61 const std::string& address_string,
62 uint16_t port,
63 int backlog) {
64 NOTIMPLEMENTED();
65 return ERR_NOT_IMPLEMENTED;
66 }
67
BindAndListen(const std::string & socket_path,int backlog)68 int UnixDomainServerSocket::BindAndListen(const std::string& socket_path,
69 int backlog) {
70 DCHECK(!listen_socket_);
71
72 SockaddrStorage address;
73 if (!FillUnixAddress(socket_path, use_abstract_namespace_, &address)) {
74 return ERR_ADDRESS_INVALID;
75 }
76
77 auto socket = std::make_unique<SocketPosix>();
78 int rv = socket->Open(AF_UNIX);
79 DCHECK_NE(ERR_IO_PENDING, rv);
80 if (rv != OK)
81 return rv;
82
83 rv = socket->Bind(address);
84 DCHECK_NE(ERR_IO_PENDING, rv);
85 if (rv != OK) {
86 PLOG(ERROR)
87 << "Could not bind unix domain socket to " << socket_path
88 << (use_abstract_namespace_ ? " (with abstract namespace)" : "");
89 return rv;
90 }
91
92 rv = socket->Listen(backlog);
93 DCHECK_NE(ERR_IO_PENDING, rv);
94 if (rv != OK)
95 return rv;
96
97 listen_socket_.swap(socket);
98 return rv;
99 }
100
GetLocalAddress(IPEndPoint * address) const101 int UnixDomainServerSocket::GetLocalAddress(IPEndPoint* address) const {
102 DCHECK(address);
103
104 // Unix domain sockets have no valid associated addr/port;
105 // return address invalid.
106 return ERR_ADDRESS_INVALID;
107 }
108
Accept(std::unique_ptr<StreamSocket> * socket,CompletionOnceCallback callback)109 int UnixDomainServerSocket::Accept(std::unique_ptr<StreamSocket>* socket,
110 CompletionOnceCallback callback) {
111 DCHECK(socket);
112 DCHECK(callback);
113 DCHECK(!callback_ && !out_socket_.stream && !out_socket_.descriptor);
114
115 out_socket_ = {socket, nullptr};
116 int rv = DoAccept();
117 if (rv == ERR_IO_PENDING)
118 callback_ = std::move(callback);
119 else
120 CancelCallback();
121 return rv;
122 }
123
AcceptSocketDescriptor(SocketDescriptor * socket,CompletionOnceCallback callback)124 int UnixDomainServerSocket::AcceptSocketDescriptor(
125 SocketDescriptor* socket,
126 CompletionOnceCallback callback) {
127 DCHECK(socket);
128 DCHECK(callback);
129 DCHECK(!callback_ && !out_socket_.stream && !out_socket_.descriptor);
130
131 out_socket_ = {nullptr, socket};
132 int rv = DoAccept();
133 if (rv == ERR_IO_PENDING)
134 callback_ = std::move(callback);
135 else
136 CancelCallback();
137 return rv;
138 }
139
DoAccept()140 int UnixDomainServerSocket::DoAccept() {
141 DCHECK(listen_socket_);
142 DCHECK(!accept_socket_);
143
144 while (true) {
145 int rv = listen_socket_->Accept(
146 &accept_socket_,
147 base::BindOnce(&UnixDomainServerSocket::AcceptCompleted,
148 base::Unretained(this)));
149 if (rv != OK)
150 return rv;
151 if (AuthenticateAndGetStreamSocket())
152 return OK;
153 // Accept another socket because authentication error should be transparent
154 // to the caller.
155 }
156 }
157
AcceptCompleted(int rv)158 void UnixDomainServerSocket::AcceptCompleted(int rv) {
159 DCHECK(!callback_.is_null());
160
161 if (rv != OK) {
162 RunCallback(rv);
163 return;
164 }
165
166 if (AuthenticateAndGetStreamSocket()) {
167 RunCallback(OK);
168 return;
169 }
170
171 // Accept another socket because authentication error should be transparent
172 // to the caller.
173 rv = DoAccept();
174 if (rv != ERR_IO_PENDING)
175 RunCallback(rv);
176 }
177
AuthenticateAndGetStreamSocket()178 bool UnixDomainServerSocket::AuthenticateAndGetStreamSocket() {
179 DCHECK(accept_socket_);
180
181 Credentials credentials;
182 if (!GetPeerCredentials(accept_socket_->socket_fd(), &credentials) ||
183 !auth_callback_.Run(credentials)) {
184 accept_socket_.reset();
185 return false;
186 }
187
188 SetSocketResult(std::move(accept_socket_));
189 return true;
190 }
191
SetSocketResult(std::unique_ptr<SocketPosix> accepted_socket)192 void UnixDomainServerSocket::SetSocketResult(
193 std::unique_ptr<SocketPosix> accepted_socket) {
194 // Exactly one of the output pointers should be set.
195 DCHECK_NE(!!out_socket_.stream, !!out_socket_.descriptor);
196
197 // Pass ownership of |accepted_socket|.
198 if (out_socket_.descriptor) {
199 *out_socket_.descriptor = accepted_socket->ReleaseConnectedSocket();
200 return;
201 }
202 *out_socket_.stream =
203 std::make_unique<UnixDomainClientSocket>(std::move(accepted_socket));
204 }
205
RunCallback(int rv)206 void UnixDomainServerSocket::RunCallback(int rv) {
207 out_socket_ = SocketDestination();
208 std::move(callback_).Run(rv);
209 }
210
CancelCallback()211 void UnixDomainServerSocket::CancelCallback() {
212 out_socket_ = SocketDestination();
213 callback_.Reset();
214 }
215
216 } // namespace net
217