1 // Copyright 2019 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef DISCOVERY_MDNS_MDNS_QUERIER_H_ 6 #define DISCOVERY_MDNS_MDNS_QUERIER_H_ 7 8 #include <list> 9 #include <map> 10 #include <memory> 11 #include <vector> 12 13 #include "discovery/common/config.h" 14 #include "discovery/mdns/mdns_receiver.h" 15 #include "discovery/mdns/mdns_record_changed_callback.h" 16 #include "discovery/mdns/mdns_records.h" 17 #include "discovery/mdns/mdns_trackers.h" 18 #include "platform/api/task_runner.h" 19 20 namespace openscreen { 21 namespace discovery { 22 23 class MdnsRandom; 24 class MdnsSender; 25 class MdnsQuestionTracker; 26 class MdnsRecordTracker; 27 class ReportingClient; 28 29 class MdnsQuerier : public MdnsReceiver::ResponseClient { 30 public: 31 MdnsQuerier(MdnsSender* sender, 32 MdnsReceiver* receiver, 33 TaskRunner* task_runner, 34 ClockNowFunctionPtr now_function, 35 MdnsRandom* random_delay, 36 ReportingClient* reporting_client, 37 Config config); 38 MdnsQuerier(const MdnsQuerier& other) = delete; 39 MdnsQuerier(MdnsQuerier&& other) noexcept = delete; 40 MdnsQuerier& operator=(const MdnsQuerier& other) = delete; 41 MdnsQuerier& operator=(MdnsQuerier&& other) noexcept = delete; 42 ~MdnsQuerier() override; 43 44 // Starts an mDNS query with the given name, DNS type, and DNS class. Updated 45 // records are passed to |callback|. The caller must ensure |callback| 46 // remains alive while it is registered with a query. 47 // NOTE: This call is only valid for |dns_type| values: 48 // - DnsType::kA 49 // - DnsType::kPTR 50 // - DnsType::kTXT 51 // - DnsType::kAAAA 52 // - DnsType::kSRV 53 // - DnsType::kANY 54 void StartQuery(const DomainName& name, 55 DnsType dns_type, 56 DnsClass dns_class, 57 MdnsRecordChangedCallback* callback); 58 59 // Stops an mDNS query with the given name, DNS type, and DNS class. 60 // |callback| must be the same callback pointer that was previously passed to 61 // StartQuery. 62 void StopQuery(const DomainName& name, 63 DnsType dns_type, 64 DnsClass dns_class, 65 MdnsRecordChangedCallback* callback); 66 67 // Re-initializes the process of service discovery for the provided domain 68 // name. All ongoing queries for this domain are restarted and any previously 69 // received query results are discarded. 70 void ReinitializeQueries(const DomainName& name); 71 72 private: 73 struct CallbackInfo { 74 MdnsRecordChangedCallback* const callback; 75 const DnsType dns_type; 76 const DnsClass dns_class; 77 }; 78 79 // Represents a Least Recently Used cache of MdnsRecordTrackers. 80 class RecordTrackerLruCache { 81 public: 82 using RecordTrackerConstRef = 83 std::reference_wrapper<const MdnsRecordTracker>; 84 using TrackerApplicableCheck = 85 std::function<bool(const MdnsRecordTracker&)>; 86 using TrackerChangeCallback = std::function<void(const MdnsRecordTracker&)>; 87 88 RecordTrackerLruCache(MdnsQuerier* querier, 89 MdnsSender* sender, 90 MdnsRandom* random_delay, 91 TaskRunner* task_runner, 92 ClockNowFunctionPtr now_function, 93 ReportingClient* reporting_client, 94 const Config& config); 95 96 // Returns all trackers with the associated |name| such that its type 97 // represents a type corresponding to |dns_type| and class corresponding to 98 // |dns_class|. 99 std::vector<RecordTrackerConstRef> Find(const DomainName& name); 100 std::vector<RecordTrackerConstRef> Find(const DomainName& name, 101 DnsType dns_type, 102 DnsClass dns_class); 103 104 // Calls ExpireSoon on all record trackers in the provided domain which 105 // match the provided applicability check. Returns the number of trackers 106 // marked for expiry. 107 int ExpireSoon(const DomainName& name, TrackerApplicableCheck check); 108 109 // Erases all record trackers in the provided domain which match the 110 // provided applicability check. Returns the number of trackers erased. 111 int Erase(const DomainName& name, TrackerApplicableCheck check); 112 113 // Updates all record trackers in the domain |record.name()| which match the 114 // provided applicability check using the provided record. Returns the 115 // number of records successfully updated. 116 int Update(const MdnsRecord& record, TrackerApplicableCheck check); 117 int Update(const MdnsRecord& record, 118 TrackerApplicableCheck check, 119 TrackerChangeCallback on_rdata_update); 120 121 // Creates a record tracker of the given type associated with the provided 122 // record. 123 const MdnsRecordTracker& StartTracking(MdnsRecord record, DnsType type); 124 size()125 size_t size() { return records_.size(); } 126 127 private: 128 using LruList = std::list<MdnsRecordTracker>; 129 using RecordMap = std::multimap<DomainName, LruList::iterator>; 130 131 void MoveToBeginning(RecordMap::iterator iterator); 132 void MoveToEnd(RecordMap::iterator iterator); 133 134 MdnsQuerier* const querier_; 135 MdnsSender* const sender_; 136 MdnsRandom* const random_delay_; 137 TaskRunner* const task_runner_; 138 ClockNowFunctionPtr now_function_; 139 ReportingClient* reporting_client_; 140 const Config& config_; 141 142 // List of RecordTracker instances used by this instance where the least 143 // recently updated element (or next to be deleted element) appears at the 144 // end of the list. 145 LruList lru_order_; 146 147 // A collection of active known record trackers, each is identified by 148 // domain name, DNS record type, and DNS record class. Multimap key is 149 // domain name only to allow easy support for wildcard processing for DNS 150 // record type and class and allow storing shared records that differ only 151 // in RDATA. 152 // 153 // MdnsRecordTracker instances are stored as unique_ptr so they are not 154 // moved around in memory when the collection is modified. This allows 155 // passing a pointer to MdnsQuestionTracker to a task running on the 156 // TaskRunner. 157 RecordMap records_; 158 }; 159 160 friend class MdnsQuerierTest; 161 162 // MdnsReceiver::ResponseClient overrides. 163 void OnMessageReceived(const MdnsMessage& message) override; 164 165 // Expires the record tracker provided. This callback is passed to owned 166 // MdnsRecordTracker instances in |records_|. 167 void OnRecordExpired(const MdnsRecordTracker* tracker, 168 const MdnsRecord& record); 169 170 // Determines whether a record received by this querier should be processed 171 // or dropped. 172 bool ShouldAnswerRecordBeProcessed(const MdnsRecord& answer); 173 174 // Processes any record update, calling into the below methods as needed. 175 // NOTE: All records of type OPT are dropped, as they should not be cached per 176 // RFC6891. 177 void ProcessRecord(const MdnsRecord& records); 178 179 // Processes a shared record update as a record of type |type|. 180 void ProcessSharedRecord(const MdnsRecord& record, DnsType type); 181 182 // Processes a unique record update as a record of type |type|. 183 void ProcessUniqueRecord(const MdnsRecord& record, DnsType type); 184 185 // Called when exactly one tracker is associated with a provided key. 186 // Determines the type of update being executed by this update call, then 187 // fires the appropriate callback. 188 void ProcessSinglyTrackedUniqueRecord(const MdnsRecord& record, 189 const MdnsRecordTracker& tracker); 190 191 // Called when multiple records are associated with the same key. Expire all 192 // record with non-matching RDATA. Update the record with the matching RDATA 193 // if it exists, otherwise insert a new record. 194 void ProcessMultiTrackedUniqueRecord(const MdnsRecord& record, 195 DnsType dns_type); 196 197 // Calls all callbacks associated with the provided record. 198 void ProcessCallbacks(const MdnsRecord& record, RecordChangedEvent event); 199 200 // Begins tracking the provided question. 201 void AddQuestion(const MdnsQuestion& question); 202 203 // Begins tracking the provided record. 204 void AddRecord(const MdnsRecord& record, DnsType type); 205 206 // Applies the supplied pending changes. 207 void ApplyPendingChanges(std::vector<PendingQueryChange> pending_changes); 208 209 MdnsSender* const sender_; 210 MdnsReceiver* const receiver_; 211 TaskRunner* const task_runner_; 212 const ClockNowFunctionPtr now_function_; 213 MdnsRandom* const random_delay_; 214 ReportingClient* reporting_client_; 215 Config config_; 216 217 // A collection of active question trackers, each is uniquely identified by 218 // domain name, DNS record type, and DNS record class. Multimap key is domain 219 // name only to allow easy support for wildcard processing for DNS record type 220 // and class. MdnsQuestionTracker instances are stored as unique_ptr so they 221 // are not moved around in memory when the collection is modified. This allows 222 // passing a pointer to MdnsQuestionTracker to a task running on the 223 // TaskRunner. 224 std::multimap<DomainName, std::unique_ptr<MdnsQuestionTracker>> questions_; 225 226 // Set of records tracked by this querier. 227 RecordTrackerLruCache records_; 228 229 // A collection of callbacks passed to StartQuery method. Each is identified 230 // by domain name, DNS record type, and DNS record class, but there can be 231 // more than one callback for a particular query. Multimap key is domain name 232 // only to allow easy matching of records against callbacks that have wildcard 233 // DNS class and/or DNS type. 234 std::multimap<DomainName, CallbackInfo> callbacks_; 235 }; 236 237 } // namespace discovery 238 } // namespace openscreen 239 240 #endif // DISCOVERY_MDNS_MDNS_QUERIER_H_ 241