• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/mdns_cache.h"
6 
7 #include <algorithm>
8 #include <tuple>
9 #include <utility>
10 
11 #include "base/strings/string_number_conversions.h"
12 #include "base/strings/string_util.h"
13 #include "net/dns/public/dns_protocol.h"
14 #include "net/dns/record_parsed.h"
15 #include "net/dns/record_rdata.h"
16 
17 // TODO(noamsml): Recursive CNAME closure (backwards and forwards).
18 
19 namespace net {
20 
21 namespace {
22 constexpr size_t kDefaultEntryLimit = 100'000;
23 }  // namespace
24 
25 // The effective TTL given to records with a nominal zero TTL.
26 // Allows time for hosts to send updated records, as detailed in RFC 6762
27 // Section 10.1.
28 static const unsigned kZeroTTLSeconds = 1;
29 
Key(unsigned type,const std::string & name,const std::string & optional)30 MDnsCache::Key::Key(unsigned type,
31                     const std::string& name,
32                     const std::string& optional)
33     : type_(type),
34       name_lowercase_(base::ToLowerASCII(name)),
35       optional_(optional) {}
36 
37 MDnsCache::Key::Key(const MDnsCache::Key& other) = default;
38 
39 MDnsCache::Key& MDnsCache::Key::operator=(const MDnsCache::Key& other) =
40     default;
41 
42 MDnsCache::Key::~Key() = default;
43 
operator <(const MDnsCache::Key & other) const44 bool MDnsCache::Key::operator<(const MDnsCache::Key& other) const {
45   return std::tie(name_lowercase_, type_, optional_) <
46          std::tie(other.name_lowercase_, other.type_, other.optional_);
47 }
48 
operator ==(const MDnsCache::Key & key) const49 bool MDnsCache::Key::operator==(const MDnsCache::Key& key) const {
50   return type_ == key.type_ && name_lowercase_ == key.name_lowercase_ &&
51          optional_ == key.optional_;
52 }
53 
54 // static
CreateFor(const RecordParsed * record)55 MDnsCache::Key MDnsCache::Key::CreateFor(const RecordParsed* record) {
56   return Key(record->type(),
57              record->name(),
58              GetOptionalFieldForRecord(record));
59 }
60 
MDnsCache()61 MDnsCache::MDnsCache() : entry_limit_(kDefaultEntryLimit) {}
62 
63 MDnsCache::~MDnsCache() = default;
64 
LookupKey(const Key & key)65 const RecordParsed* MDnsCache::LookupKey(const Key& key) {
66   auto found = mdns_cache_.find(key);
67   if (found != mdns_cache_.end()) {
68     return found->second.get();
69   }
70   return nullptr;
71 }
72 
UpdateDnsRecord(std::unique_ptr<const RecordParsed> record)73 MDnsCache::UpdateType MDnsCache::UpdateDnsRecord(
74     std::unique_ptr<const RecordParsed> record) {
75   Key cache_key = Key::CreateFor(record.get());
76 
77   // Ignore "goodbye" packets for records not in cache.
78   if (record->ttl() == 0 && mdns_cache_.find(cache_key) == mdns_cache_.end())
79     return NoChange;
80 
81   base::Time new_expiration = GetEffectiveExpiration(record.get());
82   if (next_expiration_ != base::Time())
83     new_expiration = std::min(new_expiration, next_expiration_);
84 
85   std::pair<RecordMap::iterator, bool> insert_result =
86       mdns_cache_.insert(std::make_pair(cache_key, nullptr));
87   UpdateType type = NoChange;
88   if (insert_result.second) {
89     type = RecordAdded;
90   } else {
91     if (record->ttl() != 0 &&
92         !record->IsEqual(insert_result.first->second.get(), true)) {
93       type = RecordChanged;
94     }
95   }
96 
97   insert_result.first->second = std::move(record);
98   next_expiration_ = new_expiration;
99   return type;
100 }
101 
CleanupRecords(base::Time now,const RecordRemovedCallback & record_removed_callback)102 void MDnsCache::CleanupRecords(
103     base::Time now,
104     const RecordRemovedCallback& record_removed_callback) {
105   base::Time next_expiration;
106 
107   // TODO(crbug.com/946688): Make overfill pruning more intelligent than a bulk
108   // clearing of everything.
109   bool clear_cache = IsCacheOverfilled();
110 
111   // We are guaranteed that |next_expiration_| will be at or before the next
112   // expiration. This allows clients to eagrely call CleanupRecords with
113   // impunity.
114   if (now < next_expiration_ && !clear_cache)
115     return;
116 
117   for (auto i = mdns_cache_.begin(); i != mdns_cache_.end();) {
118     base::Time expiration = GetEffectiveExpiration(i->second.get());
119     if (clear_cache || now >= expiration) {
120       record_removed_callback.Run(i->second.get());
121       i = mdns_cache_.erase(i);
122     } else {
123       if (next_expiration == base::Time() ||  expiration < next_expiration) {
124         next_expiration = expiration;
125       }
126       ++i;
127     }
128   }
129 
130   next_expiration_ = next_expiration;
131 }
132 
FindDnsRecords(unsigned type,const std::string & name,std::vector<const RecordParsed * > * results,base::Time now) const133 void MDnsCache::FindDnsRecords(unsigned type,
134                                const std::string& name,
135                                std::vector<const RecordParsed*>* results,
136                                base::Time now) const {
137   DCHECK(results);
138   results->clear();
139 
140   const std::string name_lowercase = base::ToLowerASCII(name);
141   auto i = mdns_cache_.lower_bound(Key(type, name, ""));
142   for (; i != mdns_cache_.end(); ++i) {
143     if (i->first.name_lowercase() != name_lowercase ||
144         (type != 0 && i->first.type() != type)) {
145       break;
146     }
147 
148     const RecordParsed* record = i->second.get();
149 
150     // Records are deleted only upon request.
151     if (now >= GetEffectiveExpiration(record)) continue;
152 
153     results->push_back(record);
154   }
155 }
156 
RemoveRecord(const RecordParsed * record)157 std::unique_ptr<const RecordParsed> MDnsCache::RemoveRecord(
158     const RecordParsed* record) {
159   Key key = Key::CreateFor(record);
160   auto found = mdns_cache_.find(key);
161 
162   if (found != mdns_cache_.end() && found->second.get() == record) {
163     std::unique_ptr<const RecordParsed> result = std::move(found->second);
164     mdns_cache_.erase(key);
165     return result;
166   }
167 
168   return nullptr;
169 }
170 
IsCacheOverfilled() const171 bool MDnsCache::IsCacheOverfilled() const {
172   return mdns_cache_.size() > entry_limit_;
173 }
174 
175 // static
GetOptionalFieldForRecord(const RecordParsed * record)176 std::string MDnsCache::GetOptionalFieldForRecord(const RecordParsed* record) {
177   switch (record->type()) {
178     case PtrRecordRdata::kType: {
179       const PtrRecordRdata* rdata = record->rdata<PtrRecordRdata>();
180       return rdata->ptrdomain();
181     }
182     default:  // Most records are considered unique for our purposes
183       return "";
184   }
185 }
186 
187 // static
GetEffectiveExpiration(const RecordParsed * record)188 base::Time MDnsCache::GetEffectiveExpiration(const RecordParsed* record) {
189   base::TimeDelta ttl;
190 
191   if (record->ttl()) {
192     ttl = base::Seconds(record->ttl());
193   } else {
194     ttl = base::Seconds(kZeroTTLSeconds);
195   }
196 
197   return record->time_created() + ttl;
198 }
199 
200 }  // namespace net
201