1 /*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Netd"
18
19 #include "SockDiag.h"
20
21 #include <errno.h>
22 #include <linux/inet_diag.h>
23 #include <linux/netlink.h>
24 #include <linux/sock_diag.h>
25 #include <netdb.h>
26 #include <netinet/in.h>
27 #include <netinet/tcp.h>
28 #include <string.h>
29 #include <sys/socket.h>
30 #include <sys/uio.h>
31
32 #include <cinttypes>
33
34 #include <android-base/properties.h>
35 #include <android-base/stringprintf.h>
36 #include <android-base/strings.h>
37 #include <log/log.h>
38 #include <netdutils/InternetAddresses.h>
39 #include <netdutils/Stopwatch.h>
40
41 #include "Permission.h"
42
43 #ifndef SOCK_DESTROY
44 #define SOCK_DESTROY 21
45 #endif
46
47 #define INET_DIAG_BC_MARK_COND 10
48
49 namespace android {
50
51 using android::base::StringPrintf;
52 using netdutils::ScopedAddrinfo;
53 using netdutils::Stopwatch;
54
55 namespace net {
56 namespace {
57
getAdbPort()58 int getAdbPort() {
59 return android::base::GetIntProperty("service.adb.tcp.port", 0);
60 }
61
isAdbSocket(const inet_diag_msg * msg,int adbPort)62 bool isAdbSocket(const inet_diag_msg *msg, int adbPort) {
63 return adbPort > 0 && msg->id.idiag_sport == htons(adbPort) &&
64 (msg->idiag_uid == AID_ROOT || msg->idiag_uid == AID_SHELL);
65 }
66
checkError(int fd)67 int checkError(int fd) {
68 struct {
69 nlmsghdr h;
70 nlmsgerr err;
71 } __attribute__((__packed__)) ack;
72 ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
73 if (bytesread == -1) {
74 // Read failed (error), or nothing to read (good).
75 return (errno == EAGAIN) ? 0 : -errno;
76 } else if (bytesread == (ssize_t) sizeof(ack) && ack.h.nlmsg_type == NLMSG_ERROR) {
77 // We got an error. Consume it.
78 recv(fd, &ack, sizeof(ack), 0);
79 return ack.err.error;
80 } else {
81 // The kernel replied with something. Leave it to the caller.
82 return 0;
83 }
84 }
85
86 } // namespace
87
open()88 bool SockDiag::open() {
89 if (hasSocks()) {
90 return false;
91 }
92
93 mSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
94 mWriteSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
95 if (!hasSocks()) {
96 closeSocks();
97 return false;
98 }
99
100 sockaddr_nl nl = { .nl_family = AF_NETLINK };
101 if ((connect(mSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1) ||
102 (connect(mWriteSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1)) {
103 closeSocks();
104 return false;
105 }
106
107 return true;
108 }
109
sendDumpRequest(uint8_t proto,uint8_t family,uint8_t extensions,uint32_t states,iovec * iov,int iovcnt)110 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint8_t extensions, uint32_t states,
111 iovec *iov, int iovcnt) {
112 struct {
113 nlmsghdr nlh;
114 inet_diag_req_v2 req;
115 } __attribute__((__packed__)) request = {
116 .nlh = {
117 .nlmsg_type = SOCK_DIAG_BY_FAMILY,
118 .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
119 },
120 .req = {
121 .sdiag_family = family,
122 .sdiag_protocol = proto,
123 .idiag_ext = extensions,
124 .idiag_states = states,
125 },
126 };
127
128 size_t len = 0;
129 iov[0].iov_base = &request;
130 iov[0].iov_len = sizeof(request);
131 for (int i = 0; i < iovcnt; i++) {
132 len += iov[i].iov_len;
133 }
134 request.nlh.nlmsg_len = len;
135
136 if (writev(mSock, iov, iovcnt) != (ssize_t) len) {
137 return -errno;
138 }
139
140 return checkError(mSock);
141 }
142
sendDumpRequest(uint8_t proto,uint8_t family,uint32_t states)143 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states) {
144 iovec iov[] = {
145 { nullptr, 0 },
146 };
147 return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
148 }
149
sendDumpRequest(uint8_t proto,uint8_t family,const char * addrstr)150 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
151 addrinfo hints = { .ai_flags = AI_NUMERICHOST };
152 addrinfo *res;
153 in6_addr mapped = { .s6_addr32 = { 0, 0, htonl(0xffff), 0 } };
154
155 // TODO: refactor the netlink parsing code out of system/core, bring it into netd, and stop
156 // doing string conversions when they're not necessary.
157 int ret = getaddrinfo(addrstr, nullptr, &hints, &res);
158 if (ret != 0) return -EINVAL;
159
160 // So we don't have to call freeaddrinfo on every failure path.
161 ScopedAddrinfo resP(res);
162
163 void *addr;
164 uint8_t addrlen;
165 if (res->ai_family == AF_INET && family == AF_INET) {
166 in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
167 addr = &ina;
168 addrlen = sizeof(ina);
169 } else if (res->ai_family == AF_INET && family == AF_INET6) {
170 in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
171 mapped.s6_addr32[3] = ina.s_addr;
172 addr = &mapped;
173 addrlen = sizeof(mapped);
174 } else if (res->ai_family == AF_INET6 && family == AF_INET6) {
175 in6_addr& in6a = reinterpret_cast<sockaddr_in6*>(res->ai_addr)->sin6_addr;
176 addr = &in6a;
177 addrlen = sizeof(in6a);
178 } else {
179 return -EAFNOSUPPORT;
180 }
181
182 uint8_t prefixlen = addrlen * 8;
183 uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
184 uint8_t nojump = yesjump + 4;
185
186 struct {
187 nlattr nla;
188 inet_diag_bc_op op;
189 inet_diag_hostcond cond;
190 } __attribute__((__packed__)) attrs = {
191 .nla = {
192 .nla_type = INET_DIAG_REQ_BYTECODE,
193 },
194 .op = {
195 INET_DIAG_BC_S_COND,
196 yesjump,
197 nojump,
198 },
199 .cond = {
200 family,
201 prefixlen,
202 -1,
203 {}
204 },
205 };
206
207 attrs.nla.nla_len = sizeof(attrs) + addrlen;
208
209 iovec iov[] = {
210 { nullptr, 0 },
211 { &attrs, sizeof(attrs) },
212 { addr, addrlen },
213 };
214
215 uint32_t states = ~(1 << TCP_TIME_WAIT);
216 return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
217 }
218
readDiagMsg(uint8_t proto,const SockDiag::DestroyFilter & shouldDestroy)219 int SockDiag::readDiagMsg(uint8_t proto, const SockDiag::DestroyFilter& shouldDestroy) {
220 NetlinkDumpCallback callback = [this, proto, shouldDestroy] (nlmsghdr *nlh) {
221 const inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
222 if (shouldDestroy(proto, msg)) {
223 sockDestroy(proto, msg);
224 }
225 };
226
227 return processNetlinkDump(mSock, callback);
228 }
229
readDiagMsgWithTcpInfo(const TcpInfoReader & tcpInfoReader)230 int SockDiag::readDiagMsgWithTcpInfo(const TcpInfoReader& tcpInfoReader) {
231 NetlinkDumpCallback callback = [tcpInfoReader] (nlmsghdr *nlh) {
232 if (nlh->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
233 ALOGE("expected nlmsg_type=SOCK_DIAG_BY_FAMILY, got nlmsg_type=%d", nlh->nlmsg_type);
234 return;
235 }
236 Fwmark mark;
237 struct tcp_info *tcpinfo = nullptr;
238 uint32_t tcpinfoLength = 0;
239 inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
240 uint32_t attr_len = nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*msg));
241 struct rtattr *attr = reinterpret_cast<struct rtattr*>(msg+1);
242 while (RTA_OK(attr, attr_len)) {
243 if (attr->rta_type == INET_DIAG_INFO) {
244 tcpinfo = reinterpret_cast<struct tcp_info*>(RTA_DATA(attr));
245 tcpinfoLength = RTA_PAYLOAD(attr);
246 }
247 if (attr->rta_type == INET_DIAG_MARK) {
248 mark.intValue = *reinterpret_cast<uint32_t*>(RTA_DATA(attr));
249 }
250 attr = RTA_NEXT(attr, attr_len);
251 }
252
253 tcpInfoReader(mark, msg, tcpinfo, tcpinfoLength);
254 };
255
256 return processNetlinkDump(mSock, callback);
257 }
258
259 // Determines whether a socket is a loopback socket. Does not check socket state.
isLoopbackSocket(const inet_diag_msg * msg)260 bool SockDiag::isLoopbackSocket(const inet_diag_msg *msg) {
261 switch (msg->idiag_family) {
262 case AF_INET:
263 // Old kernels only copy the IPv4 address and leave the other 12 bytes uninitialized.
264 return IN_LOOPBACK(htonl(msg->id.idiag_src[0])) ||
265 IN_LOOPBACK(htonl(msg->id.idiag_dst[0])) ||
266 msg->id.idiag_src[0] == msg->id.idiag_dst[0];
267
268 case AF_INET6: {
269 const struct in6_addr *src = (const struct in6_addr *) &msg->id.idiag_src;
270 const struct in6_addr *dst = (const struct in6_addr *) &msg->id.idiag_dst;
271 return (IN6_IS_ADDR_V4MAPPED(src) && IN_LOOPBACK(src->s6_addr32[3])) ||
272 (IN6_IS_ADDR_V4MAPPED(dst) && IN_LOOPBACK(dst->s6_addr32[3])) ||
273 IN6_IS_ADDR_LOOPBACK(src) || IN6_IS_ADDR_LOOPBACK(dst) ||
274 !memcmp(src, dst, sizeof(*src));
275 }
276 default:
277 return false;
278 }
279 }
280
sockDestroy(uint8_t proto,const inet_diag_msg * msg)281 int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
282 if (msg == nullptr) {
283 return 0;
284 }
285
286 DestroyRequest request = {
287 .nlh = {
288 .nlmsg_type = SOCK_DESTROY,
289 .nlmsg_flags = NLM_F_REQUEST,
290 },
291 .req = {
292 .sdiag_family = msg->idiag_family,
293 .sdiag_protocol = proto,
294 .idiag_states = (uint32_t) (1 << msg->idiag_state),
295 .id = msg->id,
296 },
297 };
298 request.nlh.nlmsg_len = sizeof(request);
299
300 if (write(mWriteSock, &request, sizeof(request)) < (ssize_t) sizeof(request)) {
301 return -errno;
302 }
303
304 int ret = checkError(mWriteSock);
305 if (!ret) mSocketsDestroyed++;
306 return ret;
307 }
308
destroySockets(uint8_t proto,int family,const char * addrstr,int ifindex)309 int SockDiag::destroySockets(uint8_t proto, int family, const char* addrstr, int ifindex) {
310 if (!hasSocks()) {
311 return -EBADFD;
312 }
313
314 if (int ret = sendDumpRequest(proto, family, addrstr)) {
315 return ret;
316 }
317
318 auto destroyAll = [ifindex](uint8_t, const inet_diag_msg* msg) {
319 return ifindex == 0 || ifindex == (int)msg->id.idiag_if;
320 };
321
322 return readDiagMsg(proto, destroyAll);
323 }
324
destroySockets(const char * addrstr,int ifindex)325 int SockDiag::destroySockets(const char* addrstr, int ifindex) {
326 Stopwatch s;
327 mSocketsDestroyed = 0;
328
329 std::string where = addrstr;
330 if (ifindex) where += StringPrintf(" ifindex %d", ifindex);
331
332 if (!strchr(addrstr, ':')) { // inet_ntop never returns something like ::ffff:192.0.2.1
333 if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr, ifindex)) {
334 ALOGE("Failed to destroy IPv4 sockets on %s: %s", where.c_str(), strerror(-ret));
335 return ret;
336 }
337 }
338 if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr, ifindex)) {
339 ALOGE("Failed to destroy IPv6 sockets on %s: %s", where.c_str(), strerror(-ret));
340 return ret;
341 }
342
343 if (mSocketsDestroyed > 0) {
344 ALOGI("Destroyed %d sockets on %s in %" PRId64 "us", mSocketsDestroyed, where.c_str(),
345 s.timeTakenUs());
346 }
347
348 return mSocketsDestroyed;
349 }
350
destroyLiveSockets(const DestroyFilter & destroyFilter,const char * what,iovec * iov,int iovcnt)351 int SockDiag::destroyLiveSockets(const DestroyFilter& destroyFilter, const char *what,
352 iovec *iov, int iovcnt) {
353 const int proto = IPPROTO_TCP;
354 const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
355
356 for (const int family : {AF_INET, AF_INET6}) {
357 const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
358 if (int ret = sendDumpRequest(proto, family, 0, states, iov, iovcnt)) {
359 ALOGE("Failed to dump %s sockets for %s: %s", familyName, what, strerror(-ret));
360 return ret;
361 }
362 if (int ret = readDiagMsg(proto, destroyFilter)) {
363 ALOGE("Failed to destroy %s sockets for %s: %s", familyName, what, strerror(-ret));
364 return ret;
365 }
366 }
367
368 return 0;
369 }
370
getLiveTcpInfos(const TcpInfoReader & tcpInfoReader)371 int SockDiag::getLiveTcpInfos(const TcpInfoReader& tcpInfoReader) {
372 const int proto = IPPROTO_TCP;
373 const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
374 const uint8_t extensions = (1 << INET_DIAG_MEMINFO); // flag for dumping struct tcp_info.
375
376 iovec iov[] = {
377 { nullptr, 0 },
378 };
379
380 for (const int family : {AF_INET, AF_INET6}) {
381 const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
382 if (int ret = sendDumpRequest(proto, family, extensions, states, iov, ARRAY_SIZE(iov))) {
383 ALOGE("Failed to dump %s sockets struct tcp_info: %s", familyName, strerror(-ret));
384 return ret;
385 }
386 if (int ret = readDiagMsgWithTcpInfo(tcpInfoReader)) {
387 ALOGE("Failed to read %s sockets struct tcp_info: %s", familyName, strerror(-ret));
388 return ret;
389 }
390 }
391
392 return 0;
393 }
394
destroySockets(uint8_t proto,const uid_t uid,bool excludeLoopback)395 int SockDiag::destroySockets(uint8_t proto, const uid_t uid, bool excludeLoopback) {
396 mSocketsDestroyed = 0;
397 Stopwatch s;
398
399 auto shouldDestroy = [uid, excludeLoopback] (uint8_t, const inet_diag_msg *msg) {
400 return msg != nullptr &&
401 msg->idiag_uid == uid &&
402 !(excludeLoopback && isLoopbackSocket(msg));
403 };
404
405 for (const int family : {AF_INET, AF_INET6}) {
406 const char *familyName = family == AF_INET ? "IPv4" : "IPv6";
407 uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
408 if (int ret = sendDumpRequest(proto, family, states)) {
409 ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
410 return ret;
411 }
412 if (int ret = readDiagMsg(proto, shouldDestroy)) {
413 ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
414 return ret;
415 }
416 }
417
418 if (mSocketsDestroyed > 0) {
419 ALOGI("Destroyed %d sockets for UID in %" PRId64 "us", mSocketsDestroyed, s.timeTakenUs());
420 }
421
422 return 0;
423 }
424
destroySockets(const UidRanges & uidRanges,const std::set<uid_t> & skipUids,bool excludeLoopback)425 int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids,
426 bool excludeLoopback) {
427 mSocketsDestroyed = 0;
428 Stopwatch s;
429
430 auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
431 return msg != nullptr &&
432 uidRanges.hasUid(msg->idiag_uid) &&
433 skipUids.find(msg->idiag_uid) == skipUids.end() &&
434 !(excludeLoopback && isLoopbackSocket(msg)) &&
435 !isAdbSocket(msg, getAdbPort());
436 };
437
438 iovec iov[] = {
439 { nullptr, 0 },
440 };
441
442 if (int ret = destroyLiveSockets(shouldDestroy, "UID", iov, ARRAY_SIZE(iov))) {
443 return ret;
444 }
445
446 if (mSocketsDestroyed > 0) {
447 ALOGI("Destroyed %d sockets for %s skip={%s} in %" PRId64 "us", mSocketsDestroyed,
448 uidRanges.toString().c_str(), android::base::Join(skipUids, " ").c_str(),
449 s.timeTakenUs());
450 }
451
452 return 0;
453 }
454
455 // Destroys all "live" (CONNECTED, SYN_SENT, SYN_RECV) TCP sockets on the specified netId where:
456 // 1. The opening app no longer has permission to use this network, or:
457 // 2. The opening app does have permission, but did not explicitly select this network.
458 //
459 // We destroy sockets without the explicit bit because we want to avoid the situation where a
460 // privileged app uses its privileges without knowing it is doing so. For example, a privileged app
461 // might have opened a socket on this network just because it was the default network at the
462 // time. If we don't kill these sockets, those apps could continue to use them without realizing
463 // that they are now sending and receiving traffic on a network that is now restricted.
destroySocketsLackingPermission(unsigned netId,Permission permission,bool excludeLoopback)464 int SockDiag::destroySocketsLackingPermission(unsigned netId, Permission permission,
465 bool excludeLoopback) {
466 struct markmatch {
467 inet_diag_bc_op op;
468 // TODO: switch to inet_diag_markcond
469 __u32 mark;
470 __u32 mask;
471 } __attribute__((packed));
472 constexpr uint8_t matchlen = sizeof(markmatch);
473
474 Fwmark netIdMark, netIdMask;
475 netIdMark.netId = netId;
476 netIdMask.netId = 0xffff;
477
478 Fwmark controlMark;
479 controlMark.explicitlySelected = true;
480 controlMark.permission = permission;
481
482 // A SOCK_DIAG bytecode program that accepts the sockets we intend to destroy.
483 struct bytecode {
484 markmatch netIdMatch;
485 markmatch controlMatch;
486 inet_diag_bc_op controlJump;
487 } __attribute__((packed)) bytecode;
488
489 // The length of the INET_DIAG_BC_JMP instruction.
490 constexpr uint8_t jmplen = sizeof(inet_diag_bc_op);
491 // Jump exactly this far past the end of the program to reject.
492 constexpr uint8_t rejectoffset = sizeof(inet_diag_bc_op);
493 // Total length of the program.
494 constexpr uint8_t bytecodelen = sizeof(bytecode);
495
496 bytecode = (struct bytecode) {
497 // If netId matches, continue, otherwise, reject (i.e., leave socket alone).
498 { { INET_DIAG_BC_MARK_COND, matchlen, bytecodelen + rejectoffset },
499 netIdMark.intValue, netIdMask.intValue },
500
501 // If explicit and permission bits match, go to the JMP below which rejects the socket
502 // (i.e., we leave it alone). Otherwise, jump to the end of the program, which accepts the
503 // socket (so we destroy it).
504 { { INET_DIAG_BC_MARK_COND, matchlen, matchlen + jmplen },
505 controlMark.intValue, controlMark.intValue },
506
507 // This JMP unconditionally rejects the packet by jumping to the reject target. It is
508 // necessary to keep the kernel bytecode verifier happy. If we don't have a JMP the bytecode
509 // is invalid because the target of every no jump must always be reachable by yes jumps.
510 // Without this JMP, the accept target is not reachable by yes jumps and the program will
511 // be rejected by the validator.
512 { INET_DIAG_BC_JMP, jmplen, jmplen + rejectoffset },
513
514 // We have reached the end of the program. Accept the socket, and destroy it below.
515 };
516
517 struct nlattr nla = {
518 .nla_len = sizeof(struct nlattr) + bytecodelen,
519 .nla_type = INET_DIAG_REQ_BYTECODE,
520 };
521
522 iovec iov[] = {
523 { nullptr, 0 },
524 { &nla, sizeof(nla) },
525 { &bytecode, bytecodelen },
526 };
527
528 mSocketsDestroyed = 0;
529 Stopwatch s;
530
531 auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
532 return msg != nullptr && !(excludeLoopback && isLoopbackSocket(msg));
533 };
534
535 if (int ret = destroyLiveSockets(shouldDestroy, "permission change", iov, ARRAY_SIZE(iov))) {
536 return ret;
537 }
538
539 if (mSocketsDestroyed > 0) {
540 ALOGI("Destroyed %d sockets for netId %d permission=%d in %" PRId64 "us", mSocketsDestroyed,
541 netId, permission, s.timeTakenUs());
542 }
543
544 return 0;
545 }
546
547 } // namespace net
548 } // namespace android
549