• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "netlink_socket_diag.h"
17 
18 #include <arpa/inet.h>
19 #include <cstring>
20 #include <net/if.h>
21 #include <netinet/tcp.h>
22 #include <sys/uio.h>
23 #include <unistd.h>
24 
25 #include "fwmark.h"
26 #include "net_manager_constants.h"
27 #include "netmanager_base_common_utils.h"
28 #include "netnative_log_wrapper.h"
29 #include "securec.h"
30 
31 namespace OHOS {
32 namespace nmd {
33 using namespace NetManagerStandard;
34 
35 namespace {
36 constexpr uint32_t KERNEL_BUFFER_SIZE = 8192U;
37 constexpr uint8_t ADDR_POSITION = 3U;
38 constexpr int32_t DOMAIN_IP_ADDR_MAX_LEN = 128;
39 constexpr uint32_t LOCKBACK_MASK = 0xff000000;
40 constexpr uint32_t LOCKBACK_DEFINE = 0x7f000000;
41 constexpr uid_t PUSH_UID = 7023;
42 } // namespace
43 
~NetLinkSocketDiag()44 NetLinkSocketDiag::~NetLinkSocketDiag()
45 {
46     CloseNetlinkSocket();
47 }
48 
InLookBack(uint32_t a)49 bool NetLinkSocketDiag::InLookBack(uint32_t a)
50 {
51     return (a & LOCKBACK_MASK) == LOCKBACK_DEFINE;
52 }
53 
CreateNetlinkSocket()54 bool NetLinkSocketDiag::CreateNetlinkSocket()
55 {
56     dumpSock_ = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
57     if (dumpSock_ < 0) {
58         NETNATIVE_LOGE("Create netlink socket for dump failed, error[%{public}d]: %{public}s", errno, strerror(errno));
59         return false;
60     }
61 
62     destroySock_ = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
63     if (destroySock_ < 0) {
64         NETNATIVE_LOGE("Create netlink socket for destroy failed, error[%{public}d]: %{public}s", errno,
65                        strerror(errno));
66         close(dumpSock_);
67         dumpSock_ = -1;
68         return false;
69     }
70 
71     sockaddr_nl nl = {.nl_family = AF_NETLINK};
72     if ((connect(dumpSock_, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) < 0) ||
73         (connect(destroySock_, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) < 0)) {
74         NETNATIVE_LOGE("Connect to netlink socket failed, error[%{public}d]: %{public}s", errno, strerror(errno));
75         CloseNetlinkSocket();
76         return false;
77     }
78     return true;
79 }
80 
CloseNetlinkSocket()81 void NetLinkSocketDiag::CloseNetlinkSocket()
82 {
83     // avoid double close
84     if (dumpSock_ >= 0) {
85         close(dumpSock_);
86     }
87     if (destroySock_ >= 0) {
88         close(destroySock_);
89     }
90     dumpSock_ = -1;
91     destroySock_ = -1;
92 }
93 
ExecuteDestroySocket(uint8_t proto,const inet_diag_msg * msg)94 int32_t NetLinkSocketDiag::ExecuteDestroySocket(uint8_t proto, const inet_diag_msg *msg)
95 {
96     if (msg == nullptr) {
97         NETNATIVE_LOGE("inet_diag_msg is nullptr");
98         return NETMANAGER_ERR_LOCAL_PTR_NULL;
99     }
100 
101     SockDiagRequest request;
102     request.nlh_.nlmsg_type = SOCK_DESTROY;
103     request.nlh_.nlmsg_flags = NLM_F_REQUEST;
104     request.nlh_.nlmsg_len = sizeof(request);
105 
106     request.req_ = {.sdiag_family = msg->idiag_family,
107                     .sdiag_protocol = proto,
108                     .idiag_states = static_cast<uint32_t>(1 << msg->idiag_state),
109                     .id = msg->id};
110     ssize_t writeLen = write(destroySock_, &request, sizeof(request));
111     if (writeLen < static_cast<ssize_t>(sizeof(request))) {
112         NETNATIVE_LOGE("Write destroy request to socket failed errno[%{public}d]: strerror:%{public}s", errno,
113                        strerror(errno));
114         return NETMANAGER_ERR_INTERNAL;
115     }
116 
117     int32_t ret = GetErrorFromKernel(destroySock_);
118     if (ret == NETMANAGER_SUCCESS) {
119         socketsDestroyed_++;
120     }
121     return ret;
122 }
123 
GetErrorFromKernel(int32_t fd)124 int32_t NetLinkSocketDiag::GetErrorFromKernel(int32_t fd)
125 {
126     Ack ack;
127     ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
128     if (bytesread < 0) {
129         return (errno == EAGAIN) ? NETMANAGER_SUCCESS : -errno;
130     }
131     if (bytesread == static_cast<ssize_t>(sizeof(ack)) && ack.hdr_.nlmsg_type == NLMSG_ERROR) {
132         recv(fd, &ack, sizeof(ack), 0);
133         NETNATIVE_LOGE("Receive NLMSG_ERROR:[%{public}d] from kernel", ack.err_.error);
134         return NETMANAGER_ERR_INTERNAL;
135     }
136     return NETMANAGER_SUCCESS;
137 }
138 
IsLoopbackSocket(const inet_diag_msg * msg)139 bool NetLinkSocketDiag::IsLoopbackSocket(const inet_diag_msg *msg)
140 {
141     if (msg->idiag_family == AF_INET) {
142         return InLookBack(htonl(msg->id.idiag_src[0])) || InLookBack(htonl(msg->id.idiag_dst[0]));
143     }
144 
145     if (msg->idiag_family == AF_INET6) {
146         const struct in6_addr *src = (const struct in6_addr *)&msg->id.idiag_src;
147         const struct in6_addr *dst = (const struct in6_addr *)&msg->id.idiag_dst;
148         return (IN6_IS_ADDR_V4MAPPED(src) && InLookBack(src->s6_addr32[ADDR_POSITION])) ||
149                (IN6_IS_ADDR_V4MAPPED(dst) && InLookBack(dst->s6_addr32[ADDR_POSITION])) || IN6_IS_ADDR_LOOPBACK(src) ||
150                IN6_IS_ADDR_LOOPBACK(dst);
151     }
152     return false;
153 }
154 
IsMatchNetwork(const inet_diag_msg * msg,const std::string & ipAddr)155 bool NetLinkSocketDiag::IsMatchNetwork(const inet_diag_msg *msg, const std::string &ipAddr)
156 {
157     if (msg->idiag_family == AF_INET) {
158         if (CommonUtils::GetAddrFamily(ipAddr) != AF_INET) {
159             return false;
160         }
161 
162         in_addr_t addr = inet_addr(ipAddr.c_str());
163         if (addr == msg->id.idiag_src[0] || addr == msg->id.idiag_dst[0]) {
164             return true;
165         }
166     }
167 
168     if (msg->idiag_family == AF_INET6) {
169         if (CommonUtils::GetAddrFamily(ipAddr) != AF_INET6) {
170             return false;
171         }
172 
173         char src[DOMAIN_IP_ADDR_MAX_LEN] = {0};
174         char dst[DOMAIN_IP_ADDR_MAX_LEN] = {0};
175         inet_ntop(AF_INET6, msg->id.idiag_src, src, sizeof(src));
176         inet_ntop(AF_INET6, msg->id.idiag_dst, dst, sizeof(dst));
177         if (src == ipAddr || dst == ipAddr) {
178             return true;
179         }
180     }
181     return false;
182 }
183 
ProcessSockDiagDumpResponse(uint8_t proto,const std::string & ipAddr,bool excludeLoopback)184 int32_t NetLinkSocketDiag::ProcessSockDiagDumpResponse(uint8_t proto, const std::string &ipAddr, bool excludeLoopback)
185 {
186     char buf[KERNEL_BUFFER_SIZE] = {0};
187     ssize_t readBytes = read(dumpSock_, buf, sizeof(buf));
188     if (readBytes < 0) {
189         NETNATIVE_LOGE("Failed to read socket, errno:%{public}d, strerror:%{public}s", errno, strerror(errno));
190         return NETMANAGER_ERR_INTERNAL;
191     }
192     while (readBytes > 0) {
193         uint32_t len = static_cast<uint32_t>(readBytes);
194         for (nlmsghdr *nlh = reinterpret_cast<nlmsghdr *>(buf); NLMSG_OK(nlh, len); nlh = NLMSG_NEXT(nlh, len)) {
195             if (nlh->nlmsg_type == NLMSG_ERROR) {
196                 nlmsgerr *err = reinterpret_cast<nlmsgerr *>(NLMSG_DATA(nlh));
197                 NETNATIVE_LOGE("Error netlink msg, errno:%{public}d, strerror:%{public}s", -err->error,
198                                strerror(-err->error));
199                 return err->error;
200             } else if (nlh->nlmsg_type == NLMSG_DONE) {
201                 return NETMANAGER_SUCCESS;
202             } else {
203                 const auto *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
204                 SockDiagDumpCallback(proto, msg, ipAddr, excludeLoopback);
205             }
206         }
207         readBytes = read(dumpSock_, buf, sizeof(buf));
208         if (readBytes < 0) {
209             return -errno;
210         }
211     }
212     return NETMANAGER_SUCCESS;
213 }
214 
SendSockDiagDumpRequest(uint8_t proto,uint8_t family,uint32_t states)215 int32_t NetLinkSocketDiag::SendSockDiagDumpRequest(uint8_t proto, uint8_t family, uint32_t states)
216 {
217     SockDiagRequest request;
218     size_t len = sizeof(request);
219     iovec iov;
220     iov.iov_base = &request;
221     iov.iov_len = len;
222     request.nlh_.nlmsg_type = SOCK_DIAG_BY_FAMILY;
223     request.nlh_.nlmsg_flags = (NLM_F_REQUEST | NLM_F_DUMP);
224     request.nlh_.nlmsg_len = len;
225 
226     request.req_ = {.sdiag_family = family, .sdiag_protocol = proto, .idiag_states = states};
227 
228     ssize_t writeLen = writev(dumpSock_, &iov, (sizeof(iov) / sizeof(iovec)));
229     if (writeLen != static_cast<ssize_t>(len)) {
230         NETNATIVE_LOGE("Write dump request failed errno:%{public}d, strerror:%{public}s", errno, strerror(errno));
231         return NETMANAGER_ERR_INTERNAL;
232     }
233 
234     return GetErrorFromKernel(dumpSock_);
235 }
236 
SockDiagDumpCallback(uint8_t proto,const inet_diag_msg * msg,const std::string & ipAddr,bool excludeLoopback)237 void NetLinkSocketDiag::SockDiagDumpCallback(uint8_t proto, const inet_diag_msg *msg, const std::string &ipAddr,
238                                              bool excludeLoopback)
239 {
240     if (msg == nullptr) {
241         NETNATIVE_LOGE("msg is nullptr");
242         return;
243     }
244 
245     if (socketDestroyType_ == SocketDestroyType::DESTROY_SPECIAL_CELLULAR && msg->idiag_uid != PUSH_UID) {
246         return;
247     }
248 
249     if (socketDestroyType_ == SocketDestroyType::DESTROY_DEFAULT_CELLULAR && msg->idiag_uid == PUSH_UID) {
250         return;
251     }
252 
253     if (excludeLoopback && IsLoopbackSocket(msg)) {
254         NETNATIVE_LOG_D("Loop back socket, no need to close.");
255         return;
256     }
257 
258     if (!IsMatchNetwork(msg, ipAddr)) {
259         NETNATIVE_LOG_D("Socket is not associated with the network");
260         return;
261     }
262 
263     ExecuteDestroySocket(proto, msg);
264 }
265 
DestroyLiveSockets(const char * ipAddr,bool excludeLoopback)266 void NetLinkSocketDiag::DestroyLiveSockets(const char *ipAddr, bool excludeLoopback)
267 {
268     NETNATIVE_LOG_D("DestroySocketsLackingNetwork in");
269     if (ipAddr == nullptr) {
270         NETNATIVE_LOGE("Ip address is nullptr.");
271         return;
272     }
273 
274     if (!CreateNetlinkSocket()) {
275         NETNATIVE_LOGE("Create netlink diag socket failed.");
276         return;
277     }
278 
279     const int32_t proto = IPPROTO_TCP;
280     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
281 
282     for (const int family : {AF_INET, AF_INET6}) {
283         int32_t ret = SendSockDiagDumpRequest(proto, family, states);
284         if (ret != NETMANAGER_SUCCESS) {
285             NETNATIVE_LOGE("Failed to dump %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
286             break;
287         }
288         ret = ProcessSockDiagDumpResponse(proto, ipAddr, excludeLoopback);
289         if (ret != NETMANAGER_SUCCESS) {
290             NETNATIVE_LOGE("Failed to destroy %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
291             break;
292         }
293     }
294 
295     NETNATIVE_LOGI("Destroyed %{public}d sockets", socketsDestroyed_);
296 }
297 
SetSocketDestroyType(const std::string & netCapabilities)298 int32_t NetLinkSocketDiag::SetSocketDestroyType(const std::string &netCapabilities)
299 {
300     const std::string capSpecialCellularStr = "NET_CAPABILITY_INTERNAL_DEFAULT";
301     const std::string bearerCellularStr = "BEARER_CELLULAR";
302     if (netCapabilities.find(capSpecialCellularStr) != std::string::npos) {
303         socketDestroyType_ = SocketDestroyType::DESTROY_SPECIAL_CELLULAR;
304     } else if (netCapabilities.find(bearerCellularStr) != std::string::npos) {
305         socketDestroyType_ = SocketDestroyType::DESTROY_DEFAULT_CELLULAR;
306     } else {
307         socketDestroyType_ = SocketDestroyType::DESTROY_DEFAULT;
308     }
309     return 0;
310 }
311 
SockDiagUidDumpCallback(uint8_t proto,const inet_diag_msg * msg,const NetLinkSocketDiag::DestroyFilter & needDestroy)312 void NetLinkSocketDiag::SockDiagUidDumpCallback(uint8_t proto, const inet_diag_msg *msg,
313     const NetLinkSocketDiag::DestroyFilter& needDestroy)
314 {
315     NETNATIVE_LOG_D(" SockDiagUidDumpCallback");
316     if (!needDestroy(msg)) {
317         return;
318     }
319 
320     ExecuteDestroySocket(proto, msg);
321 }
322 
ProcessSockDiagUidDumpResponse(uint8_t proto,const NetLinkSocketDiag::DestroyFilter & needDestroy)323 int32_t NetLinkSocketDiag::ProcessSockDiagUidDumpResponse(uint8_t proto,
324     const NetLinkSocketDiag::DestroyFilter& needDestroy)
325 {
326     NETNATIVE_LOG_D("ProcessSockDiagUidDumpResponse");
327     char buf[KERNEL_BUFFER_SIZE] = {0};
328     ssize_t readBytes = read(dumpSock_, buf, sizeof(buf));
329     if (readBytes < 0) {
330         return NETMANAGER_ERR_INTERNAL;
331     }
332     while (readBytes > 0) {
333         int len = readBytes;
334         for (nlmsghdr *nlh = reinterpret_cast<nlmsghdr *>(buf); NLMSG_OK(nlh, len); nlh = NLMSG_NEXT(nlh, len)) {
335             if (nlh->nlmsg_type == NLMSG_ERROR) {
336                 nlmsgerr *err = reinterpret_cast<nlmsgerr *>(NLMSG_DATA(nlh));
337                 NETNATIVE_LOGE("Error netlink msg, errno:%{public}d, strerror:%{public}s", -err->error,
338                     strerror(-err->error));
339                 return err->error;
340             } else if (nlh->nlmsg_type == NLMSG_DONE) {
341                 NETNATIVE_LOGE("ProcessSockDiagUidDumpResponse nlh->nlmsg_type == NLMSG_DONE");
342                 return NETMANAGER_SUCCESS;
343             } else {
344                 const auto *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
345                 SockDiagUidDumpCallback(proto, msg, needDestroy);
346             }
347         }
348         readBytes = read(dumpSock_, buf, sizeof(buf));
349         if (readBytes < 0) {
350             NETNATIVE_LOGE("ProcessSockDiagUidDumpResponse readBytes < 0");
351             return -errno;
352         }
353     }
354     return NETMANAGER_SUCCESS;
355 }
356 
DestroyLiveSocketsWithUid(const std::string & ipAddr,uint32_t uid)357 void NetLinkSocketDiag::DestroyLiveSocketsWithUid(const std::string &ipAddr, uint32_t uid)
358 {
359     NETNATIVE_LOG_D("TCP-RST DestroyLiveSocketsWithUid, uid:%{public}d", uid);
360     if (!CreateNetlinkSocket()) {
361         NETNATIVE_LOGE("Create netlink diag socket failed.");
362         return;
363     }
364     auto needDestroy = [&] (const inet_diag_msg *msg) {
365         return msg != nullptr && uid == msg->idiag_uid && IsMatchNetwork(msg, ipAddr) && !IsLoopbackSocket(msg);
366     };
367     const int32_t proto = IPPROTO_TCP;
368     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
369     for (const int family : {AF_INET, AF_INET6}) {
370         int32_t ret = SendSockDiagDumpRequest(proto, family, states);
371         if (ret != NETMANAGER_SUCCESS) {
372             NETNATIVE_LOGE("Failed to dump %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
373             break;
374         }
375         ret = ProcessSockDiagUidDumpResponse(proto, needDestroy);
376         if (ret != NETMANAGER_SUCCESS) {
377             NETNATIVE_LOGE("Failed to destroy %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
378             break;
379         }
380     }
381 
382     NETNATIVE_LOG_D("TCP-RST Destroyed %{public}d sockets", socketsDestroyed_);
383 }
384 
DestroyLiveSocketsWithUid(uint32_t uid)385 void NetLinkSocketDiag::DestroyLiveSocketsWithUid(uint32_t uid)
386 {
387     NETNATIVE_LOG_D("TCP-RST DestroyLiveSocketsWithUid, uid:%{public}d", uid);
388     if (!CreateNetlinkSocket()) {
389         NETNATIVE_LOGE("Create netlink diag socket failed.");
390         return;
391     }
392     auto needDestroy = [&] (const inet_diag_msg *msg) -> bool {
393         return msg != nullptr && uid == msg->idiag_uid && !IsLoopbackSocket(msg);
394     };
395     const int32_t proto = IPPROTO_TCP;
396     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
397     for (const int family : {AF_INET, AF_INET6}) {
398         int32_t ret = SendSockDiagDumpRequest(proto, family, states);
399         if (ret != NETMANAGER_SUCCESS) {
400             NETNATIVE_LOGE("Failed to dump %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
401             break;
402         }
403         ret = ProcessSockDiagUidDumpResponse(proto, needDestroy);
404         if (ret != NETMANAGER_SUCCESS) {
405             NETNATIVE_LOGE("Failed to destroy %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
406             break;
407         }
408     }
409 
410     NETNATIVE_LOG_D("TCP-RST Destroyed %{public}d sockets", socketsDestroyed_);
411 }
412 } // namespace nmd
413 } // namespace OHOS