1 /*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "mdns_protocol_impl.h"
17
18 #include <arpa/inet.h>
19 #include <cstddef>
20 #include <iostream>
21 #include <random>
22 #include <sys/types.h>
23 #include <unistd.h>
24 #include <fcntl.h>
25
26 #include "mdns_manager.h"
27 #include "mdns_packet_parser.h"
28 #include "net_conn_client.h"
29 #include "netmgr_ext_log_wrapper.h"
30
31 #include "securec.h"
32
33 namespace OHOS {
34 namespace NetManagerStandard {
35
36 constexpr uint32_t DEFAULT_INTEVAL_MS = 2000;
37 constexpr uint32_t DEFAULT_LOST_MS = 20000;
38 constexpr uint32_t DEFAULT_TTL = 120;
39 constexpr uint16_t MDNS_FLUSH_CACHE_BIT = 0x8000;
40
41 constexpr int PHASE_PTR = 1;
42 constexpr int PHASE_SRV = 2;
43 constexpr int PHASE_DOMAIN = 3;
44 static bool g_isScreenOn = true;
45
AddrToString(const std::any & addr)46 std::string AddrToString(const std::any &addr)
47 {
48 char buf[INET6_ADDRSTRLEN] = {0};
49 if (std::any_cast<in_addr>(&addr)) {
50 if (inet_ntop(AF_INET, std::any_cast<in_addr>(&addr), buf, sizeof(buf)) == nullptr) {
51 return std::string{};
52 }
53 } else if (std::any_cast<in6_addr>(&addr)) {
54 if (inet_ntop(AF_INET6, std::any_cast<in6_addr>(&addr), buf, sizeof(buf)) == nullptr) {
55 return std::string{};
56 }
57 }
58 return std::string(buf);
59 }
60
MilliSecondsSinceEpoch()61 int64_t MilliSecondsSinceEpoch()
62 {
63 return std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())
64 .count();
65 }
66
MDnsProtocolImpl()67 MDnsProtocolImpl::MDnsProtocolImpl()
68 {
69 Init();
70 }
71
Init()72 void MDnsProtocolImpl::Init()
73 {
74 NETMGR_EXT_LOG_D("mdns_log MDnsProtocolImpl init");
75 listener_.Stop();
76 listener_.CloseAllSocket();
77
78 if (config_.configAllIface) {
79 listener_.OpenSocketForEachIface(config_.ipv6Support, config_.configLo);
80 } else {
81 listener_.OpenSocketForDefault(config_.ipv6Support);
82 }
83 listener_.SetReceiveHandler(
84 [this](int sock, const MDnsPayload &payload) { return this->ReceivePacket(sock, payload); });
85 listener_.SetFinishedHandler([this](int sock) {
86 std::lock_guard<std::recursive_mutex> guard(mutex_);
87 RunTaskQueue(taskQueue_);
88 });
89 listener_.Start();
90 {
91 std::lock_guard<std::recursive_mutex> guard(mutex_);
92 taskQueue_.clear();
93 taskOnChange_.clear();
94 }
95 AddTask([this]() { return Browse(); }, false);
96
97 SubscribeCes();
98 }
99
SubscribeCes()100 void MDnsProtocolImpl::SubscribeCes()
101 {
102 int32_t COMMON_EVENT_SERVICE_ID = 3299;
103 sptr<ISystemAbilityManager> samgrClient = SystemAbilityManagerClient::GetInstance().GetSystemAbilityManager();
104 if (samgrClient == nullptr || samgrClient->CheckSystemAbility(COMMON_EVENT_SERVICE_ID) == nullptr) {
105 NETMGR_EXT_LOG_E("Subscribe:CES SA not ready, wait for the SA Added callback.");
106 return;
107 }
108 EventFwk::MatchingSkills matchSkills;
109 matchSkills.AddEvent(EventFwk::CommonEventSupport::COMMON_EVENT_SCREEN_ON);
110 matchSkills.AddEvent(EventFwk::CommonEventSupport::COMMON_EVENT_SCREEN_OFF);
111
112 EventFwk::CommonEventSubscribeInfo subcriberInfo(matchSkills);
113 if (subscriber_ == nullptr) {
114 subscriber_ = std::make_shared<MdnsSubscriber>(subcriberInfo);
115 }
116 if (!EventFwk::CommonEventManager::SubscribeCommonEvent(subscriber_)) {
117 NETMGR_EXT_LOG_E("system event register fail.");
118 }
119 }
120
OnReceiveEvent(const EventFwk::CommonEventData & data)121 void MDnsProtocolImpl::MdnsSubscriber::OnReceiveEvent(const EventFwk::CommonEventData &data)
122 {
123 auto eventName = data.GetWant().GetAction();
124 if (eventName == EventFwk::CommonEventSupport::COMMON_EVENT_SCREEN_ON) {
125 MDnsProtocolImpl::SetScreenState(true);
126 } else {
127 MDnsProtocolImpl::SetScreenState(false);
128 }
129 }
130
SetScreenState(bool isOn)131 void MDnsProtocolImpl::SetScreenState(bool isOn)
132 {
133 NETMGR_EXT_LOG_I("Mdns SetScreenState isOn[%{public}d]", isOn);
134 g_isScreenOn = isOn;
135 }
136
Browse()137 bool MDnsProtocolImpl::Browse()
138 {
139 if ((lastRunTime != -1 && MilliSecondsSinceEpoch() - lastRunTime < DEFAULT_INTEVAL_MS) || !g_isScreenOn) {
140 return false;
141 }
142 lastRunTime = MilliSecondsSinceEpoch();
143 std::lock_guard<std::recursive_mutex> guard(mutex_);
144 for (auto &&[key, res] : browserMap_) {
145 NETMGR_EXT_LOG_D("mdns_log Browse browserMap_ key[%{public}s] res.size[%{public}zu]", key.c_str(), res.size());
146 if (nameCbMap_.find(key) != nameCbMap_.end() &&
147 !MDnsManager::GetInstance().IsAvailableCallback(nameCbMap_[key])) {
148 continue;
149 }
150 handleOfflineService(key, res);
151 MDnsPayloadParser parser;
152 MDnsMessage msg{};
153 msg.questions.emplace_back(DNSProto::Question{
154 .name = key,
155 .qtype = DNSProto::RRTYPE_PTR,
156 .qclass = DNSProto::RRCLASS_IN,
157 });
158 listener_.MulticastAll(parser.ToBytes(msg));
159 }
160 return false;
161 }
162
ConnectControl(int32_t sockfd,sockaddr * serverAddr)163 int32_t MDnsProtocolImpl::ConnectControl(int32_t sockfd, sockaddr* serverAddr)
164 {
165 uint32_t flags = static_cast<uint32_t>(fcntl(sockfd, F_GETFL, 0));
166 fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
167 int32_t ret = connect(sockfd, serverAddr, sizeof(sockaddr));
168 if ((ret < 0) && (errno != EINPROGRESS)) {
169 NETMGR_EXT_LOG_E("connect error: %{public}d", errno);
170 return NETMANAGER_EXT_ERR_INTERNAL;
171 }
172 if (ret == 0) {
173 fcntl(sockfd, F_SETFL, flags); /* restore file status flags */
174 NETMGR_EXT_LOG_I("connect success.");
175 return NETMANAGER_EXT_SUCCESS;
176 }
177
178 fd_set rset;
179 FD_ZERO(&rset);
180 FD_SET(sockfd, &rset);
181 fd_set wset = rset;
182 timeval tval {1, 0};
183 ret = select(sockfd + 1, &rset, &wset, NULL, &tval);
184 if (ret < 0) { // select error.
185 NETMGR_EXT_LOG_E("select error: %{public}d", errno);
186 return NETMANAGER_EXT_ERR_INTERNAL;
187 }
188 if (ret == 0) { // timeout
189 NETMGR_EXT_LOG_E("connect timeout...");
190 return NETMANAGER_EXT_ERR_INTERNAL;
191 }
192 if (!FD_ISSET(sockfd, &rset) && !FD_ISSET(sockfd, &wset)) {
193 NETMGR_EXT_LOG_E("select error: sockfd not set");
194 return NETMANAGER_EXT_ERR_INTERNAL;
195 }
196
197 int32_t result = NETMANAGER_EXT_ERR_INTERNAL;
198 socklen_t len = sizeof(result);
199 if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &result, &len) < 0) {
200 NETMGR_EXT_LOG_E("getsockopt error: %{public}d", errno);
201 return NETMANAGER_EXT_ERR_INTERNAL;
202 }
203 if (result != 0) { // connect failed.
204 NETMGR_EXT_LOG_E("connect failed. error: %{public}d", result);
205 return NETMANAGER_EXT_ERR_INTERNAL;
206 }
207 fcntl(sockfd, F_SETFL, flags); /* restore file status flags */
208 NETMGR_EXT_LOG_I("lost but connect success.");
209 return NETMANAGER_EXT_SUCCESS;
210 }
211
IsConnectivity(const std::string & ip,int32_t port)212 bool MDnsProtocolImpl::IsConnectivity(const std::string &ip, int32_t port)
213 {
214 if (ip.empty()) {
215 NETMGR_EXT_LOG_E("ip is empty");
216 return false;
217 }
218
219 int32_t sockfd = socket(AF_INET, SOCK_STREAM, 0);
220 if (sockfd < 0) {
221 NETMGR_EXT_LOG_E("create socket error: %{public}d", errno);
222 return false;
223 }
224
225 struct sockaddr_in serverAddr;
226 if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) {
227 NETMGR_EXT_LOG_E("memset_s serverAddr failed!");
228 close(sockfd);
229 return false;
230 }
231
232 serverAddr.sin_family = AF_INET;
233 serverAddr.sin_addr.s_addr = inet_addr(ip.c_str());
234 serverAddr.sin_port = htons(port);
235 if (ConnectControl(sockfd, (struct sockaddr*)&serverAddr) != NETMANAGER_EXT_SUCCESS) {
236 NETMGR_EXT_LOG_I("connect error: %{public}d", errno);
237 close(sockfd);
238 return false;
239 }
240
241 close(sockfd);
242 return true;
243 }
244
handleOfflineService(const std::string & key,std::vector<Result> & res)245 void MDnsProtocolImpl::handleOfflineService(const std::string &key, std::vector<Result> &res)
246 {
247 NETMGR_EXT_LOG_D("mdns_log handleOfflineService key:[%{public}s]", key.c_str());
248 for (auto it = res.begin(); it != res.end();) {
249 if (lastRunTime - it->refrehTime > DEFAULT_LOST_MS && it->state == State::LIVE) {
250 std::string fullName = Decorated(it->serviceName + MDNS_DOMAIN_SPLITER_STR + it->serviceType);
251 if ((cacheMap_.find(fullName) != cacheMap_.end()) &&
252 IsConnectivity(cacheMap_[fullName].addr, cacheMap_[fullName].port)) {
253 it++;
254 continue;
255 }
256
257 it->state = State::DEAD;
258 if (nameCbMap_.find(key) != nameCbMap_.end() && nameCbMap_[key] != nullptr) {
259 NETMGR_EXT_LOG_W("mdns_log HandleServiceLost");
260 nameCbMap_[key]->HandleServiceLost(ConvertResultToInfo(*it), NETMANAGER_EXT_SUCCESS);
261 }
262 it = res.erase(it);
263 cacheMap_.erase(fullName);
264 } else {
265 it++;
266 }
267 }
268 }
269
SetConfig(const MDnsConfig & config)270 void MDnsProtocolImpl::SetConfig(const MDnsConfig &config)
271 {
272 config_ = config;
273 }
274
GetConfig() const275 const MDnsConfig &MDnsProtocolImpl::GetConfig() const
276 {
277 return config_;
278 }
279
Decorated(const std::string & name) const280 std::string MDnsProtocolImpl::Decorated(const std::string &name) const
281 {
282 return name + config_.topDomain;
283 }
284
Register(const Result & info)285 int32_t MDnsProtocolImpl::Register(const Result &info)
286 {
287 NETMGR_EXT_LOG_D("mdns_log Register");
288 if (!(IsNameValid(info.serviceName) && IsTypeValid(info.serviceType) && IsPortValid(info.port))) {
289 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
290 }
291 std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
292 if (!IsDomainValid(name)) {
293 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
294 }
295 {
296 std::lock_guard<std::recursive_mutex> guard(mutex_);
297 if (srvMap_.find(name) != srvMap_.end()) {
298 return NET_MDNS_ERR_SERVICE_INSTANCE_DUPLICATE;
299 }
300 srvMap_.emplace(name, info);
301 }
302 return Announce(info, false);
303 }
304
UnRegister(const std::string & key)305 int32_t MDnsProtocolImpl::UnRegister(const std::string &key)
306 {
307 NETMGR_EXT_LOG_D("mdns_log UnRegister");
308 std::string name = Decorated(key);
309 std::lock_guard<std::recursive_mutex> guard(mutex_);
310 if (srvMap_.find(name) != srvMap_.end()) {
311 Announce(srvMap_[name], true);
312 srvMap_.erase(name);
313 return NETMANAGER_EXT_SUCCESS;
314 }
315 return NET_MDNS_ERR_SERVICE_INSTANCE_NOT_FOUND;
316 }
317
DiscoveryFromCache(const std::string & serviceType,const sptr<IDiscoveryCallback> & cb)318 bool MDnsProtocolImpl::DiscoveryFromCache(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
319 {
320 NETMGR_EXT_LOG_D("mdns_log DiscoveryFromCache");
321 std::string name = Decorated(serviceType);
322 std::lock_guard<std::recursive_mutex> guard(mutex_);
323 if (!IsBrowserAvailable(name)) {
324 return false;
325 }
326
327 if (browserMap_.find(name) == browserMap_.end()) {
328 NETMGR_EXT_LOG_D("mdns_log DiscoveryFromCache browserMap_ not find name");
329 return false;
330 }
331
332 for (auto &res : browserMap_[name]) {
333 if (res.state == State::REMOVE || res.state == State::DEAD) {
334 continue;
335 }
336 AddTask([cb, info = ConvertResultToInfo(res)]() {
337 NETMGR_EXT_LOG_W("mdns_log DiscoveryFromCache ConvertResultToInfo HandleServiceFound");
338 if (MDnsManager::GetInstance().IsAvailableCallback(cb)) {
339 cb->HandleServiceFound(info, NETMANAGER_EXT_SUCCESS);
340 }
341 return true;
342 });
343 }
344 return true;
345 }
346
DiscoveryFromNet(const std::string & serviceType,const sptr<IDiscoveryCallback> & cb)347 bool MDnsProtocolImpl::DiscoveryFromNet(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
348 {
349 NETMGR_EXT_LOG_D("mdns_log DiscoveryFromNet");
350 std::string name = Decorated(serviceType);
351 std::lock_guard<std::recursive_mutex> guard(mutex_);
352 browserMap_.insert({name, std::vector<Result>{}});
353 nameCbMap_[name] = cb;
354 // key is serviceTYpe
355 AddEvent(name, [this, name, cb]() {
356 std::lock_guard<std::recursive_mutex> guard(mutex_);
357 if (!IsBrowserAvailable(name)) {
358 return false;
359 }
360 if (!MDnsManager::GetInstance().IsAvailableCallback(cb)) {
361 return true;
362 }
363 for (auto &res : browserMap_[name]) {
364 std::string fullName = Decorated(res.serviceName + MDNS_DOMAIN_SPLITER_STR + res.serviceType);
365 NETMGR_EXT_LOG_W("mdns_log DiscoveryFromNet name:[%{public}s] fullName:[%{public}s]", name.c_str(),
366 fullName.c_str());
367 if (cacheMap_.find(fullName) == cacheMap_.end() ||
368 (res.state == State::ADD || res.state == State::REFRESH)) {
369 NETMGR_EXT_LOG_W("mdns_log HandleServiceFound");
370 cb->HandleServiceFound(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
371 res.state = State::LIVE;
372 }
373 if (res.state == State::REMOVE) {
374 res.state = State::DEAD;
375 NETMGR_EXT_LOG_D("mdns_log HandleServiceLost");
376 cb->HandleServiceLost(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
377 if (cacheMap_.find(fullName) != cacheMap_.end()) {
378 cacheMap_.erase(fullName);
379 }
380 }
381 }
382 return false;
383 });
384
385 AddTask([=]() {
386 MDnsPayloadParser parser;
387 MDnsMessage msg{};
388 msg.questions.emplace_back(DNSProto::Question{
389 .name = name,
390 .qtype = DNSProto::RRTYPE_PTR,
391 .qclass = DNSProto::RRCLASS_IN,
392 });
393 listener_.MulticastAll(parser.ToBytes(msg));
394 return true;
395 }, false);
396 return true;
397 }
398
Discovery(const std::string & serviceType,const sptr<IDiscoveryCallback> & cb)399 int32_t MDnsProtocolImpl::Discovery(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
400 {
401 NETMGR_EXT_LOG_D("mdns_log Discovery");
402 DiscoveryFromCache(serviceType, cb);
403 DiscoveryFromNet(serviceType, cb);
404 return NETMANAGER_EXT_SUCCESS;
405 }
406
ResolveInstanceFromCache(const std::string & name,const sptr<IResolveCallback> & cb)407 bool MDnsProtocolImpl::ResolveInstanceFromCache(const std::string &name, const sptr<IResolveCallback> &cb)
408 {
409 NETMGR_EXT_LOG_D("mdns_log ResolveInstanceFromCache");
410 std::lock_guard<std::recursive_mutex> guard(mutex_);
411 if (!IsInstanceCacheAvailable(name)) {
412 NETMGR_EXT_LOG_W("mdns_log ResolveInstanceFromCache cacheMap_ has no element [%{public}s]", name.c_str());
413 return false;
414 }
415
416 NETMGR_EXT_LOG_I("mdns_log rr.name : [%{public}s]", name.c_str());
417 Result r = cacheMap_[name];
418 if (IsDomainCacheAvailable(r.domain)) {
419 r.ipv6 = cacheMap_[r.domain].ipv6;
420 r.addr = cacheMap_[r.domain].addr;
421
422 NETMGR_EXT_LOG_D("mdns_log Add Task DomainCache Available, [%{public}s]", r.domain.c_str());
423 AddTask([cb, info = ConvertResultToInfo(r)]() {
424 if (nullptr != cb) {
425 cb->HandleResolveResult(info, NETMANAGER_EXT_SUCCESS);
426 }
427 return true;
428 });
429 } else {
430 ResolveFromNet(r.domain, nullptr);
431 NETMGR_EXT_LOG_D("mdns_log Add Event DomainCache UnAvailable, [%{public}s]", r.domain.c_str());
432 AddEvent(r.domain, [this, cb, r]() mutable {
433 if (!IsDomainCacheAvailable(r.domain)) {
434 return false;
435 }
436 r.ipv6 = cacheMap_[r.domain].ipv6;
437 r.addr = cacheMap_[r.domain].addr;
438 if (nullptr != cb) {
439 cb->HandleResolveResult(ConvertResultToInfo(r), NETMANAGER_EXT_SUCCESS);
440 }
441 return true;
442 });
443 }
444 return true;
445 }
446
ResolveInstanceFromNet(const std::string & name,const sptr<IResolveCallback> & cb)447 bool MDnsProtocolImpl::ResolveInstanceFromNet(const std::string &name, const sptr<IResolveCallback> &cb)
448 {
449 NETMGR_EXT_LOG_D("mdns_log ResolveInstanceFromNet");
450 {
451 std::lock_guard<std::recursive_mutex> guard(mutex_);
452 cacheMap_[name].state = State::ADD;
453 ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
454 }
455 MDnsPayloadParser parser;
456 MDnsMessage msg{};
457 msg.questions.emplace_back(DNSProto::Question{
458 .name = name,
459 .qtype = DNSProto::RRTYPE_SRV,
460 .qclass = DNSProto::RRCLASS_IN,
461 });
462 msg.questions.emplace_back(DNSProto::Question{
463 .name = name,
464 .qtype = DNSProto::RRTYPE_TXT,
465 .qclass = DNSProto::RRCLASS_IN,
466 });
467 msg.header.qdcount = msg.questions.size();
468 AddEvent(name, [this, name, cb]() { return ResolveInstanceFromCache(name, cb); });
469 ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
470 return size > 0;
471 }
472
ResolveFromCache(const std::string & domain,const sptr<IResolveCallback> & cb)473 bool MDnsProtocolImpl::ResolveFromCache(const std::string &domain, const sptr<IResolveCallback> &cb)
474 {
475 NETMGR_EXT_LOG_D("mdns_log ResolveFromCache");
476 std::lock_guard<std::recursive_mutex> guard(mutex_);
477 if (!IsDomainCacheAvailable(domain)) {
478 return false;
479 }
480 AddTask([this, cb, info = ConvertResultToInfo(cacheMap_[domain])]() {
481 if (nullptr != cb) {
482 cb->HandleResolveResult(info, NETMANAGER_EXT_SUCCESS);
483 }
484 return true;
485 });
486 return true;
487 }
488
ResolveFromNet(const std::string & domain,const sptr<IResolveCallback> & cb)489 bool MDnsProtocolImpl::ResolveFromNet(const std::string &domain, const sptr<IResolveCallback> &cb)
490 {
491 NETMGR_EXT_LOG_D("mdns_log ResolveFromNet");
492 {
493 std::lock_guard<std::recursive_mutex> guard(mutex_);
494 cacheMap_[domain];
495 cacheMap_[domain].domain = domain;
496 }
497 MDnsPayloadParser parser;
498 MDnsMessage msg{};
499 msg.questions.emplace_back(DNSProto::Question{
500 .name = domain,
501 .qtype = DNSProto::RRTYPE_A,
502 .qclass = DNSProto::RRCLASS_IN,
503 });
504 msg.questions.emplace_back(DNSProto::Question{
505 .name = domain,
506 .qtype = DNSProto::RRTYPE_AAAA,
507 .qclass = DNSProto::RRCLASS_IN,
508 });
509 // key is serviceName
510 AddEvent(domain, [this, cb, domain]() { return ResolveFromCache(domain, cb); });
511 ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
512 return size > 0;
513 }
514
ResolveInstance(const std::string & instance,const sptr<IResolveCallback> & cb)515 int32_t MDnsProtocolImpl::ResolveInstance(const std::string &instance, const sptr<IResolveCallback> &cb)
516 {
517 NETMGR_EXT_LOG_D("mdns_log execute ResolveInstance");
518 if (!IsInstanceValid(instance)) {
519 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
520 }
521 std::string name = Decorated(instance);
522 if (!IsDomainValid(name)) {
523 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
524 }
525 if (ResolveInstanceFromCache(name, cb)) {
526 return NETMANAGER_EXT_SUCCESS;
527 }
528 return ResolveInstanceFromNet(name, cb) ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
529 }
530
Announce(const Result & info,bool off)531 int32_t MDnsProtocolImpl::Announce(const Result &info, bool off)
532 {
533 NETMGR_EXT_LOG_I("mdns_log Announce message");
534 MDnsMessage response{};
535 response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
536 std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
537 response.answers.emplace_back(DNSProto::ResourceRecord{.name = Decorated(info.serviceType),
538 .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_PTR),
539 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
540 .ttl = off ? 0U : DEFAULT_TTL,
541 .rdata = name});
542 response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
543 .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_SRV),
544 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
545 .ttl = off ? 0U : DEFAULT_TTL,
546 .rdata = DNSProto::RDataSrv{
547 .priority = 0,
548 .weight = 0,
549 .port = static_cast<uint16_t>(info.port),
550 .name = GetHostDomain(),
551 }});
552 response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
553 .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_TXT),
554 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
555 .ttl = off ? 0U : DEFAULT_TTL,
556 .rdata = info.txt});
557 MDnsPayloadParser parser;
558 ssize_t size = listener_.MulticastAll(parser.ToBytes(response));
559 return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
560 }
561
ReceivePacket(int sock,const MDnsPayload & payload)562 void MDnsProtocolImpl::ReceivePacket(int sock, const MDnsPayload &payload)
563 {
564 if (payload.size() == 0) {
565 return;
566 }
567 MDnsPayloadParser parser;
568 MDnsMessage msg = parser.FromBytes(payload);
569 if (parser.GetError() != 0) {
570 NETMGR_EXT_LOG_E("parser payload failed");
571 return;
572 }
573 if ((msg.header.flags & DNSProto::HEADER_FLAGS_QR_MASK) == 0) {
574 ProcessQuestion(sock, msg);
575 } else {
576 ProcessAnswer(sock, msg);
577 }
578 }
579
AppendRecord(std::vector<DNSProto::ResourceRecord> & rrlist,DNSProto::RRType type,const std::string & name,const std::any & rdata)580 void MDnsProtocolImpl::AppendRecord(std::vector<DNSProto::ResourceRecord> &rrlist, DNSProto::RRType type,
581 const std::string &name, const std::any &rdata)
582 {
583 rrlist.emplace_back(DNSProto::ResourceRecord{.name = name,
584 .rtype = static_cast<uint16_t>(type),
585 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
586 .ttl = DEFAULT_TTL,
587 .rdata = rdata});
588 }
589
ProcessQuestion(int sock,const MDnsMessage & msg)590 void MDnsProtocolImpl::ProcessQuestion(int sock, const MDnsMessage &msg)
591 {
592 const sockaddr *saddrIf = listener_.GetSockAddr(sock);
593 if (saddrIf == nullptr) {
594 NETMGR_EXT_LOG_W("mdns_log ProcessQuestion saddrIf is null");
595 return;
596 }
597 std::any anyAddr;
598 DNSProto::RRType anyAddrType;
599 if (saddrIf->sa_family == AF_INET6) {
600 anyAddr = reinterpret_cast<const sockaddr_in6 *>(saddrIf)->sin6_addr;
601 anyAddrType = DNSProto::RRTYPE_AAAA;
602 } else {
603 anyAddr = reinterpret_cast<const sockaddr_in *>(saddrIf)->sin_addr;
604 anyAddrType = DNSProto::RRTYPE_A;
605 }
606 int phase = 0;
607 MDnsMessage response{};
608 response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
609 for (size_t i = 0; i < msg.header.qdcount; ++i) {
610 ProcessQuestionRecord(anyAddr, anyAddrType, msg.questions[i], phase, response);
611 }
612 if (phase < PHASE_DOMAIN) {
613 AppendRecord(response.additional, anyAddrType, GetHostDomain(), anyAddr);
614 }
615
616 if (phase != 0 && response.answers.size() > 0) {
617 listener_.Multicast(sock, MDnsPayloadParser().ToBytes(response));
618 }
619 }
620
ProcessQuestionRecord(const std::any & anyAddr,const DNSProto::RRType & anyAddrType,const DNSProto::Question & qu,int & phase,MDnsMessage & response)621 void MDnsProtocolImpl::ProcessQuestionRecord(const std::any &anyAddr, const DNSProto::RRType &anyAddrType,
622 const DNSProto::Question &qu, int &phase, MDnsMessage &response)
623 {
624 NETMGR_EXT_LOG_D("mdns_log ProcessQuestionRecord");
625 std::lock_guard<std::recursive_mutex> guard(mutex_);
626 std::string name = qu.name;
627 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_PTR) {
628 std::for_each(srvMap_.begin(), srvMap_.end(), [&](const auto &elem) -> void {
629 if (EndsWith(elem.first, name)) {
630 AppendRecord(response.answers, DNSProto::RRTYPE_PTR, name, elem.first);
631 AppendRecord(response.additional, DNSProto::RRTYPE_SRV, elem.first,
632 DNSProto::RDataSrv{
633 .priority = 0,
634 .weight = 0,
635 .port = static_cast<uint16_t>(elem.second.port),
636 .name = GetHostDomain(),
637 });
638 AppendRecord(response.additional, DNSProto::RRTYPE_TXT, elem.first, elem.second.txt);
639 }
640 });
641 phase = std::max(phase, PHASE_PTR);
642 }
643 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_SRV) {
644 auto iter = srvMap_.find(name);
645 if (iter == srvMap_.end()) {
646 return;
647 }
648 AppendRecord(response.answers, DNSProto::RRTYPE_SRV, name,
649 DNSProto::RDataSrv{
650 .priority = 0,
651 .weight = 0,
652 .port = static_cast<uint16_t>(iter->second.port),
653 .name = GetHostDomain(),
654 });
655 phase = std::max(phase, PHASE_SRV);
656 }
657 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_TXT) {
658 auto iter = srvMap_.find(name);
659 if (iter == srvMap_.end()) {
660 return;
661 }
662 AppendRecord(response.answers, DNSProto::RRTYPE_TXT, name, iter->second.txt);
663 phase = std::max(phase, PHASE_SRV);
664 }
665 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_A || qu.qtype == DNSProto::RRTYPE_AAAA) {
666 if (name != GetHostDomain() || (qu.qtype != DNSProto::RRTYPE_ANY && anyAddrType != qu.qtype)) {
667 return;
668 }
669 AppendRecord(response.answers, anyAddrType, name, anyAddr);
670 phase = std::max(phase, PHASE_DOMAIN);
671 }
672 }
673
ProcessAnswer(int sock,const MDnsMessage & msg)674 void MDnsProtocolImpl::ProcessAnswer(int sock, const MDnsMessage &msg)
675 {
676 const sockaddr *saddrIf = listener_.GetSockAddr(sock);
677 if (saddrIf == nullptr) {
678 return;
679 }
680 bool v6 = (saddrIf->sa_family == AF_INET6);
681 std::set<std::string> changed;
682 for (const auto &answer : msg.answers) {
683 ProcessAnswerRecord(v6, answer, changed);
684 }
685 for (const auto &i : msg.additional) {
686 ProcessAnswerRecord(v6, i, changed);
687 }
688 for (const auto &i : changed) {
689 std::lock_guard<std::recursive_mutex> guard(mutex_);
690 RunTaskQueue(taskOnChange_[i]);
691 KillCache(i);
692 }
693 }
694
UpdatePtr(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)695 void MDnsProtocolImpl::UpdatePtr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
696 {
697 const std::string *data = std::any_cast<std::string>(&rr.rdata);
698 if (data == nullptr) {
699 return;
700 }
701
702 std::string name = rr.name;
703 if (browserMap_.find(name) == browserMap_.end()) {
704 return;
705 }
706 auto &results = browserMap_[name];
707 std::string srvName;
708 std::string srvType;
709 ExtractNameAndType(*data, srvName, srvType);
710 if (srvName.empty() || srvType.empty()) {
711 return;
712 }
713 auto res =
714 std::find_if(results.begin(), results.end(), [&](const auto &elem) { return elem.serviceName == srvName; });
715 if (res == results.end()) {
716 results.emplace_back(Result{
717 .serviceName = srvName,
718 .serviceType = srvType,
719 .state = State::ADD,
720 });
721 }
722 res = std::find_if(results.begin(), results.end(), [&](const auto &elem) { return elem.serviceName == srvName; });
723 if (res->serviceName != srvName || res->state == State::DEAD) {
724 res->state = State::REFRESH;
725 res->serviceName = srvName;
726 }
727 if (rr.ttl == 0) {
728 res->state = State::REMOVE;
729 }
730 if (res->state != State::LIVE && res->state != State::DEAD) {
731 changed.emplace(name);
732 }
733 res->ttl = rr.ttl;
734 res->refrehTime = MilliSecondsSinceEpoch();
735 }
736
UpdateSrv(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)737 void MDnsProtocolImpl::UpdateSrv(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
738 {
739 const DNSProto::RDataSrv *srv = std::any_cast<DNSProto::RDataSrv>(&rr.rdata);
740 if (srv == nullptr) {
741 return;
742 }
743 std::string name = rr.name;
744 if (cacheMap_.find(name) == cacheMap_.end()) {
745 ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
746 cacheMap_[name].state = State::ADD;
747 cacheMap_[name].domain = srv->name;
748 cacheMap_[name].port = srv->port;
749 }
750 Result &result = cacheMap_[name];
751 if (result.domain != srv->name || result.port != srv->port || result.state == State::DEAD) {
752 if (result.state != State::ADD) {
753 result.state = State::REFRESH;
754 }
755 result.domain = srv->name;
756 result.port = srv->port;
757 }
758 if (rr.ttl == 0) {
759 result.state = State::REMOVE;
760 }
761 if (result.state != State::LIVE && result.state != State::DEAD) {
762 changed.emplace(name);
763 }
764 result.ttl = rr.ttl;
765 result.refrehTime = MilliSecondsSinceEpoch();
766 }
767
UpdateTxt(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)768 void MDnsProtocolImpl::UpdateTxt(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
769 {
770 const TxtRecordEncoded *txt = std::any_cast<TxtRecordEncoded>(&rr.rdata);
771 if (txt == nullptr) {
772 return;
773 }
774 std::string name = rr.name;
775 if (cacheMap_.find(name) == cacheMap_.end()) {
776 ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
777 cacheMap_[name].state = State::ADD;
778 cacheMap_[name].txt = *txt;
779 }
780 Result &result = cacheMap_[name];
781 if (result.txt != *txt || result.state == State::DEAD) {
782 if (result.state != State::ADD) {
783 result.state = State::REFRESH;
784 }
785 result.txt = *txt;
786 }
787 if (rr.ttl == 0) {
788 result.state = State::REMOVE;
789 }
790 if (result.state != State::LIVE && result.state != State::DEAD) {
791 changed.emplace(name);
792 }
793 result.ttl = rr.ttl;
794 result.refrehTime = MilliSecondsSinceEpoch();
795 }
796
UpdateAddr(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)797 void MDnsProtocolImpl::UpdateAddr(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
798 {
799 if (v6 != (rr.rtype == DNSProto::RRTYPE_AAAA)) {
800 return;
801 }
802 const std::string addr = AddrToString(rr.rdata);
803 bool v6rr = (rr.rtype == DNSProto::RRTYPE_AAAA);
804 if (addr.empty()) {
805 return;
806 }
807 std::string name = rr.name;
808 if (cacheMap_.find(name) == cacheMap_.end()) {
809 ExtractNameAndType(name, cacheMap_[name].serviceName, cacheMap_[name].serviceType);
810 cacheMap_[name].state = State::ADD;
811 cacheMap_[name].ipv6 = v6rr;
812 cacheMap_[name].addr = addr;
813 }
814 Result &result = cacheMap_[name];
815 if (result.addr != addr || result.ipv6 != v6rr || result.state == State::DEAD) {
816 result.state = State::REFRESH;
817 result.addr = addr;
818 result.ipv6 = v6rr;
819 }
820 if (rr.ttl == 0) {
821 result.state = State::REMOVE;
822 }
823 if (result.state != State::LIVE && result.state != State::DEAD) {
824 changed.emplace(name);
825 }
826 result.ttl = rr.ttl;
827 result.refrehTime = MilliSecondsSinceEpoch();
828 }
829
ProcessAnswerRecord(bool v6,const DNSProto::ResourceRecord & rr,std::set<std::string> & changed)830 void MDnsProtocolImpl::ProcessAnswerRecord(bool v6, const DNSProto::ResourceRecord &rr, std::set<std::string> &changed)
831 {
832 NETMGR_EXT_LOG_D("mdns_log ProcessAnswerRecord, type=[%{public}d]", rr.rtype);
833 std::lock_guard<std::recursive_mutex> guard(mutex_);
834 std::string name = rr.name;
835 if (cacheMap_.find(name) == cacheMap_.end() && browserMap_.find(name) == browserMap_.end() &&
836 srvMap_.find(name) != srvMap_.end()) {
837 return;
838 }
839 if (rr.rtype == DNSProto::RRTYPE_PTR) {
840 UpdatePtr(v6, rr, changed);
841 } else if (rr.rtype == DNSProto::RRTYPE_SRV) {
842 UpdateSrv(v6, rr, changed);
843 } else if (rr.rtype == DNSProto::RRTYPE_TXT) {
844 UpdateTxt(v6, rr, changed);
845 } else if (rr.rtype == DNSProto::RRTYPE_A || rr.rtype == DNSProto::RRTYPE_AAAA) {
846 UpdateAddr(v6, rr, changed);
847 } else {
848 NETMGR_EXT_LOG_D("mdns_log Unknown packet received, type=[%{public}d]", rr.rtype);
849 }
850 }
851
GetHostDomain()852 std::string MDnsProtocolImpl::GetHostDomain()
853 {
854 if (config_.hostname.empty()) {
855 char buffer[MDNS_MAX_DOMAIN_LABEL];
856 if (gethostname(buffer, sizeof(buffer)) == 0) {
857 config_.hostname = buffer;
858 static auto uid = []() {
859 std::random_device rd;
860 return rd();
861 }();
862 config_.hostname += std::to_string(uid);
863 }
864 }
865 return Decorated(config_.hostname);
866 }
867
AddTask(const Task & task,bool atonce)868 void MDnsProtocolImpl::AddTask(const Task &task, bool atonce)
869 {
870 {
871 std::lock_guard<std::recursive_mutex> guard(mutex_);
872 taskQueue_.emplace_back(task);
873 }
874 if (atonce) {
875 listener_.TriggerRefresh();
876 }
877 }
878
ConvertResultToInfo(const MDnsProtocolImpl::Result & result)879 MDnsServiceInfo MDnsProtocolImpl::ConvertResultToInfo(const MDnsProtocolImpl::Result &result)
880 {
881 MDnsServiceInfo info;
882 info.name = result.serviceName;
883 info.type = result.serviceType;
884 if (!result.addr.empty()) {
885 info.family = result.ipv6 ? MDnsServiceInfo::IPV6 : MDnsServiceInfo::IPV4;
886 }
887 info.addr = result.addr;
888 info.port = result.port;
889 info.txtRecord = result.txt;
890 return info;
891 }
892
IsCacheAvailable(const std::string & key)893 bool MDnsProtocolImpl::IsCacheAvailable(const std::string &key)
894 {
895 constexpr int64_t ms2S = 1000LL;
896 NETMGR_EXT_LOG_D("mdns_log IsCacheAvailable, ttl=[%{public}u]", cacheMap_[key].ttl);
897 return cacheMap_.find(key) != cacheMap_.end() &&
898 (ms2S * cacheMap_[key].ttl) > static_cast<uint32_t>(MilliSecondsSinceEpoch() - cacheMap_[key].refrehTime);
899 }
900
IsDomainCacheAvailable(const std::string & key)901 bool MDnsProtocolImpl::IsDomainCacheAvailable(const std::string &key)
902 {
903 return IsCacheAvailable(key) && !cacheMap_[key].addr.empty();
904 }
905
IsInstanceCacheAvailable(const std::string & key)906 bool MDnsProtocolImpl::IsInstanceCacheAvailable(const std::string &key)
907 {
908 return IsCacheAvailable(key) && !cacheMap_[key].domain.empty();
909 }
910
IsBrowserAvailable(const std::string & key)911 bool MDnsProtocolImpl::IsBrowserAvailable(const std::string &key)
912 {
913 return browserMap_.find(key) != browserMap_.end() && !browserMap_[key].empty();
914 }
915
AddEvent(const std::string & key,const Task & task)916 void MDnsProtocolImpl::AddEvent(const std::string &key, const Task &task)
917 {
918 std::lock_guard<std::recursive_mutex> guard(mutex_);
919 taskOnChange_[key].emplace_back(task);
920 }
921
RunTaskQueue(std::list<Task> & queue)922 void MDnsProtocolImpl::RunTaskQueue(std::list<Task> &queue)
923 {
924 std::list<Task> tmp;
925 for (auto &&func : queue) {
926 if (!func()) {
927 tmp.emplace_back(func);
928 }
929 }
930 tmp.swap(queue);
931 }
932
KillCache(const std::string & key)933 void MDnsProtocolImpl::KillCache(const std::string &key)
934 {
935 NETMGR_EXT_LOG_D("mdns_log KillCache");
936 if (IsBrowserAvailable(key) && browserMap_.find(key) != browserMap_.end()) {
937 for (auto it = browserMap_[key].begin(); it != browserMap_[key].end();) {
938 KillBrowseCache(key, it);
939 }
940 }
941 if (IsCacheAvailable(key)) {
942 std::lock_guard<std::recursive_mutex> guard(mutex_);
943 auto &elem = cacheMap_[key];
944 if (elem.state == State::REMOVE) {
945 elem.state = State::DEAD;
946 cacheMap_.erase(key);
947 } else if (elem.state == State::ADD || elem.state == State::REFRESH) {
948 elem.state = State::LIVE;
949 }
950 }
951 }
952
KillBrowseCache(const std::string & key,std::vector<Result>::iterator & it)953 void MDnsProtocolImpl::KillBrowseCache(const std::string &key, std::vector<Result>::iterator &it)
954 {
955 NETMGR_EXT_LOG_D("mdns_log KillBrowseCache");
956 if (it->state == State::REMOVE) {
957 it->state = State::DEAD;
958 if (nameCbMap_.find(key) != nameCbMap_.end()) {
959 NETMGR_EXT_LOG_D("mdns_log HandleServiceLost");
960 nameCbMap_[key]->HandleServiceLost(ConvertResultToInfo(*it), NETMANAGER_EXT_SUCCESS);
961 }
962 std::string fullName = Decorated(it->serviceName + MDNS_DOMAIN_SPLITER_STR + it->serviceType);
963 cacheMap_.erase(fullName);
964 it = browserMap_[key].erase(it);
965 } else if (it->state == State::ADD || it->state == State::REFRESH) {
966 it->state = State::LIVE;
967 it++;
968 } else {
969 it++;
970 }
971 }
972
StopCbMap(const std::string & serviceType)973 int32_t MDnsProtocolImpl::StopCbMap(const std::string &serviceType)
974 {
975 NETMGR_EXT_LOG_D("mdns_log StopCbMap");
976 std::lock_guard<std::recursive_mutex> guard(mutex_);
977 std::string name = Decorated(serviceType);
978 sptr<IDiscoveryCallback> cb = nullptr;
979 if (nameCbMap_.find(name) != nameCbMap_.end()) {
980 cb = nameCbMap_[name];
981 nameCbMap_.erase(name);
982 }
983 taskOnChange_.erase(name);
984 auto it = browserMap_.find(name);
985 if (it != browserMap_.end()) {
986 if (cb != nullptr) {
987 NETMGR_EXT_LOG_I("mdns_log StopCbMap res size:[%{public}zu]", it->second.size());
988 for (auto &&res : it->second) {
989 NETMGR_EXT_LOG_W("mdns_log HandleServiceLost");
990 cb->HandleServiceLost(ConvertResultToInfo(res), NETMANAGER_EXT_SUCCESS);
991 }
992 }
993 browserMap_.erase(name);
994 }
995 return NETMANAGER_SUCCESS;
996 }
997 } // namespace NetManagerStandard
998 } // namespace OHOS
999