• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Netd"
18 
19 #include "SockDiag.h"
20 
21 #include <errno.h>
22 #include <linux/inet_diag.h>
23 #include <linux/netlink.h>
24 #include <linux/sock_diag.h>
25 #include <netdb.h>
26 #include <netinet/in.h>
27 #include <netinet/tcp.h>
28 #include <string.h>
29 #include <sys/socket.h>
30 #include <sys/uio.h>
31 
32 #include <cinttypes>
33 
34 #include <android-base/properties.h>
35 #include <android-base/stringprintf.h>
36 #include <android-base/strings.h>
37 #include <log/log.h>
38 #include <netdutils/InternetAddresses.h>
39 #include <netdutils/Stopwatch.h>
40 
41 #include "Permission.h"
42 
43 #ifndef SOCK_DESTROY
44 #define SOCK_DESTROY 21
45 #endif
46 
47 #define INET_DIAG_BC_MARK_COND 10
48 
49 namespace android {
50 
51 using android::base::StringPrintf;
52 using netdutils::ScopedAddrinfo;
53 using netdutils::Stopwatch;
54 
55 namespace net {
56 namespace {
57 
58 static const bool isUser = (android::base::GetProperty("ro.build.type", "") == "user");
59 
checkError(int fd)60 int checkError(int fd) {
61     struct {
62         nlmsghdr h;
63         nlmsgerr err;
64     } __attribute__((__packed__)) ack;
65     ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
66     if (bytesread == -1) {
67        // Read failed (error), or nothing to read (good).
68        return (errno == EAGAIN) ? 0 : -errno;
69     } else if (bytesread == (ssize_t) sizeof(ack) && ack.h.nlmsg_type == NLMSG_ERROR) {
70         // We got an error. Consume it.
71         recv(fd, &ack, sizeof(ack), 0);
72         return ack.err.error;
73     } else {
74         // The kernel replied with something. Leave it to the caller.
75         return 0;
76     }
77 }
78 
79 }  // namespace
80 
open()81 bool SockDiag::open() {
82     if (hasSocks()) {
83         return false;
84     }
85 
86     mSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
87     mWriteSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
88     if (!hasSocks()) {
89         closeSocks();
90         return false;
91     }
92 
93     sockaddr_nl nl = { .nl_family = AF_NETLINK };
94     if ((connect(mSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1) ||
95         (connect(mWriteSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1)) {
96         closeSocks();
97         return false;
98     }
99 
100     return true;
101 }
102 
sendDumpRequest(uint8_t proto,uint8_t family,uint8_t extensions,uint32_t states,iovec * iov,int iovcnt)103 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint8_t extensions, uint32_t states,
104                               iovec *iov, int iovcnt) {
105     struct {
106         nlmsghdr nlh;
107         inet_diag_req_v2 req;
108     } __attribute__((__packed__)) request = {
109         .nlh = {
110             .nlmsg_type = SOCK_DIAG_BY_FAMILY,
111             .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
112         },
113         .req = {
114             .sdiag_family = family,
115             .sdiag_protocol = proto,
116             .idiag_ext = extensions,
117             .idiag_states = states,
118         },
119     };
120 
121     size_t len = 0;
122     iov[0].iov_base = &request;
123     iov[0].iov_len = sizeof(request);
124     for (int i = 0; i < iovcnt; i++) {
125         len += iov[i].iov_len;
126     }
127     request.nlh.nlmsg_len = len;
128 
129     ssize_t writevRet = writev(mSock, iov, iovcnt);
130     // Don't let pointers to the stack escape.
131     iov[0] = {nullptr, 0};
132     if (writevRet != (ssize_t)len) {
133         return -errno;
134     }
135 
136     return checkError(mSock);
137 }
138 
sendDumpRequest(uint8_t proto,uint8_t family,uint32_t states)139 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states) {
140     iovec iov[] = {
141         { nullptr, 0 },
142     };
143     return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
144 }
145 
sendDumpRequest(uint8_t proto,uint8_t family,const char * addrstr)146 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
147     addrinfo hints = { .ai_flags = AI_NUMERICHOST };
148     addrinfo *res;
149     in6_addr mapped = { .s6_addr32 = { 0, 0, htonl(0xffff), 0 } };
150 
151     // TODO: refactor the netlink parsing code out of system/core, bring it into netd, and stop
152     // doing string conversions when they're not necessary.
153     int ret = getaddrinfo(addrstr, nullptr, &hints, &res);
154     if (ret != 0) return -EINVAL;
155 
156     // So we don't have to call freeaddrinfo on every failure path.
157     ScopedAddrinfo resP(res);
158 
159     void *addr;
160     uint8_t addrlen;
161     if (res->ai_family == AF_INET && family == AF_INET) {
162         in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
163         addr = &ina;
164         addrlen = sizeof(ina);
165     } else if (res->ai_family == AF_INET && family == AF_INET6) {
166         in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
167         mapped.s6_addr32[3] = ina.s_addr;
168         addr = &mapped;
169         addrlen = sizeof(mapped);
170     } else if (res->ai_family == AF_INET6 && family == AF_INET6) {
171         in6_addr& in6a = reinterpret_cast<sockaddr_in6*>(res->ai_addr)->sin6_addr;
172         addr = &in6a;
173         addrlen = sizeof(in6a);
174     } else {
175         return -EAFNOSUPPORT;
176     }
177 
178     uint8_t prefixlen = addrlen * 8;
179     uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
180     uint8_t nojump = yesjump + 4;
181 
182     struct {
183         nlattr nla;
184         inet_diag_bc_op op;
185         inet_diag_hostcond cond;
186     } __attribute__((__packed__)) attrs = {
187         .nla = {
188             .nla_type = INET_DIAG_REQ_BYTECODE,
189         },
190         .op = {
191             INET_DIAG_BC_S_COND,
192             yesjump,
193             nojump,
194         },
195         .cond = {
196             family,
197             prefixlen,
198             -1,
199             {}
200         },
201     };
202 
203     attrs.nla.nla_len = sizeof(attrs) + addrlen;
204 
205     iovec iov[] = {
206         { nullptr,           0 },
207         { &attrs,            sizeof(attrs) },
208         { addr,              addrlen },
209     };
210 
211     uint32_t states = ~(1 << TCP_TIME_WAIT);
212     return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
213 }
214 
readDiagMsg(uint8_t proto,const SockDiag::DestroyFilter & shouldDestroy)215 int SockDiag::readDiagMsg(uint8_t proto, const SockDiag::DestroyFilter& shouldDestroy) {
216     NetlinkDumpCallback callback = [this, proto, shouldDestroy] (nlmsghdr *nlh) {
217         const inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
218         if (shouldDestroy(proto, msg)) {
219             sockDestroy(proto, msg);
220         }
221     };
222 
223     return processNetlinkDump(mSock, callback);
224 }
225 
readDiagMsgWithTcpInfo(const TcpInfoReader & tcpInfoReader)226 int SockDiag::readDiagMsgWithTcpInfo(const TcpInfoReader& tcpInfoReader) {
227     NetlinkDumpCallback callback = [tcpInfoReader] (nlmsghdr *nlh) {
228         if (nlh->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
229             ALOGE("expected nlmsg_type=SOCK_DIAG_BY_FAMILY, got nlmsg_type=%d", nlh->nlmsg_type);
230             return;
231         }
232         Fwmark mark;
233         struct tcp_info *tcpinfo = nullptr;
234         uint32_t tcpinfoLength = 0;
235         inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
236         uint32_t attr_len = nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*msg));
237         struct rtattr *attr = reinterpret_cast<struct rtattr*>(msg+1);
238         while (RTA_OK(attr, attr_len)) {
239             if (attr->rta_type == INET_DIAG_INFO) {
240                 tcpinfo = reinterpret_cast<struct tcp_info*>(RTA_DATA(attr));
241                 tcpinfoLength = RTA_PAYLOAD(attr);
242             }
243             if (attr->rta_type == INET_DIAG_MARK) {
244                 mark.intValue = *reinterpret_cast<uint32_t*>(RTA_DATA(attr));
245             }
246             attr = RTA_NEXT(attr, attr_len);
247         }
248 
249         tcpInfoReader(mark, msg, tcpinfo, tcpinfoLength);
250     };
251 
252     return processNetlinkDump(mSock, callback);
253 }
254 
255 // Determines whether a socket is a loopback socket. Does not check socket state.
isLoopbackSocket(const inet_diag_msg * msg)256 bool SockDiag::isLoopbackSocket(const inet_diag_msg *msg) {
257     switch (msg->idiag_family) {
258         case AF_INET:
259             // Old kernels only copy the IPv4 address and leave the other 12 bytes uninitialized.
260             return IN_LOOPBACK(htonl(msg->id.idiag_src[0])) ||
261                    IN_LOOPBACK(htonl(msg->id.idiag_dst[0])) ||
262                    msg->id.idiag_src[0] == msg->id.idiag_dst[0];
263 
264         case AF_INET6: {
265             const struct in6_addr *src = (const struct in6_addr *) &msg->id.idiag_src;
266             const struct in6_addr *dst = (const struct in6_addr *) &msg->id.idiag_dst;
267             return (IN6_IS_ADDR_V4MAPPED(src) && IN_LOOPBACK(src->s6_addr32[3])) ||
268                    (IN6_IS_ADDR_V4MAPPED(dst) && IN_LOOPBACK(dst->s6_addr32[3])) ||
269                    IN6_IS_ADDR_LOOPBACK(src) || IN6_IS_ADDR_LOOPBACK(dst) ||
270                    !memcmp(src, dst, sizeof(*src));
271         }
272         default:
273             return false;
274     }
275 }
276 
sockDestroy(uint8_t proto,const inet_diag_msg * msg)277 int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
278     if (msg == nullptr) {
279        return 0;
280     }
281 
282     DestroyRequest request = {
283         .nlh = {
284             .nlmsg_type = SOCK_DESTROY,
285             .nlmsg_flags = NLM_F_REQUEST,
286         },
287         .req = {
288             .sdiag_family = msg->idiag_family,
289             .sdiag_protocol = proto,
290             .idiag_states = (uint32_t) (1 << msg->idiag_state),
291             .id = msg->id,
292         },
293     };
294     request.nlh.nlmsg_len = sizeof(request);
295 
296     if (write(mWriteSock, &request, sizeof(request)) < (ssize_t) sizeof(request)) {
297         return -errno;
298     }
299 
300     int ret = checkError(mWriteSock);
301     if (!ret) mSocketsDestroyed++;
302     return ret;
303 }
304 
destroySockets(uint8_t proto,int family,const char * addrstr,int ifindex)305 int SockDiag::destroySockets(uint8_t proto, int family, const char* addrstr, int ifindex) {
306     if (!hasSocks()) {
307         return -EBADFD;
308     }
309 
310     if (int ret = sendDumpRequest(proto, family, addrstr)) {
311         return ret;
312     }
313 
314     // Destroy all sockets on the address, except link-local sockets where ifindex doesn't match.
315     auto shouldDestroy = [ifindex](uint8_t, const inet_diag_msg* msg) {
316         return ifindex == 0 || ifindex == (int)msg->id.idiag_if;
317     };
318 
319     return readDiagMsg(proto, shouldDestroy);
320 }
321 
destroySockets(const char * addrstr,int ifindex)322 int SockDiag::destroySockets(const char* addrstr, int ifindex) {
323     Stopwatch s;
324     mSocketsDestroyed = 0;
325 
326     std::string where = addrstr;
327     if (ifindex) where += StringPrintf(" ifindex %d", ifindex);
328 
329     if (!strchr(addrstr, ':')) {  // inet_ntop never returns something like ::ffff:192.0.2.1
330         if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr, ifindex)) {
331             ALOGE("Failed to destroy IPv4 sockets on %s: %s",
332                 (isUser ? "[hidden: user build]" : where.c_str()), strerror(-ret));
333             return ret;
334         }
335     }
336     if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr, ifindex)) {
337         ALOGE("Failed to destroy IPv6 sockets on %s: %s",
338             (isUser ? "[hidden: user build]" : where.c_str()), strerror(-ret));
339         return ret;
340     }
341 
342     if (mSocketsDestroyed > 0) {
343         ALOGI("Destroyed %d sockets on %s in %" PRId64 "us", mSocketsDestroyed,
344             (isUser ? "[hidden: user build]" : where.c_str()), s.timeTakenUs());
345     }
346 
347     return mSocketsDestroyed;
348 }
349 
destroyLiveSockets(const DestroyFilter & destroyFilter,const char * what,iovec * iov,int iovcnt)350 int SockDiag::destroyLiveSockets(const DestroyFilter& destroyFilter, const char *what,
351                                  iovec *iov, int iovcnt) {
352     const int proto = IPPROTO_TCP;
353     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
354 
355     for (const int family : {AF_INET, AF_INET6}) {
356         const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
357         if (int ret = sendDumpRequest(proto, family, 0, states, iov, iovcnt)) {
358             ALOGE("Failed to dump %s sockets for %s: %s", familyName, what, strerror(-ret));
359             return ret;
360         }
361         if (int ret = readDiagMsg(proto, destroyFilter)) {
362             ALOGE("Failed to destroy %s sockets for %s: %s", familyName, what, strerror(-ret));
363             return ret;
364         }
365     }
366 
367     return 0;
368 }
369 
getLiveTcpInfos(const TcpInfoReader & tcpInfoReader)370 int SockDiag::getLiveTcpInfos(const TcpInfoReader& tcpInfoReader) {
371     const int proto = IPPROTO_TCP;
372     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
373     const uint8_t extensions = (1 << INET_DIAG_MEMINFO); // flag for dumping struct tcp_info.
374 
375     iovec iov[] = {
376         { nullptr, 0 },
377     };
378 
379     for (const int family : {AF_INET, AF_INET6}) {
380         const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
381         if (int ret = sendDumpRequest(proto, family, extensions, states, iov, ARRAY_SIZE(iov))) {
382             ALOGE("Failed to dump %s sockets struct tcp_info: %s", familyName, strerror(-ret));
383             return ret;
384         }
385         if (int ret = readDiagMsgWithTcpInfo(tcpInfoReader)) {
386             ALOGE("Failed to read %s sockets struct tcp_info: %s", familyName, strerror(-ret));
387             return ret;
388         }
389     }
390 
391     return 0;
392 }
393 
destroySockets(uint8_t proto,const uid_t uid,bool excludeLoopback)394 int SockDiag::destroySockets(uint8_t proto, const uid_t uid, bool excludeLoopback) {
395     mSocketsDestroyed = 0;
396     Stopwatch s;
397 
398     auto shouldDestroy = [uid, excludeLoopback] (uint8_t, const inet_diag_msg *msg) {
399         return msg != nullptr &&
400                msg->idiag_uid == uid &&
401                !(excludeLoopback && isLoopbackSocket(msg));
402     };
403 
404     for (const int family : {AF_INET, AF_INET6}) {
405         const char *familyName = family == AF_INET ? "IPv4" : "IPv6";
406         uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
407         if (int ret = sendDumpRequest(proto, family, states)) {
408             ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
409             return ret;
410         }
411         if (int ret = readDiagMsg(proto, shouldDestroy)) {
412             ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
413             return ret;
414         }
415     }
416 
417     if (mSocketsDestroyed > 0) {
418         ALOGI("Destroyed %d sockets for UID in %" PRId64 "us", mSocketsDestroyed, s.timeTakenUs());
419     }
420 
421     return 0;
422 }
423 
424 // Destroys all "live" (CONNECTED, SYN_SENT, SYN_RECV) TCP sockets on the specified netId where:
425 // 1. The opening app no longer has permission to use this network, or:
426 // 2. The opening app does have permission, but did not explicitly select this network.
427 //
428 // We destroy sockets without the explicit bit because we want to avoid the situation where a
429 // privileged app uses its privileges without knowing it is doing so. For example, a privileged app
430 // might have opened a socket on this network just because it was the default network at the
431 // time. If we don't kill these sockets, those apps could continue to use them without realizing
432 // that they are now sending and receiving traffic on a network that is now restricted.
destroySocketsLackingPermission(unsigned netId,Permission permission,bool excludeLoopback)433 int SockDiag::destroySocketsLackingPermission(unsigned netId, Permission permission,
434                                               bool excludeLoopback) {
435     struct markmatch {
436         inet_diag_bc_op op;
437         // TODO: switch to inet_diag_markcond
438         __u32 mark;
439         __u32 mask;
440     } __attribute__((packed));
441     constexpr uint8_t matchlen = sizeof(markmatch);
442 
443     Fwmark netIdMark, netIdMask;
444     netIdMark.netId = netId;
445     netIdMask.netId = 0xffff;
446 
447     Fwmark controlMark;
448     controlMark.explicitlySelected = true;
449     controlMark.permission = permission;
450 
451     // A SOCK_DIAG bytecode program that accepts the sockets we intend to destroy.
452     struct bytecode {
453         markmatch netIdMatch;
454         markmatch controlMatch;
455         inet_diag_bc_op controlJump;
456     } __attribute__((packed)) bytecode;
457 
458     // The length of the INET_DIAG_BC_JMP instruction.
459     constexpr uint8_t jmplen = sizeof(inet_diag_bc_op);
460     // Jump exactly this far past the end of the program to reject.
461     constexpr uint8_t rejectoffset = sizeof(inet_diag_bc_op);
462     // Total length of the program.
463     constexpr uint8_t bytecodelen = sizeof(bytecode);
464 
465     bytecode = (struct bytecode) {
466         // If netId matches, continue, otherwise, reject (i.e., leave socket alone).
467         { { INET_DIAG_BC_MARK_COND, matchlen, bytecodelen + rejectoffset },
468           netIdMark.intValue, netIdMask.intValue },
469 
470         // If explicit and permission bits match, go to the JMP below which rejects the socket
471         // (i.e., we leave it alone). Otherwise, jump to the end of the program, which accepts the
472         // socket (so we destroy it).
473         { { INET_DIAG_BC_MARK_COND, matchlen, matchlen + jmplen },
474           controlMark.intValue, controlMark.intValue },
475 
476         // This JMP unconditionally rejects the packet by jumping to the reject target. It is
477         // necessary to keep the kernel bytecode verifier happy. If we don't have a JMP the bytecode
478         // is invalid because the target of every no jump must always be reachable by yes jumps.
479         // Without this JMP, the accept target is not reachable by yes jumps and the program will
480         // be rejected by the validator.
481         { INET_DIAG_BC_JMP, jmplen, jmplen + rejectoffset },
482 
483         // We have reached the end of the program. Accept the socket, and destroy it below.
484     };
485 
486     struct nlattr nla = {
487             .nla_len = sizeof(struct nlattr) + bytecodelen,
488             .nla_type = INET_DIAG_REQ_BYTECODE,
489     };
490 
491     iovec iov[] = {
492         { nullptr,   0 },
493         { &nla,      sizeof(nla) },
494         { &bytecode, bytecodelen },
495     };
496 
497     mSocketsDestroyed = 0;
498     Stopwatch s;
499 
500     auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
501         return msg != nullptr && !(excludeLoopback && isLoopbackSocket(msg));
502     };
503 
504     if (int ret = destroyLiveSockets(shouldDestroy, "permission change", iov, ARRAY_SIZE(iov))) {
505         return ret;
506     }
507 
508     if (mSocketsDestroyed > 0) {
509         ALOGI("Destroyed %d sockets for netId %d permission=%d in %" PRId64 "us", mSocketsDestroyed,
510               netId, permission, s.timeTakenUs());
511     }
512 
513     return 0;
514 }
515 
516 }  // namespace net
517 }  // namespace android
518