1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/dns/mdns_cache.h"
6
7 #include <algorithm>
8 #include <tuple>
9 #include <utility>
10
11 #include "base/strings/string_number_conversions.h"
12 #include "base/strings/string_util.h"
13 #include "net/dns/public/dns_protocol.h"
14 #include "net/dns/record_parsed.h"
15 #include "net/dns/record_rdata.h"
16
17 // TODO(noamsml): Recursive CNAME closure (backwards and forwards).
18
19 namespace net {
20
21 namespace {
22 constexpr size_t kDefaultEntryLimit = 100'000;
23 } // namespace
24
25 // The effective TTL given to records with a nominal zero TTL.
26 // Allows time for hosts to send updated records, as detailed in RFC 6762
27 // Section 10.1.
28 static const unsigned kZeroTTLSeconds = 1;
29
Key(unsigned type,const std::string & name,const std::string & optional)30 MDnsCache::Key::Key(unsigned type,
31 const std::string& name,
32 const std::string& optional)
33 : type_(type),
34 name_lowercase_(base::ToLowerASCII(name)),
35 optional_(optional) {}
36
37 MDnsCache::Key::Key(const MDnsCache::Key& other) = default;
38
39 MDnsCache::Key& MDnsCache::Key::operator=(const MDnsCache::Key& other) =
40 default;
41
42 MDnsCache::Key::~Key() = default;
43
operator <(const MDnsCache::Key & other) const44 bool MDnsCache::Key::operator<(const MDnsCache::Key& other) const {
45 return std::tie(name_lowercase_, type_, optional_) <
46 std::tie(other.name_lowercase_, other.type_, other.optional_);
47 }
48
operator ==(const MDnsCache::Key & key) const49 bool MDnsCache::Key::operator==(const MDnsCache::Key& key) const {
50 return type_ == key.type_ && name_lowercase_ == key.name_lowercase_ &&
51 optional_ == key.optional_;
52 }
53
54 // static
CreateFor(const RecordParsed * record)55 MDnsCache::Key MDnsCache::Key::CreateFor(const RecordParsed* record) {
56 return Key(record->type(),
57 record->name(),
58 GetOptionalFieldForRecord(record));
59 }
60
MDnsCache()61 MDnsCache::MDnsCache() : entry_limit_(kDefaultEntryLimit) {}
62
63 MDnsCache::~MDnsCache() = default;
64
LookupKey(const Key & key)65 const RecordParsed* MDnsCache::LookupKey(const Key& key) {
66 auto found = mdns_cache_.find(key);
67 if (found != mdns_cache_.end()) {
68 return found->second.get();
69 }
70 return nullptr;
71 }
72
UpdateDnsRecord(std::unique_ptr<const RecordParsed> record)73 MDnsCache::UpdateType MDnsCache::UpdateDnsRecord(
74 std::unique_ptr<const RecordParsed> record) {
75 Key cache_key = Key::CreateFor(record.get());
76
77 // Ignore "goodbye" packets for records not in cache.
78 if (record->ttl() == 0 && mdns_cache_.find(cache_key) == mdns_cache_.end())
79 return NoChange;
80
81 base::Time new_expiration = GetEffectiveExpiration(record.get());
82 if (next_expiration_ != base::Time())
83 new_expiration = std::min(new_expiration, next_expiration_);
84
85 std::pair<RecordMap::iterator, bool> insert_result =
86 mdns_cache_.insert(std::make_pair(cache_key, nullptr));
87 UpdateType type = NoChange;
88 if (insert_result.second) {
89 type = RecordAdded;
90 } else {
91 if (record->ttl() != 0 &&
92 !record->IsEqual(insert_result.first->second.get(), true)) {
93 type = RecordChanged;
94 }
95 }
96
97 insert_result.first->second = std::move(record);
98 next_expiration_ = new_expiration;
99 return type;
100 }
101
CleanupRecords(base::Time now,const RecordRemovedCallback & record_removed_callback)102 void MDnsCache::CleanupRecords(
103 base::Time now,
104 const RecordRemovedCallback& record_removed_callback) {
105 base::Time next_expiration;
106
107 // TODO(crbug.com/946688): Make overfill pruning more intelligent than a bulk
108 // clearing of everything.
109 bool clear_cache = IsCacheOverfilled();
110
111 // We are guaranteed that |next_expiration_| will be at or before the next
112 // expiration. This allows clients to eagrely call CleanupRecords with
113 // impunity.
114 if (now < next_expiration_ && !clear_cache)
115 return;
116
117 for (auto i = mdns_cache_.begin(); i != mdns_cache_.end();) {
118 base::Time expiration = GetEffectiveExpiration(i->second.get());
119 if (clear_cache || now >= expiration) {
120 record_removed_callback.Run(i->second.get());
121 i = mdns_cache_.erase(i);
122 } else {
123 if (next_expiration == base::Time() || expiration < next_expiration) {
124 next_expiration = expiration;
125 }
126 ++i;
127 }
128 }
129
130 next_expiration_ = next_expiration;
131 }
132
FindDnsRecords(unsigned type,const std::string & name,std::vector<const RecordParsed * > * results,base::Time now) const133 void MDnsCache::FindDnsRecords(unsigned type,
134 const std::string& name,
135 std::vector<const RecordParsed*>* results,
136 base::Time now) const {
137 DCHECK(results);
138 results->clear();
139
140 const std::string name_lowercase = base::ToLowerASCII(name);
141 auto i = mdns_cache_.lower_bound(Key(type, name, ""));
142 for (; i != mdns_cache_.end(); ++i) {
143 if (i->first.name_lowercase() != name_lowercase ||
144 (type != 0 && i->first.type() != type)) {
145 break;
146 }
147
148 const RecordParsed* record = i->second.get();
149
150 // Records are deleted only upon request.
151 if (now >= GetEffectiveExpiration(record)) continue;
152
153 results->push_back(record);
154 }
155 }
156
RemoveRecord(const RecordParsed * record)157 std::unique_ptr<const RecordParsed> MDnsCache::RemoveRecord(
158 const RecordParsed* record) {
159 Key key = Key::CreateFor(record);
160 auto found = mdns_cache_.find(key);
161
162 if (found != mdns_cache_.end() && found->second.get() == record) {
163 std::unique_ptr<const RecordParsed> result = std::move(found->second);
164 mdns_cache_.erase(key);
165 return result;
166 }
167
168 return nullptr;
169 }
170
IsCacheOverfilled() const171 bool MDnsCache::IsCacheOverfilled() const {
172 return mdns_cache_.size() > entry_limit_;
173 }
174
175 // static
GetOptionalFieldForRecord(const RecordParsed * record)176 std::string MDnsCache::GetOptionalFieldForRecord(const RecordParsed* record) {
177 switch (record->type()) {
178 case PtrRecordRdata::kType: {
179 const PtrRecordRdata* rdata = record->rdata<PtrRecordRdata>();
180 return rdata->ptrdomain();
181 }
182 default: // Most records are considered unique for our purposes
183 return "";
184 }
185 }
186
187 // static
GetEffectiveExpiration(const RecordParsed * record)188 base::Time MDnsCache::GetEffectiveExpiration(const RecordParsed* record) {
189 base::TimeDelta ttl;
190
191 if (record->ttl()) {
192 ttl = base::Seconds(record->ttl());
193 } else {
194 ttl = base::Seconds(kZeroTTLSeconds);
195 }
196
197 return record->time_created() + ttl;
198 }
199
200 } // namespace net
201