• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/mdns/mdns_querier.h"
6 
7 #include <algorithm>
8 #include <array>
9 #include <bitset>
10 #include <memory>
11 #include <unordered_set>
12 #include <utility>
13 #include <vector>
14 
15 #include "discovery/common/config.h"
16 #include "discovery/common/reporting_client.h"
17 #include "discovery/mdns/mdns_random.h"
18 #include "discovery/mdns/mdns_receiver.h"
19 #include "discovery/mdns/mdns_sender.h"
20 #include "discovery/mdns/public/mdns_constants.h"
21 
22 namespace openscreen {
23 namespace discovery {
24 namespace {
25 
26 constexpr std::array<DnsType, 5> kTranslatedNsecAnyQueryTypes = {
27     DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA, DnsType::kSRV};
28 
IsNegativeResponseFor(const MdnsRecord & record,DnsType type)29 bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) {
30   if (record.dns_type() != DnsType::kNSEC) {
31     return false;
32   }
33 
34   const NsecRecordRdata& nsec = absl::get<NsecRecordRdata>(record.rdata());
35 
36   // RFC 6762 section 6.1, the NSEC bit must NOT be set in the received NSEC
37   // record to indicate this is an mDNS NSEC record rather than a traditional
38   // DNS NSEC record.
39   if (std::find(nsec.types().begin(), nsec.types().end(), DnsType::kNSEC) !=
40       nsec.types().end()) {
41     return false;
42   }
43 
44   return std::find_if(nsec.types().begin(), nsec.types().end(),
45                       [type](DnsType stored_type) {
46                         return stored_type == type ||
47                                stored_type == DnsType::kANY;
48                       }) != nsec.types().end();
49 }
50 
51 struct HashDnsType {
operator ()openscreen::discovery::__anon18feac030111::HashDnsType52   inline size_t operator()(DnsType type) const {
53     return static_cast<size_t>(type);
54   }
55 };
56 
57 // Helper used for sorting MDNS records. This function guarantees the following:
58 // - All MdnsRecords with the same name appear adjacent to each-other.
59 // - An NSEC record with a given name appears before all other records with the
60 //   same name.
CompareRecordByNameAndType(const MdnsRecord & first,const MdnsRecord & second)61 bool CompareRecordByNameAndType(const MdnsRecord& first,
62                                 const MdnsRecord& second) {
63   if (first.name() != second.name()) {
64     return first.name() < second.name();
65   }
66 
67   if ((first.dns_type() == DnsType::kNSEC) !=
68       (second.dns_type() == DnsType::kNSEC)) {
69     return first.dns_type() == DnsType::kNSEC;
70   }
71 
72   return first < second;
73 }
74 
75 class DnsTypeBitSet {
76  public:
77   // Returns whether any types are currently stored in this data structure.
IsEmpty()78   bool IsEmpty() { return !elements_.any(); }
79 
80   // Attempts to insert the given type into this data structure. Returns
81   // true iff the type was not already present.
Insert(DnsType type)82   bool Insert(DnsType type) {
83     uint16_t bit = (type == DnsType::kANY) ? 0 : static_cast<uint16_t>(type);
84     bool was_set = elements_.test(bit);
85     elements_.set(bit);
86     return !was_set;
87   }
88 
89   // Iterates over all members of the provided container, inserting each
90   // DnsType contained within to this instance. Returns true iff any element
91   // inserted was not already present in this instance.
92   template <typename Container>
Insert(const Container & container)93   bool Insert(const Container& container) {
94     bool has_element_been_inserted = false;
95     for (DnsType type : container) {
96       has_element_been_inserted |= Insert(type);
97     }
98     return has_element_been_inserted;
99   }
100 
101   // Attempts to remove the given type from this data structure. Returns true
102   // iff the type was present prior to this call.
Remove(DnsType type)103   bool Remove(DnsType type) {
104     if (IsEmpty()) {
105       return false;
106     } else if (type == DnsType::kANY) {
107       elements_.reset();
108       return true;
109     }
110 
111     uint16_t bit = static_cast<uint16_t>(type);
112     bool was_set = elements_.test(bit);
113     elements_.reset(bit);
114     return was_set;
115   }
116 
117   // Returns the DnsTypes currently stored in this data structure.
GetTypes() const118   std::vector<DnsType> GetTypes() const {
119     if (elements_.test(0)) {
120       return {DnsType::kANY};
121     }
122 
123     std::vector<DnsType> types;
124     for (DnsType type : kSupportedDnsTypes) {
125       if (type == DnsType::kANY) {
126         continue;
127       }
128 
129       uint16_t cast_int = static_cast<uint16_t>(type);
130       if (elements_.test(cast_int)) {
131         types.push_back(type);
132       }
133     }
134     return types;
135   }
136 
137  private:
138   std::bitset<64> elements_;
139 };
140 
141 // Modifies |records| such that no NSEC record signifies the nonexistance of a
142 // record which is also present in the same message. Order of the input vector
143 // is NOT preserved.
144 // NOTE: |records| is not of type MdnsRecord::ConstRef because the members must
145 // be modified.
146 // TODO(b/170353378): Break this logic into a separate processing module between
147 // the MdnsReader and the MdnsQuerier.
RemoveInvalidNsecFlags(std::vector<MdnsRecord> * records)148 void RemoveInvalidNsecFlags(std::vector<MdnsRecord>* records) {
149   // Sort the records so NSEC records are first so that only one iteration
150   // through all records is needed.
151   std::sort(records->begin(), records->end(), CompareRecordByNameAndType);
152 
153   // The set of NSEC records that need to be removed from |records|. This can't
154   // be done as part of the below loop because it would invalidate the iterator
155   // that's still being used.
156   std::vector<std::vector<MdnsRecord>::iterator> nsecs_to_delete;
157 
158   // Process all elements.
159   for (auto it = records->begin(); it != records->end();) {
160     if (it->dns_type() != DnsType::kNSEC) {
161       it++;
162       continue;
163     }
164 
165     // Track whether the current NSEC record in the input vector has been
166     // modified by some step of this algorithm, be that merging with another
167     // record, removing a DnsType, or any other modification.
168     bool has_changed = false;
169 
170     // The types for the new record to create, if |has_changed|.
171     const NsecRecordRdata& nsec_rdata = absl::get<NsecRecordRdata>(it->rdata());
172     DnsTypeBitSet types;
173     for (DnsType type : nsec_rdata.types()) {
174       types.Insert(type);
175     }
176     auto nsec = it;
177     it++;
178 
179     // Combine multiple NSECs to simplify the following code. This probably
180     // won't happen, but the RFC doesn't exclude the possibility, so account for
181     // it. Define the TTL of this new NSEC record created by this merge process
182     // to be the minimum of all merged NSEC records.
183     std::chrono::seconds new_ttl = nsec->ttl();
184     while (it != records->end() && it->name() == nsec->name() &&
185            it->dns_type() == DnsType::kNSEC) {
186       has_changed |=
187           types.Insert(absl::get<NsecRecordRdata>(it->rdata()).types());
188       new_ttl = std::min(new_ttl, it->ttl());
189       it = records->erase(it);
190     }
191 
192     // Remove any types associated with a known record type.
193     for (; it != records->end() && it->name() == nsec->name(); it++) {
194       OSP_DCHECK(it->dns_type() != DnsType::kNSEC);
195       has_changed |= types.Remove(it->dns_type());
196     }
197 
198     // Modify the stored NSEC record, if needed.
199     if (has_changed && types.IsEmpty()) {
200       nsecs_to_delete.push_back(nsec);
201     } else if (has_changed) {
202       NsecRecordRdata new_rdata(nsec_rdata.next_domain_name(),
203                                 types.GetTypes());
204       *nsec = MdnsRecord(nsec->name(), nsec->dns_type(), nsec->dns_class(),
205                          nsec->record_type(), new_ttl, std::move(new_rdata));
206     }
207   }
208 
209   // Erase invalid NSEC records. Go backwards to avoid invalidating the
210   // remaining iterators.
211   for (auto erase_it = nsecs_to_delete.rbegin();
212        erase_it != nsecs_to_delete.rend(); erase_it++) {
213     records->erase(*erase_it);
214   }
215 }
216 
217 }  // namespace
218 
RecordTrackerLruCache(MdnsQuerier * querier,MdnsSender * sender,MdnsRandom * random_delay,TaskRunner * task_runner,ClockNowFunctionPtr now_function,ReportingClient * reporting_client,const Config & config)219 MdnsQuerier::RecordTrackerLruCache::RecordTrackerLruCache(
220     MdnsQuerier* querier,
221     MdnsSender* sender,
222     MdnsRandom* random_delay,
223     TaskRunner* task_runner,
224     ClockNowFunctionPtr now_function,
225     ReportingClient* reporting_client,
226     const Config& config)
227     : querier_(querier),
228       sender_(sender),
229       random_delay_(random_delay),
230       task_runner_(task_runner),
231       now_function_(now_function),
232       reporting_client_(reporting_client),
233       config_(config) {
234   OSP_DCHECK(sender_);
235   OSP_DCHECK(random_delay_);
236   OSP_DCHECK(task_runner_);
237   OSP_DCHECK(reporting_client_);
238   OSP_DCHECK_GT(config_.querier_max_records_cached, 0);
239 }
240 
241 std::vector<std::reference_wrapper<const MdnsRecordTracker>>
Find(const DomainName & name)242 MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name) {
243   return Find(name, DnsType::kANY, DnsClass::kANY);
244 }
245 
246 std::vector<std::reference_wrapper<const MdnsRecordTracker>>
Find(const DomainName & name,DnsType dns_type,DnsClass dns_class)247 MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name,
248                                          DnsType dns_type,
249                                          DnsClass dns_class) {
250   std::vector<RecordTrackerConstRef> results;
251   auto pair = records_.equal_range(name);
252   for (auto it = pair.first; it != pair.second; it++) {
253     const MdnsRecordTracker& tracker = *it->second;
254     if ((dns_type == DnsType::kANY || dns_type == tracker.dns_type()) &&
255         (dns_class == DnsClass::kANY || dns_class == tracker.dns_class())) {
256       results.push_back(std::cref(tracker));
257     }
258   }
259 
260   return results;
261 }
262 
Erase(const DomainName & domain,TrackerApplicableCheck check)263 int MdnsQuerier::RecordTrackerLruCache::Erase(const DomainName& domain,
264                                               TrackerApplicableCheck check) {
265   auto pair = records_.equal_range(domain);
266   int count = 0;
267   for (RecordMap::iterator it = pair.first; it != pair.second;) {
268     if (check(*it->second)) {
269       lru_order_.erase(it->second);
270       it = records_.erase(it);
271       count++;
272     } else {
273       it++;
274     }
275   }
276 
277   return count;
278 }
279 
ExpireSoon(const DomainName & domain,TrackerApplicableCheck check)280 int MdnsQuerier::RecordTrackerLruCache::ExpireSoon(
281     const DomainName& domain,
282     TrackerApplicableCheck check) {
283   auto pair = records_.equal_range(domain);
284   int count = 0;
285   for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
286     if (check(*it->second)) {
287       MoveToEnd(it);
288       it->second->ExpireSoon();
289       count++;
290     }
291   }
292 
293   return count;
294 }
295 
Update(const MdnsRecord & record,TrackerApplicableCheck check)296 int MdnsQuerier::RecordTrackerLruCache::Update(const MdnsRecord& record,
297                                                TrackerApplicableCheck check) {
298   return Update(record, check, [](const MdnsRecordTracker& t) {});
299 }
300 
Update(const MdnsRecord & record,TrackerApplicableCheck check,TrackerChangeCallback on_rdata_update)301 int MdnsQuerier::RecordTrackerLruCache::Update(
302     const MdnsRecord& record,
303     TrackerApplicableCheck check,
304     TrackerChangeCallback on_rdata_update) {
305   auto pair = records_.equal_range(record.name());
306   int count = 0;
307   for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
308     if (check(*it->second)) {
309       auto result = it->second->Update(record);
310 
311       if (result.is_error()) {
312         reporting_client_->OnRecoverableError(
313             Error(Error::Code::kUpdateReceivedRecordFailure,
314                   result.error().ToString()));
315         continue;
316       }
317 
318       count++;
319       if (result.value() == MdnsRecordTracker::UpdateType::kGoodbye) {
320         it->second->ExpireSoon();
321         MoveToEnd(it);
322       } else {
323         MoveToBeginning(it);
324         if (result.value() == MdnsRecordTracker::UpdateType::kRdata) {
325           on_rdata_update(*it->second);
326         }
327       }
328     }
329   }
330 
331   return count;
332 }
333 
StartTracking(MdnsRecord record,DnsType dns_type)334 const MdnsRecordTracker& MdnsQuerier::RecordTrackerLruCache::StartTracking(
335     MdnsRecord record,
336     DnsType dns_type) {
337   auto expiration_callback = [this](const MdnsRecordTracker* tracker,
338                                     const MdnsRecord& record) {
339     querier_->OnRecordExpired(tracker, record);
340   };
341 
342   while (lru_order_.size() >=
343          static_cast<size_t>(config_.querier_max_records_cached)) {
344     // This call erases one of the tracked records.
345     OSP_DVLOG << "Maximum cacheable record count exceeded ("
346               << config_.querier_max_records_cached << ")";
347     lru_order_.back().ExpireNow();
348   }
349 
350   auto name = record.name();
351   lru_order_.emplace_front(std::move(record), dns_type, sender_, task_runner_,
352                            now_function_, random_delay_,
353                            std::move(expiration_callback));
354   records_.emplace(std::move(name), lru_order_.begin());
355 
356   return lru_order_.front();
357 }
358 
MoveToBeginning(MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it)359 void MdnsQuerier::RecordTrackerLruCache::MoveToBeginning(
360     MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
361   lru_order_.splice(lru_order_.begin(), lru_order_, it->second);
362   it->second = lru_order_.begin();
363 }
364 
MoveToEnd(MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it)365 void MdnsQuerier::RecordTrackerLruCache::MoveToEnd(
366     MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
367   lru_order_.splice(lru_order_.end(), lru_order_, it->second);
368   it->second = --lru_order_.end();
369 }
370 
MdnsQuerier(MdnsSender * sender,MdnsReceiver * receiver,TaskRunner * task_runner,ClockNowFunctionPtr now_function,MdnsRandom * random_delay,ReportingClient * reporting_client,Config config)371 MdnsQuerier::MdnsQuerier(MdnsSender* sender,
372                          MdnsReceiver* receiver,
373                          TaskRunner* task_runner,
374                          ClockNowFunctionPtr now_function,
375                          MdnsRandom* random_delay,
376                          ReportingClient* reporting_client,
377                          Config config)
378     : sender_(sender),
379       receiver_(receiver),
380       task_runner_(task_runner),
381       now_function_(now_function),
382       random_delay_(random_delay),
383       reporting_client_(reporting_client),
384       config_(std::move(config)),
385       records_(this,
386                sender_,
387                random_delay_,
388                task_runner_,
389                now_function_,
390                reporting_client_,
391                config_) {
392   OSP_DCHECK(sender_);
393   OSP_DCHECK(receiver_);
394   OSP_DCHECK(task_runner_);
395   OSP_DCHECK(now_function_);
396   OSP_DCHECK(random_delay_);
397   OSP_DCHECK(reporting_client_);
398 
399   receiver_->AddResponseCallback(this);
400 }
401 
~MdnsQuerier()402 MdnsQuerier::~MdnsQuerier() {
403   receiver_->RemoveResponseCallback(this);
404 }
405 
406 // NOTE: The code below is range loops instead of std:find_if, for better
407 // readability, brevity and homogeneity.  Using std::find_if results in a few
408 // more lines of code, readability suffers from extra lambdas.
409 
StartQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)410 void MdnsQuerier::StartQuery(const DomainName& name,
411                              DnsType dns_type,
412                              DnsClass dns_class,
413                              MdnsRecordChangedCallback* callback) {
414   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
415   OSP_DCHECK(callback);
416   OSP_DCHECK(CanBeQueried(dns_type));
417 
418   // Add a new callback if haven't seen it before
419   auto callbacks_it = callbacks_.equal_range(name);
420   for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
421     const CallbackInfo& callback_info = entry->second;
422     if (dns_type == callback_info.dns_type &&
423         dns_class == callback_info.dns_class &&
424         callback == callback_info.callback) {
425       // Already have this callback
426       return;
427     }
428   }
429   callbacks_.emplace(name, CallbackInfo{callback, dns_type, dns_class});
430 
431   // Notify the new callback with previously cached records.
432   // NOTE: In the future, could allow callers to fetch cached records after
433   // adding a callback, for example to prime the UI.
434   std::vector<PendingQueryChange> pending_changes;
435   const std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
436       records_.Find(name, dns_type, dns_class);
437   for (const MdnsRecordTracker& tracker : trackers) {
438     if (!tracker.is_negative_response()) {
439       MdnsRecord stored_record(name, tracker.dns_type(), tracker.dns_class(),
440                                tracker.record_type(), tracker.ttl(),
441                                tracker.rdata());
442       std::vector<PendingQueryChange> new_changes = callback->OnRecordChanged(
443           std::move(stored_record), RecordChangedEvent::kCreated);
444       pending_changes.insert(pending_changes.end(), new_changes.begin(),
445                              new_changes.end());
446     }
447   }
448 
449   // Add a new question if haven't seen it before
450   auto questions_it = questions_.equal_range(name);
451   const bool is_question_already_tracked =
452       std::find_if(questions_it.first, questions_it.second,
453                    [dns_type, dns_class](const auto& entry) {
454                      const MdnsQuestion& tracked_question =
455                          entry.second->question();
456                      return dns_type == tracked_question.dns_type() &&
457                             dns_class == tracked_question.dns_class();
458                    }) != questions_it.second;
459   if (!is_question_already_tracked) {
460     AddQuestion(
461         MdnsQuestion(name, dns_type, dns_class, ResponseType::kMulticast));
462   }
463 
464   // Apply any pending changes from the OnRecordChanged() callbacks.
465   ApplyPendingChanges(std::move(pending_changes));
466 }
467 
StopQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)468 void MdnsQuerier::StopQuery(const DomainName& name,
469                             DnsType dns_type,
470                             DnsClass dns_class,
471                             MdnsRecordChangedCallback* callback) {
472   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
473   OSP_DCHECK(callback);
474 
475   if (!CanBeQueried(dns_type)) {
476     return;
477   }
478 
479   // Find and remove the callback.
480   int callbacks_for_key = 0;
481   auto callbacks_it = callbacks_.equal_range(name);
482   for (auto entry = callbacks_it.first; entry != callbacks_it.second;) {
483     const CallbackInfo& callback_info = entry->second;
484     if (dns_type == callback_info.dns_type &&
485         dns_class == callback_info.dns_class) {
486       if (callback == callback_info.callback) {
487         entry = callbacks_.erase(entry);
488       } else {
489         ++callbacks_for_key;
490         ++entry;
491       }
492     }
493   }
494 
495   // Exit if there are still callbacks registered for DomainName + DnsType +
496   // DnsClass
497   if (callbacks_for_key > 0) {
498     return;
499   }
500 
501   // Find and delete a question that does not have any associated callbacks
502   auto questions_it = questions_.equal_range(name);
503   for (auto entry = questions_it.first; entry != questions_it.second; ++entry) {
504     const MdnsQuestion& tracked_question = entry->second->question();
505     if (dns_type == tracked_question.dns_type() &&
506         dns_class == tracked_question.dns_class()) {
507       questions_.erase(entry);
508       return;
509     }
510   }
511 }
512 
ReinitializeQueries(const DomainName & name)513 void MdnsQuerier::ReinitializeQueries(const DomainName& name) {
514   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
515 
516   // Get the ongoing queries and their callbacks.
517   std::vector<CallbackInfo> callbacks;
518   auto its = callbacks_.equal_range(name);
519   for (auto it = its.first; it != its.second; it++) {
520     callbacks.push_back(std::move(it->second));
521   }
522   callbacks_.erase(name);
523 
524   // Remove all known questions and answers.
525   questions_.erase(name);
526   records_.Erase(name, [](const MdnsRecordTracker& tracker) { return true; });
527 
528   // Restart the queries.
529   for (const auto& cb : callbacks) {
530     StartQuery(name, cb.dns_type, cb.dns_class, cb.callback);
531   }
532 }
533 
OnMessageReceived(const MdnsMessage & message)534 void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) {
535   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
536   OSP_DCHECK(message.type() == MessageType::Response);
537 
538   OSP_DVLOG << "Received mDNS Response message with "
539             << message.answers().size() << " answers and "
540             << message.additional_records().size()
541             << " additional records. Processing...";
542 
543   std::vector<MdnsRecord> records_to_process;
544 
545   // Add any records that are relevant for this querier.
546   bool found_relevant_records = false;
547   for (const MdnsRecord& record : message.answers()) {
548     if (ShouldAnswerRecordBeProcessed(record)) {
549       records_to_process.push_back(record);
550       found_relevant_records = true;
551     }
552   }
553 
554   // If any of the message's answers are relevant, add all additional records.
555   // Else, since the message has already been received and parsed, use any
556   // individual records relevant to this querier to update the cache.
557   for (const MdnsRecord& record : message.additional_records()) {
558     if (found_relevant_records || ShouldAnswerRecordBeProcessed(record)) {
559       records_to_process.push_back(record);
560     }
561   }
562 
563   // Drop NSEC records associated with a non-NSEC record of the same type.
564   RemoveInvalidNsecFlags(&records_to_process);
565 
566   // Process all remaining records.
567   for (const MdnsRecord& record_to_process : records_to_process) {
568     ProcessRecord(record_to_process);
569   }
570 
571   OSP_DVLOG << "\tmDNS Response processed (" << records_to_process.size()
572             << " records accepted)!";
573 
574   // TODO(crbug.com/openscreen/83): Check authority records.
575 }
576 
ShouldAnswerRecordBeProcessed(const MdnsRecord & answer)577 bool MdnsQuerier::ShouldAnswerRecordBeProcessed(const MdnsRecord& answer) {
578   // First, accept the record if it's associated with an ongoing question.
579   const auto questions_range = questions_.equal_range(answer.name());
580   const auto it = std::find_if(
581       questions_range.first, questions_range.second,
582       [&answer](const auto& pair) {
583         return (pair.second->question().dns_type() == DnsType::kANY ||
584                 IsNegativeResponseFor(answer,
585                                       pair.second->question().dns_type()) ||
586                 pair.second->question().dns_type() == answer.dns_type()) &&
587                (pair.second->question().dns_class() == DnsClass::kANY ||
588                 pair.second->question().dns_class() == answer.dns_class());
589       });
590   if (it != questions_range.second) {
591     return true;
592   }
593 
594   // If not, check if it corresponds to an already existing record. This is
595   // required because records which are already stored may either have been
596   // received in an additional records section, or are associated with a query
597   // which is no longer active.
598   std::vector<DnsType> types{answer.dns_type()};
599   if (answer.dns_type() == DnsType::kNSEC) {
600     const auto& nsec_rdata = absl::get<NsecRecordRdata>(answer.rdata());
601     types = nsec_rdata.types();
602   }
603 
604   for (DnsType type : types) {
605     std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
606         records_.Find(answer.name(), type, answer.dns_class());
607     if (!trackers.empty()) {
608       return true;
609     }
610   }
611 
612   // In all other cases, the record isn't relevant. Drop it.
613   return false;
614 }
615 
OnRecordExpired(const MdnsRecordTracker * tracker,const MdnsRecord & record)616 void MdnsQuerier::OnRecordExpired(const MdnsRecordTracker* tracker,
617                                   const MdnsRecord& record) {
618   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
619 
620   if (!tracker->is_negative_response()) {
621     ProcessCallbacks(record, RecordChangedEvent::kExpired);
622   }
623 
624   records_.Erase(record.name(), [tracker](const MdnsRecordTracker& it_tracker) {
625     return tracker == &it_tracker;
626   });
627 }
628 
ProcessRecord(const MdnsRecord & record)629 void MdnsQuerier::ProcessRecord(const MdnsRecord& record) {
630   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
631 
632   // Skip all records that can't be processed.
633   if (!CanBeProcessed(record.dns_type())) {
634     return;
635   }
636 
637   // Ignore NSEC records if the embedder has configured us to do so.
638   if (config_.ignore_nsec_responses && record.dns_type() == DnsType::kNSEC) {
639     return;
640   }
641 
642   // Get the types which the received record is associated with. In most cases
643   // this will only be the type of the provided record, but in the case of
644   // NSEC records this will be all records which the record dictates the
645   // nonexistence of.
646   std::vector<DnsType> types;
647   int types_count = 0;
648   const DnsType* types_ptr = nullptr;
649   if (record.dns_type() == DnsType::kNSEC) {
650     const auto& nsec_rdata = absl::get<NsecRecordRdata>(record.rdata());
651     if (std::find(nsec_rdata.types().begin(), nsec_rdata.types().end(),
652                   DnsType::kANY) != nsec_rdata.types().end()) {
653       types_ptr = kTranslatedNsecAnyQueryTypes.data();
654       types_count = kTranslatedNsecAnyQueryTypes.size();
655     } else {
656       types_ptr = nsec_rdata.types().data();
657       types_count = nsec_rdata.types().size();
658     }
659   } else {
660     types.push_back(record.dns_type());
661     types_ptr = types.data();
662     types_count = types.size();
663   }
664 
665   // Apply the update for each type that the record is associated with.
666   for (int i = 0; i < types_count; ++i) {
667     DnsType dns_type = types_ptr[i];
668     switch (record.record_type()) {
669       case RecordType::kShared: {
670         ProcessSharedRecord(record, dns_type);
671         break;
672       }
673       case RecordType::kUnique: {
674         ProcessUniqueRecord(record, dns_type);
675         break;
676       }
677     }
678   }
679 }
680 
ProcessSharedRecord(const MdnsRecord & record,DnsType dns_type)681 void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record,
682                                       DnsType dns_type) {
683   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
684   OSP_DCHECK(record.record_type() == RecordType::kShared);
685 
686   // By design, NSEC records are never shared records.
687   if (record.dns_type() == DnsType::kNSEC) {
688     return;
689   }
690 
691   // For any records updated, this host already has this shared record. Since
692   // the RDATA matches, this is only a TTL update.
693   auto check = [&record](const MdnsRecordTracker& tracker) {
694     return record.dns_type() == tracker.dns_type() &&
695            record.dns_class() == tracker.dns_class() &&
696            record.rdata() == tracker.rdata();
697   };
698   auto updated_count = records_.Update(record, std::move(check));
699 
700   if (!updated_count) {
701     // Have never before seen this shared record, insert a new one.
702     AddRecord(record, dns_type);
703     ProcessCallbacks(record, RecordChangedEvent::kCreated);
704   }
705 }
706 
ProcessUniqueRecord(const MdnsRecord & record,DnsType dns_type)707 void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
708                                       DnsType dns_type) {
709   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
710   OSP_DCHECK(record.record_type() == RecordType::kUnique);
711 
712   std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
713       records_.Find(record.name(), dns_type, record.dns_class());
714   size_t num_records_for_key = trackers.size();
715 
716   // Have not seen any records with this key before. This case is expected the
717   // first time a record is received.
718   if (num_records_for_key == size_t{0}) {
719     const bool will_exist = record.dns_type() != DnsType::kNSEC;
720     AddRecord(record, dns_type);
721     if (will_exist) {
722       ProcessCallbacks(record, RecordChangedEvent::kCreated);
723     }
724   } else if (num_records_for_key == size_t{1}) {
725     // There is exactly one tracker associated with this key. This is the
726     // expected case when a record matching this one has already been seen.
727     ProcessSinglyTrackedUniqueRecord(record, trackers[0]);
728   } else {
729     // Multiple records with the same key.
730     ProcessMultiTrackedUniqueRecord(record, dns_type);
731   }
732 }
733 
ProcessSinglyTrackedUniqueRecord(const MdnsRecord & record,const MdnsRecordTracker & tracker)734 void MdnsQuerier::ProcessSinglyTrackedUniqueRecord(
735     const MdnsRecord& record,
736     const MdnsRecordTracker& tracker) {
737   const bool existed_previously = !tracker.is_negative_response();
738   const bool will_exist = record.dns_type() != DnsType::kNSEC;
739 
740   // Calculate the callback to call on record update success while the old
741   // record still exists.
742   MdnsRecord record_for_callback = record;
743   if (existed_previously && !will_exist) {
744     record_for_callback =
745         MdnsRecord(record.name(), tracker.dns_type(), tracker.dns_class(),
746                    tracker.record_type(), tracker.ttl(), tracker.rdata());
747   }
748 
749   auto on_rdata_change = [this, r = std::move(record_for_callback),
750                           existed_previously,
751                           will_exist](const MdnsRecordTracker& tracker) {
752     // If RDATA on the record is different, notify that the record has
753     // been updated.
754     if (existed_previously && will_exist) {
755       ProcessCallbacks(r, RecordChangedEvent::kUpdated);
756     } else if (existed_previously) {
757       // Do not expire the tracker, because it still holds an NSEC record.
758       ProcessCallbacks(r, RecordChangedEvent::kExpired);
759     } else if (will_exist) {
760       ProcessCallbacks(r, RecordChangedEvent::kCreated);
761     }
762   };
763 
764   int updated_count = records_.Update(
765       record, [&tracker](const MdnsRecordTracker& t) { return &tracker == &t; },
766       std::move(on_rdata_change));
767   OSP_DCHECK_EQ(updated_count, 1);
768 }
769 
ProcessMultiTrackedUniqueRecord(const MdnsRecord & record,DnsType dns_type)770 void MdnsQuerier::ProcessMultiTrackedUniqueRecord(const MdnsRecord& record,
771                                                   DnsType dns_type) {
772   auto update_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
773     return tracker.dns_type() == dns_type &&
774            tracker.dns_class() == record.dns_class() &&
775            tracker.rdata() == record.rdata();
776   };
777   int update_count = records_.Update(
778       record, std::move(update_check),
779       [](const MdnsRecordTracker& tracker) { OSP_NOTREACHED(); });
780   OSP_DCHECK_LE(update_count, 1);
781 
782   auto expire_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
783     return tracker.dns_type() == dns_type &&
784            tracker.dns_class() == record.dns_class() &&
785            tracker.rdata() != record.rdata();
786   };
787   int expire_count =
788       records_.ExpireSoon(record.name(), std::move(expire_check));
789   OSP_DCHECK_GE(expire_count, 1);
790 
791   // Did not find an existing record to update.
792   if (!update_count && !expire_count) {
793     AddRecord(record, dns_type);
794     if (record.dns_type() != DnsType::kNSEC) {
795       ProcessCallbacks(record, RecordChangedEvent::kCreated);
796     }
797   }
798 }
799 
ProcessCallbacks(const MdnsRecord & record,RecordChangedEvent event)800 void MdnsQuerier::ProcessCallbacks(const MdnsRecord& record,
801                                    RecordChangedEvent event) {
802   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
803 
804   std::vector<PendingQueryChange> pending_changes;
805   auto callbacks_it = callbacks_.equal_range(record.name());
806   for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
807     const CallbackInfo& callback_info = entry->second;
808     if ((callback_info.dns_type == DnsType::kANY ||
809          record.dns_type() == callback_info.dns_type) &&
810         (callback_info.dns_class == DnsClass::kANY ||
811          record.dns_class() == callback_info.dns_class)) {
812       std::vector<PendingQueryChange> new_changes =
813           callback_info.callback->OnRecordChanged(record, event);
814       pending_changes.insert(pending_changes.end(), new_changes.begin(),
815                              new_changes.end());
816     }
817   }
818 
819   ApplyPendingChanges(std::move(pending_changes));
820 }
821 
AddQuestion(const MdnsQuestion & question)822 void MdnsQuerier::AddQuestion(const MdnsQuestion& question) {
823   auto tracker = std::make_unique<MdnsQuestionTracker>(
824       question, sender_, task_runner_, now_function_, random_delay_, config_);
825   MdnsQuestionTracker* ptr = tracker.get();
826   questions_.emplace(question.name(), std::move(tracker));
827 
828   // Let all records associated with this question know that there is a new
829   // query that can be used for their refresh.
830   std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
831       records_.Find(question.name(), question.dns_type(), question.dns_class());
832   for (const MdnsRecordTracker& tracker : trackers) {
833     // NOTE: When the pointed to object is deleted, its dtor removes itself
834     // from all associated records.
835     ptr->AddAssociatedRecord(&tracker);
836   }
837 }
838 
AddRecord(const MdnsRecord & record,DnsType type)839 void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) {
840   // Add the new record.
841   const auto& tracker = records_.StartTracking(record, type);
842 
843   // Let all questions associated with this record know that there is a new
844   // record that answers them (for known answer suppression).
845   auto query_it = questions_.equal_range(record.name());
846   for (auto entry = query_it.first; entry != query_it.second; ++entry) {
847     const MdnsQuestion& query = entry->second->question();
848     const bool is_relevant_type =
849         type == DnsType::kANY || type == query.dns_type();
850     const bool is_relevant_class = record.dns_class() == DnsClass::kANY ||
851                                    record.dns_class() == query.dns_class();
852     if (is_relevant_type && is_relevant_class) {
853       // NOTE: When the pointed to object is deleted, its dtor removes itself
854       // from all associated queries.
855       entry->second->AddAssociatedRecord(&tracker);
856     }
857   }
858 }
859 
ApplyPendingChanges(std::vector<PendingQueryChange> pending_changes)860 void MdnsQuerier::ApplyPendingChanges(
861     std::vector<PendingQueryChange> pending_changes) {
862   for (auto& pending_change : pending_changes) {
863     switch (pending_change.change_type) {
864       case PendingQueryChange::kStartQuery:
865         StartQuery(std::move(pending_change.name), pending_change.dns_type,
866                    pending_change.dns_class, pending_change.callback);
867         break;
868       case PendingQueryChange::kStopQuery:
869         StopQuery(std::move(pending_change.name), pending_change.dns_type,
870                   pending_change.dns_class, pending_change.callback);
871         break;
872     }
873   }
874 }
875 
876 }  // namespace discovery
877 }  // namespace openscreen
878