• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/mdns_client_impl.h"
6 
7 #include <algorithm>
8 #include <cstdint>
9 #include <memory>
10 #include <optional>
11 #include <utility>
12 #include <vector>
13 
14 #include "base/containers/fixed_flat_set.h"
15 #include "base/functional/bind.h"
16 #include "base/location.h"
17 #include "base/metrics/histogram_functions.h"
18 #include "base/not_fatal_until.h"
19 #include "base/observer_list.h"
20 #include "base/ranges/algorithm.h"
21 #include "base/strings/string_util.h"
22 #include "base/task/single_thread_task_runner.h"
23 #include "base/time/clock.h"
24 #include "base/time/default_clock.h"
25 #include "base/time/time.h"
26 #include "base/timer/timer.h"
27 #include "net/base/net_errors.h"
28 #include "net/base/rand_callback.h"
29 #include "net/dns/dns_names_util.h"
30 #include "net/dns/public/dns_protocol.h"
31 #include "net/dns/public/util.h"
32 #include "net/dns/record_rdata.h"
33 #include "net/socket/datagram_socket.h"
34 
35 // TODO(gene): Remove this temporary method of disabling NSEC support once it
36 // becomes clear whether this feature should be
37 // supported. http://crbug.com/255232
38 #define ENABLE_NSEC
39 
40 namespace net {
41 
42 namespace {
43 
44 // The fractions of the record's original TTL after which an active listener
45 // (one that had |SetActiveRefresh(true)| called) will send a query to refresh
46 // its cache. This happens both at 85% of the original TTL and again at 95% of
47 // the original TTL.
48 const double kListenerRefreshRatio1 = 0.85;
49 const double kListenerRefreshRatio2 = 0.95;
50 
51 // These values are persisted to logs. Entries should not be renumbered and
52 // numeric values should never be reused.
53 enum class mdnsQueryType {
54   kInitial = 0,  // Initial mDNS query sent.
55   kRefresh = 1,  // Refresh mDNS query sent.
56   kMaxValue = kRefresh,
57 };
58 
RecordQueryMetric(mdnsQueryType query_type,std::string_view host)59 void RecordQueryMetric(mdnsQueryType query_type, std::string_view host) {
60   constexpr auto kPrintScanServices = base::MakeFixedFlatSet<std::string_view>({
61       "_ipps._tcp.local",
62       "_ipp._tcp.local",
63       "_pdl-datastream._tcp.local",
64       "_printer._tcp.local",
65       "_print._sub._ipps._tcp.local",
66       "_print._sub._ipp._tcp.local",
67       "_scanner._tcp.local",
68       "_uscans._tcp.local",
69       "_uscan._tcp.local",
70   });
71 
72   if (host.ends_with("_googlecast._tcp.local")) {
73     base::UmaHistogramEnumeration("Network.Mdns.Googlecast", query_type);
74   } else if (base::ranges::any_of(kPrintScanServices,
75                                   [&host](std::string_view service) {
76                                     return host.ends_with(service);
77                                   })) {
78     base::UmaHistogramEnumeration("Network.Mdns.PrintScan", query_type);
79   } else {
80     base::UmaHistogramEnumeration("Network.Mdns.Other", query_type);
81   }
82 }
83 
84 }  // namespace
85 
CreateSockets(std::vector<std::unique_ptr<DatagramServerSocket>> * sockets)86 void MDnsSocketFactoryImpl::CreateSockets(
87     std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) {
88   InterfaceIndexFamilyList interfaces(GetMDnsInterfacesToBind());
89   for (const auto& interface : interfaces) {
90     DCHECK(interface.second == ADDRESS_FAMILY_IPV4 ||
91            interface.second == ADDRESS_FAMILY_IPV6);
92     std::unique_ptr<DatagramServerSocket> socket(
93         CreateAndBindMDnsSocket(interface.second, interface.first, net_log_));
94     if (socket)
95       sockets->push_back(std::move(socket));
96   }
97 }
98 
SocketHandler(std::unique_ptr<DatagramServerSocket> socket,MDnsConnection * connection)99 MDnsConnection::SocketHandler::SocketHandler(
100     std::unique_ptr<DatagramServerSocket> socket,
101     MDnsConnection* connection)
102     : socket_(std::move(socket)),
103       connection_(connection),
104       response_(dns_protocol::kMaxMulticastSize) {}
105 
106 MDnsConnection::SocketHandler::~SocketHandler() = default;
107 
Start()108 int MDnsConnection::SocketHandler::Start() {
109   IPEndPoint end_point;
110   int rv = socket_->GetLocalAddress(&end_point);
111   if (rv != OK)
112     return rv;
113   DCHECK(end_point.GetFamily() == ADDRESS_FAMILY_IPV4 ||
114          end_point.GetFamily() == ADDRESS_FAMILY_IPV6);
115   multicast_addr_ = dns_util::GetMdnsGroupEndPoint(end_point.GetFamily());
116   return DoLoop(0);
117 }
118 
DoLoop(int rv)119 int MDnsConnection::SocketHandler::DoLoop(int rv) {
120   do {
121     if (rv > 0)
122       connection_->OnDatagramReceived(&response_, recv_addr_, rv);
123 
124     rv = socket_->RecvFrom(
125         response_.io_buffer(), response_.io_buffer_size(), &recv_addr_,
126         base::BindOnce(&MDnsConnection::SocketHandler::OnDatagramReceived,
127                        base::Unretained(this)));
128   } while (rv > 0);
129 
130   if (rv != ERR_IO_PENDING)
131     return rv;
132 
133   return OK;
134 }
135 
OnDatagramReceived(int rv)136 void MDnsConnection::SocketHandler::OnDatagramReceived(int rv) {
137   if (rv >= OK)
138     rv = DoLoop(rv);
139 
140   if (rv != OK)
141     connection_->PostOnError(this, rv);
142 }
143 
Send(const scoped_refptr<IOBuffer> & buffer,unsigned size)144 void MDnsConnection::SocketHandler::Send(const scoped_refptr<IOBuffer>& buffer,
145                                          unsigned size) {
146   if (send_in_progress_) {
147     send_queue_.emplace(buffer, size);
148     return;
149   }
150   int rv =
151       socket_->SendTo(buffer.get(), size, multicast_addr_,
152                       base::BindOnce(&MDnsConnection::SocketHandler::SendDone,
153                                      base::Unretained(this)));
154   if (rv == ERR_IO_PENDING) {
155     send_in_progress_ = true;
156   } else if (rv < OK) {
157     connection_->PostOnError(this, rv);
158   }
159 }
160 
SendDone(int rv)161 void MDnsConnection::SocketHandler::SendDone(int rv) {
162   DCHECK(send_in_progress_);
163   send_in_progress_ = false;
164   if (rv != OK)
165     connection_->PostOnError(this, rv);
166   while (!send_in_progress_ && !send_queue_.empty()) {
167     std::pair<scoped_refptr<IOBuffer>, unsigned> buffer = send_queue_.front();
168     send_queue_.pop();
169     Send(buffer.first, buffer.second);
170   }
171 }
172 
MDnsConnection(MDnsConnection::Delegate * delegate)173 MDnsConnection::MDnsConnection(MDnsConnection::Delegate* delegate)
174     : delegate_(delegate) {}
175 
176 MDnsConnection::~MDnsConnection() = default;
177 
Init(MDnsSocketFactory * socket_factory)178 int MDnsConnection::Init(MDnsSocketFactory* socket_factory) {
179   std::vector<std::unique_ptr<DatagramServerSocket>> sockets;
180   socket_factory->CreateSockets(&sockets);
181 
182   for (std::unique_ptr<DatagramServerSocket>& socket : sockets) {
183     socket_handlers_.push_back(std::make_unique<MDnsConnection::SocketHandler>(
184         std::move(socket), this));
185   }
186 
187   // All unbound sockets need to be bound before processing untrusted input.
188   // This is done for security reasons, so that an attacker can't get an unbound
189   // socket.
190   int last_failure = ERR_FAILED;
191   for (size_t i = 0; i < socket_handlers_.size();) {
192     int rv = socket_handlers_[i]->Start();
193     if (rv != OK) {
194       last_failure = rv;
195       socket_handlers_.erase(socket_handlers_.begin() + i);
196       VLOG(1) << "Start failed, socket=" << i << ", error=" << rv;
197     } else {
198       ++i;
199     }
200   }
201   VLOG(1) << "Sockets ready:" << socket_handlers_.size();
202   DCHECK_NE(ERR_IO_PENDING, last_failure);
203   return socket_handlers_.empty() ? last_failure : OK;
204 }
205 
Send(const scoped_refptr<IOBuffer> & buffer,unsigned size)206 void MDnsConnection::Send(const scoped_refptr<IOBuffer>& buffer,
207                           unsigned size) {
208   for (std::unique_ptr<SocketHandler>& handler : socket_handlers_)
209     handler->Send(buffer, size);
210 }
211 
PostOnError(SocketHandler * loop,int rv)212 void MDnsConnection::PostOnError(SocketHandler* loop, int rv) {
213   int id = 0;
214   for (const auto& it : socket_handlers_) {
215     if (it.get() == loop)
216       break;
217     id++;
218   }
219   VLOG(1) << "Socket error. id=" << id << ", error=" << rv;
220   // Post to allow deletion of this object by delegate.
221   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
222       FROM_HERE, base::BindOnce(&MDnsConnection::OnError,
223                                 weak_ptr_factory_.GetWeakPtr(), rv));
224 }
225 
OnError(int rv)226 void MDnsConnection::OnError(int rv) {
227   // TODO(noamsml): Specific handling of intermittent errors that can be handled
228   // in the connection.
229   delegate_->OnConnectionError(rv);
230 }
231 
OnDatagramReceived(DnsResponse * response,const IPEndPoint & recv_addr,int bytes_read)232 void MDnsConnection::OnDatagramReceived(
233     DnsResponse* response,
234     const IPEndPoint& recv_addr,
235     int bytes_read) {
236   // TODO(noamsml): More sophisticated error handling.
237   DCHECK_GT(bytes_read, 0);
238   delegate_->HandlePacket(response, bytes_read);
239 }
240 
Core(base::Clock * clock,base::OneShotTimer * timer)241 MDnsClientImpl::Core::Core(base::Clock* clock, base::OneShotTimer* timer)
242     : clock_(clock),
243       cleanup_timer_(timer),
244       connection_(
245           std::make_unique<MDnsConnection>((MDnsConnection::Delegate*)this)) {
246   DCHECK(cleanup_timer_);
247   DCHECK(!cleanup_timer_->IsRunning());
248 }
249 
~Core()250 MDnsClientImpl::Core::~Core() {
251   cleanup_timer_->Stop();
252 }
253 
Init(MDnsSocketFactory * socket_factory)254 int MDnsClientImpl::Core::Init(MDnsSocketFactory* socket_factory) {
255   CHECK(!cleanup_timer_->IsRunning());
256   return connection_->Init(socket_factory);
257 }
258 
SendQuery(uint16_t rrtype,const std::string & name)259 bool MDnsClientImpl::Core::SendQuery(uint16_t rrtype, const std::string& name) {
260   std::optional<std::vector<uint8_t>> name_dns =
261       dns_names_util::DottedNameToNetwork(name);
262   if (!name_dns.has_value())
263     return false;
264 
265   DnsQuery query(0, name_dns.value(), rrtype);
266   query.set_flags(0);  // Remove the RD flag from the query. It is unneeded.
267 
268   connection_->Send(query.io_buffer(), query.io_buffer()->size());
269   return true;
270 }
271 
HandlePacket(DnsResponse * response,int bytes_read)272 void MDnsClientImpl::Core::HandlePacket(DnsResponse* response,
273                                         int bytes_read) {
274   unsigned offset;
275   // Note: We store cache keys rather than record pointers to avoid
276   // erroneous behavior in case a packet contains multiple exclusive
277   // records with the same type and name.
278   std::map<MDnsCache::Key, MDnsCache::UpdateType> update_keys;
279   DCHECK_GT(bytes_read, 0);
280   if (!response->InitParseWithoutQuery(bytes_read)) {
281     DVLOG(1) << "Could not understand an mDNS packet.";
282     return;  // Message is unreadable.
283   }
284 
285   // TODO(noamsml): duplicate query suppression.
286   if (!(response->flags() & dns_protocol::kFlagResponse))
287     return;  // Message is a query. ignore it.
288 
289   DnsRecordParser parser = response->Parser();
290   unsigned answer_count = response->answer_count() +
291       response->additional_answer_count();
292 
293   for (unsigned i = 0; i < answer_count; i++) {
294     offset = parser.GetOffset();
295     std::unique_ptr<const RecordParsed> record =
296         RecordParsed::CreateFrom(&parser, clock_->Now());
297 
298     if (!record) {
299       DVLOG(1) << "Could not understand an mDNS record.";
300 
301       if (offset == parser.GetOffset()) {
302         DVLOG(1) << "Abandoned parsing the rest of the packet.";
303         return;  // The parser did not advance, abort reading the packet.
304       } else {
305         continue;  // We may be able to extract other records from the packet.
306       }
307     }
308 
309     if ((record->klass() & dns_protocol::kMDnsClassMask) !=
310         dns_protocol::kClassIN) {
311       DVLOG(1) << "Received an mDNS record with non-IN class. Ignoring.";
312       continue;  // Ignore all records not in the IN class.
313     }
314 
315     MDnsCache::Key update_key = MDnsCache::Key::CreateFor(record.get());
316     MDnsCache::UpdateType update = cache_.UpdateDnsRecord(std::move(record));
317 
318     // Cleanup time may have changed.
319     ScheduleCleanup(cache_.next_expiration());
320 
321     update_keys.emplace(update_key, update);
322   }
323 
324   for (const auto& update_key : update_keys) {
325     const RecordParsed* record = cache_.LookupKey(update_key.first);
326     if (!record)
327       continue;
328 
329     if (record->type() == dns_protocol::kTypeNSEC) {
330 #if defined(ENABLE_NSEC)
331       NotifyNsecRecord(record);
332 #endif
333     } else {
334       AlertListeners(update_key.second,
335                      ListenerKey(record->name(), record->type()), record);
336     }
337   }
338 }
339 
NotifyNsecRecord(const RecordParsed * record)340 void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) {
341   DCHECK_EQ(dns_protocol::kTypeNSEC, record->type());
342   const NsecRecordRdata* rdata = record->rdata<NsecRecordRdata>();
343   DCHECK(rdata);
344 
345   // Remove all cached records matching the nonexistent RR types.
346   std::vector<const RecordParsed*> records_to_remove;
347 
348   cache_.FindDnsRecords(0, record->name(), &records_to_remove, clock_->Now());
349 
350   for (const auto* record_to_remove : records_to_remove) {
351     if (record_to_remove->type() == dns_protocol::kTypeNSEC)
352       continue;
353     if (!rdata->GetBit(record_to_remove->type())) {
354       std::unique_ptr<const RecordParsed> record_removed =
355           cache_.RemoveRecord(record_to_remove);
356       DCHECK(record_removed);
357       OnRecordRemoved(record_removed.get());
358     }
359   }
360 
361   // Alert all listeners waiting for the nonexistent RR types.
362   ListenerKey key(record->name(), 0);
363   auto i = listeners_.upper_bound(key);
364   for (; i != listeners_.end() &&
365          i->first.name_lowercase() == key.name_lowercase();
366        i++) {
367     if (!rdata->GetBit(i->first.type())) {
368       for (auto& observer : *i->second)
369         observer.AlertNsecRecord();
370     }
371   }
372 }
373 
OnConnectionError(int error)374 void MDnsClientImpl::Core::OnConnectionError(int error) {
375   // TODO(noamsml): On connection error, recreate connection and flush cache.
376   VLOG(1) << "MDNS OnConnectionError (code: " << error << ")";
377 }
378 
ListenerKey(const std::string & name,uint16_t type)379 MDnsClientImpl::Core::ListenerKey::ListenerKey(const std::string& name,
380                                                uint16_t type)
381     : name_lowercase_(base::ToLowerASCII(name)), type_(type) {}
382 
operator <(const MDnsClientImpl::Core::ListenerKey & key) const383 bool MDnsClientImpl::Core::ListenerKey::operator<(
384     const MDnsClientImpl::Core::ListenerKey& key) const {
385   if (name_lowercase_ == key.name_lowercase_)
386     return type_ < key.type_;
387   return name_lowercase_ < key.name_lowercase_;
388 }
389 
AlertListeners(MDnsCache::UpdateType update_type,const ListenerKey & key,const RecordParsed * record)390 void MDnsClientImpl::Core::AlertListeners(
391     MDnsCache::UpdateType update_type,
392     const ListenerKey& key,
393     const RecordParsed* record) {
394   auto listener_map_iterator = listeners_.find(key);
395   if (listener_map_iterator == listeners_.end()) return;
396 
397   for (auto& observer : *listener_map_iterator->second)
398     observer.HandleRecordUpdate(update_type, record);
399 }
400 
AddListener(MDnsListenerImpl * listener)401 void MDnsClientImpl::Core::AddListener(
402     MDnsListenerImpl* listener) {
403   ListenerKey key(listener->GetName(), listener->GetType());
404 
405   auto& observer_list = listeners_[key];
406   if (!observer_list)
407     observer_list = std::make_unique<ObserverListType>();
408 
409   observer_list->AddObserver(listener);
410 }
411 
RemoveListener(MDnsListenerImpl * listener)412 void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) {
413   ListenerKey key(listener->GetName(), listener->GetType());
414   auto observer_list_iterator = listeners_.find(key);
415 
416   CHECK(observer_list_iterator != listeners_.end(), base::NotFatalUntil::M130);
417   DCHECK(observer_list_iterator->second->HasObserver(listener));
418 
419   observer_list_iterator->second->RemoveObserver(listener);
420 
421   // Remove the observer list from the map if it is empty
422   if (observer_list_iterator->second->empty()) {
423     // Schedule the actual removal for later in case the listener removal
424     // happens while iterating over the observer list.
425     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
426         FROM_HERE, base::BindOnce(&MDnsClientImpl::Core::CleanupObserverList,
427                                   weak_ptr_factory_.GetWeakPtr(), key));
428   }
429 }
430 
CleanupObserverList(const ListenerKey & key)431 void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey& key) {
432   auto found = listeners_.find(key);
433   if (found != listeners_.end() && found->second->empty()) {
434     listeners_.erase(found);
435   }
436 }
437 
ScheduleCleanup(base::Time cleanup)438 void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup) {
439   // If cache is overfilled. Force an immediate cleanup.
440   if (cache_.IsCacheOverfilled())
441     cleanup = clock_->Now();
442 
443   // Cleanup is already scheduled, no need to do anything.
444   if (cleanup == scheduled_cleanup_) {
445     return;
446   }
447   scheduled_cleanup_ = cleanup;
448 
449   // This cancels the previously scheduled cleanup.
450   cleanup_timer_->Stop();
451 
452   // If |cleanup| is empty, then no cleanup necessary.
453   if (cleanup != base::Time()) {
454     cleanup_timer_->Start(FROM_HERE,
455                           std::max(base::TimeDelta(), cleanup - clock_->Now()),
456                           base::BindOnce(&MDnsClientImpl::Core::DoCleanup,
457                                          base::Unretained(this)));
458   }
459 }
460 
DoCleanup()461 void MDnsClientImpl::Core::DoCleanup() {
462   cache_.CleanupRecords(
463       clock_->Now(), base::BindRepeating(&MDnsClientImpl::Core::OnRecordRemoved,
464                                          base::Unretained(this)));
465 
466   ScheduleCleanup(cache_.next_expiration());
467 }
468 
OnRecordRemoved(const RecordParsed * record)469 void MDnsClientImpl::Core::OnRecordRemoved(
470     const RecordParsed* record) {
471   AlertListeners(MDnsCache::RecordRemoved,
472                  ListenerKey(record->name(), record->type()), record);
473 }
474 
QueryCache(uint16_t rrtype,const std::string & name,std::vector<const RecordParsed * > * records) const475 void MDnsClientImpl::Core::QueryCache(
476     uint16_t rrtype,
477     const std::string& name,
478     std::vector<const RecordParsed*>* records) const {
479   cache_.FindDnsRecords(rrtype, name, records, clock_->Now());
480 }
481 
MDnsClientImpl()482 MDnsClientImpl::MDnsClientImpl()
483     : clock_(base::DefaultClock::GetInstance()),
484       cleanup_timer_(std::make_unique<base::OneShotTimer>()) {}
485 
MDnsClientImpl(base::Clock * clock,std::unique_ptr<base::OneShotTimer> timer)486 MDnsClientImpl::MDnsClientImpl(base::Clock* clock,
487                                std::unique_ptr<base::OneShotTimer> timer)
488     : clock_(clock), cleanup_timer_(std::move(timer)) {}
489 
~MDnsClientImpl()490 MDnsClientImpl::~MDnsClientImpl() {
491   StopListening();
492 }
493 
StartListening(MDnsSocketFactory * socket_factory)494 int MDnsClientImpl::StartListening(MDnsSocketFactory* socket_factory) {
495   DCHECK(!core_.get());
496   core_ = std::make_unique<Core>(clock_, cleanup_timer_.get());
497   int rv = core_->Init(socket_factory);
498   if (rv != OK) {
499     DCHECK_NE(ERR_IO_PENDING, rv);
500     core_.reset();
501   }
502   return rv;
503 }
504 
StopListening()505 void MDnsClientImpl::StopListening() {
506   core_.reset();
507 }
508 
IsListening() const509 bool MDnsClientImpl::IsListening() const {
510   return core_.get() != nullptr;
511 }
512 
CreateListener(uint16_t rrtype,const std::string & name,MDnsListener::Delegate * delegate)513 std::unique_ptr<MDnsListener> MDnsClientImpl::CreateListener(
514     uint16_t rrtype,
515     const std::string& name,
516     MDnsListener::Delegate* delegate) {
517   return std::make_unique<MDnsListenerImpl>(rrtype, name, clock_, delegate,
518                                             this);
519 }
520 
CreateTransaction(uint16_t rrtype,const std::string & name,int flags,const MDnsTransaction::ResultCallback & callback)521 std::unique_ptr<MDnsTransaction> MDnsClientImpl::CreateTransaction(
522     uint16_t rrtype,
523     const std::string& name,
524     int flags,
525     const MDnsTransaction::ResultCallback& callback) {
526   return std::make_unique<MDnsTransactionImpl>(rrtype, name, flags, callback,
527                                                this);
528 }
529 
MDnsListenerImpl(uint16_t rrtype,const std::string & name,base::Clock * clock,MDnsListener::Delegate * delegate,MDnsClientImpl * client)530 MDnsListenerImpl::MDnsListenerImpl(uint16_t rrtype,
531                                    const std::string& name,
532                                    base::Clock* clock,
533                                    MDnsListener::Delegate* delegate,
534                                    MDnsClientImpl* client)
535     : rrtype_(rrtype),
536       name_(name),
537       clock_(clock),
538       client_(client),
539       delegate_(delegate) {}
540 
~MDnsListenerImpl()541 MDnsListenerImpl::~MDnsListenerImpl() {
542   if (started_) {
543     DCHECK(client_->core());
544     client_->core()->RemoveListener(this);
545   }
546 }
547 
Start()548 bool MDnsListenerImpl::Start() {
549   DCHECK(!started_);
550 
551   started_ = true;
552 
553   DCHECK(client_->core());
554   client_->core()->AddListener(this);
555 
556   return true;
557 }
558 
SetActiveRefresh(bool active_refresh)559 void MDnsListenerImpl::SetActiveRefresh(bool active_refresh) {
560   active_refresh_ = active_refresh;
561 
562   if (started_) {
563     if (!active_refresh_) {
564       next_refresh_.Cancel();
565     } else if (last_update_ != base::Time()) {
566       ScheduleNextRefresh();
567     }
568   }
569 }
570 
GetName() const571 const std::string& MDnsListenerImpl::GetName() const {
572   return name_;
573 }
574 
GetType() const575 uint16_t MDnsListenerImpl::GetType() const {
576   return rrtype_;
577 }
578 
HandleRecordUpdate(MDnsCache::UpdateType update_type,const RecordParsed * record)579 void MDnsListenerImpl::HandleRecordUpdate(MDnsCache::UpdateType update_type,
580                                           const RecordParsed* record) {
581   DCHECK(started_);
582 
583   if (update_type != MDnsCache::RecordRemoved) {
584     ttl_ = record->ttl();
585     last_update_ = record->time_created();
586 
587     ScheduleNextRefresh();
588   }
589 
590   if (update_type != MDnsCache::NoChange) {
591     MDnsListener::UpdateType update_external;
592 
593     switch (update_type) {
594       case MDnsCache::RecordAdded:
595         update_external = MDnsListener::RECORD_ADDED;
596         break;
597       case MDnsCache::RecordChanged:
598         update_external = MDnsListener::RECORD_CHANGED;
599         break;
600       case MDnsCache::RecordRemoved:
601         update_external = MDnsListener::RECORD_REMOVED;
602         break;
603       case MDnsCache::NoChange:
604       default:
605         NOTREACHED();
606     }
607 
608     delegate_->OnRecordUpdate(update_external, record);
609   }
610 }
611 
AlertNsecRecord()612 void MDnsListenerImpl::AlertNsecRecord() {
613   DCHECK(started_);
614   delegate_->OnNsecRecord(name_, rrtype_);
615 }
616 
ScheduleNextRefresh()617 void MDnsListenerImpl::ScheduleNextRefresh() {
618   DCHECK(last_update_ != base::Time());
619 
620   if (!active_refresh_)
621     return;
622 
623   // A zero TTL is a goodbye packet and should not be refreshed.
624   if (ttl_ == 0) {
625     next_refresh_.Cancel();
626     return;
627   }
628 
629   next_refresh_.Reset(base::BindRepeating(&MDnsListenerImpl::DoRefresh,
630                                           weak_ptr_factory_.GetWeakPtr()));
631 
632   // Schedule refreshes at both 85% and 95% of the original TTL. These will both
633   // be canceled and rescheduled if the record's TTL is updated due to a
634   // response being received.
635   base::Time next_refresh1 =
636       last_update_ +
637       base::Milliseconds(static_cast<int>(base::Time::kMillisecondsPerSecond *
638                                           kListenerRefreshRatio1 * ttl_));
639 
640   base::Time next_refresh2 =
641       last_update_ +
642       base::Milliseconds(static_cast<int>(base::Time::kMillisecondsPerSecond *
643                                           kListenerRefreshRatio2 * ttl_));
644 
645   base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
646       FROM_HERE, next_refresh_.callback(), next_refresh1 - clock_->Now());
647 
648   base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
649       FROM_HERE, next_refresh_.callback(), next_refresh2 - clock_->Now());
650 }
651 
DoRefresh()652 void MDnsListenerImpl::DoRefresh() {
653   RecordQueryMetric(mdnsQueryType::kRefresh, name_);
654   client_->core()->SendQuery(rrtype_, name_);
655 }
656 
MDnsTransactionImpl(uint16_t rrtype,const std::string & name,int flags,const MDnsTransaction::ResultCallback & callback,MDnsClientImpl * client)657 MDnsTransactionImpl::MDnsTransactionImpl(
658     uint16_t rrtype,
659     const std::string& name,
660     int flags,
661     const MDnsTransaction::ResultCallback& callback,
662     MDnsClientImpl* client)
663     : rrtype_(rrtype),
664       name_(name),
665       callback_(callback),
666       client_(client),
667       flags_(flags) {
668   DCHECK((flags_ & MDnsTransaction::FLAG_MASK) == flags_);
669   DCHECK(flags_ & MDnsTransaction::QUERY_CACHE ||
670          flags_ & MDnsTransaction::QUERY_NETWORK);
671 }
672 
~MDnsTransactionImpl()673 MDnsTransactionImpl::~MDnsTransactionImpl() {
674   timeout_.Cancel();
675 }
676 
Start()677 bool MDnsTransactionImpl::Start() {
678   DCHECK(!started_);
679   started_ = true;
680 
681   base::WeakPtr<MDnsTransactionImpl> weak_this = weak_ptr_factory_.GetWeakPtr();
682   if (flags_ & MDnsTransaction::QUERY_CACHE) {
683     ServeRecordsFromCache();
684 
685     if (!weak_this || !is_active()) return true;
686   }
687 
688   if (flags_ & MDnsTransaction::QUERY_NETWORK) {
689     return QueryAndListen();
690   }
691 
692   // If this is a cache only query, signal that the transaction is over
693   // immediately.
694   SignalTransactionOver();
695   return true;
696 }
697 
GetName() const698 const std::string& MDnsTransactionImpl::GetName() const {
699   return name_;
700 }
701 
GetType() const702 uint16_t MDnsTransactionImpl::GetType() const {
703   return rrtype_;
704 }
705 
CacheRecordFound(const RecordParsed * record)706 void MDnsTransactionImpl::CacheRecordFound(const RecordParsed* record) {
707   DCHECK(started_);
708   OnRecordUpdate(MDnsListener::RECORD_ADDED, record);
709 }
710 
TriggerCallback(MDnsTransaction::Result result,const RecordParsed * record)711 void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result,
712                                           const RecordParsed* record) {
713   DCHECK(started_);
714   if (!is_active()) return;
715 
716   // Ensure callback is run after touching all class state, so that
717   // the callback can delete the transaction.
718   MDnsTransaction::ResultCallback callback = callback_;
719 
720   // Reset the transaction if it expects a single result, or if the result
721   // is a final one (everything except for a record).
722   if (flags_ & MDnsTransaction::SINGLE_RESULT ||
723       result != MDnsTransaction::RESULT_RECORD) {
724     Reset();
725   }
726 
727   callback.Run(result, record);
728 }
729 
Reset()730 void MDnsTransactionImpl::Reset() {
731   callback_.Reset();
732   listener_.reset();
733   timeout_.Cancel();
734 }
735 
OnRecordUpdate(MDnsListener::UpdateType update,const RecordParsed * record)736 void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update,
737                                          const RecordParsed* record) {
738   DCHECK(started_);
739   if (update ==  MDnsListener::RECORD_ADDED ||
740       update == MDnsListener::RECORD_CHANGED)
741     TriggerCallback(MDnsTransaction::RESULT_RECORD, record);
742 }
743 
SignalTransactionOver()744 void MDnsTransactionImpl::SignalTransactionOver() {
745   DCHECK(started_);
746   if (flags_ & MDnsTransaction::SINGLE_RESULT) {
747     TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS, nullptr);
748   } else {
749     TriggerCallback(MDnsTransaction::RESULT_DONE, nullptr);
750   }
751 }
752 
ServeRecordsFromCache()753 void MDnsTransactionImpl::ServeRecordsFromCache() {
754   std::vector<const RecordParsed*> records;
755   base::WeakPtr<MDnsTransactionImpl> weak_this = weak_ptr_factory_.GetWeakPtr();
756 
757   if (client_->core()) {
758     client_->core()->QueryCache(rrtype_, name_, &records);
759     for (auto i = records.begin(); i != records.end() && weak_this; ++i) {
760       weak_this->TriggerCallback(MDnsTransaction::RESULT_RECORD, *i);
761     }
762 
763 #if defined(ENABLE_NSEC)
764     if (records.empty()) {
765       DCHECK(weak_this);
766       client_->core()->QueryCache(dns_protocol::kTypeNSEC, name_, &records);
767       if (!records.empty()) {
768         const NsecRecordRdata* rdata =
769             records.front()->rdata<NsecRecordRdata>();
770         DCHECK(rdata);
771         if (!rdata->GetBit(rrtype_))
772           weak_this->TriggerCallback(MDnsTransaction::RESULT_NSEC, nullptr);
773       }
774     }
775 #endif
776   }
777 }
778 
QueryAndListen()779 bool MDnsTransactionImpl::QueryAndListen() {
780   listener_ = client_->CreateListener(rrtype_, name_, this);
781   if (!listener_->Start())
782     return false;
783 
784   DCHECK(client_->core());
785   RecordQueryMetric(mdnsQueryType::kInitial, name_);
786   if (!client_->core()->SendQuery(rrtype_, name_))
787     return false;
788 
789   timeout_.Reset(base::BindOnce(&MDnsTransactionImpl::SignalTransactionOver,
790                                 weak_ptr_factory_.GetWeakPtr()));
791   base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
792       FROM_HERE, timeout_.callback(), kTransactionTimeout);
793 
794   return true;
795 }
796 
OnNsecRecord(const std::string & name,unsigned type)797 void MDnsTransactionImpl::OnNsecRecord(const std::string& name, unsigned type) {
798   TriggerCallback(RESULT_NSEC, nullptr);
799 }
800 
OnCachePurged()801 void MDnsTransactionImpl::OnCachePurged() {
802   // TODO(noamsml): Cache purge situations not yet implemented
803 }
804 
805 }  // namespace net
806