1 /*
2 * Copyright (c) 2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 /**
30 * @file
31 * This file implements the DNS-SD server.
32 */
33
34 #include "dnssd_server.hpp"
35
36 #if OPENTHREAD_CONFIG_DNSSD_SERVER_ENABLE
37
38 #include "common/array.hpp"
39 #include "common/as_core_type.hpp"
40 #include "common/code_utils.hpp"
41 #include "common/debug.hpp"
42 #include "common/instance.hpp"
43 #include "common/locator_getters.hpp"
44 #include "common/log.hpp"
45 #include "common/string.hpp"
46 #include "net/srp_server.hpp"
47 #include "net/udp6.hpp"
48
49 namespace ot {
50 namespace Dns {
51 namespace ServiceDiscovery {
52
53 RegisterLogModule("DnssdServer");
54
55 const char Server::kDnssdProtocolUdp[] = "_udp";
56 const char Server::kDnssdProtocolTcp[] = "_tcp";
57 const char Server::kDnssdSubTypeLabel[] = "._sub.";
58 const char Server::kDefaultDomainName[] = "default.service.arpa.";
59
Server(Instance & aInstance)60 Server::Server(Instance &aInstance)
61 : InstanceLocator(aInstance)
62 , mSocket(aInstance)
63 , mQueryCallbackContext(nullptr)
64 , mQuerySubscribe(nullptr)
65 , mQueryUnsubscribe(nullptr)
66 , mTimer(aInstance, Server::HandleTimer)
67 {
68 mCounters.Clear();
69 }
70
Start(void)71 Error Server::Start(void)
72 {
73 Error error = kErrorNone;
74
75 VerifyOrExit(!IsRunning());
76
77 SuccessOrExit(error = mSocket.Open(&Server::HandleUdpReceive, this));
78 SuccessOrExit(error = mSocket.Bind(kPort, kBindUnspecifiedNetif ? OT_NETIF_UNSPECIFIED : OT_NETIF_THREAD));
79
80 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
81 Get<Srp::Server>().HandleDnssdServerStateChange();
82 #endif
83
84 exit:
85 LogInfo("started: %s", ErrorToString(error));
86
87 if (error != kErrorNone)
88 {
89 IgnoreError(mSocket.Close());
90 }
91
92 return error;
93 }
94
Stop(void)95 void Server::Stop(void)
96 {
97 // Abort all query transactions
98 for (QueryTransaction &query : mQueryTransactions)
99 {
100 if (query.IsValid())
101 {
102 FinalizeQuery(query, Header::kResponseServerFailure);
103 }
104 }
105
106 mTimer.Stop();
107
108 IgnoreError(mSocket.Close());
109 LogInfo("stopped");
110
111 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
112 Get<Srp::Server>().HandleDnssdServerStateChange();
113 #endif
114 }
115
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMessageInfo)116 void Server::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMessageInfo)
117 {
118 static_cast<Server *>(aContext)->HandleUdpReceive(AsCoreType(aMessage), AsCoreType(aMessageInfo));
119 }
120
HandleUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)121 void Server::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
122 {
123 Header requestHeader;
124
125 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
126 // We first let the `Srp::Server` process the received message.
127 // It returns `kErrorNone` to indicate that it successfully
128 // processed the message.
129
130 VerifyOrExit(Get<Srp::Server>().HandleDnssdServerUdpReceive(aMessage, aMessageInfo) != kErrorNone);
131 #endif
132
133 SuccessOrExit(aMessage.Read(aMessage.GetOffset(), requestHeader));
134 VerifyOrExit(requestHeader.GetType() == Header::kTypeQuery);
135
136 ProcessQuery(requestHeader, aMessage, aMessageInfo);
137
138 exit:
139 return;
140 }
141
ProcessQuery(const Header & aRequestHeader,Message & aRequestMessage,const Ip6::MessageInfo & aMessageInfo)142 void Server::ProcessQuery(const Header &aRequestHeader, Message &aRequestMessage, const Ip6::MessageInfo &aMessageInfo)
143 {
144 Error error = kErrorNone;
145 Message * responseMessage = nullptr;
146 Header responseHeader;
147 NameCompressInfo compressInfo(kDefaultDomainName);
148 Header::Response response = Header::kResponseSuccess;
149 bool resolveByQueryCallbacks = false;
150
151 responseMessage = mSocket.NewMessage(0);
152 VerifyOrExit(responseMessage != nullptr, error = kErrorNoBufs);
153
154 // Allocate space for DNS header
155 SuccessOrExit(error = responseMessage->SetLength(sizeof(Header)));
156
157 // Setup initial DNS response header
158 responseHeader.Clear();
159 responseHeader.SetType(Header::kTypeResponse);
160 responseHeader.SetMessageId(aRequestHeader.GetMessageId());
161
162 // Validate the query
163 VerifyOrExit(aRequestHeader.GetQueryType() == Header::kQueryTypeStandard,
164 response = Header::kResponseNotImplemented);
165 VerifyOrExit(!aRequestHeader.IsTruncationFlagSet(), response = Header::kResponseFormatError);
166 VerifyOrExit(aRequestHeader.GetQuestionCount() > 0, response = Header::kResponseFormatError);
167
168 response = AddQuestions(aRequestHeader, aRequestMessage, responseHeader, *responseMessage, compressInfo);
169 VerifyOrExit(response == Header::kResponseSuccess);
170
171 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
172 // Answer the questions
173 response = ResolveBySrp(responseHeader, *responseMessage, compressInfo);
174 #endif
175
176 // Resolve the question using query callbacks if SRP server failed to resolve the questions.
177 if (responseHeader.GetAnswerCount() == 0)
178 {
179 if (kErrorNone == ResolveByQueryCallbacks(responseHeader, *responseMessage, compressInfo, aMessageInfo))
180 {
181 resolveByQueryCallbacks = true;
182 }
183 }
184 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
185 else
186 {
187 ++mCounters.mResolvedBySrp;
188 }
189 #endif
190
191 exit:
192 if (error == kErrorNone && !resolveByQueryCallbacks)
193 {
194 SendResponse(responseHeader, response, *responseMessage, aMessageInfo, mSocket);
195 }
196
197 FreeMessageOnError(responseMessage, error);
198 }
199
SendResponse(Header aHeader,Header::Response aResponseCode,Message & aMessage,const Ip6::MessageInfo & aMessageInfo,Ip6::Udp::Socket & aSocket)200 void Server::SendResponse(Header aHeader,
201 Header::Response aResponseCode,
202 Message & aMessage,
203 const Ip6::MessageInfo &aMessageInfo,
204 Ip6::Udp::Socket & aSocket)
205 {
206 Error error;
207
208 if (aResponseCode == Header::kResponseServerFailure)
209 {
210 LogWarn("failed to handle DNS query due to server failure");
211 aHeader.SetQuestionCount(0);
212 aHeader.SetAnswerCount(0);
213 aHeader.SetAdditionalRecordCount(0);
214 IgnoreError(aMessage.SetLength(sizeof(Header)));
215 }
216
217 aHeader.SetResponseCode(aResponseCode);
218 aMessage.Write(0, aHeader);
219
220 error = aSocket.SendTo(aMessage, aMessageInfo);
221
222 FreeMessageOnError(&aMessage, error);
223
224 if (error != kErrorNone)
225 {
226 LogWarn("failed to send DNS-SD reply: %s", ErrorToString(error));
227 }
228 else
229 {
230 LogInfo("send DNS-SD reply: %s, RCODE=%d", ErrorToString(error), aResponseCode);
231 }
232
233 UpdateResponseCounters(aResponseCode);
234 }
235
AddQuestions(const Header & aRequestHeader,const Message & aRequestMessage,Header & aResponseHeader,Message & aResponseMessage,NameCompressInfo & aCompressInfo)236 Header::Response Server::AddQuestions(const Header & aRequestHeader,
237 const Message & aRequestMessage,
238 Header & aResponseHeader,
239 Message & aResponseMessage,
240 NameCompressInfo &aCompressInfo)
241 {
242 Question question;
243 uint16_t readOffset;
244 Header::Response response = Header::kResponseSuccess;
245 char name[Name::kMaxNameSize];
246
247 readOffset = sizeof(Header);
248
249 // Check and append the questions
250 for (uint16_t i = 0; i < aRequestHeader.GetQuestionCount(); i++)
251 {
252 NameComponentsOffsetInfo nameComponentsOffsetInfo;
253 uint16_t qtype;
254
255 VerifyOrExit(kErrorNone == Name::ReadName(aRequestMessage, readOffset, name, sizeof(name)),
256 response = Header::kResponseFormatError);
257 VerifyOrExit(kErrorNone == aRequestMessage.Read(readOffset, question), response = Header::kResponseFormatError);
258 readOffset += sizeof(question);
259
260 qtype = question.GetType();
261
262 VerifyOrExit(qtype == ResourceRecord::kTypePtr || qtype == ResourceRecord::kTypeSrv ||
263 qtype == ResourceRecord::kTypeTxt || qtype == ResourceRecord::kTypeAaaa,
264 response = Header::kResponseNotImplemented);
265
266 VerifyOrExit(kErrorNone == FindNameComponents(name, aCompressInfo.GetDomainName(), nameComponentsOffsetInfo),
267 response = Header::kResponseNameError);
268
269 switch (question.GetType())
270 {
271 case ResourceRecord::kTypePtr:
272 VerifyOrExit(nameComponentsOffsetInfo.IsServiceName(), response = Header::kResponseNameError);
273 break;
274 case ResourceRecord::kTypeSrv:
275 VerifyOrExit(nameComponentsOffsetInfo.IsServiceInstanceName(), response = Header::kResponseNameError);
276 break;
277 case ResourceRecord::kTypeTxt:
278 VerifyOrExit(nameComponentsOffsetInfo.IsServiceInstanceName(), response = Header::kResponseNameError);
279 break;
280 case ResourceRecord::kTypeAaaa:
281 VerifyOrExit(nameComponentsOffsetInfo.IsHostName(), response = Header::kResponseNameError);
282 break;
283 default:
284 ExitNow(response = Header::kResponseNotImplemented);
285 }
286
287 VerifyOrExit(AppendQuestion(name, question, aResponseMessage, aCompressInfo) == kErrorNone,
288 response = Header::kResponseServerFailure);
289 }
290
291 aResponseHeader.SetQuestionCount(aRequestHeader.GetQuestionCount());
292
293 exit:
294 return response;
295 }
296
AppendQuestion(const char * aName,const Question & aQuestion,Message & aMessage,NameCompressInfo & aCompressInfo)297 Error Server::AppendQuestion(const char * aName,
298 const Question & aQuestion,
299 Message & aMessage,
300 NameCompressInfo &aCompressInfo)
301 {
302 Error error = kErrorNone;
303
304 switch (aQuestion.GetType())
305 {
306 case ResourceRecord::kTypePtr:
307 SuccessOrExit(error = AppendServiceName(aMessage, aName, aCompressInfo));
308 break;
309 case ResourceRecord::kTypeSrv:
310 case ResourceRecord::kTypeTxt:
311 SuccessOrExit(error = AppendInstanceName(aMessage, aName, aCompressInfo));
312 break;
313 case ResourceRecord::kTypeAaaa:
314 SuccessOrExit(error = AppendHostName(aMessage, aName, aCompressInfo));
315 break;
316 default:
317 OT_ASSERT(false);
318 }
319
320 error = aMessage.Append(aQuestion);
321
322 exit:
323 return error;
324 }
325
AppendPtrRecord(Message & aMessage,const char * aServiceName,const char * aInstanceName,uint32_t aTtl,NameCompressInfo & aCompressInfo)326 Error Server::AppendPtrRecord(Message & aMessage,
327 const char * aServiceName,
328 const char * aInstanceName,
329 uint32_t aTtl,
330 NameCompressInfo &aCompressInfo)
331 {
332 Error error;
333 PtrRecord ptrRecord;
334 uint16_t recordOffset;
335
336 ptrRecord.Init();
337 ptrRecord.SetTtl(aTtl);
338
339 SuccessOrExit(error = AppendServiceName(aMessage, aServiceName, aCompressInfo));
340
341 recordOffset = aMessage.GetLength();
342 SuccessOrExit(error = aMessage.SetLength(recordOffset + sizeof(ptrRecord)));
343
344 SuccessOrExit(error = AppendInstanceName(aMessage, aInstanceName, aCompressInfo));
345
346 ptrRecord.SetLength(aMessage.GetLength() - (recordOffset + sizeof(ResourceRecord)));
347 aMessage.Write(recordOffset, ptrRecord);
348
349 exit:
350 return error;
351 }
352
AppendSrvRecord(Message & aMessage,const char * aInstanceName,const char * aHostName,uint32_t aTtl,uint16_t aPriority,uint16_t aWeight,uint16_t aPort,NameCompressInfo & aCompressInfo)353 Error Server::AppendSrvRecord(Message & aMessage,
354 const char * aInstanceName,
355 const char * aHostName,
356 uint32_t aTtl,
357 uint16_t aPriority,
358 uint16_t aWeight,
359 uint16_t aPort,
360 NameCompressInfo &aCompressInfo)
361 {
362 SrvRecord srvRecord;
363 Error error = kErrorNone;
364 uint16_t recordOffset;
365
366 srvRecord.Init();
367 srvRecord.SetTtl(aTtl);
368 srvRecord.SetPriority(aPriority);
369 srvRecord.SetWeight(aWeight);
370 srvRecord.SetPort(aPort);
371
372 SuccessOrExit(error = AppendInstanceName(aMessage, aInstanceName, aCompressInfo));
373
374 recordOffset = aMessage.GetLength();
375 SuccessOrExit(error = aMessage.SetLength(recordOffset + sizeof(srvRecord)));
376
377 SuccessOrExit(error = AppendHostName(aMessage, aHostName, aCompressInfo));
378
379 srvRecord.SetLength(aMessage.GetLength() - (recordOffset + sizeof(ResourceRecord)));
380 aMessage.Write(recordOffset, srvRecord);
381
382 exit:
383 return error;
384 }
385
AppendAaaaRecord(Message & aMessage,const char * aHostName,const Ip6::Address & aAddress,uint32_t aTtl,NameCompressInfo & aCompressInfo)386 Error Server::AppendAaaaRecord(Message & aMessage,
387 const char * aHostName,
388 const Ip6::Address &aAddress,
389 uint32_t aTtl,
390 NameCompressInfo & aCompressInfo)
391 {
392 AaaaRecord aaaaRecord;
393 Error error;
394
395 aaaaRecord.Init();
396 aaaaRecord.SetTtl(aTtl);
397 aaaaRecord.SetAddress(aAddress);
398
399 SuccessOrExit(error = AppendHostName(aMessage, aHostName, aCompressInfo));
400 error = aMessage.Append(aaaaRecord);
401
402 exit:
403 return error;
404 }
405
AppendServiceName(Message & aMessage,const char * aName,NameCompressInfo & aCompressInfo)406 Error Server::AppendServiceName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo)
407 {
408 Error error;
409 uint16_t serviceCompressOffset = aCompressInfo.GetServiceNameOffset(aMessage, aName);
410 const char *serviceName;
411
412 // Check whether `aName` is a sub-type service name.
413 serviceName = StringFind(aName, kDnssdSubTypeLabel, kStringCaseInsensitiveMatch);
414
415 if (serviceName != nullptr)
416 {
417 uint8_t subTypeLabelLength = static_cast<uint8_t>(serviceName - aName) + sizeof(kDnssdSubTypeLabel) - 1;
418
419 SuccessOrExit(error = Name::AppendMultipleLabels(aName, subTypeLabelLength, aMessage));
420
421 // Skip over the "._sub." label to get to the root service name.
422 serviceName += sizeof(kDnssdSubTypeLabel) - 1;
423 }
424 else
425 {
426 serviceName = aName;
427 }
428
429 if (serviceCompressOffset != NameCompressInfo::kUnknownOffset)
430 {
431 error = Name::AppendPointerLabel(serviceCompressOffset, aMessage);
432 }
433 else
434 {
435 uint8_t domainStart = static_cast<uint8_t>(StringLength(serviceName, Name::kMaxNameSize - 1) -
436 StringLength(aCompressInfo.GetDomainName(), Name::kMaxNameSize - 1));
437 uint16_t domainCompressOffset = aCompressInfo.GetDomainNameOffset();
438
439 serviceCompressOffset = aMessage.GetLength();
440 aCompressInfo.SetServiceNameOffset(serviceCompressOffset);
441
442 if (domainCompressOffset == NameCompressInfo::kUnknownOffset)
443 {
444 aCompressInfo.SetDomainNameOffset(serviceCompressOffset + domainStart);
445 error = Name::AppendName(serviceName, aMessage);
446 }
447 else
448 {
449 SuccessOrExit(error = Name::AppendMultipleLabels(serviceName, domainStart, aMessage));
450 error = Name::AppendPointerLabel(domainCompressOffset, aMessage);
451 }
452 }
453
454 exit:
455 return error;
456 }
457
AppendInstanceName(Message & aMessage,const char * aName,NameCompressInfo & aCompressInfo)458 Error Server::AppendInstanceName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo)
459 {
460 Error error;
461 uint16_t instanceCompressOffset = aCompressInfo.GetInstanceNameOffset(aMessage, aName);
462
463 if (instanceCompressOffset != NameCompressInfo::kUnknownOffset)
464 {
465 error = Name::AppendPointerLabel(instanceCompressOffset, aMessage);
466 }
467 else
468 {
469 NameComponentsOffsetInfo nameComponentsInfo;
470
471 IgnoreError(FindNameComponents(aName, aCompressInfo.GetDomainName(), nameComponentsInfo));
472 OT_ASSERT(nameComponentsInfo.IsServiceInstanceName());
473
474 aCompressInfo.SetInstanceNameOffset(aMessage.GetLength());
475
476 // Append the instance name as one label
477 SuccessOrExit(error = Name::AppendLabel(aName, nameComponentsInfo.mServiceOffset - 1, aMessage));
478
479 {
480 const char *serviceName = aName + nameComponentsInfo.mServiceOffset;
481 uint16_t serviceCompressOffset = aCompressInfo.GetServiceNameOffset(aMessage, serviceName);
482
483 if (serviceCompressOffset != NameCompressInfo::kUnknownOffset)
484 {
485 error = Name::AppendPointerLabel(serviceCompressOffset, aMessage);
486 }
487 else
488 {
489 aCompressInfo.SetServiceNameOffset(aMessage.GetLength());
490 error = Name::AppendName(serviceName, aMessage);
491 }
492 }
493 }
494
495 exit:
496 return error;
497 }
498
AppendTxtRecord(Message & aMessage,const char * aInstanceName,const void * aTxtData,uint16_t aTxtLength,uint32_t aTtl,NameCompressInfo & aCompressInfo)499 Error Server::AppendTxtRecord(Message & aMessage,
500 const char * aInstanceName,
501 const void * aTxtData,
502 uint16_t aTxtLength,
503 uint32_t aTtl,
504 NameCompressInfo &aCompressInfo)
505 {
506 Error error = kErrorNone;
507 TxtRecord txtRecord;
508 const uint8_t kEmptyTxt = 0;
509
510 SuccessOrExit(error = AppendInstanceName(aMessage, aInstanceName, aCompressInfo));
511
512 txtRecord.Init();
513 txtRecord.SetTtl(aTtl);
514 txtRecord.SetLength(aTxtLength > 0 ? aTxtLength : sizeof(kEmptyTxt));
515
516 SuccessOrExit(error = aMessage.Append(txtRecord));
517 if (aTxtLength > 0)
518 {
519 error = aMessage.AppendBytes(aTxtData, aTxtLength);
520 }
521 else
522 {
523 error = aMessage.Append(kEmptyTxt);
524 }
525
526 exit:
527 return error;
528 }
529
AppendHostName(Message & aMessage,const char * aName,NameCompressInfo & aCompressInfo)530 Error Server::AppendHostName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo)
531 {
532 Error error;
533 uint16_t hostCompressOffset = aCompressInfo.GetHostNameOffset(aMessage, aName);
534
535 if (hostCompressOffset != NameCompressInfo::kUnknownOffset)
536 {
537 error = Name::AppendPointerLabel(hostCompressOffset, aMessage);
538 }
539 else
540 {
541 uint8_t domainStart = static_cast<uint8_t>(StringLength(aName, Name::kMaxNameLength) -
542 StringLength(aCompressInfo.GetDomainName(), Name::kMaxNameSize - 1));
543 uint16_t domainCompressOffset = aCompressInfo.GetDomainNameOffset();
544
545 hostCompressOffset = aMessage.GetLength();
546 aCompressInfo.SetHostNameOffset(hostCompressOffset);
547
548 if (domainCompressOffset == NameCompressInfo::kUnknownOffset)
549 {
550 aCompressInfo.SetDomainNameOffset(hostCompressOffset + domainStart);
551 error = Name::AppendName(aName, aMessage);
552 }
553 else
554 {
555 SuccessOrExit(error = Name::AppendMultipleLabels(aName, domainStart, aMessage));
556 error = Name::AppendPointerLabel(domainCompressOffset, aMessage);
557 }
558 }
559
560 exit:
561 return error;
562 }
563
IncResourceRecordCount(Header & aHeader,bool aAdditional)564 void Server::IncResourceRecordCount(Header &aHeader, bool aAdditional)
565 {
566 if (aAdditional)
567 {
568 aHeader.SetAdditionalRecordCount(aHeader.GetAdditionalRecordCount() + 1);
569 }
570 else
571 {
572 aHeader.SetAnswerCount(aHeader.GetAnswerCount() + 1);
573 }
574 }
575
FindNameComponents(const char * aName,const char * aDomain,NameComponentsOffsetInfo & aInfo)576 Error Server::FindNameComponents(const char *aName, const char *aDomain, NameComponentsOffsetInfo &aInfo)
577 {
578 uint8_t nameLen = static_cast<uint8_t>(StringLength(aName, Name::kMaxNameLength));
579 uint8_t domainLen = static_cast<uint8_t>(StringLength(aDomain, Name::kMaxNameLength));
580 Error error = kErrorNone;
581 uint8_t labelBegin, labelEnd;
582
583 VerifyOrExit(Name::IsSubDomainOf(aName, aDomain), error = kErrorInvalidArgs);
584
585 labelBegin = nameLen - domainLen;
586 aInfo.mDomainOffset = labelBegin;
587
588 while (true)
589 {
590 error = FindPreviousLabel(aName, labelBegin, labelEnd);
591
592 VerifyOrExit(error == kErrorNone, error = (error == kErrorNotFound ? kErrorNone : error));
593
594 if (labelEnd == labelBegin + kProtocolLabelLength &&
595 (StringStartsWith(&aName[labelBegin], kDnssdProtocolUdp, kStringCaseInsensitiveMatch) ||
596 StringStartsWith(&aName[labelBegin], kDnssdProtocolTcp, kStringCaseInsensitiveMatch)))
597 {
598 // <Protocol> label found
599 aInfo.mProtocolOffset = labelBegin;
600 break;
601 }
602 }
603
604 // Get service label <Service>
605 error = FindPreviousLabel(aName, labelBegin, labelEnd);
606 VerifyOrExit(error == kErrorNone, error = (error == kErrorNotFound ? kErrorNone : error));
607
608 aInfo.mServiceOffset = labelBegin;
609
610 // Check for service subtype
611 error = FindPreviousLabel(aName, labelBegin, labelEnd);
612 VerifyOrExit(error == kErrorNone, error = (error == kErrorNotFound ? kErrorNone : error));
613
614 // Note that `kDnssdSubTypeLabel` is "._sub.". Here we get the
615 // label only so we want to compare it with "_sub".
616 if ((labelEnd == labelBegin + kSubTypeLabelLength) &&
617 StringStartsWith(&aName[labelBegin], kDnssdSubTypeLabel + 1, kStringCaseInsensitiveMatch))
618 {
619 SuccessOrExit(error = FindPreviousLabel(aName, labelBegin, labelEnd));
620 VerifyOrExit(labelBegin == 0, error = kErrorInvalidArgs);
621 aInfo.mSubTypeOffset = labelBegin;
622 ExitNow();
623 }
624
625 // Treat everything before <Service> as <Instance> label
626 aInfo.mInstanceOffset = 0;
627
628 exit:
629 return error;
630 }
631
FindPreviousLabel(const char * aName,uint8_t & aStart,uint8_t & aStop)632 Error Server::FindPreviousLabel(const char *aName, uint8_t &aStart, uint8_t &aStop)
633 {
634 // This method finds the previous label before the current label (whose start index is @p aStart), and updates @p
635 // aStart to the start index of the label and @p aStop to the index of the dot just after the label.
636 // @note The input value of @p aStop does not matter because it is only used to output.
637
638 Error error = kErrorNone;
639 uint8_t start = aStart;
640 uint8_t end;
641
642 VerifyOrExit(start > 0, error = kErrorNotFound);
643 VerifyOrExit(aName[--start] == Name::kLabelSeperatorChar, error = kErrorInvalidArgs);
644
645 end = start;
646 while (start > 0 && aName[start - 1] != Name::kLabelSeperatorChar)
647 {
648 start--;
649 }
650
651 VerifyOrExit(start < end, error = kErrorInvalidArgs);
652
653 aStart = start;
654 aStop = end;
655
656 exit:
657 return error;
658 }
659
660 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
ResolveBySrp(Header & aResponseHeader,Message & aResponseMessage,Server::NameCompressInfo & aCompressInfo)661 Header::Response Server::ResolveBySrp(Header & aResponseHeader,
662 Message & aResponseMessage,
663 Server::NameCompressInfo &aCompressInfo)
664 {
665 Question question;
666 uint16_t readOffset = sizeof(Header);
667 Header::Response response = Header::kResponseSuccess;
668 char name[Name::kMaxNameSize];
669
670 for (uint16_t i = 0; i < aResponseHeader.GetQuestionCount(); i++)
671 {
672 IgnoreError(Name::ReadName(aResponseMessage, readOffset, name, sizeof(name)));
673 IgnoreError(aResponseMessage.Read(readOffset, question));
674 readOffset += sizeof(question);
675
676 response = ResolveQuestionBySrp(name, question, aResponseHeader, aResponseMessage, aCompressInfo,
677 /* aAdditional */ false);
678
679 LogInfo("ANSWER: TRANSACTION=0x%04x, QUESTION=[%s %d %d], RCODE=%d", aResponseHeader.GetMessageId(), name,
680 question.GetClass(), question.GetType(), response);
681 }
682
683 // Answer the questions with additional RRs if required
684 if (aResponseHeader.GetAnswerCount() > 0)
685 {
686 readOffset = sizeof(Header);
687 for (uint16_t i = 0; i < aResponseHeader.GetQuestionCount(); i++)
688 {
689 IgnoreError(Name::ReadName(aResponseMessage, readOffset, name, sizeof(name)));
690 IgnoreError(aResponseMessage.Read(readOffset, question));
691 readOffset += sizeof(question);
692
693 VerifyOrExit(Header::kResponseServerFailure != ResolveQuestionBySrp(name, question, aResponseHeader,
694 aResponseMessage, aCompressInfo,
695 /* aAdditional */ true),
696 response = Header::kResponseServerFailure);
697
698 LogInfo("ADDITIONAL: TRANSACTION=0x%04x, QUESTION=[%s %d %d], RCODE=%d", aResponseHeader.GetMessageId(),
699 name, question.GetClass(), question.GetType(), response);
700 }
701 }
702 exit:
703 return response;
704 }
705
ResolveQuestionBySrp(const char * aName,const Question & aQuestion,Header & aResponseHeader,Message & aResponseMessage,NameCompressInfo & aCompressInfo,bool aAdditional)706 Header::Response Server::ResolveQuestionBySrp(const char * aName,
707 const Question & aQuestion,
708 Header & aResponseHeader,
709 Message & aResponseMessage,
710 NameCompressInfo &aCompressInfo,
711 bool aAdditional)
712 {
713 Error error = kErrorNone;
714 const Srp::Server::Host *host = nullptr;
715 TimeMilli now = TimerMilli::GetNow();
716 uint16_t qtype = aQuestion.GetType();
717 Header::Response response = Header::kResponseNameError;
718
719 while ((host = GetNextSrpHost(host)) != nullptr)
720 {
721 bool needAdditionalAaaaRecord = false;
722 const char *hostName = host->GetFullName();
723
724 // Handle PTR/SRV/TXT query
725 if (qtype == ResourceRecord::kTypePtr || qtype == ResourceRecord::kTypeSrv || qtype == ResourceRecord::kTypeTxt)
726 {
727 const Srp::Server::Service *service = nullptr;
728
729 while ((service = GetNextSrpService(*host, service)) != nullptr)
730 {
731 uint32_t instanceTtl = TimeMilli::MsecToSec(service->GetExpireTime() - TimerMilli::GetNow());
732 const char *instanceName = service->GetInstanceName();
733 bool serviceNameMatched = service->MatchesServiceName(aName);
734 bool instanceNameMatched = service->MatchesInstanceName(aName);
735 bool ptrQueryMatched = qtype == ResourceRecord::kTypePtr && serviceNameMatched;
736 bool srvQueryMatched = qtype == ResourceRecord::kTypeSrv && instanceNameMatched;
737 bool txtQueryMatched = qtype == ResourceRecord::kTypeTxt && instanceNameMatched;
738
739 if (ptrQueryMatched || srvQueryMatched)
740 {
741 needAdditionalAaaaRecord = true;
742 }
743
744 if (!aAdditional && ptrQueryMatched)
745 {
746 SuccessOrExit(
747 error = AppendPtrRecord(aResponseMessage, aName, instanceName, instanceTtl, aCompressInfo));
748 IncResourceRecordCount(aResponseHeader, aAdditional);
749 response = Header::kResponseSuccess;
750 }
751
752 if ((!aAdditional && srvQueryMatched) ||
753 (aAdditional && ptrQueryMatched &&
754 !HasQuestion(aResponseHeader, aResponseMessage, instanceName, ResourceRecord::kTypeSrv)))
755 {
756 SuccessOrExit(error = AppendSrvRecord(aResponseMessage, instanceName, hostName, instanceTtl,
757 service->GetPriority(), service->GetWeight(),
758 service->GetPort(), aCompressInfo));
759 IncResourceRecordCount(aResponseHeader, aAdditional);
760 response = Header::kResponseSuccess;
761 }
762
763 if ((!aAdditional && txtQueryMatched) ||
764 (aAdditional && ptrQueryMatched &&
765 !HasQuestion(aResponseHeader, aResponseMessage, instanceName, ResourceRecord::kTypeTxt)))
766 {
767 SuccessOrExit(error = AppendTxtRecord(aResponseMessage, instanceName, service->GetTxtData(),
768 service->GetTxtDataLength(), instanceTtl, aCompressInfo));
769 IncResourceRecordCount(aResponseHeader, aAdditional);
770 response = Header::kResponseSuccess;
771 }
772 }
773 }
774
775 // Handle AAAA query
776 if ((!aAdditional && qtype == ResourceRecord::kTypeAaaa && host->Matches(aName)) ||
777 (aAdditional && needAdditionalAaaaRecord &&
778 !HasQuestion(aResponseHeader, aResponseMessage, hostName, ResourceRecord::kTypeAaaa)))
779 {
780 uint8_t addrNum;
781 const Ip6::Address *addrs = host->GetAddresses(addrNum);
782 uint32_t hostTtl = TimeMilli::MsecToSec(host->GetExpireTime() - now);
783
784 for (uint8_t i = 0; i < addrNum; i++)
785 {
786 SuccessOrExit(error = AppendAaaaRecord(aResponseMessage, hostName, addrs[i], hostTtl, aCompressInfo));
787 IncResourceRecordCount(aResponseHeader, aAdditional);
788 }
789
790 response = Header::kResponseSuccess;
791 }
792 }
793
794 exit:
795 return error == kErrorNone ? response : Header::kResponseServerFailure;
796 }
797
GetNextSrpHost(const Srp::Server::Host * aHost)798 const Srp::Server::Host *Server::GetNextSrpHost(const Srp::Server::Host *aHost)
799 {
800 const Srp::Server::Host *host = Get<Srp::Server>().GetNextHost(aHost);
801
802 while (host != nullptr && host->IsDeleted())
803 {
804 host = Get<Srp::Server>().GetNextHost(host);
805 }
806
807 return host;
808 }
809
GetNextSrpService(const Srp::Server::Host & aHost,const Srp::Server::Service * aService)810 const Srp::Server::Service *Server::GetNextSrpService(const Srp::Server::Host & aHost,
811 const Srp::Server::Service *aService)
812 {
813 return aHost.FindNextService(aService, Srp::Server::kFlagsAnyTypeActiveService);
814 }
815 #endif // OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
816
ResolveByQueryCallbacks(Header & aResponseHeader,Message & aResponseMessage,NameCompressInfo & aCompressInfo,const Ip6::MessageInfo & aMessageInfo)817 Error Server::ResolveByQueryCallbacks(Header & aResponseHeader,
818 Message & aResponseMessage,
819 NameCompressInfo & aCompressInfo,
820 const Ip6::MessageInfo &aMessageInfo)
821 {
822 QueryTransaction *query = nullptr;
823 DnsQueryType queryType;
824 char name[Name::kMaxNameSize];
825
826 Error error = kErrorNone;
827
828 VerifyOrExit(mQuerySubscribe != nullptr, error = kErrorFailed);
829
830 queryType = GetQueryTypeAndName(aResponseHeader, aResponseMessage, name);
831 VerifyOrExit(queryType != kDnsQueryNone, error = kErrorNotImplemented);
832
833 query = NewQuery(aResponseHeader, aResponseMessage, aCompressInfo, aMessageInfo);
834 VerifyOrExit(query != nullptr, error = kErrorNoBufs);
835
836 mQuerySubscribe(mQueryCallbackContext, name);
837
838 exit:
839 return error;
840 }
841
NewQuery(const Header & aResponseHeader,Message & aResponseMessage,const NameCompressInfo & aCompressInfo,const Ip6::MessageInfo & aMessageInfo)842 Server::QueryTransaction *Server::NewQuery(const Header & aResponseHeader,
843 Message & aResponseMessage,
844 const NameCompressInfo &aCompressInfo,
845 const Ip6::MessageInfo &aMessageInfo)
846 {
847 QueryTransaction *newQuery = nullptr;
848
849 for (QueryTransaction &query : mQueryTransactions)
850 {
851 if (query.IsValid())
852 {
853 continue;
854 }
855
856 query.Init(aResponseHeader, aResponseMessage, aCompressInfo, aMessageInfo, GetInstance());
857 ExitNow(newQuery = &query);
858 }
859
860 exit:
861 if (newQuery != nullptr)
862 {
863 ResetTimer();
864 }
865
866 return newQuery;
867 }
868
CanAnswerQuery(const QueryTransaction & aQuery,const char * aServiceFullName,const otDnssdServiceInstanceInfo & aInstanceInfo)869 bool Server::CanAnswerQuery(const QueryTransaction & aQuery,
870 const char * aServiceFullName,
871 const otDnssdServiceInstanceInfo &aInstanceInfo)
872 {
873 char name[Name::kMaxNameSize];
874 DnsQueryType sdType;
875 bool canAnswer = false;
876
877 sdType = GetQueryTypeAndName(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), name);
878
879 switch (sdType)
880 {
881 case kDnsQueryBrowse:
882 canAnswer = StringMatch(name, aServiceFullName, kStringCaseInsensitiveMatch);
883 break;
884 case kDnsQueryResolve:
885 canAnswer = StringMatch(name, aInstanceInfo.mFullName, kStringCaseInsensitiveMatch);
886 break;
887 default:
888 break;
889 }
890
891 return canAnswer;
892 }
893
CanAnswerQuery(const Server::QueryTransaction & aQuery,const char * aHostFullName)894 bool Server::CanAnswerQuery(const Server::QueryTransaction &aQuery, const char *aHostFullName)
895 {
896 char name[Name::kMaxNameSize];
897 DnsQueryType sdType;
898
899 sdType = GetQueryTypeAndName(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), name);
900 return (sdType == kDnsQueryResolveHost) && StringMatch(name, aHostFullName, kStringCaseInsensitiveMatch);
901 }
902
AnswerQuery(QueryTransaction & aQuery,const char * aServiceFullName,const otDnssdServiceInstanceInfo & aInstanceInfo)903 void Server::AnswerQuery(QueryTransaction & aQuery,
904 const char * aServiceFullName,
905 const otDnssdServiceInstanceInfo &aInstanceInfo)
906 {
907 Header & responseHeader = aQuery.GetResponseHeader();
908 Message & responseMessage = aQuery.GetResponseMessage();
909 Error error = kErrorNone;
910 NameCompressInfo &compressInfo = aQuery.GetNameCompressInfo();
911
912 if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aServiceFullName,
913 ResourceRecord::kTypePtr))
914 {
915 SuccessOrExit(error = AppendPtrRecord(responseMessage, aServiceFullName, aInstanceInfo.mFullName,
916 aInstanceInfo.mTtl, compressInfo));
917 IncResourceRecordCount(responseHeader, false);
918 }
919
920 for (uint8_t additional = 0; additional <= 1; additional++)
921 {
922 if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aInstanceInfo.mFullName,
923 ResourceRecord::kTypeSrv) == !additional)
924 {
925 SuccessOrExit(error = AppendSrvRecord(responseMessage, aInstanceInfo.mFullName, aInstanceInfo.mHostName,
926 aInstanceInfo.mTtl, aInstanceInfo.mPriority, aInstanceInfo.mWeight,
927 aInstanceInfo.mPort, compressInfo));
928 IncResourceRecordCount(responseHeader, additional);
929 }
930
931 if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aInstanceInfo.mFullName,
932 ResourceRecord::kTypeTxt) == !additional)
933 {
934 SuccessOrExit(error = AppendTxtRecord(responseMessage, aInstanceInfo.mFullName, aInstanceInfo.mTxtData,
935 aInstanceInfo.mTxtLength, aInstanceInfo.mTtl, compressInfo));
936 IncResourceRecordCount(responseHeader, additional);
937 }
938
939 if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aInstanceInfo.mHostName,
940 ResourceRecord::kTypeAaaa) == !additional)
941 {
942 for (uint8_t i = 0; i < aInstanceInfo.mAddressNum; i++)
943 {
944 const Ip6::Address &address = AsCoreType(&aInstanceInfo.mAddresses[i]);
945
946 OT_ASSERT(!address.IsUnspecified() && !address.IsLinkLocal() && !address.IsMulticast() &&
947 !address.IsLoopback());
948
949 SuccessOrExit(error = AppendAaaaRecord(responseMessage, aInstanceInfo.mHostName, address,
950 aInstanceInfo.mTtl, compressInfo));
951 IncResourceRecordCount(responseHeader, additional);
952 }
953 }
954 }
955
956 exit:
957 FinalizeQuery(aQuery, error == kErrorNone ? Header::kResponseSuccess : Header::kResponseServerFailure);
958 ResetTimer();
959 }
960
AnswerQuery(QueryTransaction & aQuery,const char * aHostFullName,const otDnssdHostInfo & aHostInfo)961 void Server::AnswerQuery(QueryTransaction &aQuery, const char *aHostFullName, const otDnssdHostInfo &aHostInfo)
962 {
963 Header & responseHeader = aQuery.GetResponseHeader();
964 Message & responseMessage = aQuery.GetResponseMessage();
965 Error error = kErrorNone;
966 NameCompressInfo &compressInfo = aQuery.GetNameCompressInfo();
967
968 if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aHostFullName, ResourceRecord::kTypeAaaa))
969 {
970 for (uint8_t i = 0; i < aHostInfo.mAddressNum; i++)
971 {
972 const Ip6::Address &address = AsCoreType(&aHostInfo.mAddresses[i]);
973
974 OT_ASSERT(!address.IsUnspecified() && !address.IsMulticast() && !address.IsLinkLocal() &&
975 !address.IsLoopback());
976
977 SuccessOrExit(error =
978 AppendAaaaRecord(responseMessage, aHostFullName, address, aHostInfo.mTtl, compressInfo));
979 IncResourceRecordCount(responseHeader, /* aAdditional */ false);
980 }
981 }
982
983 exit:
984 FinalizeQuery(aQuery, error == kErrorNone ? Header::kResponseSuccess : Header::kResponseServerFailure);
985 ResetTimer();
986 }
987
SetQueryCallbacks(otDnssdQuerySubscribeCallback aSubscribe,otDnssdQueryUnsubscribeCallback aUnsubscribe,void * aContext)988 void Server::SetQueryCallbacks(otDnssdQuerySubscribeCallback aSubscribe,
989 otDnssdQueryUnsubscribeCallback aUnsubscribe,
990 void * aContext)
991 {
992 OT_ASSERT((aSubscribe == nullptr) == (aUnsubscribe == nullptr));
993
994 mQuerySubscribe = aSubscribe;
995 mQueryUnsubscribe = aUnsubscribe;
996 mQueryCallbackContext = aContext;
997 }
998
HandleDiscoveredServiceInstance(const char * aServiceFullName,const otDnssdServiceInstanceInfo & aInstanceInfo)999 void Server::HandleDiscoveredServiceInstance(const char * aServiceFullName,
1000 const otDnssdServiceInstanceInfo &aInstanceInfo)
1001 {
1002 OT_ASSERT(StringEndsWith(aServiceFullName, Name::kLabelSeperatorChar));
1003 OT_ASSERT(StringEndsWith(aInstanceInfo.mFullName, Name::kLabelSeperatorChar));
1004 OT_ASSERT(StringEndsWith(aInstanceInfo.mHostName, Name::kLabelSeperatorChar));
1005
1006 for (QueryTransaction &query : mQueryTransactions)
1007 {
1008 if (query.IsValid() && CanAnswerQuery(query, aServiceFullName, aInstanceInfo))
1009 {
1010 AnswerQuery(query, aServiceFullName, aInstanceInfo);
1011 }
1012 }
1013 }
1014
HandleDiscoveredHost(const char * aHostFullName,const otDnssdHostInfo & aHostInfo)1015 void Server::HandleDiscoveredHost(const char *aHostFullName, const otDnssdHostInfo &aHostInfo)
1016 {
1017 OT_ASSERT(StringEndsWith(aHostFullName, Name::kLabelSeperatorChar));
1018
1019 for (QueryTransaction &query : mQueryTransactions)
1020 {
1021 if (query.IsValid() && CanAnswerQuery(query, aHostFullName))
1022 {
1023 AnswerQuery(query, aHostFullName, aHostInfo);
1024 }
1025 }
1026 }
1027
GetNextQuery(const otDnssdQuery * aQuery) const1028 const otDnssdQuery *Server::GetNextQuery(const otDnssdQuery *aQuery) const
1029 {
1030 const QueryTransaction *cur = &mQueryTransactions[0];
1031 const QueryTransaction *found = nullptr;
1032 const QueryTransaction *query = static_cast<const QueryTransaction *>(aQuery);
1033
1034 if (aQuery != nullptr)
1035 {
1036 cur = query + 1;
1037 }
1038
1039 for (; cur < GetArrayEnd(mQueryTransactions); cur++)
1040 {
1041 if (cur->IsValid())
1042 {
1043 found = cur;
1044 break;
1045 }
1046 }
1047
1048 return static_cast<const otDnssdQuery *>(found);
1049 }
1050
GetQueryTypeAndName(const otDnssdQuery * aQuery,char (& aName)[Name::kMaxNameSize])1051 Server::DnsQueryType Server::GetQueryTypeAndName(const otDnssdQuery *aQuery, char (&aName)[Name::kMaxNameSize])
1052 {
1053 const QueryTransaction *query = static_cast<const QueryTransaction *>(aQuery);
1054
1055 OT_ASSERT(query->IsValid());
1056 return GetQueryTypeAndName(query->GetResponseHeader(), query->GetResponseMessage(), aName);
1057 }
1058
GetQueryTypeAndName(const Header & aHeader,const Message & aMessage,char (& aName)[Name::kMaxNameSize])1059 Server::DnsQueryType Server::GetQueryTypeAndName(const Header & aHeader,
1060 const Message &aMessage,
1061 char (&aName)[Name::kMaxNameSize])
1062 {
1063 DnsQueryType sdType = kDnsQueryNone;
1064
1065 for (uint16_t i = 0, readOffset = sizeof(Header); i < aHeader.GetQuestionCount(); i++)
1066 {
1067 Question question;
1068
1069 IgnoreError(Name::ReadName(aMessage, readOffset, aName, sizeof(aName)));
1070 IgnoreError(aMessage.Read(readOffset, question));
1071 readOffset += sizeof(question);
1072
1073 switch (question.GetType())
1074 {
1075 case ResourceRecord::kTypePtr:
1076 ExitNow(sdType = kDnsQueryBrowse);
1077 case ResourceRecord::kTypeSrv:
1078 case ResourceRecord::kTypeTxt:
1079 ExitNow(sdType = kDnsQueryResolve);
1080 }
1081 }
1082
1083 for (uint16_t i = 0, readOffset = sizeof(Header); i < aHeader.GetQuestionCount(); i++)
1084 {
1085 Question question;
1086
1087 IgnoreError(Name::ReadName(aMessage, readOffset, aName, sizeof(aName)));
1088 IgnoreError(aMessage.Read(readOffset, question));
1089 readOffset += sizeof(question);
1090
1091 switch (question.GetType())
1092 {
1093 case ResourceRecord::kTypeAaaa:
1094 case ResourceRecord::kTypeA:
1095 ExitNow(sdType = kDnsQueryResolveHost);
1096 }
1097 }
1098
1099 exit:
1100 return sdType;
1101 }
1102
HasQuestion(const Header & aHeader,const Message & aMessage,const char * aName,uint16_t aQuestionType)1103 bool Server::HasQuestion(const Header &aHeader, const Message &aMessage, const char *aName, uint16_t aQuestionType)
1104 {
1105 bool found = false;
1106
1107 for (uint16_t i = 0, readOffset = sizeof(Header); i < aHeader.GetQuestionCount(); i++)
1108 {
1109 Question question;
1110 Error error;
1111
1112 error = Name::CompareName(aMessage, readOffset, aName);
1113 IgnoreError(aMessage.Read(readOffset, question));
1114 readOffset += sizeof(question);
1115
1116 if (error == kErrorNone && aQuestionType == question.GetType())
1117 {
1118 ExitNow(found = true);
1119 }
1120 }
1121
1122 exit:
1123 return found;
1124 }
1125
HandleTimer(Timer & aTimer)1126 void Server::HandleTimer(Timer &aTimer)
1127 {
1128 aTimer.Get<Server>().HandleTimer();
1129 }
1130
HandleTimer(void)1131 void Server::HandleTimer(void)
1132 {
1133 TimeMilli now = TimerMilli::GetNow();
1134
1135 for (QueryTransaction &query : mQueryTransactions)
1136 {
1137 TimeMilli expire;
1138
1139 if (!query.IsValid())
1140 {
1141 continue;
1142 }
1143
1144 expire = query.GetStartTime() + kQueryTimeout;
1145 if (expire <= now)
1146 {
1147 FinalizeQuery(query, Header::kResponseSuccess);
1148 }
1149 }
1150
1151 ResetTimer();
1152 }
1153
ResetTimer(void)1154 void Server::ResetTimer(void)
1155 {
1156 TimeMilli now = TimerMilli::GetNow();
1157 TimeMilli nextExpire = now.GetDistantFuture();
1158
1159 for (QueryTransaction &query : mQueryTransactions)
1160 {
1161 TimeMilli expire;
1162
1163 if (!query.IsValid())
1164 {
1165 continue;
1166 }
1167
1168 expire = query.GetStartTime() + kQueryTimeout;
1169 if (expire <= now)
1170 {
1171 nextExpire = now;
1172 }
1173 else if (expire < nextExpire)
1174 {
1175 nextExpire = expire;
1176 }
1177 }
1178
1179 if (nextExpire < now.GetDistantFuture())
1180 {
1181 mTimer.FireAt(nextExpire);
1182 }
1183 else
1184 {
1185 mTimer.Stop();
1186 }
1187 }
1188
FinalizeQuery(QueryTransaction & aQuery,Header::Response aResponseCode)1189 void Server::FinalizeQuery(QueryTransaction &aQuery, Header::Response aResponseCode)
1190 {
1191 char name[Name::kMaxNameSize];
1192 DnsQueryType sdType;
1193
1194 OT_ASSERT(mQueryUnsubscribe != nullptr);
1195
1196 sdType = GetQueryTypeAndName(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), name);
1197
1198 OT_ASSERT(sdType != kDnsQueryNone);
1199 OT_UNUSED_VARIABLE(sdType);
1200
1201 mQueryUnsubscribe(mQueryCallbackContext, name);
1202 aQuery.Finalize(aResponseCode, mSocket);
1203 }
1204
Init(const Header & aResponseHeader,Message & aResponseMessage,const NameCompressInfo & aCompressInfo,const Ip6::MessageInfo & aMessageInfo,Instance & aInstance)1205 void Server::QueryTransaction::Init(const Header & aResponseHeader,
1206 Message & aResponseMessage,
1207 const NameCompressInfo &aCompressInfo,
1208 const Ip6::MessageInfo &aMessageInfo,
1209 Instance & aInstance)
1210 {
1211 OT_ASSERT(mResponseMessage == nullptr);
1212
1213 InstanceLocatorInit::Init(aInstance);
1214 mResponseHeader = aResponseHeader;
1215 mResponseMessage = &aResponseMessage;
1216 mCompressInfo = aCompressInfo;
1217 mMessageInfo = aMessageInfo;
1218 mStartTime = TimerMilli::GetNow();
1219 }
1220
Finalize(Header::Response aResponseMessage,Ip6::Udp::Socket & aSocket)1221 void Server::QueryTransaction::Finalize(Header::Response aResponseMessage, Ip6::Udp::Socket &aSocket)
1222 {
1223 OT_ASSERT(mResponseMessage != nullptr);
1224
1225 Get<Server>().SendResponse(mResponseHeader, aResponseMessage, *mResponseMessage, mMessageInfo, aSocket);
1226 mResponseMessage = nullptr;
1227 }
1228
UpdateResponseCounters(Header::Response aResponseCode)1229 void Server::UpdateResponseCounters(Header::Response aResponseCode)
1230 {
1231 switch (aResponseCode)
1232 {
1233 case UpdateHeader::kResponseSuccess:
1234 ++mCounters.mSuccessResponse;
1235 break;
1236 case UpdateHeader::kResponseServerFailure:
1237 ++mCounters.mServerFailureResponse;
1238 break;
1239 case UpdateHeader::kResponseFormatError:
1240 ++mCounters.mFormatErrorResponse;
1241 break;
1242 case UpdateHeader::kResponseNameError:
1243 ++mCounters.mNameErrorResponse;
1244 break;
1245 case UpdateHeader::kResponseNotImplemented:
1246 ++mCounters.mNotImplementedResponse;
1247 break;
1248 default:
1249 ++mCounters.mOtherResponse;
1250 break;
1251 }
1252 }
1253
1254 } // namespace ServiceDiscovery
1255 } // namespace Dns
1256 } // namespace ot
1257
1258 #endif // OPENTHREAD_CONFIG_DNS_SERVER_ENABLE
1259