• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "netlink_socket_diag.h"
17 
18 #include <arpa/inet.h>
19 #include <cstring>
20 #include <net/if.h>
21 #include <netinet/tcp.h>
22 #include <sys/uio.h>
23 #include <unistd.h>
24 
25 #include "fwmark.h"
26 #include "net_manager_constants.h"
27 #include "netmanager_base_common_utils.h"
28 #include "netnative_log_wrapper.h"
29 #include "securec.h"
30 
31 namespace OHOS {
32 namespace nmd {
33 using namespace NetManagerStandard;
34 
35 namespace {
36 constexpr uint32_t KERNEL_BUFFER_SIZE = 8192U;
37 constexpr uint8_t ADDR_POSITION = 3U;
38 constexpr int32_t DOMAIN_IP_ADDR_MAX_LEN = 128;
39 constexpr uint32_t LOCKBACK_MASK = 0xff000000;
40 constexpr uint32_t LOCKBACK_DEFINE = 0x7f000000;
41 } // namespace
42 
~NetLinkSocketDiag()43 NetLinkSocketDiag::~NetLinkSocketDiag()
44 {
45     CloseNetlinkSocket();
46 }
47 
InLookBack(uint32_t a)48 bool NetLinkSocketDiag::InLookBack(uint32_t a)
49 {
50     return (a & LOCKBACK_MASK) == LOCKBACK_DEFINE;
51 }
52 
CreateNetlinkSocket()53 bool NetLinkSocketDiag::CreateNetlinkSocket()
54 {
55     dumpSock_ = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
56     if (dumpSock_ < 0) {
57         NETNATIVE_LOGE("Create netlink socket for dump failed, error[%{public}d]: %{public}s", errno, strerror(errno));
58         return false;
59     }
60 
61     destroySock_ = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
62     if (destroySock_ < 0) {
63         NETNATIVE_LOGE("Create netlink socket for destroy failed, error[%{public}d]: %{public}s", errno,
64                        strerror(errno));
65         close(dumpSock_);
66         return false;
67     }
68 
69     sockaddr_nl nl = {.nl_family = AF_NETLINK};
70     if ((connect(dumpSock_, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) < 0) ||
71         (connect(destroySock_, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) < 0)) {
72         NETNATIVE_LOGE("Connect to netlink socket failed, error[%{public}d]: %{public}s", errno, strerror(errno));
73         CloseNetlinkSocket();
74         return false;
75     }
76     return true;
77 }
78 
CloseNetlinkSocket()79 void NetLinkSocketDiag::CloseNetlinkSocket()
80 {
81     close(dumpSock_);
82     close(destroySock_);
83     dumpSock_ = -1;
84     destroySock_ = -1;
85 }
86 
ExecuteDestroySocket(uint8_t proto,const inet_diag_msg * msg)87 int32_t NetLinkSocketDiag::ExecuteDestroySocket(uint8_t proto, const inet_diag_msg *msg)
88 {
89     if (msg == nullptr) {
90         NETNATIVE_LOGE("inet_diag_msg is nullptr");
91         return NETMANAGER_ERR_LOCAL_PTR_NULL;
92     }
93 
94     SockDiagRequest request;
95     request.nlh_.nlmsg_type = SOCK_DESTROY;
96     request.nlh_.nlmsg_flags = NLM_F_REQUEST;
97     request.nlh_.nlmsg_len = sizeof(request);
98 
99     request.req_ = {.sdiag_family = msg->idiag_family,
100                     .sdiag_protocol = proto,
101                     .idiag_states = static_cast<uint32_t>(1 << msg->idiag_state),
102                     .id = msg->id};
103     ssize_t writeLen = write(destroySock_, &request, sizeof(request));
104     if (writeLen < static_cast<ssize_t>(sizeof(request))) {
105         NETNATIVE_LOGE("Write destroy request to socket failed errno[%{public}d]: strerror:%{public}s", errno,
106                        strerror(errno));
107         return NETMANAGER_ERR_INTERNAL;
108     }
109 
110     int32_t ret = GetErrorFromKernel(destroySock_);
111     if (ret == NETMANAGER_SUCCESS) {
112         socketsDestroyed_++;
113     }
114     return ret;
115 }
116 
GetErrorFromKernel(int32_t fd)117 int32_t NetLinkSocketDiag::GetErrorFromKernel(int32_t fd)
118 {
119     Ack ack;
120     ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
121     if (bytesread < 0) {
122         NETNATIVE_LOGE("Get error info from kernel failed errno[%{public}d]: strerror:%{public}s", errno,
123                        strerror(errno));
124         return (errno == EAGAIN) ? NETMANAGER_SUCCESS : -errno;
125     }
126     if (bytesread == static_cast<ssize_t>(sizeof(ack)) && ack.hdr_.nlmsg_type == NLMSG_ERROR) {
127         recv(fd, &ack, sizeof(ack), 0);
128         NETNATIVE_LOGE("Receive NLMSG_ERROR:[%{public}d] from kernel", ack.err_.error);
129         return NETMANAGER_ERR_INTERNAL;
130     }
131     return NETMANAGER_SUCCESS;
132 }
133 
IsLoopbackSocket(const inet_diag_msg * msg)134 bool NetLinkSocketDiag::IsLoopbackSocket(const inet_diag_msg *msg)
135 {
136     if (msg->idiag_family == AF_INET) {
137         return InLookBack(htonl(msg->id.idiag_src[0])) || InLookBack(htonl(msg->id.idiag_dst[0]));
138     }
139 
140     if (msg->idiag_family == AF_INET6) {
141         const struct in6_addr *src = (const struct in6_addr *)&msg->id.idiag_src;
142         const struct in6_addr *dst = (const struct in6_addr *)&msg->id.idiag_dst;
143         return (IN6_IS_ADDR_V4MAPPED(src) && InLookBack(src->s6_addr32[ADDR_POSITION])) ||
144                (IN6_IS_ADDR_V4MAPPED(dst) && InLookBack(dst->s6_addr32[ADDR_POSITION])) || IN6_IS_ADDR_LOOPBACK(src) ||
145                IN6_IS_ADDR_LOOPBACK(dst);
146     }
147     return false;
148 }
149 
IsMatchNetwork(const inet_diag_msg * msg,const std::string & ipAddr)150 bool NetLinkSocketDiag::IsMatchNetwork(const inet_diag_msg *msg, const std::string &ipAddr)
151 {
152     if (msg->idiag_family == AF_INET) {
153         if (CommonUtils::GetAddrFamily(ipAddr) != AF_INET) {
154             return false;
155         }
156 
157         in_addr_t addr = inet_addr(ipAddr.c_str());
158         if (addr == msg->id.idiag_src[0] || addr == msg->id.idiag_dst[0]) {
159             return true;
160         }
161     }
162 
163     if (msg->idiag_family == AF_INET6) {
164         if (CommonUtils::GetAddrFamily(ipAddr) != AF_INET6) {
165             return false;
166         }
167 
168         char src[DOMAIN_IP_ADDR_MAX_LEN] = {0};
169         char dst[DOMAIN_IP_ADDR_MAX_LEN] = {0};
170         inet_ntop(AF_INET6, msg->id.idiag_src, src, sizeof(src));
171         inet_ntop(AF_INET6, msg->id.idiag_dst, dst, sizeof(dst));
172         if (src == ipAddr || dst == ipAddr) {
173             return true;
174         }
175     }
176     return false;
177 }
178 
ProcessSockDiagDumpResponse(uint8_t proto,const std::string & ipAddr,bool excludeLoopback)179 int32_t NetLinkSocketDiag::ProcessSockDiagDumpResponse(uint8_t proto, const std::string &ipAddr, bool excludeLoopback)
180 {
181     char buf[KERNEL_BUFFER_SIZE] = {0};
182     ssize_t readBytes = read(dumpSock_, buf, sizeof(buf));
183     if (readBytes < 0) {
184         NETNATIVE_LOGE("Failed to read socket, errno:%{public}d, strerror:%{public}s", errno, strerror(errno));
185         return NETMANAGER_ERR_INTERNAL;
186     }
187     while (readBytes > 0) {
188         uint32_t len = readBytes;
189         for (nlmsghdr *nlh = reinterpret_cast<nlmsghdr *>(buf); NLMSG_OK(nlh, len); nlh = NLMSG_NEXT(nlh, len)) {
190             if (nlh->nlmsg_type == NLMSG_ERROR) {
191                 nlmsgerr *err = reinterpret_cast<nlmsgerr *>(NLMSG_DATA(nlh));
192                 NETNATIVE_LOGE("Error netlink msg, errno:%{public}d, strerror:%{public}s", -err->error,
193                                strerror(-err->error));
194                 return err->error;
195             } else if (nlh->nlmsg_type == NLMSG_DONE) {
196                 return NETMANAGER_SUCCESS;
197             } else {
198                 const auto *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
199                 SockDiagDumpCallback(proto, msg, ipAddr, excludeLoopback);
200             }
201         }
202         readBytes = read(dumpSock_, buf, sizeof(buf));
203         if (readBytes < 0) {
204             return -errno;
205         }
206     }
207     return NETMANAGER_SUCCESS;
208 }
209 
SendSockDiagDumpRequest(uint8_t proto,uint8_t family,uint32_t states)210 int32_t NetLinkSocketDiag::SendSockDiagDumpRequest(uint8_t proto, uint8_t family, uint32_t states)
211 {
212     SockDiagRequest request;
213     size_t len = sizeof(request);
214     iovec iov;
215     iov.iov_base = &request;
216     iov.iov_len = len;
217     request.nlh_.nlmsg_type = SOCK_DIAG_BY_FAMILY;
218     request.nlh_.nlmsg_flags = (NLM_F_REQUEST | NLM_F_DUMP);
219     request.nlh_.nlmsg_len = len;
220 
221     request.req_ = {.sdiag_family = family, .sdiag_protocol = proto, .idiag_states = states};
222 
223     ssize_t writeLen = writev(dumpSock_, &iov, (sizeof(iov) / sizeof(iovec)));
224     if (writeLen != static_cast<ssize_t>(len)) {
225         NETNATIVE_LOGE("Write dump request failed errno:%{public}d, strerror:%{public}s", errno, strerror(errno));
226         return NETMANAGER_ERR_INTERNAL;
227     }
228 
229     return GetErrorFromKernel(dumpSock_);
230 }
231 
SockDiagDumpCallback(uint8_t proto,const inet_diag_msg * msg,const std::string & ipAddr,bool excludeLoopback)232 void NetLinkSocketDiag::SockDiagDumpCallback(uint8_t proto, const inet_diag_msg *msg, const std::string &ipAddr,
233                                              bool excludeLoopback)
234 {
235     if (msg == nullptr) {
236         NETNATIVE_LOGE("msg is nullptr");
237         return;
238     }
239 
240     if (excludeLoopback && IsLoopbackSocket(msg)) {
241         NETNATIVE_LOGE("Loop back socket, no need to close.");
242         return;
243     }
244 
245     if (!IsMatchNetwork(msg, ipAddr)) {
246         NETNATIVE_LOGE("Socket is not associated with the network");
247         return;
248     }
249 
250     ExecuteDestroySocket(proto, msg);
251 }
252 
DestroyLiveSockets(const char * ipAddr,bool excludeLoopback)253 void NetLinkSocketDiag::DestroyLiveSockets(const char *ipAddr, bool excludeLoopback)
254 {
255     NETNATIVE_LOG_D("DestroySocketsLackingNetwork in");
256     if (ipAddr == nullptr) {
257         NETNATIVE_LOGE("Ip address is nullptr.");
258         return;
259     }
260 
261     if (!CreateNetlinkSocket()) {
262         NETNATIVE_LOGE("Create netlink diag socket failed.");
263         return;
264     }
265 
266     const int32_t proto = IPPROTO_TCP;
267     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
268 
269     for (const int family : {AF_INET, AF_INET6}) {
270         int32_t ret = SendSockDiagDumpRequest(proto, family, states);
271         if (ret != NETMANAGER_SUCCESS) {
272             NETNATIVE_LOGE("Failed to dump %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
273             break;
274         }
275         ret = ProcessSockDiagDumpResponse(proto, ipAddr, excludeLoopback);
276         if (ret != NETMANAGER_SUCCESS) {
277             NETNATIVE_LOGE("Failed to destroy %{public}s sockets", family == AF_INET ? "IPv4" : "IPv6");
278             break;
279         }
280     }
281 
282     NETNATIVE_LOG_D("Destroyed %{public}d sockets", socketsDestroyed_);
283 }
284 } // namespace nmd
285 } // namespace OHOS