1 /*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Netd"
18
19 #include "SockDiag.h"
20
21 #include <errno.h>
22 #include <linux/inet_diag.h>
23 #include <linux/netlink.h>
24 #include <linux/sock_diag.h>
25 #include <netdb.h>
26 #include <netinet/in.h>
27 #include <netinet/tcp.h>
28 #include <string.h>
29 #include <sys/socket.h>
30 #include <sys/uio.h>
31
32 #include <cinttypes>
33
34 #include <android-base/properties.h>
35 #include <android-base/strings.h>
36 #include <log/log.h>
37 #include <netdutils/InternetAddresses.h>
38 #include <netdutils/Stopwatch.h>
39
40 #include "Permission.h"
41
42 #ifndef SOCK_DESTROY
43 #define SOCK_DESTROY 21
44 #endif
45
46 #define INET_DIAG_BC_MARK_COND 10
47
48 namespace android {
49
50 using netdutils::ScopedAddrinfo;
51 using netdutils::Stopwatch;
52
53 namespace net {
54 namespace {
55
getAdbPort()56 int getAdbPort() {
57 return android::base::GetIntProperty("service.adb.tcp.port", 0);
58 }
59
isAdbSocket(const inet_diag_msg * msg,int adbPort)60 bool isAdbSocket(const inet_diag_msg *msg, int adbPort) {
61 return adbPort > 0 && msg->id.idiag_sport == htons(adbPort) &&
62 (msg->idiag_uid == AID_ROOT || msg->idiag_uid == AID_SHELL);
63 }
64
checkError(int fd)65 int checkError(int fd) {
66 struct {
67 nlmsghdr h;
68 nlmsgerr err;
69 } __attribute__((__packed__)) ack;
70 ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
71 if (bytesread == -1) {
72 // Read failed (error), or nothing to read (good).
73 return (errno == EAGAIN) ? 0 : -errno;
74 } else if (bytesread == (ssize_t) sizeof(ack) && ack.h.nlmsg_type == NLMSG_ERROR) {
75 // We got an error. Consume it.
76 recv(fd, &ack, sizeof(ack), 0);
77 return ack.err.error;
78 } else {
79 // The kernel replied with something. Leave it to the caller.
80 return 0;
81 }
82 }
83
84 } // namespace
85
open()86 bool SockDiag::open() {
87 if (hasSocks()) {
88 return false;
89 }
90
91 mSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
92 mWriteSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
93 if (!hasSocks()) {
94 closeSocks();
95 return false;
96 }
97
98 sockaddr_nl nl = { .nl_family = AF_NETLINK };
99 if ((connect(mSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1) ||
100 (connect(mWriteSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1)) {
101 closeSocks();
102 return false;
103 }
104
105 return true;
106 }
107
sendDumpRequest(uint8_t proto,uint8_t family,uint8_t extensions,uint32_t states,iovec * iov,int iovcnt)108 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint8_t extensions, uint32_t states,
109 iovec *iov, int iovcnt) {
110 struct {
111 nlmsghdr nlh;
112 inet_diag_req_v2 req;
113 } __attribute__((__packed__)) request = {
114 .nlh = {
115 .nlmsg_type = SOCK_DIAG_BY_FAMILY,
116 .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
117 },
118 .req = {
119 .sdiag_family = family,
120 .sdiag_protocol = proto,
121 .idiag_ext = extensions,
122 .idiag_states = states,
123 },
124 };
125
126 size_t len = 0;
127 iov[0].iov_base = &request;
128 iov[0].iov_len = sizeof(request);
129 for (int i = 0; i < iovcnt; i++) {
130 len += iov[i].iov_len;
131 }
132 request.nlh.nlmsg_len = len;
133
134 if (writev(mSock, iov, iovcnt) != (ssize_t) len) {
135 return -errno;
136 }
137
138 return checkError(mSock);
139 }
140
sendDumpRequest(uint8_t proto,uint8_t family,uint32_t states)141 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states) {
142 iovec iov[] = {
143 { nullptr, 0 },
144 };
145 return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
146 }
147
sendDumpRequest(uint8_t proto,uint8_t family,const char * addrstr)148 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
149 addrinfo hints = { .ai_flags = AI_NUMERICHOST };
150 addrinfo *res;
151 in6_addr mapped = { .s6_addr32 = { 0, 0, htonl(0xffff), 0 } };
152
153 // TODO: refactor the netlink parsing code out of system/core, bring it into netd, and stop
154 // doing string conversions when they're not necessary.
155 int ret = getaddrinfo(addrstr, nullptr, &hints, &res);
156 if (ret != 0) return -EINVAL;
157
158 // So we don't have to call freeaddrinfo on every failure path.
159 ScopedAddrinfo resP(res);
160
161 void *addr;
162 uint8_t addrlen;
163 if (res->ai_family == AF_INET && family == AF_INET) {
164 in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
165 addr = &ina;
166 addrlen = sizeof(ina);
167 } else if (res->ai_family == AF_INET && family == AF_INET6) {
168 in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
169 mapped.s6_addr32[3] = ina.s_addr;
170 addr = &mapped;
171 addrlen = sizeof(mapped);
172 } else if (res->ai_family == AF_INET6 && family == AF_INET6) {
173 in6_addr& in6a = reinterpret_cast<sockaddr_in6*>(res->ai_addr)->sin6_addr;
174 addr = &in6a;
175 addrlen = sizeof(in6a);
176 } else {
177 return -EAFNOSUPPORT;
178 }
179
180 uint8_t prefixlen = addrlen * 8;
181 uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
182 uint8_t nojump = yesjump + 4;
183
184 struct {
185 nlattr nla;
186 inet_diag_bc_op op;
187 inet_diag_hostcond cond;
188 } __attribute__((__packed__)) attrs = {
189 .nla = {
190 .nla_type = INET_DIAG_REQ_BYTECODE,
191 },
192 .op = {
193 INET_DIAG_BC_S_COND,
194 yesjump,
195 nojump,
196 },
197 .cond = {
198 family,
199 prefixlen,
200 -1,
201 {}
202 },
203 };
204
205 attrs.nla.nla_len = sizeof(attrs) + addrlen;
206
207 iovec iov[] = {
208 { nullptr, 0 },
209 { &attrs, sizeof(attrs) },
210 { addr, addrlen },
211 };
212
213 uint32_t states = ~(1 << TCP_TIME_WAIT);
214 return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
215 }
216
readDiagMsg(uint8_t proto,const SockDiag::DestroyFilter & shouldDestroy)217 int SockDiag::readDiagMsg(uint8_t proto, const SockDiag::DestroyFilter& shouldDestroy) {
218 NetlinkDumpCallback callback = [this, proto, shouldDestroy] (nlmsghdr *nlh) {
219 const inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
220 if (shouldDestroy(proto, msg)) {
221 sockDestroy(proto, msg);
222 }
223 };
224
225 return processNetlinkDump(mSock, callback);
226 }
227
readDiagMsgWithTcpInfo(const TcpInfoReader & tcpInfoReader)228 int SockDiag::readDiagMsgWithTcpInfo(const TcpInfoReader& tcpInfoReader) {
229 NetlinkDumpCallback callback = [tcpInfoReader] (nlmsghdr *nlh) {
230 if (nlh->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
231 ALOGE("expected nlmsg_type=SOCK_DIAG_BY_FAMILY, got nlmsg_type=%d", nlh->nlmsg_type);
232 return;
233 }
234 Fwmark mark;
235 struct tcp_info *tcpinfo = nullptr;
236 uint32_t tcpinfoLength = 0;
237 inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
238 uint32_t attr_len = nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*msg));
239 struct rtattr *attr = reinterpret_cast<struct rtattr*>(msg+1);
240 while (RTA_OK(attr, attr_len)) {
241 if (attr->rta_type == INET_DIAG_INFO) {
242 tcpinfo = reinterpret_cast<struct tcp_info*>(RTA_DATA(attr));
243 tcpinfoLength = RTA_PAYLOAD(attr);
244 }
245 if (attr->rta_type == INET_DIAG_MARK) {
246 mark.intValue = *reinterpret_cast<uint32_t*>(RTA_DATA(attr));
247 }
248 attr = RTA_NEXT(attr, attr_len);
249 }
250
251 tcpInfoReader(mark, msg, tcpinfo, tcpinfoLength);
252 };
253
254 return processNetlinkDump(mSock, callback);
255 }
256
257 // Determines whether a socket is a loopback socket. Does not check socket state.
isLoopbackSocket(const inet_diag_msg * msg)258 bool SockDiag::isLoopbackSocket(const inet_diag_msg *msg) {
259 switch (msg->idiag_family) {
260 case AF_INET:
261 // Old kernels only copy the IPv4 address and leave the other 12 bytes uninitialized.
262 return IN_LOOPBACK(htonl(msg->id.idiag_src[0])) ||
263 IN_LOOPBACK(htonl(msg->id.idiag_dst[0])) ||
264 msg->id.idiag_src[0] == msg->id.idiag_dst[0];
265
266 case AF_INET6: {
267 const struct in6_addr *src = (const struct in6_addr *) &msg->id.idiag_src;
268 const struct in6_addr *dst = (const struct in6_addr *) &msg->id.idiag_dst;
269 return (IN6_IS_ADDR_V4MAPPED(src) && IN_LOOPBACK(src->s6_addr32[3])) ||
270 (IN6_IS_ADDR_V4MAPPED(dst) && IN_LOOPBACK(dst->s6_addr32[3])) ||
271 IN6_IS_ADDR_LOOPBACK(src) || IN6_IS_ADDR_LOOPBACK(dst) ||
272 !memcmp(src, dst, sizeof(*src));
273 }
274 default:
275 return false;
276 }
277 }
278
sockDestroy(uint8_t proto,const inet_diag_msg * msg)279 int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
280 if (msg == nullptr) {
281 return 0;
282 }
283
284 DestroyRequest request = {
285 .nlh = {
286 .nlmsg_type = SOCK_DESTROY,
287 .nlmsg_flags = NLM_F_REQUEST,
288 },
289 .req = {
290 .sdiag_family = msg->idiag_family,
291 .sdiag_protocol = proto,
292 .idiag_states = (uint32_t) (1 << msg->idiag_state),
293 .id = msg->id,
294 },
295 };
296 request.nlh.nlmsg_len = sizeof(request);
297
298 if (write(mWriteSock, &request, sizeof(request)) < (ssize_t) sizeof(request)) {
299 return -errno;
300 }
301
302 int ret = checkError(mWriteSock);
303 if (!ret) mSocketsDestroyed++;
304 return ret;
305 }
306
destroySockets(uint8_t proto,int family,const char * addrstr)307 int SockDiag::destroySockets(uint8_t proto, int family, const char *addrstr) {
308 if (!hasSocks()) {
309 return -EBADFD;
310 }
311
312 if (int ret = sendDumpRequest(proto, family, addrstr)) {
313 return ret;
314 }
315
316 auto destroyAll = [] (uint8_t, const inet_diag_msg*) { return true; };
317
318 return readDiagMsg(proto, destroyAll);
319 }
320
destroySockets(const char * addrstr)321 int SockDiag::destroySockets(const char *addrstr) {
322 Stopwatch s;
323 mSocketsDestroyed = 0;
324
325 if (!strchr(addrstr, ':')) {
326 if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr)) {
327 ALOGE("Failed to destroy IPv4 sockets on %s: %s", addrstr, strerror(-ret));
328 return ret;
329 }
330 }
331 if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr)) {
332 ALOGE("Failed to destroy IPv6 sockets on %s: %s", addrstr, strerror(-ret));
333 return ret;
334 }
335
336 if (mSocketsDestroyed > 0) {
337 ALOGI("Destroyed %d sockets on %s in %" PRId64 "us", mSocketsDestroyed, addrstr,
338 s.timeTakenUs());
339 }
340
341 return mSocketsDestroyed;
342 }
343
destroyLiveSockets(const DestroyFilter & destroyFilter,const char * what,iovec * iov,int iovcnt)344 int SockDiag::destroyLiveSockets(const DestroyFilter& destroyFilter, const char *what,
345 iovec *iov, int iovcnt) {
346 const int proto = IPPROTO_TCP;
347 const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
348
349 for (const int family : {AF_INET, AF_INET6}) {
350 const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
351 if (int ret = sendDumpRequest(proto, family, 0, states, iov, iovcnt)) {
352 ALOGE("Failed to dump %s sockets for %s: %s", familyName, what, strerror(-ret));
353 return ret;
354 }
355 if (int ret = readDiagMsg(proto, destroyFilter)) {
356 ALOGE("Failed to destroy %s sockets for %s: %s", familyName, what, strerror(-ret));
357 return ret;
358 }
359 }
360
361 return 0;
362 }
363
getLiveTcpInfos(const TcpInfoReader & tcpInfoReader)364 int SockDiag::getLiveTcpInfos(const TcpInfoReader& tcpInfoReader) {
365 const int proto = IPPROTO_TCP;
366 const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
367 const uint8_t extensions = (1 << INET_DIAG_MEMINFO); // flag for dumping struct tcp_info.
368
369 iovec iov[] = {
370 { nullptr, 0 },
371 };
372
373 for (const int family : {AF_INET, AF_INET6}) {
374 const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
375 if (int ret = sendDumpRequest(proto, family, extensions, states, iov, ARRAY_SIZE(iov))) {
376 ALOGE("Failed to dump %s sockets struct tcp_info: %s", familyName, strerror(-ret));
377 return ret;
378 }
379 if (int ret = readDiagMsgWithTcpInfo(tcpInfoReader)) {
380 ALOGE("Failed to read %s sockets struct tcp_info: %s", familyName, strerror(-ret));
381 return ret;
382 }
383 }
384
385 return 0;
386 }
387
destroySockets(uint8_t proto,const uid_t uid,bool excludeLoopback)388 int SockDiag::destroySockets(uint8_t proto, const uid_t uid, bool excludeLoopback) {
389 mSocketsDestroyed = 0;
390 Stopwatch s;
391
392 auto shouldDestroy = [uid, excludeLoopback] (uint8_t, const inet_diag_msg *msg) {
393 return msg != nullptr &&
394 msg->idiag_uid == uid &&
395 !(excludeLoopback && isLoopbackSocket(msg));
396 };
397
398 for (const int family : {AF_INET, AF_INET6}) {
399 const char *familyName = family == AF_INET ? "IPv4" : "IPv6";
400 uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
401 if (int ret = sendDumpRequest(proto, family, states)) {
402 ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
403 return ret;
404 }
405 if (int ret = readDiagMsg(proto, shouldDestroy)) {
406 ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
407 return ret;
408 }
409 }
410
411 if (mSocketsDestroyed > 0) {
412 ALOGI("Destroyed %d sockets for UID in %" PRId64 "us", mSocketsDestroyed, s.timeTakenUs());
413 }
414
415 return 0;
416 }
417
destroySockets(const UidRanges & uidRanges,const std::set<uid_t> & skipUids,bool excludeLoopback)418 int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids,
419 bool excludeLoopback) {
420 mSocketsDestroyed = 0;
421 Stopwatch s;
422
423 auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
424 return msg != nullptr &&
425 uidRanges.hasUid(msg->idiag_uid) &&
426 skipUids.find(msg->idiag_uid) == skipUids.end() &&
427 !(excludeLoopback && isLoopbackSocket(msg)) &&
428 !isAdbSocket(msg, getAdbPort());
429 };
430
431 iovec iov[] = {
432 { nullptr, 0 },
433 };
434
435 if (int ret = destroyLiveSockets(shouldDestroy, "UID", iov, ARRAY_SIZE(iov))) {
436 return ret;
437 }
438
439 if (mSocketsDestroyed > 0) {
440 ALOGI("Destroyed %d sockets for %s skip={%s} in %" PRId64 "us", mSocketsDestroyed,
441 uidRanges.toString().c_str(), android::base::Join(skipUids, " ").c_str(),
442 s.timeTakenUs());
443 }
444
445 return 0;
446 }
447
448 // Destroys all "live" (CONNECTED, SYN_SENT, SYN_RECV) TCP sockets on the specified netId where:
449 // 1. The opening app no longer has permission to use this network, or:
450 // 2. The opening app does have permission, but did not explicitly select this network.
451 //
452 // We destroy sockets without the explicit bit because we want to avoid the situation where a
453 // privileged app uses its privileges without knowing it is doing so. For example, a privileged app
454 // might have opened a socket on this network just because it was the default network at the
455 // time. If we don't kill these sockets, those apps could continue to use them without realizing
456 // that they are now sending and receiving traffic on a network that is now restricted.
destroySocketsLackingPermission(unsigned netId,Permission permission,bool excludeLoopback)457 int SockDiag::destroySocketsLackingPermission(unsigned netId, Permission permission,
458 bool excludeLoopback) {
459 struct markmatch {
460 inet_diag_bc_op op;
461 // TODO: switch to inet_diag_markcond
462 __u32 mark;
463 __u32 mask;
464 } __attribute__((packed));
465 constexpr uint8_t matchlen = sizeof(markmatch);
466
467 Fwmark netIdMark, netIdMask;
468 netIdMark.netId = netId;
469 netIdMask.netId = 0xffff;
470
471 Fwmark controlMark;
472 controlMark.explicitlySelected = true;
473 controlMark.permission = permission;
474
475 // A SOCK_DIAG bytecode program that accepts the sockets we intend to destroy.
476 struct bytecode {
477 markmatch netIdMatch;
478 markmatch controlMatch;
479 inet_diag_bc_op controlJump;
480 } __attribute__((packed)) bytecode;
481
482 // The length of the INET_DIAG_BC_JMP instruction.
483 constexpr uint8_t jmplen = sizeof(inet_diag_bc_op);
484 // Jump exactly this far past the end of the program to reject.
485 constexpr uint8_t rejectoffset = sizeof(inet_diag_bc_op);
486 // Total length of the program.
487 constexpr uint8_t bytecodelen = sizeof(bytecode);
488
489 bytecode = (struct bytecode) {
490 // If netId matches, continue, otherwise, reject (i.e., leave socket alone).
491 { { INET_DIAG_BC_MARK_COND, matchlen, bytecodelen + rejectoffset },
492 netIdMark.intValue, netIdMask.intValue },
493
494 // If explicit and permission bits match, go to the JMP below which rejects the socket
495 // (i.e., we leave it alone). Otherwise, jump to the end of the program, which accepts the
496 // socket (so we destroy it).
497 { { INET_DIAG_BC_MARK_COND, matchlen, matchlen + jmplen },
498 controlMark.intValue, controlMark.intValue },
499
500 // This JMP unconditionally rejects the packet by jumping to the reject target. It is
501 // necessary to keep the kernel bytecode verifier happy. If we don't have a JMP the bytecode
502 // is invalid because the target of every no jump must always be reachable by yes jumps.
503 // Without this JMP, the accept target is not reachable by yes jumps and the program will
504 // be rejected by the validator.
505 { INET_DIAG_BC_JMP, jmplen, jmplen + rejectoffset },
506
507 // We have reached the end of the program. Accept the socket, and destroy it below.
508 };
509
510 struct nlattr nla = {
511 .nla_len = sizeof(struct nlattr) + bytecodelen,
512 .nla_type = INET_DIAG_REQ_BYTECODE,
513 };
514
515 iovec iov[] = {
516 { nullptr, 0 },
517 { &nla, sizeof(nla) },
518 { &bytecode, bytecodelen },
519 };
520
521 mSocketsDestroyed = 0;
522 Stopwatch s;
523
524 auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
525 return msg != nullptr && !(excludeLoopback && isLoopbackSocket(msg));
526 };
527
528 if (int ret = destroyLiveSockets(shouldDestroy, "permission change", iov, ARRAY_SIZE(iov))) {
529 return ret;
530 }
531
532 if (mSocketsDestroyed > 0) {
533 ALOGI("Destroyed %d sockets for netId %d permission=%d in %" PRId64 "us", mSocketsDestroyed,
534 netId, permission, s.timeTakenUs());
535 }
536
537 return 0;
538 }
539
540 } // namespace net
541 } // namespace android
542