• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Netd"
18 
19 #include "SockDiag.h"
20 
21 #include <errno.h>
22 #include <linux/inet_diag.h>
23 #include <linux/netlink.h>
24 #include <linux/sock_diag.h>
25 #include <netdb.h>
26 #include <netinet/in.h>
27 #include <netinet/tcp.h>
28 #include <string.h>
29 #include <sys/socket.h>
30 #include <sys/uio.h>
31 
32 #include <cinttypes>
33 
34 #include <android-base/properties.h>
35 #include <android-base/stringprintf.h>
36 #include <android-base/strings.h>
37 #include <log/log.h>
38 #include <netdutils/InternetAddresses.h>
39 #include <netdutils/Stopwatch.h>
40 
41 #include "Permission.h"
42 
43 #ifndef SOCK_DESTROY
44 #define SOCK_DESTROY 21
45 #endif
46 
47 #define INET_DIAG_BC_MARK_COND 10
48 
49 namespace android {
50 
51 using android::base::StringPrintf;
52 using netdutils::ScopedAddrinfo;
53 using netdutils::Stopwatch;
54 
55 namespace net {
56 namespace {
57 
getAdbPort()58 int getAdbPort() {
59     return android::base::GetIntProperty("service.adb.tcp.port", 0);
60 }
61 
isAdbSocket(const inet_diag_msg * msg,int adbPort)62 bool isAdbSocket(const inet_diag_msg *msg, int adbPort) {
63     return adbPort > 0 && msg->id.idiag_sport == htons(adbPort) &&
64         (msg->idiag_uid == AID_ROOT || msg->idiag_uid == AID_SHELL);
65 }
66 
checkError(int fd)67 int checkError(int fd) {
68     struct {
69         nlmsghdr h;
70         nlmsgerr err;
71     } __attribute__((__packed__)) ack;
72     ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
73     if (bytesread == -1) {
74        // Read failed (error), or nothing to read (good).
75        return (errno == EAGAIN) ? 0 : -errno;
76     } else if (bytesread == (ssize_t) sizeof(ack) && ack.h.nlmsg_type == NLMSG_ERROR) {
77         // We got an error. Consume it.
78         recv(fd, &ack, sizeof(ack), 0);
79         return ack.err.error;
80     } else {
81         // The kernel replied with something. Leave it to the caller.
82         return 0;
83     }
84 }
85 
86 }  // namespace
87 
open()88 bool SockDiag::open() {
89     if (hasSocks()) {
90         return false;
91     }
92 
93     mSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
94     mWriteSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
95     if (!hasSocks()) {
96         closeSocks();
97         return false;
98     }
99 
100     sockaddr_nl nl = { .nl_family = AF_NETLINK };
101     if ((connect(mSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1) ||
102         (connect(mWriteSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1)) {
103         closeSocks();
104         return false;
105     }
106 
107     return true;
108 }
109 
sendDumpRequest(uint8_t proto,uint8_t family,uint8_t extensions,uint32_t states,iovec * iov,int iovcnt)110 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint8_t extensions, uint32_t states,
111                               iovec *iov, int iovcnt) {
112     struct {
113         nlmsghdr nlh;
114         inet_diag_req_v2 req;
115     } __attribute__((__packed__)) request = {
116         .nlh = {
117             .nlmsg_type = SOCK_DIAG_BY_FAMILY,
118             .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
119         },
120         .req = {
121             .sdiag_family = family,
122             .sdiag_protocol = proto,
123             .idiag_ext = extensions,
124             .idiag_states = states,
125         },
126     };
127 
128     size_t len = 0;
129     iov[0].iov_base = &request;
130     iov[0].iov_len = sizeof(request);
131     for (int i = 0; i < iovcnt; i++) {
132         len += iov[i].iov_len;
133     }
134     request.nlh.nlmsg_len = len;
135 
136     ssize_t writevRet = writev(mSock, iov, iovcnt);
137     // Don't let pointers to the stack escape.
138     iov[0] = {nullptr, 0};
139     if (writevRet != (ssize_t)len) {
140         return -errno;
141     }
142 
143     return checkError(mSock);
144 }
145 
sendDumpRequest(uint8_t proto,uint8_t family,uint32_t states)146 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states) {
147     iovec iov[] = {
148         { nullptr, 0 },
149     };
150     return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
151 }
152 
sendDumpRequest(uint8_t proto,uint8_t family,const char * addrstr)153 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
154     addrinfo hints = { .ai_flags = AI_NUMERICHOST };
155     addrinfo *res;
156     in6_addr mapped = { .s6_addr32 = { 0, 0, htonl(0xffff), 0 } };
157 
158     // TODO: refactor the netlink parsing code out of system/core, bring it into netd, and stop
159     // doing string conversions when they're not necessary.
160     int ret = getaddrinfo(addrstr, nullptr, &hints, &res);
161     if (ret != 0) return -EINVAL;
162 
163     // So we don't have to call freeaddrinfo on every failure path.
164     ScopedAddrinfo resP(res);
165 
166     void *addr;
167     uint8_t addrlen;
168     if (res->ai_family == AF_INET && family == AF_INET) {
169         in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
170         addr = &ina;
171         addrlen = sizeof(ina);
172     } else if (res->ai_family == AF_INET && family == AF_INET6) {
173         in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
174         mapped.s6_addr32[3] = ina.s_addr;
175         addr = &mapped;
176         addrlen = sizeof(mapped);
177     } else if (res->ai_family == AF_INET6 && family == AF_INET6) {
178         in6_addr& in6a = reinterpret_cast<sockaddr_in6*>(res->ai_addr)->sin6_addr;
179         addr = &in6a;
180         addrlen = sizeof(in6a);
181     } else {
182         return -EAFNOSUPPORT;
183     }
184 
185     uint8_t prefixlen = addrlen * 8;
186     uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
187     uint8_t nojump = yesjump + 4;
188 
189     struct {
190         nlattr nla;
191         inet_diag_bc_op op;
192         inet_diag_hostcond cond;
193     } __attribute__((__packed__)) attrs = {
194         .nla = {
195             .nla_type = INET_DIAG_REQ_BYTECODE,
196         },
197         .op = {
198             INET_DIAG_BC_S_COND,
199             yesjump,
200             nojump,
201         },
202         .cond = {
203             family,
204             prefixlen,
205             -1,
206             {}
207         },
208     };
209 
210     attrs.nla.nla_len = sizeof(attrs) + addrlen;
211 
212     iovec iov[] = {
213         { nullptr,           0 },
214         { &attrs,            sizeof(attrs) },
215         { addr,              addrlen },
216     };
217 
218     uint32_t states = ~(1 << TCP_TIME_WAIT);
219     return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
220 }
221 
readDiagMsg(uint8_t proto,const SockDiag::DestroyFilter & shouldDestroy)222 int SockDiag::readDiagMsg(uint8_t proto, const SockDiag::DestroyFilter& shouldDestroy) {
223     NetlinkDumpCallback callback = [this, proto, shouldDestroy] (nlmsghdr *nlh) {
224         const inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
225         if (shouldDestroy(proto, msg)) {
226             sockDestroy(proto, msg);
227         }
228     };
229 
230     return processNetlinkDump(mSock, callback);
231 }
232 
readDiagMsgWithTcpInfo(const TcpInfoReader & tcpInfoReader)233 int SockDiag::readDiagMsgWithTcpInfo(const TcpInfoReader& tcpInfoReader) {
234     NetlinkDumpCallback callback = [tcpInfoReader] (nlmsghdr *nlh) {
235         if (nlh->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
236             ALOGE("expected nlmsg_type=SOCK_DIAG_BY_FAMILY, got nlmsg_type=%d", nlh->nlmsg_type);
237             return;
238         }
239         Fwmark mark;
240         struct tcp_info *tcpinfo = nullptr;
241         uint32_t tcpinfoLength = 0;
242         inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
243         uint32_t attr_len = nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*msg));
244         struct rtattr *attr = reinterpret_cast<struct rtattr*>(msg+1);
245         while (RTA_OK(attr, attr_len)) {
246             if (attr->rta_type == INET_DIAG_INFO) {
247                 tcpinfo = reinterpret_cast<struct tcp_info*>(RTA_DATA(attr));
248                 tcpinfoLength = RTA_PAYLOAD(attr);
249             }
250             if (attr->rta_type == INET_DIAG_MARK) {
251                 mark.intValue = *reinterpret_cast<uint32_t*>(RTA_DATA(attr));
252             }
253             attr = RTA_NEXT(attr, attr_len);
254         }
255 
256         tcpInfoReader(mark, msg, tcpinfo, tcpinfoLength);
257     };
258 
259     return processNetlinkDump(mSock, callback);
260 }
261 
262 // Determines whether a socket is a loopback socket. Does not check socket state.
isLoopbackSocket(const inet_diag_msg * msg)263 bool SockDiag::isLoopbackSocket(const inet_diag_msg *msg) {
264     switch (msg->idiag_family) {
265         case AF_INET:
266             // Old kernels only copy the IPv4 address and leave the other 12 bytes uninitialized.
267             return IN_LOOPBACK(htonl(msg->id.idiag_src[0])) ||
268                    IN_LOOPBACK(htonl(msg->id.idiag_dst[0])) ||
269                    msg->id.idiag_src[0] == msg->id.idiag_dst[0];
270 
271         case AF_INET6: {
272             const struct in6_addr *src = (const struct in6_addr *) &msg->id.idiag_src;
273             const struct in6_addr *dst = (const struct in6_addr *) &msg->id.idiag_dst;
274             return (IN6_IS_ADDR_V4MAPPED(src) && IN_LOOPBACK(src->s6_addr32[3])) ||
275                    (IN6_IS_ADDR_V4MAPPED(dst) && IN_LOOPBACK(dst->s6_addr32[3])) ||
276                    IN6_IS_ADDR_LOOPBACK(src) || IN6_IS_ADDR_LOOPBACK(dst) ||
277                    !memcmp(src, dst, sizeof(*src));
278         }
279         default:
280             return false;
281     }
282 }
283 
sockDestroy(uint8_t proto,const inet_diag_msg * msg)284 int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
285     if (msg == nullptr) {
286        return 0;
287     }
288 
289     DestroyRequest request = {
290         .nlh = {
291             .nlmsg_type = SOCK_DESTROY,
292             .nlmsg_flags = NLM_F_REQUEST,
293         },
294         .req = {
295             .sdiag_family = msg->idiag_family,
296             .sdiag_protocol = proto,
297             .idiag_states = (uint32_t) (1 << msg->idiag_state),
298             .id = msg->id,
299         },
300     };
301     request.nlh.nlmsg_len = sizeof(request);
302 
303     if (write(mWriteSock, &request, sizeof(request)) < (ssize_t) sizeof(request)) {
304         return -errno;
305     }
306 
307     int ret = checkError(mWriteSock);
308     if (!ret) mSocketsDestroyed++;
309     return ret;
310 }
311 
destroySockets(uint8_t proto,int family,const char * addrstr,int ifindex)312 int SockDiag::destroySockets(uint8_t proto, int family, const char* addrstr, int ifindex) {
313     if (!hasSocks()) {
314         return -EBADFD;
315     }
316 
317     if (int ret = sendDumpRequest(proto, family, addrstr)) {
318         return ret;
319     }
320 
321     // Destroy all sockets on the address, except link-local sockets where ifindex doesn't match.
322     auto shouldDestroy = [ifindex](uint8_t, const inet_diag_msg* msg) {
323         return ifindex == 0 || ifindex == (int)msg->id.idiag_if;
324     };
325 
326     return readDiagMsg(proto, shouldDestroy);
327 }
328 
destroySockets(const char * addrstr,int ifindex)329 int SockDiag::destroySockets(const char* addrstr, int ifindex) {
330     Stopwatch s;
331     mSocketsDestroyed = 0;
332 
333     std::string where = addrstr;
334     if (ifindex) where += StringPrintf(" ifindex %d", ifindex);
335 
336     if (!strchr(addrstr, ':')) {  // inet_ntop never returns something like ::ffff:192.0.2.1
337         if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr, ifindex)) {
338             ALOGE("Failed to destroy IPv4 sockets on %s: %s", where.c_str(), strerror(-ret));
339             return ret;
340         }
341     }
342     if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr, ifindex)) {
343         ALOGE("Failed to destroy IPv6 sockets on %s: %s", where.c_str(), strerror(-ret));
344         return ret;
345     }
346 
347     if (mSocketsDestroyed > 0) {
348         ALOGI("Destroyed %d sockets on %s in %" PRId64 "us", mSocketsDestroyed, where.c_str(),
349               s.timeTakenUs());
350     }
351 
352     return mSocketsDestroyed;
353 }
354 
destroyLiveSockets(const DestroyFilter & destroyFilter,const char * what,iovec * iov,int iovcnt)355 int SockDiag::destroyLiveSockets(const DestroyFilter& destroyFilter, const char *what,
356                                  iovec *iov, int iovcnt) {
357     const int proto = IPPROTO_TCP;
358     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
359 
360     for (const int family : {AF_INET, AF_INET6}) {
361         const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
362         if (int ret = sendDumpRequest(proto, family, 0, states, iov, iovcnt)) {
363             ALOGE("Failed to dump %s sockets for %s: %s", familyName, what, strerror(-ret));
364             return ret;
365         }
366         if (int ret = readDiagMsg(proto, destroyFilter)) {
367             ALOGE("Failed to destroy %s sockets for %s: %s", familyName, what, strerror(-ret));
368             return ret;
369         }
370     }
371 
372     return 0;
373 }
374 
getLiveTcpInfos(const TcpInfoReader & tcpInfoReader)375 int SockDiag::getLiveTcpInfos(const TcpInfoReader& tcpInfoReader) {
376     const int proto = IPPROTO_TCP;
377     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
378     const uint8_t extensions = (1 << INET_DIAG_MEMINFO); // flag for dumping struct tcp_info.
379 
380     iovec iov[] = {
381         { nullptr, 0 },
382     };
383 
384     for (const int family : {AF_INET, AF_INET6}) {
385         const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
386         if (int ret = sendDumpRequest(proto, family, extensions, states, iov, ARRAY_SIZE(iov))) {
387             ALOGE("Failed to dump %s sockets struct tcp_info: %s", familyName, strerror(-ret));
388             return ret;
389         }
390         if (int ret = readDiagMsgWithTcpInfo(tcpInfoReader)) {
391             ALOGE("Failed to read %s sockets struct tcp_info: %s", familyName, strerror(-ret));
392             return ret;
393         }
394     }
395 
396     return 0;
397 }
398 
destroySockets(uint8_t proto,const uid_t uid,bool excludeLoopback)399 int SockDiag::destroySockets(uint8_t proto, const uid_t uid, bool excludeLoopback) {
400     mSocketsDestroyed = 0;
401     Stopwatch s;
402 
403     auto shouldDestroy = [uid, excludeLoopback] (uint8_t, const inet_diag_msg *msg) {
404         return msg != nullptr &&
405                msg->idiag_uid == uid &&
406                !(excludeLoopback && isLoopbackSocket(msg));
407     };
408 
409     for (const int family : {AF_INET, AF_INET6}) {
410         const char *familyName = family == AF_INET ? "IPv4" : "IPv6";
411         uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
412         if (int ret = sendDumpRequest(proto, family, states)) {
413             ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
414             return ret;
415         }
416         if (int ret = readDiagMsg(proto, shouldDestroy)) {
417             ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
418             return ret;
419         }
420     }
421 
422     if (mSocketsDestroyed > 0) {
423         ALOGI("Destroyed %d sockets for UID in %" PRId64 "us", mSocketsDestroyed, s.timeTakenUs());
424     }
425 
426     return 0;
427 }
428 
destroySockets(const UidRanges & uidRanges,const std::set<uid_t> & skipUids,bool excludeLoopback)429 int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids,
430                              bool excludeLoopback) {
431     mSocketsDestroyed = 0;
432     Stopwatch s;
433 
434     auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
435         return msg != nullptr &&
436                uidRanges.hasUid(msg->idiag_uid) &&
437                skipUids.find(msg->idiag_uid) == skipUids.end() &&
438                !(excludeLoopback && isLoopbackSocket(msg)) &&
439                !isAdbSocket(msg, getAdbPort());
440     };
441 
442     iovec iov[] = {
443         { nullptr, 0 },
444     };
445 
446     if (int ret = destroyLiveSockets(shouldDestroy, "UID", iov, ARRAY_SIZE(iov))) {
447         return ret;
448     }
449 
450     if (mSocketsDestroyed > 0) {
451         ALOGI("Destroyed %d sockets for %s skip={%s} in %" PRId64 "us", mSocketsDestroyed,
452               uidRanges.toString().c_str(), android::base::Join(skipUids, " ").c_str(),
453               s.timeTakenUs());
454     }
455 
456     return 0;
457 }
458 
459 // Destroys all "live" (CONNECTED, SYN_SENT, SYN_RECV) TCP sockets on the specified netId where:
460 // 1. The opening app no longer has permission to use this network, or:
461 // 2. The opening app does have permission, but did not explicitly select this network.
462 //
463 // We destroy sockets without the explicit bit because we want to avoid the situation where a
464 // privileged app uses its privileges without knowing it is doing so. For example, a privileged app
465 // might have opened a socket on this network just because it was the default network at the
466 // time. If we don't kill these sockets, those apps could continue to use them without realizing
467 // that they are now sending and receiving traffic on a network that is now restricted.
destroySocketsLackingPermission(unsigned netId,Permission permission,bool excludeLoopback)468 int SockDiag::destroySocketsLackingPermission(unsigned netId, Permission permission,
469                                               bool excludeLoopback) {
470     struct markmatch {
471         inet_diag_bc_op op;
472         // TODO: switch to inet_diag_markcond
473         __u32 mark;
474         __u32 mask;
475     } __attribute__((packed));
476     constexpr uint8_t matchlen = sizeof(markmatch);
477 
478     Fwmark netIdMark, netIdMask;
479     netIdMark.netId = netId;
480     netIdMask.netId = 0xffff;
481 
482     Fwmark controlMark;
483     controlMark.explicitlySelected = true;
484     controlMark.permission = permission;
485 
486     // A SOCK_DIAG bytecode program that accepts the sockets we intend to destroy.
487     struct bytecode {
488         markmatch netIdMatch;
489         markmatch controlMatch;
490         inet_diag_bc_op controlJump;
491     } __attribute__((packed)) bytecode;
492 
493     // The length of the INET_DIAG_BC_JMP instruction.
494     constexpr uint8_t jmplen = sizeof(inet_diag_bc_op);
495     // Jump exactly this far past the end of the program to reject.
496     constexpr uint8_t rejectoffset = sizeof(inet_diag_bc_op);
497     // Total length of the program.
498     constexpr uint8_t bytecodelen = sizeof(bytecode);
499 
500     bytecode = (struct bytecode) {
501         // If netId matches, continue, otherwise, reject (i.e., leave socket alone).
502         { { INET_DIAG_BC_MARK_COND, matchlen, bytecodelen + rejectoffset },
503           netIdMark.intValue, netIdMask.intValue },
504 
505         // If explicit and permission bits match, go to the JMP below which rejects the socket
506         // (i.e., we leave it alone). Otherwise, jump to the end of the program, which accepts the
507         // socket (so we destroy it).
508         { { INET_DIAG_BC_MARK_COND, matchlen, matchlen + jmplen },
509           controlMark.intValue, controlMark.intValue },
510 
511         // This JMP unconditionally rejects the packet by jumping to the reject target. It is
512         // necessary to keep the kernel bytecode verifier happy. If we don't have a JMP the bytecode
513         // is invalid because the target of every no jump must always be reachable by yes jumps.
514         // Without this JMP, the accept target is not reachable by yes jumps and the program will
515         // be rejected by the validator.
516         { INET_DIAG_BC_JMP, jmplen, jmplen + rejectoffset },
517 
518         // We have reached the end of the program. Accept the socket, and destroy it below.
519     };
520 
521     struct nlattr nla = {
522             .nla_len = sizeof(struct nlattr) + bytecodelen,
523             .nla_type = INET_DIAG_REQ_BYTECODE,
524     };
525 
526     iovec iov[] = {
527         { nullptr,   0 },
528         { &nla,      sizeof(nla) },
529         { &bytecode, bytecodelen },
530     };
531 
532     mSocketsDestroyed = 0;
533     Stopwatch s;
534 
535     auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
536         return msg != nullptr && !(excludeLoopback && isLoopbackSocket(msg));
537     };
538 
539     if (int ret = destroyLiveSockets(shouldDestroy, "permission change", iov, ARRAY_SIZE(iov))) {
540         return ret;
541     }
542 
543     if (mSocketsDestroyed > 0) {
544         ALOGI("Destroyed %d sockets for netId %d permission=%d in %" PRId64 "us", mSocketsDestroyed,
545               netId, permission, s.timeTakenUs());
546     }
547 
548     return 0;
549 }
550 
551 }  // namespace net
552 }  // namespace android
553