• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "mdns_protocol_impl.h"
17 
18 #include <arpa/inet.h>
19 #include <iostream>
20 #include <numeric>
21 #include <random>
22 #include <unistd.h>
23 
24 #include "mdns_packet_parser.h"
25 #include "netmgr_ext_log_wrapper.h"
26 
27 namespace OHOS {
28 namespace NetManagerStandard {
29 
30 namespace {
31 
32 constexpr uint32_t DEFAULT_TTL = 120;
33 constexpr uint16_t MDNS_FLUSH_CACHE_BIT = 0x8000;
34 
35 constexpr int PHASE_PTR = 1;
36 constexpr int PHASE_SRV = 2;
37 constexpr int PHASE_DOMAIN = 3;
38 
AddrToString(const std::any & addr)39 std::string AddrToString(const std::any &addr)
40 {
41     char buf[INET6_ADDRSTRLEN] = {0};
42     if (std::any_cast<in_addr>(&addr)) {
43         if (inet_ntop(AF_INET, std::any_cast<in_addr>(&addr), buf, sizeof(buf)) == nullptr) {
44             return std::string{};
45         }
46     } else if (std::any_cast<in6_addr>(&addr)) {
47         if (inet_ntop(AF_INET6, std::any_cast<in6_addr>(&addr), buf, sizeof(buf)) == nullptr) {
48             return std::string{};
49         }
50     } else {
51         return std::string{};
52     }
53     return std::string(buf);
54 }
55 
56 } // namespace
57 
MDnsProtocolImpl()58 MDnsProtocolImpl::MDnsProtocolImpl()
59 {
60     Init();
61 }
62 
Init()63 void MDnsProtocolImpl::Init()
64 {
65     listener_.CloseAllSocket();
66     if (config_.configAllIface) {
67         listener_.OpenSocketForEachIface(config_.ipv6Support, config_.configLo);
68     } else {
69         listener_.OpenSocketForDefault(config_.ipv6Support);
70     }
71     listener_.SetReceiveHandler(
72         [this](int sock, const MDnsPayload &payload) { return this->ReceivePacket(sock, payload); });
73     listener_.SetRefreshHandler([this](int sock) { return this->OnRefresh(sock); });
74 }
75 
SetConfig(const MDnsConfig & config)76 void MDnsProtocolImpl::SetConfig(const MDnsConfig &config)
77 {
78     config_ = config;
79 }
80 
GetConfig() const81 const MDnsConfig &MDnsProtocolImpl::GetConfig() const
82 {
83     return config_;
84 }
85 
SetHandler(const Handler & handler)86 void MDnsProtocolImpl::SetHandler(const Handler &handler)
87 {
88     handler_ = handler;
89 }
90 
Decorated(const std::string & name) const91 std::string MDnsProtocolImpl::Decorated(const std::string &name) const
92 {
93     return name + config_.topDomain + MDNS_DOMAIN_SPLITER_STR;
94 }
95 
Dotted(const std::string & name) const96 std::string MDnsProtocolImpl::Dotted(const std::string &name) const
97 {
98     return EndsWith(name, MDNS_DOMAIN_SPLITER_STR) ? name : name + MDNS_DOMAIN_SPLITER_STR;
99 }
100 
UnDotted(const std::string & name) const101 std::string MDnsProtocolImpl::UnDotted(const std::string &name) const
102 {
103     return EndsWith(name, MDNS_DOMAIN_SPLITER_STR) ? name.substr(0, name.size() - 1) : name;
104 }
105 
ExtractInstance(const Result & info) const106 std::string MDnsProtocolImpl::ExtractInstance(const Result &info) const
107 {
108     return Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
109 }
110 
Register(const Result & info)111 int32_t MDnsProtocolImpl::Register(const Result &info)
112 {
113     if (!(IsNameValid(info.serviceName) && IsTypeValid(info.serviceType) && IsPortValid(info.port))) {
114         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
115     }
116     std::string name = ExtractInstance(info);
117     if (!IsDomainValid(name)) {
118         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
119     }
120     {
121         std::lock_guard<std::mutex> guard(mutex_);
122         if (srvMap_.find(name) != srvMap_.end()) {
123             return NET_MDNS_ERR_SERVICE_INSTANCE_DUPLICATE;
124         }
125         srvMap_.emplace(name, info);
126     }
127 
128     listener_.Start();
129     return Announce(info, false);
130 }
131 
Discovery(const std::string & serviceType)132 int32_t MDnsProtocolImpl::Discovery(const std::string &serviceType)
133 {
134     if (!IsTypeValid(serviceType)) {
135         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
136     }
137     std::string name = Decorated(serviceType);
138     if (!IsDomainValid(name)) {
139         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
140     }
141     {
142         std::lock_guard<std::mutex> guard(mutex_);
143         ++reqMap_[name];
144         ++reqCount_;
145     }
146     MDnsPayloadParser parser;
147     MDnsMessage msg{};
148     msg.questions.emplace_back(DNSProto::Question{
149         .name = name,
150         .qtype = DNSProto::RRTYPE_PTR,
151         .qclass = DNSProto::RRCLASS_IN,
152     });
153     msg.header.qdcount = msg.questions.size();
154 
155     listener_.Start();
156     ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
157 
158     return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
159 }
160 
ResolveInstance(const std::string & instance)161 int32_t MDnsProtocolImpl::ResolveInstance(const std::string &instance)
162 {
163     if (!IsInstanceValid(instance)) {
164         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
165     }
166     std::string name = Decorated(instance);
167     if (!IsDomainValid(name)) {
168         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
169     }
170     {
171         std::lock_guard<std::mutex> guard(mutex_);
172         ++reqMap_[name];
173         ++reqCount_;
174     }
175     MDnsPayloadParser parser;
176     MDnsMessage msg{};
177     msg.questions.emplace_back(DNSProto::Question{
178         .name = name,
179         .qtype = DNSProto::RRTYPE_SRV,
180         .qclass = DNSProto::RRCLASS_IN,
181     });
182     msg.questions.emplace_back(DNSProto::Question{
183         .name = name,
184         .qtype = DNSProto::RRTYPE_TXT,
185         .qclass = DNSProto::RRCLASS_IN,
186     });
187     msg.header.qdcount = msg.questions.size();
188 
189     listener_.Start();
190     ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
191 
192     return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
193 }
194 
Resolve(const std::string & domain)195 int32_t MDnsProtocolImpl::Resolve(const std::string &domain)
196 {
197     if (!IsDomainValid(domain)) {
198         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
199     }
200     std::string name = Dotted(domain);
201     if (!IsDomainValid(name)) {
202         return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
203     }
204     {
205         std::lock_guard<std::mutex> guard(mutex_);
206         ++reqMap_[name];
207         ++reqCount_;
208     }
209     MDnsPayloadParser parser;
210     MDnsMessage msg{};
211     msg.questions.emplace_back(DNSProto::Question{
212         .name = name,
213         .qtype = DNSProto::RRTYPE_A,
214         .qclass = DNSProto::RRCLASS_IN,
215     });
216     msg.questions.emplace_back(DNSProto::Question{
217         .name = name,
218         .qtype = DNSProto::RRTYPE_AAAA,
219         .qclass = DNSProto::RRCLASS_IN,
220     });
221     msg.header.qdcount = msg.questions.size();
222 
223     listener_.Start();
224     ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
225 
226     return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
227 }
228 
UnRegister(const std::string & key)229 int32_t MDnsProtocolImpl::UnRegister(const std::string &key)
230 {
231     std::string name = Decorated(key);
232     std::lock_guard<std::mutex> guard(mutex_);
233     if (srvMap_.find(name) != srvMap_.end()) {
234         Announce(srvMap_[name], true);
235         srvMap_.erase(name);
236         return NETMANAGER_EXT_SUCCESS;
237     }
238     return NET_MDNS_ERR_SERVICE_INSTANCE_NOT_FOUND;
239 }
240 
StopDiscovery(const std::string & key)241 int32_t MDnsProtocolImpl::StopDiscovery(const std::string &key)
242 {
243     return Stop(Decorated(key));
244 }
245 
StopResolveInstance(const std::string & key)246 int32_t MDnsProtocolImpl::StopResolveInstance(const std::string &key)
247 {
248     return Stop(Decorated(key));
249 }
250 
StopResolve(const std::string & key)251 int32_t MDnsProtocolImpl::StopResolve(const std::string &key)
252 {
253     return Stop(Dotted(key));
254 }
255 
Stop(const std::string & key)256 int32_t MDnsProtocolImpl::Stop(const std::string &key)
257 {
258     std::lock_guard<std::mutex> guard(mutex_);
259     if (reqMap_.find(key) != reqMap_.end() && reqMap_[key] > 0) {
260         --reqMap_[key];
261         --reqCount_;
262     }
263     return NETMANAGER_EXT_SUCCESS;
264 }
265 
Announce(const Result & info,bool off)266 int32_t MDnsProtocolImpl::Announce(const Result &info, bool off)
267 {
268     MDnsMessage response{};
269     response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
270     std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
271     response.answers.emplace_back(DNSProto::ResourceRecord{.name = Decorated(info.serviceType),
272                                                            .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_PTR),
273                                                            .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
274                                                            .ttl = off ? 0U : DEFAULT_TTL,
275                                                            .rdata = name});
276     response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
277                                                            .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_SRV),
278                                                            .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
279                                                            .ttl = off ? 0U : DEFAULT_TTL,
280                                                            .rdata = DNSProto::RDataSrv{
281                                                                .priority = 0,
282                                                                .weight = 0,
283                                                                .port = static_cast<uint16_t>(info.port),
284                                                                .name = GetHostDomain(),
285                                                            }});
286     response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
287                                                            .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_TXT),
288                                                            .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
289                                                            .ttl = off ? 0U : DEFAULT_TTL,
290                                                            .rdata = info.txt});
291     MDnsPayloadParser parser;
292     ssize_t size = listener_.MulticastAll(parser.ToBytes(response));
293     return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
294 }
295 
ReceivePacket(int sock,const MDnsPayload & payload)296 void MDnsProtocolImpl::ReceivePacket(int sock, const MDnsPayload &payload)
297 {
298     if (payload.size() == 0) {
299         NETMGR_EXT_LOG_W("empty payload received");
300         return;
301     }
302     MDnsPayloadParser parser;
303     MDnsMessage msg = parser.FromBytes(payload);
304     if (parser.GetError() != 0) {
305         NETMGR_EXT_LOG_W("payload parse failed");
306         return;
307     }
308     if ((msg.header.flags & DNSProto::HEADER_FLAGS_QR_MASK) == 0) {
309         ProcessQuestion(sock, msg);
310     } else {
311         ProcessAnswer(sock, msg);
312     }
313 }
314 
OnRefresh(int sock)315 void MDnsProtocolImpl::OnRefresh(int sock)
316 {
317     std::lock_guard<std::mutex> guard(mutex_);
318     NETMGR_EXT_LOG_W("taskQueue_ size: %u", static_cast<uint32_t>(taskQueue_.size()));
319     while (taskQueue_.size() > 0) {
320         taskQueue_.front()();
321         taskQueue_.pop();
322     }
323 }
324 
AppendRecord(std::vector<DNSProto::ResourceRecord> & rrlist,DNSProto::RRType type,const std::string & name,const std::any & rdata)325 void MDnsProtocolImpl::AppendRecord(std::vector<DNSProto::ResourceRecord> &rrlist, DNSProto::RRType type,
326                                     const std::string &name, const std::any &rdata)
327 {
328     rrlist.emplace_back(DNSProto::ResourceRecord{.name = name,
329                                                  .rtype = static_cast<uint16_t>(type),
330                                                  .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
331                                                  .ttl = DEFAULT_TTL,
332                                                  .rdata = rdata});
333 }
334 
ProcessQuestion(int sock,const MDnsMessage & msg)335 void MDnsProtocolImpl::ProcessQuestion(int sock, const MDnsMessage &msg)
336 {
337     const sockaddr *saddrIf = listener_.GetSockAddr(sock);
338     if (saddrIf == nullptr) {
339         return;
340     }
341     std::any anyAddr;
342     DNSProto::RRType anyAddrType;
343     if (saddrIf->sa_family == AF_INET6) {
344         anyAddr = reinterpret_cast<const sockaddr_in6 *>(saddrIf)->sin6_addr;
345         anyAddrType = DNSProto::RRTYPE_AAAA;
346     } else {
347         anyAddr = reinterpret_cast<const sockaddr_in *>(saddrIf)->sin_addr;
348         anyAddrType = DNSProto::RRTYPE_A;
349     }
350     int phase = 0;
351     MDnsMessage response{};
352     response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
353     for (size_t i = 0; i < msg.header.qdcount; ++i) {
354         ProcessQuestionRecord(anyAddr, anyAddrType, msg.questions[i], phase, response);
355     }
356     if (phase == PHASE_SRV) {
357         AppendRecord(response.additional, anyAddrType, GetHostDomain(), anyAddr);
358     }
359     if (phase != 0 && response.answers.size() > 0) {
360         listener_.Multicast(sock, MDnsPayloadParser().ToBytes(response));
361     }
362 }
363 
ProcessQuestionRecord(const std::any & anyAddr,const DNSProto::RRType & anyAddrType,const DNSProto::Question & qu,int & phase,MDnsMessage & response)364 void MDnsProtocolImpl::ProcessQuestionRecord(const std::any &anyAddr, const DNSProto::RRType &anyAddrType,
365                                              const DNSProto::Question &qu, int &phase, MDnsMessage &response)
366 {
367     std::string name = qu.name;
368     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_PTR) {
369         std::lock_guard<std::mutex> guard(mutex_);
370         std::for_each(srvMap_.begin(), srvMap_.end(), [&](const auto &elem) -> void {
371             if (EndsWith(elem.first, name)) {
372                 AppendRecord(response.answers, DNSProto::RRTYPE_PTR, name, elem.first);
373             }
374         });
375         phase = std::max(phase, PHASE_PTR);
376     }
377     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_SRV) {
378         std::lock_guard<std::mutex> guard(mutex_);
379         auto iter = srvMap_.find(name);
380         if (iter == srvMap_.end()) {
381             return;
382         }
383         AppendRecord(response.answers, DNSProto::RRTYPE_SRV, name,
384                      DNSProto::RDataSrv{
385                          .priority = 0,
386                          .weight = 0,
387                          .port = static_cast<uint16_t>(iter->second.port),
388                          .name = GetHostDomain(),
389                      });
390         phase = std::max(phase, PHASE_SRV);
391     }
392     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_TXT) {
393         std::lock_guard<std::mutex> guard(mutex_);
394         auto iter = srvMap_.find(name);
395         if (iter == srvMap_.end()) {
396             return;
397         }
398         AppendRecord(response.answers, DNSProto::RRTYPE_TXT, name, iter->second.txt);
399         phase = std::max(phase, PHASE_SRV);
400     }
401     if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_A || qu.qtype == DNSProto::RRTYPE_AAAA) {
402         if (name != GetHostDomain() || (qu.qtype != DNSProto::RRTYPE_ANY && anyAddrType != qu.qtype)) {
403             return;
404         }
405         AppendRecord(response.answers, anyAddrType, name, anyAddr);
406         phase = std::max(phase, PHASE_DOMAIN);
407     }
408 }
409 
ProcessAnswer(int sock,const MDnsMessage & msg)410 void MDnsProtocolImpl::ProcessAnswer(int sock, const MDnsMessage &msg)
411 {
412     const sockaddr *saddrIf = listener_.GetSockAddr(sock);
413     if (saddrIf == nullptr) {
414         return;
415     }
416     bool v6 = (saddrIf->sa_family == AF_INET6);
417 
418     std::vector<Result> matches;
419     std::map<std::string, Result> results;
420     std::map<std::string, std::string> needMore;
421     for (size_t i = 0; i < msg.answers.size(); ++i) {
422         ProcessAnswerRecord(v6, msg.answers[i], matches, results, needMore);
423     }
424 
425     for (size_t i = 0; i < msg.additional.size() && !needMore.empty(); ++i) {
426         std::string name = msg.additional[i].name;
427         if (needMore.find(name) == needMore.end()) {
428             continue;
429         }
430         if (msg.additional[i].rtype == DNSProto::RRTYPE_A || msg.additional[i].rtype == DNSProto::RRTYPE_AAAA) {
431             if (v6 != (msg.additional[i].rtype == DNSProto::RRTYPE_AAAA)) {
432                 continue;
433             }
434             Result &result = results[needMore[name]];
435             result.domain = UnDotted(name);
436             result.ipv6 = (msg.additional[i].rtype == DNSProto::RRTYPE_AAAA);
437             result.addr = AddrToString(msg.additional[i].rdata);
438         }
439     }
440 
441     for (auto i = matches.begin(); i != matches.end() && handler_ != nullptr; ++i) {
442         handler_(*i, NETMANAGER_EXT_SUCCESS);
443     }
444     for (auto i = results.begin(); i != results.end() && handler_ != nullptr; ++i) {
445         i->second.iface = listener_.GetIface(sock);
446         i->second.ipv6 = v6;
447         handler_(i->second, NETMANAGER_EXT_SUCCESS);
448     }
449 }
450 
ProcessAnswerRecord(bool v6,const DNSProto::ResourceRecord & rr,std::vector<Result> & matches,std::map<std::string,Result> & results,std::map<std::string,std::string> & needMore)451 void MDnsProtocolImpl::ProcessAnswerRecord(bool v6, const DNSProto::ResourceRecord &rr, std::vector<Result> &matches,
452                                            std::map<std::string, Result> &results,
453                                            std::map<std::string, std::string> &needMore)
454 {
455     std::string name = rr.name;
456     mutex_.lock();
457     if (reqMap_[name] <= 0) {
458         return mutex_.unlock();
459     }
460     mutex_.unlock();
461     if (rr.rtype == DNSProto::RRTYPE_PTR) {
462         const std::string *data = std::any_cast<std::string>(&rr.rdata);
463         if (data == nullptr) {
464             return;
465         }
466         Result result;
467         result.type = (rr.ttl == 0) ? SERVICE_LOST : SERVICE_FOUND;
468         ExtractNameAndType(*data, result.serviceName, result.serviceType);
469         if (std::find_if(matches.begin(), matches.end(), [&](const auto &elem) {
470                 return elem.serviceName == result.serviceName && elem.serviceType == result.serviceType;
471             }) == matches.end()) {
472             matches.emplace_back(std::move(result));
473         }
474     } else if (rr.rtype == DNSProto::RRTYPE_SRV) {
475         const DNSProto::RDataSrv *srv = std::any_cast<DNSProto::RDataSrv>(&rr.rdata);
476         if (rr.ttl == 0 || srv == nullptr) {
477             return;
478         }
479         Result &result = results[name];
480         result.type = INSTANCE_RESOLVED;
481         ExtractNameAndType(name, result.serviceName, result.serviceType);
482         result.domain = UnDotted(srv->name);
483         result.port = srv->port;
484         needMore[srv->name] = name;
485     } else if (rr.rtype == DNSProto::RRTYPE_TXT) {
486         const TxtRecordEncoded *txt = std::any_cast<TxtRecordEncoded>(&rr.rdata);
487         if (rr.ttl == 0 || txt == nullptr) {
488             return;
489         }
490         Result &result = results[name];
491         result.txt = *txt;
492     } else if (rr.rtype == DNSProto::RRTYPE_A || rr.rtype == DNSProto::RRTYPE_AAAA) {
493         if (rr.ttl == 0 || v6 != (rr.rtype == DNSProto::RRTYPE_AAAA)) {
494             return;
495         }
496         Result &result = results[name];
497         result.type = DOMAIN_RESOLVED;
498         result.domain = UnDotted(name);
499         result.ipv6 = (rr.rtype == DNSProto::RRTYPE_AAAA);
500         result.addr = AddrToString(rr.rdata);
501     } else {
502         NETMGR_EXT_LOG_D("Unknown packet received, type=[%{public}d]", rr.rtype);
503     }
504 }
505 
GetHostDomain()506 std::string MDnsProtocolImpl::GetHostDomain()
507 {
508     if (config_.hostname.empty()) {
509         char buffer[MDNS_MAX_DOMAIN_LABEL];
510         if (gethostname(buffer, sizeof(buffer)) == 0) {
511             config_.hostname = buffer;
512             static auto uid = []() {
513                 std::random_device rd;
514                 return rd();
515             }();
516             config_.hostname += std::to_string(uid);
517         }
518     }
519     return Decorated(config_.hostname);
520 }
521 
RunTaskLater(const Task & task)522 void MDnsProtocolImpl::RunTaskLater(const Task &task)
523 {
524     taskQueue_.push(task);
525     listener_.TriggerRefresh();
526 }
527 
528 } // namespace NetManagerStandard
529 } // namespace OHOS
530