1 /*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <errno.h>
18 #include <netdb.h>
19 #include <string.h>
20 #include <netinet/in.h>
21 #include <netinet/tcp.h>
22 #include <sys/socket.h>
23 #include <sys/uio.h>
24
25 #include <linux/netlink.h>
26 #include <linux/sock_diag.h>
27 #include <linux/inet_diag.h>
28
29 #define LOG_TAG "Netd"
30
31 #include <android-base/strings.h>
32 #include <log/log.h>
33 #include <netdutils/Stopwatch.h>
34
35 #include "NetdConstants.h"
36 #include "Permission.h"
37 #include "SockDiag.h"
38
39 #ifndef SOCK_DESTROY
40 #define SOCK_DESTROY 21
41 #endif
42
43 #define INET_DIAG_BC_MARK_COND 10
44
45 namespace android {
46
47 using netdutils::Stopwatch;
48
49 namespace net {
50 namespace {
51
checkError(int fd)52 int checkError(int fd) {
53 struct {
54 nlmsghdr h;
55 nlmsgerr err;
56 } __attribute__((__packed__)) ack;
57 ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
58 if (bytesread == -1) {
59 // Read failed (error), or nothing to read (good).
60 return (errno == EAGAIN) ? 0 : -errno;
61 } else if (bytesread == (ssize_t) sizeof(ack) && ack.h.nlmsg_type == NLMSG_ERROR) {
62 // We got an error. Consume it.
63 recv(fd, &ack, sizeof(ack), 0);
64 return ack.err.error;
65 } else {
66 // The kernel replied with something. Leave it to the caller.
67 return 0;
68 }
69 }
70
71 } // namespace
72
open()73 bool SockDiag::open() {
74 if (hasSocks()) {
75 return false;
76 }
77
78 mSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
79 mWriteSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
80 if (!hasSocks()) {
81 closeSocks();
82 return false;
83 }
84
85 sockaddr_nl nl = { .nl_family = AF_NETLINK };
86 if ((connect(mSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1) ||
87 (connect(mWriteSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1)) {
88 closeSocks();
89 return false;
90 }
91
92 return true;
93 }
94
sendDumpRequest(uint8_t proto,uint8_t family,uint8_t extensions,uint32_t states,iovec * iov,int iovcnt)95 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint8_t extensions, uint32_t states,
96 iovec *iov, int iovcnt) {
97 struct {
98 nlmsghdr nlh;
99 inet_diag_req_v2 req;
100 } __attribute__((__packed__)) request = {
101 .nlh = {
102 .nlmsg_type = SOCK_DIAG_BY_FAMILY,
103 .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
104 },
105 .req = {
106 .sdiag_family = family,
107 .sdiag_protocol = proto,
108 .idiag_ext = extensions,
109 .idiag_states = states,
110 },
111 };
112
113 size_t len = 0;
114 iov[0].iov_base = &request;
115 iov[0].iov_len = sizeof(request);
116 for (int i = 0; i < iovcnt; i++) {
117 len += iov[i].iov_len;
118 }
119 request.nlh.nlmsg_len = len;
120
121 if (writev(mSock, iov, iovcnt) != (ssize_t) len) {
122 return -errno;
123 }
124
125 return checkError(mSock);
126 }
127
sendDumpRequest(uint8_t proto,uint8_t family,uint32_t states)128 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states) {
129 iovec iov[] = {
130 { nullptr, 0 },
131 };
132 return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
133 }
134
sendDumpRequest(uint8_t proto,uint8_t family,const char * addrstr)135 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
136 addrinfo hints = { .ai_flags = AI_NUMERICHOST };
137 addrinfo *res;
138 in6_addr mapped = { .s6_addr32 = { 0, 0, htonl(0xffff), 0 } };
139 int ret;
140
141 // TODO: refactor the netlink parsing code out of system/core, bring it into netd, and stop
142 // doing string conversions when they're not necessary.
143 if ((ret = getaddrinfo(addrstr, nullptr, &hints, &res)) != 0) {
144 return -EINVAL;
145 }
146
147 // So we don't have to call freeaddrinfo on every failure path.
148 ScopedAddrinfo resP(res);
149
150 void *addr;
151 uint8_t addrlen;
152 if (res->ai_family == AF_INET && family == AF_INET) {
153 in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
154 addr = &ina;
155 addrlen = sizeof(ina);
156 } else if (res->ai_family == AF_INET && family == AF_INET6) {
157 in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
158 mapped.s6_addr32[3] = ina.s_addr;
159 addr = &mapped;
160 addrlen = sizeof(mapped);
161 } else if (res->ai_family == AF_INET6 && family == AF_INET6) {
162 in6_addr& in6a = reinterpret_cast<sockaddr_in6*>(res->ai_addr)->sin6_addr;
163 addr = &in6a;
164 addrlen = sizeof(in6a);
165 } else {
166 return -EAFNOSUPPORT;
167 }
168
169 uint8_t prefixlen = addrlen * 8;
170 uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
171 uint8_t nojump = yesjump + 4;
172
173 struct {
174 nlattr nla;
175 inet_diag_bc_op op;
176 inet_diag_hostcond cond;
177 } __attribute__((__packed__)) attrs = {
178 .nla = {
179 .nla_type = INET_DIAG_REQ_BYTECODE,
180 },
181 .op = {
182 INET_DIAG_BC_S_COND,
183 yesjump,
184 nojump,
185 },
186 .cond = {
187 family,
188 prefixlen,
189 -1,
190 {}
191 },
192 };
193
194 attrs.nla.nla_len = sizeof(attrs) + addrlen;
195
196 iovec iov[] = {
197 { nullptr, 0 },
198 { &attrs, sizeof(attrs) },
199 { addr, addrlen },
200 };
201
202 uint32_t states = ~(1 << TCP_TIME_WAIT);
203 return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
204 }
205
readDiagMsg(uint8_t proto,const SockDiag::DestroyFilter & shouldDestroy)206 int SockDiag::readDiagMsg(uint8_t proto, const SockDiag::DestroyFilter& shouldDestroy) {
207 NetlinkDumpCallback callback = [this, proto, shouldDestroy] (nlmsghdr *nlh) {
208 const inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
209 if (shouldDestroy(proto, msg)) {
210 sockDestroy(proto, msg);
211 }
212 };
213
214 return processNetlinkDump(mSock, callback);
215 }
216
readDiagMsgWithTcpInfo(const TcpInfoReader & tcpInfoReader)217 int SockDiag::readDiagMsgWithTcpInfo(const TcpInfoReader& tcpInfoReader) {
218 NetlinkDumpCallback callback = [tcpInfoReader] (nlmsghdr *nlh) {
219 if (nlh->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
220 ALOGE("expected nlmsg_type=SOCK_DIAG_BY_FAMILY, got nlmsg_type=%d", nlh->nlmsg_type);
221 return;
222 }
223 Fwmark mark;
224 struct tcp_info *tcpinfo = nullptr;
225 uint32_t tcpinfoLength = 0;
226 inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
227 uint32_t attr_len = nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*msg));
228 struct rtattr *attr = reinterpret_cast<struct rtattr*>(msg+1);
229 while (RTA_OK(attr, attr_len)) {
230 if (attr->rta_type == INET_DIAG_INFO) {
231 tcpinfo = reinterpret_cast<struct tcp_info*>(RTA_DATA(attr));
232 tcpinfoLength = RTA_PAYLOAD(attr);
233 }
234 if (attr->rta_type == INET_DIAG_MARK) {
235 mark.intValue = *reinterpret_cast<uint32_t*>(RTA_DATA(attr));
236 }
237 attr = RTA_NEXT(attr, attr_len);
238 }
239
240 tcpInfoReader(mark, msg, tcpinfo, tcpinfoLength);
241 };
242
243 return processNetlinkDump(mSock, callback);
244 }
245
246 // Determines whether a socket is a loopback socket. Does not check socket state.
isLoopbackSocket(const inet_diag_msg * msg)247 bool SockDiag::isLoopbackSocket(const inet_diag_msg *msg) {
248 switch (msg->idiag_family) {
249 case AF_INET:
250 // Old kernels only copy the IPv4 address and leave the other 12 bytes uninitialized.
251 return IN_LOOPBACK(htonl(msg->id.idiag_src[0])) ||
252 IN_LOOPBACK(htonl(msg->id.idiag_dst[0])) ||
253 msg->id.idiag_src[0] == msg->id.idiag_dst[0];
254
255 case AF_INET6: {
256 const struct in6_addr *src = (const struct in6_addr *) &msg->id.idiag_src;
257 const struct in6_addr *dst = (const struct in6_addr *) &msg->id.idiag_dst;
258 return (IN6_IS_ADDR_V4MAPPED(src) && IN_LOOPBACK(src->s6_addr32[3])) ||
259 (IN6_IS_ADDR_V4MAPPED(dst) && IN_LOOPBACK(dst->s6_addr32[3])) ||
260 IN6_IS_ADDR_LOOPBACK(src) || IN6_IS_ADDR_LOOPBACK(dst) ||
261 !memcmp(src, dst, sizeof(*src));
262 }
263 default:
264 return false;
265 }
266 }
267
sockDestroy(uint8_t proto,const inet_diag_msg * msg)268 int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
269 if (msg == nullptr) {
270 return 0;
271 }
272
273 DestroyRequest request = {
274 .nlh = {
275 .nlmsg_type = SOCK_DESTROY,
276 .nlmsg_flags = NLM_F_REQUEST,
277 },
278 .req = {
279 .sdiag_family = msg->idiag_family,
280 .sdiag_protocol = proto,
281 .idiag_states = (uint32_t) (1 << msg->idiag_state),
282 .id = msg->id,
283 },
284 };
285 request.nlh.nlmsg_len = sizeof(request);
286
287 if (write(mWriteSock, &request, sizeof(request)) < (ssize_t) sizeof(request)) {
288 return -errno;
289 }
290
291 int ret = checkError(mWriteSock);
292 if (!ret) mSocketsDestroyed++;
293 return ret;
294 }
295
destroySockets(uint8_t proto,int family,const char * addrstr)296 int SockDiag::destroySockets(uint8_t proto, int family, const char *addrstr) {
297 if (!hasSocks()) {
298 return -EBADFD;
299 }
300
301 if (int ret = sendDumpRequest(proto, family, addrstr)) {
302 return ret;
303 }
304
305 auto destroyAll = [] (uint8_t, const inet_diag_msg*) { return true; };
306
307 return readDiagMsg(proto, destroyAll);
308 }
309
destroySockets(const char * addrstr)310 int SockDiag::destroySockets(const char *addrstr) {
311 Stopwatch s;
312 mSocketsDestroyed = 0;
313
314 if (!strchr(addrstr, ':')) {
315 if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr)) {
316 ALOGE("Failed to destroy IPv4 sockets on %s: %s", addrstr, strerror(-ret));
317 return ret;
318 }
319 }
320 if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr)) {
321 ALOGE("Failed to destroy IPv6 sockets on %s: %s", addrstr, strerror(-ret));
322 return ret;
323 }
324
325 if (mSocketsDestroyed > 0) {
326 ALOGI("Destroyed %d sockets on %s in %.1f ms", mSocketsDestroyed, addrstr, s.timeTaken());
327 }
328
329 return mSocketsDestroyed;
330 }
331
destroyLiveSockets(const DestroyFilter & destroyFilter,const char * what,iovec * iov,int iovcnt)332 int SockDiag::destroyLiveSockets(const DestroyFilter& destroyFilter, const char *what,
333 iovec *iov, int iovcnt) {
334 const int proto = IPPROTO_TCP;
335 const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
336
337 for (const int family : {AF_INET, AF_INET6}) {
338 const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
339 if (int ret = sendDumpRequest(proto, family, 0, states, iov, iovcnt)) {
340 ALOGE("Failed to dump %s sockets for %s: %s", familyName, what, strerror(-ret));
341 return ret;
342 }
343 if (int ret = readDiagMsg(proto, destroyFilter)) {
344 ALOGE("Failed to destroy %s sockets for %s: %s", familyName, what, strerror(-ret));
345 return ret;
346 }
347 }
348
349 return 0;
350 }
351
getLiveTcpInfos(const TcpInfoReader & tcpInfoReader)352 int SockDiag::getLiveTcpInfos(const TcpInfoReader& tcpInfoReader) {
353 const int proto = IPPROTO_TCP;
354 const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
355 const uint8_t extensions = (1 << INET_DIAG_MEMINFO); // flag for dumping struct tcp_info.
356
357 iovec iov[] = {
358 { nullptr, 0 },
359 };
360
361 for (const int family : {AF_INET, AF_INET6}) {
362 const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
363 if (int ret = sendDumpRequest(proto, family, extensions, states, iov, ARRAY_SIZE(iov))) {
364 ALOGE("Failed to dump %s sockets struct tcp_info: %s", familyName, strerror(-ret));
365 return ret;
366 }
367 if (int ret = readDiagMsgWithTcpInfo(tcpInfoReader)) {
368 ALOGE("Failed to read %s sockets struct tcp_info: %s", familyName, strerror(-ret));
369 return ret;
370 }
371 }
372
373 return 0;
374 }
375
destroySockets(uint8_t proto,const uid_t uid,bool excludeLoopback)376 int SockDiag::destroySockets(uint8_t proto, const uid_t uid, bool excludeLoopback) {
377 mSocketsDestroyed = 0;
378 Stopwatch s;
379
380 auto shouldDestroy = [uid, excludeLoopback] (uint8_t, const inet_diag_msg *msg) {
381 return msg != nullptr &&
382 msg->idiag_uid == uid &&
383 !(excludeLoopback && isLoopbackSocket(msg));
384 };
385
386 for (const int family : {AF_INET, AF_INET6}) {
387 const char *familyName = family == AF_INET ? "IPv4" : "IPv6";
388 uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
389 if (int ret = sendDumpRequest(proto, family, states)) {
390 ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
391 return ret;
392 }
393 if (int ret = readDiagMsg(proto, shouldDestroy)) {
394 ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
395 return ret;
396 }
397 }
398
399 if (mSocketsDestroyed > 0) {
400 ALOGI("Destroyed %d sockets for UID in %.1f ms", mSocketsDestroyed, s.timeTaken());
401 }
402
403 return 0;
404 }
405
destroySockets(const UidRanges & uidRanges,const std::set<uid_t> & skipUids,bool excludeLoopback)406 int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids,
407 bool excludeLoopback) {
408 mSocketsDestroyed = 0;
409 Stopwatch s;
410
411 auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
412 return msg != nullptr &&
413 uidRanges.hasUid(msg->idiag_uid) &&
414 skipUids.find(msg->idiag_uid) == skipUids.end() &&
415 !(excludeLoopback && isLoopbackSocket(msg));
416 };
417
418 iovec iov[] = {
419 { nullptr, 0 },
420 };
421
422 if (int ret = destroyLiveSockets(shouldDestroy, "UID", iov, ARRAY_SIZE(iov))) {
423 return ret;
424 }
425
426 if (mSocketsDestroyed > 0) {
427 ALOGI("Destroyed %d sockets for %s skip={%s} in %.1f ms",
428 mSocketsDestroyed, uidRanges.toString().c_str(),
429 android::base::Join(skipUids, " ").c_str(), s.timeTaken());
430 }
431
432 return 0;
433 }
434
435 // Destroys all "live" (CONNECTED, SYN_SENT, SYN_RECV) TCP sockets on the specified netId where:
436 // 1. The opening app no longer has permission to use this network, or:
437 // 2. The opening app does have permission, but did not explicitly select this network.
438 //
439 // We destroy sockets without the explicit bit because we want to avoid the situation where a
440 // privileged app uses its privileges without knowing it is doing so. For example, a privileged app
441 // might have opened a socket on this network just because it was the default network at the
442 // time. If we don't kill these sockets, those apps could continue to use them without realizing
443 // that they are now sending and receiving traffic on a network that is now restricted.
destroySocketsLackingPermission(unsigned netId,Permission permission,bool excludeLoopback)444 int SockDiag::destroySocketsLackingPermission(unsigned netId, Permission permission,
445 bool excludeLoopback) {
446 struct markmatch {
447 inet_diag_bc_op op;
448 // TODO: switch to inet_diag_markcond
449 __u32 mark;
450 __u32 mask;
451 } __attribute__((packed));
452 constexpr uint8_t matchlen = sizeof(markmatch);
453
454 Fwmark netIdMark, netIdMask;
455 netIdMark.netId = netId;
456 netIdMask.netId = 0xffff;
457
458 Fwmark controlMark;
459 controlMark.explicitlySelected = true;
460 controlMark.permission = permission;
461
462 // A SOCK_DIAG bytecode program that accepts the sockets we intend to destroy.
463 struct bytecode {
464 markmatch netIdMatch;
465 markmatch controlMatch;
466 inet_diag_bc_op controlJump;
467 } __attribute__((packed)) bytecode;
468
469 // The length of the INET_DIAG_BC_JMP instruction.
470 constexpr uint8_t jmplen = sizeof(inet_diag_bc_op);
471 // Jump exactly this far past the end of the program to reject.
472 constexpr uint8_t rejectoffset = sizeof(inet_diag_bc_op);
473 // Total length of the program.
474 constexpr uint8_t bytecodelen = sizeof(bytecode);
475
476 bytecode = (struct bytecode) {
477 // If netId matches, continue, otherwise, reject (i.e., leave socket alone).
478 { { INET_DIAG_BC_MARK_COND, matchlen, bytecodelen + rejectoffset },
479 netIdMark.intValue, netIdMask.intValue },
480
481 // If explicit and permission bits match, go to the JMP below which rejects the socket
482 // (i.e., we leave it alone). Otherwise, jump to the end of the program, which accepts the
483 // socket (so we destroy it).
484 { { INET_DIAG_BC_MARK_COND, matchlen, matchlen + jmplen },
485 controlMark.intValue, controlMark.intValue },
486
487 // This JMP unconditionally rejects the packet by jumping to the reject target. It is
488 // necessary to keep the kernel bytecode verifier happy. If we don't have a JMP the bytecode
489 // is invalid because the target of every no jump must always be reachable by yes jumps.
490 // Without this JMP, the accept target is not reachable by yes jumps and the program will
491 // be rejected by the validator.
492 { INET_DIAG_BC_JMP, jmplen, jmplen + rejectoffset },
493
494 // We have reached the end of the program. Accept the socket, and destroy it below.
495 };
496
497 struct nlattr nla = {
498 .nla_type = INET_DIAG_REQ_BYTECODE,
499 .nla_len = sizeof(struct nlattr) + bytecodelen,
500 };
501
502 iovec iov[] = {
503 { nullptr, 0 },
504 { &nla, sizeof(nla) },
505 { &bytecode, bytecodelen },
506 };
507
508 mSocketsDestroyed = 0;
509 Stopwatch s;
510
511 auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
512 return msg != nullptr && !(excludeLoopback && isLoopbackSocket(msg));
513 };
514
515 if (int ret = destroyLiveSockets(shouldDestroy, "permission change", iov, ARRAY_SIZE(iov))) {
516 return ret;
517 }
518
519 if (mSocketsDestroyed > 0) {
520 ALOGI("Destroyed %d sockets for netId %d permission=%d in %.1f ms",
521 mSocketsDestroyed, netId, permission, s.timeTaken());
522 }
523
524 return 0;
525 }
526
527 } // namespace net
528 } // namespace android
529