1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "dns_param_cache.h"
17
18 #include <algorithm>
19
20 #include "netmanager_base_common_utils.h"
21 #ifdef FEATURE_NET_FIREWALL_ENABLE
22 #include "bpf_netfirewall.h"
23 #include "netfirewall_parcel.h"
24 #include <ctime>
25 #endif
26
27 namespace OHOS::nmd {
28 using namespace OHOS::NetManagerStandard::CommonUtils;
29 namespace {
GetVectorData(const std::vector<std::string> & data,std::string & result)30 void GetVectorData(const std::vector<std::string> &data, std::string &result)
31 {
32 result.append("{ ");
33 std::for_each(data.begin(), data.end(), [&result](const auto &str) { result.append(ToAnonymousIp(str) + ", "); });
34 result.append("}\n");
35 }
36 constexpr int RES_TIMEOUT = 4000; // min. milliseconds between retries
37 constexpr int RES_DEFAULT_RETRY = 2; // Default
38 } // namespace
39
DnsParamCache()40 DnsParamCache::DnsParamCache() : defaultNetId_(0) {}
41
GetInstance()42 DnsParamCache &DnsParamCache::GetInstance()
43 {
44 static DnsParamCache instance;
45 return instance;
46 }
47
SelectNameservers(const std::vector<std::string> & servers)48 std::vector<std::string> DnsParamCache::SelectNameservers(const std::vector<std::string> &servers)
49 {
50 std::vector<std::string> res = servers;
51 if (res.size() > MAX_SERVER_NUM - 1) {
52 res.resize(MAX_SERVER_NUM - 1);
53 }
54 return res;
55 }
56
CreateCacheForNet(uint16_t netId,bool isVpnNet)57 int32_t DnsParamCache::CreateCacheForNet(uint16_t netId, bool isVpnNet)
58 {
59 NETNATIVE_LOGI("DnsParamCache::CreateCacheForNet, netid:%{public}d,", netId);
60 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
61 auto it = serverConfigMap_.find(netId);
62 if (it != serverConfigMap_.end()) {
63 NETNATIVE_LOGE("DnsParamCache::CreateCacheForNet, netid already exist, no need to create");
64 return -EEXIST;
65 }
66 serverConfigMap_[netId].SetNetId(netId);
67 if (isVpnNet) {
68 NETNATIVE_LOGI("DnsParamCache::CreateCacheForNet clear all dns cache when vpn net create");
69 for (auto iterator = serverConfigMap_.begin(); iterator != serverConfigMap_.end(); iterator++) {
70 iterator->second.GetCache().Clear();
71 }
72 }
73 return 0;
74 }
75
DestroyNetworkCache(uint16_t netId,bool isVpnNet)76 int32_t DnsParamCache::DestroyNetworkCache(uint16_t netId, bool isVpnNet)
77 {
78 NETNATIVE_LOGI("DnsParamCache::DestroyNetworkCache, netid:%{public}d, %{public}d", netId, isVpnNet);
79 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
80 auto it = serverConfigMap_.find(netId);
81 if (it == serverConfigMap_.end()) {
82 return -ENOENT;
83 }
84 serverConfigMap_.erase(it);
85 if (defaultNetId_ == netId) {
86 defaultNetId_ = 0;
87 }
88 if (isVpnNet) {
89 NETNATIVE_LOGI("DnsParamCache::DestroyNetworkCache clear all dns cache when vpn net destroy");
90 for (auto it = serverConfigMap_.begin(); it != serverConfigMap_.end(); it++) {
91 it->second.GetCache().Clear();
92 }
93 }
94 return 0;
95 }
96
SetResolverConfig(uint16_t netId,uint16_t baseTimeoutMsec,uint8_t retryCount,const std::vector<std::string> & servers,const std::vector<std::string> & domains)97 int32_t DnsParamCache::SetResolverConfig(uint16_t netId, uint16_t baseTimeoutMsec, uint8_t retryCount,
98 const std::vector<std::string> &servers,
99 const std::vector<std::string> &domains)
100 {
101 std::vector<std::string> nameservers = SelectNameservers(servers);
102 NETNATIVE_LOG_D("DnsParamCache::SetResolverConfig, netid:%{public}d, numServers:%{public}d,", netId,
103 static_cast<int>(nameservers.size()));
104
105 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
106
107 // select_domains
108 auto it = serverConfigMap_.find(netId);
109 if (it == serverConfigMap_.end()) {
110 NETNATIVE_LOGE("DnsParamCache::SetResolverConfig failed, netid is non-existent");
111 return -ENOENT;
112 }
113
114 auto oldDnsServers = it->second.GetServers();
115 std::sort(oldDnsServers.begin(), oldDnsServers.end());
116
117 auto newDnsServers = servers;
118 std::sort(newDnsServers.begin(), newDnsServers.end());
119
120 if (oldDnsServers != newDnsServers) {
121 it->second.GetCache().Clear();
122 }
123
124 it->second.SetNetId(netId);
125 it->second.SetServers(servers);
126 it->second.SetDomains(domains);
127 if (retryCount == 0) {
128 it->second.SetRetryCount(RES_DEFAULT_RETRY);
129 } else {
130 it->second.SetRetryCount(retryCount);
131 }
132 if (baseTimeoutMsec == 0) {
133 it->second.SetTimeoutMsec(RES_TIMEOUT);
134 } else {
135 it->second.SetTimeoutMsec(baseTimeoutMsec);
136 }
137 return 0;
138 }
139
SetDefaultNetwork(uint16_t netId)140 void DnsParamCache::SetDefaultNetwork(uint16_t netId)
141 {
142 defaultNetId_ = netId;
143 }
144
EnableIpv6(uint16_t netId)145 void DnsParamCache::EnableIpv6(uint16_t netId)
146 {
147 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
148 auto it = serverConfigMap_.find(netId);
149 if (it == serverConfigMap_.end()) {
150 DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
151 return;
152 }
153
154 it->second.EnableIpv6();
155 }
156
IsIpv6Enable(uint16_t netId)157 bool DnsParamCache::IsIpv6Enable(uint16_t netId)
158 {
159 if (netId == 0) {
160 netId = defaultNetId_;
161 }
162
163 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
164 auto it = serverConfigMap_.find(netId);
165 if (it == serverConfigMap_.end()) {
166 DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
167 return false;
168 }
169
170 return it->second.IsIpv6Enable();
171 }
172
GetResolverConfig(uint16_t netId,std::vector<std::string> & servers,std::vector<std::string> & domains,uint16_t & baseTimeoutMsec,uint8_t & retryCount)173 int32_t DnsParamCache::GetResolverConfig(uint16_t netId, std::vector<std::string> &servers,
174 std::vector<std::string> &domains, uint16_t &baseTimeoutMsec,
175 uint8_t &retryCount)
176 {
177 NETNATIVE_LOG_D("DnsParamCache::GetResolverConfig no uid");
178 if (netId == 0) {
179 netId = defaultNetId_;
180 NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
181 }
182
183 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
184 auto it = serverConfigMap_.find(netId);
185 if (it == serverConfigMap_.end()) {
186 DNS_CONFIG_PRINT("get Config failed: netid is not have netid:%{public}d,", netId);
187 return -ENOENT;
188 }
189
190 servers = it->second.GetServers();
191 #ifdef FEATURE_NET_FIREWALL_ENABLE
192 std::vector<std::string> dns;
193 if (GetDnsServersByAppUid(GetCallingUid(), dns)) {
194 DNS_CONFIG_PRINT("GetResolverConfig hit netfirewall");
195 servers.assign(dns.begin(), dns.end());
196 }
197 #endif
198 domains = it->second.GetDomains();
199 baseTimeoutMsec = it->second.GetTimeoutMsec();
200 retryCount = it->second.GetRetryCount();
201
202 return 0;
203 }
204
GetResolverConfig(uint16_t netId,uint32_t uid,std::vector<std::string> & servers,std::vector<std::string> & domains,uint16_t & baseTimeoutMsec,uint8_t & retryCount)205 int32_t DnsParamCache::GetResolverConfig(uint16_t netId, uint32_t uid, std::vector<std::string> &servers,
206 std::vector<std::string> &domains, uint16_t &baseTimeoutMsec,
207 uint8_t &retryCount)
208 {
209 NETNATIVE_LOG_D("DnsParamCache::GetResolverConfig has uid");
210 if (netId == 0) {
211 netId = defaultNetId_;
212 NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
213 }
214
215 {
216 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
217 for (auto mem : vpnUidRanges_) {
218 if (static_cast<int32_t>(uid) >= mem.begin_ && static_cast<int32_t>(uid) <= mem.end_) {
219 NETNATIVE_LOG_D("is vpn hap");
220 auto it = serverConfigMap_.find(vpnNetId_);
221 if (it == serverConfigMap_.end()) {
222 NETNATIVE_LOG_D("vpn get Config failed: not have vpnnetid:%{public}d,", vpnNetId_);
223 break;
224 }
225 servers = it->second.GetServers();
226 #ifdef FEATURE_NET_FIREWALL_ENABLE
227 std::vector<std::string> dns;
228 if (GetDnsServersByAppUid(GetCallingUid(), dns)) {
229 DNS_CONFIG_PRINT("GetResolverConfig hit netfirewall");
230 servers.assign(dns.begin(), dns.end());
231 }
232 #endif
233 domains = it->second.GetDomains();
234 baseTimeoutMsec = it->second.GetTimeoutMsec();
235 retryCount = it->second.GetRetryCount();
236 return 0;
237 }
238 }
239 }
240 return GetResolverConfig(netId, servers, domains, baseTimeoutMsec, retryCount);
241 }
242
GetDefaultNetwork() const243 int32_t DnsParamCache::GetDefaultNetwork() const
244 {
245 return defaultNetId_;
246 }
247
SetDnsCache(uint16_t netId,const std::string & hostName,const AddrInfo & addrInfo)248 void DnsParamCache::SetDnsCache(uint16_t netId, const std::string &hostName, const AddrInfo &addrInfo)
249 {
250 if (netId == 0) {
251 netId = defaultNetId_;
252 }
253 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
254 #ifdef FEATURE_NET_FIREWALL_ENABLE
255 int32_t appUid = static_cast<int32_t>(GetCallingUid());
256 bool isMatchAllow = false;
257 if (IsInterceptDomain(appUid, hostName, isMatchAllow)) {
258 DNS_CONFIG_PRINT("SetDnsCache failed: domain was Intercepted: %{public}s,", hostName.c_str());
259 return;
260 }
261 if (isMatchAllow && (addrInfo.aiFamily == AF_INET || addrInfo.aiFamily == AF_INET6)) {
262 NetAddrInfo netInfo;
263 netInfo.aiFamily = addrInfo.aiFamily;
264 if (addrInfo.aiFamily == AF_INET) {
265 netInfo.aiAddr.sin = addrInfo.aiAddr.sin.sin_addr;
266 } else {
267 memcpy_s(&netInfo.aiAddr.sin6, sizeof(addrInfo.aiAddr.sin6.sin6_addr), &addrInfo.aiAddr.sin6.sin6_addr,
268 sizeof(addrInfo.aiAddr.sin6.sin6_addr));
269 }
270 OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->AddDomainCache(netInfo);
271 }
272 #endif
273 auto it = serverConfigMap_.find(netId);
274 if (it == serverConfigMap_.end()) {
275 DNS_CONFIG_PRINT("SetDnsCache failed: netid is not have netid:%{public}d,", netId);
276 return;
277 }
278
279 it->second.GetCache().Put(hostName, addrInfo);
280 }
281
GetDnsCache(uint16_t netId,const std::string & hostName)282 std::vector<AddrInfo> DnsParamCache::GetDnsCache(uint16_t netId, const std::string &hostName)
283 {
284 if (netId == 0) {
285 netId = defaultNetId_;
286 }
287
288 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
289 #ifdef FEATURE_NET_FIREWALL_ENABLE
290 int32_t appUid = static_cast<int32_t>(GetCallingUid());
291 bool isMatchAllow = false;
292 if (IsInterceptDomain(appUid, hostName, isMatchAllow)) {
293 NotifyDomianIntercept(appUid, hostName);
294 AddrInfo fakeAddr = { 0 };
295 fakeAddr.aiFamily = AF_UNSPEC;
296 fakeAddr.aiAddr.sin.sin_family = AF_UNSPEC;
297 fakeAddr.aiAddr.sin.sin_addr.s_addr = INADDR_NONE;
298 fakeAddr.aiAddrLen = sizeof(struct sockaddr_in);
299 return { fakeAddr };
300 }
301 #endif
302
303 auto it = serverConfigMap_.find(netId);
304 if (it == serverConfigMap_.end()) {
305 DNS_CONFIG_PRINT("GetDnsCache failed: netid is not have netid:%{public}d,", netId);
306 return {};
307 }
308
309 return it->second.GetCache().Get(hostName);
310 }
311
SetCacheDelayed(uint16_t netId,const std::string & hostName)312 void DnsParamCache::SetCacheDelayed(uint16_t netId, const std::string &hostName)
313 {
314 if (netId == 0) {
315 netId = defaultNetId_;
316 }
317
318 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
319 auto it = serverConfigMap_.find(netId);
320 if (it == serverConfigMap_.end()) {
321 DNS_CONFIG_PRINT("SetCacheDelayed failed: netid is not have netid:%{public}d,", netId);
322 return;
323 }
324
325 it->second.SetCacheDelayed(hostName);
326 }
327
AddUidRange(uint32_t netId,const std::vector<NetManagerStandard::UidRange> & uidRanges)328 int32_t DnsParamCache::AddUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)
329 {
330 std::lock_guard<ffrt::mutex> guard(uidRangeMutex_);
331 NETNATIVE_LOG_D("DnsParamCache::AddUidRange size = [%{public}zu]", uidRanges.size());
332 vpnNetId_ = netId;
333 auto middle = vpnUidRanges_.insert(vpnUidRanges_.end(), uidRanges.begin(), uidRanges.end());
334 std::inplace_merge(vpnUidRanges_.begin(), middle, vpnUidRanges_.end());
335 return 0;
336 }
337
DelUidRange(uint32_t netId,const std::vector<NetManagerStandard::UidRange> & uidRanges)338 int32_t DnsParamCache::DelUidRange(uint32_t netId, const std::vector<NetManagerStandard::UidRange> &uidRanges)
339 {
340 std::lock_guard<ffrt::mutex> guard(uidRangeMutex_);
341 NETNATIVE_LOG_D("DnsParamCache::DelUidRange size = [%{public}zu]", uidRanges.size());
342 vpnNetId_ = 0;
343 auto end = std::set_difference(vpnUidRanges_.begin(), vpnUidRanges_.end(), uidRanges.begin(),
344 uidRanges.end(), vpnUidRanges_.begin());
345 vpnUidRanges_.erase(end, vpnUidRanges_.end());
346 return 0;
347 }
348
IsVpnOpen() const349 bool DnsParamCache::IsVpnOpen() const
350 {
351 return vpnUidRanges_.size();
352 }
353
354 #ifdef FEATURE_NET_FIREWALL_ENABLE
GetUserId(int32_t appUid)355 int32_t DnsParamCache::GetUserId(int32_t appUid)
356 {
357 int32_t userId = appUid / USER_ID_DIVIDOR;
358 return userId > 0 ? userId : currentUserId_;
359 }
360
GetDnsServersByAppUid(int32_t appUid,std::vector<std::string> & servers)361 bool DnsParamCache::GetDnsServersByAppUid(int32_t appUid, std::vector<std::string> &servers)
362 {
363 if (netFirewallDnsRuleMap_.empty()) {
364 return false;
365 }
366 DNS_CONFIG_PRINT("GetDnsServersByAppUid: appUid=%{public}d", appUid);
367 auto it = netFirewallDnsRuleMap_.find(appUid);
368 if (it == netFirewallDnsRuleMap_.end()) {
369 // if appUid not found, try to find invalid appUid=0;
370 it = netFirewallDnsRuleMap_.find(0);
371 }
372 if (it != netFirewallDnsRuleMap_.end()) {
373 int32_t userId = GetUserId(appUid);
374 std::vector<sptr<NetFirewallDnsRule>> rules = it->second;
375 for (const auto &rule : rules) {
376 if (rule->userId != userId) {
377 continue;
378 }
379 servers.emplace_back(rule->primaryDns);
380 servers.emplace_back(rule->standbyDns);
381 }
382 return true;
383 }
384 return false;
385 }
386
SetFirewallRules(NetFirewallRuleType type,const std::vector<sptr<NetFirewallBaseRule>> & ruleList,bool isFinish)387 int32_t DnsParamCache::SetFirewallRules(NetFirewallRuleType type,
388 const std::vector<sptr<NetFirewallBaseRule>> &ruleList, bool isFinish)
389 {
390 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
391 NETNATIVE_LOGI("SetFirewallRules: size=%{public}zu isFinish=%{public}" PRId32, ruleList.size(), isFinish);
392 if (ruleList.empty()) {
393 NETNATIVE_LOGE("SetFirewallRules: rules is empty");
394 return -1;
395 }
396 int32_t ret = 0;
397 switch (type) {
398 case NetFirewallRuleType::RULE_DNS: {
399 for (const auto &rule : ruleList) {
400 firewallDnsRules_.emplace_back(firewall_rule_cast<NetFirewallDnsRule>(rule));
401 }
402 if (isFinish) {
403 ret = SetFirewallDnsRules(firewallDnsRules_);
404 firewallDnsRules_.clear();
405 }
406 break;
407 }
408 case NetFirewallRuleType::RULE_DOMAIN: {
409 ClearAllDnsCache();
410 for (const auto &rule : ruleList) {
411 firewallDomainRules_.emplace_back(firewall_rule_cast<NetFirewallDomainRule>(rule));
412 }
413 if (isFinish) {
414 ret = SetFirewallDomainRules(firewallDomainRules_);
415 firewallDomainRules_.clear();
416 OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
417 }
418 break;
419 }
420 default:
421 break;
422 }
423 return ret;
424 }
425
SetFirewallDnsRules(const std::vector<sptr<NetFirewallDnsRule>> & ruleList)426 int32_t DnsParamCache::SetFirewallDnsRules(const std::vector<sptr<NetFirewallDnsRule>> &ruleList)
427 {
428 for (const auto &rule : ruleList) {
429 std::vector<sptr<NetFirewallDnsRule>> rules;
430 auto it = netFirewallDnsRuleMap_.find(rule->appUid);
431 if (it != netFirewallDnsRuleMap_.end()) {
432 rules = it->second;
433 }
434 rules.emplace_back(std::move(rule));
435 netFirewallDnsRuleMap_.emplace(rule->appUid, std::move(rules));
436 }
437 return 0;
438 }
439
GetFirewallRuleAction(int32_t appUid,const std::vector<sptr<NetFirewallDomainRule>> & rules)440 FirewallRuleAction DnsParamCache::GetFirewallRuleAction(int32_t appUid,
441 const std::vector<sptr<NetFirewallDomainRule>> &rules)
442 {
443 int32_t userId = GetUserId(appUid);
444 for (const auto &rule : rules) {
445 if (rule->userId != userId) {
446 continue;
447 }
448 if ((rule->appUid && appUid == rule->appUid) || !rule->appUid) {
449 return rule->ruleAction;
450 }
451 }
452
453 return FirewallRuleAction::RULE_INVALID;
454 }
455
checkEmpty4InterceptDomain(const std::string & hostName)456 bool DnsParamCache::checkEmpty4InterceptDomain(const std::string &hostName)
457 {
458 if (hostName.empty()) {
459 return true;
460 }
461 if (!netFirewallDomainRulesAllowMap_.empty() || !netFirewallDomainRulesDenyMap_.empty()) {
462 return false;
463 }
464 if (domainAllowLsmTrie_ && !domainAllowLsmTrie_->Empty()) {
465 return false;
466 }
467 return !domainDenyLsmTrie_ || domainDenyLsmTrie_->Empty();
468 }
469
IsInterceptDomain(int32_t appUid,const std::string & hostName,bool & isMatchAllow)470 bool DnsParamCache::IsInterceptDomain(int32_t appUid, const std::string &hostName, bool &isMatchAllow)
471 {
472 if (checkEmpty4InterceptDomain(hostName)) {
473 return false;
474 }
475 std::string host = hostName.substr(0, hostName.find(' '));
476 DNS_CONFIG_PRINT("IsInterceptDomain: appUid: %{public}d, hostName: %{private}s", appUid, host.c_str());
477 std::transform(host.begin(), host.end(), host.begin(), ::tolower);
478 std::vector<sptr<NetFirewallDomainRule>> rules;
479 FirewallRuleAction exactAllowAction = FirewallRuleAction::RULE_INVALID;
480 auto it = netFirewallDomainRulesAllowMap_.find(host);
481 if (it != netFirewallDomainRulesAllowMap_.end()) {
482 rules = it->second;
483 exactAllowAction = GetFirewallRuleAction(appUid, rules);
484 }
485 FirewallRuleAction exactDenyAction = FirewallRuleAction::RULE_INVALID;
486 auto iter = netFirewallDomainRulesDenyMap_.find(host);
487 if (iter != netFirewallDomainRulesDenyMap_.end()) {
488 rules = iter->second;
489 exactDenyAction = GetFirewallRuleAction(appUid, rules);
490 }
491 FirewallRuleAction wildcardAllowAction = FirewallRuleAction::RULE_INVALID;
492 if (domainAllowLsmTrie_->LongestSuffixMatch(host, rules)) {
493 wildcardAllowAction = GetFirewallRuleAction(appUid, rules);
494 }
495 FirewallRuleAction wildcardDenyAction = FirewallRuleAction::RULE_INVALID;
496 if (domainDenyLsmTrie_->LongestSuffixMatch(host, rules)) {
497 wildcardDenyAction = GetFirewallRuleAction(appUid, rules);
498 }
499 isMatchAllow = (exactAllowAction != FirewallRuleAction::RULE_INVALID) ||
500 (wildcardAllowAction != FirewallRuleAction::RULE_INVALID);
501 bool isDeny = (exactDenyAction != FirewallRuleAction::RULE_INVALID) ||
502 (wildcardDenyAction != FirewallRuleAction::RULE_INVALID);
503 if (isMatchAllow) {
504 // Apply default rules in case of conflict
505 return isDeny && (firewallDefaultAction_ == FirewallRuleAction::RULE_DENY);
506 }
507 return isDeny;
508 }
509
SetFirewallDefaultAction(FirewallRuleAction inDefault,FirewallRuleAction outDefault)510 int32_t DnsParamCache::SetFirewallDefaultAction(FirewallRuleAction inDefault, FirewallRuleAction outDefault)
511 {
512 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
513 DNS_CONFIG_PRINT("SetFirewallDefaultAction: firewallDefaultAction_: %{public}d", (int)outDefault);
514 firewallDefaultAction_ = outDefault;
515 return 0;
516 }
517
BuildFirewallDomainLsmTrie(const sptr<NetFirewallDomainRule> & rule,const std::string & domain)518 void DnsParamCache::BuildFirewallDomainLsmTrie(const sptr<NetFirewallDomainRule> &rule, const std::string &domain)
519 {
520 std::vector<sptr<NetFirewallDomainRule>> rules;
521 std::string suffix(domain);
522 auto wildcardCharIndex = suffix.find('*');
523 if (wildcardCharIndex != std::string::npos) {
524 suffix = suffix.substr(wildcardCharIndex + 1);
525 }
526 DNS_CONFIG_PRINT("BuildFirewallDomainLsmTrie: suffix: %{public}s", suffix.c_str());
527 std::transform(suffix.begin(), suffix.end(), suffix.begin(), ::tolower);
528 if (rule->ruleAction == FirewallRuleAction::RULE_DENY) {
529 if (domainDenyLsmTrie_->LongestSuffixMatch(suffix, rules)) {
530 rules.emplace_back(std::move(rule));
531 domainDenyLsmTrie_->Update(suffix, rules);
532 return;
533 }
534 rules.emplace_back(std::move(rule));
535 domainDenyLsmTrie_->Insert(suffix, rules);
536 } else {
537 if (domainAllowLsmTrie_->LongestSuffixMatch(suffix, rules)) {
538 rules.emplace_back(std::move(rule));
539 domainAllowLsmTrie_->Update(suffix, rules);
540 return;
541 }
542 rules.emplace_back(std::move(rule));
543 domainAllowLsmTrie_->Insert(suffix, rules);
544 }
545 }
546
BuildFirewallDomainMap(const sptr<NetFirewallDomainRule> & rule,const std::string & raw)547 void DnsParamCache::BuildFirewallDomainMap(const sptr<NetFirewallDomainRule> &rule, const std::string &raw)
548 {
549 DNS_CONFIG_PRINT("BuildFirewallDomainMap: domain: %{public}s", raw.c_str());
550 std::string domain(raw);
551 std::vector<sptr<NetFirewallDomainRule>> rules;
552 std::transform(domain.begin(), domain.end(), domain.begin(), ::tolower);
553 if (rule->ruleAction == FirewallRuleAction::RULE_DENY) {
554 auto it = netFirewallDomainRulesDenyMap_.find(domain);
555 if (it != netFirewallDomainRulesDenyMap_.end()) {
556 rules = it->second;
557 }
558
559 rules.emplace_back(std::move(rule));
560 netFirewallDomainRulesDenyMap_.emplace(domain, std::move(rules));
561 } else {
562 auto it = netFirewallDomainRulesAllowMap_.find(domain);
563 if (it != netFirewallDomainRulesAllowMap_.end()) {
564 rules = it->second;
565 }
566
567 rules.emplace_back(rule);
568 netFirewallDomainRulesAllowMap_.emplace(domain, std::move(rules));
569 }
570 }
571
SetFirewallDomainRules(const std::vector<sptr<NetFirewallDomainRule>> & ruleList)572 int32_t DnsParamCache::SetFirewallDomainRules(const std::vector<sptr<NetFirewallDomainRule>> &ruleList)
573 {
574 if (!domainAllowLsmTrie_) {
575 domainAllowLsmTrie_ =
576 std::make_shared<NetManagerStandard::SuffixMatchTrie<std::vector<sptr<NetFirewallDomainRule>>>>();
577 }
578 if (!domainDenyLsmTrie_) {
579 domainDenyLsmTrie_ =
580 std::make_shared<NetManagerStandard::SuffixMatchTrie<std::vector<sptr<NetFirewallDomainRule>>>>();
581 }
582 for (const auto &rule : ruleList) {
583 for (const auto ¶m : rule->domains) {
584 if (param.isWildcard) {
585 BuildFirewallDomainLsmTrie(rule, param.domain);
586 } else {
587 BuildFirewallDomainMap(rule, param.domain);
588 }
589 }
590 }
591 return 0;
592 }
593
ClearFirewallRules(NetFirewallRuleType type)594 int32_t DnsParamCache::ClearFirewallRules(NetFirewallRuleType type)
595 {
596 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
597 switch (type) {
598 case NetFirewallRuleType::RULE_DNS:
599 firewallDnsRules_.clear();
600 netFirewallDnsRuleMap_.clear();
601 break;
602 case NetFirewallRuleType::RULE_DOMAIN: {
603 firewallDomainRules_.clear();
604 netFirewallDomainRulesAllowMap_.clear();
605 netFirewallDomainRulesDenyMap_.clear();
606 if (domainAllowLsmTrie_) {
607 domainAllowLsmTrie_ = nullptr;
608 }
609 if (domainDenyLsmTrie_) {
610 domainDenyLsmTrie_ = nullptr;
611 }
612 OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
613 break;
614 }
615 case NetFirewallRuleType::RULE_ALL: {
616 firewallDnsRules_.clear();
617 netFirewallDnsRuleMap_.clear();
618 firewallDomainRules_.clear();
619 netFirewallDomainRulesAllowMap_.clear();
620 netFirewallDomainRulesDenyMap_.clear();
621 if (domainAllowLsmTrie_) {
622 domainAllowLsmTrie_ = nullptr;
623 }
624 if (domainDenyLsmTrie_) {
625 domainDenyLsmTrie_ = nullptr;
626 }
627 OHOS::NetManagerStandard::NetsysBpfNetFirewall::GetInstance()->ClearDomainCache();
628 break;
629 }
630 default:
631 break;
632 }
633 return 0;
634 }
635
NotifyDomianIntercept(int32_t appUid,const std::string & hostName)636 void DnsParamCache::NotifyDomianIntercept(int32_t appUid, const std::string &hostName)
637 {
638 if (hostName.empty()) {
639 return;
640 }
641 std::string host = hostName.substr(0, hostName.find(' '));
642 NETNATIVE_LOGI("NotifyDomianIntercept: appUid: %{public}d, hostName: %{private}s", appUid, host.c_str());
643 sptr<InterceptRecord> record = sptr<InterceptRecord>::MakeSptr();
644 record->time = (int32_t)time(NULL);
645 record->appUid = appUid;
646 record->domain = host;
647
648 if (oldRecord_ != nullptr && (record->time - oldRecord_->time) < INTERCEPT_BUFF_INTERVAL_SEC) {
649 if (record->appUid == oldRecord_->appUid && record->domain == oldRecord_->domain) {
650 return;
651 }
652 }
653 oldRecord_ = record;
654 for (const auto &callback : callbacks_) {
655 callback->OnIntercept(record);
656 }
657 }
658
RegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> & callback)659 int32_t DnsParamCache::RegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)
660 {
661 if (!callback) {
662 return -1;
663 }
664
665 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
666 callbacks_.emplace_back(callback);
667
668 return 0;
669 }
670
UnRegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> & callback)671 int32_t DnsParamCache::UnRegisterNetFirewallCallback(const sptr<NetsysNative::INetFirewallCallback> &callback)
672 {
673 if (!callback) {
674 return -1;
675 }
676
677 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
678 for (auto it = callbacks_.begin(); it != callbacks_.end(); ++it) {
679 if (*it == callback) {
680 callbacks_.erase(it);
681 return 0;
682 }
683 }
684 return -1;
685 }
686
SetFirewallCurrentUserId(int32_t userId)687 int32_t DnsParamCache::SetFirewallCurrentUserId(int32_t userId)
688 {
689 currentUserId_ = userId;
690 ClearAllDnsCache();
691 return 0;
692 }
693
ClearAllDnsCache()694 void DnsParamCache::ClearAllDnsCache()
695 {
696 NETNATIVE_LOGI("ClearAllDnsCache");
697 for (auto it = serverConfigMap_.begin(); it != serverConfigMap_.end(); it++) {
698 it->second.GetCache().Clear();
699 }
700 }
701 #endif
702
GetDumpInfo(std::string & info)703 void DnsParamCache::GetDumpInfo(std::string &info)
704 {
705 std::string dnsData;
706 static const std::string TAB = " ";
707 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
708 std::for_each(serverConfigMap_.begin(), serverConfigMap_.end(), [&dnsData](const auto &serverConfig) {
709 dnsData.append(TAB + "NetId: " + std::to_string(serverConfig.second.GetNetId()) + "\n");
710 dnsData.append(TAB + "TimeoutMsec: " + std::to_string(serverConfig.second.GetTimeoutMsec()) + "\n");
711 dnsData.append(TAB + "RetryCount: " + std::to_string(serverConfig.second.GetRetryCount()) + "\n");
712 dnsData.append(TAB + "Servers:");
713 GetVectorData(serverConfig.second.GetServers(), dnsData);
714 dnsData.append(TAB + "Domains:");
715 GetVectorData(serverConfig.second.GetDomains(), dnsData);
716 });
717 info.append(dnsData);
718 }
719
SetUserDefinedServerFlag(uint16_t netId,bool flag)720 int32_t DnsParamCache::SetUserDefinedServerFlag(uint16_t netId, bool flag)
721 {
722 NETNATIVE_LOGI("DnsParamCache::SetUserDefinedServerFlag, netid:%{public}d, flag:%{public}d,", netId, flag);
723
724 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
725 // select_domains
726 auto it = serverConfigMap_.find(netId);
727 if (it == serverConfigMap_.end()) {
728 NETNATIVE_LOGE("DnsParamCache::SetUserDefinedServerFlag failed, netid is non-existent");
729 return -ENOENT;
730 }
731 it->second.SetUserDefinedServerFlag(flag);
732 return 0;
733 }
734
GetUserDefinedServerFlag(uint16_t netId,bool & flag)735 int32_t DnsParamCache::GetUserDefinedServerFlag(uint16_t netId, bool &flag)
736 {
737 if (netId == 0) {
738 netId = defaultNetId_;
739 NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
740 }
741 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
742 auto it = serverConfigMap_.find(netId);
743 if (it == serverConfigMap_.end()) {
744 DNS_CONFIG_PRINT("GetUserDefinedServerFlag failed: netid is not have netid:%{public}d,", netId);
745 return -ENOENT;
746 }
747 flag = it->second.IsUserDefinedServer();
748 return 0;
749 }
750
GetUserDefinedServerFlag(uint16_t netId,bool & flag,uint32_t uid)751 int32_t DnsParamCache::GetUserDefinedServerFlag(uint16_t netId, bool &flag, uint32_t uid)
752 {
753 if (netId == 0) {
754 netId = defaultNetId_;
755 NETNATIVE_LOG_D("defaultNetId_ = [%{public}u]", netId);
756 }
757 {
758 std::lock_guard<ffrt::mutex> guard(cacheMutex_);
759 for (auto mem : vpnUidRanges_) {
760 if (static_cast<int32_t>(uid) >= mem.begin_ && static_cast<int32_t>(uid) <= mem.end_) {
761 NETNATIVE_LOG_D("is vpn hap");
762 auto it = serverConfigMap_.find(vpnNetId_);
763 if (it == serverConfigMap_.end()) {
764 NETNATIVE_LOG_D("vpn get Config failed: not have vpnnetid:%{public}d,", vpnNetId_);
765 break;
766 }
767 flag = it->second.IsUserDefinedServer();
768 return 0;
769 }
770 }
771 auto it = serverConfigMap_.find(netId);
772 if (it == serverConfigMap_.end()) {
773 DNS_CONFIG_PRINT("GetUserDefinedServerFlag failed: netid is not have netid:%{public}d,", netId);
774 return -ENOENT;
775 }
776 }
777 return GetUserDefinedServerFlag(netId, flag);
778 }
779 } // namespace OHOS::nmd
780