• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #include <algorithm>
16 #include "netmanager_base_common_utils.h"
17 #include "dns_param_cache.h"
18 #include "netnative_log_wrapper.h"
19 
20 #ifdef FEATURE_NET_FIREWALL_ENABLE
21 #include "bpf_netfirewall.h"
22 #include "netfirewall_parcel.h"
23 #include <ctime>
24 #endif
25 
26 namespace OHOS::nmd {
27 using namespace OHOS::NetManagerStandard::CommonUtils;
28 namespace {
GetVectorData(const std::vector<std::string> & data,std::string & result)29 void GetVectorData(const std::vector<std::string> &data, std::string &result)
30 {
31     result.append("{ ");
32     std::for_each(data.begin(), data.end(), [&result](const auto &str) { result.append(ToAnonymousIp(str) + ", "); });
33     result.append("}\n");
34 }
35 constexpr int RES_TIMEOUT = 4000;    // min. milliseconds between retries
36 constexpr int RES_DEFAULT_RETRY = 2; // Default
37 } // namespace
38 
DnsParamCache()39 DnsParamCache::DnsParamCache() : defaultNetId_(0) {}
40 
GetInstance()41 DnsParamCache &DnsParamCache::GetInstance()
42 {
43     static DnsParamCache instance;
44     return instance;
45 }
46 
SelectNameservers(const std::vector<std::string> & servers)47 std::vector<std::string> DnsParamCache::SelectNameservers(const std::vector<std::string> &servers)
48 {
49     std::vector<std::string> res = servers;
50     if (res.size() > MAX_SERVER_NUM_EXT - 1) {
51         res.resize(MAX_SERVER_NUM_EXT - 1);
52     }
53     return res;
54 }
55 
CreateCacheForNet(uint16_t netId,bool isVpnNet)56 int32_t DnsParamCache::CreateCacheForNet(uint16_t netId, bool isVpnNet)
57 {
58     NETNATIVE_LOGI("DnsParamCache::CreateCacheForNet, netid:%{public}d,", netId);
59     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
60     auto it = serverConfigMap_.find(netId);
61     if (it != serverConfigMap_.end()) {
62         NETNATIVE_LOGE("DnsParamCache::CreateCacheForNet, netid already exist, no need to create");
63         return -EEXIST;
64     }
65     serverConfigMap_[netId].SetNetId(netId);
66     if (isVpnNet) {
67         NETNATIVE_LOGI("DnsParamCache::CreateCacheForNet clear all dns cache when vpn net create");
68         for (auto iterator = serverConfigMap_.begin(); iterator != serverConfigMap_.end(); iterator++) {
69             iterator->second.GetCache().Clear();
70         }
71     }
72     return 0;
73 }
74 
DestroyNetworkCache(uint16_t netId,bool isVpnNet)75 int32_t DnsParamCache::DestroyNetworkCache(uint16_t netId, bool isVpnNet)
76 {
77     NETNATIVE_LOGI("DnsParamCache::DestroyNetworkCache, netid:%{public}d, %{public}d", netId, isVpnNet);
78     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
79     auto it = serverConfigMap_.find(netId);
80     if (it == serverConfigMap_.end()) {
81         return -ENOENT;
82     }
83     serverConfigMap_.erase(it);
84     if (defaultNetId_ == netId) {
85         defaultNetId_ = 0;
86     }
87     if (isVpnNet) {
88         NETNATIVE_LOGI("DnsParamCache::DestroyNetworkCache clear all dns cache when vpn net destroy");
89         for (auto it = serverConfigMap_.begin(); it != serverConfigMap_.end(); it++) {
90             it->second.GetCache().Clear();
91         }
92     }
93     return 0;
94 }
95 
SetResolverConfig(uint16_t netId,uint16_t baseTimeoutMsec,uint8_t retryCount,const std::vector<std::string> & servers,const std::vector<std::string> & domains)96 int32_t DnsParamCache::SetResolverConfig(uint16_t netId, uint16_t baseTimeoutMsec, uint8_t retryCount,
97                                          const std::vector<std::string> &servers,
98                                          const std::vector<std::string> &domains)
99 {
100     std::vector<std::string> nameservers = SelectNameservers(servers);
101     NETNATIVE_LOG_D("DnsParamCache::SetResolverConfig, netid:%{public}d, numServers:%{public}d,", netId,
102                     static_cast<int>(nameservers.size()));
103 
104     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
105 
106     // select_domains
107     auto it = serverConfigMap_.find(netId);
108     if (it == serverConfigMap_.end()) {
109         NETNATIVE_LOGE("DnsParamCache::SetResolverConfig failed, netid is non-existent");
110         return -ENOENT;
111     }
112 
113     auto oldDnsServers = it->second.GetServers();
114     std::sort(oldDnsServers.begin(), oldDnsServers.end());
115 
116     auto newDnsServers = servers;
117     std::sort(newDnsServers.begin(), newDnsServers.end());
118 
119     if (oldDnsServers != newDnsServers) {
120         it->second.GetCache().Clear();
121     }
122 
123     it->second.SetNetId(netId);
124     it->second.SetServers(servers);
125     it->second.SetDomains(domains);
126     if (retryCount == 0) {
127         it->second.SetRetryCount(RES_DEFAULT_RETRY);
128     } else {
129         it->second.SetRetryCount(retryCount);
130     }
131     if (baseTimeoutMsec == 0) {
132         it->second.SetTimeoutMsec(RES_TIMEOUT);
133     } else {
134         it->second.SetTimeoutMsec(baseTimeoutMsec);
135     }
136     return 0;
137 }
138 
SetDefaultNetwork(uint16_t netId)139 void DnsParamCache::SetDefaultNetwork(uint16_t netId)
140 {
141     defaultNetId_ = netId;
142 }
143 
EnableIpv6(uint16_t netId)144 void DnsParamCache::EnableIpv6(uint16_t netId)
145 {
146     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
147     auto it = serverConfigMap_.find(netId);
148     if (it == serverConfigMap_.end()) {
149         DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
150         return;
151     }
152 
153     it->second.EnableIpv6();
154 }
155 
IsIpv6Enable(uint16_t netId)156 bool DnsParamCache::IsIpv6Enable(uint16_t netId)
157 {
158     if (netId == 0) {
159         netId = defaultNetId_;
160     }
161 
162     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
163     auto it = serverConfigMap_.find(netId);
164     if (it == serverConfigMap_.end()) {
165         DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
166         return false;
167     }
168 
169     return it->second.IsIpv6Enable();
170 }
171 
GetResolverConfig(uint16_t netId,std::vector<std::string> & servers,std::vector<std::string> & domains,uint16_t & baseTimeoutMsec,uint8_t & retryCount)172 int32_t DnsParamCache::GetResolverConfig(uint16_t netId, std::vector<std::string> &servers,
173                                          std::vector<std::string> &domains, uint16_t &baseTimeoutMsec,
174                                          uint8_t &retryCount)
175 {
176     NETNATIVE_LOG_D("DnsParamCache::GetResolverConfig no uid");
177     if (netId == 0) {
178         netId = defaultNetId_;
179         NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
180     }
181 
182     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
183     auto it = serverConfigMap_.find(netId);
184     if (it == serverConfigMap_.end()) {
185         DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
186         return -ENOENT;
187     }
188 
189     servers = it->second.GetServers();
190 #ifdef FEATURE_NET_FIREWALL_ENABLE
191     std::vector<std::string> dns;
192     if (GetDnsServersByAppUid(GetCallingUid(), dns)) {
193         DNS_CONFIG_PRINT("GetResolverConfig hit netfirewall");
194         servers.assign(dns.begin(), dns.end());
195     }
196 #endif
197     domains = it->second.GetDomains();
198     baseTimeoutMsec = it->second.GetTimeoutMsec();
199     retryCount = it->second.GetRetryCount();
200 
201     return 0;
202 }
203 
GetResolverConfig(uint16_t netId,uint32_t uid,std::vector<std::string> & servers,std::vector<std::string> & domains,uint16_t & baseTimeoutMsec,uint8_t & retryCount)204 int32_t DnsParamCache::GetResolverConfig(uint16_t netId, uint32_t uid, std::vector<std::string> &servers,
205                                          std::vector<std::string> &domains, uint16_t &baseTimeoutMsec,
206                                          uint8_t &retryCount)
207 {
208     NETNATIVE_LOG_D("DnsParamCache::GetResolverConfig has uid");
209     if (netId == 0) {
210         netId = defaultNetId_;
211         NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
212     }
213 
214     {
215         std::lock_guard<ffrt::mutex> guard(cacheMutex_);
216         for (auto mem : vpnUidRanges_) {
217             if (static_cast<int32_t>(uid) >= mem.begin_ && static_cast<int32_t>(uid) <= mem.end_) {
218                 NETNATIVE_LOG_D("is vpn hap");
219                 auto it = serverConfigMap_.find(vpnNetId_);
220                 if (it == serverConfigMap_.end()) {
221                     NETNATIVE_LOG_D("vpn get Config failed: not have vpnnetid:%{public}d,", vpnNetId_);
222                     break;
223                 }
224                 servers = it->second.GetServers();
225 #ifdef FEATURE_NET_FIREWALL_ENABLE
226                 std::vector<std::string> dns;
227                 if (GetDnsServersByAppUid(GetCallingUid(), dns)) {
228                     DNS_CONFIG_PRINT("GetResolverConfig hit netfirewall");
229                     servers.assign(dns.begin(), dns.end());
230                 }
231 #endif
232                 domains = it->second.GetDomains();
233                 baseTimeoutMsec = it->second.GetTimeoutMsec();
234                 retryCount = it->second.GetRetryCount();
235                 return 0;
236             }
237         }
238     }
239     return GetResolverConfig(netId, servers, domains, baseTimeoutMsec, retryCount);
240 }
241 
GetDefaultNetwork() const242 int32_t DnsParamCache::GetDefaultNetwork() const
243 {
244     return defaultNetId_;
245 }
246 
SetDnsCache(uint16_t netId,const std::string & hostName,const AddrInfo & addrInfo)247 void DnsParamCache::SetDnsCache(uint16_t netId, const std::string &hostName, const AddrInfo &addrInfo)
248 {
249     if (netId == 0) {
250         netId = defaultNetId_;
251     }
252     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
253 #ifdef FEATURE_NET_FIREWALL_ENABLE
254     int32_t appUid = static_cast<int32_t>(GetCallingUid());
255     bool isMatchAllow = false;
256     if (IsInterceptDomain(appUid, hostName, isMatchAllow)) {
257         DNS_CONFIG_PRINT("SetDnsCache failed: domain was Intercepted: %{public}s,", hostName.c_str());
258         return;
259     }
260     if (isMatchAllow && (addrInfo.aiFamily == AF_INET || addrInfo.aiFamily == AF_INET6)) {
261         NetAddrInfo netInfo;
262         netInfo.aiFamily = addrInfo.aiFamily;
263         if (addrInfo.aiFamily == AF_INET) {
264             netInfo.aiAddr.sin = addrInfo.aiAddr.sin.sin_addr;
265         } else {
266             memcpy_s(&netInfo.aiAddr.sin6, sizeof(addrInfo.aiAddr.sin6.sin6_addr), &addrInfo.aiAddr.sin6.sin6_addr,
267                      sizeof(addrInfo.aiAddr.sin6.sin6_addr));
268         }
269         OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->AddDomainCache(netInfo);
270     }
271 #endif
272     auto it = serverConfigMap_.find(netId);
273     if (it == serverConfigMap_.end()) {
274         DNS_CONFIG_PRINT("SetDnsCache failed: netid is not have netid:%{public}d,", netId);
275         return;
276     }
277 
278     it->second.GetCache().Put(hostName, addrInfo);
279 }
280 
GetDnsCache(uint16_t netId,const std::string & hostName)281 std::vector<AddrInfo> DnsParamCache::GetDnsCache(uint16_t netId, const std::string &hostName)
282 {
283     if (netId == 0) {
284         netId = defaultNetId_;
285     }
286 
287     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
288 #ifdef FEATURE_NET_FIREWALL_ENABLE
289     int32_t appUid = static_cast<int32_t>(GetCallingUid());
290     bool isMatchAllow = false;
291     if (IsInterceptDomain(appUid, hostName, isMatchAllow)) {
292         NotifyDomianIntercept(appUid, hostName);
293         AddrInfo fakeAddr = { 0 };
294         fakeAddr.aiFamily = AF_UNSPEC;
295         fakeAddr.aiAddr.sin.sin_family = AF_UNSPEC;
296         fakeAddr.aiAddr.sin.sin_addr.s_addr = INADDR_NONE;
297         fakeAddr.aiAddrLen = sizeof(struct sockaddr_in);
298         return { fakeAddr };
299     }
300 #endif
301 
302     auto it = serverConfigMap_.find(netId);
303     if (it == serverConfigMap_.end()) {
304         DNS_CONFIG_PRINT("GetDnsCache failed: netid is not have netid:%{public}d,", netId);
305         return {};
306     }
307 
308     return it->second.GetCache().Get(hostName);
309 }
310 
SetCacheDelayed(uint16_t netId,const std::string & hostName)311 void DnsParamCache::SetCacheDelayed(uint16_t netId, const std::string &hostName)
312 {
313     if (netId == 0) {
314         netId = defaultNetId_;
315     }
316 
317     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
318     auto it = serverConfigMap_.find(netId);
319     if (it == serverConfigMap_.end()) {
320         DNS_CONFIG_PRINT("SetCacheDelayed failed: netid is not have netid:%{public}d,", netId);
321         return;
322     }
323 
324     it->second.SetCacheDelayed(hostName);
325 }
326 
AddUidRange(uint32_t netId,const std::vector<NetManagerStandard::UidRange> & uidRanges)327 int32_t DnsParamCache::AddUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)
328 {
329     std::lock_guard<ffrt::mutex> guard(uidRangeMutex_);
330     NETNATIVE_LOG_D("DnsParamCache::AddUidRange size = [%{public}zu]", uidRanges.size());
331     vpnNetId_ = netId;
332     auto middle = vpnUidRanges_.insert(vpnUidRanges_.end(), uidRanges.begin(), uidRanges.end());
333     std::inplace_merge(vpnUidRanges_.begin(), middle, vpnUidRanges_.end());
334     return 0;
335 }
336 
DelUidRange(uint32_t netId,const std::vector<NetManagerStandard::UidRange> & uidRanges)337 int32_t DnsParamCache::DelUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)
338 {
339     std::lock_guard<ffrt::mutex> guard(uidRangeMutex_);
340     NETNATIVE_LOG_D("DnsParamCache::DelUidRange size = [%{public}zu]", uidRanges.size());
341     vpnNetId_ = 0;
342     auto end = std::set_difference(vpnUidRanges_.begin(), vpnUidRanges_.end(), uidRanges.begin(),
343                                    uidRanges.end(), vpnUidRanges_.begin());
344     vpnUidRanges_.erase(end, vpnUidRanges_.end());
345     return 0;
346 }
347 
IsVpnOpen() const348 bool DnsParamCache::IsVpnOpen() const
349 {
350     return vpnUidRanges_.size();
351 }
352 
353 #ifdef FEATURE_NET_FIREWALL_ENABLE
GetUserId(int32_t appUid)354 int32_t DnsParamCache::GetUserId(int32_t appUid)
355 {
356     int32_t userId = appUid / USER_ID_DIVIDOR;
357     return userId > 0 ? userId : currentUserId_;
358 }
359 
GetDnsServersByAppUid(int32_t appUid,std::vector<std::string> & servers)360 bool DnsParamCache::GetDnsServersByAppUid(int32_t appUid, std::vector<std::string> &servers)
361 {
362     if (netFirewallDnsRuleMap_.empty()) {
363         return false;
364     }
365     DNS_CONFIG_PRINT("GetDnsServersByAppUid: appUid=%{public}d", appUid);
366     auto it = netFirewallDnsRuleMap_.find(appUid);
367     if (it == netFirewallDnsRuleMap_.end()) {
368         // if appUid not found, try to find invalid appUid=0;
369         it = netFirewallDnsRuleMap_.find(0);
370     }
371     if (it != netFirewallDnsRuleMap_.end()) {
372         int32_t userId = GetUserId(appUid);
373         std::vector<sptr<NetFirewallDnsRule>> rules = it->second;
374         for (const auto &rule : rules) {
375             if (rule->userId != userId) {
376                 continue;
377             }
378             servers.emplace_back(rule->primaryDns);
379             servers.emplace_back(rule->standbyDns);
380         }
381         return true;
382     }
383     return false;
384 }
385 
SetFirewallRules(NetFirewallRuleType type,const std::vector<sptr<NetFirewallBaseRule>> & ruleList,bool isFinish)386 int32_t DnsParamCache::SetFirewallRules(NetFirewallRuleType type,
387                                         const std::vector<sptr<NetFirewallBaseRule>> &ruleList, bool isFinish)
388 {
389     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
390     NETNATIVE_LOGI("SetFirewallRules: size=%{public}zu isFinish=%{public}" PRId32, ruleList.size(), isFinish);
391     if (ruleList.empty()) {
392         NETNATIVE_LOGE("SetFirewallRules: rules is empty");
393         return -1;
394     }
395     int32_t ret = 0;
396     switch (type) {
397         case NetFirewallRuleType::RULE_DNS: {
398             for (const auto &rule : ruleList) {
399                 firewallDnsRules_.emplace_back(firewall_rule_cast<NetFirewallDnsRule>(rule));
400             }
401             if (isFinish) {
402                 ret = SetFirewallDnsRules(firewallDnsRules_);
403                 firewallDnsRules_.clear();
404             }
405             break;
406         }
407         case NetFirewallRuleType::RULE_DOMAIN: {
408             ClearAllDnsCache();
409             for (const auto &rule : ruleList) {
410                 firewallDomainRules_.emplace_back(firewall_rule_cast<NetFirewallDomainRule>(rule));
411             }
412             if (isFinish) {
413                 ret = SetFirewallDomainRules(firewallDomainRules_);
414                 firewallDomainRules_.clear();
415                 OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
416             }
417             break;
418         }
419         default:
420             break;
421     }
422     return ret;
423 }
424 
SetFirewallDnsRules(const std::vector<sptr<NetFirewallDnsRule>> & ruleList)425 int32_t DnsParamCache::SetFirewallDnsRules(const std::vector<sptr<NetFirewallDnsRule>> &ruleList)
426 {
427     for (const auto &rule : ruleList) {
428         std::vector<sptr<NetFirewallDnsRule>> rules;
429         auto it = netFirewallDnsRuleMap_.find(rule->appUid);
430         if (it != netFirewallDnsRuleMap_.end()) {
431             rules = it->second;
432         }
433         rules.emplace_back(std::move(rule));
434         netFirewallDnsRuleMap_.emplace(rule->appUid, std::move(rules));
435     }
436     return 0;
437 }
438 
GetFirewallRuleAction(int32_t appUid,const std::vector<sptr<NetFirewallDomainRule>> & rules)439 FirewallRuleAction DnsParamCache::GetFirewallRuleAction(int32_t appUid,
440                                                         const std::vector<sptr<NetFirewallDomainRule>> &rules)
441 {
442     int32_t userId = GetUserId(appUid);
443     for (const auto &rule : rules) {
444         if (rule->userId != userId) {
445             continue;
446         }
447         if ((rule->appUid && appUid == rule->appUid) || !rule->appUid) {
448             return rule->ruleAction;
449         }
450     }
451 
452     return FirewallRuleAction::RULE_INVALID;
453 }
454 
checkEmpty4InterceptDomain(const std::string & hostName)455 bool DnsParamCache::checkEmpty4InterceptDomain(const std::string &hostName)
456 {
457     if (hostName.empty()) {
458         return true;
459     }
460     if (!netFirewallDomainRulesAllowMap_.empty() || !netFirewallDomainRulesDenyMap_.empty()) {
461         return false;
462     }
463     if (domainAllowLsmTrie_ && !domainAllowLsmTrie_->Empty()) {
464         return false;
465     }
466     return !domainDenyLsmTrie_ || domainDenyLsmTrie_->Empty();
467 }
468 
IsInterceptDomain(int32_t appUid,const std::string & hostName,bool & isMatchAllow)469 bool DnsParamCache::IsInterceptDomain(int32_t appUid, const std::string &hostName, bool &isMatchAllow)
470 {
471     if (checkEmpty4InterceptDomain(hostName)) {
472         return false;
473     }
474     std::string host = hostName.substr(0, hostName.find(' '));
475     DNS_CONFIG_PRINT("IsInterceptDomain: appUid: %{public}d, hostName: %{private}s", appUid, host.c_str());
476     std::transform(host.begin(), host.end(), host.begin(), ::tolower);
477     std::vector<sptr<NetFirewallDomainRule>> rules;
478     FirewallRuleAction exactAllowAction = FirewallRuleAction::RULE_INVALID;
479     auto it = netFirewallDomainRulesAllowMap_.find(host);
480     if (it != netFirewallDomainRulesAllowMap_.end()) {
481         rules = it->second;
482         exactAllowAction = GetFirewallRuleAction(appUid, rules);
483     }
484     FirewallRuleAction exactDenyAction = FirewallRuleAction::RULE_INVALID;
485     auto iter = netFirewallDomainRulesDenyMap_.find(host);
486     if (iter != netFirewallDomainRulesDenyMap_.end()) {
487         rules = iter->second;
488         exactDenyAction = GetFirewallRuleAction(appUid, rules);
489     }
490     FirewallRuleAction wildcardAllowAction = FirewallRuleAction::RULE_INVALID;
491     if (domainAllowLsmTrie_->LongestSuffixMatch(host, rules)) {
492         wildcardAllowAction = GetFirewallRuleAction(appUid, rules);
493     }
494     FirewallRuleAction wildcardDenyAction = FirewallRuleAction::RULE_INVALID;
495     if (domainDenyLsmTrie_->LongestSuffixMatch(host, rules)) {
496         wildcardDenyAction = GetFirewallRuleAction(appUid, rules);
497     }
498     isMatchAllow = (exactAllowAction != FirewallRuleAction::RULE_INVALID) ||
499                    (wildcardAllowAction != FirewallRuleAction::RULE_INVALID);
500     bool isDeny = (exactDenyAction != FirewallRuleAction::RULE_INVALID) ||
501                   (wildcardDenyAction != FirewallRuleAction::RULE_INVALID);
502     if (isMatchAllow) {
503         // Apply default rules in case of conflict
504         return isDeny && (firewallDefaultAction_ == FirewallRuleAction::RULE_DENY);
505     }
506     return isDeny;
507 }
508 
SetFirewallDefaultAction(FirewallRuleAction inDefault,FirewallRuleAction outDefault)509 int32_t DnsParamCache::SetFirewallDefaultAction(FirewallRuleAction inDefault, FirewallRuleAction outDefault)
510 {
511     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
512     DNS_CONFIG_PRINT("SetFirewallDefaultAction: firewallDefaultAction_: %{public}d", (int)outDefault);
513     firewallDefaultAction_ = outDefault;
514     return 0;
515 }
516 
BuildFirewallDomainLsmTrie(const sptr<NetFirewallDomainRule> & rule,const std::string & domain)517 void DnsParamCache::BuildFirewallDomainLsmTrie(const sptr<NetFirewallDomainRule> &rule, const std::string &domain)
518 {
519     std::vector<sptr<NetFirewallDomainRule>> rules;
520     std::string suffix(domain);
521     auto wildcardCharIndex = suffix.find('*');
522     if (wildcardCharIndex != std::string::npos) {
523         suffix = suffix.substr(wildcardCharIndex + 1);
524     }
525     DNS_CONFIG_PRINT("BuildFirewallDomainLsmTrie: suffix: %{public}s", suffix.c_str());
526     std::transform(suffix.begin(), suffix.end(), suffix.begin(), ::tolower);
527     if (rule->ruleAction == FirewallRuleAction::RULE_DENY) {
528         if (domainDenyLsmTrie_->LongestSuffixMatch(suffix, rules)) {
529             rules.emplace_back(std::move(rule));
530             domainDenyLsmTrie_->Update(suffix, rules);
531             return;
532         }
533         rules.emplace_back(std::move(rule));
534         domainDenyLsmTrie_->Insert(suffix, rules);
535     } else {
536         if (domainAllowLsmTrie_->LongestSuffixMatch(suffix, rules)) {
537             rules.emplace_back(std::move(rule));
538             domainAllowLsmTrie_->Update(suffix, rules);
539             return;
540         }
541         rules.emplace_back(std::move(rule));
542         domainAllowLsmTrie_->Insert(suffix, rules);
543     }
544 }
545 
BuildFirewallDomainMap(const sptr<NetFirewallDomainRule> & rule,const std::string & raw)546 void DnsParamCache::BuildFirewallDomainMap(const sptr<NetFirewallDomainRule> &rule, const std::string &raw)
547 {
548     DNS_CONFIG_PRINT("BuildFirewallDomainMap: domain: %{public}s", raw.c_str());
549     std::string domain(raw);
550     std::vector<sptr<NetFirewallDomainRule>> rules;
551     std::transform(domain.begin(), domain.end(), domain.begin(), ::tolower);
552     if (rule->ruleAction == FirewallRuleAction::RULE_DENY) {
553         auto it = netFirewallDomainRulesDenyMap_.find(domain);
554         if (it != netFirewallDomainRulesDenyMap_.end()) {
555             rules = it->second;
556         }
557 
558         rules.emplace_back(std::move(rule));
559         netFirewallDomainRulesDenyMap_.emplace(domain, std::move(rules));
560     } else {
561         auto it = netFirewallDomainRulesAllowMap_.find(domain);
562         if (it != netFirewallDomainRulesAllowMap_.end()) {
563             rules = it->second;
564         }
565 
566         rules.emplace_back(rule);
567         netFirewallDomainRulesAllowMap_.emplace(domain, std::move(rules));
568     }
569 }
570 
SetFirewallDomainRules(const std::vector<sptr<NetFirewallDomainRule>> & ruleList)571 int32_t DnsParamCache::SetFirewallDomainRules(const std::vector<sptr<NetFirewallDomainRule>> &ruleList)
572 {
573     if (!domainAllowLsmTrie_) {
574         domainAllowLsmTrie_ =
575             std::make_shared<NetManagerStandard::SuffixMatchTrie<std::vector<sptr<NetFirewallDomainRule>>>>();
576     }
577     if (!domainDenyLsmTrie_) {
578         domainDenyLsmTrie_ =
579             std::make_shared<NetManagerStandard::SuffixMatchTrie<std::vector<sptr<NetFirewallDomainRule>>>>();
580     }
581     for (const auto &rule : ruleList) {
582         for (const auto &param : rule->domains) {
583             if (param.isWildcard) {
584                 BuildFirewallDomainLsmTrie(rule, param.domain);
585             } else {
586                 BuildFirewallDomainMap(rule, param.domain);
587             }
588         }
589     }
590     return 0;
591 }
592 
ClearFirewallRules(NetFirewallRuleType type)593 int32_t DnsParamCache::ClearFirewallRules(NetFirewallRuleType type)
594 {
595     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
596     switch (type) {
597         case NetFirewallRuleType::RULE_DNS:
598             firewallDnsRules_.clear();
599             netFirewallDnsRuleMap_.clear();
600             break;
601         case NetFirewallRuleType::RULE_DOMAIN: {
602             firewallDomainRules_.clear();
603             netFirewallDomainRulesAllowMap_.clear();
604             netFirewallDomainRulesDenyMap_.clear();
605             if (domainAllowLsmTrie_) {
606                 domainAllowLsmTrie_ = nullptr;
607             }
608             if (domainDenyLsmTrie_) {
609                 domainDenyLsmTrie_ = nullptr;
610             }
611             OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
612             break;
613         }
614         case NetFirewallRuleType::RULE_ALL: {
615             firewallDnsRules_.clear();
616             netFirewallDnsRuleMap_.clear();
617             firewallDomainRules_.clear();
618             netFirewallDomainRulesAllowMap_.clear();
619             netFirewallDomainRulesDenyMap_.clear();
620             if (domainAllowLsmTrie_) {
621                 domainAllowLsmTrie_ = nullptr;
622             }
623             if (domainDenyLsmTrie_) {
624                 domainDenyLsmTrie_ = nullptr;
625             }
626             OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
627             break;
628         }
629         default:
630             break;
631     }
632     return 0;
633 }
634 
NotifyDomianIntercept(int32_t appUid,const std::string & hostName)635 void DnsParamCache::NotifyDomianIntercept(int32_t appUid, const std::string &hostName)
636 {
637     if (hostName.empty()) {
638         return;
639     }
640     std::string host = hostName.substr(0, hostName.find(' '));
641     NETNATIVE_LOGI("NotifyDomianIntercept: appUid: %{public}d, hostName: %{private}s", appUid, host.c_str());
642     sptr<InterceptRecord> record = sptr<InterceptRecord>::MakeSptr();
643     record->time = (int32_t)time(NULL);
644     record->appUid = appUid;
645     record->domain = host;
646 
647     if (oldRecord_ != nullptr && (record->time - oldRecord_->time) < INTERCEPT_BUFF_INTERVAL_SEC) {
648         if (record->appUid == oldRecord_->appUid && record->domain == oldRecord_->domain) {
649             return;
650         }
651     }
652     oldRecord_ = record;
653     for (const auto &callback : callbacks_) {
654         callback->OnIntercept(record);
655     }
656 }
657 
RegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> & callback)658 int32_t DnsParamCache::RegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)
659 {
660     if (!callback) {
661         return -1;
662     }
663 
664     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
665     callbacks_.emplace_back(callback);
666 
667     return 0;
668 }
669 
UnRegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> & callback)670 int32_t DnsParamCache::UnRegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)
671 {
672     if (!callback) {
673         return -1;
674     }
675 
676     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
677     for (auto it = callbacks_.begin(); it != callbacks_.end(); ++it) {
678         if (*it == callback) {
679             callbacks_.erase(it);
680             return 0;
681         }
682     }
683     return -1;
684 }
685 
SetFirewallCurrentUserId(int32_t userId)686 int32_t DnsParamCache::SetFirewallCurrentUserId(int32_t userId)
687 {
688     currentUserId_ = userId;
689     ClearAllDnsCache();
690     return 0;
691 }
692 
ClearAllDnsCache()693 void DnsParamCache::ClearAllDnsCache()
694 {
695     NETNATIVE_LOGI("ClearAllDnsCache");
696     for (auto it = serverConfigMap_.begin(); it != serverConfigMap_.end(); it++) {
697         it->second.GetCache().Clear();
698     }
699 }
700 #endif
701 
GetDumpInfo(std::string & info)702 void DnsParamCache::GetDumpInfo(std::string &info)
703 {
704     std::string dnsData;
705     static const std::string TAB = "  ";
706     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
707     std::for_each(serverConfigMap_.begin(), serverConfigMap_.end(), [&dnsData](const auto &serverConfig) {
708         dnsData.append(TAB + "NetId: " + std::to_string(serverConfig.second.GetNetId()) + "\n");
709         dnsData.append(TAB + "TimeoutMsec: " + std::to_string(serverConfig.second.GetTimeoutMsec()) + "\n");
710         dnsData.append(TAB + "RetryCount: " + std::to_string(serverConfig.second.GetRetryCount()) + "\n");
711         dnsData.append(TAB + "Servers:");
712         GetVectorData(serverConfig.second.GetServers(), dnsData);
713         dnsData.append(TAB + "Domains:");
714         GetVectorData(serverConfig.second.GetDomains(), dnsData);
715     });
716     info.append(dnsData);
717 }
718 
SetUserDefinedServerFlag(uint16_t netId,bool flag)719 int32_t DnsParamCache::SetUserDefinedServerFlag(uint16_t netId, bool flag)
720 {
721     NETNATIVE_LOGI("DnsParamCache::SetUserDefinedServerFlag, netid:%{public}d, flag:%{public}d,", netId, flag);
722 
723     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
724     // select_domains
725     auto it = serverConfigMap_.find(netId);
726     if (it == serverConfigMap_.end()) {
727         NETNATIVE_LOGE("DnsParamCache::SetUserDefinedServerFlag failed, netid is non-existent");
728         return -ENOENT;
729     }
730     it->second.SetUserDefinedServerFlag(flag);
731     return 0;
732 }
733 
GetUserDefinedServerFlag(uint16_t netId,bool & flag)734 int32_t DnsParamCache::GetUserDefinedServerFlag(uint16_t netId, bool &flag)
735 {
736     if (netId == 0) {
737         netId = defaultNetId_;
738         NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
739     }
740     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
741     auto it = serverConfigMap_.find(netId);
742     if (it == serverConfigMap_.end()) {
743         DNS_CONFIG_PRINT("GetUserDefinedServerFlag failed: netid is not have netid:%{public}d,", netId);
744         return -ENOENT;
745     }
746     flag = it->second.IsUserDefinedServer();
747     return 0;
748 }
749 
GetUserDefinedServerFlag(uint16_t netId,bool & flag,uint32_t uid)750 int32_t DnsParamCache::GetUserDefinedServerFlag(uint16_t netId, bool &flag, uint32_t uid)
751 {
752     if (netId == 0) {
753         netId = defaultNetId_;
754         NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
755     }
756     {
757         std::lock_guard<ffrt::mutex> guard(cacheMutex_);
758         for (auto mem : vpnUidRanges_) {
759             if (static_cast<int32_t>(uid) >= mem.begin_ && static_cast<int32_t>(uid) <= mem.end_) {
760                 NETNATIVE_LOG_D("is vpn hap");
761                 auto it = serverConfigMap_.find(vpnNetId_);
762                 if (it == serverConfigMap_.end()) {
763                     NETNATIVE_LOG_D("vpn get Config failed: not have vpnnetid:%{public}d,", vpnNetId_);
764                     break;
765                 }
766                 flag = it->second.IsUserDefinedServer();
767                 return 0;
768             }
769         }
770         auto it = serverConfigMap_.find(netId);
771         if (it == serverConfigMap_.end()) {
772             DNS_CONFIG_PRINT("GetUserDefinedServerFlag failed: netid is not have netid:%{public}d,", netId);
773             return -ENOENT;
774         }
775     }
776     return GetUserDefinedServerFlag(netId, flag);
777 }
778 
IsUseVpnDns(uint32_t uid)779 bool DnsParamCache::IsUseVpnDns(uint32_t uid)
780 {
781     for (auto mem : vpnUidRanges_) {
782         if (static_cast<int32_t>(uid) >= mem.begin_ && static_cast<int32_t>(uid) <= mem.end_) {
783             auto it = serverConfigMap_.find(vpnNetId_);
784             if (it == serverConfigMap_.end()) {
785                 return false;
786             }
787             return true;
788         }
789     }
790     return false;
791 }
792 
FlushDnsCache(uint16_t netId)793 int32_t DnsParamCache::FlushDnsCache(uint16_t netId)
794 {
795     if (netId == 0) {
796         netId = defaultNetId_;
797         NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
798     }
799     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
800     auto it = serverConfigMap_.find(netId);
801     if (it == serverConfigMap_.end()) {
802         DNS_CONFIG_PRINT("FlushDnsCache failed: netid is non-existent netid:%{public}d,", netId);
803         return -ENOENT;
804     }
805     it->second.GetCache().Clear();
806     return 0;
807 }
808 } // namespace OHOS::nmd
809