• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "dhcp_arp_checker.h"
17 
18 #include <cerrno>
19 #include <chrono>
20 #include <fcntl.h>
21 #include <net/if_arp.h>
22 #include <net/if.h>
23 #include <netpacket/packet.h>
24 #include <poll.h>
25 #include <sys/socket.h>
26 #include <unistd.h>
27 
28 #include "securec.h"
29 #include "dhcp_common_utils.h"
30 #include "dhcp_logger.h"
31 
32 namespace OHOS {
33 namespace DHCP {
34 DEFINE_DHCPLOG_DHCP_LABEL("DhcpArpChecker");
35 constexpr const char *DHCP_ARP_CHECKER_THREAD = "DHCP_ARP_CHECKER_THREAD";
36 constexpr int32_t MIN_WAIT_TIME_MS_THREAD = 2000;
37 constexpr int32_t MAX_LENGTH = 1500;
38 constexpr int32_t OPT_SUCC = 0;
39 constexpr int32_t OPT_FAIL = -1;
40 
DhcpArpChecker()41 DhcpArpChecker::DhcpArpChecker() : m_isSocketCreated(false), m_socketFd(-1), m_ifaceIndex(0), m_protocol(0)
42 {
43     DHCP_LOGI("DhcpArpChecker()");
44     dhcpArpCheckerThread_ = std::make_unique<DhcpThread>(DHCP_ARP_CHECKER_THREAD);
45 }
46 
~DhcpArpChecker()47 DhcpArpChecker::~DhcpArpChecker()
48 {
49     DHCP_LOGI("~DhcpArpChecker()");
50     Stop();
51     if (dhcpArpCheckerThread_) {
52         dhcpArpCheckerThread_.reset();
53     }
54 }
55 
Start(std::string & ifname,std::string & hwAddr,std::string & senderIp,std::string & targetIp)56 bool DhcpArpChecker::Start(std::string& ifname, std::string& hwAddr, std::string& senderIp, std::string& targetIp)
57 {
58     if (m_isSocketCreated) {
59         Stop();
60     }
61     uint8_t mac[ETH_ALEN + sizeof(uint32_t)];
62     if (sscanf_s(hwAddr.c_str(), "%02x:%02x:%02x:%02x:%02x:%02x",
63         &mac[0], &mac[1], &mac[2], &mac[3], &mac[4], &mac[5]) != ETH_ALEN) {  // mac address
64         DHCP_LOGE("invalid hwAddr:%{private}s", hwAddr.c_str());
65         if (memset_s(mac, sizeof(mac), 0, sizeof(mac)) != EOK) {
66             DHCP_LOGE("ArpChecker memset fail");
67         }
68     }
69     auto func = [this, ifname]() {
70         return this->CreateSocket(ifname.c_str(), ETH_P_ARP);
71     };
72     auto ret = dhcpArpCheckerThread_->PostSyncTimeOutTask(func, MIN_WAIT_TIME_MS_THREAD);
73     if (ret == false) {
74         DHCP_LOGE("DhcpArpChecker CreateSocket failed");
75         return false;
76     }
77     inet_aton(senderIp.c_str(), &m_localIpAddr);
78     if (memcpy_s(m_localMacAddr, ETH_ALEN, mac, ETH_ALEN) != EOK) {
79         DHCP_LOGE("DhcpArpChecker memcpy fail");
80         return false;
81     }
82     if (memset_s(m_l2Broadcast, ETH_ALEN, 0xFF, ETH_ALEN) != EOK) {
83         DHCP_LOGE("DhcpArpChecker memset fail");
84         return false;
85     }
86     inet_aton(targetIp.c_str(), &m_targetIpAddr);
87     return true;
88 }
89 
Stop()90 void DhcpArpChecker::Stop()
91 {
92     if (!m_isSocketCreated) {
93         return;
94     }
95     auto func = [this]() {
96         return this->CloseSocket();
97     };
98     dhcpArpCheckerThread_->PostSyncTimeOutTask(func, MIN_WAIT_TIME_MS_THREAD);
99     m_isSocketCreated = false;
100 }
101 
SetArpPacket(ArpPacket & arpPacket,bool isFillSenderIp)102 bool DhcpArpChecker::SetArpPacket(ArpPacket& arpPacket, bool isFillSenderIp)
103 {
104     arpPacket.ar_hrd = htons(ARPHRD_ETHER);
105     arpPacket.ar_pro = htons(ETH_P_IP);
106     arpPacket.ar_hln = ETH_ALEN;
107     arpPacket.ar_pln = IPV4_ALEN;
108     arpPacket.ar_op = htons(ARPOP_REQUEST);
109     if (memcpy_s(arpPacket.ar_sha, ETH_ALEN, m_localMacAddr, ETH_ALEN) != EOK) {
110         DHCP_LOGE("DoArpCheck memcpy fail");
111         return false;
112     }
113     if (isFillSenderIp) {
114         if (memcpy_s(arpPacket.ar_spa, IPV4_ALEN, &m_localIpAddr, sizeof(m_localIpAddr)) != EOK) {
115             DHCP_LOGE("DoArpCheck memcpy fail");
116             return false;
117         }
118     } else {
119         if (memset_s(arpPacket.ar_spa, IPV4_ALEN, 0, IPV4_ALEN) != EOK) {
120             DHCP_LOGE("DoArpCheck memset fail");
121             return false;
122         }
123     }
124     if (memset_s(arpPacket.ar_tha, ETH_ALEN, 0, ETH_ALEN) != EOK) {
125         DHCP_LOGE("DoArpCheck memset fail");
126         return false;
127     }
128     if (memcpy_s(arpPacket.ar_tpa, IPV4_ALEN, &m_targetIpAddr, sizeof(m_targetIpAddr)) != EOK) {
129         DHCP_LOGE("DoArpCheck memcpy fail");
130         return false;
131     }
132     return true;
133 }
134 
DoArpCheck(int32_t timeoutMillis,bool isFillSenderIp,uint64_t & timeCost)135 bool DhcpArpChecker::DoArpCheck(int32_t timeoutMillis, bool isFillSenderIp, uint64_t &timeCost)
136 {
137     if (!m_isSocketCreated) {
138         DHCP_LOGE("DoArpCheck failed, socket not created");
139         return false;
140     }
141 
142     struct ArpPacket arpPacket;
143     if (!SetArpPacket(arpPacket, isFillSenderIp)) {
144         return false;
145     }
146 
147     if (SendData(reinterpret_cast<uint8_t *>(&arpPacket), sizeof(arpPacket), m_l2Broadcast) != 0) {
148         return false;
149     }
150 
151     timeCost = 0;
152     int32_t leftMillis = timeoutMillis;
153     uint8_t recvBuff[MAX_LENGTH];
154 
155     // Add overall timeout tracking to prevent infinite loop
156     std::chrono::steady_clock::time_point overallStartTime = std::chrono::steady_clock::now();
157 
158     while (leftMillis > 0) {
159         std::chrono::steady_clock::time_point startTime = std::chrono::steady_clock::now();
160         int32_t readLen = RecvData(recvBuff, sizeof(recvBuff), leftMillis);
161 
162         // Always calculate elapsed time after each operation
163         std::chrono::steady_clock::time_point current = std::chrono::steady_clock::now();
164         int64_t elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(current - startTime).count();
165         if (elapsed <= 0) {
166             elapsed = 1;  // Force at least 1ms progress
167         }
168         leftMillis -= static_cast<int32_t>(elapsed);
169 
170         // Double check overall timeout to prevent any edge cases
171         int64_t overallElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
172             current - overallStartTime).count();
173         if (overallElapsed >= timeoutMillis) {
174             DHCP_LOGW("DoArpCheck overall timeout reached");
175             break;
176         }
177 
178         if (readLen < 0) {
179             DHCP_LOGE("readLen < 0, stop arp");
180             return false;
181         }
182         if (readLen < static_cast<int32_t>(sizeof(struct ArpPacket))) {
183             continue;
184         }
185 
186         struct ArpPacket *respPacket = reinterpret_cast<struct ArpPacket*>(recvBuff);
187         if (ntohs(respPacket->ar_hrd) == ARPHRD_ETHER && ntohs(respPacket->ar_pro) == ETH_P_IP &&
188             respPacket->ar_hln == ETH_ALEN && respPacket->ar_pln == IPV4_ALEN &&
189             ntohs(respPacket->ar_op) == ARPOP_REPLY &&
190             memcmp(respPacket->ar_sha, m_localMacAddr, ETH_ALEN) != 0 &&
191             memcmp(respPacket->ar_spa, &m_targetIpAddr, IPV4_ALEN) == 0) {
192             timeCost = static_cast<uint64_t>(overallElapsed);
193             return true;
194         }
195     }
196     return false;
197 }
198 
GetGwMacAddrList(int32_t timeoutMillis,bool isFillSenderIp,std::vector<std::string> & gwMacLists)199 void DhcpArpChecker::GetGwMacAddrList(int32_t timeoutMillis, bool isFillSenderIp, std::vector<std::string>& gwMacLists)
200 {
201     gwMacLists.clear();
202     if (!m_isSocketCreated) {
203         DHCP_LOGE("GetGwMacAddrList failed, socket not created");
204         return;
205     }
206     struct ArpPacket arpPacket;
207     if (!SetArpPacket(arpPacket, isFillSenderIp)) {
208         DHCP_LOGE("GetGwMacAddrList SetArpPacket failed");
209         return;
210     }
211 
212     if (SendData(reinterpret_cast<uint8_t *>(&arpPacket), sizeof(arpPacket), m_l2Broadcast) != 0) {
213         DHCP_LOGE("GetGwMacAddrList SendData failed");
214         return;
215     }
216     int32_t readLen = 0;
217     int32_t leftMillis = timeoutMillis;
218     uint8_t recvBuff[MAX_LENGTH];
219 
220     // Add overall timeout tracking to prevent infinite loop
221     std::chrono::steady_clock::time_point overallStartTime = std::chrono::steady_clock::now();
222 
223     while (leftMillis > 0) {
224         std::chrono::steady_clock::time_point startTime = std::chrono::steady_clock::now();
225         readLen = RecvData(recvBuff, sizeof(recvBuff), leftMillis);
226         if (readLen >= static_cast<int32_t>(sizeof(struct ArpPacket))) {
227             struct ArpPacket *respPacket = reinterpret_cast<struct ArpPacket*>(recvBuff);
228             if (ntohs(respPacket->ar_hrd) == ARPHRD_ETHER &&
229                 ntohs(respPacket->ar_pro) == ETH_P_IP &&
230                 respPacket->ar_hln == ETH_ALEN &&
231                 respPacket->ar_pln == IPV4_ALEN &&
232                 ntohs(respPacket->ar_op) == ARPOP_REPLY &&
233                 memcmp(respPacket->ar_sha, m_localMacAddr, ETH_ALEN) != 0 &&
234                 memcmp(respPacket->ar_spa, &m_targetIpAddr, IPV4_ALEN) == 0) {
235                 std::string gwMacAddr = MacArray2Str(respPacket->ar_sha, ETH_ALEN);
236                 SaveGwMacAddr(gwMacAddr, gwMacLists);
237             }
238         }
239 
240         std::chrono::steady_clock::time_point current = std::chrono::steady_clock::now();
241         int64_t elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(current - startTime).count();
242 
243         // Ensure minimum progress to prevent infinite loop
244         if (elapsed <= 0) {
245             elapsed = 1;  // Force at least 1ms progress
246         }
247 
248         leftMillis -= static_cast<int32_t>(elapsed);
249 
250         // Double check overall timeout as safety net
251         int64_t overallElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
252             current - overallStartTime).count();
253         if (overallElapsed >= timeoutMillis) {
254             DHCP_LOGW("GetGwMacAddrList overall timeout reached");
255             break;
256         }
257     }
258 }
259 
SaveGwMacAddr(std::string gwMacAddr,std::vector<std::string> & gwMacLists)260 void DhcpArpChecker::SaveGwMacAddr(std::string gwMacAddr, std::vector<std::string>& gwMacLists)
261 {
262     auto it = std::find(gwMacLists.begin(), gwMacLists.end(), gwMacAddr);
263     if (!gwMacAddr.empty() && (it == gwMacLists.end())) {
264         gwMacLists.push_back(gwMacAddr);
265     }
266 }
267 
CreateSocket(const char * iface,uint16_t protocol)268 int32_t DhcpArpChecker::CreateSocket(const char *iface, uint16_t protocol)
269 {
270     if (iface == nullptr) {
271         DHCP_LOGE("iface is null");
272         return OPT_FAIL;
273     }
274 
275     int32_t ifaceIndex = static_cast<int32_t>(if_nametoindex(iface));
276     if (ifaceIndex == 0) {
277         DHCP_LOGE("get iface index fail: %{public}s", iface);
278         return OPT_FAIL;
279     }
280     if (ifaceIndex > INTEGER_MAX) {
281         DHCP_LOGE("ifaceIndex > max interger, fail:%{public}s ifaceIndex:%{public}d", iface, ifaceIndex);
282         return OPT_FAIL;
283     }
284     int32_t socketFd = socket(PF_PACKET, SOCK_DGRAM, htons(protocol));
285     if (socketFd < 0) {
286         DHCP_LOGE("create socket fail");
287         return OPT_FAIL;
288     }
289 
290     if (SetNonBlock(socketFd)) {
291         DHCP_LOGE("set non block fail");
292         (void)close(socketFd);
293         return OPT_FAIL;
294     }
295 
296     struct sockaddr_ll rawAddr;
297     rawAddr.sll_ifindex = ifaceIndex;
298     rawAddr.sll_protocol = htons(protocol);
299     rawAddr.sll_family = AF_PACKET;
300 
301     int32_t ret = bind(socketFd, reinterpret_cast<struct sockaddr *>(&rawAddr), sizeof(rawAddr));
302     if (ret != 0) {
303         DHCP_LOGE("bind fail");
304         (void)close(socketFd);
305         return OPT_FAIL;
306     }
307 
308     m_socketFd = socketFd;
309     m_ifaceIndex = ifaceIndex;
310     m_protocol = protocol;
311     m_isSocketCreated = true;
312     return OPT_SUCC;
313 }
314 
SendData(uint8_t * buff,int32_t count,uint8_t * destHwaddr)315 int32_t DhcpArpChecker::SendData(uint8_t *buff, int32_t count, uint8_t *destHwaddr)
316 {
317     if (buff == nullptr || destHwaddr == nullptr) {
318         DHCP_LOGE("buff or dest hwaddr is null");
319         return OPT_FAIL;
320     }
321 
322     if (m_socketFd < 0 || m_ifaceIndex == 0) {
323         DHCP_LOGE("invalid socket fd");
324         return OPT_FAIL;
325     }
326 
327     struct sockaddr_ll rawAddr;
328     (void)memset_s(&rawAddr, sizeof(rawAddr), 0, sizeof(rawAddr));
329     rawAddr.sll_ifindex = m_ifaceIndex;
330     rawAddr.sll_protocol = htons(m_protocol);
331     rawAddr.sll_family = AF_PACKET;
332     if (memcpy_s(rawAddr.sll_addr, sizeof(rawAddr.sll_addr), destHwaddr, ETH_ALEN) != EOK) {
333         DHCP_LOGE("Send: memcpy fail");
334         return OPT_FAIL;
335     }
336 
337     int32_t ret;
338     do {
339         ret = sendto(m_socketFd, buff, count, 0, reinterpret_cast<struct sockaddr *>(&rawAddr), sizeof(rawAddr));
340         if (ret == -1) {
341             DHCP_LOGE("Send: sendto fail");
342             if (errno != EINTR) {
343                 break;
344             }
345         }
346     } while (ret == -1);
347     return ret > 0 ? OPT_SUCC : OPT_FAIL;
348 }
349 
RecvData(uint8_t * buff,int32_t count,int32_t timeoutMillis)350 int32_t DhcpArpChecker::RecvData(uint8_t *buff, int32_t count, int32_t timeoutMillis)
351 {
352     DHCP_LOGI("RecvData timeoutMillis:%{public}d", timeoutMillis);
353     if (m_socketFd < 0) {
354         DHCP_LOGE("invalid socket fd");
355         return -1;
356     }
357 
358     pollfd fds[1];
359     fds[0].fd = m_socketFd;
360     fds[0].events = POLLIN;
361     if (poll(fds, 1, timeoutMillis) <= 0) {
362         DHCP_LOGW("RecvData poll timeout or error");
363         return 0;
364     }
365     DHCP_LOGI("RecvData poll finished");
366     int32_t nBytes;
367     do {
368         nBytes = read(m_socketFd, buff, count);
369         if (nBytes == -1) {
370             if (errno != EINTR) {
371                 break;
372             }
373         }
374     } while (nBytes == -1);
375     return nBytes < 0 ? 0 : nBytes;
376 }
377 
CloseSocket(void)378 int32_t DhcpArpChecker::CloseSocket(void)
379 {
380     int32_t ret = OPT_FAIL;
381     if (m_socketFd >= 0) {
382         ret = close(m_socketFd);
383         if (ret != OPT_SUCC) {
384             DHCP_LOGE("close fail.");
385         }
386     }
387     m_socketFd = -1;
388     m_ifaceIndex = 0;
389     m_protocol = 0;
390     m_isSocketCreated = false;
391     return ret;
392 }
393 
SetNonBlock(int32_t fd)394 bool DhcpArpChecker::SetNonBlock(int32_t fd)
395 {
396     int32_t ret = fcntl(fd, F_GETFL);
397     if (ret < 0) {
398         return false;
399     }
400 
401     uint32_t flags = (static_cast<uint32_t>(ret) | O_NONBLOCK);
402     return fcntl(fd, F_SETFL, flags);
403 }
404 }
405 }