• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "mdns_protocol_impl.h"
17 
18 #include <arpa/inet.h>
19 #include <cstddef>
20 #include <iostream>
21 #include <random>
22 #include <sys/types.h>
23 #include <unistd.h>
24 #include <fcntl.h>
25 
26 #include "mdns_manager.h"
27 #include "mdns_packet_parser.h"
28 #include "net_conn_client.h"
29 #include "netmgr_ext_log_wrapper.h"
30 
31 #include "securec.h"
32 
33 namespace OHOS {
34 namespace NetManagerStandard {
35 
36 constexpr uint32_t DEFAULT_INTEVAL_MS = 2000;
37 constexpr uint32_t DEFAULT_LOST_MS = 20000;
38 constexpr uint32_t DEFAULT_TTL = 120;
39 constexpr uint16_t MDNS_FLUSH_CACHE_BIT = 0x8000;
40 
41 constexpr int PHASE_PTR = 1;
42 constexpr int PHASE_SRV = 2;
43 constexpr int PHASE_DOMAIN = 3;
44 static bool g_isScreenOn = true;
45 
AddrToString(const std::any & addr)46 std::string AddrToString(const std::any &addr)
47 {
48     char buf[INET6_ADDRSTRLEN] = {0};
49     if (std::any_cast<in_addr>(&addr)) {
50         if (inet_ntop(AF_INET, std::any_cast<in_addr>(&addr), buf, sizeof(buf)) == nullptr) {
51             return std::string{};
52         }
53     } else if (std::any_cast<in6_addr>(&addr)) {
54         if (inet_ntop(AF_INET6, std::any_cast<in6_addr>(&addr), buf, sizeof(buf)) == nullptr) {
55             return std::string{};
56         }
57     }
58     return std::string(buf);
59 }
60 
MilliSecondsSinceEpoch()61 int64_t MilliSecondsSinceEpoch()
62 {
63     return std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())
64         .count();
65 }
66 
MDnsProtocolImpl()67 MDnsProtocolImpl::MDnsProtocolImpl()
68 {
69     Init();
70 }
71 
Init()72 void MDnsProtocolImpl::Init()
73 {
74     NETMGR_EXT_LOG_D("mdns_log MDnsProtocolImpl init");
75     listener_.Stop();
76     listener_.CloseAllSocket();
77 
78     if (config_.configAllIface) {
79         listener_.OpenSocketForEachIface(config_.ipv6Support, config_.configLo);
80     } else {
81         listener_.OpenSocketForDefault(config_.ipv6Support);
82     }
83     listener_.SetReceiveHandler(
84         [this](int sock, const MDnsPayload &payload) { return this->ReceivePacket(sock, payload); });
85     listener_.SetFinishedHandler([this](int sock) {
86         std::lock_guard<std::recursive_mutex> guard(mutex_);
87         RunTaskQueue(taskQueue_);
88     });
89     listener_.Start();
90     {
91         std::lock_guard<std::recursive_mutex> guard(mutex_);
92         taskQueue_.clear();
93         taskOnChange_.clear();
94     }
95     AddTask([this]() { return Browse(); }, false);
96 
97     SubscribeCes();
98 }
99 
SubscribeCes()100 void MDnsProtocolImpl::SubscribeCes()
101 {
102     int32_t COMMON_EVENT_SERVICE_ID = 3299;
103     sptr<ISystemAbilityManager> samgrClient = SystemAbilityManagerClient::GetInstance().GetSystemAbilityManager();
104     if (samgrClient == nullptr || samgrClient->CheckSystemAbility(COMMON_EVENT_SERVICE_ID) == nullptr) {
105         NETMGR_EXT_LOG_E("Subscribe:CES SA not ready, wait for the SA Added callback.");
106         return;
107     }
108     EventFwk::MatchingSkills matchSkills;
109     matchSkills.AddEvent(EventFwk::CommonEventSupport::COMMON_EVENT_SCREEN_ON);
110     matchSkills.AddEvent(EventFwk::CommonEventSupport::COMMON_EVENT_SCREEN_OFF);
111 
112     EventFwk::CommonEventSubscribeInfo subcriberInfo(matchSkills);
113     if (subscriber_ == nullptr) {
114         subscriber_ = std::make_shared<MdnsSubscriber>(subcriberInfo);
115     }
116     if (!EventFwk::CommonEventManager::SubscribeCommonEvent(subscriber_)) {
117         NETMGR_EXT_LOG_E("system event register fail.");
118     }
119 }
120 
OnReceiveEvent(const EventFwk::CommonEventData & data)121 void MDnsProtocolImpl::MdnsSubscriber::OnReceiveEvent(const EventFwk::CommonEventData &data)
122 {
123     auto eventName = data.GetWant().GetAction();
124     if (eventName == EventFwk::CommonEventSupport::COMMON_EVENT_SCREEN_ON) {
125         MDnsProtocolImpl::SetScreenState(true);
126     } else {
127         MDnsProtocolImpl::SetScreenState(false);
128     }
129 }
130 
SetScreenState(bool isOn)131 void MDnsProtocolImpl::SetScreenState(bool isOn)
132 {
133     NETMGR_EXT_LOG_I("Mdns SetScreenState isOn[%{public}d]", isOn);
134     g_isScreenOn = isOn;
135 }
136 
Browse()137 bool MDnsProtocolImpl::Browse()
138 {
139     if ((lastRunTime != -1 && MilliSecondsSinceEpoch() - lastRunTime < DEFAULT_INTEVAL_MS) || !g_isScreenOn) {
140         return false;
141     }
142     lastRunTime = MilliSecondsSinceEpoch();
143     std::lock_guard<std::recursive_mutex> guard(mutex_);
144     for (auto &&[key, res] : browserMap_) {
145         NETMGR_EXT_LOG_D("mdns_log Browse browserMap_ key[%{public}s] res.size[%{public}zu]", key.c_str(), res.size());
146         if (nameCbMap_.find(key) != nameCbMap_.end() &&
147             !MDnsManager::GetInstance().IsAvailableCallback(nameCbMap_[key])) {
148             continue;
149         }
150         handleOfflineService(key, res);
151         MDnsPayloadParser parser;
152         MDnsMessage msg{};
153         msg.questions.emplace_back(DNSProto::Question{
154             .name = key,
155             .qtype = DNSProto::RRTYPE_PTR,
156             .qclass = DNSProto::RRCLASS_IN,
157         });
158         listener_.MulticastAll(parser.ToBytes(msg));
159     }
160     return false;
161 }
162 
ConnectControl(int32_t sockfd,sockaddr * serverAddr)163 int32_t MDnsProtocolImpl::ConnectControl(int32_t sockfd, sockaddr* serverAddr)
164 {
165     uint32_t flags = static_cast<uint32_t>(fcntl(sockfd, F_GETFL, 0));
166     fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
167     int32_t ret = connect(sockfd, serverAddr, sizeof(sockaddr));
168     if ((ret < 0) && (errno != EINPROGRESS)) {
169         NETMGR_EXT_LOG_E("connect error: %{public}d", errno);
170         return NETMANAGER_EXT_ERR_INTERNAL;
171     }
172     if (ret == 0) {
173         fcntl(sockfd, F_SETFL, flags); /* restore file status flags */
174         NETMGR_EXT_LOG_I("connect success.");
175         return NETMANAGER_EXT_SUCCESS;
176     }
177 
178     fd_set rset;
179     FD_ZERO(&rset);
180     FD_SET(sockfd, &rset);
181     fd_set wset = rset;
182     timeval tval {1, 0};
183     ret = select(sockfd + 1, &rset, &wset, NULL, &tval);
184     if (ret < 0) { // select error.
185         NETMGR_EXT_LOG_E("select error: %{public}d", errno);
186         return NETMANAGER_EXT_ERR_INTERNAL;
187     }
188     if (ret == 0) { // timeout
189         NETMGR_EXT_LOG_E("connect timeout...");
190         return NETMANAGER_EXT_ERR_INTERNAL;
191     }
192     if (!FD_ISSET(sockfd, &rset) && !FD_ISSET(sockfd, &wset)) {
193         NETMGR_EXT_LOG_E("select error: sockfd not set");
194         return NETMANAGER_EXT_ERR_INTERNAL;
195     }
196 
197     int32_t result = NETMANAGER_EXT_ERR_INTERNAL;
198     socklen_t len = sizeof(result);
199     if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &result, &len) < 0) {
200         NETMGR_EXT_LOG_E("getsockopt error: %{public}d", errno);
201         return NETMANAGER_EXT_ERR_INTERNAL;
202     }
203     if (result != 0) { // connect failed.
204         NETMGR_EXT_LOG_E("connect failed. error: %{public}d", result);
205         return NETMANAGER_EXT_ERR_INTERNAL;
206     }
207     fcntl(sockfd, F_SETFL, flags); /* restore file status flags */
208     NETMGR_EXT_LOG_I("lost but connect success.");
209     return NETMANAGER_EXT_SUCCESS;
210 }
211 
IsConnectivity(const std::string & ip,int32_t port)212 bool MDnsProtocolImpl::IsConnectivity(const std::string &ip, int32_t port)
213 {
214     if (ip.empty()) {
215         NETMGR_EXT_LOG_E("ip is empty");
216         return false;
217     }
218 
219     int32_t sockfd = socket(AF_INET, SOCK_STREAM, 0);
220     if (sockfd < 0) {
221         NETMGR_EXT_LOG_E("create socket error: %{public}d", errno);
222         return false;
223     }
224 
225     struct sockaddr_in serverAddr;
226     if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
227         NETMGR_EXT_LOG_E("memset_s serverAddr failed!");
228         close(sockfd);
229         return false;
230     }
231 
232     serverAddr.sin_family = AF_INET;
233     serverAddr.sin_addr.s_addr = inet_addr(ip.c_str());
234     serverAddr.sin_port = htons(port);
235     if (ConnectControl(sockfd, (struct sockaddr*)&serverAddr) != NETMANAGER_EXT_SUCCESS) {
236         NETMGR_EXT_LOG_I("connect error: %{public}d", errno);
237         close(sockfd);
238         return false;
239     }
240 
241     close(sockfd);
242     return true;
243 }
244 
handleOfflineService(const std::string & key,std::vector<Result> & res)245 void MDnsProtocolImpl::handleOfflineService(const std::string &key, std::vector<Result> &res)
246 {
247     NETMGR_EXT_LOG_D("mdns_log handleOfflineService key:[%{public}s]", key.c_str());
248     for (auto it = res.begin(); it != res.end();) {
249         if (lastRunTime - it->refrehTime > DEFAULT_LOST_MS && it->state == State::LIVE) {
250             std::string fullName = Decorated(it->serviceName + MDNS_DOMAIN_SPLITER_STR + it->serviceType);
251             if ((cacheMap_.find(fullName) != cacheMap_.end()) &&
252                 IsConnectivity(cacheMap_[fullName].addr, cacheMap_[fullName].port)) {
253                 it++;
254                 continue;
255             }
256 
257             it->state = State::DEAD;
258             if (nameCbMap_.find(key) != nameCbMap_.end() && nameCbMap_[key] != nullptr) {
259                 NETMGR_EXT_LOG_W("mdns_log HandleServiceLost");
260                 nameCbMap_[key]->HandleServiceLost(ConvertResultToInfo(*it), NETMANAGER_EXT_SUCCESS);
261             }
262             it = res.erase(it);
263             cacheMap_.erase(fullName);
264         } else {
265             it++;
266         }
267     }
268 }
269 
SetConfig(const MDnsConfig & config)270 void MDnsProtocolImpl::SetConfig(const MDnsConfig &config)
271 {
272     config_ = config;
273 }
274 
GetConfig() const275 const MDnsConfig &MDnsProtocolImpl::GetConfig() const
276 {
277     return config_;
278 }
279 
Decorated(const std::string & name) const280 std::string MDnsProtocolImpl::Decorated(const std::string &name) const
281 {
282     return name + config_.topDomain;
283 }
284 
Register(const Result & info)285 int32_t MDnsProtocolImpl::Register(const Result &info)
286 {
287     NETMGR_EXT_LOG_D("mdns_log Register");
288     if (!(IsNameValid(info.serviceName) && IsTypeValid(info.serviceType) && IsPortValid(info.port))) {
289         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
290     }
291     std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
292     if (!IsDomainValid(name)) {
293         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
294     }
295     {
296         std::lock_guard<std::recursive_mutex> guard(mutex_);
297         if (srvMap_.find(name) != srvMap_.end()) {
298             return NET_MDNS_ERR_SERVICE_INSTANCE_DUPLICATE;
299         }
300         srvMap_.emplace(name, info);
301     }
302     return Announce(info, false);
303 }
304 
UnRegister(const std::string & key)305 int32_t MDnsProtocolImpl::UnRegister(const std::string &key)
306 {
307     NETMGR_EXT_LOG_D("mdns_log UnRegister");
308     std::string name = Decorated(key);
309     std::lock_guard<std::recursive_mutex> guard(mutex_);
310     if (srvMap_.find(name) != srvMap_.end()) {
311         Announce(srvMap_[name], true);
312         srvMap_.erase(name);
313         return NETMANAGER_EXT_SUCCESS;
314     }
315     return NET_MDNS_ERR_SERVICE_INSTANCE_NOT_FOUND;
316 }
317 
DiscoveryFromCache(const std::string & serviceType,const sptr<IDiscoveryCallback> & cb)318 bool MDnsProtocolImpl::DiscoveryFromCache(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
319 {
320     NETMGR_EXT_LOG_D("mdns_log DiscoveryFromCache");
321     std::string name = Decorated(serviceType);
322     std::lock_guard<std::recursive_mutex> guard(mutex_);
323     if (!IsBrowserAvailable(name)) {
324         return false;
325     }
326 
327     if (browserMap_.find(name) == browserMap_.end()) {
328         NETMGR_EXT_LOG_D("mdns_log DiscoveryFromCache browserMap_ not find name");
329         return false;
330     }
331 
332     for (auto &res : browserMap_[name]) {
333         if (res.state == State::REMOVE || res.state == State::DEAD) {
334             continue;
335         }
336         AddTask([cb, info = ConvertResultToInfo(res)]() {
337             NETMGR_EXT_LOG_W("mdns_log DiscoveryFromCache ConvertResultToInfo HandleServiceFound");
338             if (MDnsManager::GetInstance().IsAvailableCallback(cb)) {
339                 cb->HandleServiceFound(info, NETMANAGER_EXT_SUCCESS);
340             }
341             return true;
342         });
343     }
344     return true;
345 }
346 
DiscoveryFromNet(const std::string & serviceType,const sptr<IDiscoveryCallback> & cb)347 bool MDnsProtocolImpl::DiscoveryFromNet(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
348 {
349     NETMGR_EXT_LOG_D("mdns_log DiscoveryFromNet");
350     std::string name = Decorated(serviceType);
351     std::lock_guard<std::recursive_mutex> guard(mutex_);
352     browserMap_.insert({name, std::vector<Result>{}});
353     nameCbMap_[name] = cb;
354     // key is serviceTYpe
355     AddEvent(name, [this, name, cb]() {
356         std::lock_guard<std::recursive_mutex> guard(mutex_);
357         if (!IsBrowserAvailable(name)) {
358             return false;
359         }
360         if (!MDnsManager::GetInstance().IsAvailableCallback(cb)) {
361             return true;
362         }
363         for (auto &res : browserMap_[name]) {
364             std::string fullName = Decorated(res.serviceName + MDNS_DOMAIN_SPLITER_STR + res.serviceType);
365             NETMGR_EXT_LOG_W("mdns_log DiscoveryFromNet name:[%{public}s] fullName:[%{public}s]", name.c_str(),
366                              fullName.c_str());
367             if (cacheMap_.find(fullName) == cacheMap_.end() ||
368                 (res.state == State::ADD || res.state == State::REFRESH)) {
369                 NETMGR_EXT_LOG_W("mdns_log HandleServiceFound");
370                 cb->HandleServiceFound(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
371                 res.state = State::LIVE;
372             }
373             if (res.state == State::REMOVE) {
374                 res.state = State::DEAD;
375                 NETMGR_EXT_LOG_D("mdns_log HandleServiceLost");
376                 cb->HandleServiceLost(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
377                 if (cacheMap_.find(fullName) != cacheMap_.end()) {
378                     cacheMap_.erase(fullName);
379                 }
380             }
381         }
382         return false;
383     });
384 
385     AddTask([=]() {
386             MDnsPayloadParser parser;
387             MDnsMessage msg{};
388             msg.questions.emplace_back(DNSProto::Question{
389                 .name = name,
390                 .qtype = DNSProto::RRTYPE_PTR,
391                 .qclass = DNSProto::RRCLASS_IN,
392             });
393             listener_.MulticastAll(parser.ToBytes(msg));
394             return true;
395         }, false);
396     return true;
397 }
398 
Discovery(const std::string & serviceType,const sptr<IDiscoveryCallback> & cb)399 int32_t MDnsProtocolImpl::Discovery(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
400 {
401     NETMGR_EXT_LOG_D("mdns_log Discovery");
402     DiscoveryFromCache(serviceType, cb);
403     DiscoveryFromNet(serviceType, cb);
404     return NETMANAGER_EXT_SUCCESS;
405 }
406 
ResolveInstanceFromCache(const std::string & name,const sptr<IResolveCallback> & cb)407 bool MDnsProtocolImpl::ResolveInstanceFromCache(const std::string &name, const sptr<IResolveCallback> &cb)
408 {
409     NETMGR_EXT_LOG_D("mdns_log ResolveInstanceFromCache");
410     std::lock_guard<std::recursive_mutex> guard(mutex_);
411     if (!IsInstanceCacheAvailable(name)) {
412         NETMGR_EXT_LOG_W("mdns_log ResolveInstanceFromCache cacheMap_ has no element [%{public}s]", name.c_str());
413         return false;
414     }
415 
416     NETMGR_EXT_LOG_I("mdns_log rr.name : [%{public}s]", name.c_str());
417     Result r = cacheMap_[name];
418     if (IsDomainCacheAvailable(r.domain)) {
419         r.ipv6 = cacheMap_[r.domain].ipv6;
420         r.addr = cacheMap_[r.domain].addr;
421 
422         NETMGR_EXT_LOG_D("mdns_log Add Task DomainCache Available, [%{public}s]", r.domain.c_str());
423         AddTask([cb, info = ConvertResultToInfo(r)]() {
424             if (nullptr != cb) {
425                 cb->HandleResolveResult(info, NETMANAGER_EXT_SUCCESS);
426             }
427             return true;
428         });
429     } else {
430         ResolveFromNet(r.domain, nullptr);
431         NETMGR_EXT_LOG_D("mdns_log Add Event DomainCache UnAvailable, [%{public}s]", r.domain.c_str());
432         AddEvent(r.domain, [this, cb, r]() mutable {
433             if (!IsDomainCacheAvailable(r.domain)) {
434                 return false;
435             }
436             r.ipv6 = cacheMap_[r.domain].ipv6;
437             r.addr = cacheMap_[r.domain].addr;
438             if (nullptr != cb) {
439                 cb->HandleResolveResult(ConvertResultToInfo(r), NETMANAGER_EXT_SUCCESS);
440             }
441             return true;
442         });
443     }
444     return true;
445 }
446 
ResolveInstanceFromNet(const std::string & name,const sptr<IResolveCallback> & cb)447 bool MDnsProtocolImpl::ResolveInstanceFromNet(const std::string &name, const sptr<IResolveCallback> &cb)
448 {
449     NETMGR_EXT_LOG_D("mdns_log ResolveInstanceFromNet");
450     {
451         std::lock_guard<std::recursive_mutex> guard(mutex_);
452         cacheMap_[name].state = State::ADD;
453         ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
454     }
455     MDnsPayloadParser parser;
456     MDnsMessage msg{};
457     msg.questions.emplace_back(DNSProto::Question{
458         .name = name,
459         .qtype = DNSProto::RRTYPE_SRV,
460         .qclass = DNSProto::RRCLASS_IN,
461     });
462     msg.questions.emplace_back(DNSProto::Question{
463         .name = name,
464         .qtype = DNSProto::RRTYPE_TXT,
465         .qclass = DNSProto::RRCLASS_IN,
466     });
467     msg.header.qdcount = msg.questions.size();
468     AddEvent(name, [this, name, cb]() { return ResolveInstanceFromCache(name, cb); });
469     ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
470     return size > 0;
471 }
472 
ResolveFromCache(const std::string & domain,const sptr<IResolveCallback> & cb)473 bool MDnsProtocolImpl::ResolveFromCache(const std::string &domain, const sptr<IResolveCallback> &cb)
474 {
475     NETMGR_EXT_LOG_D("mdns_log ResolveFromCache");
476     std::lock_guard<std::recursive_mutex> guard(mutex_);
477     if (!IsDomainCacheAvailable(domain)) {
478         return false;
479     }
480     AddTask([this, cb, info = ConvertResultToInfo(cacheMap_[domain])]() {
481         if (nullptr != cb) {
482             cb->HandleResolveResult(info, NETMANAGER_EXT_SUCCESS);
483         }
484         return true;
485     });
486     return true;
487 }
488 
ResolveFromNet(const std::string & domain,const sptr<IResolveCallback> & cb)489 bool MDnsProtocolImpl::ResolveFromNet(const std::string &domain, const sptr<IResolveCallback> &cb)
490 {
491     NETMGR_EXT_LOG_D("mdns_log ResolveFromNet");
492     {
493         std::lock_guard<std::recursive_mutex> guard(mutex_);
494         cacheMap_[domain];
495         cacheMap_[domain].domain = domain;
496     }
497     MDnsPayloadParser parser;
498     MDnsMessage msg{};
499     msg.questions.emplace_back(DNSProto::Question{
500         .name = domain,
501         .qtype = DNSProto::RRTYPE_A,
502         .qclass = DNSProto::RRCLASS_IN,
503     });
504     msg.questions.emplace_back(DNSProto::Question{
505         .name = domain,
506         .qtype = DNSProto::RRTYPE_AAAA,
507         .qclass = DNSProto::RRCLASS_IN,
508     });
509     // key is serviceName
510     AddEvent(domain, [this, cb, domain]() { return ResolveFromCache(domain, cb); });
511     ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
512     return size > 0;
513 }
514 
ResolveInstance(const std::string & instance,const sptr<IResolveCallback> & cb)515 int32_t MDnsProtocolImpl::ResolveInstance(const std::string &instance, const sptr<IResolveCallback> &cb)
516 {
517     NETMGR_EXT_LOG_D("mdns_log execute ResolveInstance");
518     if (!IsInstanceValid(instance)) {
519         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
520     }
521     std::string name = Decorated(instance);
522     if (!IsDomainValid(name)) {
523         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
524     }
525     if (ResolveInstanceFromCache(name, cb)) {
526         return NETMANAGER_EXT_SUCCESS;
527     }
528     return ResolveInstanceFromNet(name, cb) ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
529 }
530 
Announce(const Result & info,bool off)531 int32_t MDnsProtocolImpl::Announce(const Result &info, bool off)
532 {
533     NETMGR_EXT_LOG_I("mdns_log Announce message");
534     MDnsMessage response{};
535     response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
536     std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
537     response.answers.emplace_back(DNSProto::ResourceRecord{.name = Decorated(info.serviceType),
538                                                            .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_PTR),
539                                                            .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
540                                                            .ttl = off ? 0U : DEFAULT_TTL,
541                                                            .rdata = name});
542     response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
543                                                            .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_SRV),
544                                                            .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
545                                                            .ttl = off ? 0U : DEFAULT_TTL,
546                                                            .rdata = DNSProto::RDataSrv{
547                                                                .priority = 0,
548                                                                .weight = 0,
549                                                                .port = static_cast<uint16_t>(info.port),
550                                                                .name = GetHostDomain(),
551                                                            }});
552     response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
553                                                            .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_TXT),
554                                                            .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
555                                                            .ttl = off ? 0U : DEFAULT_TTL,
556                                                            .rdata = info.txt});
557     MDnsPayloadParser parser;
558     ssize_t size = listener_.MulticastAll(parser.ToBytes(response));
559     return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
560 }
561 
ReceivePacket(int sock,const MDnsPayload & payload)562 void MDnsProtocolImpl::ReceivePacket(int sock, const MDnsPayload &payload)
563 {
564     if (payload.size() == 0) {
565         return;
566     }
567     MDnsPayloadParser parser;
568     MDnsMessage msg = parser.FromBytes(payload);
569     if (parser.GetError() != 0) {
570         NETMGR_EXT_LOG_E("parser payload failed");
571         return;
572     }
573     if ((msg.header.flags & DNSProto::HEADER_FLAGS_QR_MASK) == 0) {
574         ProcessQuestion(sock, msg);
575     } else {
576         ProcessAnswer(sock, msg);
577     }
578 }
579 
AppendRecord(std::vector<DNSProto::ResourceRecord> & rrlist,DNSProto::RRType type,const std::string & name,const std::any & rdata)580 void MDnsProtocolImpl::AppendRecord(std::vector<DNSProto::ResourceRecord> &rrlist, DNSProto::RRType type,
581                                     const std::string &name, const std::any &rdata)
582 {
583     rrlist.emplace_back(DNSProto::ResourceRecord{.name = name,
584                                                  .rtype = static_cast<uint16_t>(type),
585                                                  .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
586                                                  .ttl = DEFAULT_TTL,
587                                                  .rdata = rdata});
588 }
589 
ProcessQuestion(int sock,const MDnsMessage & msg)590 void MDnsProtocolImpl::ProcessQuestion(int sock, const MDnsMessage &msg)
591 {
592     const sockaddr *saddrIf = listener_.GetSockAddr(sock);
593     if (saddrIf == nullptr) {
594         NETMGR_EXT_LOG_W("mdns_log ProcessQuestion saddrIf is null");
595         return;
596     }
597     std::any anyAddr;
598     DNSProto::RRType anyAddrType;
599     if (saddrIf->sa_family == AF_INET6) {
600         anyAddr = reinterpret_cast<const sockaddr_in6 *>(saddrIf)->sin6_addr;
601         anyAddrType = DNSProto::RRTYPE_AAAA;
602     } else {
603         anyAddr = reinterpret_cast<const sockaddr_in *>(saddrIf)->sin_addr;
604         anyAddrType = DNSProto::RRTYPE_A;
605     }
606     int phase = 0;
607     MDnsMessage response{};
608     response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
609     for (size_t i = 0; i < msg.header.qdcount; ++i) {
610         ProcessQuestionRecord(anyAddr, anyAddrType, msg.questions[i], phase, response);
611     }
612     if (phase < PHASE_DOMAIN) {
613         AppendRecord(response.additional, anyAddrType, GetHostDomain(), anyAddr);
614     }
615 
616     if (phase != 0 && response.answers.size() > 0) {
617         listener_.Multicast(sock, MDnsPayloadParser().ToBytes(response));
618     }
619 }
620 
ProcessQuestionRecord(const std::any & anyAddr,const DNSProto::RRType & anyAddrType,const DNSProto::Question & qu,int & phase,MDnsMessage & response)621 void MDnsProtocolImpl::ProcessQuestionRecord(const std::any &anyAddr, const DNSProto::RRType &anyAddrType,
622                                              const DNSProto::Question &qu, int &phase, MDnsMessage &response)
623 {
624     NETMGR_EXT_LOG_D("mdns_log ProcessQuestionRecord");
625     std::lock_guard<std::recursive_mutex> guard(mutex_);
626     std::string name = qu.name;
627     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_PTR) {
628         std::for_each(srvMap_.begin(), srvMap_.end(), [&](const auto &elem) -> void {
629             if (EndsWith(elem.first, name)) {
630                 AppendRecord(response.answers, DNSProto::RRTYPE_PTR, name, elem.first);
631                 AppendRecord(response.additional, DNSProto::RRTYPE_SRV, elem.first,
632                              DNSProto::RDataSrv{
633                                  .priority = 0,
634                                  .weight = 0,
635                                  .port = static_cast<uint16_t>(elem.second.port),
636                                  .name = GetHostDomain(),
637                              });
638                 AppendRecord(response.additional, DNSProto::RRTYPE_TXT, elem.first, elem.second.txt);
639             }
640         });
641         phase = std::max(phase, PHASE_PTR);
642     }
643     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_SRV) {
644         auto iter = srvMap_.find(name);
645         if (iter == srvMap_.end()) {
646             return;
647         }
648         AppendRecord(response.answers, DNSProto::RRTYPE_SRV, name,
649                      DNSProto::RDataSrv{
650                          .priority = 0,
651                          .weight = 0,
652                          .port = static_cast<uint16_t>(iter->second.port),
653                          .name = GetHostDomain(),
654                      });
655         phase = std::max(phase, PHASE_SRV);
656     }
657     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_TXT) {
658         auto iter = srvMap_.find(name);
659         if (iter == srvMap_.end()) {
660             return;
661         }
662         AppendRecord(response.answers, DNSProto::RRTYPE_TXT, name, iter->second.txt);
663         phase = std::max(phase, PHASE_SRV);
664     }
665     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_A || qu.qtype == DNSProto::RRTYPE_AAAA) {
666         if (name != GetHostDomain() || (qu.qtype != DNSProto::RRTYPE_ANY && anyAddrType != qu.qtype)) {
667             return;
668         }
669         AppendRecord(response.answers, anyAddrType, name, anyAddr);
670         phase = std::max(phase, PHASE_DOMAIN);
671     }
672 }
673 
ProcessAnswer(int sock,const MDnsMessage & msg)674 void MDnsProtocolImpl::ProcessAnswer(int sock, const MDnsMessage &msg)
675 {
676     const sockaddr *saddrIf = listener_.GetSockAddr(sock);
677     if (saddrIf == nullptr) {
678         return;
679     }
680     bool v6 = (saddrIf->sa_family == AF_INET6);
681     std::set<std::string> changed;
682     for (const auto &answer : msg.answers) {
683         ProcessAnswerRecord(v6, answer, changed);
684     }
685     for (const auto &i : msg.additional) {
686         ProcessAnswerRecord(v6, i, changed);
687     }
688     for (const auto &i : changed) {
689         std::lock_guard<std::recursive_mutex> guard(mutex_);
690         RunTaskQueue(taskOnChange_[i]);
691         KillCache(i);
692     }
693 }
694 
UpdatePtr(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)695 void MDnsProtocolImpl::UpdatePtr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
696 {
697     const std::string *data = std::any_cast<std::string>(&rr.rdata);
698     if (data == nullptr) {
699         return;
700     }
701 
702     std::string name = rr.name;
703     if (browserMap_.find(name) == browserMap_.end()) {
704         return;
705     }
706     auto &results = browserMap_[name];
707     std::string srvName;
708     std::string srvType;
709     ExtractNameAndType(*data, srvName, srvType);
710     if (srvName.empty() || srvType.empty()) {
711         return;
712     }
713     auto res =
714         std::find_if(results.begin(), results.end(), [&](const auto &elem) { return elem.serviceName == srvName; });
715     if (res == results.end()) {
716         results.emplace_back(Result{
717             .serviceName = srvName,
718             .serviceType = srvType,
719             .state = State::ADD,
720         });
721     }
722     res = std::find_if(results.begin(), results.end(), [&](const auto &elem) { return elem.serviceName == srvName; });
723     if (res->serviceName != srvName || res->state == State::DEAD) {
724         res->state = State::REFRESH;
725         res->serviceName = srvName;
726     }
727     if (rr.ttl == 0) {
728         res->state = State::REMOVE;
729     }
730     if (res->state != State::LIVE && res->state != State::DEAD) {
731         changed.emplace(name);
732     }
733     res->ttl = rr.ttl;
734     res->refrehTime = MilliSecondsSinceEpoch();
735 }
736 
UpdateSrv(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)737 void MDnsProtocolImpl::UpdateSrv(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
738 {
739     const DNSProto::RDataSrv *srv = std::any_cast<DNSProto::RDataSrv>(&rr.rdata);
740     if (srv == nullptr) {
741         return;
742     }
743     std::string name = rr.name;
744     if (cacheMap_.find(name) == cacheMap_.end()) {
745         ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
746         cacheMap_[name].state = State::ADD;
747         cacheMap_[name].domain = srv->name;
748         cacheMap_[name].port = srv->port;
749     }
750     Result &result = cacheMap_[name];
751     if (result.domain != srv->name || result.port != srv->port || result.state == State::DEAD) {
752         if (result.state != State::ADD) {
753             result.state = State::REFRESH;
754         }
755         result.domain = srv->name;
756         result.port = srv->port;
757     }
758     if (rr.ttl == 0) {
759         result.state = State::REMOVE;
760     }
761     if (result.state != State::LIVE && result.state != State::DEAD) {
762         changed.emplace(name);
763     }
764     result.ttl = rr.ttl;
765     result.refrehTime = MilliSecondsSinceEpoch();
766 }
767 
UpdateTxt(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)768 void MDnsProtocolImpl::UpdateTxt(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
769 {
770     const TxtRecordEncoded *txt = std::any_cast<TxtRecordEncoded>(&rr.rdata);
771     if (txt == nullptr) {
772         return;
773     }
774     std::string name = rr.name;
775     if (cacheMap_.find(name) == cacheMap_.end()) {
776         ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
777         cacheMap_[name].state = State::ADD;
778         cacheMap_[name].txt = *txt;
779     }
780     Result &result = cacheMap_[name];
781     if (result.txt != *txt || result.state == State::DEAD) {
782         if (result.state != State::ADD) {
783             result.state = State::REFRESH;
784         }
785         result.txt = *txt;
786     }
787     if (rr.ttl == 0) {
788         result.state = State::REMOVE;
789     }
790     if (result.state != State::LIVE && result.state != State::DEAD) {
791         changed.emplace(name);
792     }
793     result.ttl = rr.ttl;
794     result.refrehTime = MilliSecondsSinceEpoch();
795 }
796 
UpdateAddr(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)797 void MDnsProtocolImpl::UpdateAddr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
798 {
799     if (v6 != (rr.rtype == DNSProto::RRTYPE_AAAA)) {
800         return;
801     }
802     const std::string addr = AddrToString(rr.rdata);
803     bool v6rr = (rr.rtype == DNSProto::RRTYPE_AAAA);
804     if (addr.empty()) {
805         return;
806     }
807     std::string name = rr.name;
808     if (cacheMap_.find(name) == cacheMap_.end()) {
809         ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
810         cacheMap_[name].state = State::ADD;
811         cacheMap_[name].ipv6 = v6rr;
812         cacheMap_[name].addr = addr;
813     }
814     Result &result = cacheMap_[name];
815     if (result.addr != addr || result.ipv6 != v6rr || result.state == State::DEAD) {
816         result.state = State::REFRESH;
817         result.addr = addr;
818         result.ipv6 = v6rr;
819     }
820     if (rr.ttl == 0) {
821         result.state = State::REMOVE;
822     }
823     if (result.state != State::LIVE && result.state != State::DEAD) {
824         changed.emplace(name);
825     }
826     result.ttl = rr.ttl;
827     result.refrehTime = MilliSecondsSinceEpoch();
828 }
829 
ProcessAnswerRecord(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)830 void MDnsProtocolImpl::ProcessAnswerRecord(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
831 {
832     NETMGR_EXT_LOG_D("mdns_log ProcessAnswerRecord, type=[%{public}d]", rr.rtype);
833     std::lock_guard<std::recursive_mutex> guard(mutex_);
834     std::string name = rr.name;
835     if (cacheMap_.find(name) == cacheMap_.end() && browserMap_.find(name) == browserMap_.end() &&
836         srvMap_.find(name) != srvMap_.end()) {
837         return;
838     }
839     if (rr.rtype == DNSProto::RRTYPE_PTR) {
840         UpdatePtr(v6, rr, changed);
841     } else if (rr.rtype == DNSProto::RRTYPE_SRV) {
842         UpdateSrv(v6, rr, changed);
843     } else if (rr.rtype == DNSProto::RRTYPE_TXT) {
844         UpdateTxt(v6, rr, changed);
845     } else if (rr.rtype == DNSProto::RRTYPE_A || rr.rtype == DNSProto::RRTYPE_AAAA) {
846         UpdateAddr(v6, rr, changed);
847     } else {
848         NETMGR_EXT_LOG_D("mdns_log Unknown packet received, type=[%{public}d]", rr.rtype);
849     }
850 }
851 
GetHostDomain()852 std::string MDnsProtocolImpl::GetHostDomain()
853 {
854     if (config_.hostname.empty()) {
855         char buffer[MDNS_MAX_DOMAIN_LABEL];
856         if (gethostname(buffer, sizeof(buffer)) == 0) {
857             config_.hostname = buffer;
858             static auto uid = []() {
859                 std::random_device rd;
860                 return rd();
861             }();
862             config_.hostname += std::to_string(uid);
863         }
864     }
865     return Decorated(config_.hostname);
866 }
867 
AddTask(const Task & task,bool atonce)868 void MDnsProtocolImpl::AddTask(const Task &task, bool atonce)
869 {
870     {
871         std::lock_guard<std::recursive_mutex> guard(mutex_);
872         taskQueue_.emplace_back(task);
873     }
874     if (atonce) {
875         listener_.TriggerRefresh();
876     }
877 }
878 
ConvertResultToInfo(const MDnsProtocolImpl::Result & result)879 MDnsServiceInfo MDnsProtocolImpl::ConvertResultToInfo(const MDnsProtocolImpl::Result &result)
880 {
881     MDnsServiceInfo info;
882     info.name = result.serviceName;
883     info.type = result.serviceType;
884     if (!result.addr.empty()) {
885         info.family = result.ipv6 ? MDnsServiceInfo::IPV6 : MDnsServiceInfo::IPV4;
886     }
887     info.addr = result.addr;
888     info.port = result.port;
889     info.txtRecord = result.txt;
890     return info;
891 }
892 
IsCacheAvailable(const std::string & key)893 bool MDnsProtocolImpl::IsCacheAvailable(const std::string &key)
894 {
895     constexpr int64_t ms2S = 1000LL;
896     NETMGR_EXT_LOG_D("mdns_log IsCacheAvailable, ttl=[%{public}u]", cacheMap_[key].ttl);
897     return cacheMap_.find(key) != cacheMap_.end() &&
898            (ms2S * cacheMap_[key].ttl) > static_cast<uint32_t>(MilliSecondsSinceEpoch() - cacheMap_[key].refrehTime);
899 }
900 
IsDomainCacheAvailable(const std::string & key)901 bool MDnsProtocolImpl::IsDomainCacheAvailable(const std::string &key)
902 {
903     return IsCacheAvailable(key) && !cacheMap_[key].addr.empty();
904 }
905 
IsInstanceCacheAvailable(const std::string & key)906 bool MDnsProtocolImpl::IsInstanceCacheAvailable(const std::string &key)
907 {
908     return IsCacheAvailable(key) && !cacheMap_[key].domain.empty();
909 }
910 
IsBrowserAvailable(const std::string & key)911 bool MDnsProtocolImpl::IsBrowserAvailable(const std::string &key)
912 {
913     return browserMap_.find(key) != browserMap_.end() && !browserMap_[key].empty();
914 }
915 
AddEvent(const std::string & key,const Task & task)916 void MDnsProtocolImpl::AddEvent(const std::string &key, const Task &task)
917 {
918     std::lock_guard<std::recursive_mutex> guard(mutex_);
919     taskOnChange_[key].emplace_back(task);
920 }
921 
RunTaskQueue(std::list<Task> & queue)922 void MDnsProtocolImpl::RunTaskQueue(std::list<Task> &queue)
923 {
924     std::list<Task> tmp;
925     for (auto &&func : queue) {
926         if (!func()) {
927             tmp.emplace_back(func);
928         }
929     }
930     tmp.swap(queue);
931 }
932 
KillCache(const std::string & key)933 void MDnsProtocolImpl::KillCache(const std::string &key)
934 {
935     NETMGR_EXT_LOG_D("mdns_log KillCache");
936     if (IsBrowserAvailable(key) && browserMap_.find(key) != browserMap_.end()) {
937         for (auto it = browserMap_[key].begin(); it != browserMap_[key].end();) {
938             KillBrowseCache(key, it);
939         }
940     }
941     if (IsCacheAvailable(key)) {
942         std::lock_guard<std::recursive_mutex> guard(mutex_);
943         auto &elem = cacheMap_[key];
944         if (elem.state == State::REMOVE) {
945             elem.state = State::DEAD;
946             cacheMap_.erase(key);
947         } else if (elem.state == State::ADD || elem.state == State::REFRESH) {
948             elem.state = State::LIVE;
949         }
950     }
951 }
952 
KillBrowseCache(const std::string & key,std::vector<Result>::iterator & it)953 void MDnsProtocolImpl::KillBrowseCache(const std::string &key, std::vector<Result>::iterator &it)
954 {
955     NETMGR_EXT_LOG_D("mdns_log KillBrowseCache");
956     if (it->state == State::REMOVE) {
957         it->state = State::DEAD;
958         if (nameCbMap_.find(key) != nameCbMap_.end()) {
959             NETMGR_EXT_LOG_D("mdns_log HandleServiceLost");
960             nameCbMap_[key]->HandleServiceLost(ConvertResultToInfo(*it), NETMANAGER_EXT_SUCCESS);
961         }
962         std::string fullName = Decorated(it->serviceName + MDNS_DOMAIN_SPLITER_STR + it->serviceType);
963         cacheMap_.erase(fullName);
964         it = browserMap_[key].erase(it);
965     } else if (it->state == State::ADD || it->state == State::REFRESH) {
966         it->state = State::LIVE;
967         it++;
968     } else {
969         it++;
970     }
971 }
972 
StopCbMap(const std::string & serviceType)973 int32_t MDnsProtocolImpl::StopCbMap(const std::string &serviceType)
974 {
975     NETMGR_EXT_LOG_D("mdns_log StopCbMap");
976     std::lock_guard<std::recursive_mutex> guard(mutex_);
977     std::string name = Decorated(serviceType);
978     sptr<IDiscoveryCallback> cb = nullptr;
979     if (nameCbMap_.find(name) != nameCbMap_.end()) {
980         cb = nameCbMap_[name];
981         nameCbMap_.erase(name);
982     }
983     taskOnChange_.erase(name);
984     auto it = browserMap_.find(name);
985     if (it != browserMap_.end()) {
986         if (cb != nullptr) {
987             NETMGR_EXT_LOG_I("mdns_log StopCbMap res size:[%{public}zu]", it->second.size());
988             for (auto &&res : it->second) {
989                 NETMGR_EXT_LOG_W("mdns_log HandleServiceLost");
990                 cb->HandleServiceLost(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
991             }
992         }
993         browserMap_.erase(name);
994     }
995     return NETMANAGER_SUCCESS;
996 }
997 } // namespace NetManagerStandard
998 } // namespace OHOS
999