1 /*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "mdns_protocol_impl.h"
17
18 #include <arpa/inet.h>
19 #include <iostream>
20 #include <numeric>
21 #include <random>
22 #include <unistd.h>
23
24 #include "mdns_packet_parser.h"
25 #include "netmgr_ext_log_wrapper.h"
26
27 namespace OHOS {
28 namespace NetManagerStandard {
29
30 namespace {
31
32 constexpr uint32_t DEFAULT_TTL = 120;
33 constexpr uint16_t MDNS_FLUSH_CACHE_BIT = 0x8000;
34
35 constexpr int PHASE_PTR = 1;
36 constexpr int PHASE_SRV = 2;
37 constexpr int PHASE_DOMAIN = 3;
38
AddrToString(const std::any & addr)39 std::string AddrToString(const std::any &addr)
40 {
41 char buf[INET6_ADDRSTRLEN] = {0};
42 if (std::any_cast<in_addr>(&addr)) {
43 if (inet_ntop(AF_INET, std::any_cast<in_addr>(&addr), buf, sizeof(buf)) == nullptr) {
44 return std::string{};
45 }
46 } else if (std::any_cast<in6_addr>(&addr)) {
47 if (inet_ntop(AF_INET6, std::any_cast<in6_addr>(&addr), buf, sizeof(buf)) == nullptr) {
48 return std::string{};
49 }
50 } else {
51 return std::string{};
52 }
53 return std::string(buf);
54 }
55
56 } // namespace
57
MDnsProtocolImpl()58 MDnsProtocolImpl::MDnsProtocolImpl()
59 {
60 Init();
61 }
62
Init()63 void MDnsProtocolImpl::Init()
64 {
65 listener_.CloseAllSocket();
66 if (config_.configAllIface) {
67 listener_.OpenSocketForEachIface(config_.ipv6Support, config_.configLo);
68 } else {
69 listener_.OpenSocketForDefault(config_.ipv6Support);
70 }
71 listener_.SetReceiveHandler(
72 [this](int sock, const MDnsPayload &payload) { return this->ReceivePacket(sock, payload); });
73 listener_.SetRefreshHandler([this](int sock) { return this->OnRefresh(sock); });
74 }
75
SetConfig(const MDnsConfig & config)76 void MDnsProtocolImpl::SetConfig(const MDnsConfig &config)
77 {
78 config_ = config;
79 }
80
GetConfig() const81 const MDnsConfig &MDnsProtocolImpl::GetConfig() const
82 {
83 return config_;
84 }
85
SetHandler(const Handler & handler)86 void MDnsProtocolImpl::SetHandler(const Handler &handler)
87 {
88 handler_ = handler;
89 }
90
Decorated(const std::string & name) const91 std::string MDnsProtocolImpl::Decorated(const std::string &name) const
92 {
93 return name + config_.topDomain + MDNS_DOMAIN_SPLITER_STR;
94 }
95
Dotted(const std::string & name) const96 std::string MDnsProtocolImpl::Dotted(const std::string &name) const
97 {
98 return EndsWith(name, MDNS_DOMAIN_SPLITER_STR) ? name : name + MDNS_DOMAIN_SPLITER_STR;
99 }
100
UnDotted(const std::string & name) const101 std::string MDnsProtocolImpl::UnDotted(const std::string &name) const
102 {
103 return EndsWith(name, MDNS_DOMAIN_SPLITER_STR) ? name.substr(0, name.size() - 1) : name;
104 }
105
ExtractInstance(const Result & info) const106 std::string MDnsProtocolImpl::ExtractInstance(const Result &info) const
107 {
108 return Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
109 }
110
Register(const Result & info)111 int32_t MDnsProtocolImpl::Register(const Result &info)
112 {
113 if (!(IsNameValid(info.serviceName) && IsTypeValid(info.serviceType) && IsPortValid(info.port))) {
114 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
115 }
116 std::string name = ExtractInstance(info);
117 if (!IsDomainValid(name)) {
118 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
119 }
120 {
121 std::lock_guard<std::mutex> guard(mutex_);
122 if (srvMap_.find(name) != srvMap_.end()) {
123 return NET_MDNS_ERR_SERVICE_INSTANCE_DUPLICATE;
124 }
125 srvMap_.emplace(name, info);
126 }
127
128 listener_.Start();
129 return Announce(info, false);
130 }
131
Discovery(const std::string & serviceType)132 int32_t MDnsProtocolImpl::Discovery(const std::string &serviceType)
133 {
134 if (!IsTypeValid(serviceType)) {
135 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
136 }
137 std::string name = Decorated(serviceType);
138 if (!IsDomainValid(name)) {
139 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
140 }
141 {
142 std::lock_guard<std::mutex> guard(mutex_);
143 ++reqMap_[name];
144 ++reqCount_;
145 }
146 MDnsPayloadParser parser;
147 MDnsMessage msg{};
148 msg.questions.emplace_back(DNSProto::Question{
149 .name = name,
150 .qtype = DNSProto::RRTYPE_PTR,
151 .qclass = DNSProto::RRCLASS_IN,
152 });
153 msg.header.qdcount = msg.questions.size();
154
155 listener_.Start();
156 ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
157
158 return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
159 }
160
ResolveInstance(const std::string & instance)161 int32_t MDnsProtocolImpl::ResolveInstance(const std::string &instance)
162 {
163 if (!IsInstanceValid(instance)) {
164 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
165 }
166 std::string name = Decorated(instance);
167 if (!IsDomainValid(name)) {
168 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
169 }
170 {
171 std::lock_guard<std::mutex> guard(mutex_);
172 ++reqMap_[name];
173 ++reqCount_;
174 }
175 MDnsPayloadParser parser;
176 MDnsMessage msg{};
177 msg.questions.emplace_back(DNSProto::Question{
178 .name = name,
179 .qtype = DNSProto::RRTYPE_SRV,
180 .qclass = DNSProto::RRCLASS_IN,
181 });
182 msg.questions.emplace_back(DNSProto::Question{
183 .name = name,
184 .qtype = DNSProto::RRTYPE_TXT,
185 .qclass = DNSProto::RRCLASS_IN,
186 });
187 msg.header.qdcount = msg.questions.size();
188
189 listener_.Start();
190 ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
191
192 return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
193 }
194
Resolve(const std::string & domain)195 int32_t MDnsProtocolImpl::Resolve(const std::string &domain)
196 {
197 if (!IsDomainValid(domain)) {
198 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
199 }
200 std::string name = Dotted(domain);
201 if (!IsDomainValid(name)) {
202 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
203 }
204 {
205 std::lock_guard<std::mutex> guard(mutex_);
206 ++reqMap_[name];
207 ++reqCount_;
208 }
209 MDnsPayloadParser parser;
210 MDnsMessage msg{};
211 msg.questions.emplace_back(DNSProto::Question{
212 .name = name,
213 .qtype = DNSProto::RRTYPE_A,
214 .qclass = DNSProto::RRCLASS_IN,
215 });
216 msg.questions.emplace_back(DNSProto::Question{
217 .name = name,
218 .qtype = DNSProto::RRTYPE_AAAA,
219 .qclass = DNSProto::RRCLASS_IN,
220 });
221 msg.header.qdcount = msg.questions.size();
222
223 listener_.Start();
224 ssize_t size = listener_.MulticastAll(parser.ToBytes(msg));
225
226 return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
227 }
228
UnRegister(const std::string & key)229 int32_t MDnsProtocolImpl::UnRegister(const std::string &key)
230 {
231 std::string name = Decorated(key);
232 std::lock_guard<std::mutex> guard(mutex_);
233 if (srvMap_.find(name) != srvMap_.end()) {
234 Announce(srvMap_[name], true);
235 srvMap_.erase(name);
236 return NETMANAGER_EXT_SUCCESS;
237 }
238 return NET_MDNS_ERR_SERVICE_INSTANCE_NOT_FOUND;
239 }
240
StopDiscovery(const std::string & key)241 int32_t MDnsProtocolImpl::StopDiscovery(const std::string &key)
242 {
243 return Stop(Decorated(key));
244 }
245
StopResolveInstance(const std::string & key)246 int32_t MDnsProtocolImpl::StopResolveInstance(const std::string &key)
247 {
248 return Stop(Decorated(key));
249 }
250
StopResolve(const std::string & key)251 int32_t MDnsProtocolImpl::StopResolve(const std::string &key)
252 {
253 return Stop(Dotted(key));
254 }
255
Stop(const std::string & key)256 int32_t MDnsProtocolImpl::Stop(const std::string &key)
257 {
258 std::lock_guard<std::mutex> guard(mutex_);
259 if (reqMap_.find(key) != reqMap_.end() && reqMap_[key] > 0) {
260 --reqMap_[key];
261 --reqCount_;
262 }
263 return NETMANAGER_EXT_SUCCESS;
264 }
265
Announce(const Result & info,bool off)266 int32_t MDnsProtocolImpl::Announce(const Result &info, bool off)
267 {
268 MDnsMessage response{};
269 response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
270 std::string name = Decorated(info.serviceName + MDNS_DOMAIN_SPLITER_STR + info.serviceType);
271 response.answers.emplace_back(DNSProto::ResourceRecord{.name = Decorated(info.serviceType),
272 .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_PTR),
273 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
274 .ttl = off ? 0U : DEFAULT_TTL,
275 .rdata = name});
276 response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
277 .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_SRV),
278 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
279 .ttl = off ? 0U : DEFAULT_TTL,
280 .rdata = DNSProto::RDataSrv{
281 .priority = 0,
282 .weight = 0,
283 .port = static_cast<uint16_t>(info.port),
284 .name = GetHostDomain(),
285 }});
286 response.answers.emplace_back(DNSProto::ResourceRecord{.name = name,
287 .rtype = static_cast<uint16_t>(DNSProto::RRTYPE_TXT),
288 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
289 .ttl = off ? 0U : DEFAULT_TTL,
290 .rdata = info.txt});
291 MDnsPayloadParser parser;
292 ssize_t size = listener_.MulticastAll(parser.ToBytes(response));
293 return size > 0 ? NETMANAGER_EXT_SUCCESS : NET_MDNS_ERR_SEND;
294 }
295
ReceivePacket(int sock,const MDnsPayload & payload)296 void MDnsProtocolImpl::ReceivePacket(int sock, const MDnsPayload &payload)
297 {
298 if (payload.size() == 0) {
299 NETMGR_EXT_LOG_W("empty payload received");
300 return;
301 }
302 MDnsPayloadParser parser;
303 MDnsMessage msg = parser.FromBytes(payload);
304 if (parser.GetError() != 0) {
305 NETMGR_EXT_LOG_W("payload parse failed");
306 return;
307 }
308 if ((msg.header.flags & DNSProto::HEADER_FLAGS_QR_MASK) == 0) {
309 ProcessQuestion(sock, msg);
310 } else {
311 ProcessAnswer(sock, msg);
312 }
313 }
314
OnRefresh(int sock)315 void MDnsProtocolImpl::OnRefresh(int sock)
316 {
317 std::lock_guard<std::mutex> guard(mutex_);
318 NETMGR_EXT_LOG_W("taskQueue_ size: %u", static_cast<uint32_t>(taskQueue_.size()));
319 while (taskQueue_.size() > 0) {
320 taskQueue_.front()();
321 taskQueue_.pop();
322 }
323 }
324
AppendRecord(std::vector<DNSProto::ResourceRecord> & rrlist,DNSProto::RRType type,const std::string & name,const std::any & rdata)325 void MDnsProtocolImpl::AppendRecord(std::vector<DNSProto::ResourceRecord> &rrlist, DNSProto::RRType type,
326 const std::string &name, const std::any &rdata)
327 {
328 rrlist.emplace_back(DNSProto::ResourceRecord{.name = name,
329 .rtype = static_cast<uint16_t>(type),
330 .rclass = DNSProto::RRCLASS_IN | MDNS_FLUSH_CACHE_BIT,
331 .ttl = DEFAULT_TTL,
332 .rdata = rdata});
333 }
334
ProcessQuestion(int sock,const MDnsMessage & msg)335 void MDnsProtocolImpl::ProcessQuestion(int sock, const MDnsMessage &msg)
336 {
337 const sockaddr *saddrIf = listener_.GetSockAddr(sock);
338 if (saddrIf == nullptr) {
339 return;
340 }
341 std::any anyAddr;
342 DNSProto::RRType anyAddrType;
343 if (saddrIf->sa_family == AF_INET6) {
344 anyAddr = reinterpret_cast<const sockaddr_in6 *>(saddrIf)->sin6_addr;
345 anyAddrType = DNSProto::RRTYPE_AAAA;
346 } else {
347 anyAddr = reinterpret_cast<const sockaddr_in *>(saddrIf)->sin_addr;
348 anyAddrType = DNSProto::RRTYPE_A;
349 }
350 int phase = 0;
351 MDnsMessage response{};
352 response.header.flags = DNSProto::MDNS_ANSWER_FLAGS;
353 for (size_t i = 0; i < msg.header.qdcount; ++i) {
354 ProcessQuestionRecord(anyAddr, anyAddrType, msg.questions[i], phase, response);
355 }
356 if (phase == PHASE_SRV) {
357 AppendRecord(response.additional, anyAddrType, GetHostDomain(), anyAddr);
358 }
359 if (phase != 0 && response.answers.size() > 0) {
360 listener_.Multicast(sock, MDnsPayloadParser().ToBytes(response));
361 }
362 }
363
ProcessQuestionRecord(const std::any & anyAddr,const DNSProto::RRType & anyAddrType,const DNSProto::Question & qu,int & phase,MDnsMessage & response)364 void MDnsProtocolImpl::ProcessQuestionRecord(const std::any &anyAddr, const DNSProto::RRType &anyAddrType,
365 const DNSProto::Question &qu, int &phase, MDnsMessage &response)
366 {
367 std::string name = qu.name;
368 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_PTR) {
369 std::lock_guard<std::mutex> guard(mutex_);
370 std::for_each(srvMap_.begin(), srvMap_.end(), [&](const auto &elem) -> void {
371 if (EndsWith(elem.first, name)) {
372 AppendRecord(response.answers, DNSProto::RRTYPE_PTR, name, elem.first);
373 }
374 });
375 phase = std::max(phase, PHASE_PTR);
376 }
377 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_SRV) {
378 std::lock_guard<std::mutex> guard(mutex_);
379 auto iter = srvMap_.find(name);
380 if (iter == srvMap_.end()) {
381 return;
382 }
383 AppendRecord(response.answers, DNSProto::RRTYPE_SRV, name,
384 DNSProto::RDataSrv{
385 .priority = 0,
386 .weight = 0,
387 .port = static_cast<uint16_t>(iter->second.port),
388 .name = GetHostDomain(),
389 });
390 phase = std::max(phase, PHASE_SRV);
391 }
392 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_TXT) {
393 std::lock_guard<std::mutex> guard(mutex_);
394 auto iter = srvMap_.find(name);
395 if (iter == srvMap_.end()) {
396 return;
397 }
398 AppendRecord(response.answers, DNSProto::RRTYPE_TXT, name, iter->second.txt);
399 phase = std::max(phase, PHASE_SRV);
400 }
401 if (qu.qtype == DNSProto::RRTYPE_ANY || qu.qtype == DNSProto::RRTYPE_A || qu.qtype == DNSProto::RRTYPE_AAAA) {
402 if (name != GetHostDomain() || (qu.qtype != DNSProto::RRTYPE_ANY && anyAddrType != qu.qtype)) {
403 return;
404 }
405 AppendRecord(response.answers, anyAddrType, name, anyAddr);
406 phase = std::max(phase, PHASE_DOMAIN);
407 }
408 }
409
ProcessAnswer(int sock,const MDnsMessage & msg)410 void MDnsProtocolImpl::ProcessAnswer(int sock, const MDnsMessage &msg)
411 {
412 const sockaddr *saddrIf = listener_.GetSockAddr(sock);
413 if (saddrIf == nullptr) {
414 return;
415 }
416 bool v6 = (saddrIf->sa_family == AF_INET6);
417
418 std::vector<Result> matches;
419 std::map<std::string, Result> results;
420 std::map<std::string, std::string> needMore;
421 for (size_t i = 0; i < msg.answers.size(); ++i) {
422 ProcessAnswerRecord(v6, msg.answers[i], matches, results, needMore);
423 }
424
425 for (size_t i = 0; i < msg.additional.size() && !needMore.empty(); ++i) {
426 std::string name = msg.additional[i].name;
427 if (needMore.find(name) == needMore.end()) {
428 continue;
429 }
430 if (msg.additional[i].rtype == DNSProto::RRTYPE_A || msg.additional[i].rtype == DNSProto::RRTYPE_AAAA) {
431 if (v6 != (msg.additional[i].rtype == DNSProto::RRTYPE_AAAA)) {
432 continue;
433 }
434 Result &result = results[needMore[name]];
435 result.domain = UnDotted(name);
436 result.ipv6 = (msg.additional[i].rtype == DNSProto::RRTYPE_AAAA);
437 result.addr = AddrToString(msg.additional[i].rdata);
438 }
439 }
440
441 for (auto i = matches.begin(); i != matches.end() && handler_ != nullptr; ++i) {
442 handler_(*i, NETMANAGER_EXT_SUCCESS);
443 }
444 for (auto i = results.begin(); i != results.end() && handler_ != nullptr; ++i) {
445 i->second.iface = listener_.GetIface(sock);
446 i->second.ipv6 = v6;
447 handler_(i->second, NETMANAGER_EXT_SUCCESS);
448 }
449 }
450
ProcessAnswerRecord(bool v6,const DNSProto::ResourceRecord & rr,std::vector<Result> & matches,std::map<std::string,Result> & results,std::map<std::string,std::string> & needMore)451 void MDnsProtocolImpl::ProcessAnswerRecord(bool v6, const DNSProto::ResourceRecord &rr, std::vector<Result> &matches,
452 std::map<std::string, Result> &results,
453 std::map<std::string, std::string> &needMore)
454 {
455 std::string name = rr.name;
456 mutex_.lock();
457 if (reqMap_[name] <= 0) {
458 return mutex_.unlock();
459 }
460 mutex_.unlock();
461 if (rr.rtype == DNSProto::RRTYPE_PTR) {
462 const std::string *data = std::any_cast<std::string>(&rr.rdata);
463 if (data == nullptr) {
464 return;
465 }
466 Result result;
467 result.type = (rr.ttl == 0) ? SERVICE_LOST : SERVICE_FOUND;
468 ExtractNameAndType(*data, result.serviceName, result.serviceType);
469 if (std::find_if(matches.begin(), matches.end(), [&](const auto &elem) {
470 return elem.serviceName == result.serviceName && elem.serviceType == result.serviceType;
471 }) == matches.end()) {
472 matches.emplace_back(std::move(result));
473 }
474 } else if (rr.rtype == DNSProto::RRTYPE_SRV) {
475 const DNSProto::RDataSrv *srv = std::any_cast<DNSProto::RDataSrv>(&rr.rdata);
476 if (rr.ttl == 0 || srv == nullptr) {
477 return;
478 }
479 Result &result = results[name];
480 result.type = INSTANCE_RESOLVED;
481 ExtractNameAndType(name, result.serviceName, result.serviceType);
482 result.domain = UnDotted(srv->name);
483 result.port = srv->port;
484 needMore[srv->name] = name;
485 } else if (rr.rtype == DNSProto::RRTYPE_TXT) {
486 const TxtRecordEncoded *txt = std::any_cast<TxtRecordEncoded>(&rr.rdata);
487 if (rr.ttl == 0 || txt == nullptr) {
488 return;
489 }
490 Result &result = results[name];
491 result.txt = *txt;
492 } else if (rr.rtype == DNSProto::RRTYPE_A || rr.rtype == DNSProto::RRTYPE_AAAA) {
493 if (rr.ttl == 0 || v6 != (rr.rtype == DNSProto::RRTYPE_AAAA)) {
494 return;
495 }
496 Result &result = results[name];
497 result.type = DOMAIN_RESOLVED;
498 result.domain = UnDotted(name);
499 result.ipv6 = (rr.rtype == DNSProto::RRTYPE_AAAA);
500 result.addr = AddrToString(rr.rdata);
501 } else {
502 NETMGR_EXT_LOG_D("Unknown packet received, type=[%{public}d]", rr.rtype);
503 }
504 }
505
GetHostDomain()506 std::string MDnsProtocolImpl::GetHostDomain()
507 {
508 if (config_.hostname.empty()) {
509 char buffer[MDNS_MAX_DOMAIN_LABEL];
510 if (gethostname(buffer, sizeof(buffer)) == 0) {
511 config_.hostname = buffer;
512 static auto uid = []() {
513 std::random_device rd;
514 return rd();
515 }();
516 config_.hostname += std::to_string(uid);
517 }
518 }
519 return Decorated(config_.hostname);
520 }
521
RunTaskLater(const Task & task)522 void MDnsProtocolImpl::RunTaskLater(const Task &task)
523 {
524 taskQueue_.push(task);
525 listener_.TriggerRefresh();
526 }
527
528 } // namespace NetManagerStandard
529 } // namespace OHOS
530