1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/dns/mdns_client_impl.h"
6
7 #include <algorithm>
8 #include <cstdint>
9 #include <memory>
10 #include <optional>
11 #include <utility>
12 #include <vector>
13
14 #include "base/containers/fixed_flat_set.h"
15 #include "base/functional/bind.h"
16 #include "base/location.h"
17 #include "base/metrics/histogram_functions.h"
18 #include "base/not_fatal_until.h"
19 #include "base/observer_list.h"
20 #include "base/ranges/algorithm.h"
21 #include "base/strings/string_util.h"
22 #include "base/task/single_thread_task_runner.h"
23 #include "base/time/clock.h"
24 #include "base/time/default_clock.h"
25 #include "base/time/time.h"
26 #include "base/timer/timer.h"
27 #include "net/base/net_errors.h"
28 #include "net/base/rand_callback.h"
29 #include "net/dns/dns_names_util.h"
30 #include "net/dns/public/dns_protocol.h"
31 #include "net/dns/public/util.h"
32 #include "net/dns/record_rdata.h"
33 #include "net/socket/datagram_socket.h"
34
35 // TODO(gene): Remove this temporary method of disabling NSEC support once it
36 // becomes clear whether this feature should be
37 // supported. http://crbug.com/255232
38 #define ENABLE_NSEC
39
40 namespace net {
41
42 namespace {
43
44 // The fractions of the record's original TTL after which an active listener
45 // (one that had |SetActiveRefresh(true)| called) will send a query to refresh
46 // its cache. This happens both at 85% of the original TTL and again at 95% of
47 // the original TTL.
48 const double kListenerRefreshRatio1 = 0.85;
49 const double kListenerRefreshRatio2 = 0.95;
50
51 // These values are persisted to logs. Entries should not be renumbered and
52 // numeric values should never be reused.
53 enum class mdnsQueryType {
54 kInitial = 0, // Initial mDNS query sent.
55 kRefresh = 1, // Refresh mDNS query sent.
56 kMaxValue = kRefresh,
57 };
58
RecordQueryMetric(mdnsQueryType query_type,std::string_view host)59 void RecordQueryMetric(mdnsQueryType query_type, std::string_view host) {
60 constexpr auto kPrintScanServices = base::MakeFixedFlatSet<std::string_view>({
61 "_ipps._tcp.local",
62 "_ipp._tcp.local",
63 "_pdl-datastream._tcp.local",
64 "_printer._tcp.local",
65 "_print._sub._ipps._tcp.local",
66 "_print._sub._ipp._tcp.local",
67 "_scanner._tcp.local",
68 "_uscans._tcp.local",
69 "_uscan._tcp.local",
70 });
71
72 if (host.ends_with("_googlecast._tcp.local")) {
73 base::UmaHistogramEnumeration("Network.Mdns.Googlecast", query_type);
74 } else if (base::ranges::any_of(kPrintScanServices,
75 [&host](std::string_view service) {
76 return host.ends_with(service);
77 })) {
78 base::UmaHistogramEnumeration("Network.Mdns.PrintScan", query_type);
79 } else {
80 base::UmaHistogramEnumeration("Network.Mdns.Other", query_type);
81 }
82 }
83
84 } // namespace
85
CreateSockets(std::vector<std::unique_ptr<DatagramServerSocket>> * sockets)86 void MDnsSocketFactoryImpl::CreateSockets(
87 std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) {
88 InterfaceIndexFamilyList interfaces(GetMDnsInterfacesToBind());
89 for (const auto& interface : interfaces) {
90 DCHECK(interface.second == ADDRESS_FAMILY_IPV4 ||
91 interface.second == ADDRESS_FAMILY_IPV6);
92 std::unique_ptr<DatagramServerSocket> socket(
93 CreateAndBindMDnsSocket(interface.second, interface.first, net_log_));
94 if (socket)
95 sockets->push_back(std::move(socket));
96 }
97 }
98
SocketHandler(std::unique_ptr<DatagramServerSocket> socket,MDnsConnection * connection)99 MDnsConnection::SocketHandler::SocketHandler(
100 std::unique_ptr<DatagramServerSocket> socket,
101 MDnsConnection* connection)
102 : socket_(std::move(socket)),
103 connection_(connection),
104 response_(dns_protocol::kMaxMulticastSize) {}
105
106 MDnsConnection::SocketHandler::~SocketHandler() = default;
107
Start()108 int MDnsConnection::SocketHandler::Start() {
109 IPEndPoint end_point;
110 int rv = socket_->GetLocalAddress(&end_point);
111 if (rv != OK)
112 return rv;
113 DCHECK(end_point.GetFamily() == ADDRESS_FAMILY_IPV4 ||
114 end_point.GetFamily() == ADDRESS_FAMILY_IPV6);
115 multicast_addr_ = dns_util::GetMdnsGroupEndPoint(end_point.GetFamily());
116 return DoLoop(0);
117 }
118
DoLoop(int rv)119 int MDnsConnection::SocketHandler::DoLoop(int rv) {
120 do {
121 if (rv > 0)
122 connection_->OnDatagramReceived(&response_, recv_addr_, rv);
123
124 rv = socket_->RecvFrom(
125 response_.io_buffer(), response_.io_buffer_size(), &recv_addr_,
126 base::BindOnce(&MDnsConnection::SocketHandler::OnDatagramReceived,
127 base::Unretained(this)));
128 } while (rv > 0);
129
130 if (rv != ERR_IO_PENDING)
131 return rv;
132
133 return OK;
134 }
135
OnDatagramReceived(int rv)136 void MDnsConnection::SocketHandler::OnDatagramReceived(int rv) {
137 if (rv >= OK)
138 rv = DoLoop(rv);
139
140 if (rv != OK)
141 connection_->PostOnError(this, rv);
142 }
143
Send(const scoped_refptr<IOBuffer> & buffer,unsigned size)144 void MDnsConnection::SocketHandler::Send(const scoped_refptr<IOBuffer>& buffer,
145 unsigned size) {
146 if (send_in_progress_) {
147 send_queue_.emplace(buffer, size);
148 return;
149 }
150 int rv =
151 socket_->SendTo(buffer.get(), size, multicast_addr_,
152 base::BindOnce(&MDnsConnection::SocketHandler::SendDone,
153 base::Unretained(this)));
154 if (rv == ERR_IO_PENDING) {
155 send_in_progress_ = true;
156 } else if (rv < OK) {
157 connection_->PostOnError(this, rv);
158 }
159 }
160
SendDone(int rv)161 void MDnsConnection::SocketHandler::SendDone(int rv) {
162 DCHECK(send_in_progress_);
163 send_in_progress_ = false;
164 if (rv != OK)
165 connection_->PostOnError(this, rv);
166 while (!send_in_progress_ && !send_queue_.empty()) {
167 std::pair<scoped_refptr<IOBuffer>, unsigned> buffer = send_queue_.front();
168 send_queue_.pop();
169 Send(buffer.first, buffer.second);
170 }
171 }
172
MDnsConnection(MDnsConnection::Delegate * delegate)173 MDnsConnection::MDnsConnection(MDnsConnection::Delegate* delegate)
174 : delegate_(delegate) {}
175
176 MDnsConnection::~MDnsConnection() = default;
177
Init(MDnsSocketFactory * socket_factory)178 int MDnsConnection::Init(MDnsSocketFactory* socket_factory) {
179 std::vector<std::unique_ptr<DatagramServerSocket>> sockets;
180 socket_factory->CreateSockets(&sockets);
181
182 for (std::unique_ptr<DatagramServerSocket>& socket : sockets) {
183 socket_handlers_.push_back(std::make_unique<MDnsConnection::SocketHandler>(
184 std::move(socket), this));
185 }
186
187 // All unbound sockets need to be bound before processing untrusted input.
188 // This is done for security reasons, so that an attacker can't get an unbound
189 // socket.
190 int last_failure = ERR_FAILED;
191 for (size_t i = 0; i < socket_handlers_.size();) {
192 int rv = socket_handlers_[i]->Start();
193 if (rv != OK) {
194 last_failure = rv;
195 socket_handlers_.erase(socket_handlers_.begin() + i);
196 VLOG(1) << "Start failed, socket=" << i << ", error=" << rv;
197 } else {
198 ++i;
199 }
200 }
201 VLOG(1) << "Sockets ready:" << socket_handlers_.size();
202 DCHECK_NE(ERR_IO_PENDING, last_failure);
203 return socket_handlers_.empty() ? last_failure : OK;
204 }
205
Send(const scoped_refptr<IOBuffer> & buffer,unsigned size)206 void MDnsConnection::Send(const scoped_refptr<IOBuffer>& buffer,
207 unsigned size) {
208 for (std::unique_ptr<SocketHandler>& handler : socket_handlers_)
209 handler->Send(buffer, size);
210 }
211
PostOnError(SocketHandler * loop,int rv)212 void MDnsConnection::PostOnError(SocketHandler* loop, int rv) {
213 int id = 0;
214 for (const auto& it : socket_handlers_) {
215 if (it.get() == loop)
216 break;
217 id++;
218 }
219 VLOG(1) << "Socket error. id=" << id << ", error=" << rv;
220 // Post to allow deletion of this object by delegate.
221 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
222 FROM_HERE, base::BindOnce(&MDnsConnection::OnError,
223 weak_ptr_factory_.GetWeakPtr(), rv));
224 }
225
OnError(int rv)226 void MDnsConnection::OnError(int rv) {
227 // TODO(noamsml): Specific handling of intermittent errors that can be handled
228 // in the connection.
229 delegate_->OnConnectionError(rv);
230 }
231
OnDatagramReceived(DnsResponse * response,const IPEndPoint & recv_addr,int bytes_read)232 void MDnsConnection::OnDatagramReceived(
233 DnsResponse* response,
234 const IPEndPoint& recv_addr,
235 int bytes_read) {
236 // TODO(noamsml): More sophisticated error handling.
237 DCHECK_GT(bytes_read, 0);
238 delegate_->HandlePacket(response, bytes_read);
239 }
240
Core(base::Clock * clock,base::OneShotTimer * timer)241 MDnsClientImpl::Core::Core(base::Clock* clock, base::OneShotTimer* timer)
242 : clock_(clock),
243 cleanup_timer_(timer),
244 connection_(
245 std::make_unique<MDnsConnection>((MDnsConnection::Delegate*)this)) {
246 DCHECK(cleanup_timer_);
247 DCHECK(!cleanup_timer_->IsRunning());
248 }
249
~Core()250 MDnsClientImpl::Core::~Core() {
251 cleanup_timer_->Stop();
252 }
253
Init(MDnsSocketFactory * socket_factory)254 int MDnsClientImpl::Core::Init(MDnsSocketFactory* socket_factory) {
255 CHECK(!cleanup_timer_->IsRunning());
256 return connection_->Init(socket_factory);
257 }
258
SendQuery(uint16_t rrtype,const std::string & name)259 bool MDnsClientImpl::Core::SendQuery(uint16_t rrtype, const std::string& name) {
260 std::optional<std::vector<uint8_t>> name_dns =
261 dns_names_util::DottedNameToNetwork(name);
262 if (!name_dns.has_value())
263 return false;
264
265 DnsQuery query(0, name_dns.value(), rrtype);
266 query.set_flags(0); // Remove the RD flag from the query. It is unneeded.
267
268 connection_->Send(query.io_buffer(), query.io_buffer()->size());
269 return true;
270 }
271
HandlePacket(DnsResponse * response,int bytes_read)272 void MDnsClientImpl::Core::HandlePacket(DnsResponse* response,
273 int bytes_read) {
274 unsigned offset;
275 // Note: We store cache keys rather than record pointers to avoid
276 // erroneous behavior in case a packet contains multiple exclusive
277 // records with the same type and name.
278 std::map<MDnsCache::Key, MDnsCache::UpdateType> update_keys;
279 DCHECK_GT(bytes_read, 0);
280 if (!response->InitParseWithoutQuery(bytes_read)) {
281 DVLOG(1) << "Could not understand an mDNS packet.";
282 return; // Message is unreadable.
283 }
284
285 // TODO(noamsml): duplicate query suppression.
286 if (!(response->flags() & dns_protocol::kFlagResponse))
287 return; // Message is a query. ignore it.
288
289 DnsRecordParser parser = response->Parser();
290 unsigned answer_count = response->answer_count() +
291 response->additional_answer_count();
292
293 for (unsigned i = 0; i < answer_count; i++) {
294 offset = parser.GetOffset();
295 std::unique_ptr<const RecordParsed> record =
296 RecordParsed::CreateFrom(&parser, clock_->Now());
297
298 if (!record) {
299 DVLOG(1) << "Could not understand an mDNS record.";
300
301 if (offset == parser.GetOffset()) {
302 DVLOG(1) << "Abandoned parsing the rest of the packet.";
303 return; // The parser did not advance, abort reading the packet.
304 } else {
305 continue; // We may be able to extract other records from the packet.
306 }
307 }
308
309 if ((record->klass() & dns_protocol::kMDnsClassMask) !=
310 dns_protocol::kClassIN) {
311 DVLOG(1) << "Received an mDNS record with non-IN class. Ignoring.";
312 continue; // Ignore all records not in the IN class.
313 }
314
315 MDnsCache::Key update_key = MDnsCache::Key::CreateFor(record.get());
316 MDnsCache::UpdateType update = cache_.UpdateDnsRecord(std::move(record));
317
318 // Cleanup time may have changed.
319 ScheduleCleanup(cache_.next_expiration());
320
321 update_keys.emplace(update_key, update);
322 }
323
324 for (const auto& update_key : update_keys) {
325 const RecordParsed* record = cache_.LookupKey(update_key.first);
326 if (!record)
327 continue;
328
329 if (record->type() == dns_protocol::kTypeNSEC) {
330 #if defined(ENABLE_NSEC)
331 NotifyNsecRecord(record);
332 #endif
333 } else {
334 AlertListeners(update_key.second,
335 ListenerKey(record->name(), record->type()), record);
336 }
337 }
338 }
339
NotifyNsecRecord(const RecordParsed * record)340 void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) {
341 DCHECK_EQ(dns_protocol::kTypeNSEC, record->type());
342 const NsecRecordRdata* rdata = record->rdata<NsecRecordRdata>();
343 DCHECK(rdata);
344
345 // Remove all cached records matching the nonexistent RR types.
346 std::vector<const RecordParsed*> records_to_remove;
347
348 cache_.FindDnsRecords(0, record->name(), &records_to_remove, clock_->Now());
349
350 for (const auto* record_to_remove : records_to_remove) {
351 if (record_to_remove->type() == dns_protocol::kTypeNSEC)
352 continue;
353 if (!rdata->GetBit(record_to_remove->type())) {
354 std::unique_ptr<const RecordParsed> record_removed =
355 cache_.RemoveRecord(record_to_remove);
356 DCHECK(record_removed);
357 OnRecordRemoved(record_removed.get());
358 }
359 }
360
361 // Alert all listeners waiting for the nonexistent RR types.
362 ListenerKey key(record->name(), 0);
363 auto i = listeners_.upper_bound(key);
364 for (; i != listeners_.end() &&
365 i->first.name_lowercase() == key.name_lowercase();
366 i++) {
367 if (!rdata->GetBit(i->first.type())) {
368 for (auto& observer : *i->second)
369 observer.AlertNsecRecord();
370 }
371 }
372 }
373
OnConnectionError(int error)374 void MDnsClientImpl::Core::OnConnectionError(int error) {
375 // TODO(noamsml): On connection error, recreate connection and flush cache.
376 VLOG(1) << "MDNS OnConnectionError (code: " << error << ")";
377 }
378
ListenerKey(const std::string & name,uint16_t type)379 MDnsClientImpl::Core::ListenerKey::ListenerKey(const std::string& name,
380 uint16_t type)
381 : name_lowercase_(base::ToLowerASCII(name)), type_(type) {}
382
operator <(const MDnsClientImpl::Core::ListenerKey & key) const383 bool MDnsClientImpl::Core::ListenerKey::operator<(
384 const MDnsClientImpl::Core::ListenerKey& key) const {
385 if (name_lowercase_ == key.name_lowercase_)
386 return type_ < key.type_;
387 return name_lowercase_ < key.name_lowercase_;
388 }
389
AlertListeners(MDnsCache::UpdateType update_type,const ListenerKey & key,const RecordParsed * record)390 void MDnsClientImpl::Core::AlertListeners(
391 MDnsCache::UpdateType update_type,
392 const ListenerKey& key,
393 const RecordParsed* record) {
394 auto listener_map_iterator = listeners_.find(key);
395 if (listener_map_iterator == listeners_.end()) return;
396
397 for (auto& observer : *listener_map_iterator->second)
398 observer.HandleRecordUpdate(update_type, record);
399 }
400
AddListener(MDnsListenerImpl * listener)401 void MDnsClientImpl::Core::AddListener(
402 MDnsListenerImpl* listener) {
403 ListenerKey key(listener->GetName(), listener->GetType());
404
405 auto& observer_list = listeners_[key];
406 if (!observer_list)
407 observer_list = std::make_unique<ObserverListType>();
408
409 observer_list->AddObserver(listener);
410 }
411
RemoveListener(MDnsListenerImpl * listener)412 void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) {
413 ListenerKey key(listener->GetName(), listener->GetType());
414 auto observer_list_iterator = listeners_.find(key);
415
416 CHECK(observer_list_iterator != listeners_.end(), base::NotFatalUntil::M130);
417 DCHECK(observer_list_iterator->second->HasObserver(listener));
418
419 observer_list_iterator->second->RemoveObserver(listener);
420
421 // Remove the observer list from the map if it is empty
422 if (observer_list_iterator->second->empty()) {
423 // Schedule the actual removal for later in case the listener removal
424 // happens while iterating over the observer list.
425 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
426 FROM_HERE, base::BindOnce(&MDnsClientImpl::Core::CleanupObserverList,
427 weak_ptr_factory_.GetWeakPtr(), key));
428 }
429 }
430
CleanupObserverList(const ListenerKey & key)431 void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey& key) {
432 auto found = listeners_.find(key);
433 if (found != listeners_.end() && found->second->empty()) {
434 listeners_.erase(found);
435 }
436 }
437
ScheduleCleanup(base::Time cleanup)438 void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup) {
439 // If cache is overfilled. Force an immediate cleanup.
440 if (cache_.IsCacheOverfilled())
441 cleanup = clock_->Now();
442
443 // Cleanup is already scheduled, no need to do anything.
444 if (cleanup == scheduled_cleanup_) {
445 return;
446 }
447 scheduled_cleanup_ = cleanup;
448
449 // This cancels the previously scheduled cleanup.
450 cleanup_timer_->Stop();
451
452 // If |cleanup| is empty, then no cleanup necessary.
453 if (cleanup != base::Time()) {
454 cleanup_timer_->Start(FROM_HERE,
455 std::max(base::TimeDelta(), cleanup - clock_->Now()),
456 base::BindOnce(&MDnsClientImpl::Core::DoCleanup,
457 base::Unretained(this)));
458 }
459 }
460
DoCleanup()461 void MDnsClientImpl::Core::DoCleanup() {
462 cache_.CleanupRecords(
463 clock_->Now(), base::BindRepeating(&MDnsClientImpl::Core::OnRecordRemoved,
464 base::Unretained(this)));
465
466 ScheduleCleanup(cache_.next_expiration());
467 }
468
OnRecordRemoved(const RecordParsed * record)469 void MDnsClientImpl::Core::OnRecordRemoved(
470 const RecordParsed* record) {
471 AlertListeners(MDnsCache::RecordRemoved,
472 ListenerKey(record->name(), record->type()), record);
473 }
474
QueryCache(uint16_t rrtype,const std::string & name,std::vector<const RecordParsed * > * records) const475 void MDnsClientImpl::Core::QueryCache(
476 uint16_t rrtype,
477 const std::string& name,
478 std::vector<const RecordParsed*>* records) const {
479 cache_.FindDnsRecords(rrtype, name, records, clock_->Now());
480 }
481
MDnsClientImpl()482 MDnsClientImpl::MDnsClientImpl()
483 : clock_(base::DefaultClock::GetInstance()),
484 cleanup_timer_(std::make_unique<base::OneShotTimer>()) {}
485
MDnsClientImpl(base::Clock * clock,std::unique_ptr<base::OneShotTimer> timer)486 MDnsClientImpl::MDnsClientImpl(base::Clock* clock,
487 std::unique_ptr<base::OneShotTimer> timer)
488 : clock_(clock), cleanup_timer_(std::move(timer)) {}
489
~MDnsClientImpl()490 MDnsClientImpl::~MDnsClientImpl() {
491 StopListening();
492 }
493
StartListening(MDnsSocketFactory * socket_factory)494 int MDnsClientImpl::StartListening(MDnsSocketFactory* socket_factory) {
495 DCHECK(!core_.get());
496 core_ = std::make_unique<Core>(clock_, cleanup_timer_.get());
497 int rv = core_->Init(socket_factory);
498 if (rv != OK) {
499 DCHECK_NE(ERR_IO_PENDING, rv);
500 core_.reset();
501 }
502 return rv;
503 }
504
StopListening()505 void MDnsClientImpl::StopListening() {
506 core_.reset();
507 }
508
IsListening() const509 bool MDnsClientImpl::IsListening() const {
510 return core_.get() != nullptr;
511 }
512
CreateListener(uint16_t rrtype,const std::string & name,MDnsListener::Delegate * delegate)513 std::unique_ptr<MDnsListener> MDnsClientImpl::CreateListener(
514 uint16_t rrtype,
515 const std::string& name,
516 MDnsListener::Delegate* delegate) {
517 return std::make_unique<MDnsListenerImpl>(rrtype, name, clock_, delegate,
518 this);
519 }
520
CreateTransaction(uint16_t rrtype,const std::string & name,int flags,const MDnsTransaction::ResultCallback & callback)521 std::unique_ptr<MDnsTransaction> MDnsClientImpl::CreateTransaction(
522 uint16_t rrtype,
523 const std::string& name,
524 int flags,
525 const MDnsTransaction::ResultCallback& callback) {
526 return std::make_unique<MDnsTransactionImpl>(rrtype, name, flags, callback,
527 this);
528 }
529
MDnsListenerImpl(uint16_t rrtype,const std::string & name,base::Clock * clock,MDnsListener::Delegate * delegate,MDnsClientImpl * client)530 MDnsListenerImpl::MDnsListenerImpl(uint16_t rrtype,
531 const std::string& name,
532 base::Clock* clock,
533 MDnsListener::Delegate* delegate,
534 MDnsClientImpl* client)
535 : rrtype_(rrtype),
536 name_(name),
537 clock_(clock),
538 client_(client),
539 delegate_(delegate) {}
540
~MDnsListenerImpl()541 MDnsListenerImpl::~MDnsListenerImpl() {
542 if (started_) {
543 DCHECK(client_->core());
544 client_->core()->RemoveListener(this);
545 }
546 }
547
Start()548 bool MDnsListenerImpl::Start() {
549 DCHECK(!started_);
550
551 started_ = true;
552
553 DCHECK(client_->core());
554 client_->core()->AddListener(this);
555
556 return true;
557 }
558
SetActiveRefresh(bool active_refresh)559 void MDnsListenerImpl::SetActiveRefresh(bool active_refresh) {
560 active_refresh_ = active_refresh;
561
562 if (started_) {
563 if (!active_refresh_) {
564 next_refresh_.Cancel();
565 } else if (last_update_ != base::Time()) {
566 ScheduleNextRefresh();
567 }
568 }
569 }
570
GetName() const571 const std::string& MDnsListenerImpl::GetName() const {
572 return name_;
573 }
574
GetType() const575 uint16_t MDnsListenerImpl::GetType() const {
576 return rrtype_;
577 }
578
HandleRecordUpdate(MDnsCache::UpdateType update_type,const RecordParsed * record)579 void MDnsListenerImpl::HandleRecordUpdate(MDnsCache::UpdateType update_type,
580 const RecordParsed* record) {
581 DCHECK(started_);
582
583 if (update_type != MDnsCache::RecordRemoved) {
584 ttl_ = record->ttl();
585 last_update_ = record->time_created();
586
587 ScheduleNextRefresh();
588 }
589
590 if (update_type != MDnsCache::NoChange) {
591 MDnsListener::UpdateType update_external;
592
593 switch (update_type) {
594 case MDnsCache::RecordAdded:
595 update_external = MDnsListener::RECORD_ADDED;
596 break;
597 case MDnsCache::RecordChanged:
598 update_external = MDnsListener::RECORD_CHANGED;
599 break;
600 case MDnsCache::RecordRemoved:
601 update_external = MDnsListener::RECORD_REMOVED;
602 break;
603 case MDnsCache::NoChange:
604 default:
605 NOTREACHED();
606 }
607
608 delegate_->OnRecordUpdate(update_external, record);
609 }
610 }
611
AlertNsecRecord()612 void MDnsListenerImpl::AlertNsecRecord() {
613 DCHECK(started_);
614 delegate_->OnNsecRecord(name_, rrtype_);
615 }
616
ScheduleNextRefresh()617 void MDnsListenerImpl::ScheduleNextRefresh() {
618 DCHECK(last_update_ != base::Time());
619
620 if (!active_refresh_)
621 return;
622
623 // A zero TTL is a goodbye packet and should not be refreshed.
624 if (ttl_ == 0) {
625 next_refresh_.Cancel();
626 return;
627 }
628
629 next_refresh_.Reset(base::BindRepeating(&MDnsListenerImpl::DoRefresh,
630 weak_ptr_factory_.GetWeakPtr()));
631
632 // Schedule refreshes at both 85% and 95% of the original TTL. These will both
633 // be canceled and rescheduled if the record's TTL is updated due to a
634 // response being received.
635 base::Time next_refresh1 =
636 last_update_ +
637 base::Milliseconds(static_cast<int>(base::Time::kMillisecondsPerSecond *
638 kListenerRefreshRatio1 * ttl_));
639
640 base::Time next_refresh2 =
641 last_update_ +
642 base::Milliseconds(static_cast<int>(base::Time::kMillisecondsPerSecond *
643 kListenerRefreshRatio2 * ttl_));
644
645 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
646 FROM_HERE, next_refresh_.callback(), next_refresh1 - clock_->Now());
647
648 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
649 FROM_HERE, next_refresh_.callback(), next_refresh2 - clock_->Now());
650 }
651
DoRefresh()652 void MDnsListenerImpl::DoRefresh() {
653 RecordQueryMetric(mdnsQueryType::kRefresh, name_);
654 client_->core()->SendQuery(rrtype_, name_);
655 }
656
MDnsTransactionImpl(uint16_t rrtype,const std::string & name,int flags,const MDnsTransaction::ResultCallback & callback,MDnsClientImpl * client)657 MDnsTransactionImpl::MDnsTransactionImpl(
658 uint16_t rrtype,
659 const std::string& name,
660 int flags,
661 const MDnsTransaction::ResultCallback& callback,
662 MDnsClientImpl* client)
663 : rrtype_(rrtype),
664 name_(name),
665 callback_(callback),
666 client_(client),
667 flags_(flags) {
668 DCHECK((flags_ & MDnsTransaction::FLAG_MASK) == flags_);
669 DCHECK(flags_ & MDnsTransaction::QUERY_CACHE ||
670 flags_ & MDnsTransaction::QUERY_NETWORK);
671 }
672
~MDnsTransactionImpl()673 MDnsTransactionImpl::~MDnsTransactionImpl() {
674 timeout_.Cancel();
675 }
676
Start()677 bool MDnsTransactionImpl::Start() {
678 DCHECK(!started_);
679 started_ = true;
680
681 base::WeakPtr<MDnsTransactionImpl> weak_this = weak_ptr_factory_.GetWeakPtr();
682 if (flags_ & MDnsTransaction::QUERY_CACHE) {
683 ServeRecordsFromCache();
684
685 if (!weak_this || !is_active()) return true;
686 }
687
688 if (flags_ & MDnsTransaction::QUERY_NETWORK) {
689 return QueryAndListen();
690 }
691
692 // If this is a cache only query, signal that the transaction is over
693 // immediately.
694 SignalTransactionOver();
695 return true;
696 }
697
GetName() const698 const std::string& MDnsTransactionImpl::GetName() const {
699 return name_;
700 }
701
GetType() const702 uint16_t MDnsTransactionImpl::GetType() const {
703 return rrtype_;
704 }
705
CacheRecordFound(const RecordParsed * record)706 void MDnsTransactionImpl::CacheRecordFound(const RecordParsed* record) {
707 DCHECK(started_);
708 OnRecordUpdate(MDnsListener::RECORD_ADDED, record);
709 }
710
TriggerCallback(MDnsTransaction::Result result,const RecordParsed * record)711 void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result,
712 const RecordParsed* record) {
713 DCHECK(started_);
714 if (!is_active()) return;
715
716 // Ensure callback is run after touching all class state, so that
717 // the callback can delete the transaction.
718 MDnsTransaction::ResultCallback callback = callback_;
719
720 // Reset the transaction if it expects a single result, or if the result
721 // is a final one (everything except for a record).
722 if (flags_ & MDnsTransaction::SINGLE_RESULT ||
723 result != MDnsTransaction::RESULT_RECORD) {
724 Reset();
725 }
726
727 callback.Run(result, record);
728 }
729
Reset()730 void MDnsTransactionImpl::Reset() {
731 callback_.Reset();
732 listener_.reset();
733 timeout_.Cancel();
734 }
735
OnRecordUpdate(MDnsListener::UpdateType update,const RecordParsed * record)736 void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update,
737 const RecordParsed* record) {
738 DCHECK(started_);
739 if (update == MDnsListener::RECORD_ADDED ||
740 update == MDnsListener::RECORD_CHANGED)
741 TriggerCallback(MDnsTransaction::RESULT_RECORD, record);
742 }
743
SignalTransactionOver()744 void MDnsTransactionImpl::SignalTransactionOver() {
745 DCHECK(started_);
746 if (flags_ & MDnsTransaction::SINGLE_RESULT) {
747 TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS, nullptr);
748 } else {
749 TriggerCallback(MDnsTransaction::RESULT_DONE, nullptr);
750 }
751 }
752
ServeRecordsFromCache()753 void MDnsTransactionImpl::ServeRecordsFromCache() {
754 std::vector<const RecordParsed*> records;
755 base::WeakPtr<MDnsTransactionImpl> weak_this = weak_ptr_factory_.GetWeakPtr();
756
757 if (client_->core()) {
758 client_->core()->QueryCache(rrtype_, name_, &records);
759 for (auto i = records.begin(); i != records.end() && weak_this; ++i) {
760 weak_this->TriggerCallback(MDnsTransaction::RESULT_RECORD, *i);
761 }
762
763 #if defined(ENABLE_NSEC)
764 if (records.empty()) {
765 DCHECK(weak_this);
766 client_->core()->QueryCache(dns_protocol::kTypeNSEC, name_, &records);
767 if (!records.empty()) {
768 const NsecRecordRdata* rdata =
769 records.front()->rdata<NsecRecordRdata>();
770 DCHECK(rdata);
771 if (!rdata->GetBit(rrtype_))
772 weak_this->TriggerCallback(MDnsTransaction::RESULT_NSEC, nullptr);
773 }
774 }
775 #endif
776 }
777 }
778
QueryAndListen()779 bool MDnsTransactionImpl::QueryAndListen() {
780 listener_ = client_->CreateListener(rrtype_, name_, this);
781 if (!listener_->Start())
782 return false;
783
784 DCHECK(client_->core());
785 RecordQueryMetric(mdnsQueryType::kInitial, name_);
786 if (!client_->core()->SendQuery(rrtype_, name_))
787 return false;
788
789 timeout_.Reset(base::BindOnce(&MDnsTransactionImpl::SignalTransactionOver,
790 weak_ptr_factory_.GetWeakPtr()));
791 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
792 FROM_HERE, timeout_.callback(), kTransactionTimeout);
793
794 return true;
795 }
796
OnNsecRecord(const std::string & name,unsigned type)797 void MDnsTransactionImpl::OnNsecRecord(const std::string& name, unsigned type) {
798 TriggerCallback(RESULT_NSEC, nullptr);
799 }
800
OnCachePurged()801 void MDnsTransactionImpl::OnCachePurged() {
802 // TODO(noamsml): Cache purge situations not yet implemented
803 }
804
805 } // namespace net
806