• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Netd"
18 
19 #include "SockDiag.h"
20 
21 #include <errno.h>
22 #include <linux/inet_diag.h>
23 #include <linux/netlink.h>
24 #include <linux/sock_diag.h>
25 #include <netdb.h>
26 #include <netinet/in.h>
27 #include <netinet/tcp.h>
28 #include <string.h>
29 #include <sys/socket.h>
30 #include <sys/uio.h>
31 
32 #include <cinttypes>
33 
34 #include <android-base/properties.h>
35 #include <android-base/strings.h>
36 #include <log/log.h>
37 #include <netdutils/InternetAddresses.h>
38 #include <netdutils/Stopwatch.h>
39 
40 #include "Permission.h"
41 
42 #ifndef SOCK_DESTROY
43 #define SOCK_DESTROY 21
44 #endif
45 
46 #define INET_DIAG_BC_MARK_COND 10
47 
48 namespace android {
49 
50 using netdutils::ScopedAddrinfo;
51 using netdutils::Stopwatch;
52 
53 namespace net {
54 namespace {
55 
getAdbPort()56 int getAdbPort() {
57     return android::base::GetIntProperty("service.adb.tcp.port", 0);
58 }
59 
isAdbSocket(const inet_diag_msg * msg,int adbPort)60 bool isAdbSocket(const inet_diag_msg *msg, int adbPort) {
61     return adbPort > 0 && msg->id.idiag_sport == htons(adbPort) &&
62         (msg->idiag_uid == AID_ROOT || msg->idiag_uid == AID_SHELL);
63 }
64 
checkError(int fd)65 int checkError(int fd) {
66     struct {
67         nlmsghdr h;
68         nlmsgerr err;
69     } __attribute__((__packed__)) ack;
70     ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
71     if (bytesread == -1) {
72        // Read failed (error), or nothing to read (good).
73        return (errno == EAGAIN) ? 0 : -errno;
74     } else if (bytesread == (ssize_t) sizeof(ack) && ack.h.nlmsg_type == NLMSG_ERROR) {
75         // We got an error. Consume it.
76         recv(fd, &ack, sizeof(ack), 0);
77         return ack.err.error;
78     } else {
79         // The kernel replied with something. Leave it to the caller.
80         return 0;
81     }
82 }
83 
84 }  // namespace
85 
open()86 bool SockDiag::open() {
87     if (hasSocks()) {
88         return false;
89     }
90 
91     mSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
92     mWriteSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
93     if (!hasSocks()) {
94         closeSocks();
95         return false;
96     }
97 
98     sockaddr_nl nl = { .nl_family = AF_NETLINK };
99     if ((connect(mSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1) ||
100         (connect(mWriteSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1)) {
101         closeSocks();
102         return false;
103     }
104 
105     return true;
106 }
107 
sendDumpRequest(uint8_t proto,uint8_t family,uint8_t extensions,uint32_t states,iovec * iov,int iovcnt)108 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint8_t extensions, uint32_t states,
109                               iovec *iov, int iovcnt) {
110     struct {
111         nlmsghdr nlh;
112         inet_diag_req_v2 req;
113     } __attribute__((__packed__)) request = {
114         .nlh = {
115             .nlmsg_type = SOCK_DIAG_BY_FAMILY,
116             .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
117         },
118         .req = {
119             .sdiag_family = family,
120             .sdiag_protocol = proto,
121             .idiag_ext = extensions,
122             .idiag_states = states,
123         },
124     };
125 
126     size_t len = 0;
127     iov[0].iov_base = &request;
128     iov[0].iov_len = sizeof(request);
129     for (int i = 0; i < iovcnt; i++) {
130         len += iov[i].iov_len;
131     }
132     request.nlh.nlmsg_len = len;
133 
134     if (writev(mSock, iov, iovcnt) != (ssize_t) len) {
135         return -errno;
136     }
137 
138     return checkError(mSock);
139 }
140 
sendDumpRequest(uint8_t proto,uint8_t family,uint32_t states)141 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states) {
142     iovec iov[] = {
143         { nullptr, 0 },
144     };
145     return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
146 }
147 
sendDumpRequest(uint8_t proto,uint8_t family,const char * addrstr)148 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
149     addrinfo hints = { .ai_flags = AI_NUMERICHOST };
150     addrinfo *res;
151     in6_addr mapped = { .s6_addr32 = { 0, 0, htonl(0xffff), 0 } };
152 
153     // TODO: refactor the netlink parsing code out of system/core, bring it into netd, and stop
154     // doing string conversions when they're not necessary.
155     int ret = getaddrinfo(addrstr, nullptr, &hints, &res);
156     if (ret != 0) return -EINVAL;
157 
158     // So we don't have to call freeaddrinfo on every failure path.
159     ScopedAddrinfo resP(res);
160 
161     void *addr;
162     uint8_t addrlen;
163     if (res->ai_family == AF_INET && family == AF_INET) {
164         in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
165         addr = &ina;
166         addrlen = sizeof(ina);
167     } else if (res->ai_family == AF_INET && family == AF_INET6) {
168         in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
169         mapped.s6_addr32[3] = ina.s_addr;
170         addr = &mapped;
171         addrlen = sizeof(mapped);
172     } else if (res->ai_family == AF_INET6 && family == AF_INET6) {
173         in6_addr& in6a = reinterpret_cast<sockaddr_in6*>(res->ai_addr)->sin6_addr;
174         addr = &in6a;
175         addrlen = sizeof(in6a);
176     } else {
177         return -EAFNOSUPPORT;
178     }
179 
180     uint8_t prefixlen = addrlen * 8;
181     uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
182     uint8_t nojump = yesjump + 4;
183 
184     struct {
185         nlattr nla;
186         inet_diag_bc_op op;
187         inet_diag_hostcond cond;
188     } __attribute__((__packed__)) attrs = {
189         .nla = {
190             .nla_type = INET_DIAG_REQ_BYTECODE,
191         },
192         .op = {
193             INET_DIAG_BC_S_COND,
194             yesjump,
195             nojump,
196         },
197         .cond = {
198             family,
199             prefixlen,
200             -1,
201             {}
202         },
203     };
204 
205     attrs.nla.nla_len = sizeof(attrs) + addrlen;
206 
207     iovec iov[] = {
208         { nullptr,           0 },
209         { &attrs,            sizeof(attrs) },
210         { addr,              addrlen },
211     };
212 
213     uint32_t states = ~(1 << TCP_TIME_WAIT);
214     return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
215 }
216 
readDiagMsg(uint8_t proto,const SockDiag::DestroyFilter & shouldDestroy)217 int SockDiag::readDiagMsg(uint8_t proto, const SockDiag::DestroyFilter& shouldDestroy) {
218     NetlinkDumpCallback callback = [this, proto, shouldDestroy] (nlmsghdr *nlh) {
219         const inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
220         if (shouldDestroy(proto, msg)) {
221             sockDestroy(proto, msg);
222         }
223     };
224 
225     return processNetlinkDump(mSock, callback);
226 }
227 
readDiagMsgWithTcpInfo(const TcpInfoReader & tcpInfoReader)228 int SockDiag::readDiagMsgWithTcpInfo(const TcpInfoReader& tcpInfoReader) {
229     NetlinkDumpCallback callback = [tcpInfoReader] (nlmsghdr *nlh) {
230         if (nlh->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
231             ALOGE("expected nlmsg_type=SOCK_DIAG_BY_FAMILY, got nlmsg_type=%d", nlh->nlmsg_type);
232             return;
233         }
234         Fwmark mark;
235         struct tcp_info *tcpinfo = nullptr;
236         uint32_t tcpinfoLength = 0;
237         inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
238         uint32_t attr_len = nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*msg));
239         struct rtattr *attr = reinterpret_cast<struct rtattr*>(msg+1);
240         while (RTA_OK(attr, attr_len)) {
241             if (attr->rta_type == INET_DIAG_INFO) {
242                 tcpinfo = reinterpret_cast<struct tcp_info*>(RTA_DATA(attr));
243                 tcpinfoLength = RTA_PAYLOAD(attr);
244             }
245             if (attr->rta_type == INET_DIAG_MARK) {
246                 mark.intValue = *reinterpret_cast<uint32_t*>(RTA_DATA(attr));
247             }
248             attr = RTA_NEXT(attr, attr_len);
249         }
250 
251         tcpInfoReader(mark, msg, tcpinfo, tcpinfoLength);
252     };
253 
254     return processNetlinkDump(mSock, callback);
255 }
256 
257 // Determines whether a socket is a loopback socket. Does not check socket state.
isLoopbackSocket(const inet_diag_msg * msg)258 bool SockDiag::isLoopbackSocket(const inet_diag_msg *msg) {
259     switch (msg->idiag_family) {
260         case AF_INET:
261             // Old kernels only copy the IPv4 address and leave the other 12 bytes uninitialized.
262             return IN_LOOPBACK(htonl(msg->id.idiag_src[0])) ||
263                    IN_LOOPBACK(htonl(msg->id.idiag_dst[0])) ||
264                    msg->id.idiag_src[0] == msg->id.idiag_dst[0];
265 
266         case AF_INET6: {
267             const struct in6_addr *src = (const struct in6_addr *) &msg->id.idiag_src;
268             const struct in6_addr *dst = (const struct in6_addr *) &msg->id.idiag_dst;
269             return (IN6_IS_ADDR_V4MAPPED(src) && IN_LOOPBACK(src->s6_addr32[3])) ||
270                    (IN6_IS_ADDR_V4MAPPED(dst) && IN_LOOPBACK(dst->s6_addr32[3])) ||
271                    IN6_IS_ADDR_LOOPBACK(src) || IN6_IS_ADDR_LOOPBACK(dst) ||
272                    !memcmp(src, dst, sizeof(*src));
273         }
274         default:
275             return false;
276     }
277 }
278 
sockDestroy(uint8_t proto,const inet_diag_msg * msg)279 int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
280     if (msg == nullptr) {
281        return 0;
282     }
283 
284     DestroyRequest request = {
285         .nlh = {
286             .nlmsg_type = SOCK_DESTROY,
287             .nlmsg_flags = NLM_F_REQUEST,
288         },
289         .req = {
290             .sdiag_family = msg->idiag_family,
291             .sdiag_protocol = proto,
292             .idiag_states = (uint32_t) (1 << msg->idiag_state),
293             .id = msg->id,
294         },
295     };
296     request.nlh.nlmsg_len = sizeof(request);
297 
298     if (write(mWriteSock, &request, sizeof(request)) < (ssize_t) sizeof(request)) {
299         return -errno;
300     }
301 
302     int ret = checkError(mWriteSock);
303     if (!ret) mSocketsDestroyed++;
304     return ret;
305 }
306 
destroySockets(uint8_t proto,int family,const char * addrstr)307 int SockDiag::destroySockets(uint8_t proto, int family, const char *addrstr) {
308     if (!hasSocks()) {
309         return -EBADFD;
310     }
311 
312     if (int ret = sendDumpRequest(proto, family, addrstr)) {
313         return ret;
314     }
315 
316     auto destroyAll = [] (uint8_t, const inet_diag_msg*) { return true; };
317 
318     return readDiagMsg(proto, destroyAll);
319 }
320 
destroySockets(const char * addrstr)321 int SockDiag::destroySockets(const char *addrstr) {
322     Stopwatch s;
323     mSocketsDestroyed = 0;
324 
325     if (!strchr(addrstr, ':')) {
326         if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr)) {
327             ALOGE("Failed to destroy IPv4 sockets on %s: %s", addrstr, strerror(-ret));
328             return ret;
329         }
330     }
331     if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr)) {
332         ALOGE("Failed to destroy IPv6 sockets on %s: %s", addrstr, strerror(-ret));
333         return ret;
334     }
335 
336     if (mSocketsDestroyed > 0) {
337         ALOGI("Destroyed %d sockets on %s in %" PRId64 "us", mSocketsDestroyed, addrstr,
338               s.timeTakenUs());
339     }
340 
341     return mSocketsDestroyed;
342 }
343 
destroyLiveSockets(const DestroyFilter & destroyFilter,const char * what,iovec * iov,int iovcnt)344 int SockDiag::destroyLiveSockets(const DestroyFilter& destroyFilter, const char *what,
345                                  iovec *iov, int iovcnt) {
346     const int proto = IPPROTO_TCP;
347     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
348 
349     for (const int family : {AF_INET, AF_INET6}) {
350         const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
351         if (int ret = sendDumpRequest(proto, family, 0, states, iov, iovcnt)) {
352             ALOGE("Failed to dump %s sockets for %s: %s", familyName, what, strerror(-ret));
353             return ret;
354         }
355         if (int ret = readDiagMsg(proto, destroyFilter)) {
356             ALOGE("Failed to destroy %s sockets for %s: %s", familyName, what, strerror(-ret));
357             return ret;
358         }
359     }
360 
361     return 0;
362 }
363 
getLiveTcpInfos(const TcpInfoReader & tcpInfoReader)364 int SockDiag::getLiveTcpInfos(const TcpInfoReader& tcpInfoReader) {
365     const int proto = IPPROTO_TCP;
366     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
367     const uint8_t extensions = (1 << INET_DIAG_MEMINFO); // flag for dumping struct tcp_info.
368 
369     iovec iov[] = {
370         { nullptr, 0 },
371     };
372 
373     for (const int family : {AF_INET, AF_INET6}) {
374         const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
375         if (int ret = sendDumpRequest(proto, family, extensions, states, iov, ARRAY_SIZE(iov))) {
376             ALOGE("Failed to dump %s sockets struct tcp_info: %s", familyName, strerror(-ret));
377             return ret;
378         }
379         if (int ret = readDiagMsgWithTcpInfo(tcpInfoReader)) {
380             ALOGE("Failed to read %s sockets struct tcp_info: %s", familyName, strerror(-ret));
381             return ret;
382         }
383     }
384 
385     return 0;
386 }
387 
destroySockets(uint8_t proto,const uid_t uid,bool excludeLoopback)388 int SockDiag::destroySockets(uint8_t proto, const uid_t uid, bool excludeLoopback) {
389     mSocketsDestroyed = 0;
390     Stopwatch s;
391 
392     auto shouldDestroy = [uid, excludeLoopback] (uint8_t, const inet_diag_msg *msg) {
393         return msg != nullptr &&
394                msg->idiag_uid == uid &&
395                !(excludeLoopback && isLoopbackSocket(msg));
396     };
397 
398     for (const int family : {AF_INET, AF_INET6}) {
399         const char *familyName = family == AF_INET ? "IPv4" : "IPv6";
400         uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
401         if (int ret = sendDumpRequest(proto, family, states)) {
402             ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
403             return ret;
404         }
405         if (int ret = readDiagMsg(proto, shouldDestroy)) {
406             ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
407             return ret;
408         }
409     }
410 
411     if (mSocketsDestroyed > 0) {
412         ALOGI("Destroyed %d sockets for UID in %" PRId64 "us", mSocketsDestroyed, s.timeTakenUs());
413     }
414 
415     return 0;
416 }
417 
destroySockets(const UidRanges & uidRanges,const std::set<uid_t> & skipUids,bool excludeLoopback)418 int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids,
419                              bool excludeLoopback) {
420     mSocketsDestroyed = 0;
421     Stopwatch s;
422 
423     auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
424         return msg != nullptr &&
425                uidRanges.hasUid(msg->idiag_uid) &&
426                skipUids.find(msg->idiag_uid) == skipUids.end() &&
427                !(excludeLoopback && isLoopbackSocket(msg)) &&
428                !isAdbSocket(msg, getAdbPort());
429     };
430 
431     iovec iov[] = {
432         { nullptr, 0 },
433     };
434 
435     if (int ret = destroyLiveSockets(shouldDestroy, "UID", iov, ARRAY_SIZE(iov))) {
436         return ret;
437     }
438 
439     if (mSocketsDestroyed > 0) {
440         ALOGI("Destroyed %d sockets for %s skip={%s} in %" PRId64 "us", mSocketsDestroyed,
441               uidRanges.toString().c_str(), android::base::Join(skipUids, " ").c_str(),
442               s.timeTakenUs());
443     }
444 
445     return 0;
446 }
447 
448 // Destroys all "live" (CONNECTED, SYN_SENT, SYN_RECV) TCP sockets on the specified netId where:
449 // 1. The opening app no longer has permission to use this network, or:
450 // 2. The opening app does have permission, but did not explicitly select this network.
451 //
452 // We destroy sockets without the explicit bit because we want to avoid the situation where a
453 // privileged app uses its privileges without knowing it is doing so. For example, a privileged app
454 // might have opened a socket on this network just because it was the default network at the
455 // time. If we don't kill these sockets, those apps could continue to use them without realizing
456 // that they are now sending and receiving traffic on a network that is now restricted.
destroySocketsLackingPermission(unsigned netId,Permission permission,bool excludeLoopback)457 int SockDiag::destroySocketsLackingPermission(unsigned netId, Permission permission,
458                                               bool excludeLoopback) {
459     struct markmatch {
460         inet_diag_bc_op op;
461         // TODO: switch to inet_diag_markcond
462         __u32 mark;
463         __u32 mask;
464     } __attribute__((packed));
465     constexpr uint8_t matchlen = sizeof(markmatch);
466 
467     Fwmark netIdMark, netIdMask;
468     netIdMark.netId = netId;
469     netIdMask.netId = 0xffff;
470 
471     Fwmark controlMark;
472     controlMark.explicitlySelected = true;
473     controlMark.permission = permission;
474 
475     // A SOCK_DIAG bytecode program that accepts the sockets we intend to destroy.
476     struct bytecode {
477         markmatch netIdMatch;
478         markmatch controlMatch;
479         inet_diag_bc_op controlJump;
480     } __attribute__((packed)) bytecode;
481 
482     // The length of the INET_DIAG_BC_JMP instruction.
483     constexpr uint8_t jmplen = sizeof(inet_diag_bc_op);
484     // Jump exactly this far past the end of the program to reject.
485     constexpr uint8_t rejectoffset = sizeof(inet_diag_bc_op);
486     // Total length of the program.
487     constexpr uint8_t bytecodelen = sizeof(bytecode);
488 
489     bytecode = (struct bytecode) {
490         // If netId matches, continue, otherwise, reject (i.e., leave socket alone).
491         { { INET_DIAG_BC_MARK_COND, matchlen, bytecodelen + rejectoffset },
492           netIdMark.intValue, netIdMask.intValue },
493 
494         // If explicit and permission bits match, go to the JMP below which rejects the socket
495         // (i.e., we leave it alone). Otherwise, jump to the end of the program, which accepts the
496         // socket (so we destroy it).
497         { { INET_DIAG_BC_MARK_COND, matchlen, matchlen + jmplen },
498           controlMark.intValue, controlMark.intValue },
499 
500         // This JMP unconditionally rejects the packet by jumping to the reject target. It is
501         // necessary to keep the kernel bytecode verifier happy. If we don't have a JMP the bytecode
502         // is invalid because the target of every no jump must always be reachable by yes jumps.
503         // Without this JMP, the accept target is not reachable by yes jumps and the program will
504         // be rejected by the validator.
505         { INET_DIAG_BC_JMP, jmplen, jmplen + rejectoffset },
506 
507         // We have reached the end of the program. Accept the socket, and destroy it below.
508     };
509 
510     struct nlattr nla = {
511             .nla_len = sizeof(struct nlattr) + bytecodelen,
512             .nla_type = INET_DIAG_REQ_BYTECODE,
513     };
514 
515     iovec iov[] = {
516         { nullptr,   0 },
517         { &nla,      sizeof(nla) },
518         { &bytecode, bytecodelen },
519     };
520 
521     mSocketsDestroyed = 0;
522     Stopwatch s;
523 
524     auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
525         return msg != nullptr && !(excludeLoopback && isLoopbackSocket(msg));
526     };
527 
528     if (int ret = destroyLiveSockets(shouldDestroy, "permission change", iov, ARRAY_SIZE(iov))) {
529         return ret;
530     }
531 
532     if (mSocketsDestroyed > 0) {
533         ALOGI("Destroyed %d sockets for netId %d permission=%d in %" PRId64 "us", mSocketsDestroyed,
534               netId, permission, s.timeTakenUs());
535     }
536 
537     return 0;
538 }
539 
540 }  // namespace net
541 }  // namespace android
542