• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2014 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "NetdClient.h"
18 
19 #include <arpa/inet.h>
20 #include <errno.h>
21 #include <math.h>
22 #include <resolv.h>
23 #include <stdlib.h>
24 #include <sys/socket.h>
25 #include <sys/un.h>
26 #include <unistd.h>
27 
28 #include <atomic>
29 #include <string>
30 #include <vector>
31 
32 #include <android-base/parseint.h>
33 #include <android-base/unique_fd.h>
34 
35 #include "Fwmark.h"
36 #include "FwmarkClient.h"
37 #include "FwmarkCommand.h"
38 #include "netdclient_priv.h"
39 #include "netdutils/ResponseCode.h"
40 #include "netdutils/Stopwatch.h"
41 #include "netid_client.h"
42 
43 using android::base::ParseInt;
44 using android::base::unique_fd;
45 using android::netdutils::ResponseCode;
46 using android::netdutils::Stopwatch;
47 
48 namespace {
49 
50 // Keep this in sync with CMD_BUF_SIZE in FrameworkListener.cpp.
51 constexpr size_t MAX_CMD_SIZE = 4096;
52 
53 std::atomic_uint netIdForProcess(NETID_UNSET);
54 std::atomic_uint netIdForResolv(NETID_UNSET);
55 
56 typedef int (*Accept4FunctionType)(int, sockaddr*, socklen_t*, int);
57 typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t);
58 typedef int (*SocketFunctionType)(int, int, int);
59 typedef unsigned (*NetIdForResolvFunctionType)(unsigned);
60 typedef int (*DnsOpenProxyType)();
61 
62 // These variables are only modified at startup (when libc.so is loaded) and never afterwards, so
63 // it's okay that they are read later at runtime without a lock.
64 Accept4FunctionType libcAccept4 = nullptr;
65 ConnectFunctionType libcConnect = nullptr;
66 SocketFunctionType libcSocket = nullptr;
67 
checkSocket(int socketFd)68 int checkSocket(int socketFd) {
69     if (socketFd < 0) {
70         return -EBADF;
71     }
72     int family;
73     socklen_t familyLen = sizeof(family);
74     if (getsockopt(socketFd, SOL_SOCKET, SO_DOMAIN, &family, &familyLen) == -1) {
75         return -errno;
76     }
77     if (!FwmarkClient::shouldSetFwmark(family)) {
78         return -EAFNOSUPPORT;
79     }
80     return 0;
81 }
82 
shouldMarkSocket(int socketFd,const sockaddr * dst)83 bool shouldMarkSocket(int socketFd, const sockaddr* dst) {
84     // Only mark inet sockets that are connecting to inet destinations. This excludes, for example,
85     // inet sockets connecting to AF_UNSPEC (i.e., being disconnected), and non-inet sockets that
86     // for some reason the caller wants to attempt to connect to an inet destination.
87     return dst && FwmarkClient::shouldSetFwmark(dst->sa_family) && (checkSocket(socketFd) == 0);
88 }
89 
closeFdAndSetErrno(int fd,int error)90 int closeFdAndSetErrno(int fd, int error) {
91     close(fd);
92     errno = -error;
93     return -1;
94 }
95 
netdClientAccept4(int sockfd,sockaddr * addr,socklen_t * addrlen,int flags)96 int netdClientAccept4(int sockfd, sockaddr* addr, socklen_t* addrlen, int flags) {
97     int acceptedSocket = libcAccept4(sockfd, addr, addrlen, flags);
98     if (acceptedSocket == -1) {
99         return -1;
100     }
101     int family;
102     if (addr) {
103         family = addr->sa_family;
104     } else {
105         socklen_t familyLen = sizeof(family);
106         if (getsockopt(acceptedSocket, SOL_SOCKET, SO_DOMAIN, &family, &familyLen) == -1) {
107             return closeFdAndSetErrno(acceptedSocket, -errno);
108         }
109     }
110     if (FwmarkClient::shouldSetFwmark(family)) {
111         FwmarkCommand command = {FwmarkCommand::ON_ACCEPT, 0, 0, 0};
112         if (int error = FwmarkClient().send(&command, acceptedSocket, nullptr)) {
113             return closeFdAndSetErrno(acceptedSocket, error);
114         }
115     }
116     return acceptedSocket;
117 }
118 
netdClientConnect(int sockfd,const sockaddr * addr,socklen_t addrlen)119 int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) {
120     const bool shouldSetFwmark = shouldMarkSocket(sockfd, addr);
121     if (shouldSetFwmark) {
122         FwmarkCommand command = {FwmarkCommand::ON_CONNECT, 0, 0, 0};
123         if (int error = FwmarkClient().send(&command, sockfd, nullptr)) {
124             errno = -error;
125             return -1;
126         }
127     }
128     // Latency measurement does not include time of sending commands to Fwmark
129     Stopwatch s;
130     const int ret = libcConnect(sockfd, addr, addrlen);
131     // Save errno so it isn't clobbered by sending ON_CONNECT_COMPLETE
132     const int connectErrno = errno;
133     const unsigned latencyMs = lround(s.timeTaken());
134     // Send an ON_CONNECT_COMPLETE command that includes sockaddr and connect latency for reporting
135     if (shouldSetFwmark && FwmarkClient::shouldReportConnectComplete(addr->sa_family)) {
136         FwmarkConnectInfo connectInfo(ret == 0 ? 0 : connectErrno, latencyMs, addr);
137         // TODO: get the netId from the socket mark once we have continuous benchmark runs
138         FwmarkCommand command = {FwmarkCommand::ON_CONNECT_COMPLETE, /* netId (ignored) */ 0,
139                                  /* uid (filled in by the server) */ 0, 0};
140         // Ignore return value since it's only used for logging
141         FwmarkClient().send(&command, sockfd, &connectInfo);
142     }
143     errno = connectErrno;
144     return ret;
145 }
146 
netdClientSocket(int domain,int type,int protocol)147 int netdClientSocket(int domain, int type, int protocol) {
148     int socketFd = libcSocket(domain, type, protocol);
149     if (socketFd == -1) {
150         return -1;
151     }
152     unsigned netId = netIdForProcess & ~NETID_USE_LOCAL_NAMESERVERS;
153     if (netId != NETID_UNSET && FwmarkClient::shouldSetFwmark(domain)) {
154         if (int error = setNetworkForSocket(netId, socketFd)) {
155             return closeFdAndSetErrno(socketFd, error);
156         }
157     }
158     return socketFd;
159 }
160 
getNetworkForResolv(unsigned netId)161 unsigned getNetworkForResolv(unsigned netId) {
162     if (netId != NETID_UNSET) {
163         return netId;
164     }
165     // Special case for DNS-over-TLS bypass; b/72345192 .
166     if ((netIdForResolv & ~NETID_USE_LOCAL_NAMESERVERS) != NETID_UNSET) {
167         return netIdForResolv;
168     }
169     netId = netIdForProcess;
170     if (netId != NETID_UNSET) {
171         return netId;
172     }
173     return netIdForResolv;
174 }
175 
setNetworkForTarget(unsigned netId,std::atomic_uint * target)176 int setNetworkForTarget(unsigned netId, std::atomic_uint* target) {
177     const unsigned requestedNetId = netId;
178     netId &= ~NETID_USE_LOCAL_NAMESERVERS;
179 
180     if (netId == NETID_UNSET) {
181         *target = netId;
182         return 0;
183     }
184     // Verify that we are allowed to use |netId|, by creating a socket and trying to have it marked
185     // with the netId. Call libcSocket() directly; else the socket creation (via netdClientSocket())
186     // might itself cause another check with the fwmark server, which would be wasteful.
187 
188     const auto socketFunc = libcSocket ? libcSocket : socket;
189     int socketFd = socketFunc(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC, 0);
190     if (socketFd < 0) {
191         return -errno;
192     }
193     int error = setNetworkForSocket(netId, socketFd);
194     if (!error) {
195         *target = requestedNetId;
196     }
197     close(socketFd);
198     return error;
199 }
200 
dns_open_proxy()201 int dns_open_proxy() {
202     const char* cache_mode = getenv("ANDROID_DNS_MODE");
203     const bool use_proxy = (cache_mode == NULL || strcmp(cache_mode, "local") != 0);
204     if (!use_proxy) {
205         errno = ENOSYS;
206         return -1;
207     }
208 
209     const auto socketFunc = libcSocket ? libcSocket : socket;
210     int s = socketFunc(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
211     if (s == -1) {
212         return -1;
213     }
214     const int one = 1;
215     setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one));
216 
217     static const struct sockaddr_un proxy_addr = {
218             .sun_family = AF_UNIX,
219             .sun_path = "/dev/socket/dnsproxyd",
220     };
221 
222     const auto connectFunc = libcConnect ? libcConnect : connect;
223     if (TEMP_FAILURE_RETRY(
224                 connectFunc(s, (const struct sockaddr*) &proxy_addr, sizeof(proxy_addr))) != 0) {
225         // Store the errno for connect because we only care about why we can't connect to dnsproxyd
226         int storedErrno = errno;
227         close(s);
228         errno = storedErrno;
229         return -1;
230     }
231 
232     return s;
233 }
234 
divCeil(size_t dividend,size_t divisor)235 auto divCeil(size_t dividend, size_t divisor) {
236     return ((dividend + divisor - 1) / divisor);
237 }
238 
239 // FrameworkListener only does only read() call, and fails if the read doesn't contain \0
240 // Do single write here
sendData(int fd,const void * buf,size_t size)241 int sendData(int fd, const void* buf, size_t size) {
242     if (fd < 0) {
243         return -EBADF;
244     }
245 
246     ssize_t rc = TEMP_FAILURE_RETRY(write(fd, (char*) buf, size));
247     if (rc > 0) {
248         return rc;
249     } else if (rc == 0) {
250         return -EIO;
251     } else {
252         return -errno;
253     }
254 }
255 
readData(int fd,void * buf,size_t size)256 int readData(int fd, void* buf, size_t size) {
257     if (fd < 0) {
258         return -EBADF;
259     }
260 
261     size_t current = 0;
262     for (;;) {
263         ssize_t rc = TEMP_FAILURE_RETRY(read(fd, (char*) buf + current, size - current));
264         if (rc > 0) {
265             current += rc;
266             if (current == size) {
267                 break;
268             }
269         } else if (rc == 0) {
270             return -EIO;
271         } else {
272             return -errno;
273         }
274     }
275     return 0;
276 }
277 
readBE32(int fd,int32_t * result)278 bool readBE32(int fd, int32_t* result) {
279     int32_t tmp;
280     ssize_t n = TEMP_FAILURE_RETRY(read(fd, &tmp, sizeof(tmp)));
281     if (n < static_cast<ssize_t>(sizeof(tmp))) {
282         return false;
283     }
284     *result = ntohl(tmp);
285     return true;
286 }
287 
readResponseCode(int fd,int * result)288 bool readResponseCode(int fd, int* result) {
289     char buf[4];
290     ssize_t n = TEMP_FAILURE_RETRY(read(fd, &buf, sizeof(buf)));
291     if (n < static_cast<ssize_t>(sizeof(buf))) {
292         return false;
293     }
294 
295     // The format of response code is 3 bytes followed by a space.
296     buf[3] = '\0';
297     if (!ParseInt(buf, result)) {
298         errno = EINVAL;
299         return false;
300     }
301 
302     return true;
303 }
304 
305 }  // namespace
306 
307 #define CHECK_SOCKET_IS_MARKABLE(sock)          \
308     do {                                        \
309         int err;                                \
310         if ((err = checkSocket(sock)) != 0) {   \
311             return err;                         \
312         }                                       \
313     } while (false);
314 
315 // accept() just calls accept4(..., 0), so there's no need to handle accept() separately.
netdClientInitAccept4(Accept4FunctionType * function)316 extern "C" void netdClientInitAccept4(Accept4FunctionType* function) {
317     if (function && *function) {
318         libcAccept4 = *function;
319         *function = netdClientAccept4;
320     }
321 }
322 
netdClientInitConnect(ConnectFunctionType * function)323 extern "C" void netdClientInitConnect(ConnectFunctionType* function) {
324     if (function && *function) {
325         libcConnect = *function;
326         *function = netdClientConnect;
327     }
328 }
329 
netdClientInitSocket(SocketFunctionType * function)330 extern "C" void netdClientInitSocket(SocketFunctionType* function) {
331     if (function && *function) {
332         libcSocket = *function;
333         *function = netdClientSocket;
334     }
335 }
336 
netdClientInitNetIdForResolv(NetIdForResolvFunctionType * function)337 extern "C" void netdClientInitNetIdForResolv(NetIdForResolvFunctionType* function) {
338     if (function) {
339         *function = getNetworkForResolv;
340     }
341 }
342 
netdClientInitDnsOpenProxy(DnsOpenProxyType * function)343 extern "C" void netdClientInitDnsOpenProxy(DnsOpenProxyType* function) {
344     if (function) {
345         *function = dns_open_proxy;
346     }
347 }
348 
getNetworkForSocket(unsigned * netId,int socketFd)349 extern "C" int getNetworkForSocket(unsigned* netId, int socketFd) {
350     if (!netId || socketFd < 0) {
351         return -EBADF;
352     }
353     Fwmark fwmark;
354     socklen_t fwmarkLen = sizeof(fwmark.intValue);
355     if (getsockopt(socketFd, SOL_SOCKET, SO_MARK, &fwmark.intValue, &fwmarkLen) == -1) {
356         return -errno;
357     }
358     *netId = fwmark.netId;
359     return 0;
360 }
361 
getNetworkForProcess()362 extern "C" unsigned getNetworkForProcess() {
363     return netIdForProcess & ~NETID_USE_LOCAL_NAMESERVERS;
364 }
365 
setNetworkForSocket(unsigned netId,int socketFd)366 extern "C" int setNetworkForSocket(unsigned netId, int socketFd) {
367     CHECK_SOCKET_IS_MARKABLE(socketFd);
368     FwmarkCommand command = {FwmarkCommand::SELECT_NETWORK, netId, 0, 0};
369     return FwmarkClient().send(&command, socketFd, nullptr);
370 }
371 
setNetworkForProcess(unsigned netId)372 extern "C" int setNetworkForProcess(unsigned netId) {
373     return setNetworkForTarget(netId, &netIdForProcess);
374 }
375 
setNetworkForResolv(unsigned netId)376 extern "C" int setNetworkForResolv(unsigned netId) {
377     return setNetworkForTarget(netId, &netIdForResolv);
378 }
379 
protectFromVpn(int socketFd)380 extern "C" int protectFromVpn(int socketFd) {
381     if (socketFd < 0) {
382         return -EBADF;
383     }
384     FwmarkCommand command = {FwmarkCommand::PROTECT_FROM_VPN, 0, 0, 0};
385     return FwmarkClient().send(&command, socketFd, nullptr);
386 }
387 
setNetworkForUser(uid_t uid,int socketFd)388 extern "C" int setNetworkForUser(uid_t uid, int socketFd) {
389     CHECK_SOCKET_IS_MARKABLE(socketFd);
390     FwmarkCommand command = {FwmarkCommand::SELECT_FOR_USER, 0, uid, 0};
391     return FwmarkClient().send(&command, socketFd, nullptr);
392 }
393 
queryUserAccess(uid_t uid,unsigned netId)394 extern "C" int queryUserAccess(uid_t uid, unsigned netId) {
395     FwmarkCommand command = {FwmarkCommand::QUERY_USER_ACCESS, netId, uid, 0};
396     return FwmarkClient().send(&command, -1, nullptr);
397 }
398 
tagSocket(int socketFd,uint32_t tag,uid_t uid)399 extern "C" int tagSocket(int socketFd, uint32_t tag, uid_t uid) {
400     CHECK_SOCKET_IS_MARKABLE(socketFd);
401     FwmarkCommand command = {FwmarkCommand::TAG_SOCKET, 0, uid, tag};
402     return FwmarkClient().send(&command, socketFd, nullptr);
403 }
404 
untagSocket(int socketFd)405 extern "C" int untagSocket(int socketFd) {
406     CHECK_SOCKET_IS_MARKABLE(socketFd);
407     FwmarkCommand command = {FwmarkCommand::UNTAG_SOCKET, 0, 0, 0};
408     return FwmarkClient().send(&command, socketFd, nullptr);
409 }
410 
setCounterSet(uint32_t counterSet,uid_t uid)411 extern "C" int setCounterSet(uint32_t counterSet, uid_t uid) {
412     FwmarkCommand command = {FwmarkCommand::SET_COUNTERSET, 0, uid, counterSet};
413     return FwmarkClient().send(&command, -1, nullptr);
414 }
415 
deleteTagData(uint32_t tag,uid_t uid)416 extern "C" int deleteTagData(uint32_t tag, uid_t uid) {
417     FwmarkCommand command = {FwmarkCommand::DELETE_TAGDATA, 0, uid, tag};
418     return FwmarkClient().send(&command, -1, nullptr);
419 }
420 
resNetworkQuery(unsigned netId,const char * dname,int ns_class,int ns_type,uint32_t flags)421 extern "C" int resNetworkQuery(unsigned netId, const char* dname, int ns_class, int ns_type,
422                                uint32_t flags) {
423     std::vector<uint8_t> buf(MAX_CMD_SIZE, 0);
424     int len = res_mkquery(ns_o_query, dname, ns_class, ns_type, nullptr, 0, nullptr, buf.data(),
425                           MAX_CMD_SIZE);
426 
427     return resNetworkSend(netId, buf.data(), len, flags);
428 }
429 
resNetworkSend(unsigned netId,const uint8_t * msg,size_t msglen,uint32_t flags)430 extern "C" int resNetworkSend(unsigned netId, const uint8_t* msg, size_t msglen, uint32_t flags) {
431     // Encode
432     // Base 64 encodes every 3 bytes into 4 characters, but then adds padding to the next
433     // multiple of 4 and a \0
434     const size_t encodedLen = divCeil(msglen, 3) * 4 + 1;
435     std::string encodedQuery(encodedLen - 1, 0);
436     int enLen = b64_ntop(msg, msglen, encodedQuery.data(), encodedLen);
437 
438     if (enLen < 0) {
439         // Unexpected behavior, encode failed
440         // b64_ntop only fails when size is too long.
441         return -EMSGSIZE;
442     }
443     // Send
444     netId = getNetworkForResolv(netId);
445     const std::string cmd = "resnsend " + std::to_string(netId) + " " + std::to_string(flags) +
446                             " " + encodedQuery + '\0';
447     if (cmd.size() > MAX_CMD_SIZE) {
448         // Cmd size must less than buffer size of FrameworkListener
449         return -EMSGSIZE;
450     }
451     int fd = dns_open_proxy();
452     if (fd == -1) {
453         return -errno;
454     }
455     ssize_t rc = sendData(fd, cmd.c_str(), cmd.size());
456     if (rc < 0) {
457         close(fd);
458         return rc;
459     }
460     shutdown(fd, SHUT_WR);
461     return fd;
462 }
463 
resNetworkResult(int fd,int * rcode,uint8_t * answer,size_t anslen)464 extern "C" int resNetworkResult(int fd, int* rcode, uint8_t* answer, size_t anslen) {
465     int32_t result = 0;
466     unique_fd ufd(fd);
467     // Read -errno/rcode
468     if (!readBE32(fd, &result)) {
469         // Unexpected behavior, read -errno/rcode fail
470         return -errno;
471     }
472     if (result < 0) {
473         // result < 0, it's -errno
474         return result;
475     }
476     // result >= 0, it's rcode
477     *rcode = result;
478 
479     // Read answer
480     int32_t size = 0;
481     if (!readBE32(fd, &size)) {
482         // Unexpected behavior, read ans len fail
483         return -EREMOTEIO;
484     }
485     if (anslen < static_cast<size_t>(size)) {
486         // Answer buffer is too small
487         return -EMSGSIZE;
488     }
489     int rc = readData(fd, answer, size);
490     if (rc < 0) {
491         // Reading the answer failed.
492         return rc;
493     }
494     return size;
495 }
496 
resNetworkCancel(int fd)497 extern "C" void resNetworkCancel(int fd) {
498     close(fd);
499 }
500 
getNetworkForDns(unsigned * dnsNetId)501 extern "C" int getNetworkForDns(unsigned* dnsNetId) {
502     if (dnsNetId == nullptr) return -EFAULT;
503     int fd = dns_open_proxy();
504     if (fd == -1) {
505         return -errno;
506     }
507     unique_fd ufd(fd);
508     return getNetworkForDnsInternal(fd, dnsNetId);
509 }
510 
getNetworkForDnsInternal(int fd,unsigned * dnsNetId)511 int getNetworkForDnsInternal(int fd, unsigned* dnsNetId) {
512     if (fd == -1) {
513         return -EBADF;
514     }
515 
516     unsigned resolvNetId = getNetworkForResolv(NETID_UNSET);
517 
518     const std::string cmd = "getdnsnetid " + std::to_string(resolvNetId);
519     ssize_t rc = sendData(fd, cmd.c_str(), cmd.size() + 1);
520     if (rc < 0) {
521         return rc;
522     }
523 
524     int responseCode = 0;
525     // Read responseCode
526     if (!readResponseCode(fd, &responseCode)) {
527         // Unexpected behavior, read responseCode fail
528         return -errno;
529     }
530 
531     if (responseCode != ResponseCode::DnsProxyQueryResult) {
532         return -EOPNOTSUPP;
533     }
534 
535     int32_t result = 0;
536     // Read -errno/dnsnetid
537     if (!readBE32(fd, &result)) {
538         // Unexpected behavior, read -errno/dnsnetid fail
539         return -errno;
540     }
541 
542     *dnsNetId = result;
543 
544     return 0;
545 }
546