• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "dns_param_cache.h"
17 
18 #include <algorithm>
19 
20 #include "netmanager_base_common_utils.h"
21 #ifdef FEATURE_NET_FIREWALL_ENABLE
22 #include "bpf_netfirewall.h"
23 #include "netfirewall_parcel.h"
24 #include <ctime>
25 #endif
26 
27 namespace OHOS::nmd {
28 using namespace OHOS::NetManagerStandard::CommonUtils;
29 namespace {
GetVectorData(const std::vector<std::string> & data,std::string & result)30 void GetVectorData(const std::vector<std::string> &data, std::string &result)
31 {
32     result.append("{ ");
33     std::for_each(data.begin(), data.end(), [&result](const auto &str) { result.append(ToAnonymousIp(str) + ", "); });
34     result.append("}\n");
35 }
36 constexpr int RES_TIMEOUT = 4000;    // min. milliseconds between retries
37 constexpr int RES_DEFAULT_RETRY = 2; // Default
38 } // namespace
39 
DnsParamCache()40 DnsParamCache::DnsParamCache() : defaultNetId_(0) {}
41 
GetInstance()42 DnsParamCache &DnsParamCache::GetInstance()
43 {
44     static DnsParamCache instance;
45     return instance;
46 }
47 
SelectNameservers(const std::vector<std::string> & servers)48 std::vector<std::string> DnsParamCache::SelectNameservers(const std::vector<std::string> &servers)
49 {
50     std::vector<std::string> res = servers;
51     if (res.size() > MAX_SERVER_NUM - 1) {
52         res.resize(MAX_SERVER_NUM - 1);
53     }
54     return res;
55 }
56 
CreateCacheForNet(uint16_t netId,bool isVpnNet)57 int32_t DnsParamCache::CreateCacheForNet(uint16_t netId, bool isVpnNet)
58 {
59     NETNATIVE_LOGI("DnsParamCache::CreateCacheForNet, netid:%{public}d,", netId);
60     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
61     auto it = serverConfigMap_.find(netId);
62     if (it != serverConfigMap_.end()) {
63         NETNATIVE_LOGE("DnsParamCache::CreateCacheForNet, netid already exist, no need to create");
64         return -EEXIST;
65     }
66     serverConfigMap_[netId].SetNetId(netId);
67     if (isVpnNet) {
68         NETNATIVE_LOGI("DnsParamCache::CreateCacheForNet clear all dns cache when vpn net create");
69         for (auto iterator = serverConfigMap_.begin(); iterator != serverConfigMap_.end(); iterator++) {
70             iterator->second.GetCache().Clear();
71         }
72     }
73     return 0;
74 }
75 
DestroyNetworkCache(uint16_t netId,bool isVpnNet)76 int32_t DnsParamCache::DestroyNetworkCache(uint16_t netId, bool isVpnNet)
77 {
78     NETNATIVE_LOGI("DnsParamCache::DestroyNetworkCache, netid:%{public}d, %{public}d", netId, isVpnNet);
79     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
80     auto it = serverConfigMap_.find(netId);
81     if (it == serverConfigMap_.end()) {
82         return -ENOENT;
83     }
84     serverConfigMap_.erase(it);
85     if (defaultNetId_ == netId) {
86         defaultNetId_ = 0;
87     }
88     if (isVpnNet) {
89         NETNATIVE_LOGI("DnsParamCache::DestroyNetworkCache clear all dns cache when vpn net destroy");
90         for (auto it = serverConfigMap_.begin(); it != serverConfigMap_.end(); it++) {
91             it->second.GetCache().Clear();
92         }
93     }
94     return 0;
95 }
96 
SetResolverConfig(uint16_t netId,uint16_t baseTimeoutMsec,uint8_t retryCount,const std::vector<std::string> & servers,const std::vector<std::string> & domains)97 int32_t DnsParamCache::SetResolverConfig(uint16_t netId, uint16_t baseTimeoutMsec, uint8_t retryCount,
98                                          const std::vector<std::string> &servers,
99                                          const std::vector<std::string> &domains)
100 {
101     std::vector<std::string> nameservers = SelectNameservers(servers);
102     NETNATIVE_LOG_D("DnsParamCache::SetResolverConfig, netid:%{public}d, numServers:%{public}d,", netId,
103                     static_cast<int>(nameservers.size()));
104 
105     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
106 
107     // select_domains
108     auto it = serverConfigMap_.find(netId);
109     if (it == serverConfigMap_.end()) {
110         NETNATIVE_LOGE("DnsParamCache::SetResolverConfig failed, netid is non-existent");
111         return -ENOENT;
112     }
113 
114     auto oldDnsServers = it->second.GetServers();
115     std::sort(oldDnsServers.begin(), oldDnsServers.end());
116 
117     auto newDnsServers = servers;
118     std::sort(newDnsServers.begin(), newDnsServers.end());
119 
120     if (oldDnsServers != newDnsServers) {
121         it->second.GetCache().Clear();
122     }
123 
124     it->second.SetNetId(netId);
125     it->second.SetServers(servers);
126     it->second.SetDomains(domains);
127     if (retryCount == 0) {
128         it->second.SetRetryCount(RES_DEFAULT_RETRY);
129     } else {
130         it->second.SetRetryCount(retryCount);
131     }
132     if (baseTimeoutMsec == 0) {
133         it->second.SetTimeoutMsec(RES_TIMEOUT);
134     } else {
135         it->second.SetTimeoutMsec(baseTimeoutMsec);
136     }
137     return 0;
138 }
139 
SetDefaultNetwork(uint16_t netId)140 void DnsParamCache::SetDefaultNetwork(uint16_t netId)
141 {
142     defaultNetId_ = netId;
143 }
144 
EnableIpv6(uint16_t netId)145 void DnsParamCache::EnableIpv6(uint16_t netId)
146 {
147     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
148     auto it = serverConfigMap_.find(netId);
149     if (it == serverConfigMap_.end()) {
150         DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
151         return;
152     }
153 
154     it->second.EnableIpv6();
155 }
156 
IsIpv6Enable(uint16_t netId)157 bool DnsParamCache::IsIpv6Enable(uint16_t netId)
158 {
159     if (netId == 0) {
160         netId = defaultNetId_;
161     }
162 
163     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
164     auto it = serverConfigMap_.find(netId);
165     if (it == serverConfigMap_.end()) {
166         DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
167         return false;
168     }
169 
170     return it->second.IsIpv6Enable();
171 }
172 
GetResolverConfig(uint16_t netId,std::vector<std::string> & servers,std::vector<std::string> & domains,uint16_t & baseTimeoutMsec,uint8_t & retryCount)173 int32_t DnsParamCache::GetResolverConfig(uint16_t netId, std::vector<std::string> &servers,
174                                          std::vector<std::string> &domains, uint16_t &baseTimeoutMsec,
175                                          uint8_t &retryCount)
176 {
177     NETNATIVE_LOG_D("DnsParamCache::GetResolverConfig no uid");
178     if (netId == 0) {
179         netId = defaultNetId_;
180         NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
181     }
182 
183     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
184     auto it = serverConfigMap_.find(netId);
185     if (it == serverConfigMap_.end()) {
186         DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
187         return -ENOENT;
188     }
189 
190     servers = it->second.GetServers();
191 #ifdef FEATURE_NET_FIREWALL_ENABLE
192     std::vector<std::string> dns;
193     if (GetDnsServersByAppUid(GetCallingUid(), dns)) {
194         DNS_CONFIG_PRINT("GetResolverConfig hit netfirewall");
195         servers.assign(dns.begin(), dns.end());
196     }
197 #endif
198     domains = it->second.GetDomains();
199     baseTimeoutMsec = it->second.GetTimeoutMsec();
200     retryCount = it->second.GetRetryCount();
201 
202     return 0;
203 }
204 
GetResolverConfig(uint16_t netId,uint32_t uid,std::vector<std::string> & servers,std::vector<std::string> & domains,uint16_t & baseTimeoutMsec,uint8_t & retryCount)205 int32_t DnsParamCache::GetResolverConfig(uint16_t netId, uint32_t uid, std::vector<std::string> &servers,
206                                          std::vector<std::string> &domains, uint16_t &baseTimeoutMsec,
207                                          uint8_t &retryCount)
208 {
209     NETNATIVE_LOG_D("DnsParamCache::GetResolverConfig has uid");
210     if (netId == 0) {
211         netId = defaultNetId_;
212         NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
213     }
214 
215     {
216         std::lock_guard<ffrt::mutex> guard(cacheMutex_);
217         for (auto mem : vpnUidRanges_) {
218             if (static_cast<int32_t>(uid) >= mem.begin_ && static_cast<int32_t>(uid) <= mem.end_) {
219                 NETNATIVE_LOG_D("is vpn hap");
220                 auto it = serverConfigMap_.find(vpnNetId_);
221                 if (it == serverConfigMap_.end()) {
222                     NETNATIVE_LOG_D("vpn get Config failed: not have vpnnetid:%{public}d,", vpnNetId_);
223                     break;
224                 }
225                 servers = it->second.GetServers();
226 #ifdef FEATURE_NET_FIREWALL_ENABLE
227                 std::vector<std::string> dns;
228                 if (GetDnsServersByAppUid(GetCallingUid(), dns)) {
229                     DNS_CONFIG_PRINT("GetResolverConfig hit netfirewall");
230                     servers.assign(dns.begin(), dns.end());
231                 }
232 #endif
233                 domains = it->second.GetDomains();
234                 baseTimeoutMsec = it->second.GetTimeoutMsec();
235                 retryCount = it->second.GetRetryCount();
236                 return 0;
237             }
238         }
239     }
240     return GetResolverConfig(netId, servers, domains, baseTimeoutMsec, retryCount);
241 }
242 
GetDefaultNetwork() const243 int32_t DnsParamCache::GetDefaultNetwork() const
244 {
245     return defaultNetId_;
246 }
247 
SetDnsCache(uint16_t netId,const std::string & hostName,const AddrInfo & addrInfo)248 void DnsParamCache::SetDnsCache(uint16_t netId, const std::string &hostName, const AddrInfo &addrInfo)
249 {
250     if (netId == 0) {
251         netId = defaultNetId_;
252     }
253     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
254 #ifdef FEATURE_NET_FIREWALL_ENABLE
255     int32_t appUid = static_cast<int32_t>(GetCallingUid());
256     bool isMatchAllow = false;
257     if (IsInterceptDomain(appUid, hostName, isMatchAllow)) {
258         DNS_CONFIG_PRINT("SetDnsCache failed: domain was Intercepted: %{public}s,", hostName.c_str());
259         return;
260     }
261     if (isMatchAllow && (addrInfo.aiFamily == AF_INET || addrInfo.aiFamily == AF_INET6)) {
262         NetAddrInfo netInfo;
263         netInfo.aiFamily = addrInfo.aiFamily;
264         if (addrInfo.aiFamily == AF_INET) {
265             netInfo.aiAddr.sin = addrInfo.aiAddr.sin.sin_addr;
266         } else {
267             memcpy_s(&netInfo.aiAddr.sin6, sizeof(addrInfo.aiAddr.sin6.sin6_addr), &addrInfo.aiAddr.sin6.sin6_addr,
268                      sizeof(addrInfo.aiAddr.sin6.sin6_addr));
269         }
270         OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->AddDomainCache(netInfo);
271     }
272 #endif
273     auto it = serverConfigMap_.find(netId);
274     if (it == serverConfigMap_.end()) {
275         DNS_CONFIG_PRINT("SetDnsCache failed: netid is not have netid:%{public}d,", netId);
276         return;
277     }
278 
279     it->second.GetCache().Put(hostName, addrInfo);
280 }
281 
GetDnsCache(uint16_t netId,const std::string & hostName)282 std::vector<AddrInfo> DnsParamCache::GetDnsCache(uint16_t netId, const std::string &hostName)
283 {
284     if (netId == 0) {
285         netId = defaultNetId_;
286     }
287 
288     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
289 #ifdef FEATURE_NET_FIREWALL_ENABLE
290     int32_t appUid = static_cast<int32_t>(GetCallingUid());
291     bool isMatchAllow = false;
292     if (IsInterceptDomain(appUid, hostName, isMatchAllow)) {
293         NotifyDomianIntercept(appUid, hostName);
294         AddrInfo fakeAddr = { 0 };
295         fakeAddr.aiFamily = AF_UNSPEC;
296         fakeAddr.aiAddr.sin.sin_family = AF_UNSPEC;
297         fakeAddr.aiAddr.sin.sin_addr.s_addr = INADDR_NONE;
298         fakeAddr.aiAddrLen = sizeof(struct sockaddr_in);
299         return { fakeAddr };
300     }
301 #endif
302 
303     auto it = serverConfigMap_.find(netId);
304     if (it == serverConfigMap_.end()) {
305         DNS_CONFIG_PRINT("GetDnsCache failed: netid is not have netid:%{public}d,", netId);
306         return {};
307     }
308 
309     return it->second.GetCache().Get(hostName);
310 }
311 
SetCacheDelayed(uint16_t netId,const std::string & hostName)312 void DnsParamCache::SetCacheDelayed(uint16_t netId, const std::string &hostName)
313 {
314     if (netId == 0) {
315         netId = defaultNetId_;
316     }
317 
318     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
319     auto it = serverConfigMap_.find(netId);
320     if (it == serverConfigMap_.end()) {
321         DNS_CONFIG_PRINT("SetCacheDelayed failed: netid is not have netid:%{public}d,", netId);
322         return;
323     }
324 
325     it->second.SetCacheDelayed(hostName);
326 }
327 
AddUidRange(uint32_t netId,const std::vector<NetManagerStandard::UidRange> & uidRanges)328 int32_t DnsParamCache::AddUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)
329 {
330     std::lock_guard<ffrt::mutex> guard(uidRangeMutex_);
331     NETNATIVE_LOG_D("DnsParamCache::AddUidRange size = [%{public}zu]", uidRanges.size());
332     vpnNetId_ = netId;
333     auto middle = vpnUidRanges_.insert(vpnUidRanges_.end(), uidRanges.begin(), uidRanges.end());
334     std::inplace_merge(vpnUidRanges_.begin(), middle, vpnUidRanges_.end());
335     return 0;
336 }
337 
DelUidRange(uint32_t netId,const std::vector<NetManagerStandard::UidRange> & uidRanges)338 int32_t DnsParamCache::DelUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)
339 {
340     std::lock_guard<ffrt::mutex> guard(uidRangeMutex_);
341     NETNATIVE_LOG_D("DnsParamCache::DelUidRange size = [%{public}zu]", uidRanges.size());
342     vpnNetId_ = 0;
343     auto end = std::set_difference(vpnUidRanges_.begin(), vpnUidRanges_.end(), uidRanges.begin(),
344                                    uidRanges.end(), vpnUidRanges_.begin());
345     vpnUidRanges_.erase(end, vpnUidRanges_.end());
346     return 0;
347 }
348 
IsVpnOpen() const349 bool DnsParamCache::IsVpnOpen() const
350 {
351     return vpnUidRanges_.size();
352 }
353 
354 #ifdef FEATURE_NET_FIREWALL_ENABLE
GetUserId(int32_t appUid)355 int32_t DnsParamCache::GetUserId(int32_t appUid)
356 {
357     int32_t userId = appUid / USER_ID_DIVIDOR;
358     return userId > 0 ? userId : currentUserId_;
359 }
360 
GetDnsServersByAppUid(int32_t appUid,std::vector<std::string> & servers)361 bool DnsParamCache::GetDnsServersByAppUid(int32_t appUid, std::vector<std::string> &servers)
362 {
363     if (netFirewallDnsRuleMap_.empty()) {
364         return false;
365     }
366     DNS_CONFIG_PRINT("GetDnsServersByAppUid: appUid=%{public}d", appUid);
367     auto it = netFirewallDnsRuleMap_.find(appUid);
368     if (it == netFirewallDnsRuleMap_.end()) {
369         // if appUid not found, try to find invalid appUid=0;
370         it = netFirewallDnsRuleMap_.find(0);
371     }
372     if (it != netFirewallDnsRuleMap_.end()) {
373         int32_t userId = GetUserId(appUid);
374         std::vector<sptr<NetFirewallDnsRule>> rules = it->second;
375         for (const auto &rule : rules) {
376             if (rule->userId != userId) {
377                 continue;
378             }
379             servers.emplace_back(rule->primaryDns);
380             servers.emplace_back(rule->standbyDns);
381         }
382         return true;
383     }
384     return false;
385 }
386 
SetFirewallRules(NetFirewallRuleType type,const std::vector<sptr<NetFirewallBaseRule>> & ruleList,bool isFinish)387 int32_t DnsParamCache::SetFirewallRules(NetFirewallRuleType type,
388                                         const std::vector<sptr<NetFirewallBaseRule>> &ruleList, bool isFinish)
389 {
390     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
391     NETNATIVE_LOGI("SetFirewallRules: size=%{public}zu isFinish=%{public}" PRId32, ruleList.size(), isFinish);
392     if (ruleList.empty()) {
393         NETNATIVE_LOGE("SetFirewallRules: rules is empty");
394         return -1;
395     }
396     int32_t ret = 0;
397     switch (type) {
398         case NetFirewallRuleType::RULE_DNS: {
399             for (const auto &rule : ruleList) {
400                 firewallDnsRules_.emplace_back(firewall_rule_cast<NetFirewallDnsRule>(rule));
401             }
402             if (isFinish) {
403                 ret = SetFirewallDnsRules(firewallDnsRules_);
404                 firewallDnsRules_.clear();
405             }
406             break;
407         }
408         case NetFirewallRuleType::RULE_DOMAIN: {
409             ClearAllDnsCache();
410             for (const auto &rule : ruleList) {
411                 firewallDomainRules_.emplace_back(firewall_rule_cast<NetFirewallDomainRule>(rule));
412             }
413             if (isFinish) {
414                 ret = SetFirewallDomainRules(firewallDomainRules_);
415                 firewallDomainRules_.clear();
416                 OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
417             }
418             break;
419         }
420         default:
421             break;
422     }
423     return ret;
424 }
425 
SetFirewallDnsRules(const std::vector<sptr<NetFirewallDnsRule>> & ruleList)426 int32_t DnsParamCache::SetFirewallDnsRules(const std::vector<sptr<NetFirewallDnsRule>> &ruleList)
427 {
428     for (const auto &rule : ruleList) {
429         std::vector<sptr<NetFirewallDnsRule>> rules;
430         auto it = netFirewallDnsRuleMap_.find(rule->appUid);
431         if (it != netFirewallDnsRuleMap_.end()) {
432             rules = it->second;
433         }
434         rules.emplace_back(std::move(rule));
435         netFirewallDnsRuleMap_.emplace(rule->appUid, std::move(rules));
436     }
437     return 0;
438 }
439 
GetFirewallRuleAction(int32_t appUid,const std::vector<sptr<NetFirewallDomainRule>> & rules)440 FirewallRuleAction DnsParamCache::GetFirewallRuleAction(int32_t appUid,
441                                                         const std::vector<sptr<NetFirewallDomainRule>> &rules)
442 {
443     int32_t userId = GetUserId(appUid);
444     for (const auto &rule : rules) {
445         if (rule->userId != userId) {
446             continue;
447         }
448         if ((rule->appUid && appUid == rule->appUid) || !rule->appUid) {
449             return rule->ruleAction;
450         }
451     }
452 
453     return FirewallRuleAction::RULE_INVALID;
454 }
455 
checkEmpty4InterceptDomain(const std::string & hostName)456 bool DnsParamCache::checkEmpty4InterceptDomain(const std::string &hostName)
457 {
458     if (hostName.empty()) {
459         return true;
460     }
461     if (!netFirewallDomainRulesAllowMap_.empty() || !netFirewallDomainRulesDenyMap_.empty()) {
462         return false;
463     }
464     if (domainAllowLsmTrie_ && !domainAllowLsmTrie_->Empty()) {
465         return false;
466     }
467     return !domainDenyLsmTrie_ || domainDenyLsmTrie_->Empty();
468 }
469 
IsInterceptDomain(int32_t appUid,const std::string & hostName,bool & isMatchAllow)470 bool DnsParamCache::IsInterceptDomain(int32_t appUid, const std::string &hostName, bool &isMatchAllow)
471 {
472     if (checkEmpty4InterceptDomain(hostName)) {
473         return false;
474     }
475     std::string host = hostName.substr(0, hostName.find(' '));
476     DNS_CONFIG_PRINT("IsInterceptDomain: appUid: %{public}d, hostName: %{private}s", appUid, host.c_str());
477     std::transform(host.begin(), host.end(), host.begin(), ::tolower);
478     std::vector<sptr<NetFirewallDomainRule>> rules;
479     FirewallRuleAction exactAllowAction = FirewallRuleAction::RULE_INVALID;
480     auto it = netFirewallDomainRulesAllowMap_.find(host);
481     if (it != netFirewallDomainRulesAllowMap_.end()) {
482         rules = it->second;
483         exactAllowAction = GetFirewallRuleAction(appUid, rules);
484     }
485     FirewallRuleAction exactDenyAction = FirewallRuleAction::RULE_INVALID;
486     auto iter = netFirewallDomainRulesDenyMap_.find(host);
487     if (iter != netFirewallDomainRulesDenyMap_.end()) {
488         rules = iter->second;
489         exactDenyAction = GetFirewallRuleAction(appUid, rules);
490     }
491     FirewallRuleAction wildcardAllowAction = FirewallRuleAction::RULE_INVALID;
492     if (domainAllowLsmTrie_->LongestSuffixMatch(host, rules)) {
493         wildcardAllowAction = GetFirewallRuleAction(appUid, rules);
494     }
495     FirewallRuleAction wildcardDenyAction = FirewallRuleAction::RULE_INVALID;
496     if (domainDenyLsmTrie_->LongestSuffixMatch(host, rules)) {
497         wildcardDenyAction = GetFirewallRuleAction(appUid, rules);
498     }
499     isMatchAllow = (exactAllowAction != FirewallRuleAction::RULE_INVALID) ||
500                    (wildcardAllowAction != FirewallRuleAction::RULE_INVALID);
501     bool isDeny = (exactDenyAction != FirewallRuleAction::RULE_INVALID) ||
502                   (wildcardDenyAction != FirewallRuleAction::RULE_INVALID);
503     if (isMatchAllow) {
504         // Apply default rules in case of conflict
505         return isDeny && (firewallDefaultAction_ == FirewallRuleAction::RULE_DENY);
506     }
507     return isDeny;
508 }
509 
SetFirewallDefaultAction(FirewallRuleAction inDefault,FirewallRuleAction outDefault)510 int32_t DnsParamCache::SetFirewallDefaultAction(FirewallRuleAction inDefault, FirewallRuleAction outDefault)
511 {
512     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
513     DNS_CONFIG_PRINT("SetFirewallDefaultAction: firewallDefaultAction_: %{public}d", (int)outDefault);
514     firewallDefaultAction_ = outDefault;
515     return 0;
516 }
517 
BuildFirewallDomainLsmTrie(const sptr<NetFirewallDomainRule> & rule,const std::string & domain)518 void DnsParamCache::BuildFirewallDomainLsmTrie(const sptr<NetFirewallDomainRule> &rule, const std::string &domain)
519 {
520     std::vector<sptr<NetFirewallDomainRule>> rules;
521     std::string suffix(domain);
522     auto wildcardCharIndex = suffix.find('*');
523     if (wildcardCharIndex != std::string::npos) {
524         suffix = suffix.substr(wildcardCharIndex + 1);
525     }
526     DNS_CONFIG_PRINT("BuildFirewallDomainLsmTrie: suffix: %{public}s", suffix.c_str());
527     std::transform(suffix.begin(), suffix.end(), suffix.begin(), ::tolower);
528     if (rule->ruleAction == FirewallRuleAction::RULE_DENY) {
529         if (domainDenyLsmTrie_->LongestSuffixMatch(suffix, rules)) {
530             rules.emplace_back(std::move(rule));
531             domainDenyLsmTrie_->Update(suffix, rules);
532             return;
533         }
534         rules.emplace_back(std::move(rule));
535         domainDenyLsmTrie_->Insert(suffix, rules);
536     } else {
537         if (domainAllowLsmTrie_->LongestSuffixMatch(suffix, rules)) {
538             rules.emplace_back(std::move(rule));
539             domainAllowLsmTrie_->Update(suffix, rules);
540             return;
541         }
542         rules.emplace_back(std::move(rule));
543         domainAllowLsmTrie_->Insert(suffix, rules);
544     }
545 }
546 
BuildFirewallDomainMap(const sptr<NetFirewallDomainRule> & rule,const std::string & raw)547 void DnsParamCache::BuildFirewallDomainMap(const sptr<NetFirewallDomainRule> &rule, const std::string &raw)
548 {
549     DNS_CONFIG_PRINT("BuildFirewallDomainMap: domain: %{public}s", raw.c_str());
550     std::string domain(raw);
551     std::vector<sptr<NetFirewallDomainRule>> rules;
552     std::transform(domain.begin(), domain.end(), domain.begin(), ::tolower);
553     if (rule->ruleAction == FirewallRuleAction::RULE_DENY) {
554         auto it = netFirewallDomainRulesDenyMap_.find(domain);
555         if (it != netFirewallDomainRulesDenyMap_.end()) {
556             rules = it->second;
557         }
558 
559         rules.emplace_back(std::move(rule));
560         netFirewallDomainRulesDenyMap_.emplace(domain, std::move(rules));
561     } else {
562         auto it = netFirewallDomainRulesAllowMap_.find(domain);
563         if (it != netFirewallDomainRulesAllowMap_.end()) {
564             rules = it->second;
565         }
566 
567         rules.emplace_back(rule);
568         netFirewallDomainRulesAllowMap_.emplace(domain, std::move(rules));
569     }
570 }
571 
SetFirewallDomainRules(const std::vector<sptr<NetFirewallDomainRule>> & ruleList)572 int32_t DnsParamCache::SetFirewallDomainRules(const std::vector<sptr<NetFirewallDomainRule>> &ruleList)
573 {
574     if (!domainAllowLsmTrie_) {
575         domainAllowLsmTrie_ =
576             std::make_shared<NetManagerStandard::SuffixMatchTrie<std::vector<sptr<NetFirewallDomainRule>>>>();
577     }
578     if (!domainDenyLsmTrie_) {
579         domainDenyLsmTrie_ =
580             std::make_shared<NetManagerStandard::SuffixMatchTrie<std::vector<sptr<NetFirewallDomainRule>>>>();
581     }
582     for (const auto &rule : ruleList) {
583         for (const auto &param : rule->domains) {
584             if (param.isWildcard) {
585                 BuildFirewallDomainLsmTrie(rule, param.domain);
586             } else {
587                 BuildFirewallDomainMap(rule, param.domain);
588             }
589         }
590     }
591     return 0;
592 }
593 
ClearFirewallRules(NetFirewallRuleType type)594 int32_t DnsParamCache::ClearFirewallRules(NetFirewallRuleType type)
595 {
596     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
597     switch (type) {
598         case NetFirewallRuleType::RULE_DNS:
599             firewallDnsRules_.clear();
600             netFirewallDnsRuleMap_.clear();
601             break;
602         case NetFirewallRuleType::RULE_DOMAIN: {
603             firewallDomainRules_.clear();
604             netFirewallDomainRulesAllowMap_.clear();
605             netFirewallDomainRulesDenyMap_.clear();
606             if (domainAllowLsmTrie_) {
607                 domainAllowLsmTrie_ = nullptr;
608             }
609             if (domainDenyLsmTrie_) {
610                 domainDenyLsmTrie_ = nullptr;
611             }
612             OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
613             break;
614         }
615         case NetFirewallRuleType::RULE_ALL: {
616             firewallDnsRules_.clear();
617             netFirewallDnsRuleMap_.clear();
618             firewallDomainRules_.clear();
619             netFirewallDomainRulesAllowMap_.clear();
620             netFirewallDomainRulesDenyMap_.clear();
621             if (domainAllowLsmTrie_) {
622                 domainAllowLsmTrie_ = nullptr;
623             }
624             if (domainDenyLsmTrie_) {
625                 domainDenyLsmTrie_ = nullptr;
626             }
627             OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
628             break;
629         }
630         default:
631             break;
632     }
633     return 0;
634 }
635 
NotifyDomianIntercept(int32_t appUid,const std::string & hostName)636 void DnsParamCache::NotifyDomianIntercept(int32_t appUid, const std::string &hostName)
637 {
638     if (hostName.empty()) {
639         return;
640     }
641     std::string host = hostName.substr(0, hostName.find(' '));
642     NETNATIVE_LOGI("NotifyDomianIntercept: appUid: %{public}d, hostName: %{private}s", appUid, host.c_str());
643     sptr<InterceptRecord> record = sptr<InterceptRecord>::MakeSptr();
644     record->time = (int32_t)time(NULL);
645     record->appUid = appUid;
646     record->domain = host;
647 
648     if (oldRecord_ != nullptr && (record->time - oldRecord_->time) < INTERCEPT_BUFF_INTERVAL_SEC) {
649         if (record->appUid == oldRecord_->appUid && record->domain == oldRecord_->domain) {
650             return;
651         }
652     }
653     oldRecord_ = record;
654     for (const auto &callback : callbacks_) {
655         callback->OnIntercept(record);
656     }
657 }
658 
RegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> & callback)659 int32_t DnsParamCache::RegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)
660 {
661     if (!callback) {
662         return -1;
663     }
664 
665     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
666     callbacks_.emplace_back(callback);
667 
668     return 0;
669 }
670 
UnRegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> & callback)671 int32_t DnsParamCache::UnRegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)
672 {
673     if (!callback) {
674         return -1;
675     }
676 
677     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
678     for (auto it = callbacks_.begin(); it != callbacks_.end(); ++it) {
679         if (*it == callback) {
680             callbacks_.erase(it);
681             return 0;
682         }
683     }
684     return -1;
685 }
686 
SetFirewallCurrentUserId(int32_t userId)687 int32_t DnsParamCache::SetFirewallCurrentUserId(int32_t userId)
688 {
689     currentUserId_ = userId;
690     ClearAllDnsCache();
691     return 0;
692 }
693 
ClearAllDnsCache()694 void DnsParamCache::ClearAllDnsCache()
695 {
696     NETNATIVE_LOGI("ClearAllDnsCache");
697     for (auto it = serverConfigMap_.begin(); it != serverConfigMap_.end(); it++) {
698         it->second.GetCache().Clear();
699     }
700 }
701 #endif
702 
GetDumpInfo(std::string & info)703 void DnsParamCache::GetDumpInfo(std::string &info)
704 {
705     std::string dnsData;
706     static const std::string TAB = "  ";
707     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
708     std::for_each(serverConfigMap_.begin(), serverConfigMap_.end(), [&dnsData](const auto &serverConfig) {
709         dnsData.append(TAB + "NetId: " + std::to_string(serverConfig.second.GetNetId()) + "\n");
710         dnsData.append(TAB + "TimeoutMsec: " + std::to_string(serverConfig.second.GetTimeoutMsec()) + "\n");
711         dnsData.append(TAB + "RetryCount: " + std::to_string(serverConfig.second.GetRetryCount()) + "\n");
712         dnsData.append(TAB + "Servers:");
713         GetVectorData(serverConfig.second.GetServers(), dnsData);
714         dnsData.append(TAB + "Domains:");
715         GetVectorData(serverConfig.second.GetDomains(), dnsData);
716     });
717     info.append(dnsData);
718 }
719 
SetUserDefinedServerFlag(uint16_t netId,bool flag)720 int32_t DnsParamCache::SetUserDefinedServerFlag(uint16_t netId, bool flag)
721 {
722     NETNATIVE_LOGI("DnsParamCache::SetUserDefinedServerFlag, netid:%{public}d, flag:%{public}d,", netId, flag);
723 
724     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
725     // select_domains
726     auto it = serverConfigMap_.find(netId);
727     if (it == serverConfigMap_.end()) {
728         NETNATIVE_LOGE("DnsParamCache::SetUserDefinedServerFlag failed, netid is non-existent");
729         return -ENOENT;
730     }
731     it->second.SetUserDefinedServerFlag(flag);
732     return 0;
733 }
734 
GetUserDefinedServerFlag(uint16_t netId,bool & flag)735 int32_t DnsParamCache::GetUserDefinedServerFlag(uint16_t netId, bool &flag)
736 {
737     if (netId == 0) {
738         netId = defaultNetId_;
739         NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
740     }
741     std::lock_guard<ffrt::mutex> guard(cacheMutex_);
742     auto it = serverConfigMap_.find(netId);
743     if (it == serverConfigMap_.end()) {
744         DNS_CONFIG_PRINT("GetUserDefinedServerFlag failed: netid is not have netid:%{public}d,", netId);
745         return -ENOENT;
746     }
747     flag = it->second.IsUserDefinedServer();
748     return 0;
749 }
750 
GetUserDefinedServerFlag(uint16_t netId,bool & flag,uint32_t uid)751 int32_t DnsParamCache::GetUserDefinedServerFlag(uint16_t netId, bool &flag, uint32_t uid)
752 {
753     if (netId == 0) {
754         netId = defaultNetId_;
755         NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
756     }
757     {
758         std::lock_guard<ffrt::mutex> guard(cacheMutex_);
759         for (auto mem : vpnUidRanges_) {
760             if (static_cast<int32_t>(uid) >= mem.begin_ && static_cast<int32_t>(uid) <= mem.end_) {
761                 NETNATIVE_LOG_D("is vpn hap");
762                 auto it = serverConfigMap_.find(vpnNetId_);
763                 if (it == serverConfigMap_.end()) {
764                     NETNATIVE_LOG_D("vpn get Config failed: not have vpnnetid:%{public}d,", vpnNetId_);
765                     break;
766                 }
767                 flag = it->second.IsUserDefinedServer();
768                 return 0;
769             }
770         }
771         auto it = serverConfigMap_.find(netId);
772         if (it == serverConfigMap_.end()) {
773             DNS_CONFIG_PRINT("GetUserDefinedServerFlag failed: netid is not have netid:%{public}d,", netId);
774             return -ENOENT;
775         }
776     }
777     return GetUserDefinedServerFlag(netId, flag);
778 }
779 } // namespace OHOS::nmd
780