• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "netlink_socket_diag.h"
17 
18 #include <arpa/inet.h>
19 #include <cstring>
20 #include <net/if.h>
21 #include <netinet/tcp.h>
22 #include <sys/uio.h>
23 #include <unistd.h>
24 
25 #include "fwmark.h"
26 #include "net_manager_constants.h"
27 #include "netmanager_base_common_utils.h"
28 #include "netnative_log_wrapper.h"
29 #include "securec.h"
30 
31 namespace OHOS {
32 namespace nmd {
33 using namespace NetManagerStandard;
34 
35 namespace {
36 constexpr uint32_t KERNEL_BUFFER_SIZE = 8192U;
37 constexpr uint8_t ADDR_POSITION = 3U;
38 constexpr int32_t DOMAIN_IP_ADDR_MAX_LEN = 128;
39 constexpr uint32_t LOCKBACK_MASK = 0xff000000;
40 constexpr uint32_t LOCKBACK_DEFINE = 0x7f000000;
41 constexpr uid_t PUSH_UID = 7023;
42 } // namespace
43 
~NetLinkSocketDiag()44 NetLinkSocketDiag::~NetLinkSocketDiag()
45 {
46     CloseNetlinkSocket();
47 }
48 
InLookBack(uint32_t a)49 bool NetLinkSocketDiag::InLookBack(uint32_t a)
50 {
51     return (a & LOCKBACK_MASK) == LOCKBACK_DEFINE;
52 }
53 
CreateNetlinkSocket()54 bool NetLinkSocketDiag::CreateNetlinkSocket()
55 {
56     dumpSock_ = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
57     if (dumpSock_ < 0) {
58         NETNATIVE_LOGE("Create netlink socket for dump failed, error[%{public}d]: %{public}s", errno, strerror(errno));
59         return false;
60     }
61 
62     destroySock_ = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
63     if (destroySock_ < 0) {
64         NETNATIVE_LOGE("Create netlink socket for destroy failed, error[%{public}d]: %{public}s", errno,
65                        strerror(errno));
66         close(dumpSock_);
67         return false;
68     }
69 
70     sockaddr_nl nl = {.nl_family = AF_NETLINK};
71     if ((connect(dumpSock_, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) < 0) ||
72         (connect(destroySock_, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) < 0)) {
73         NETNATIVE_LOGE("Connect to netlink socket failed, error[%{public}d]: %{public}s", errno, strerror(errno));
74         CloseNetlinkSocket();
75         return false;
76     }
77     return true;
78 }
79 
CloseNetlinkSocket()80 void NetLinkSocketDiag::CloseNetlinkSocket()
81 {
82     close(dumpSock_);
83     close(destroySock_);
84     dumpSock_ = -1;
85     destroySock_ = -1;
86 }
87 
ExecuteDestroySocket(uint8_t proto,const inet_diag_msg * msg)88 int32_t NetLinkSocketDiag::ExecuteDestroySocket(uint8_t proto, const inet_diag_msg *msg)
89 {
90     if (msg == nullptr) {
91         NETNATIVE_LOGE("inet_diag_msg is nullptr");
92         return NETMANAGER_ERR_LOCAL_PTR_NULL;
93     }
94 
95     SockDiagRequest request;
96     request.nlh_.nlmsg_type = SOCK_DESTROY;
97     request.nlh_.nlmsg_flags = NLM_F_REQUEST;
98     request.nlh_.nlmsg_len = sizeof(request);
99 
100     request.req_ = {.sdiag_family = msg->idiag_family,
101                     .sdiag_protocol = proto,
102                     .idiag_states = static_cast<uint32_t>(1 << msg->idiag_state),
103                     .id = msg->id};
104     ssize_t writeLen = write(destroySock_, &request, sizeof(request));
105     if (writeLen < static_cast<ssize_t>(sizeof(request))) {
106         NETNATIVE_LOGE("Write destroy request to socket failed errno[%{public}d]: strerror:%{public}s", errno,
107                        strerror(errno));
108         return NETMANAGER_ERR_INTERNAL;
109     }
110 
111     int32_t ret = GetErrorFromKernel(destroySock_);
112     if (ret == NETMANAGER_SUCCESS) {
113         socketsDestroyed_++;
114     }
115     return ret;
116 }
117 
GetErrorFromKernel(int32_t fd)118 int32_t NetLinkSocketDiag::GetErrorFromKernel(int32_t fd)
119 {
120     Ack ack;
121     ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
122     if (bytesread < 0) {
123         return (errno == EAGAIN) ? NETMANAGER_SUCCESS : -errno;
124     }
125     if (bytesread == static_cast<ssize_t>(sizeof(ack)) && ack.hdr_.nlmsg_type == NLMSG_ERROR) {
126         recv(fd, &ack, sizeof(ack), 0);
127         NETNATIVE_LOGE("Receive NLMSG_ERROR:[%{public}d] from kernel", ack.err_.error);
128         return NETMANAGER_ERR_INTERNAL;
129     }
130     return NETMANAGER_SUCCESS;
131 }
132 
IsLoopbackSocket(const inet_diag_msg * msg)133 bool NetLinkSocketDiag::IsLoopbackSocket(const inet_diag_msg *msg)
134 {
135     if (msg->idiag_family == AF_INET) {
136         return InLookBack(htonl(msg->id.idiag_src[0])) || InLookBack(htonl(msg->id.idiag_dst[0]));
137     }
138 
139     if (msg->idiag_family == AF_INET6) {
140         const struct in6_addr *src = (const struct in6_addr *)&msg->id.idiag_src;
141         const struct in6_addr *dst = (const struct in6_addr *)&msg->id.idiag_dst;
142         return (IN6_IS_ADDR_V4MAPPED(src) && InLookBack(src->s6_addr32[ADDR_POSITION])) ||
143                (IN6_IS_ADDR_V4MAPPED(dst) && InLookBack(dst->s6_addr32[ADDR_POSITION])) || IN6_IS_ADDR_LOOPBACK(src) ||
144                IN6_IS_ADDR_LOOPBACK(dst);
145     }
146     return false;
147 }
148 
IsMatchNetwork(const inet_diag_msg * msg,const std::string & ipAddr)149 bool NetLinkSocketDiag::IsMatchNetwork(const inet_diag_msg *msg, const std::string &ipAddr)
150 {
151     if (msg->idiag_family == AF_INET) {
152         if (CommonUtils::GetAddrFamily(ipAddr) != AF_INET) {
153             return false;
154         }
155 
156         in_addr_t addr = inet_addr(ipAddr.c_str());
157         if (addr == msg->id.idiag_src[0] || addr == msg->id.idiag_dst[0]) {
158             return true;
159         }
160     }
161 
162     if (msg->idiag_family == AF_INET6) {
163         if (CommonUtils::GetAddrFamily(ipAddr) != AF_INET6) {
164             return false;
165         }
166 
167         char src[DOMAIN_IP_ADDR_MAX_LEN] = {0};
168         char dst[DOMAIN_IP_ADDR_MAX_LEN] = {0};
169         inet_ntop(AF_INET6, msg->id.idiag_src, src, sizeof(src));
170         inet_ntop(AF_INET6, msg->id.idiag_dst, dst, sizeof(dst));
171         if (src == ipAddr || dst == ipAddr) {
172             return true;
173         }
174     }
175     return false;
176 }
177 
ProcessSockDiagDumpResponse(uint8_t proto,const std::string & ipAddr,bool excludeLoopback)178 int32_t NetLinkSocketDiag::ProcessSockDiagDumpResponse(uint8_t proto, const std::string &ipAddr, bool excludeLoopback)
179 {
180     char buf[KERNEL_BUFFER_SIZE] = {0};
181     ssize_t readBytes = read(dumpSock_, buf, sizeof(buf));
182     if (readBytes < 0) {
183         NETNATIVE_LOGE("Failed to read socket, errno:%{public}d, strerror:%{public}s", errno, strerror(errno));
184         return NETMANAGER_ERR_INTERNAL;
185     }
186     while (readBytes > 0) {
187         uint32_t len = static_cast<uint32_t>(readBytes);
188         for (nlmsghdr *nlh = reinterpret_cast<nlmsghdr *>(buf); NLMSG_OK(nlh, len); nlh = NLMSG_NEXT(nlh, len)) {
189             if (nlh->nlmsg_type == NLMSG_ERROR) {
190                 nlmsgerr *err = reinterpret_cast<nlmsgerr *>(NLMSG_DATA(nlh));
191                 NETNATIVE_LOGE("Error netlink msg, errno:%{public}d, strerror:%{public}s", -err->error,
192                                strerror(-err->error));
193                 return err->error;
194             } else if (nlh->nlmsg_type == NLMSG_DONE) {
195                 return NETMANAGER_SUCCESS;
196             } else {
197                 const auto *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
198                 SockDiagDumpCallback(proto, msg, ipAddr, excludeLoopback);
199             }
200         }
201         readBytes = read(dumpSock_, buf, sizeof(buf));
202         if (readBytes < 0) {
203             return -errno;
204         }
205     }
206     return NETMANAGER_SUCCESS;
207 }
208 
SendSockDiagDumpRequest(uint8_t proto,uint8_t family,uint32_t states)209 int32_t NetLinkSocketDiag::SendSockDiagDumpRequest(uint8_t proto, uint8_t family, uint32_t states)
210 {
211     SockDiagRequest request;
212     size_t len = sizeof(request);
213     iovec iov;
214     iov.iov_base = &request;
215     iov.iov_len = len;
216     request.nlh_.nlmsg_type = SOCK_DIAG_BY_FAMILY;
217     request.nlh_.nlmsg_flags = (NLM_F_REQUEST | NLM_F_DUMP);
218     request.nlh_.nlmsg_len = len;
219 
220     request.req_ = {.sdiag_family = family, .sdiag_protocol = proto, .idiag_states = states};
221 
222     ssize_t writeLen = writev(dumpSock_, &iov, (sizeof(iov) / sizeof(iovec)));
223     if (writeLen != static_cast<ssize_t>(len)) {
224         NETNATIVE_LOGE("Write dump request failed errno:%{public}d, strerror:%{public}s", errno, strerror(errno));
225         return NETMANAGER_ERR_INTERNAL;
226     }
227 
228     return GetErrorFromKernel(dumpSock_);
229 }
230 
SockDiagDumpCallback(uint8_t proto,const inet_diag_msg * msg,const std::string & ipAddr,bool excludeLoopback)231 void NetLinkSocketDiag::SockDiagDumpCallback(uint8_t proto, const inet_diag_msg *msg, const std::string &ipAddr,
232                                              bool excludeLoopback)
233 {
234     if (msg == nullptr) {
235         NETNATIVE_LOGE("msg is nullptr");
236         return;
237     }
238 
239     if (socketDestroyType_ == SocketDestroyType::DESTROY_SPECIAL_CELLULAR && msg->idiag_uid != PUSH_UID) {
240         return;
241     }
242 
243     if (socketDestroyType_ == SocketDestroyType::DESTROY_DEFAULT_CELLULAR && msg->idiag_uid == PUSH_UID) {
244         return;
245     }
246 
247     if (excludeLoopback && IsLoopbackSocket(msg)) {
248         NETNATIVE_LOGE("Loop back socket, no need to close.");
249         return;
250     }
251 
252     if (!IsMatchNetwork(msg, ipAddr)) {
253         NETNATIVE_LOG_D("Socket is not associated with the network");
254         return;
255     }
256 
257     ExecuteDestroySocket(proto, msg);
258 }
259 
DestroyLiveSockets(const char * ipAddr,bool excludeLoopback)260 void NetLinkSocketDiag::DestroyLiveSockets(const char *ipAddr, bool excludeLoopback)
261 {
262     NETNATIVE_LOG_D("DestroySocketsLackingNetwork in");
263     if (ipAddr == nullptr) {
264         NETNATIVE_LOGE("Ip address is nullptr.");
265         return;
266     }
267 
268     if (!CreateNetlinkSocket()) {
269         NETNATIVE_LOGE("Create netlink diag socket failed.");
270         return;
271     }
272 
273     const int32_t proto = IPPROTO_TCP;
274     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
275 
276     for (const int family : {AF_INET, AF_INET6}) {
277         int32_t ret = SendSockDiagDumpRequest(proto, family, states);
278         if (ret != NETMANAGER_SUCCESS) {
279             NETNATIVE_LOGE("Failed to dump %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
280             break;
281         }
282         ret = ProcessSockDiagDumpResponse(proto, ipAddr, excludeLoopback);
283         if (ret != NETMANAGER_SUCCESS) {
284             NETNATIVE_LOGE("Failed to destroy %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
285             break;
286         }
287     }
288 
289     NETNATIVE_LOGI("Destroyed %{public}d sockets", socketsDestroyed_);
290 }
291 
SetSocketDestroyType(const std::string & netCapabilities)292 int32_t NetLinkSocketDiag::SetSocketDestroyType(const std::string &netCapabilities)
293 {
294     const std::string capSpecialCellularStr = "NET_CAPABILITY_INTERNAL_DEFAULT";
295     const std::string bearerCellularStr = "BEARER_CELLULAR";
296     if (netCapabilities.find(capSpecialCellularStr) != std::string::npos) {
297         socketDestroyType_ = SocketDestroyType::DESTROY_SPECIAL_CELLULAR;
298     } else if (netCapabilities.find(bearerCellularStr) != std::string::npos) {
299         socketDestroyType_ = SocketDestroyType::DESTROY_DEFAULT_CELLULAR;
300     } else {
301         socketDestroyType_ = SocketDestroyType::DESTROY_DEFAULT;
302     }
303     return 0;
304 }
305 
SockDiagUidDumpCallback(uint8_t proto,const inet_diag_msg * msg,const NetLinkSocketDiag::DestroyFilter & needDestroy)306 void NetLinkSocketDiag::SockDiagUidDumpCallback(uint8_t proto, const inet_diag_msg *msg,
307     const NetLinkSocketDiag::DestroyFilter& needDestroy)
308 {
309     NETNATIVE_LOG_D(" SockDiagUidDumpCallback");
310     if (!needDestroy(msg)) {
311         return;
312     }
313 
314     ExecuteDestroySocket(proto, msg);
315 }
316 
ProcessSockDiagUidDumpResponse(uint8_t proto,const NetLinkSocketDiag::DestroyFilter & needDestroy)317 int32_t NetLinkSocketDiag::ProcessSockDiagUidDumpResponse(uint8_t proto,
318     const NetLinkSocketDiag::DestroyFilter& needDestroy)
319 {
320     NETNATIVE_LOG_D("ProcessSockDiagUidDumpResponse");
321     char buf[KERNEL_BUFFER_SIZE] = {0};
322     ssize_t readBytes = read(dumpSock_, buf, sizeof(buf));
323     if (readBytes < 0) {
324         return NETMANAGER_ERR_INTERNAL;
325     }
326     while (readBytes > 0) {
327         int len = readBytes;
328         for (nlmsghdr *nlh = reinterpret_cast<nlmsghdr *>(buf); NLMSG_OK(nlh, len); nlh = NLMSG_NEXT(nlh, len)) {
329             if (nlh->nlmsg_type == NLMSG_ERROR) {
330                 nlmsgerr *err = reinterpret_cast<nlmsgerr *>(NLMSG_DATA(nlh));
331                 NETNATIVE_LOGE("Error netlink msg, errno:%{public}d, strerror:%{public}s", -err->error,
332                     strerror(-err->error));
333                 return err->error;
334             } else if (nlh->nlmsg_type == NLMSG_DONE) {
335                 NETNATIVE_LOGE("ProcessSockDiagUidDumpResponse nlh->nlmsg_type == NLMSG_DONE");
336                 return NETMANAGER_SUCCESS;
337             } else {
338                 const auto *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
339                 SockDiagUidDumpCallback(proto, msg, needDestroy);
340             }
341         }
342         readBytes = read(dumpSock_, buf, sizeof(buf));
343         if (readBytes < 0) {
344             NETNATIVE_LOGE("ProcessSockDiagUidDumpResponse readBytes < 0");
345             return -errno;
346         }
347     }
348     return NETMANAGER_SUCCESS;
349 }
350 
DestroyLiveSocketsWithUid(const std::string & ipAddr,uint32_t uid)351 void NetLinkSocketDiag::DestroyLiveSocketsWithUid(const std::string &ipAddr, uint32_t uid)
352 {
353     NETNATIVE_LOG_D("TCP-RST DestroyLiveSocketsWithUid, uid:%{public}d", uid);
354     if (!CreateNetlinkSocket()) {
355         NETNATIVE_LOGE("Create netlink diag socket failed.");
356         return;
357     }
358     auto needDestroy = [&] (const inet_diag_msg *msg) {
359         return msg != nullptr && uid == msg->idiag_uid && IsMatchNetwork(msg, ipAddr) && !IsLoopbackSocket(msg);
360     };
361     const int32_t proto = IPPROTO_TCP;
362     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
363     for (const int family : {AF_INET, AF_INET6}) {
364         int32_t ret = SendSockDiagDumpRequest(proto, family, states);
365         if (ret != NETMANAGER_SUCCESS) {
366             NETNATIVE_LOGE("Failed to dump %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
367             break;
368         }
369         ret = ProcessSockDiagUidDumpResponse(proto, needDestroy);
370         if (ret != NETMANAGER_SUCCESS) {
371             NETNATIVE_LOGE("Failed to destroy %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
372             break;
373         }
374     }
375 
376     NETNATIVE_LOG_D("TCP-RST Destroyed %{public}d sockets", socketsDestroyed_);
377 }
378 
DestroyLiveSocketsWithUid(uint32_t uid)379 void NetLinkSocketDiag::DestroyLiveSocketsWithUid(uint32_t uid)
380 {
381     NETNATIVE_LOG_D("TCP-RST DestroyLiveSocketsWithUid, uid:%{public}d", uid);
382     if (!CreateNetlinkSocket()) {
383         NETNATIVE_LOGE("Create netlink diag socket failed.");
384         return;
385     }
386     auto needDestroy = [&] (const inet_diag_msg *msg) -> bool {
387         return msg != nullptr && uid == msg->idiag_uid && !IsLoopbackSocket(msg);
388     };
389     const int32_t proto = IPPROTO_TCP;
390     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
391     for (const int family : {AF_INET, AF_INET6}) {
392         int32_t ret = SendSockDiagDumpRequest(proto, family, states);
393         if (ret != NETMANAGER_SUCCESS) {
394             NETNATIVE_LOGE("Failed to dump %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
395             break;
396         }
397         ret = ProcessSockDiagUidDumpResponse(proto, needDestroy);
398         if (ret != NETMANAGER_SUCCESS) {
399             NETNATIVE_LOGE("Failed to destroy %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
400             break;
401         }
402     }
403 
404     NETNATIVE_LOG_D("TCP-RST Destroyed %{public}d sockets", socketsDestroyed_);
405 }
406 } // namespace nmd
407 } // namespace OHOS