• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/mock_host_resolver.h"
6 
7 #include <stdint.h>
8 
9 #include <memory>
10 #include <string>
11 #include <utility>
12 #include <vector>
13 
14 #include "base/check_op.h"
15 #include "base/functional/bind.h"
16 #include "base/functional/callback_helpers.h"
17 #include "base/location.h"
18 #include "base/logging.h"
19 #include "base/memory/ptr_util.h"
20 #include "base/memory/raw_ptr.h"
21 #include "base/memory/ref_counted.h"
22 #include "base/no_destructor.h"
23 #include "base/notreached.h"
24 #include "base/strings/pattern.h"
25 #include "base/strings/string_piece.h"
26 #include "base/strings/string_split.h"
27 #include "base/strings/string_util.h"
28 #include "base/task/single_thread_task_runner.h"
29 #include "base/threading/platform_thread.h"
30 #include "base/time/default_tick_clock.h"
31 #include "base/time/tick_clock.h"
32 #include "base/time/time.h"
33 #include "base/types/optional_util.h"
34 #include "build/build_config.h"
35 #include "net/base/address_family.h"
36 #include "net/base/address_list.h"
37 #include "net/base/host_port_pair.h"
38 #include "net/base/ip_address.h"
39 #include "net/base/ip_endpoint.h"
40 #include "net/base/net_errors.h"
41 #include "net/base/net_export.h"
42 #include "net/base/network_anonymization_key.h"
43 #include "net/base/test_completion_callback.h"
44 #include "net/dns/dns_alias_utility.h"
45 #include "net/dns/dns_names_util.h"
46 #include "net/dns/dns_util.h"
47 #include "net/dns/host_cache.h"
48 #include "net/dns/host_resolver.h"
49 #include "net/dns/host_resolver_manager.h"
50 #include "net/dns/host_resolver_system_task.h"
51 #include "net/dns/https_record_rdata.h"
52 #include "net/dns/public/dns_query_type.h"
53 #include "net/dns/public/host_resolver_results.h"
54 #include "net/dns/public/host_resolver_source.h"
55 #include "net/dns/public/mdns_listener_update_type.h"
56 #include "net/dns/public/resolve_error_info.h"
57 #include "net/dns/public/secure_dns_policy.h"
58 #include "net/log/net_log_with_source.h"
59 #include "net/url_request/url_request_context.h"
60 #include "third_party/abseil-cpp/absl/types/optional.h"
61 #include "third_party/abseil-cpp/absl/types/variant.h"
62 #include "url/scheme_host_port.h"
63 
64 #if BUILDFLAG(IS_WIN)
65 #include "net/base/winsock_init.h"
66 #endif
67 
68 namespace net {
69 
70 namespace {
71 
72 // Cache size for the MockCachingHostResolver.
73 const unsigned kMaxCacheEntries = 100;
74 // TTL for the successful resolutions. Failures are not cached.
75 const unsigned kCacheEntryTTLSeconds = 60;
76 
GetCacheHost(const HostResolver::Host & endpoint)77 absl::variant<url::SchemeHostPort, std::string> GetCacheHost(
78     const HostResolver::Host& endpoint) {
79   if (endpoint.HasScheme()) {
80     return endpoint.AsSchemeHostPort();
81   }
82 
83   return endpoint.GetHostname();
84 }
85 
CreateCacheEntry(base::StringPiece canonical_name,const std::vector<HostResolverEndpointResult> & endpoint_results,const std::set<std::string> & aliases)86 absl::optional<HostCache::Entry> CreateCacheEntry(
87     base::StringPiece canonical_name,
88     const std::vector<HostResolverEndpointResult>& endpoint_results,
89     const std::set<std::string>& aliases) {
90   absl::optional<std::vector<net::IPEndPoint>> ip_endpoints;
91   std::multimap<HttpsRecordPriority, ConnectionEndpointMetadata>
92       endpoint_metadatas;
93   for (const auto& endpoint_result : endpoint_results) {
94     if (!ip_endpoints) {
95       ip_endpoints = endpoint_result.ip_endpoints;
96     } else {
97       // TODO(crbug.com/1264933): Support caching different IP endpoints
98       // resutls.
99       CHECK(*ip_endpoints == endpoint_result.ip_endpoints)
100           << "Currently caching MockHostResolver only supports same IP "
101              "endpoints results.";
102     }
103 
104     if (!endpoint_result.metadata.supported_protocol_alpns.empty()) {
105       endpoint_metadatas.emplace(/*priority=*/1, endpoint_result.metadata);
106     }
107   }
108   DCHECK(ip_endpoints);
109   auto endpoint_entry = HostCache::Entry(OK, *ip_endpoints, aliases,
110                                          HostCache::Entry::SOURCE_UNKNOWN);
111   endpoint_entry.set_canonical_names(std::set{std::string(canonical_name)});
112   if (endpoint_metadatas.empty()) {
113     return endpoint_entry;
114   }
115   return HostCache::Entry::MergeEntries(
116       HostCache::Entry(OK, std::move(endpoint_metadatas),
117                        HostCache::Entry::SOURCE_UNKNOWN),
118       endpoint_entry);
119 }
120 }  // namespace
121 
ParseAddressList(base::StringPiece host_list,std::vector<net::IPEndPoint> * ip_endpoints)122 int ParseAddressList(base::StringPiece host_list,
123                      std::vector<net::IPEndPoint>* ip_endpoints) {
124   ip_endpoints->clear();
125   for (base::StringPiece address : base::SplitStringPiece(
126            host_list, ",", base::TRIM_WHITESPACE, base::SPLIT_WANT_ALL)) {
127     IPAddress ip_address;
128     if (!ip_address.AssignFromIPLiteral(address)) {
129       LOG(WARNING) << "Not a supported IP literal: " << address;
130       return ERR_UNEXPECTED;
131     }
132     ip_endpoints->push_back(IPEndPoint(ip_address, 0));
133   }
134   return OK;
135 }
136 
137 class MockHostResolverBase::RequestImpl
138     : public HostResolver::ResolveHostRequest {
139  public:
RequestImpl(Host request_endpoint,const NetworkAnonymizationKey & network_anonymization_key,const absl::optional<ResolveHostParameters> & optional_parameters,base::WeakPtr<MockHostResolverBase> resolver)140   RequestImpl(Host request_endpoint,
141               const NetworkAnonymizationKey& network_anonymization_key,
142               const absl::optional<ResolveHostParameters>& optional_parameters,
143               base::WeakPtr<MockHostResolverBase> resolver)
144       : request_endpoint_(std::move(request_endpoint)),
145         network_anonymization_key_(network_anonymization_key),
146         parameters_(optional_parameters ? optional_parameters.value()
147                                         : ResolveHostParameters()),
148         priority_(parameters_.initial_priority),
149         host_resolver_flags_(ParametersToHostResolverFlags(parameters_)),
150         resolve_error_info_(ResolveErrorInfo(ERR_IO_PENDING)),
151         resolver_(resolver) {}
152 
153   RequestImpl(const RequestImpl&) = delete;
154   RequestImpl& operator=(const RequestImpl&) = delete;
155 
~RequestImpl()156   ~RequestImpl() override {
157     if (id_ > 0) {
158       if (resolver_)
159         resolver_->DetachRequest(id_);
160       id_ = 0;
161       resolver_ = nullptr;
162     }
163   }
164 
DetachFromResolver()165   void DetachFromResolver() {
166     id_ = 0;
167     resolver_ = nullptr;
168   }
169 
Start(CompletionOnceCallback callback)170   int Start(CompletionOnceCallback callback) override {
171     DCHECK(callback);
172     // Start() may only be called once per request.
173     DCHECK_EQ(0u, id_);
174     DCHECK(!complete_);
175     DCHECK(!callback_);
176     // Parent HostResolver must still be alive to call Start().
177     DCHECK(resolver_);
178 
179     int rv = resolver_->Resolve(this);
180     DCHECK(!complete_);
181     if (rv == ERR_IO_PENDING) {
182       DCHECK_GT(id_, 0u);
183       callback_ = std::move(callback);
184     } else {
185       DCHECK_EQ(0u, id_);
186       complete_ = true;
187     }
188 
189     return rv;
190   }
191 
GetAddressResults() const192   const AddressList* GetAddressResults() const override {
193     DCHECK(complete_);
194     return base::OptionalToPtr(address_results_);
195   }
196 
GetEndpointResults() const197   const std::vector<HostResolverEndpointResult>* GetEndpointResults()
198       const override {
199     DCHECK(complete_);
200     return base::OptionalToPtr(endpoint_results_);
201   }
202 
GetTextResults() const203   const absl::optional<std::vector<std::string>>& GetTextResults()
204       const override {
205     DCHECK(complete_);
206     static const base::NoDestructor<absl::optional<std::vector<std::string>>>
207         nullopt_result;
208     return *nullopt_result;
209   }
210 
GetHostnameResults() const211   const absl::optional<std::vector<HostPortPair>>& GetHostnameResults()
212       const override {
213     DCHECK(complete_);
214     static const base::NoDestructor<absl::optional<std::vector<HostPortPair>>>
215         nullopt_result;
216     return *nullopt_result;
217   }
218 
GetDnsAliasResults() const219   const std::set<std::string>* GetDnsAliasResults() const override {
220     DCHECK(complete_);
221     return base::OptionalToPtr(fixed_up_dns_alias_results_);
222   }
223 
GetResolveErrorInfo() const224   net::ResolveErrorInfo GetResolveErrorInfo() const override {
225     DCHECK(complete_);
226     return resolve_error_info_;
227   }
228 
GetStaleInfo() const229   const absl::optional<HostCache::EntryStaleness>& GetStaleInfo()
230       const override {
231     DCHECK(complete_);
232     return staleness_;
233   }
234 
ChangeRequestPriority(RequestPriority priority)235   void ChangeRequestPriority(RequestPriority priority) override {
236     priority_ = priority;
237   }
238 
SetError(int error)239   void SetError(int error) {
240     // Should only be called before request is marked completed.
241     DCHECK(!complete_);
242     resolve_error_info_ = ResolveErrorInfo(error);
243   }
244 
245   // Sets `endpoint_results_`, `fixed_up_dns_alias_results_`,
246   // `address_results_` and `staleness_` after fixing them up.
247   // Also sets `error` to OK.
SetEndpointResults(std::vector<HostResolverEndpointResult> endpoint_results,std::set<std::string> aliases,absl::optional<HostCache::EntryStaleness> staleness)248   void SetEndpointResults(
249       std::vector<HostResolverEndpointResult> endpoint_results,
250       std::set<std::string> aliases,
251       absl::optional<HostCache::EntryStaleness> staleness) {
252     DCHECK(!complete_);
253     DCHECK(!endpoint_results_);
254     DCHECK(!parameters_.is_speculative);
255 
256     endpoint_results_ = std::move(endpoint_results);
257     for (auto& result : *endpoint_results_) {
258       result.ip_endpoints = FixupEndPoints(result.ip_endpoints);
259     }
260 
261     fixed_up_dns_alias_results_ = FixupAliases(aliases);
262 
263     // `HostResolver` implementations are expected to provide an `AddressList`
264     // result whenever `HostResolverEndpointResult` is also available.
265     address_results_ = EndpointResultToAddressList(
266         *endpoint_results_, *fixed_up_dns_alias_results_);
267 
268     staleness_ = std::move(staleness);
269 
270     SetError(OK);
271   }
272 
OnAsyncCompleted(size_t id,int error)273   void OnAsyncCompleted(size_t id, int error) {
274     DCHECK_EQ(id_, id);
275     id_ = 0;
276 
277     // Check that error information has been set and that the top-level error
278     // code is valid.
279     DCHECK(resolve_error_info_.error != ERR_IO_PENDING);
280     DCHECK(error == OK || error == ERR_NAME_NOT_RESOLVED ||
281            error == ERR_DNS_NAME_HTTPS_ONLY);
282 
283     DCHECK(!complete_);
284     complete_ = true;
285 
286     DCHECK(callback_);
287     std::move(callback_).Run(error);
288   }
289 
request_endpoint() const290   const Host& request_endpoint() const { return request_endpoint_; }
291 
network_anonymization_key() const292   const NetworkAnonymizationKey& network_anonymization_key() const {
293     return network_anonymization_key_;
294   }
295 
parameters() const296   const ResolveHostParameters& parameters() const { return parameters_; }
297 
host_resolver_flags() const298   int host_resolver_flags() const { return host_resolver_flags_; }
299 
id()300   size_t id() { return id_; }
301 
priority() const302   RequestPriority priority() const { return priority_; }
303 
set_id(size_t id)304   void set_id(size_t id) {
305     DCHECK_GT(id, 0u);
306     DCHECK_EQ(0u, id_);
307 
308     id_ = id;
309   }
310 
complete()311   bool complete() { return complete_; }
312 
313   // Similar get GetAddressResults() and GetResolveErrorInfo(), but only exposed
314   // through the HostResolver::ResolveHostRequest interface, and don't have the
315   // DCHECKs that `complete_` is true.
address_results() const316   const absl::optional<AddressList>& address_results() const {
317     return address_results_;
318   }
resolve_error_info() const319   ResolveErrorInfo resolve_error_info() const { return resolve_error_info_; }
320 
321  private:
FixupEndPoints(const std::vector<IPEndPoint> & endpoints)322   std::vector<IPEndPoint> FixupEndPoints(
323       const std::vector<IPEndPoint>& endpoints) {
324     std::vector<IPEndPoint> corrected;
325     for (const IPEndPoint& endpoint : endpoints) {
326       DCHECK_NE(endpoint.GetFamily(), ADDRESS_FAMILY_UNSPECIFIED);
327       if (parameters_.dns_query_type == DnsQueryType::UNSPECIFIED ||
328           parameters_.dns_query_type ==
329               AddressFamilyToDnsQueryType(endpoint.GetFamily())) {
330         if (endpoint.port() == 0) {
331           corrected.emplace_back(endpoint.address(),
332                                  request_endpoint_.GetPort());
333         } else {
334           corrected.push_back(endpoint);
335         }
336       }
337     }
338     return corrected;
339   }
FixupAliases(const std::set<std::string> aliases)340   std::set<std::string> FixupAliases(const std::set<std::string> aliases) {
341     if (aliases.empty())
342       return std::set<std::string>{
343           std::string(request_endpoint_.GetHostnameWithoutBrackets())};
344     return aliases;
345   }
346 
347   const Host request_endpoint_;
348   const NetworkAnonymizationKey network_anonymization_key_;
349   const ResolveHostParameters parameters_;
350   RequestPriority priority_;
351   int host_resolver_flags_;
352 
353   absl::optional<AddressList> address_results_;
354   absl::optional<std::vector<HostResolverEndpointResult>> endpoint_results_;
355   absl::optional<std::set<std::string>> fixed_up_dns_alias_results_;
356   absl::optional<HostCache::EntryStaleness> staleness_;
357   ResolveErrorInfo resolve_error_info_;
358 
359   // Used while stored with the resolver for async resolution.  Otherwise 0.
360   size_t id_ = 0;
361 
362   CompletionOnceCallback callback_;
363   // Use a WeakPtr as the resolver may be destroyed while there are still
364   // outstanding request objects.
365   base::WeakPtr<MockHostResolverBase> resolver_;
366   bool complete_ = false;
367 };
368 
369 class MockHostResolverBase::ProbeRequestImpl
370     : public HostResolver::ProbeRequest {
371  public:
ProbeRequestImpl(base::WeakPtr<MockHostResolverBase> resolver)372   explicit ProbeRequestImpl(base::WeakPtr<MockHostResolverBase> resolver)
373       : resolver_(std::move(resolver)) {}
374 
375   ProbeRequestImpl(const ProbeRequestImpl&) = delete;
376   ProbeRequestImpl& operator=(const ProbeRequestImpl&) = delete;
377 
~ProbeRequestImpl()378   ~ProbeRequestImpl() override {
379     if (resolver_) {
380       resolver_->state_->ClearDohProbeRequestIfMatching(this);
381     }
382   }
383 
Start()384   int Start() override {
385     DCHECK(resolver_);
386     resolver_->state_->set_doh_probe_request(this);
387 
388     return ERR_IO_PENDING;
389   }
390 
391  private:
392   base::WeakPtr<MockHostResolverBase> resolver_;
393 };
394 
395 class MockHostResolverBase::MdnsListenerImpl
396     : public HostResolver::MdnsListener {
397  public:
MdnsListenerImpl(const HostPortPair & host,DnsQueryType query_type,base::WeakPtr<MockHostResolverBase> resolver)398   MdnsListenerImpl(const HostPortPair& host,
399                    DnsQueryType query_type,
400                    base::WeakPtr<MockHostResolverBase> resolver)
401       : host_(host), query_type_(query_type), resolver_(resolver) {
402     DCHECK_NE(DnsQueryType::UNSPECIFIED, query_type_);
403     DCHECK(resolver_);
404   }
405 
~MdnsListenerImpl()406   ~MdnsListenerImpl() override {
407     if (resolver_)
408       resolver_->RemoveCancelledListener(this);
409   }
410 
Start(Delegate * delegate)411   int Start(Delegate* delegate) override {
412     DCHECK(delegate);
413     DCHECK(!delegate_);
414     DCHECK(resolver_);
415 
416     delegate_ = delegate;
417     resolver_->AddListener(this);
418 
419     return OK;
420   }
421 
TriggerAddressResult(MdnsListenerUpdateType update_type,IPEndPoint address)422   void TriggerAddressResult(MdnsListenerUpdateType update_type,
423                             IPEndPoint address) {
424     delegate_->OnAddressResult(update_type, query_type_, std::move(address));
425   }
426 
TriggerTextResult(MdnsListenerUpdateType update_type,std::vector<std::string> text_records)427   void TriggerTextResult(MdnsListenerUpdateType update_type,
428                          std::vector<std::string> text_records) {
429     delegate_->OnTextResult(update_type, query_type_, std::move(text_records));
430   }
431 
TriggerHostnameResult(MdnsListenerUpdateType update_type,HostPortPair host)432   void TriggerHostnameResult(MdnsListenerUpdateType update_type,
433                              HostPortPair host) {
434     delegate_->OnHostnameResult(update_type, query_type_, std::move(host));
435   }
436 
TriggerUnhandledResult(MdnsListenerUpdateType update_type)437   void TriggerUnhandledResult(MdnsListenerUpdateType update_type) {
438     delegate_->OnUnhandledResult(update_type, query_type_);
439   }
440 
host() const441   const HostPortPair& host() const { return host_; }
query_type() const442   DnsQueryType query_type() const { return query_type_; }
443 
444  private:
445   const HostPortPair host_;
446   const DnsQueryType query_type_;
447 
448   raw_ptr<Delegate> delegate_ = nullptr;
449 
450   // Use a WeakPtr as the resolver may be destroyed while there are still
451   // outstanding listener objects.
452   base::WeakPtr<MockHostResolverBase> resolver_;
453 };
454 
455 MockHostResolverBase::RuleResolver::RuleKey::RuleKey() = default;
456 
457 MockHostResolverBase::RuleResolver::RuleKey::~RuleKey() = default;
458 
459 MockHostResolverBase::RuleResolver::RuleKey::RuleKey(const RuleKey&) = default;
460 
461 MockHostResolverBase::RuleResolver::RuleKey&
462 MockHostResolverBase::RuleResolver::RuleKey::operator=(const RuleKey&) =
463     default;
464 
465 MockHostResolverBase::RuleResolver::RuleKey::RuleKey(RuleKey&&) = default;
466 
467 MockHostResolverBase::RuleResolver::RuleKey&
468 MockHostResolverBase::RuleResolver::RuleKey::operator=(RuleKey&&) = default;
469 
470 MockHostResolverBase::RuleResolver::RuleResult::RuleResult() = default;
471 
RuleResult(std::vector<HostResolverEndpointResult> endpoints,std::set<std::string> aliases)472 MockHostResolverBase::RuleResolver::RuleResult::RuleResult(
473     std::vector<HostResolverEndpointResult> endpoints,
474     std::set<std::string> aliases)
475     : endpoints(std::move(endpoints)), aliases(std::move(aliases)) {}
476 
477 MockHostResolverBase::RuleResolver::RuleResult::~RuleResult() = default;
478 
479 MockHostResolverBase::RuleResolver::RuleResult::RuleResult(const RuleResult&) =
480     default;
481 
482 MockHostResolverBase::RuleResolver::RuleResult&
483 MockHostResolverBase::RuleResolver::RuleResult::operator=(const RuleResult&) =
484     default;
485 
486 MockHostResolverBase::RuleResolver::RuleResult::RuleResult(RuleResult&&) =
487     default;
488 
489 MockHostResolverBase::RuleResolver::RuleResult&
490 MockHostResolverBase::RuleResolver::RuleResult::operator=(RuleResult&&) =
491     default;
492 
RuleResolver(absl::optional<RuleResultOrError> default_result)493 MockHostResolverBase::RuleResolver::RuleResolver(
494     absl::optional<RuleResultOrError> default_result)
495     : default_result_(std::move(default_result)) {}
496 
497 MockHostResolverBase::RuleResolver::~RuleResolver() = default;
498 
499 MockHostResolverBase::RuleResolver::RuleResolver(const RuleResolver&) = default;
500 
501 MockHostResolverBase::RuleResolver&
502 MockHostResolverBase::RuleResolver::operator=(const RuleResolver&) = default;
503 
504 MockHostResolverBase::RuleResolver::RuleResolver(RuleResolver&&) = default;
505 
506 MockHostResolverBase::RuleResolver&
507 MockHostResolverBase::RuleResolver::operator=(RuleResolver&&) = default;
508 
509 const MockHostResolverBase::RuleResolver::RuleResultOrError&
Resolve(const Host & request_endpoint,DnsQueryTypeSet request_types,HostResolverSource request_source) const510 MockHostResolverBase::RuleResolver::Resolve(
511     const Host& request_endpoint,
512     DnsQueryTypeSet request_types,
513     HostResolverSource request_source) const {
514   for (const auto& rule : rules_) {
515     const RuleKey& key = rule.first;
516     const RuleResultOrError& result = rule.second;
517 
518     if (absl::holds_alternative<RuleKey::NoScheme>(key.scheme) &&
519         request_endpoint.HasScheme()) {
520       continue;
521     }
522 
523     if (key.port.has_value() &&
524         key.port.value() != request_endpoint.GetPort()) {
525       continue;
526     }
527 
528     DCHECK(!key.query_type.has_value() ||
529            key.query_type.value() != DnsQueryType::UNSPECIFIED);
530     if (key.query_type.has_value() &&
531         !request_types.Has(key.query_type.value())) {
532       continue;
533     }
534 
535     if (key.query_source.has_value() &&
536         request_source != key.query_source.value()) {
537       continue;
538     }
539 
540     if (absl::holds_alternative<RuleKey::Scheme>(key.scheme) &&
541         (!request_endpoint.HasScheme() ||
542          request_endpoint.GetScheme() !=
543              absl::get<RuleKey::Scheme>(key.scheme))) {
544       continue;
545     }
546 
547     if (!base::MatchPattern(request_endpoint.GetHostnameWithoutBrackets(),
548                             key.hostname_pattern)) {
549       continue;
550     }
551 
552     return result;
553   }
554 
555   if (default_result_)
556     return default_result_.value();
557 
558   NOTREACHED() << "Request " << request_endpoint.GetHostname()
559                << " did not match any MockHostResolver rules.";
560   static const RuleResultOrError kUnexpected = ERR_UNEXPECTED;
561   return kUnexpected;
562 }
563 
ClearRules()564 void MockHostResolverBase::RuleResolver::ClearRules() {
565   rules_.clear();
566 }
567 
568 // static
569 MockHostResolverBase::RuleResolver::RuleResultOrError
GetLocalhostResult()570 MockHostResolverBase::RuleResolver::GetLocalhostResult() {
571   HostResolverEndpointResult endpoint;
572   endpoint.ip_endpoints = {IPEndPoint(IPAddress::IPv4Localhost(), /*port=*/0)};
573   return RuleResult(std::vector{endpoint});
574 }
575 
AddRule(RuleKey key,RuleResultOrError result)576 void MockHostResolverBase::RuleResolver::AddRule(RuleKey key,
577                                                  RuleResultOrError result) {
578   // Literals are always resolved to themselves by MockHostResolverBase,
579   // consequently we do not support remapping them.
580   IPAddress ip_address;
581   DCHECK(!ip_address.AssignFromIPLiteral(key.hostname_pattern));
582 
583   CHECK(rules_.emplace(std::move(key), std::move(result)).second)
584       << "Duplicate rule key";
585 }
586 
AddRule(RuleKey key,base::StringPiece ip_literal)587 void MockHostResolverBase::RuleResolver::AddRule(RuleKey key,
588                                                  base::StringPiece ip_literal) {
589   std::vector<HostResolverEndpointResult> endpoints;
590   endpoints.emplace_back();
591   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
592   AddRule(std::move(key), RuleResult(std::move(endpoints)));
593 }
594 
AddRule(base::StringPiece hostname_pattern,RuleResultOrError result)595 void MockHostResolverBase::RuleResolver::AddRule(
596     base::StringPiece hostname_pattern,
597     RuleResultOrError result) {
598   RuleKey key;
599   key.hostname_pattern = std::string(hostname_pattern);
600   AddRule(std::move(key), std::move(result));
601 }
602 
AddRule(base::StringPiece hostname_pattern,base::StringPiece ip_literal)603 void MockHostResolverBase::RuleResolver::AddRule(
604     base::StringPiece hostname_pattern,
605     base::StringPiece ip_literal) {
606   std::vector<HostResolverEndpointResult> endpoints;
607   endpoints.emplace_back();
608   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
609   AddRule(hostname_pattern, RuleResult(std::move(endpoints)));
610 }
611 
AddRule(base::StringPiece hostname_pattern,Error error)612 void MockHostResolverBase::RuleResolver::AddRule(
613     base::StringPiece hostname_pattern,
614     Error error) {
615   RuleKey key;
616   key.hostname_pattern = std::string(hostname_pattern);
617 
618   AddRule(std::move(key), error);
619 }
620 
AddIPLiteralRule(base::StringPiece hostname_pattern,base::StringPiece ip_literal,base::StringPiece canonical_name)621 void MockHostResolverBase::RuleResolver::AddIPLiteralRule(
622     base::StringPiece hostname_pattern,
623     base::StringPiece ip_literal,
624     base::StringPiece canonical_name) {
625   RuleKey key;
626   key.hostname_pattern = std::string(hostname_pattern);
627 
628   std::set<std::string> aliases;
629   if (!canonical_name.empty())
630     aliases.emplace(canonical_name);
631 
632   std::vector<HostResolverEndpointResult> endpoints;
633   endpoints.emplace_back();
634   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
635   AddRule(std::move(key), RuleResult(std::move(endpoints), std::move(aliases)));
636 }
637 
AddIPLiteralRuleWithDnsAliases(base::StringPiece hostname_pattern,base::StringPiece ip_literal,std::vector<std::string> dns_aliases)638 void MockHostResolverBase::RuleResolver::AddIPLiteralRuleWithDnsAliases(
639     base::StringPiece hostname_pattern,
640     base::StringPiece ip_literal,
641     std::vector<std::string> dns_aliases) {
642   std::vector<HostResolverEndpointResult> endpoints;
643   endpoints.emplace_back();
644   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
645   AddRule(hostname_pattern,
646           RuleResult(
647               std::move(endpoints),
648               std::set<std::string>(dns_aliases.begin(), dns_aliases.end())));
649 }
650 
AddIPLiteralRuleWithDnsAliases(base::StringPiece hostname_pattern,base::StringPiece ip_literal,std::set<std::string> dns_aliases)651 void MockHostResolverBase::RuleResolver::AddIPLiteralRuleWithDnsAliases(
652     base::StringPiece hostname_pattern,
653     base::StringPiece ip_literal,
654     std::set<std::string> dns_aliases) {
655   std::vector<std::string> aliases_vector;
656   base::ranges::move(dns_aliases, std::back_inserter(aliases_vector));
657 
658   AddIPLiteralRuleWithDnsAliases(hostname_pattern, ip_literal,
659                                  std::move(aliases_vector));
660 }
661 
AddSimulatedFailure(base::StringPiece hostname_pattern)662 void MockHostResolverBase::RuleResolver::AddSimulatedFailure(
663     base::StringPiece hostname_pattern) {
664   AddRule(hostname_pattern, ERR_NAME_NOT_RESOLVED);
665 }
666 
AddSimulatedTimeoutFailure(base::StringPiece hostname_pattern)667 void MockHostResolverBase::RuleResolver::AddSimulatedTimeoutFailure(
668     base::StringPiece hostname_pattern) {
669   AddRule(hostname_pattern, ERR_DNS_TIMED_OUT);
670 }
671 
AddRuleWithFlags(base::StringPiece host_pattern,base::StringPiece ip_literal,HostResolverFlags,std::vector<std::string> dns_aliases)672 void MockHostResolverBase::RuleResolver::AddRuleWithFlags(
673     base::StringPiece host_pattern,
674     base::StringPiece ip_literal,
675     HostResolverFlags /*flags*/,
676     std::vector<std::string> dns_aliases) {
677   std::vector<HostResolverEndpointResult> endpoints;
678   endpoints.emplace_back();
679   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
680   AddRule(host_pattern, RuleResult(std::move(endpoints),
681                                    std::set<std::string>(dns_aliases.begin(),
682                                                          dns_aliases.end())));
683 }
684 
685 MockHostResolverBase::State::State() = default;
686 MockHostResolverBase::State::~State() = default;
687 
~MockHostResolverBase()688 MockHostResolverBase::~MockHostResolverBase() {
689   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
690 
691   // Sanity check that pending requests are always cleaned up, by waiting for
692   // completion, manually cancelling, or calling OnShutdown().
693   DCHECK(!state_->has_pending_requests());
694 }
695 
OnShutdown()696 void MockHostResolverBase::OnShutdown() {
697   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
698 
699   // Cancel all pending requests.
700   for (auto& request : state_->mutable_requests()) {
701     request.second->DetachFromResolver();
702   }
703   state_->mutable_requests().clear();
704 
705   // Prevent future requests by clearing resolution rules and the cache.
706   rule_resolver_.ClearRules();
707   cache_ = nullptr;
708 
709   state_->ClearDohProbeRequest();
710 }
711 
712 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(url::SchemeHostPort host,NetworkAnonymizationKey network_anonymization_key,NetLogWithSource net_log,absl::optional<ResolveHostParameters> optional_parameters)713 MockHostResolverBase::CreateRequest(
714     url::SchemeHostPort host,
715     NetworkAnonymizationKey network_anonymization_key,
716     NetLogWithSource net_log,
717     absl::optional<ResolveHostParameters> optional_parameters) {
718   return std::make_unique<RequestImpl>(Host(std::move(host)),
719                                        network_anonymization_key,
720                                        optional_parameters, AsWeakPtr());
721 }
722 
723 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(const HostPortPair & host,const NetworkAnonymizationKey & network_anonymization_key,const NetLogWithSource & source_net_log,const absl::optional<ResolveHostParameters> & optional_parameters)724 MockHostResolverBase::CreateRequest(
725     const HostPortPair& host,
726     const NetworkAnonymizationKey& network_anonymization_key,
727     const NetLogWithSource& source_net_log,
728     const absl::optional<ResolveHostParameters>& optional_parameters) {
729   return std::make_unique<RequestImpl>(Host(host), network_anonymization_key,
730                                        optional_parameters, AsWeakPtr());
731 }
732 
733 std::unique_ptr<HostResolver::ProbeRequest>
CreateDohProbeRequest()734 MockHostResolverBase::CreateDohProbeRequest() {
735   return std::make_unique<ProbeRequestImpl>(AsWeakPtr());
736 }
737 
738 std::unique_ptr<HostResolver::MdnsListener>
CreateMdnsListener(const HostPortPair & host,DnsQueryType query_type)739 MockHostResolverBase::CreateMdnsListener(const HostPortPair& host,
740                                          DnsQueryType query_type) {
741   return std::make_unique<MdnsListenerImpl>(host, query_type, AsWeakPtr());
742 }
743 
GetHostCache()744 HostCache* MockHostResolverBase::GetHostCache() {
745   return cache_.get();
746 }
747 
LoadIntoCache(absl::variant<url::SchemeHostPort,HostPortPair> endpoint,const NetworkAnonymizationKey & network_anonymization_key,const absl::optional<ResolveHostParameters> & optional_parameters)748 int MockHostResolverBase::LoadIntoCache(
749     absl::variant<url::SchemeHostPort, HostPortPair> endpoint,
750     const NetworkAnonymizationKey& network_anonymization_key,
751     const absl::optional<ResolveHostParameters>& optional_parameters) {
752   return LoadIntoCache(Host(std::move(endpoint)), network_anonymization_key,
753                        optional_parameters);
754 }
755 
LoadIntoCache(const Host & endpoint,const NetworkAnonymizationKey & network_anonymization_key,const absl::optional<ResolveHostParameters> & optional_parameters)756 int MockHostResolverBase::LoadIntoCache(
757     const Host& endpoint,
758     const NetworkAnonymizationKey& network_anonymization_key,
759     const absl::optional<ResolveHostParameters>& optional_parameters) {
760   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
761   DCHECK(cache_);
762 
763   ResolveHostParameters parameters =
764       optional_parameters.value_or(ResolveHostParameters());
765 
766   std::vector<HostResolverEndpointResult> endpoints;
767   std::set<std::string> aliases;
768   absl::optional<HostCache::EntryStaleness> stale_info;
769   int rv = ResolveFromIPLiteralOrCache(
770       endpoint, network_anonymization_key, parameters.dns_query_type,
771       ParametersToHostResolverFlags(parameters), parameters.source,
772       parameters.cache_usage, &endpoints, &aliases, &stale_info);
773   if (rv != ERR_DNS_CACHE_MISS) {
774     // Request already in cache (or IP literal). No need to load it.
775     return rv;
776   }
777 
778   // Just like the real resolver, refuse to do anything with invalid
779   // hostnames.
780   if (!dns_names_util::IsValidDnsName(endpoint.GetHostnameWithoutBrackets()))
781     return ERR_NAME_NOT_RESOLVED;
782 
783   RequestImpl request(endpoint, network_anonymization_key, optional_parameters,
784                       AsWeakPtr());
785   return DoSynchronousResolution(request);
786 }
787 
ResolveAllPending()788 void MockHostResolverBase::ResolveAllPending() {
789   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
790   DCHECK(ondemand_mode_);
791   for (auto& [id, request] : state_->mutable_requests()) {
792     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
793         FROM_HERE,
794         base::BindOnce(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id));
795   }
796 }
797 
last_id()798 size_t MockHostResolverBase::last_id() {
799   if (!has_pending_requests())
800     return 0;
801   return state_->mutable_requests().rbegin()->first;
802 }
803 
ResolveNow(size_t id)804 void MockHostResolverBase::ResolveNow(size_t id) {
805   auto it = state_->mutable_requests().find(id);
806   if (it == state_->mutable_requests().end())
807     return;  // was canceled
808 
809   RequestImpl* req = it->second;
810   state_->mutable_requests().erase(it);
811 
812   int error = DoSynchronousResolution(*req);
813   req->OnAsyncCompleted(id, error);
814 }
815 
DetachRequest(size_t id)816 void MockHostResolverBase::DetachRequest(size_t id) {
817   auto it = state_->mutable_requests().find(id);
818   CHECK(it != state_->mutable_requests().end());
819   state_->mutable_requests().erase(it);
820 }
821 
request_host(size_t id)822 base::StringPiece MockHostResolverBase::request_host(size_t id) {
823   DCHECK(request(id));
824   return request(id)->request_endpoint().GetHostnameWithoutBrackets();
825 }
826 
request_priority(size_t id)827 RequestPriority MockHostResolverBase::request_priority(size_t id) {
828   DCHECK(request(id));
829   return request(id)->priority();
830 }
831 
832 const NetworkAnonymizationKey&
request_network_anonymization_key(size_t id)833 MockHostResolverBase::request_network_anonymization_key(size_t id) {
834   DCHECK(request(id));
835   return request(id)->network_anonymization_key();
836 }
837 
ResolveOnlyRequestNow()838 void MockHostResolverBase::ResolveOnlyRequestNow() {
839   DCHECK_EQ(1u, state_->mutable_requests().size());
840   ResolveNow(state_->mutable_requests().begin()->first);
841 }
842 
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type,const IPEndPoint & address_result)843 void MockHostResolverBase::TriggerMdnsListeners(
844     const HostPortPair& host,
845     DnsQueryType query_type,
846     MdnsListenerUpdateType update_type,
847     const IPEndPoint& address_result) {
848   for (auto* listener : listeners_) {
849     if (listener->host() == host && listener->query_type() == query_type)
850       listener->TriggerAddressResult(update_type, address_result);
851   }
852 }
853 
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type,const std::vector<std::string> & text_result)854 void MockHostResolverBase::TriggerMdnsListeners(
855     const HostPortPair& host,
856     DnsQueryType query_type,
857     MdnsListenerUpdateType update_type,
858     const std::vector<std::string>& text_result) {
859   for (auto* listener : listeners_) {
860     if (listener->host() == host && listener->query_type() == query_type)
861       listener->TriggerTextResult(update_type, text_result);
862   }
863 }
864 
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type,const HostPortPair & host_result)865 void MockHostResolverBase::TriggerMdnsListeners(
866     const HostPortPair& host,
867     DnsQueryType query_type,
868     MdnsListenerUpdateType update_type,
869     const HostPortPair& host_result) {
870   for (auto* listener : listeners_) {
871     if (listener->host() == host && listener->query_type() == query_type)
872       listener->TriggerHostnameResult(update_type, host_result);
873   }
874 }
875 
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type)876 void MockHostResolverBase::TriggerMdnsListeners(
877     const HostPortPair& host,
878     DnsQueryType query_type,
879     MdnsListenerUpdateType update_type) {
880   for (auto* listener : listeners_) {
881     if (listener->host() == host && listener->query_type() == query_type)
882       listener->TriggerUnhandledResult(update_type);
883   }
884 }
885 
request(size_t id)886 MockHostResolverBase::RequestImpl* MockHostResolverBase::request(size_t id) {
887   RequestMap::iterator request = state_->mutable_requests().find(id);
888   CHECK(request != state_->mutable_requests().end());
889   CHECK_EQ(request->second->id(), id);
890   return (*request).second;
891 }
892 
893 // start id from 1 to distinguish from NULL RequestHandle
MockHostResolverBase(bool use_caching,int cache_invalidation_num,RuleResolver rule_resolver)894 MockHostResolverBase::MockHostResolverBase(bool use_caching,
895                                            int cache_invalidation_num,
896                                            RuleResolver rule_resolver)
897     : rule_resolver_(std::move(rule_resolver)),
898       initial_cache_invalidation_num_(cache_invalidation_num),
899       tick_clock_(base::DefaultTickClock::GetInstance()),
900       state_(base::MakeRefCounted<State>()) {
901   if (use_caching)
902     cache_ = std::make_unique<HostCache>(kMaxCacheEntries);
903   else
904     DCHECK_GE(0, cache_invalidation_num);
905 }
906 
Resolve(RequestImpl * request)907 int MockHostResolverBase::Resolve(RequestImpl* request) {
908   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
909 
910   last_request_priority_ = request->parameters().initial_priority;
911   last_request_network_anonymization_key_ =
912       request->network_anonymization_key();
913   last_secure_dns_policy_ = request->parameters().secure_dns_policy;
914   state_->IncrementNumResolve();
915   std::vector<HostResolverEndpointResult> endpoints;
916   std::set<std::string> aliases;
917   absl::optional<HostCache::EntryStaleness> stale_info;
918   // TODO(crbug.com/1264933): Allow caching `ConnectionEndpoint` results.
919   int rv = ResolveFromIPLiteralOrCache(
920       request->request_endpoint(), request->network_anonymization_key(),
921       request->parameters().dns_query_type, request->host_resolver_flags(),
922       request->parameters().source, request->parameters().cache_usage,
923       &endpoints, &aliases, &stale_info);
924 
925   if (rv == OK && !request->parameters().is_speculative) {
926     request->SetEndpointResults(std::move(endpoints), std::move(aliases),
927                                 std::move(stale_info));
928   } else {
929     request->SetError(rv);
930   }
931 
932   if (rv != ERR_DNS_CACHE_MISS ||
933       request->parameters().source == HostResolverSource::LOCAL_ONLY) {
934     return SquashErrorCode(rv);
935   }
936 
937   // Just like the real resolver, refuse to do anything with invalid
938   // hostnames.
939   if (!dns_names_util::IsValidDnsName(
940           request->request_endpoint().GetHostnameWithoutBrackets())) {
941     request->SetError(ERR_NAME_NOT_RESOLVED);
942     return ERR_NAME_NOT_RESOLVED;
943   }
944 
945   if (synchronous_mode_)
946     return DoSynchronousResolution(*request);
947 
948   // Store the request for asynchronous resolution
949   size_t id = next_request_id_++;
950   request->set_id(id);
951   state_->mutable_requests()[id] = request;
952 
953   if (!ondemand_mode_) {
954     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
955         FROM_HERE,
956         base::BindOnce(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id));
957   }
958 
959   return ERR_IO_PENDING;
960 }
961 
ResolveFromIPLiteralOrCache(const Host & endpoint,const NetworkAnonymizationKey & network_anonymization_key,DnsQueryType dns_query_type,HostResolverFlags flags,HostResolverSource source,HostResolver::ResolveHostParameters::CacheUsage cache_usage,std::vector<HostResolverEndpointResult> * out_endpoints,std::set<std::string> * out_aliases,absl::optional<HostCache::EntryStaleness> * out_stale_info)962 int MockHostResolverBase::ResolveFromIPLiteralOrCache(
963     const Host& endpoint,
964     const NetworkAnonymizationKey& network_anonymization_key,
965     DnsQueryType dns_query_type,
966     HostResolverFlags flags,
967     HostResolverSource source,
968     HostResolver::ResolveHostParameters::CacheUsage cache_usage,
969     std::vector<HostResolverEndpointResult>* out_endpoints,
970     std::set<std::string>* out_aliases,
971     absl::optional<HostCache::EntryStaleness>* out_stale_info) {
972   DCHECK(out_endpoints);
973   DCHECK(out_aliases);
974   DCHECK(out_stale_info);
975   out_endpoints->clear();
976   out_aliases->clear();
977   *out_stale_info = absl::nullopt;
978 
979   IPAddress ip_address;
980   if (ip_address.AssignFromIPLiteral(endpoint.GetHostnameWithoutBrackets())) {
981     const DnsQueryType desired_address_query =
982         AddressFamilyToDnsQueryType(GetAddressFamily(ip_address));
983     DCHECK_NE(desired_address_query, DnsQueryType::UNSPECIFIED);
984 
985     // This matches the behavior HostResolverImpl.
986     if (dns_query_type != DnsQueryType::UNSPECIFIED &&
987         dns_query_type != desired_address_query) {
988       return ERR_NAME_NOT_RESOLVED;
989     }
990 
991     *out_endpoints = std::vector<HostResolverEndpointResult>(1);
992     (*out_endpoints)[0].ip_endpoints.emplace_back(ip_address,
993                                                   endpoint.GetPort());
994     if (flags & HOST_RESOLVER_CANONNAME)
995       *out_aliases = {ip_address.ToString()};
996     return OK;
997   }
998 
999   std::vector<IPEndPoint> localhost_endpoints;
1000   // Immediately resolve any "localhost" or recognized similar names.
1001   if (IsAddressType(dns_query_type) &&
1002       ResolveLocalHostname(endpoint.GetHostnameWithoutBrackets(),
1003                            &localhost_endpoints)) {
1004     *out_endpoints = std::vector<HostResolverEndpointResult>(1);
1005     (*out_endpoints)[0].ip_endpoints = localhost_endpoints;
1006     return OK;
1007   }
1008   int rv = ERR_DNS_CACHE_MISS;
1009   bool cache_allowed =
1010       cache_usage == HostResolver::ResolveHostParameters::CacheUsage::ALLOWED ||
1011       cache_usage ==
1012           HostResolver::ResolveHostParameters::CacheUsage::STALE_ALLOWED;
1013   if (cache_.get() && cache_allowed) {
1014     // Local-only requests search the cache for non-local-only results.
1015     HostResolverSource effective_source =
1016         source == HostResolverSource::LOCAL_ONLY ? HostResolverSource::ANY
1017                                                  : source;
1018     HostCache::Key key(GetCacheHost(endpoint), dns_query_type, flags,
1019                        effective_source, network_anonymization_key);
1020     const std::pair<const HostCache::Key, HostCache::Entry>* cache_result;
1021     HostCache::EntryStaleness stale_info = HostCache::kNotStale;
1022     if (cache_usage ==
1023         HostResolver::ResolveHostParameters::CacheUsage::STALE_ALLOWED) {
1024       cache_result = cache_->LookupStale(key, tick_clock_->NowTicks(),
1025                                          &stale_info, true /* ignore_secure */);
1026     } else {
1027       cache_result = cache_->Lookup(key, tick_clock_->NowTicks(),
1028                                     true /* ignore_secure */);
1029     }
1030     if (cache_result) {
1031       rv = cache_result->second.error();
1032       if (rv == OK) {
1033         *out_endpoints = cache_result->second.GetEndpoints().value();
1034 
1035         if (cache_result->second.aliases()) {
1036           *out_aliases = *cache_result->second.aliases();
1037         }
1038         *out_stale_info = std::move(stale_info);
1039       }
1040 
1041       auto cache_invalidation_iterator = cache_invalidation_nums_.find(key);
1042       if (cache_invalidation_iterator != cache_invalidation_nums_.end()) {
1043         DCHECK_LE(1, cache_invalidation_iterator->second);
1044         cache_invalidation_iterator->second--;
1045         if (cache_invalidation_iterator->second == 0) {
1046           HostCache::Entry new_entry(cache_result->second);
1047           cache_->Set(key, new_entry, tick_clock_->NowTicks(),
1048                       base::TimeDelta());
1049           cache_invalidation_nums_.erase(cache_invalidation_iterator);
1050         }
1051       }
1052     }
1053   }
1054   return rv;
1055 }
1056 
DoSynchronousResolution(RequestImpl & request)1057 int MockHostResolverBase::DoSynchronousResolution(RequestImpl& request) {
1058   state_->IncrementNumNonLocalResolves();
1059 
1060   const RuleResolver::RuleResultOrError& result = rule_resolver_.Resolve(
1061       request.request_endpoint(), request.parameters().dns_query_type,
1062       request.parameters().source);
1063 
1064   int error = ERR_UNEXPECTED;
1065   absl::optional<HostCache::Entry> cache_entry;
1066   if (absl::holds_alternative<RuleResolver::RuleResult>(result)) {
1067     const auto& rule_result = absl::get<RuleResolver::RuleResult>(result);
1068     const auto& endpoint_results = rule_result.endpoints;
1069     const auto& aliases = rule_result.aliases;
1070     request.SetEndpointResults(endpoint_results, aliases,
1071                                /*staleness=*/absl::nullopt);
1072     // TODO(crbug.com/1264933): Change `error` on empty results?
1073     error = OK;
1074     if (cache_.get()) {
1075       cache_entry = CreateCacheEntry(request.request_endpoint().GetHostname(),
1076                                      endpoint_results, aliases);
1077     }
1078   } else {
1079     DCHECK(absl::holds_alternative<RuleResolver::ErrorResult>(result));
1080     error = absl::get<RuleResolver::ErrorResult>(result);
1081     request.SetError(error);
1082     if (cache_.get()) {
1083       cache_entry.emplace(error, HostCache::Entry::SOURCE_UNKNOWN);
1084     }
1085   }
1086   if (cache_.get() && cache_entry.has_value()) {
1087     HostCache::Key key(
1088         GetCacheHost(request.request_endpoint()),
1089         request.parameters().dns_query_type, request.host_resolver_flags(),
1090         request.parameters().source, request.network_anonymization_key());
1091     // Storing a failure with TTL 0 so that it overwrites previous value.
1092     base::TimeDelta ttl;
1093     if (error == OK) {
1094       ttl = base::Seconds(kCacheEntryTTLSeconds);
1095       if (initial_cache_invalidation_num_ > 0)
1096         cache_invalidation_nums_[key] = initial_cache_invalidation_num_;
1097     }
1098     cache_->Set(key, cache_entry.value(), tick_clock_->NowTicks(), ttl);
1099   }
1100 
1101   return SquashErrorCode(error);
1102 }
1103 
AddListener(MdnsListenerImpl * listener)1104 void MockHostResolverBase::AddListener(MdnsListenerImpl* listener) {
1105   listeners_.insert(listener);
1106 }
1107 
RemoveCancelledListener(MdnsListenerImpl * listener)1108 void MockHostResolverBase::RemoveCancelledListener(MdnsListenerImpl* listener) {
1109   listeners_.erase(listener);
1110 }
1111 
MockHostResolverFactory(MockHostResolverBase::RuleResolver rules,bool use_caching,int cache_invalidation_num)1112 MockHostResolverFactory::MockHostResolverFactory(
1113     MockHostResolverBase::RuleResolver rules,
1114     bool use_caching,
1115     int cache_invalidation_num)
1116     : rules_(std::move(rules)),
1117       use_caching_(use_caching),
1118       cache_invalidation_num_(cache_invalidation_num) {}
1119 
1120 MockHostResolverFactory::~MockHostResolverFactory() = default;
1121 
CreateResolver(HostResolverManager * manager,base::StringPiece host_mapping_rules,bool enable_caching)1122 std::unique_ptr<HostResolver> MockHostResolverFactory::CreateResolver(
1123     HostResolverManager* manager,
1124     base::StringPiece host_mapping_rules,
1125     bool enable_caching) {
1126   DCHECK(host_mapping_rules.empty());
1127 
1128   // Explicit new to access private constructor.
1129   auto resolver = base::WrapUnique(new MockHostResolverBase(
1130       enable_caching && use_caching_, cache_invalidation_num_, rules_));
1131   return resolver;
1132 }
1133 
CreateStandaloneResolver(NetLog * net_log,const HostResolver::ManagerOptions & options,base::StringPiece host_mapping_rules,bool enable_caching)1134 std::unique_ptr<HostResolver> MockHostResolverFactory::CreateStandaloneResolver(
1135     NetLog* net_log,
1136     const HostResolver::ManagerOptions& options,
1137     base::StringPiece host_mapping_rules,
1138     bool enable_caching) {
1139   return CreateResolver(nullptr, host_mapping_rules, enable_caching);
1140 }
1141 
1142 //-----------------------------------------------------------------------------
1143 
Rule(ResolverType resolver_type,base::StringPiece host_pattern,AddressFamily address_family,HostResolverFlags host_resolver_flags,base::StringPiece replacement,std::vector<std::string> dns_aliases,int latency_ms)1144 RuleBasedHostResolverProc::Rule::Rule(ResolverType resolver_type,
1145                                       base::StringPiece host_pattern,
1146                                       AddressFamily address_family,
1147                                       HostResolverFlags host_resolver_flags,
1148                                       base::StringPiece replacement,
1149                                       std::vector<std::string> dns_aliases,
1150                                       int latency_ms)
1151     : resolver_type(resolver_type),
1152       host_pattern(host_pattern),
1153       address_family(address_family),
1154       host_resolver_flags(host_resolver_flags),
1155       replacement(replacement),
1156       dns_aliases(std::move(dns_aliases)),
1157       latency_ms(latency_ms) {
1158   DCHECK(this->dns_aliases != std::vector<std::string>({""}));
1159 }
1160 
1161 RuleBasedHostResolverProc::Rule::Rule(const Rule& other) = default;
1162 
1163 RuleBasedHostResolverProc::Rule::~Rule() = default;
1164 
RuleBasedHostResolverProc(scoped_refptr<HostResolverProc> previous,bool allow_fallback)1165 RuleBasedHostResolverProc::RuleBasedHostResolverProc(
1166     scoped_refptr<HostResolverProc> previous,
1167     bool allow_fallback)
1168     : HostResolverProc(std::move(previous), allow_fallback) {}
1169 
AddRule(base::StringPiece host_pattern,base::StringPiece replacement)1170 void RuleBasedHostResolverProc::AddRule(base::StringPiece host_pattern,
1171                                         base::StringPiece replacement) {
1172   AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1173                           replacement);
1174 }
1175 
AddRuleForAddressFamily(base::StringPiece host_pattern,AddressFamily address_family,base::StringPiece replacement)1176 void RuleBasedHostResolverProc::AddRuleForAddressFamily(
1177     base::StringPiece host_pattern,
1178     AddressFamily address_family,
1179     base::StringPiece replacement) {
1180   DCHECK(!replacement.empty());
1181   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1182   Rule rule(Rule::kResolverTypeSystem, host_pattern, address_family, flags,
1183             replacement, {} /* dns_aliases */, 0);
1184   AddRuleInternal(rule);
1185 }
1186 
AddRuleWithFlags(base::StringPiece host_pattern,base::StringPiece replacement,HostResolverFlags flags,std::vector<std::string> dns_aliases)1187 void RuleBasedHostResolverProc::AddRuleWithFlags(
1188     base::StringPiece host_pattern,
1189     base::StringPiece replacement,
1190     HostResolverFlags flags,
1191     std::vector<std::string> dns_aliases) {
1192   DCHECK(!replacement.empty());
1193   Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1194             flags, replacement, std::move(dns_aliases), 0);
1195   AddRuleInternal(rule);
1196 }
1197 
AddIPLiteralRule(base::StringPiece host_pattern,base::StringPiece ip_literal,base::StringPiece canonical_name)1198 void RuleBasedHostResolverProc::AddIPLiteralRule(
1199     base::StringPiece host_pattern,
1200     base::StringPiece ip_literal,
1201     base::StringPiece canonical_name) {
1202   // Literals are always resolved to themselves by HostResolverImpl,
1203   // consequently we do not support remapping them.
1204   IPAddress ip_address;
1205   DCHECK(!ip_address.AssignFromIPLiteral(host_pattern));
1206   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1207   std::vector<std::string> aliases;
1208   if (!canonical_name.empty()) {
1209     flags |= HOST_RESOLVER_CANONNAME;
1210     aliases.emplace_back(canonical_name);
1211   }
1212 
1213   Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
1214             ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, std::move(aliases),
1215             0);
1216   AddRuleInternal(rule);
1217 }
1218 
AddIPLiteralRuleWithDnsAliases(base::StringPiece host_pattern,base::StringPiece ip_literal,std::vector<std::string> dns_aliases)1219 void RuleBasedHostResolverProc::AddIPLiteralRuleWithDnsAliases(
1220     base::StringPiece host_pattern,
1221     base::StringPiece ip_literal,
1222     std::vector<std::string> dns_aliases) {
1223   // Literals are always resolved to themselves by HostResolverImpl,
1224   // consequently we do not support remapping them.
1225   IPAddress ip_address;
1226   DCHECK(!ip_address.AssignFromIPLiteral(host_pattern));
1227   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1228   if (!dns_aliases.empty())
1229     flags |= HOST_RESOLVER_CANONNAME;
1230 
1231   Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
1232             ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal,
1233             std::move(dns_aliases), 0);
1234   AddRuleInternal(rule);
1235 }
1236 
AddRuleWithLatency(base::StringPiece host_pattern,base::StringPiece replacement,int latency_ms)1237 void RuleBasedHostResolverProc::AddRuleWithLatency(
1238     base::StringPiece host_pattern,
1239     base::StringPiece replacement,
1240     int latency_ms) {
1241   DCHECK(!replacement.empty());
1242   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1243   Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1244             flags, replacement, /*dns_aliases=*/{}, latency_ms);
1245   AddRuleInternal(rule);
1246 }
1247 
AllowDirectLookup(base::StringPiece host_pattern)1248 void RuleBasedHostResolverProc::AllowDirectLookup(
1249     base::StringPiece host_pattern) {
1250   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1251   Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1252             flags, std::string(), /*dns_aliases=*/{}, 0);
1253   AddRuleInternal(rule);
1254 }
1255 
AddSimulatedFailure(base::StringPiece host_pattern,HostResolverFlags flags)1256 void RuleBasedHostResolverProc::AddSimulatedFailure(
1257     base::StringPiece host_pattern,
1258     HostResolverFlags flags) {
1259   Rule rule(Rule::kResolverTypeFail, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1260             flags, std::string(), /*dns_aliases=*/{}, 0);
1261   AddRuleInternal(rule);
1262 }
1263 
AddSimulatedTimeoutFailure(base::StringPiece host_pattern,HostResolverFlags flags)1264 void RuleBasedHostResolverProc::AddSimulatedTimeoutFailure(
1265     base::StringPiece host_pattern,
1266     HostResolverFlags flags) {
1267   Rule rule(Rule::kResolverTypeFailTimeout, host_pattern,
1268             ADDRESS_FAMILY_UNSPECIFIED, flags, std::string(),
1269             /*dns_aliases=*/{}, 0);
1270   AddRuleInternal(rule);
1271 }
1272 
ClearRules()1273 void RuleBasedHostResolverProc::ClearRules() {
1274   CHECK(modifications_allowed_);
1275   base::AutoLock lock(rule_lock_);
1276   rules_.clear();
1277 }
1278 
DisableModifications()1279 void RuleBasedHostResolverProc::DisableModifications() {
1280   modifications_allowed_ = false;
1281 }
1282 
GetRules()1283 RuleBasedHostResolverProc::RuleList RuleBasedHostResolverProc::GetRules() {
1284   RuleList rv;
1285   {
1286     base::AutoLock lock(rule_lock_);
1287     rv = rules_;
1288   }
1289   return rv;
1290 }
1291 
NumResolvesForHostPattern(base::StringPiece host_pattern)1292 size_t RuleBasedHostResolverProc::NumResolvesForHostPattern(
1293     base::StringPiece host_pattern) {
1294   base::AutoLock lock(rule_lock_);
1295   return num_resolves_per_host_pattern_[host_pattern];
1296 }
1297 
Resolve(const std::string & host,AddressFamily address_family,HostResolverFlags host_resolver_flags,AddressList * addrlist,int * os_error)1298 int RuleBasedHostResolverProc::Resolve(const std::string& host,
1299                                        AddressFamily address_family,
1300                                        HostResolverFlags host_resolver_flags,
1301                                        AddressList* addrlist,
1302                                        int* os_error) {
1303   base::AutoLock lock(rule_lock_);
1304   RuleList::iterator r;
1305   for (r = rules_.begin(); r != rules_.end(); ++r) {
1306     bool matches_address_family =
1307         r->address_family == ADDRESS_FAMILY_UNSPECIFIED ||
1308         r->address_family == address_family;
1309     // Ignore HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6, since it should
1310     // have no impact on whether a rule matches.
1311     HostResolverFlags flags =
1312         host_resolver_flags & ~HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
1313     // Flags match if all of the bitflags in host_resolver_flags are enabled
1314     // in the rule's host_resolver_flags. However, the rule may have additional
1315     // flags specified, in which case the flags should still be considered a
1316     // match.
1317     bool matches_flags = (r->host_resolver_flags & flags) == flags;
1318     if (matches_flags && matches_address_family &&
1319         base::MatchPattern(host, r->host_pattern)) {
1320       num_resolves_per_host_pattern_[r->host_pattern]++;
1321 
1322       if (r->latency_ms != 0) {
1323         base::PlatformThread::Sleep(base::Milliseconds(r->latency_ms));
1324       }
1325 
1326       // Remap to a new host.
1327       const std::string& effective_host =
1328           r->replacement.empty() ? host : r->replacement;
1329 
1330       // Apply the resolving function to the remapped hostname.
1331       switch (r->resolver_type) {
1332         case Rule::kResolverTypeFail:
1333           return ERR_NAME_NOT_RESOLVED;
1334         case Rule::kResolverTypeFailTimeout:
1335           return ERR_DNS_TIMED_OUT;
1336         case Rule::kResolverTypeSystem:
1337           EnsureSystemHostResolverCallReady();
1338           return SystemHostResolverCall(effective_host, address_family,
1339                                         host_resolver_flags, addrlist,
1340                                         os_error);
1341         case Rule::kResolverTypeIPLiteral: {
1342           AddressList raw_addr_list;
1343           std::vector<std::string> aliases;
1344           aliases = (!r->dns_aliases.empty())
1345                         ? r->dns_aliases
1346                         : std::vector<std::string>({host});
1347           std::vector<net::IPEndPoint> ip_endpoints;
1348           int result = ParseAddressList(effective_host, &ip_endpoints);
1349           // Filter out addresses with the wrong family.
1350           *addrlist = AddressList();
1351           for (const auto& address : ip_endpoints) {
1352             if (address_family == ADDRESS_FAMILY_UNSPECIFIED ||
1353                 address_family == address.GetFamily()) {
1354               addrlist->push_back(address);
1355             }
1356           }
1357           addrlist->SetDnsAliases(aliases);
1358 
1359           if (result == OK && addrlist->empty())
1360             return ERR_NAME_NOT_RESOLVED;
1361           return result;
1362         }
1363         default:
1364           NOTREACHED();
1365           return ERR_UNEXPECTED;
1366       }
1367     }
1368   }
1369 
1370   return ResolveUsingPrevious(host, address_family, host_resolver_flags,
1371                               addrlist, os_error);
1372 }
1373 
1374 RuleBasedHostResolverProc::~RuleBasedHostResolverProc() = default;
1375 
AddRuleInternal(const Rule & rule)1376 void RuleBasedHostResolverProc::AddRuleInternal(const Rule& rule) {
1377   Rule fixed_rule = rule;
1378   // SystemResolverProc expects valid DNS addresses.
1379   // So for kResolverTypeSystem rules:
1380   // * CHECK that replacement is empty (empty domain names mean use a direct
1381   //   lookup) or a valid DNS name (which includes IP addresses).
1382   // * If the replacement is an IP address, switch to an IP literal rule.
1383   if (fixed_rule.resolver_type == Rule::kResolverTypeSystem) {
1384     CHECK(fixed_rule.replacement.empty() ||
1385           dns_names_util::IsValidDnsName(fixed_rule.replacement));
1386 
1387     IPAddress ip_address;
1388     bool valid_address = ip_address.AssignFromIPLiteral(fixed_rule.replacement);
1389     if (valid_address) {
1390       fixed_rule.resolver_type = Rule::kResolverTypeIPLiteral;
1391     }
1392   }
1393 
1394   CHECK(modifications_allowed_);
1395   base::AutoLock lock(rule_lock_);
1396   rules_.push_back(fixed_rule);
1397 }
1398 
CreateCatchAllHostResolverProc()1399 scoped_refptr<RuleBasedHostResolverProc> CreateCatchAllHostResolverProc() {
1400   auto catchall =
1401       base::MakeRefCounted<RuleBasedHostResolverProc>(/*previous=*/nullptr,
1402                                                       /*allow_fallback=*/false);
1403   // Note that IPv6 lookups fail.
1404   catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost");
1405 
1406   // Next add a rules-based layer that the test controls.
1407   return base::MakeRefCounted<RuleBasedHostResolverProc>(
1408       std::move(catchall), /*allow_fallback=*/false);
1409 }
1410 
1411 //-----------------------------------------------------------------------------
1412 
1413 // Implementation of ResolveHostRequest that tracks cancellations when the
1414 // request is destroyed after being started.
1415 class HangingHostResolver::RequestImpl
1416     : public HostResolver::ResolveHostRequest,
1417       public HostResolver::ProbeRequest {
1418  public:
RequestImpl(base::WeakPtr<HangingHostResolver> resolver)1419   explicit RequestImpl(base::WeakPtr<HangingHostResolver> resolver)
1420       : resolver_(resolver) {}
1421 
1422   RequestImpl(const RequestImpl&) = delete;
1423   RequestImpl& operator=(const RequestImpl&) = delete;
1424 
~RequestImpl()1425   ~RequestImpl() override {
1426     if (is_running_ && resolver_)
1427       resolver_->state_->IncrementNumCancellations();
1428   }
1429 
Start(CompletionOnceCallback callback)1430   int Start(CompletionOnceCallback callback) override { return Start(); }
1431 
Start()1432   int Start() override {
1433     DCHECK(resolver_);
1434     is_running_ = true;
1435     return ERR_IO_PENDING;
1436   }
1437 
GetAddressResults() const1438   const AddressList* GetAddressResults() const override {
1439     base::ImmediateCrash();
1440   }
1441 
GetEndpointResults() const1442   const std::vector<HostResolverEndpointResult>* GetEndpointResults()
1443       const override {
1444     base::ImmediateCrash();
1445   }
1446 
GetTextResults() const1447   const absl::optional<std::vector<std::string>>& GetTextResults()
1448       const override {
1449     base::ImmediateCrash();
1450   }
1451 
GetHostnameResults() const1452   const absl::optional<std::vector<HostPortPair>>& GetHostnameResults()
1453       const override {
1454     base::ImmediateCrash();
1455   }
1456 
GetDnsAliasResults() const1457   const std::set<std::string>* GetDnsAliasResults() const override {
1458     base::ImmediateCrash();
1459   }
1460 
GetResolveErrorInfo() const1461   net::ResolveErrorInfo GetResolveErrorInfo() const override {
1462     base::ImmediateCrash();
1463   }
1464 
GetStaleInfo() const1465   const absl::optional<HostCache::EntryStaleness>& GetStaleInfo()
1466       const override {
1467     base::ImmediateCrash();
1468   }
1469 
ChangeRequestPriority(RequestPriority priority)1470   void ChangeRequestPriority(RequestPriority priority) override {}
1471 
1472  private:
1473   // Use a WeakPtr as the resolver may be destroyed while there are still
1474   // outstanding request objects.
1475   base::WeakPtr<HangingHostResolver> resolver_;
1476   bool is_running_ = false;
1477 };
1478 
1479 HangingHostResolver::State::State() = default;
1480 HangingHostResolver::State::~State() = default;
1481 
HangingHostResolver()1482 HangingHostResolver::HangingHostResolver()
1483     : state_(base::MakeRefCounted<State>()) {}
1484 
1485 HangingHostResolver::~HangingHostResolver() = default;
1486 
OnShutdown()1487 void HangingHostResolver::OnShutdown() {
1488   shutting_down_ = true;
1489 }
1490 
1491 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(url::SchemeHostPort host,NetworkAnonymizationKey network_anonymization_key,NetLogWithSource net_log,absl::optional<ResolveHostParameters> optional_parameters)1492 HangingHostResolver::CreateRequest(
1493     url::SchemeHostPort host,
1494     NetworkAnonymizationKey network_anonymization_key,
1495     NetLogWithSource net_log,
1496     absl::optional<ResolveHostParameters> optional_parameters) {
1497   // TODO(crbug.com/1206799): Propagate scheme and make affect behavior.
1498   return CreateRequest(HostPortPair::FromSchemeHostPort(host),
1499                        network_anonymization_key, net_log, optional_parameters);
1500 }
1501 
1502 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(const HostPortPair & host,const NetworkAnonymizationKey & network_anonymization_key,const NetLogWithSource & source_net_log,const absl::optional<ResolveHostParameters> & optional_parameters)1503 HangingHostResolver::CreateRequest(
1504     const HostPortPair& host,
1505     const NetworkAnonymizationKey& network_anonymization_key,
1506     const NetLogWithSource& source_net_log,
1507     const absl::optional<ResolveHostParameters>& optional_parameters) {
1508   last_host_ = host;
1509   last_network_anonymization_key_ = network_anonymization_key;
1510 
1511   if (shutting_down_)
1512     return CreateFailingRequest(ERR_CONTEXT_SHUT_DOWN);
1513 
1514   if (optional_parameters &&
1515       optional_parameters.value().source == HostResolverSource::LOCAL_ONLY) {
1516     return CreateFailingRequest(ERR_DNS_CACHE_MISS);
1517   }
1518 
1519   return std::make_unique<RequestImpl>(weak_ptr_factory_.GetWeakPtr());
1520 }
1521 
1522 std::unique_ptr<HostResolver::ProbeRequest>
CreateDohProbeRequest()1523 HangingHostResolver::CreateDohProbeRequest() {
1524   if (shutting_down_)
1525     return CreateFailingProbeRequest(ERR_CONTEXT_SHUT_DOWN);
1526 
1527   return std::make_unique<RequestImpl>(weak_ptr_factory_.GetWeakPtr());
1528 }
1529 
SetRequestContext(URLRequestContext * url_request_context)1530 void HangingHostResolver::SetRequestContext(
1531     URLRequestContext* url_request_context) {}
1532 
1533 //-----------------------------------------------------------------------------
1534 
1535 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() = default;
1536 
ScopedDefaultHostResolverProc(HostResolverProc * proc)1537 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc(
1538     HostResolverProc* proc) {
1539   Init(proc);
1540 }
1541 
~ScopedDefaultHostResolverProc()1542 ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() {
1543   HostResolverProc* old_proc =
1544       HostResolverProc::SetDefault(previous_proc_.get());
1545   // The lifetimes of multiple instances must be nested.
1546   CHECK_EQ(old_proc, current_proc_.get());
1547 }
1548 
Init(HostResolverProc * proc)1549 void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) {
1550   current_proc_ = proc;
1551   previous_proc_ = HostResolverProc::SetDefault(current_proc_.get());
1552   current_proc_->SetLastProc(previous_proc_);
1553 }
1554 
1555 }  // namespace net
1556