1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/dns/mdns_client_impl.h"
6
7 #include <algorithm>
8 #include <cstdint>
9 #include <memory>
10 #include <utility>
11 #include <vector>
12
13 #include "base/functional/bind.h"
14 #include "base/location.h"
15 #include "base/observer_list.h"
16 #include "base/strings/string_util.h"
17 #include "base/task/single_thread_task_runner.h"
18 #include "base/time/clock.h"
19 #include "base/time/default_clock.h"
20 #include "base/time/time.h"
21 #include "base/timer/timer.h"
22 #include "net/base/net_errors.h"
23 #include "net/base/rand_callback.h"
24 #include "net/dns/dns_names_util.h"
25 #include "net/dns/public/dns_protocol.h"
26 #include "net/dns/public/util.h"
27 #include "net/dns/record_rdata.h"
28 #include "net/socket/datagram_socket.h"
29 #include "third_party/abseil-cpp/absl/types/optional.h"
30
31 // TODO(gene): Remove this temporary method of disabling NSEC support once it
32 // becomes clear whether this feature should be
33 // supported. http://crbug.com/255232
34 #define ENABLE_NSEC
35
36 namespace net {
37
38 namespace {
39
40 // The fractions of the record's original TTL after which an active listener
41 // (one that had |SetActiveRefresh(true)| called) will send a query to refresh
42 // its cache. This happens both at 85% of the original TTL and again at 95% of
43 // the original TTL.
44 const double kListenerRefreshRatio1 = 0.85;
45 const double kListenerRefreshRatio2 = 0.95;
46
47 } // namespace
48
CreateSockets(std::vector<std::unique_ptr<DatagramServerSocket>> * sockets)49 void MDnsSocketFactoryImpl::CreateSockets(
50 std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) {
51 InterfaceIndexFamilyList interfaces(GetMDnsInterfacesToBind());
52 for (const auto& interface : interfaces) {
53 DCHECK(interface.second == ADDRESS_FAMILY_IPV4 ||
54 interface.second == ADDRESS_FAMILY_IPV6);
55 std::unique_ptr<DatagramServerSocket> socket(
56 CreateAndBindMDnsSocket(interface.second, interface.first, net_log_));
57 if (socket)
58 sockets->push_back(std::move(socket));
59 }
60 }
61
SocketHandler(std::unique_ptr<DatagramServerSocket> socket,MDnsConnection * connection)62 MDnsConnection::SocketHandler::SocketHandler(
63 std::unique_ptr<DatagramServerSocket> socket,
64 MDnsConnection* connection)
65 : socket_(std::move(socket)),
66 connection_(connection),
67 response_(dns_protocol::kMaxMulticastSize) {}
68
69 MDnsConnection::SocketHandler::~SocketHandler() = default;
70
Start()71 int MDnsConnection::SocketHandler::Start() {
72 IPEndPoint end_point;
73 int rv = socket_->GetLocalAddress(&end_point);
74 if (rv != OK)
75 return rv;
76 DCHECK(end_point.GetFamily() == ADDRESS_FAMILY_IPV4 ||
77 end_point.GetFamily() == ADDRESS_FAMILY_IPV6);
78 multicast_addr_ = dns_util::GetMdnsGroupEndPoint(end_point.GetFamily());
79 return DoLoop(0);
80 }
81
DoLoop(int rv)82 int MDnsConnection::SocketHandler::DoLoop(int rv) {
83 do {
84 if (rv > 0)
85 connection_->OnDatagramReceived(&response_, recv_addr_, rv);
86
87 rv = socket_->RecvFrom(
88 response_.io_buffer(), response_.io_buffer_size(), &recv_addr_,
89 base::BindOnce(&MDnsConnection::SocketHandler::OnDatagramReceived,
90 base::Unretained(this)));
91 } while (rv > 0);
92
93 if (rv != ERR_IO_PENDING)
94 return rv;
95
96 return OK;
97 }
98
OnDatagramReceived(int rv)99 void MDnsConnection::SocketHandler::OnDatagramReceived(int rv) {
100 if (rv >= OK)
101 rv = DoLoop(rv);
102
103 if (rv != OK)
104 connection_->PostOnError(this, rv);
105 }
106
Send(const scoped_refptr<IOBuffer> & buffer,unsigned size)107 void MDnsConnection::SocketHandler::Send(const scoped_refptr<IOBuffer>& buffer,
108 unsigned size) {
109 if (send_in_progress_) {
110 send_queue_.push(std::make_pair(buffer, size));
111 return;
112 }
113 int rv =
114 socket_->SendTo(buffer.get(), size, multicast_addr_,
115 base::BindOnce(&MDnsConnection::SocketHandler::SendDone,
116 base::Unretained(this)));
117 if (rv == ERR_IO_PENDING) {
118 send_in_progress_ = true;
119 } else if (rv < OK) {
120 connection_->PostOnError(this, rv);
121 }
122 }
123
SendDone(int rv)124 void MDnsConnection::SocketHandler::SendDone(int rv) {
125 DCHECK(send_in_progress_);
126 send_in_progress_ = false;
127 if (rv != OK)
128 connection_->PostOnError(this, rv);
129 while (!send_in_progress_ && !send_queue_.empty()) {
130 std::pair<scoped_refptr<IOBuffer>, unsigned> buffer = send_queue_.front();
131 send_queue_.pop();
132 Send(buffer.first, buffer.second);
133 }
134 }
135
MDnsConnection(MDnsConnection::Delegate * delegate)136 MDnsConnection::MDnsConnection(MDnsConnection::Delegate* delegate)
137 : delegate_(delegate) {}
138
139 MDnsConnection::~MDnsConnection() = default;
140
Init(MDnsSocketFactory * socket_factory)141 int MDnsConnection::Init(MDnsSocketFactory* socket_factory) {
142 std::vector<std::unique_ptr<DatagramServerSocket>> sockets;
143 socket_factory->CreateSockets(&sockets);
144
145 for (std::unique_ptr<DatagramServerSocket>& socket : sockets) {
146 socket_handlers_.push_back(std::make_unique<MDnsConnection::SocketHandler>(
147 std::move(socket), this));
148 }
149
150 // All unbound sockets need to be bound before processing untrusted input.
151 // This is done for security reasons, so that an attacker can't get an unbound
152 // socket.
153 int last_failure = ERR_FAILED;
154 for (size_t i = 0; i < socket_handlers_.size();) {
155 int rv = socket_handlers_[i]->Start();
156 if (rv != OK) {
157 last_failure = rv;
158 socket_handlers_.erase(socket_handlers_.begin() + i);
159 VLOG(1) << "Start failed, socket=" << i << ", error=" << rv;
160 } else {
161 ++i;
162 }
163 }
164 VLOG(1) << "Sockets ready:" << socket_handlers_.size();
165 DCHECK_NE(ERR_IO_PENDING, last_failure);
166 return socket_handlers_.empty() ? last_failure : OK;
167 }
168
Send(const scoped_refptr<IOBuffer> & buffer,unsigned size)169 void MDnsConnection::Send(const scoped_refptr<IOBuffer>& buffer,
170 unsigned size) {
171 for (std::unique_ptr<SocketHandler>& handler : socket_handlers_)
172 handler->Send(buffer, size);
173 }
174
PostOnError(SocketHandler * loop,int rv)175 void MDnsConnection::PostOnError(SocketHandler* loop, int rv) {
176 int id = 0;
177 for (const auto& it : socket_handlers_) {
178 if (it.get() == loop)
179 break;
180 id++;
181 }
182 VLOG(1) << "Socket error. id=" << id << ", error=" << rv;
183 // Post to allow deletion of this object by delegate.
184 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
185 FROM_HERE, base::BindOnce(&MDnsConnection::OnError,
186 weak_ptr_factory_.GetWeakPtr(), rv));
187 }
188
OnError(int rv)189 void MDnsConnection::OnError(int rv) {
190 // TODO(noamsml): Specific handling of intermittent errors that can be handled
191 // in the connection.
192 delegate_->OnConnectionError(rv);
193 }
194
OnDatagramReceived(DnsResponse * response,const IPEndPoint & recv_addr,int bytes_read)195 void MDnsConnection::OnDatagramReceived(
196 DnsResponse* response,
197 const IPEndPoint& recv_addr,
198 int bytes_read) {
199 // TODO(noamsml): More sophisticated error handling.
200 DCHECK_GT(bytes_read, 0);
201 delegate_->HandlePacket(response, bytes_read);
202 }
203
Core(base::Clock * clock,base::OneShotTimer * timer)204 MDnsClientImpl::Core::Core(base::Clock* clock, base::OneShotTimer* timer)
205 : clock_(clock),
206 cleanup_timer_(timer),
207 connection_(
208 std::make_unique<MDnsConnection>((MDnsConnection::Delegate*)this)) {
209 DCHECK(cleanup_timer_);
210 DCHECK(!cleanup_timer_->IsRunning());
211 }
212
~Core()213 MDnsClientImpl::Core::~Core() {
214 cleanup_timer_->Stop();
215 }
216
Init(MDnsSocketFactory * socket_factory)217 int MDnsClientImpl::Core::Init(MDnsSocketFactory* socket_factory) {
218 CHECK(!cleanup_timer_->IsRunning());
219 return connection_->Init(socket_factory);
220 }
221
SendQuery(uint16_t rrtype,const std::string & name)222 bool MDnsClientImpl::Core::SendQuery(uint16_t rrtype, const std::string& name) {
223 absl::optional<std::vector<uint8_t>> name_dns =
224 dns_names_util::DottedNameToNetwork(name);
225 if (!name_dns.has_value())
226 return false;
227
228 DnsQuery query(0, name_dns.value(), rrtype);
229 query.set_flags(0); // Remove the RD flag from the query. It is unneeded.
230
231 connection_->Send(query.io_buffer(), query.io_buffer()->size());
232 return true;
233 }
234
HandlePacket(DnsResponse * response,int bytes_read)235 void MDnsClientImpl::Core::HandlePacket(DnsResponse* response,
236 int bytes_read) {
237 unsigned offset;
238 // Note: We store cache keys rather than record pointers to avoid
239 // erroneous behavior in case a packet contains multiple exclusive
240 // records with the same type and name.
241 std::map<MDnsCache::Key, MDnsCache::UpdateType> update_keys;
242 DCHECK_GT(bytes_read, 0);
243 if (!response->InitParseWithoutQuery(bytes_read)) {
244 DVLOG(1) << "Could not understand an mDNS packet.";
245 return; // Message is unreadable.
246 }
247
248 // TODO(noamsml): duplicate query suppression.
249 if (!(response->flags() & dns_protocol::kFlagResponse))
250 return; // Message is a query. ignore it.
251
252 DnsRecordParser parser = response->Parser();
253 unsigned answer_count = response->answer_count() +
254 response->additional_answer_count();
255
256 for (unsigned i = 0; i < answer_count; i++) {
257 offset = parser.GetOffset();
258 std::unique_ptr<const RecordParsed> record =
259 RecordParsed::CreateFrom(&parser, clock_->Now());
260
261 if (!record) {
262 DVLOG(1) << "Could not understand an mDNS record.";
263
264 if (offset == parser.GetOffset()) {
265 DVLOG(1) << "Abandoned parsing the rest of the packet.";
266 return; // The parser did not advance, abort reading the packet.
267 } else {
268 continue; // We may be able to extract other records from the packet.
269 }
270 }
271
272 if ((record->klass() & dns_protocol::kMDnsClassMask) !=
273 dns_protocol::kClassIN) {
274 DVLOG(1) << "Received an mDNS record with non-IN class. Ignoring.";
275 continue; // Ignore all records not in the IN class.
276 }
277
278 MDnsCache::Key update_key = MDnsCache::Key::CreateFor(record.get());
279 MDnsCache::UpdateType update = cache_.UpdateDnsRecord(std::move(record));
280
281 // Cleanup time may have changed.
282 ScheduleCleanup(cache_.next_expiration());
283
284 update_keys.insert(std::make_pair(update_key, update));
285 }
286
287 for (const auto& update_key : update_keys) {
288 const RecordParsed* record = cache_.LookupKey(update_key.first);
289 if (!record)
290 continue;
291
292 if (record->type() == dns_protocol::kTypeNSEC) {
293 #if defined(ENABLE_NSEC)
294 NotifyNsecRecord(record);
295 #endif
296 } else {
297 AlertListeners(update_key.second,
298 ListenerKey(record->name(), record->type()), record);
299 }
300 }
301 }
302
NotifyNsecRecord(const RecordParsed * record)303 void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) {
304 DCHECK_EQ(dns_protocol::kTypeNSEC, record->type());
305 const NsecRecordRdata* rdata = record->rdata<NsecRecordRdata>();
306 DCHECK(rdata);
307
308 // Remove all cached records matching the nonexistent RR types.
309 std::vector<const RecordParsed*> records_to_remove;
310
311 cache_.FindDnsRecords(0, record->name(), &records_to_remove, clock_->Now());
312
313 for (const auto* record_to_remove : records_to_remove) {
314 if (record_to_remove->type() == dns_protocol::kTypeNSEC)
315 continue;
316 if (!rdata->GetBit(record_to_remove->type())) {
317 std::unique_ptr<const RecordParsed> record_removed =
318 cache_.RemoveRecord(record_to_remove);
319 DCHECK(record_removed);
320 OnRecordRemoved(record_removed.get());
321 }
322 }
323
324 // Alert all listeners waiting for the nonexistent RR types.
325 ListenerKey key(record->name(), 0);
326 auto i = listeners_.upper_bound(key);
327 for (; i != listeners_.end() &&
328 i->first.name_lowercase() == key.name_lowercase();
329 i++) {
330 if (!rdata->GetBit(i->first.type())) {
331 for (auto& observer : *i->second)
332 observer.AlertNsecRecord();
333 }
334 }
335 }
336
OnConnectionError(int error)337 void MDnsClientImpl::Core::OnConnectionError(int error) {
338 // TODO(noamsml): On connection error, recreate connection and flush cache.
339 VLOG(1) << "MDNS OnConnectionError (code: " << error << ")";
340 }
341
ListenerKey(const std::string & name,uint16_t type)342 MDnsClientImpl::Core::ListenerKey::ListenerKey(const std::string& name,
343 uint16_t type)
344 : name_lowercase_(base::ToLowerASCII(name)), type_(type) {}
345
operator <(const MDnsClientImpl::Core::ListenerKey & key) const346 bool MDnsClientImpl::Core::ListenerKey::operator<(
347 const MDnsClientImpl::Core::ListenerKey& key) const {
348 if (name_lowercase_ == key.name_lowercase_)
349 return type_ < key.type_;
350 return name_lowercase_ < key.name_lowercase_;
351 }
352
AlertListeners(MDnsCache::UpdateType update_type,const ListenerKey & key,const RecordParsed * record)353 void MDnsClientImpl::Core::AlertListeners(
354 MDnsCache::UpdateType update_type,
355 const ListenerKey& key,
356 const RecordParsed* record) {
357 auto listener_map_iterator = listeners_.find(key);
358 if (listener_map_iterator == listeners_.end()) return;
359
360 for (auto& observer : *listener_map_iterator->second)
361 observer.HandleRecordUpdate(update_type, record);
362 }
363
AddListener(MDnsListenerImpl * listener)364 void MDnsClientImpl::Core::AddListener(
365 MDnsListenerImpl* listener) {
366 ListenerKey key(listener->GetName(), listener->GetType());
367
368 auto& observer_list = listeners_[key];
369 if (!observer_list)
370 observer_list = std::make_unique<ObserverListType>();
371
372 observer_list->AddObserver(listener);
373 }
374
RemoveListener(MDnsListenerImpl * listener)375 void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) {
376 ListenerKey key(listener->GetName(), listener->GetType());
377 auto observer_list_iterator = listeners_.find(key);
378
379 DCHECK(observer_list_iterator != listeners_.end());
380 DCHECK(observer_list_iterator->second->HasObserver(listener));
381
382 observer_list_iterator->second->RemoveObserver(listener);
383
384 // Remove the observer list from the map if it is empty
385 if (observer_list_iterator->second->empty()) {
386 // Schedule the actual removal for later in case the listener removal
387 // happens while iterating over the observer list.
388 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
389 FROM_HERE, base::BindOnce(&MDnsClientImpl::Core::CleanupObserverList,
390 AsWeakPtr(), key));
391 }
392 }
393
CleanupObserverList(const ListenerKey & key)394 void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey& key) {
395 auto found = listeners_.find(key);
396 if (found != listeners_.end() && found->second->empty()) {
397 listeners_.erase(found);
398 }
399 }
400
ScheduleCleanup(base::Time cleanup)401 void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup) {
402 // If cache is overfilled. Force an immediate cleanup.
403 if (cache_.IsCacheOverfilled())
404 cleanup = clock_->Now();
405
406 // Cleanup is already scheduled, no need to do anything.
407 if (cleanup == scheduled_cleanup_) {
408 return;
409 }
410 scheduled_cleanup_ = cleanup;
411
412 // This cancels the previously scheduled cleanup.
413 cleanup_timer_->Stop();
414
415 // If |cleanup| is empty, then no cleanup necessary.
416 if (cleanup != base::Time()) {
417 cleanup_timer_->Start(FROM_HERE,
418 std::max(base::TimeDelta(), cleanup - clock_->Now()),
419 base::BindOnce(&MDnsClientImpl::Core::DoCleanup,
420 base::Unretained(this)));
421 }
422 }
423
DoCleanup()424 void MDnsClientImpl::Core::DoCleanup() {
425 cache_.CleanupRecords(
426 clock_->Now(), base::BindRepeating(&MDnsClientImpl::Core::OnRecordRemoved,
427 base::Unretained(this)));
428
429 ScheduleCleanup(cache_.next_expiration());
430 }
431
OnRecordRemoved(const RecordParsed * record)432 void MDnsClientImpl::Core::OnRecordRemoved(
433 const RecordParsed* record) {
434 AlertListeners(MDnsCache::RecordRemoved,
435 ListenerKey(record->name(), record->type()), record);
436 }
437
QueryCache(uint16_t rrtype,const std::string & name,std::vector<const RecordParsed * > * records) const438 void MDnsClientImpl::Core::QueryCache(
439 uint16_t rrtype,
440 const std::string& name,
441 std::vector<const RecordParsed*>* records) const {
442 cache_.FindDnsRecords(rrtype, name, records, clock_->Now());
443 }
444
MDnsClientImpl()445 MDnsClientImpl::MDnsClientImpl()
446 : clock_(base::DefaultClock::GetInstance()),
447 cleanup_timer_(std::make_unique<base::OneShotTimer>()) {}
448
MDnsClientImpl(base::Clock * clock,std::unique_ptr<base::OneShotTimer> timer)449 MDnsClientImpl::MDnsClientImpl(base::Clock* clock,
450 std::unique_ptr<base::OneShotTimer> timer)
451 : clock_(clock), cleanup_timer_(std::move(timer)) {}
452
~MDnsClientImpl()453 MDnsClientImpl::~MDnsClientImpl() {
454 StopListening();
455 }
456
StartListening(MDnsSocketFactory * socket_factory)457 int MDnsClientImpl::StartListening(MDnsSocketFactory* socket_factory) {
458 DCHECK(!core_.get());
459 core_ = std::make_unique<Core>(clock_, cleanup_timer_.get());
460 int rv = core_->Init(socket_factory);
461 if (rv != OK) {
462 DCHECK_NE(ERR_IO_PENDING, rv);
463 core_.reset();
464 }
465 return rv;
466 }
467
StopListening()468 void MDnsClientImpl::StopListening() {
469 core_.reset();
470 }
471
IsListening() const472 bool MDnsClientImpl::IsListening() const {
473 return core_.get() != nullptr;
474 }
475
CreateListener(uint16_t rrtype,const std::string & name,MDnsListener::Delegate * delegate)476 std::unique_ptr<MDnsListener> MDnsClientImpl::CreateListener(
477 uint16_t rrtype,
478 const std::string& name,
479 MDnsListener::Delegate* delegate) {
480 return std::make_unique<MDnsListenerImpl>(rrtype, name, clock_, delegate,
481 this);
482 }
483
CreateTransaction(uint16_t rrtype,const std::string & name,int flags,const MDnsTransaction::ResultCallback & callback)484 std::unique_ptr<MDnsTransaction> MDnsClientImpl::CreateTransaction(
485 uint16_t rrtype,
486 const std::string& name,
487 int flags,
488 const MDnsTransaction::ResultCallback& callback) {
489 return std::make_unique<MDnsTransactionImpl>(rrtype, name, flags, callback,
490 this);
491 }
492
MDnsListenerImpl(uint16_t rrtype,const std::string & name,base::Clock * clock,MDnsListener::Delegate * delegate,MDnsClientImpl * client)493 MDnsListenerImpl::MDnsListenerImpl(uint16_t rrtype,
494 const std::string& name,
495 base::Clock* clock,
496 MDnsListener::Delegate* delegate,
497 MDnsClientImpl* client)
498 : rrtype_(rrtype),
499 name_(name),
500 clock_(clock),
501 client_(client),
502 delegate_(delegate) {}
503
~MDnsListenerImpl()504 MDnsListenerImpl::~MDnsListenerImpl() {
505 if (started_) {
506 DCHECK(client_->core());
507 client_->core()->RemoveListener(this);
508 }
509 }
510
Start()511 bool MDnsListenerImpl::Start() {
512 DCHECK(!started_);
513
514 started_ = true;
515
516 DCHECK(client_->core());
517 client_->core()->AddListener(this);
518
519 return true;
520 }
521
SetActiveRefresh(bool active_refresh)522 void MDnsListenerImpl::SetActiveRefresh(bool active_refresh) {
523 active_refresh_ = active_refresh;
524
525 if (started_) {
526 if (!active_refresh_) {
527 next_refresh_.Cancel();
528 } else if (last_update_ != base::Time()) {
529 ScheduleNextRefresh();
530 }
531 }
532 }
533
GetName() const534 const std::string& MDnsListenerImpl::GetName() const {
535 return name_;
536 }
537
GetType() const538 uint16_t MDnsListenerImpl::GetType() const {
539 return rrtype_;
540 }
541
HandleRecordUpdate(MDnsCache::UpdateType update_type,const RecordParsed * record)542 void MDnsListenerImpl::HandleRecordUpdate(MDnsCache::UpdateType update_type,
543 const RecordParsed* record) {
544 DCHECK(started_);
545
546 if (update_type != MDnsCache::RecordRemoved) {
547 ttl_ = record->ttl();
548 last_update_ = record->time_created();
549
550 ScheduleNextRefresh();
551 }
552
553 if (update_type != MDnsCache::NoChange) {
554 MDnsListener::UpdateType update_external;
555
556 switch (update_type) {
557 case MDnsCache::RecordAdded:
558 update_external = MDnsListener::RECORD_ADDED;
559 break;
560 case MDnsCache::RecordChanged:
561 update_external = MDnsListener::RECORD_CHANGED;
562 break;
563 case MDnsCache::RecordRemoved:
564 update_external = MDnsListener::RECORD_REMOVED;
565 break;
566 case MDnsCache::NoChange:
567 default:
568 NOTREACHED();
569 // Dummy assignment to suppress compiler warning.
570 update_external = MDnsListener::RECORD_CHANGED;
571 break;
572 }
573
574 delegate_->OnRecordUpdate(update_external, record);
575 }
576 }
577
AlertNsecRecord()578 void MDnsListenerImpl::AlertNsecRecord() {
579 DCHECK(started_);
580 delegate_->OnNsecRecord(name_, rrtype_);
581 }
582
ScheduleNextRefresh()583 void MDnsListenerImpl::ScheduleNextRefresh() {
584 DCHECK(last_update_ != base::Time());
585
586 if (!active_refresh_)
587 return;
588
589 // A zero TTL is a goodbye packet and should not be refreshed.
590 if (ttl_ == 0) {
591 next_refresh_.Cancel();
592 return;
593 }
594
595 next_refresh_.Reset(
596 base::BindRepeating(&MDnsListenerImpl::DoRefresh, AsWeakPtr()));
597
598 // Schedule refreshes at both 85% and 95% of the original TTL. These will both
599 // be canceled and rescheduled if the record's TTL is updated due to a
600 // response being received.
601 base::Time next_refresh1 =
602 last_update_ +
603 base::Milliseconds(static_cast<int>(base::Time::kMillisecondsPerSecond *
604 kListenerRefreshRatio1 * ttl_));
605
606 base::Time next_refresh2 =
607 last_update_ +
608 base::Milliseconds(static_cast<int>(base::Time::kMillisecondsPerSecond *
609 kListenerRefreshRatio2 * ttl_));
610
611 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
612 FROM_HERE, next_refresh_.callback(), next_refresh1 - clock_->Now());
613
614 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
615 FROM_HERE, next_refresh_.callback(), next_refresh2 - clock_->Now());
616 }
617
DoRefresh()618 void MDnsListenerImpl::DoRefresh() {
619 client_->core()->SendQuery(rrtype_, name_);
620 }
621
MDnsTransactionImpl(uint16_t rrtype,const std::string & name,int flags,const MDnsTransaction::ResultCallback & callback,MDnsClientImpl * client)622 MDnsTransactionImpl::MDnsTransactionImpl(
623 uint16_t rrtype,
624 const std::string& name,
625 int flags,
626 const MDnsTransaction::ResultCallback& callback,
627 MDnsClientImpl* client)
628 : rrtype_(rrtype),
629 name_(name),
630 callback_(callback),
631 client_(client),
632 flags_(flags) {
633 DCHECK((flags_ & MDnsTransaction::FLAG_MASK) == flags_);
634 DCHECK(flags_ & MDnsTransaction::QUERY_CACHE ||
635 flags_ & MDnsTransaction::QUERY_NETWORK);
636 }
637
~MDnsTransactionImpl()638 MDnsTransactionImpl::~MDnsTransactionImpl() {
639 timeout_.Cancel();
640 }
641
Start()642 bool MDnsTransactionImpl::Start() {
643 DCHECK(!started_);
644 started_ = true;
645
646 base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr();
647 if (flags_ & MDnsTransaction::QUERY_CACHE) {
648 ServeRecordsFromCache();
649
650 if (!weak_this || !is_active()) return true;
651 }
652
653 if (flags_ & MDnsTransaction::QUERY_NETWORK) {
654 return QueryAndListen();
655 }
656
657 // If this is a cache only query, signal that the transaction is over
658 // immediately.
659 SignalTransactionOver();
660 return true;
661 }
662
GetName() const663 const std::string& MDnsTransactionImpl::GetName() const {
664 return name_;
665 }
666
GetType() const667 uint16_t MDnsTransactionImpl::GetType() const {
668 return rrtype_;
669 }
670
CacheRecordFound(const RecordParsed * record)671 void MDnsTransactionImpl::CacheRecordFound(const RecordParsed* record) {
672 DCHECK(started_);
673 OnRecordUpdate(MDnsListener::RECORD_ADDED, record);
674 }
675
TriggerCallback(MDnsTransaction::Result result,const RecordParsed * record)676 void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result,
677 const RecordParsed* record) {
678 DCHECK(started_);
679 if (!is_active()) return;
680
681 // Ensure callback is run after touching all class state, so that
682 // the callback can delete the transaction.
683 MDnsTransaction::ResultCallback callback = callback_;
684
685 // Reset the transaction if it expects a single result, or if the result
686 // is a final one (everything except for a record).
687 if (flags_ & MDnsTransaction::SINGLE_RESULT ||
688 result != MDnsTransaction::RESULT_RECORD) {
689 Reset();
690 }
691
692 callback.Run(result, record);
693 }
694
Reset()695 void MDnsTransactionImpl::Reset() {
696 callback_.Reset();
697 listener_.reset();
698 timeout_.Cancel();
699 }
700
OnRecordUpdate(MDnsListener::UpdateType update,const RecordParsed * record)701 void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update,
702 const RecordParsed* record) {
703 DCHECK(started_);
704 if (update == MDnsListener::RECORD_ADDED ||
705 update == MDnsListener::RECORD_CHANGED)
706 TriggerCallback(MDnsTransaction::RESULT_RECORD, record);
707 }
708
SignalTransactionOver()709 void MDnsTransactionImpl::SignalTransactionOver() {
710 DCHECK(started_);
711 if (flags_ & MDnsTransaction::SINGLE_RESULT) {
712 TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS, nullptr);
713 } else {
714 TriggerCallback(MDnsTransaction::RESULT_DONE, nullptr);
715 }
716 }
717
ServeRecordsFromCache()718 void MDnsTransactionImpl::ServeRecordsFromCache() {
719 std::vector<const RecordParsed*> records;
720 base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr();
721
722 if (client_->core()) {
723 client_->core()->QueryCache(rrtype_, name_, &records);
724 for (auto i = records.begin(); i != records.end() && weak_this; ++i) {
725 weak_this->TriggerCallback(MDnsTransaction::RESULT_RECORD, *i);
726 }
727
728 #if defined(ENABLE_NSEC)
729 if (records.empty()) {
730 DCHECK(weak_this);
731 client_->core()->QueryCache(dns_protocol::kTypeNSEC, name_, &records);
732 if (!records.empty()) {
733 const NsecRecordRdata* rdata =
734 records.front()->rdata<NsecRecordRdata>();
735 DCHECK(rdata);
736 if (!rdata->GetBit(rrtype_))
737 weak_this->TriggerCallback(MDnsTransaction::RESULT_NSEC, nullptr);
738 }
739 }
740 #endif
741 }
742 }
743
QueryAndListen()744 bool MDnsTransactionImpl::QueryAndListen() {
745 listener_ = client_->CreateListener(rrtype_, name_, this);
746 if (!listener_->Start())
747 return false;
748
749 DCHECK(client_->core());
750 if (!client_->core()->SendQuery(rrtype_, name_))
751 return false;
752
753 timeout_.Reset(
754 base::BindOnce(&MDnsTransactionImpl::SignalTransactionOver, AsWeakPtr()));
755 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
756 FROM_HERE, timeout_.callback(), kTransactionTimeout);
757
758 return true;
759 }
760
OnNsecRecord(const std::string & name,unsigned type)761 void MDnsTransactionImpl::OnNsecRecord(const std::string& name, unsigned type) {
762 TriggerCallback(RESULT_NSEC, nullptr);
763 }
764
OnCachePurged()765 void MDnsTransactionImpl::OnCachePurged() {
766 // TODO(noamsml): Cache purge situations not yet implemented
767 }
768
769 } // namespace net
770