1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/mdns/mdns_trackers.h"
6
7 #include <array>
8 #include <limits>
9 #include <utility>
10
11 #include "discovery/common/config.h"
12 #include "discovery/mdns/mdns_random.h"
13 #include "discovery/mdns/mdns_record_changed_callback.h"
14 #include "discovery/mdns/mdns_sender.h"
15 #include "util/std_util.h"
16
17 namespace openscreen {
18 namespace discovery {
19
20 namespace {
21
22 // RFC 6762 Section 5.2
23 // https://tools.ietf.org/html/rfc6762#section-5.2
24
25 // Attempt to refresh a record should be performed at 80%, 85%, 90% and 95% TTL.
26 constexpr double kTtlFractions[] = {0.80, 0.85, 0.90, 0.95, 1.00};
27
28 // Intervals between successive queries must increase by at least a factor of 2.
29 constexpr int kIntervalIncreaseFactor = 2;
30
31 // The interval between the first two queries must be at least one second.
32 constexpr std::chrono::seconds kMinimumQueryInterval{1};
33
34 // The querier may cap the question refresh interval to a maximum of 60 minutes.
35 constexpr std::chrono::minutes kMaximumQueryInterval{60};
36
37 // RFC 6762 Section 10.1
38 // https://tools.ietf.org/html/rfc6762#section-10.1
39
40 // A goodbye record is a record with TTL of 0.
IsGoodbyeRecord(const MdnsRecord & record)41 bool IsGoodbyeRecord(const MdnsRecord& record) {
42 return record.ttl() == std::chrono::seconds(0);
43 }
44
IsNegativeResponseForType(const MdnsRecord & record,DnsType dns_type)45 bool IsNegativeResponseForType(const MdnsRecord& record, DnsType dns_type) {
46 if (record.dns_type() != DnsType::kNSEC) {
47 return false;
48 }
49
50 const auto& nsec_types = absl::get<NsecRecordRdata>(record.rdata()).types();
51 return std::find_if(nsec_types.begin(), nsec_types.end(),
52 [dns_type](DnsType type) {
53 return type == dns_type || type == DnsType::kANY;
54 }) != nsec_types.end();
55 }
56
57 // RFC 6762 Section 10.1
58 // https://tools.ietf.org/html/rfc6762#section-10.1
59 // In case of a goodbye record, the querier should set TTL to 1 second
60 constexpr std::chrono::seconds kGoodbyeRecordTtl{1};
61
62 } // namespace
63
MdnsTracker(MdnsSender * sender,TaskRunner * task_runner,ClockNowFunctionPtr now_function,MdnsRandom * random_delay,TrackerType tracker_type)64 MdnsTracker::MdnsTracker(MdnsSender* sender,
65 TaskRunner* task_runner,
66 ClockNowFunctionPtr now_function,
67 MdnsRandom* random_delay,
68 TrackerType tracker_type)
69 : sender_(sender),
70 task_runner_(task_runner),
71 now_function_(now_function),
72 send_alarm_(now_function, task_runner),
73 random_delay_(random_delay),
74 tracker_type_(tracker_type) {
75 OSP_DCHECK(task_runner_);
76 OSP_DCHECK(now_function_);
77 OSP_DCHECK(random_delay_);
78 OSP_DCHECK(sender_);
79 }
80
~MdnsTracker()81 MdnsTracker::~MdnsTracker() {
82 send_alarm_.Cancel();
83
84 for (const MdnsTracker* node : adjacent_nodes_) {
85 node->RemovedReverseAdjacency(this);
86 }
87 }
88
AddAdjacentNode(const MdnsTracker * node) const89 bool MdnsTracker::AddAdjacentNode(const MdnsTracker* node) const {
90 OSP_DCHECK(node);
91 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
92
93 auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node);
94 if (it != adjacent_nodes_.end()) {
95 return false;
96 }
97
98 adjacent_nodes_.push_back(node);
99 node->AddReverseAdjacency(this);
100 return true;
101 }
102
RemoveAdjacentNode(const MdnsTracker * node) const103 bool MdnsTracker::RemoveAdjacentNode(const MdnsTracker* node) const {
104 OSP_DCHECK(node);
105 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
106
107 auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node);
108 if (it == adjacent_nodes_.end()) {
109 return false;
110 }
111
112 adjacent_nodes_.erase(it);
113 node->RemovedReverseAdjacency(this);
114 return true;
115 }
116
AddReverseAdjacency(const MdnsTracker * node) const117 void MdnsTracker::AddReverseAdjacency(const MdnsTracker* node) const {
118 OSP_DCHECK(std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node) ==
119 adjacent_nodes_.end());
120
121 adjacent_nodes_.push_back(node);
122 }
123
RemovedReverseAdjacency(const MdnsTracker * node) const124 void MdnsTracker::RemovedReverseAdjacency(const MdnsTracker* node) const {
125 auto it = std::find(adjacent_nodes_.begin(), adjacent_nodes_.end(), node);
126 OSP_DCHECK(it != adjacent_nodes_.end());
127
128 adjacent_nodes_.erase(it);
129 }
130
MdnsRecordTracker(MdnsRecord record,DnsType dns_type,MdnsSender * sender,TaskRunner * task_runner,ClockNowFunctionPtr now_function,MdnsRandom * random_delay,RecordExpiredCallback record_expired_callback)131 MdnsRecordTracker::MdnsRecordTracker(
132 MdnsRecord record,
133 DnsType dns_type,
134 MdnsSender* sender,
135 TaskRunner* task_runner,
136 ClockNowFunctionPtr now_function,
137 MdnsRandom* random_delay,
138 RecordExpiredCallback record_expired_callback)
139 : MdnsTracker(sender,
140 task_runner,
141 now_function,
142 random_delay,
143 TrackerType::kRecordTracker),
144 record_(std::move(record)),
145 dns_type_(dns_type),
146 start_time_(now_function_()),
147 record_expired_callback_(std::move(record_expired_callback)) {
148 OSP_DCHECK(record_expired_callback_);
149
150 // RecordTrackers cannot be created for tracking NSEC types or ANY types.
151 OSP_DCHECK(dns_type_ != DnsType::kNSEC);
152 OSP_DCHECK(dns_type_ != DnsType::kANY);
153
154 // Validate that, if the provided |record| is an NSEC record, then it provides
155 // a negative response for |dns_type|.
156 OSP_DCHECK(record_.dns_type() != DnsType::kNSEC ||
157 IsNegativeResponseForType(record_, dns_type_));
158
159 ScheduleFollowUpQuery();
160 }
161
162 MdnsRecordTracker::~MdnsRecordTracker() = default;
163
Update(const MdnsRecord & new_record)164 ErrorOr<MdnsRecordTracker::UpdateType> MdnsRecordTracker::Update(
165 const MdnsRecord& new_record) {
166 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
167 const bool has_same_rdata = record_.dns_type() == new_record.dns_type() &&
168 record_.rdata() == new_record.rdata();
169 const bool new_is_negative_response = new_record.dns_type() == DnsType::kNSEC;
170 const bool current_is_negative_response =
171 record_.dns_type() == DnsType::kNSEC;
172
173 if ((record_.dns_class() != new_record.dns_class()) ||
174 (record_.name() != new_record.name())) {
175 // The new record has been passed to a wrong tracker.
176 return Error::Code::kParameterInvalid;
177 }
178
179 // New response record must correspond to the correct type.
180 if ((!new_is_negative_response && new_record.dns_type() != dns_type_) ||
181 (new_is_negative_response &&
182 !IsNegativeResponseForType(new_record, dns_type_))) {
183 // The new record has been passed to a wrong tracker.
184 return Error::Code::kParameterInvalid;
185 }
186
187 // Goodbye records must have the same RDATA but TTL of 0.
188 // RFC 6762 Section 10.1.
189 // https://tools.ietf.org/html/rfc6762#section-10.1
190 if (!new_is_negative_response && !current_is_negative_response &&
191 IsGoodbyeRecord(new_record) && !has_same_rdata) {
192 // The new record has been passed to a wrong tracker.
193 return Error::Code::kParameterInvalid;
194 }
195
196 UpdateType result = UpdateType::kGoodbye;
197 if (IsGoodbyeRecord(new_record)) {
198 record_ = MdnsRecord(new_record.name(), new_record.dns_type(),
199 new_record.dns_class(), new_record.record_type(),
200 kGoodbyeRecordTtl, new_record.rdata());
201
202 // Goodbye records do not need to be re-queried, set the attempt count to
203 // the last item, which is 100% of TTL, i.e. record expiration.
204 attempt_count_ = countof(kTtlFractions) - 1;
205 } else {
206 record_ = new_record;
207 attempt_count_ = 0;
208 result = has_same_rdata ? UpdateType::kTTLOnly : UpdateType::kRdata;
209 }
210
211 start_time_ = now_function_();
212 ScheduleFollowUpQuery();
213
214 return result;
215 }
216
AddAssociatedQuery(const MdnsQuestionTracker * question_tracker) const217 bool MdnsRecordTracker::AddAssociatedQuery(
218 const MdnsQuestionTracker* question_tracker) const {
219 return AddAdjacentNode(question_tracker);
220 }
221
RemoveAssociatedQuery(const MdnsQuestionTracker * question_tracker) const222 bool MdnsRecordTracker::RemoveAssociatedQuery(
223 const MdnsQuestionTracker* question_tracker) const {
224 return RemoveAdjacentNode(question_tracker);
225 }
226
ExpireSoon()227 void MdnsRecordTracker::ExpireSoon() {
228 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
229
230 record_ =
231 MdnsRecord(record_.name(), record_.dns_type(), record_.dns_class(),
232 record_.record_type(), kGoodbyeRecordTtl, record_.rdata());
233
234 // Set the attempt count to the last item, which is 100% of TTL, i.e. record
235 // expiration, to prevent any re-queries
236 attempt_count_ = countof(kTtlFractions) - 1;
237 start_time_ = now_function_();
238 ScheduleFollowUpQuery();
239 }
240
ExpireNow()241 void MdnsRecordTracker::ExpireNow() {
242 record_expired_callback_(this, record_);
243 }
244
IsNearingExpiry() const245 bool MdnsRecordTracker::IsNearingExpiry() const {
246 return (now_function_() - start_time_) > record_.ttl() / 2;
247 }
248
SendQuery() const249 bool MdnsRecordTracker::SendQuery() const {
250 const Clock::time_point expiration_time = start_time_ + record_.ttl();
251 bool is_expired = (now_function_() >= expiration_time);
252 if (!is_expired) {
253 for (const MdnsTracker* tracker : adjacent_nodes()) {
254 tracker->SendQuery();
255 }
256 } else {
257 record_expired_callback_(this, record_);
258 }
259
260 return !is_expired;
261 }
262
ScheduleFollowUpQuery()263 void MdnsRecordTracker::ScheduleFollowUpQuery() {
264 send_alarm_.Schedule(
265 [this] {
266 if (SendQuery()) {
267 ScheduleFollowUpQuery();
268 }
269 },
270 GetNextSendTime());
271 }
272
GetRecords() const273 std::vector<MdnsRecord> MdnsRecordTracker::GetRecords() const {
274 return {record_};
275 }
276
GetNextSendTime()277 Clock::time_point MdnsRecordTracker::GetNextSendTime() {
278 OSP_DCHECK(attempt_count_ < countof(kTtlFractions));
279
280 double ttl_fraction = kTtlFractions[attempt_count_++];
281
282 // Do not add random variation to the expiration time (last fraction of TTL)
283 if (attempt_count_ != countof(kTtlFractions)) {
284 ttl_fraction += random_delay_->GetRecordTtlVariation();
285 }
286
287 const Clock::duration delay =
288 Clock::to_duration(record_.ttl() * ttl_fraction);
289 return start_time_ + delay;
290 }
291
MdnsQuestionTracker(MdnsQuestion question,MdnsSender * sender,TaskRunner * task_runner,ClockNowFunctionPtr now_function,MdnsRandom * random_delay,const Config & config,QueryType query_type)292 MdnsQuestionTracker::MdnsQuestionTracker(MdnsQuestion question,
293 MdnsSender* sender,
294 TaskRunner* task_runner,
295 ClockNowFunctionPtr now_function,
296 MdnsRandom* random_delay,
297 const Config& config,
298 QueryType query_type)
299 : MdnsTracker(sender,
300 task_runner,
301 now_function,
302 random_delay,
303 TrackerType::kQuestionTracker),
304 question_(std::move(question)),
305 send_delay_(kMinimumQueryInterval),
306 query_type_(query_type),
307 maximum_announcement_count_(config.new_query_announcement_count < 0
308 ? INT_MAX
309 : config.new_query_announcement_count) {
310 // Initialize the last send time to time_point::min() so that the next call to
311 // SendQuery() is guaranteed to query the network.
312 last_send_time_ = TrivialClockTraits::time_point::min();
313
314 // The initial query has to be sent after a random delay of 20-120
315 // milliseconds.
316 if (announcements_so_far_ < maximum_announcement_count_) {
317 announcements_so_far_++;
318
319 if (query_type_ == QueryType::kOneShot) {
320 task_runner_->PostTask([this] { MdnsQuestionTracker::SendQuery(); });
321 } else {
322 OSP_DCHECK(query_type_ == QueryType::kContinuous);
323 send_alarm_.ScheduleFromNow(
324 [this]() {
325 MdnsQuestionTracker::SendQuery();
326 ScheduleFollowUpQuery();
327 },
328 random_delay_->GetInitialQueryDelay());
329 }
330 }
331 }
332
333 MdnsQuestionTracker::~MdnsQuestionTracker() = default;
334
AddAssociatedRecord(const MdnsRecordTracker * record_tracker) const335 bool MdnsQuestionTracker::AddAssociatedRecord(
336 const MdnsRecordTracker* record_tracker) const {
337 return AddAdjacentNode(record_tracker);
338 }
339
RemoveAssociatedRecord(const MdnsRecordTracker * record_tracker) const340 bool MdnsQuestionTracker::RemoveAssociatedRecord(
341 const MdnsRecordTracker* record_tracker) const {
342 return RemoveAdjacentNode(record_tracker);
343 }
344
GetRecords() const345 std::vector<MdnsRecord> MdnsQuestionTracker::GetRecords() const {
346 std::vector<MdnsRecord> records;
347 for (const MdnsTracker* tracker : adjacent_nodes()) {
348 OSP_DCHECK(tracker->tracker_type() == TrackerType::kRecordTracker);
349
350 // This call cannot result in an infinite loop because MdnsRecordTracker
351 // instances only return a single record from this call.
352 std::vector<MdnsRecord> node_records = tracker->GetRecords();
353 OSP_DCHECK(node_records.size() == 1);
354
355 records.push_back(std::move(node_records[0]));
356 }
357
358 return records;
359 }
360
SendQuery() const361 bool MdnsQuestionTracker::SendQuery() const {
362 // NOTE: The RFC does not specify the minimum interval between queries for
363 // multiple records of the same query when initiated for different reasons
364 // (such as for different record refreshes or for one record refresh and the
365 // periodic re-querying for a continuous query). For this reason, a constant
366 // outside of scope of the RFC has been chosen.
367 TrivialClockTraits::time_point now = now_function_();
368 if (now < last_send_time_ + kMinimumQueryInterval) {
369 return true;
370 }
371 last_send_time_ = now;
372
373 MdnsMessage message(CreateMessageId(), MessageType::Query);
374 message.AddQuestion(question_);
375
376 // Send the message and additional known answer packets as needed.
377 for (auto it = adjacent_nodes().begin(); it != adjacent_nodes().end();) {
378 OSP_DCHECK((*it)->tracker_type() == TrackerType::kRecordTracker);
379
380 const MdnsRecordTracker* record_tracker =
381 static_cast<const MdnsRecordTracker*>(*it);
382 if (record_tracker->IsNearingExpiry()) {
383 it++;
384 continue;
385 }
386
387 // A record tracker should only contain one record.
388 std::vector<MdnsRecord> node_records = (*it)->GetRecords();
389 OSP_DCHECK(node_records.size() == 1);
390 MdnsRecord node_record = std::move(node_records[0]);
391
392 if (message.CanAddRecord(node_record)) {
393 message.AddAnswer(std::move(node_record));
394 it++;
395 } else if (message.questions().empty() && message.answers().empty()) {
396 // This case should never happen, because it means a record is too large
397 // to fit into its own message.
398 OSP_LOG_INFO
399 << "Encountered unreasonably large message in cache. Skipping "
400 << "known answer in suppressions...";
401 it++;
402 } else {
403 message.set_truncated();
404 sender_->SendMulticast(message);
405 message = MdnsMessage(CreateMessageId(), MessageType::Query);
406 }
407 }
408 sender_->SendMulticast(message);
409 return true;
410 }
411
ScheduleFollowUpQuery()412 void MdnsQuestionTracker::ScheduleFollowUpQuery() {
413 if (announcements_so_far_ >= maximum_announcement_count_) {
414 return;
415 }
416 announcements_so_far_++;
417
418 send_alarm_.ScheduleFromNow(
419 [this] {
420 if (SendQuery()) {
421 ScheduleFollowUpQuery();
422 }
423 },
424 send_delay_);
425 send_delay_ = send_delay_ * kIntervalIncreaseFactor;
426 if (send_delay_ > kMaximumQueryInterval) {
427 send_delay_ = kMaximumQueryInterval;
428 }
429 }
430
431 } // namespace discovery
432 } // namespace openscreen
433