1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/mdns/mdns_querier.h"
6
7 #include <algorithm>
8 #include <array>
9 #include <bitset>
10 #include <memory>
11 #include <unordered_set>
12 #include <utility>
13 #include <vector>
14
15 #include "discovery/common/config.h"
16 #include "discovery/common/reporting_client.h"
17 #include "discovery/mdns/mdns_random.h"
18 #include "discovery/mdns/mdns_receiver.h"
19 #include "discovery/mdns/mdns_sender.h"
20 #include "discovery/mdns/public/mdns_constants.h"
21
22 namespace openscreen {
23 namespace discovery {
24 namespace {
25
26 constexpr std::array<DnsType, 5> kTranslatedNsecAnyQueryTypes = {
27 DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA, DnsType::kSRV};
28
IsNegativeResponseFor(const MdnsRecord & record,DnsType type)29 bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) {
30 if (record.dns_type() != DnsType::kNSEC) {
31 return false;
32 }
33
34 const NsecRecordRdata& nsec = absl::get<NsecRecordRdata>(record.rdata());
35
36 // RFC 6762 section 6.1, the NSEC bit must NOT be set in the received NSEC
37 // record to indicate this is an mDNS NSEC record rather than a traditional
38 // DNS NSEC record.
39 if (std::find(nsec.types().begin(), nsec.types().end(), DnsType::kNSEC) !=
40 nsec.types().end()) {
41 return false;
42 }
43
44 return std::find_if(nsec.types().begin(), nsec.types().end(),
45 [type](DnsType stored_type) {
46 return stored_type == type ||
47 stored_type == DnsType::kANY;
48 }) != nsec.types().end();
49 }
50
51 struct HashDnsType {
operator ()openscreen::discovery::__anon18feac030111::HashDnsType52 inline size_t operator()(DnsType type) const {
53 return static_cast<size_t>(type);
54 }
55 };
56
57 // Helper used for sorting MDNS records. This function guarantees the following:
58 // - All MdnsRecords with the same name appear adjacent to each-other.
59 // - An NSEC record with a given name appears before all other records with the
60 // same name.
CompareRecordByNameAndType(const MdnsRecord & first,const MdnsRecord & second)61 bool CompareRecordByNameAndType(const MdnsRecord& first,
62 const MdnsRecord& second) {
63 if (first.name() != second.name()) {
64 return first.name() < second.name();
65 }
66
67 if ((first.dns_type() == DnsType::kNSEC) !=
68 (second.dns_type() == DnsType::kNSEC)) {
69 return first.dns_type() == DnsType::kNSEC;
70 }
71
72 return first < second;
73 }
74
75 class DnsTypeBitSet {
76 public:
77 // Returns whether any types are currently stored in this data structure.
IsEmpty()78 bool IsEmpty() { return !elements_.any(); }
79
80 // Attempts to insert the given type into this data structure. Returns
81 // true iff the type was not already present.
Insert(DnsType type)82 bool Insert(DnsType type) {
83 uint16_t bit = (type == DnsType::kANY) ? 0 : static_cast<uint16_t>(type);
84 bool was_set = elements_.test(bit);
85 elements_.set(bit);
86 return !was_set;
87 }
88
89 // Iterates over all members of the provided container, inserting each
90 // DnsType contained within to this instance. Returns true iff any element
91 // inserted was not already present in this instance.
92 template <typename Container>
Insert(const Container & container)93 bool Insert(const Container& container) {
94 bool has_element_been_inserted = false;
95 for (DnsType type : container) {
96 has_element_been_inserted |= Insert(type);
97 }
98 return has_element_been_inserted;
99 }
100
101 // Attempts to remove the given type from this data structure. Returns true
102 // iff the type was present prior to this call.
Remove(DnsType type)103 bool Remove(DnsType type) {
104 if (IsEmpty()) {
105 return false;
106 } else if (type == DnsType::kANY) {
107 elements_.reset();
108 return true;
109 }
110
111 uint16_t bit = static_cast<uint16_t>(type);
112 bool was_set = elements_.test(bit);
113 elements_.reset(bit);
114 return was_set;
115 }
116
117 // Returns the DnsTypes currently stored in this data structure.
GetTypes() const118 std::vector<DnsType> GetTypes() const {
119 if (elements_.test(0)) {
120 return {DnsType::kANY};
121 }
122
123 std::vector<DnsType> types;
124 for (DnsType type : kSupportedDnsTypes) {
125 if (type == DnsType::kANY) {
126 continue;
127 }
128
129 uint16_t cast_int = static_cast<uint16_t>(type);
130 if (elements_.test(cast_int)) {
131 types.push_back(type);
132 }
133 }
134 return types;
135 }
136
137 private:
138 std::bitset<64> elements_;
139 };
140
141 // Modifies |records| such that no NSEC record signifies the nonexistance of a
142 // record which is also present in the same message. Order of the input vector
143 // is NOT preserved.
144 // NOTE: |records| is not of type MdnsRecord::ConstRef because the members must
145 // be modified.
146 // TODO(b/170353378): Break this logic into a separate processing module between
147 // the MdnsReader and the MdnsQuerier.
RemoveInvalidNsecFlags(std::vector<MdnsRecord> * records)148 void RemoveInvalidNsecFlags(std::vector<MdnsRecord>* records) {
149 // Sort the records so NSEC records are first so that only one iteration
150 // through all records is needed.
151 std::sort(records->begin(), records->end(), CompareRecordByNameAndType);
152
153 // The set of NSEC records that need to be removed from |records|. This can't
154 // be done as part of the below loop because it would invalidate the iterator
155 // that's still being used.
156 std::vector<std::vector<MdnsRecord>::iterator> nsecs_to_delete;
157
158 // Process all elements.
159 for (auto it = records->begin(); it != records->end();) {
160 if (it->dns_type() != DnsType::kNSEC) {
161 it++;
162 continue;
163 }
164
165 // Track whether the current NSEC record in the input vector has been
166 // modified by some step of this algorithm, be that merging with another
167 // record, removing a DnsType, or any other modification.
168 bool has_changed = false;
169
170 // The types for the new record to create, if |has_changed|.
171 const NsecRecordRdata& nsec_rdata = absl::get<NsecRecordRdata>(it->rdata());
172 DnsTypeBitSet types;
173 for (DnsType type : nsec_rdata.types()) {
174 types.Insert(type);
175 }
176 auto nsec = it;
177 it++;
178
179 // Combine multiple NSECs to simplify the following code. This probably
180 // won't happen, but the RFC doesn't exclude the possibility, so account for
181 // it. Define the TTL of this new NSEC record created by this merge process
182 // to be the minimum of all merged NSEC records.
183 std::chrono::seconds new_ttl = nsec->ttl();
184 while (it != records->end() && it->name() == nsec->name() &&
185 it->dns_type() == DnsType::kNSEC) {
186 has_changed |=
187 types.Insert(absl::get<NsecRecordRdata>(it->rdata()).types());
188 new_ttl = std::min(new_ttl, it->ttl());
189 it = records->erase(it);
190 }
191
192 // Remove any types associated with a known record type.
193 for (; it != records->end() && it->name() == nsec->name(); it++) {
194 OSP_DCHECK(it->dns_type() != DnsType::kNSEC);
195 has_changed |= types.Remove(it->dns_type());
196 }
197
198 // Modify the stored NSEC record, if needed.
199 if (has_changed && types.IsEmpty()) {
200 nsecs_to_delete.push_back(nsec);
201 } else if (has_changed) {
202 NsecRecordRdata new_rdata(nsec_rdata.next_domain_name(),
203 types.GetTypes());
204 *nsec = MdnsRecord(nsec->name(), nsec->dns_type(), nsec->dns_class(),
205 nsec->record_type(), new_ttl, std::move(new_rdata));
206 }
207 }
208
209 // Erase invalid NSEC records. Go backwards to avoid invalidating the
210 // remaining iterators.
211 for (auto erase_it = nsecs_to_delete.rbegin();
212 erase_it != nsecs_to_delete.rend(); erase_it++) {
213 records->erase(*erase_it);
214 }
215 }
216
217 } // namespace
218
RecordTrackerLruCache(MdnsQuerier * querier,MdnsSender * sender,MdnsRandom * random_delay,TaskRunner * task_runner,ClockNowFunctionPtr now_function,ReportingClient * reporting_client,const Config & config)219 MdnsQuerier::RecordTrackerLruCache::RecordTrackerLruCache(
220 MdnsQuerier* querier,
221 MdnsSender* sender,
222 MdnsRandom* random_delay,
223 TaskRunner* task_runner,
224 ClockNowFunctionPtr now_function,
225 ReportingClient* reporting_client,
226 const Config& config)
227 : querier_(querier),
228 sender_(sender),
229 random_delay_(random_delay),
230 task_runner_(task_runner),
231 now_function_(now_function),
232 reporting_client_(reporting_client),
233 config_(config) {
234 OSP_DCHECK(sender_);
235 OSP_DCHECK(random_delay_);
236 OSP_DCHECK(task_runner_);
237 OSP_DCHECK(reporting_client_);
238 OSP_DCHECK_GT(config_.querier_max_records_cached, 0);
239 }
240
241 std::vector<std::reference_wrapper<const MdnsRecordTracker>>
Find(const DomainName & name)242 MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name) {
243 return Find(name, DnsType::kANY, DnsClass::kANY);
244 }
245
246 std::vector<std::reference_wrapper<const MdnsRecordTracker>>
Find(const DomainName & name,DnsType dns_type,DnsClass dns_class)247 MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name,
248 DnsType dns_type,
249 DnsClass dns_class) {
250 std::vector<RecordTrackerConstRef> results;
251 auto pair = records_.equal_range(name);
252 for (auto it = pair.first; it != pair.second; it++) {
253 const MdnsRecordTracker& tracker = *it->second;
254 if ((dns_type == DnsType::kANY || dns_type == tracker.dns_type()) &&
255 (dns_class == DnsClass::kANY || dns_class == tracker.dns_class())) {
256 results.push_back(std::cref(tracker));
257 }
258 }
259
260 return results;
261 }
262
Erase(const DomainName & domain,TrackerApplicableCheck check)263 int MdnsQuerier::RecordTrackerLruCache::Erase(const DomainName& domain,
264 TrackerApplicableCheck check) {
265 auto pair = records_.equal_range(domain);
266 int count = 0;
267 for (RecordMap::iterator it = pair.first; it != pair.second;) {
268 if (check(*it->second)) {
269 lru_order_.erase(it->second);
270 it = records_.erase(it);
271 count++;
272 } else {
273 it++;
274 }
275 }
276
277 return count;
278 }
279
ExpireSoon(const DomainName & domain,TrackerApplicableCheck check)280 int MdnsQuerier::RecordTrackerLruCache::ExpireSoon(
281 const DomainName& domain,
282 TrackerApplicableCheck check) {
283 auto pair = records_.equal_range(domain);
284 int count = 0;
285 for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
286 if (check(*it->second)) {
287 MoveToEnd(it);
288 it->second->ExpireSoon();
289 count++;
290 }
291 }
292
293 return count;
294 }
295
Update(const MdnsRecord & record,TrackerApplicableCheck check)296 int MdnsQuerier::RecordTrackerLruCache::Update(const MdnsRecord& record,
297 TrackerApplicableCheck check) {
298 return Update(record, check, [](const MdnsRecordTracker& t) {});
299 }
300
Update(const MdnsRecord & record,TrackerApplicableCheck check,TrackerChangeCallback on_rdata_update)301 int MdnsQuerier::RecordTrackerLruCache::Update(
302 const MdnsRecord& record,
303 TrackerApplicableCheck check,
304 TrackerChangeCallback on_rdata_update) {
305 auto pair = records_.equal_range(record.name());
306 int count = 0;
307 for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
308 if (check(*it->second)) {
309 auto result = it->second->Update(record);
310
311 if (result.is_error()) {
312 reporting_client_->OnRecoverableError(
313 Error(Error::Code::kUpdateReceivedRecordFailure,
314 result.error().ToString()));
315 continue;
316 }
317
318 count++;
319 if (result.value() == MdnsRecordTracker::UpdateType::kGoodbye) {
320 it->second->ExpireSoon();
321 MoveToEnd(it);
322 } else {
323 MoveToBeginning(it);
324 if (result.value() == MdnsRecordTracker::UpdateType::kRdata) {
325 on_rdata_update(*it->second);
326 }
327 }
328 }
329 }
330
331 return count;
332 }
333
StartTracking(MdnsRecord record,DnsType dns_type)334 const MdnsRecordTracker& MdnsQuerier::RecordTrackerLruCache::StartTracking(
335 MdnsRecord record,
336 DnsType dns_type) {
337 auto expiration_callback = [this](const MdnsRecordTracker* tracker,
338 const MdnsRecord& record) {
339 querier_->OnRecordExpired(tracker, record);
340 };
341
342 while (lru_order_.size() >=
343 static_cast<size_t>(config_.querier_max_records_cached)) {
344 // This call erases one of the tracked records.
345 OSP_DVLOG << "Maximum cacheable record count exceeded ("
346 << config_.querier_max_records_cached << ")";
347 lru_order_.back().ExpireNow();
348 }
349
350 auto name = record.name();
351 lru_order_.emplace_front(std::move(record), dns_type, sender_, task_runner_,
352 now_function_, random_delay_,
353 std::move(expiration_callback));
354 records_.emplace(std::move(name), lru_order_.begin());
355
356 return lru_order_.front();
357 }
358
MoveToBeginning(MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it)359 void MdnsQuerier::RecordTrackerLruCache::MoveToBeginning(
360 MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
361 lru_order_.splice(lru_order_.begin(), lru_order_, it->second);
362 it->second = lru_order_.begin();
363 }
364
MoveToEnd(MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it)365 void MdnsQuerier::RecordTrackerLruCache::MoveToEnd(
366 MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
367 lru_order_.splice(lru_order_.end(), lru_order_, it->second);
368 it->second = --lru_order_.end();
369 }
370
MdnsQuerier(MdnsSender * sender,MdnsReceiver * receiver,TaskRunner * task_runner,ClockNowFunctionPtr now_function,MdnsRandom * random_delay,ReportingClient * reporting_client,Config config)371 MdnsQuerier::MdnsQuerier(MdnsSender* sender,
372 MdnsReceiver* receiver,
373 TaskRunner* task_runner,
374 ClockNowFunctionPtr now_function,
375 MdnsRandom* random_delay,
376 ReportingClient* reporting_client,
377 Config config)
378 : sender_(sender),
379 receiver_(receiver),
380 task_runner_(task_runner),
381 now_function_(now_function),
382 random_delay_(random_delay),
383 reporting_client_(reporting_client),
384 config_(std::move(config)),
385 records_(this,
386 sender_,
387 random_delay_,
388 task_runner_,
389 now_function_,
390 reporting_client_,
391 config_) {
392 OSP_DCHECK(sender_);
393 OSP_DCHECK(receiver_);
394 OSP_DCHECK(task_runner_);
395 OSP_DCHECK(now_function_);
396 OSP_DCHECK(random_delay_);
397 OSP_DCHECK(reporting_client_);
398
399 receiver_->AddResponseCallback(this);
400 }
401
~MdnsQuerier()402 MdnsQuerier::~MdnsQuerier() {
403 receiver_->RemoveResponseCallback(this);
404 }
405
406 // NOTE: The code below is range loops instead of std:find_if, for better
407 // readability, brevity and homogeneity. Using std::find_if results in a few
408 // more lines of code, readability suffers from extra lambdas.
409
StartQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)410 void MdnsQuerier::StartQuery(const DomainName& name,
411 DnsType dns_type,
412 DnsClass dns_class,
413 MdnsRecordChangedCallback* callback) {
414 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
415 OSP_DCHECK(callback);
416 OSP_DCHECK(CanBeQueried(dns_type));
417
418 // Add a new callback if haven't seen it before
419 auto callbacks_it = callbacks_.equal_range(name);
420 for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
421 const CallbackInfo& callback_info = entry->second;
422 if (dns_type == callback_info.dns_type &&
423 dns_class == callback_info.dns_class &&
424 callback == callback_info.callback) {
425 // Already have this callback
426 return;
427 }
428 }
429 callbacks_.emplace(name, CallbackInfo{callback, dns_type, dns_class});
430
431 // Notify the new callback with previously cached records.
432 // NOTE: In the future, could allow callers to fetch cached records after
433 // adding a callback, for example to prime the UI.
434 std::vector<PendingQueryChange> pending_changes;
435 const std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
436 records_.Find(name, dns_type, dns_class);
437 for (const MdnsRecordTracker& tracker : trackers) {
438 if (!tracker.is_negative_response()) {
439 MdnsRecord stored_record(name, tracker.dns_type(), tracker.dns_class(),
440 tracker.record_type(), tracker.ttl(),
441 tracker.rdata());
442 std::vector<PendingQueryChange> new_changes = callback->OnRecordChanged(
443 std::move(stored_record), RecordChangedEvent::kCreated);
444 pending_changes.insert(pending_changes.end(), new_changes.begin(),
445 new_changes.end());
446 }
447 }
448
449 // Add a new question if haven't seen it before
450 auto questions_it = questions_.equal_range(name);
451 const bool is_question_already_tracked =
452 std::find_if(questions_it.first, questions_it.second,
453 [dns_type, dns_class](const auto& entry) {
454 const MdnsQuestion& tracked_question =
455 entry.second->question();
456 return dns_type == tracked_question.dns_type() &&
457 dns_class == tracked_question.dns_class();
458 }) != questions_it.second;
459 if (!is_question_already_tracked) {
460 AddQuestion(
461 MdnsQuestion(name, dns_type, dns_class, ResponseType::kMulticast));
462 }
463
464 // Apply any pending changes from the OnRecordChanged() callbacks.
465 ApplyPendingChanges(std::move(pending_changes));
466 }
467
StopQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)468 void MdnsQuerier::StopQuery(const DomainName& name,
469 DnsType dns_type,
470 DnsClass dns_class,
471 MdnsRecordChangedCallback* callback) {
472 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
473 OSP_DCHECK(callback);
474
475 if (!CanBeQueried(dns_type)) {
476 return;
477 }
478
479 // Find and remove the callback.
480 int callbacks_for_key = 0;
481 auto callbacks_it = callbacks_.equal_range(name);
482 for (auto entry = callbacks_it.first; entry != callbacks_it.second;) {
483 const CallbackInfo& callback_info = entry->second;
484 if (dns_type == callback_info.dns_type &&
485 dns_class == callback_info.dns_class) {
486 if (callback == callback_info.callback) {
487 entry = callbacks_.erase(entry);
488 } else {
489 ++callbacks_for_key;
490 ++entry;
491 }
492 }
493 }
494
495 // Exit if there are still callbacks registered for DomainName + DnsType +
496 // DnsClass
497 if (callbacks_for_key > 0) {
498 return;
499 }
500
501 // Find and delete a question that does not have any associated callbacks
502 auto questions_it = questions_.equal_range(name);
503 for (auto entry = questions_it.first; entry != questions_it.second; ++entry) {
504 const MdnsQuestion& tracked_question = entry->second->question();
505 if (dns_type == tracked_question.dns_type() &&
506 dns_class == tracked_question.dns_class()) {
507 questions_.erase(entry);
508 return;
509 }
510 }
511 }
512
ReinitializeQueries(const DomainName & name)513 void MdnsQuerier::ReinitializeQueries(const DomainName& name) {
514 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
515
516 // Get the ongoing queries and their callbacks.
517 std::vector<CallbackInfo> callbacks;
518 auto its = callbacks_.equal_range(name);
519 for (auto it = its.first; it != its.second; it++) {
520 callbacks.push_back(std::move(it->second));
521 }
522 callbacks_.erase(name);
523
524 // Remove all known questions and answers.
525 questions_.erase(name);
526 records_.Erase(name, [](const MdnsRecordTracker& tracker) { return true; });
527
528 // Restart the queries.
529 for (const auto& cb : callbacks) {
530 StartQuery(name, cb.dns_type, cb.dns_class, cb.callback);
531 }
532 }
533
OnMessageReceived(const MdnsMessage & message)534 void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) {
535 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
536 OSP_DCHECK(message.type() == MessageType::Response);
537
538 OSP_DVLOG << "Received mDNS Response message with "
539 << message.answers().size() << " answers and "
540 << message.additional_records().size()
541 << " additional records. Processing...";
542
543 std::vector<MdnsRecord> records_to_process;
544
545 // Add any records that are relevant for this querier.
546 bool found_relevant_records = false;
547 for (const MdnsRecord& record : message.answers()) {
548 if (ShouldAnswerRecordBeProcessed(record)) {
549 records_to_process.push_back(record);
550 found_relevant_records = true;
551 }
552 }
553
554 // If any of the message's answers are relevant, add all additional records.
555 // Else, since the message has already been received and parsed, use any
556 // individual records relevant to this querier to update the cache.
557 for (const MdnsRecord& record : message.additional_records()) {
558 if (found_relevant_records || ShouldAnswerRecordBeProcessed(record)) {
559 records_to_process.push_back(record);
560 }
561 }
562
563 // Drop NSEC records associated with a non-NSEC record of the same type.
564 RemoveInvalidNsecFlags(&records_to_process);
565
566 // Process all remaining records.
567 for (const MdnsRecord& record_to_process : records_to_process) {
568 ProcessRecord(record_to_process);
569 }
570
571 OSP_DVLOG << "\tmDNS Response processed (" << records_to_process.size()
572 << " records accepted)!";
573
574 // TODO(crbug.com/openscreen/83): Check authority records.
575 }
576
ShouldAnswerRecordBeProcessed(const MdnsRecord & answer)577 bool MdnsQuerier::ShouldAnswerRecordBeProcessed(const MdnsRecord& answer) {
578 // First, accept the record if it's associated with an ongoing question.
579 const auto questions_range = questions_.equal_range(answer.name());
580 const auto it = std::find_if(
581 questions_range.first, questions_range.second,
582 [&answer](const auto& pair) {
583 return (pair.second->question().dns_type() == DnsType::kANY ||
584 IsNegativeResponseFor(answer,
585 pair.second->question().dns_type()) ||
586 pair.second->question().dns_type() == answer.dns_type()) &&
587 (pair.second->question().dns_class() == DnsClass::kANY ||
588 pair.second->question().dns_class() == answer.dns_class());
589 });
590 if (it != questions_range.second) {
591 return true;
592 }
593
594 // If not, check if it corresponds to an already existing record. This is
595 // required because records which are already stored may either have been
596 // received in an additional records section, or are associated with a query
597 // which is no longer active.
598 std::vector<DnsType> types{answer.dns_type()};
599 if (answer.dns_type() == DnsType::kNSEC) {
600 const auto& nsec_rdata = absl::get<NsecRecordRdata>(answer.rdata());
601 types = nsec_rdata.types();
602 }
603
604 for (DnsType type : types) {
605 std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
606 records_.Find(answer.name(), type, answer.dns_class());
607 if (!trackers.empty()) {
608 return true;
609 }
610 }
611
612 // In all other cases, the record isn't relevant. Drop it.
613 return false;
614 }
615
OnRecordExpired(const MdnsRecordTracker * tracker,const MdnsRecord & record)616 void MdnsQuerier::OnRecordExpired(const MdnsRecordTracker* tracker,
617 const MdnsRecord& record) {
618 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
619
620 if (!tracker->is_negative_response()) {
621 ProcessCallbacks(record, RecordChangedEvent::kExpired);
622 }
623
624 records_.Erase(record.name(), [tracker](const MdnsRecordTracker& it_tracker) {
625 return tracker == &it_tracker;
626 });
627 }
628
ProcessRecord(const MdnsRecord & record)629 void MdnsQuerier::ProcessRecord(const MdnsRecord& record) {
630 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
631
632 // Skip all records that can't be processed.
633 if (!CanBeProcessed(record.dns_type())) {
634 return;
635 }
636
637 // Ignore NSEC records if the embedder has configured us to do so.
638 if (config_.ignore_nsec_responses && record.dns_type() == DnsType::kNSEC) {
639 return;
640 }
641
642 // Get the types which the received record is associated with. In most cases
643 // this will only be the type of the provided record, but in the case of
644 // NSEC records this will be all records which the record dictates the
645 // nonexistence of.
646 std::vector<DnsType> types;
647 int types_count = 0;
648 const DnsType* types_ptr = nullptr;
649 if (record.dns_type() == DnsType::kNSEC) {
650 const auto& nsec_rdata = absl::get<NsecRecordRdata>(record.rdata());
651 if (std::find(nsec_rdata.types().begin(), nsec_rdata.types().end(),
652 DnsType::kANY) != nsec_rdata.types().end()) {
653 types_ptr = kTranslatedNsecAnyQueryTypes.data();
654 types_count = kTranslatedNsecAnyQueryTypes.size();
655 } else {
656 types_ptr = nsec_rdata.types().data();
657 types_count = nsec_rdata.types().size();
658 }
659 } else {
660 types.push_back(record.dns_type());
661 types_ptr = types.data();
662 types_count = types.size();
663 }
664
665 // Apply the update for each type that the record is associated with.
666 for (int i = 0; i < types_count; ++i) {
667 DnsType dns_type = types_ptr[i];
668 switch (record.record_type()) {
669 case RecordType::kShared: {
670 ProcessSharedRecord(record, dns_type);
671 break;
672 }
673 case RecordType::kUnique: {
674 ProcessUniqueRecord(record, dns_type);
675 break;
676 }
677 }
678 }
679 }
680
ProcessSharedRecord(const MdnsRecord & record,DnsType dns_type)681 void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record,
682 DnsType dns_type) {
683 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
684 OSP_DCHECK(record.record_type() == RecordType::kShared);
685
686 // By design, NSEC records are never shared records.
687 if (record.dns_type() == DnsType::kNSEC) {
688 return;
689 }
690
691 // For any records updated, this host already has this shared record. Since
692 // the RDATA matches, this is only a TTL update.
693 auto check = [&record](const MdnsRecordTracker& tracker) {
694 return record.dns_type() == tracker.dns_type() &&
695 record.dns_class() == tracker.dns_class() &&
696 record.rdata() == tracker.rdata();
697 };
698 auto updated_count = records_.Update(record, std::move(check));
699
700 if (!updated_count) {
701 // Have never before seen this shared record, insert a new one.
702 AddRecord(record, dns_type);
703 ProcessCallbacks(record, RecordChangedEvent::kCreated);
704 }
705 }
706
ProcessUniqueRecord(const MdnsRecord & record,DnsType dns_type)707 void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
708 DnsType dns_type) {
709 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
710 OSP_DCHECK(record.record_type() == RecordType::kUnique);
711
712 std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
713 records_.Find(record.name(), dns_type, record.dns_class());
714 size_t num_records_for_key = trackers.size();
715
716 // Have not seen any records with this key before. This case is expected the
717 // first time a record is received.
718 if (num_records_for_key == size_t{0}) {
719 const bool will_exist = record.dns_type() != DnsType::kNSEC;
720 AddRecord(record, dns_type);
721 if (will_exist) {
722 ProcessCallbacks(record, RecordChangedEvent::kCreated);
723 }
724 } else if (num_records_for_key == size_t{1}) {
725 // There is exactly one tracker associated with this key. This is the
726 // expected case when a record matching this one has already been seen.
727 ProcessSinglyTrackedUniqueRecord(record, trackers[0]);
728 } else {
729 // Multiple records with the same key.
730 ProcessMultiTrackedUniqueRecord(record, dns_type);
731 }
732 }
733
ProcessSinglyTrackedUniqueRecord(const MdnsRecord & record,const MdnsRecordTracker & tracker)734 void MdnsQuerier::ProcessSinglyTrackedUniqueRecord(
735 const MdnsRecord& record,
736 const MdnsRecordTracker& tracker) {
737 const bool existed_previously = !tracker.is_negative_response();
738 const bool will_exist = record.dns_type() != DnsType::kNSEC;
739
740 // Calculate the callback to call on record update success while the old
741 // record still exists.
742 MdnsRecord record_for_callback = record;
743 if (existed_previously && !will_exist) {
744 record_for_callback =
745 MdnsRecord(record.name(), tracker.dns_type(), tracker.dns_class(),
746 tracker.record_type(), tracker.ttl(), tracker.rdata());
747 }
748
749 auto on_rdata_change = [this, r = std::move(record_for_callback),
750 existed_previously,
751 will_exist](const MdnsRecordTracker& tracker) {
752 // If RDATA on the record is different, notify that the record has
753 // been updated.
754 if (existed_previously && will_exist) {
755 ProcessCallbacks(r, RecordChangedEvent::kUpdated);
756 } else if (existed_previously) {
757 // Do not expire the tracker, because it still holds an NSEC record.
758 ProcessCallbacks(r, RecordChangedEvent::kExpired);
759 } else if (will_exist) {
760 ProcessCallbacks(r, RecordChangedEvent::kCreated);
761 }
762 };
763
764 int updated_count = records_.Update(
765 record, [&tracker](const MdnsRecordTracker& t) { return &tracker == &t; },
766 std::move(on_rdata_change));
767 OSP_DCHECK_EQ(updated_count, 1);
768 }
769
ProcessMultiTrackedUniqueRecord(const MdnsRecord & record,DnsType dns_type)770 void MdnsQuerier::ProcessMultiTrackedUniqueRecord(const MdnsRecord& record,
771 DnsType dns_type) {
772 auto update_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
773 return tracker.dns_type() == dns_type &&
774 tracker.dns_class() == record.dns_class() &&
775 tracker.rdata() == record.rdata();
776 };
777 int update_count = records_.Update(
778 record, std::move(update_check),
779 [](const MdnsRecordTracker& tracker) { OSP_NOTREACHED(); });
780 OSP_DCHECK_LE(update_count, 1);
781
782 auto expire_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
783 return tracker.dns_type() == dns_type &&
784 tracker.dns_class() == record.dns_class() &&
785 tracker.rdata() != record.rdata();
786 };
787 int expire_count =
788 records_.ExpireSoon(record.name(), std::move(expire_check));
789 OSP_DCHECK_GE(expire_count, 1);
790
791 // Did not find an existing record to update.
792 if (!update_count && !expire_count) {
793 AddRecord(record, dns_type);
794 if (record.dns_type() != DnsType::kNSEC) {
795 ProcessCallbacks(record, RecordChangedEvent::kCreated);
796 }
797 }
798 }
799
ProcessCallbacks(const MdnsRecord & record,RecordChangedEvent event)800 void MdnsQuerier::ProcessCallbacks(const MdnsRecord& record,
801 RecordChangedEvent event) {
802 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
803
804 std::vector<PendingQueryChange> pending_changes;
805 auto callbacks_it = callbacks_.equal_range(record.name());
806 for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
807 const CallbackInfo& callback_info = entry->second;
808 if ((callback_info.dns_type == DnsType::kANY ||
809 record.dns_type() == callback_info.dns_type) &&
810 (callback_info.dns_class == DnsClass::kANY ||
811 record.dns_class() == callback_info.dns_class)) {
812 std::vector<PendingQueryChange> new_changes =
813 callback_info.callback->OnRecordChanged(record, event);
814 pending_changes.insert(pending_changes.end(), new_changes.begin(),
815 new_changes.end());
816 }
817 }
818
819 ApplyPendingChanges(std::move(pending_changes));
820 }
821
AddQuestion(const MdnsQuestion & question)822 void MdnsQuerier::AddQuestion(const MdnsQuestion& question) {
823 auto tracker = std::make_unique<MdnsQuestionTracker>(
824 question, sender_, task_runner_, now_function_, random_delay_, config_);
825 MdnsQuestionTracker* ptr = tracker.get();
826 questions_.emplace(question.name(), std::move(tracker));
827
828 // Let all records associated with this question know that there is a new
829 // query that can be used for their refresh.
830 std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
831 records_.Find(question.name(), question.dns_type(), question.dns_class());
832 for (const MdnsRecordTracker& tracker : trackers) {
833 // NOTE: When the pointed to object is deleted, its dtor removes itself
834 // from all associated records.
835 ptr->AddAssociatedRecord(&tracker);
836 }
837 }
838
AddRecord(const MdnsRecord & record,DnsType type)839 void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) {
840 // Add the new record.
841 const auto& tracker = records_.StartTracking(record, type);
842
843 // Let all questions associated with this record know that there is a new
844 // record that answers them (for known answer suppression).
845 auto query_it = questions_.equal_range(record.name());
846 for (auto entry = query_it.first; entry != query_it.second; ++entry) {
847 const MdnsQuestion& query = entry->second->question();
848 const bool is_relevant_type =
849 type == DnsType::kANY || type == query.dns_type();
850 const bool is_relevant_class = record.dns_class() == DnsClass::kANY ||
851 record.dns_class() == query.dns_class();
852 if (is_relevant_type && is_relevant_class) {
853 // NOTE: When the pointed to object is deleted, its dtor removes itself
854 // from all associated queries.
855 entry->second->AddAssociatedRecord(&tracker);
856 }
857 }
858 }
859
ApplyPendingChanges(std::vector<PendingQueryChange> pending_changes)860 void MdnsQuerier::ApplyPendingChanges(
861 std::vector<PendingQueryChange> pending_changes) {
862 for (auto& pending_change : pending_changes) {
863 switch (pending_change.change_type) {
864 case PendingQueryChange::kStartQuery:
865 StartQuery(std::move(pending_change.name), pending_change.dns_type,
866 pending_change.dns_class, pending_change.callback);
867 break;
868 case PendingQueryChange::kStopQuery:
869 StopQuery(std::move(pending_change.name), pending_change.dns_type,
870 pending_change.dns_class, pending_change.callback);
871 break;
872 }
873 }
874 }
875
876 } // namespace discovery
877 } // namespace openscreen
878