• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/mock_host_resolver.h"
6 
7 #include <stdint.h>
8 
9 #include <memory>
10 #include <optional>
11 #include <string>
12 #include <string_view>
13 #include <utility>
14 #include <vector>
15 
16 #include "base/check_op.h"
17 #include "base/functional/bind.h"
18 #include "base/functional/callback_helpers.h"
19 #include "base/location.h"
20 #include "base/logging.h"
21 #include "base/memory/ptr_util.h"
22 #include "base/memory/raw_ptr.h"
23 #include "base/memory/ref_counted.h"
24 #include "base/memory/weak_ptr.h"
25 #include "base/no_destructor.h"
26 #include "base/notreached.h"
27 #include "base/strings/pattern.h"
28 #include "base/strings/string_split.h"
29 #include "base/strings/string_util.h"
30 #include "base/task/single_thread_task_runner.h"
31 #include "base/threading/platform_thread.h"
32 #include "base/time/default_tick_clock.h"
33 #include "base/time/tick_clock.h"
34 #include "base/time/time.h"
35 #include "base/types/optional_util.h"
36 #include "build/build_config.h"
37 #include "net/base/address_family.h"
38 #include "net/base/address_list.h"
39 #include "net/base/host_port_pair.h"
40 #include "net/base/ip_address.h"
41 #include "net/base/ip_endpoint.h"
42 #include "net/base/net_errors.h"
43 #include "net/base/net_export.h"
44 #include "net/base/network_anonymization_key.h"
45 #include "net/base/request_priority.h"
46 #include "net/base/test_completion_callback.h"
47 #include "net/dns/dns_alias_utility.h"
48 #include "net/dns/dns_names_util.h"
49 #include "net/dns/dns_util.h"
50 #include "net/dns/host_cache.h"
51 #include "net/dns/host_resolver.h"
52 #include "net/dns/host_resolver_manager.h"
53 #include "net/dns/host_resolver_system_task.h"
54 #include "net/dns/https_record_rdata.h"
55 #include "net/dns/public/dns_query_type.h"
56 #include "net/dns/public/host_resolver_results.h"
57 #include "net/dns/public/host_resolver_source.h"
58 #include "net/dns/public/mdns_listener_update_type.h"
59 #include "net/dns/public/resolve_error_info.h"
60 #include "net/dns/public/secure_dns_policy.h"
61 #include "net/log/net_log_with_source.h"
62 #include "net/url_request/url_request_context.h"
63 #include "third_party/abseil-cpp/absl/types/variant.h"
64 #include "url/scheme_host_port.h"
65 
66 #if BUILDFLAG(IS_WIN)
67 #include "net/base/winsock_init.h"
68 #endif
69 
70 namespace net {
71 
72 namespace {
73 
74 // Cache size for the MockCachingHostResolver.
75 const unsigned kMaxCacheEntries = 100;
76 // TTL for the successful resolutions. Failures are not cached.
77 const unsigned kCacheEntryTTLSeconds = 60;
78 
GetCacheHost(const HostResolver::Host & endpoint)79 absl::variant<url::SchemeHostPort, std::string> GetCacheHost(
80     const HostResolver::Host& endpoint) {
81   if (endpoint.HasScheme()) {
82     return endpoint.AsSchemeHostPort();
83   }
84 
85   return endpoint.GetHostname();
86 }
87 
CreateCacheEntry(std::string_view canonical_name,const std::vector<HostResolverEndpointResult> & endpoint_results,const std::set<std::string> & aliases)88 std::optional<HostCache::Entry> CreateCacheEntry(
89     std::string_view canonical_name,
90     const std::vector<HostResolverEndpointResult>& endpoint_results,
91     const std::set<std::string>& aliases) {
92   std::optional<std::vector<net::IPEndPoint>> ip_endpoints;
93   std::multimap<HttpsRecordPriority, ConnectionEndpointMetadata>
94       endpoint_metadatas;
95   for (const auto& endpoint_result : endpoint_results) {
96     if (!ip_endpoints) {
97       ip_endpoints = endpoint_result.ip_endpoints;
98     } else {
99       // TODO(crbug.com/40203587): Support caching different IP endpoints
100       // resutls.
101       CHECK(*ip_endpoints == endpoint_result.ip_endpoints)
102           << "Currently caching MockHostResolver only supports same IP "
103              "endpoints results.";
104     }
105 
106     if (!endpoint_result.metadata.supported_protocol_alpns.empty()) {
107       endpoint_metadatas.emplace(/*priority=*/1, endpoint_result.metadata);
108     }
109   }
110   DCHECK(ip_endpoints);
111   auto endpoint_entry = HostCache::Entry(OK, *ip_endpoints, aliases,
112                                          HostCache::Entry::SOURCE_UNKNOWN);
113   endpoint_entry.set_canonical_names(std::set{std::string(canonical_name)});
114   if (endpoint_metadatas.empty()) {
115     return endpoint_entry;
116   }
117   return HostCache::Entry::MergeEntries(
118       HostCache::Entry(OK, std::move(endpoint_metadatas),
119                        HostCache::Entry::SOURCE_UNKNOWN),
120       endpoint_entry);
121 }
122 }  // namespace
123 
ParseAddressList(std::string_view host_list,std::vector<net::IPEndPoint> * ip_endpoints)124 int ParseAddressList(std::string_view host_list,
125                      std::vector<net::IPEndPoint>* ip_endpoints) {
126   ip_endpoints->clear();
127   for (std::string_view address : base::SplitStringPiece(
128            host_list, ",", base::TRIM_WHITESPACE, base::SPLIT_WANT_ALL)) {
129     IPAddress ip_address;
130     if (!ip_address.AssignFromIPLiteral(address)) {
131       LOG(WARNING) << "Not a supported IP literal: " << address;
132       return ERR_UNEXPECTED;
133     }
134     ip_endpoints->push_back(IPEndPoint(ip_address, 0));
135   }
136   return OK;
137 }
138 
139 // Base class for
140 // MockHostResolverBase::{RequestImpl,ServiceEndpointRequestImpl}.
141 class MockHostResolverBase::RequestBase {
142  public:
RequestBase(Host request_endpoint,const NetworkAnonymizationKey & network_anonymization_key,const std::optional<ResolveHostParameters> & optional_parameters,base::WeakPtr<MockHostResolverBase> resolver)143   RequestBase(Host request_endpoint,
144               const NetworkAnonymizationKey& network_anonymization_key,
145               const std::optional<ResolveHostParameters>& optional_parameters,
146               base::WeakPtr<MockHostResolverBase> resolver)
147       : request_endpoint_(std::move(request_endpoint)),
148         network_anonymization_key_(network_anonymization_key),
149         parameters_(optional_parameters ? optional_parameters.value()
150                                         : ResolveHostParameters()),
151         priority_(parameters_.initial_priority),
152         host_resolver_flags_(ParametersToHostResolverFlags(parameters_)),
153         resolve_error_info_(ResolveErrorInfo(ERR_IO_PENDING)),
154         resolver_(resolver) {}
155 
156   RequestBase(const RequestBase&) = delete;
157   RequestBase& operator=(const RequestBase&) = delete;
158 
~RequestBase()159   virtual ~RequestBase() {
160     if (id_ > 0) {
161       if (resolver_) {
162         resolver_->DetachRequest(id_);
163       }
164       id_ = 0;
165       resolver_ = nullptr;
166     }
167   }
168 
DetachFromResolver()169   void DetachFromResolver() {
170     id_ = 0;
171     resolver_ = nullptr;
172   }
173 
SetError(int error)174   void SetError(int error) {
175     // Should only be called before request is marked completed.
176     DCHECK(!complete_);
177     resolve_error_info_ = ResolveErrorInfo(error);
178   }
179 
180   // Sets `endpoint_results_`, `fixed_up_dns_alias_results_`,
181   // `address_results_` and `staleness_` after fixing them up.
182   // Also sets `error` to OK.
SetEndpointResults(std::vector<HostResolverEndpointResult> endpoint_results,std::set<std::string> aliases,std::optional<HostCache::EntryStaleness> staleness)183   void SetEndpointResults(
184       std::vector<HostResolverEndpointResult> endpoint_results,
185       std::set<std::string> aliases,
186       std::optional<HostCache::EntryStaleness> staleness) {
187     DCHECK(!complete_);
188     DCHECK(!endpoint_results_);
189     DCHECK(!parameters_.is_speculative);
190 
191     endpoint_results_ = std::move(endpoint_results);
192     for (auto& result : *endpoint_results_) {
193       result.ip_endpoints = FixupEndPoints(result.ip_endpoints);
194     }
195 
196     fixed_up_dns_alias_results_ = FixupAliases(aliases);
197 
198     // `HostResolver` implementations are expected to provide an `AddressList`
199     // result whenever `HostResolverEndpointResult` is also available.
200     address_results_ = EndpointResultToAddressList(
201         *endpoint_results_, *fixed_up_dns_alias_results_);
202 
203     staleness_ = std::move(staleness);
204 
205     SetError(OK);
206     SetEndpointResultsInternal();
207   }
208 
OnAsyncCompleted(size_t id,int error)209   void OnAsyncCompleted(size_t id, int error) {
210     DCHECK_EQ(id_, id);
211     id_ = 0;
212 
213     // Check that error information has been set and that the top-level error
214     // code is valid.
215     DCHECK(resolve_error_info_.error != ERR_IO_PENDING);
216     DCHECK(error == OK || error == ERR_NAME_NOT_RESOLVED ||
217            error == ERR_DNS_NAME_HTTPS_ONLY);
218 
219     DCHECK(!complete_);
220     complete_ = true;
221 
222     DCHECK(callback_);
223     std::move(callback_).Run(error);
224   }
225 
request_endpoint() const226   const Host& request_endpoint() const { return request_endpoint_; }
227 
network_anonymization_key() const228   const NetworkAnonymizationKey& network_anonymization_key() const {
229     return network_anonymization_key_;
230   }
231 
parameters() const232   const ResolveHostParameters& parameters() const { return parameters_; }
233 
host_resolver_flags() const234   int host_resolver_flags() const { return host_resolver_flags_; }
235 
id()236   size_t id() { return id_; }
237 
priority() const238   RequestPriority priority() const { return priority_; }
239 
set_id(size_t id)240   void set_id(size_t id) {
241     DCHECK_GT(id, 0u);
242     DCHECK_EQ(0u, id_);
243 
244     id_ = id;
245   }
246 
complete()247   bool complete() { return complete_; }
248 
249   // Similar get GetAddressResults() and GetResolveErrorInfo(), but only exposed
250   // through the HostResolver::ResolveHostRequest interface, and don't have the
251   // DCHECKs that `complete_` is true.
address_results() const252   const std::optional<AddressList>& address_results() const {
253     return address_results_;
254   }
resolve_error_info() const255   ResolveErrorInfo resolve_error_info() const { return resolve_error_info_; }
256 
257  protected:
FixupEndPoints(const std::vector<IPEndPoint> & endpoints)258   std::vector<IPEndPoint> FixupEndPoints(
259       const std::vector<IPEndPoint>& endpoints) {
260     std::vector<IPEndPoint> corrected;
261     for (const IPEndPoint& endpoint : endpoints) {
262       DCHECK_NE(endpoint.GetFamily(), ADDRESS_FAMILY_UNSPECIFIED);
263       if (parameters_.dns_query_type == DnsQueryType::UNSPECIFIED ||
264           parameters_.dns_query_type ==
265               AddressFamilyToDnsQueryType(endpoint.GetFamily())) {
266         if (endpoint.port() == 0) {
267           corrected.emplace_back(endpoint.address(),
268                                  request_endpoint_.GetPort());
269         } else {
270           corrected.push_back(endpoint);
271         }
272       }
273     }
274     return corrected;
275   }
276 
FixupAliases(const std::set<std::string> aliases)277   std::set<std::string> FixupAliases(const std::set<std::string> aliases) {
278     if (aliases.empty()) {
279       return std::set<std::string>{
280           std::string(request_endpoint_.GetHostnameWithoutBrackets())};
281     }
282     return aliases;
283   }
284 
285   // Helper method of SetEndpointResults() for subclass specific logic.
SetEndpointResultsInternal()286   virtual void SetEndpointResultsInternal() {}
287 
288   const Host request_endpoint_;
289   const NetworkAnonymizationKey network_anonymization_key_;
290   const ResolveHostParameters parameters_;
291   RequestPriority priority_;
292   int host_resolver_flags_;
293 
294   std::optional<AddressList> address_results_;
295   std::optional<std::vector<HostResolverEndpointResult>> endpoint_results_;
296   std::optional<std::set<std::string>> fixed_up_dns_alias_results_;
297   std::optional<HostCache::EntryStaleness> staleness_;
298   ResolveErrorInfo resolve_error_info_;
299 
300   // Used while stored with the resolver for async resolution.  Otherwise 0.
301   size_t id_ = 0;
302 
303   CompletionOnceCallback callback_;
304   // Use a WeakPtr as the resolver may be destroyed while there are still
305   // outstanding request objects.
306   base::WeakPtr<MockHostResolverBase> resolver_;
307   bool complete_ = false;
308 };
309 
310 class MockHostResolverBase::RequestImpl
311     : public RequestBase,
312       public HostResolver::ResolveHostRequest {
313  public:
RequestImpl(Host request_endpoint,const NetworkAnonymizationKey & network_anonymization_key,const std::optional<ResolveHostParameters> & optional_parameters,base::WeakPtr<MockHostResolverBase> resolver)314   RequestImpl(Host request_endpoint,
315               const NetworkAnonymizationKey& network_anonymization_key,
316               const std::optional<ResolveHostParameters>& optional_parameters,
317               base::WeakPtr<MockHostResolverBase> resolver)
318       : RequestBase(std::move(request_endpoint),
319                     network_anonymization_key,
320                     optional_parameters,
321                     std::move(resolver)) {}
322 
323   RequestImpl(const RequestImpl&) = delete;
324   RequestImpl& operator=(const RequestImpl&) = delete;
325 
326   ~RequestImpl() override = default;
327 
Start(CompletionOnceCallback callback)328   int Start(CompletionOnceCallback callback) override {
329     DCHECK(callback);
330     // Start() may only be called once per request.
331     DCHECK_EQ(0u, id_);
332     DCHECK(!complete_);
333     DCHECK(!callback_);
334     // Parent HostResolver must still be alive to call Start().
335     DCHECK(resolver_);
336 
337     int rv = resolver_->Resolve(this);
338     DCHECK(!complete_);
339     if (rv == ERR_IO_PENDING) {
340       DCHECK_GT(id_, 0u);
341       callback_ = std::move(callback);
342     } else {
343       DCHECK_EQ(0u, id_);
344       complete_ = true;
345     }
346 
347     return rv;
348   }
349 
GetAddressResults() const350   const AddressList* GetAddressResults() const override {
351     DCHECK(complete_);
352     return base::OptionalToPtr(address_results_);
353   }
354 
GetEndpointResults() const355   const std::vector<HostResolverEndpointResult>* GetEndpointResults()
356       const override {
357     DCHECK(complete_);
358     return base::OptionalToPtr(endpoint_results_);
359   }
360 
GetTextResults() const361   const std::vector<std::string>* GetTextResults() const override {
362     DCHECK(complete_);
363     static const base::NoDestructor<std::vector<std::string>> empty_result;
364     return empty_result.get();
365   }
366 
GetHostnameResults() const367   const std::vector<HostPortPair>* GetHostnameResults() const override {
368     DCHECK(complete_);
369     static const base::NoDestructor<std::vector<HostPortPair>> empty_result;
370     return empty_result.get();
371   }
372 
GetDnsAliasResults() const373   const std::set<std::string>* GetDnsAliasResults() const override {
374     DCHECK(complete_);
375     return base::OptionalToPtr(fixed_up_dns_alias_results_);
376   }
377 
GetResolveErrorInfo() const378   net::ResolveErrorInfo GetResolveErrorInfo() const override {
379     DCHECK(complete_);
380     return resolve_error_info_;
381   }
382 
GetStaleInfo() const383   const std::optional<HostCache::EntryStaleness>& GetStaleInfo()
384       const override {
385     DCHECK(complete_);
386     return staleness_;
387   }
388 
ChangeRequestPriority(RequestPriority priority)389   void ChangeRequestPriority(RequestPriority priority) override {
390     priority_ = priority;
391   }
392 };
393 
394 class MockHostResolverBase::ServiceEndpointRequestImpl
395     : public RequestBase,
396       public HostResolver::ServiceEndpointRequest {
397  public:
ServiceEndpointRequestImpl(Host request_endpoint,const NetworkAnonymizationKey & network_anonymization_key,const std::optional<ResolveHostParameters> & optional_parameters,base::WeakPtr<MockHostResolverBase> resolver)398   ServiceEndpointRequestImpl(
399       Host request_endpoint,
400       const NetworkAnonymizationKey& network_anonymization_key,
401       const std::optional<ResolveHostParameters>& optional_parameters,
402       base::WeakPtr<MockHostResolverBase> resolver)
403       : RequestBase(std::move(request_endpoint),
404                     network_anonymization_key,
405                     optional_parameters,
406                     std::move(resolver)) {}
407 
408   ServiceEndpointRequestImpl(const ServiceEndpointRequestImpl&) = delete;
409   ServiceEndpointRequestImpl& operator=(const ServiceEndpointRequestImpl&) =
410       delete;
411 
412   ~ServiceEndpointRequestImpl() override = default;
413 
414   // HostResolver::ServiceEndpointRequest implementations:
Start(Delegate * delegate)415   int Start(Delegate* delegate) override {
416     CHECK(delegate);
417     CHECK(!delegate_);
418     CHECK_EQ(id_, 0u);
419     CHECK(!complete_);
420     CHECK(resolver_);
421 
422     int rv = resolver_->Resolve(this);
423     DCHECK(!complete_);
424     if (rv == ERR_IO_PENDING) {
425       CHECK_GT(id_, 0u);
426       delegate_ = delegate;
427       callback_ = base::BindOnce(
428           &ServiceEndpointRequestImpl::NotifyDelegateOfCompletion,
429           weak_ptr_factory_.GetWeakPtr());
430     } else {
431       CHECK_EQ(id_, 0u);
432       complete_ = true;
433     }
434 
435     return rv;
436   }
437 
GetEndpointResults()438   const std::vector<ServiceEndpoint>& GetEndpointResults() override {
439     return service_endpoint_results_;
440   }
441 
GetDnsAliasResults()442   const std::set<std::string>& GetDnsAliasResults() override {
443     if (fixed_up_dns_alias_results_.has_value()) {
444       return *fixed_up_dns_alias_results_;
445     }
446     static const base::NoDestructor<std::set<std::string>> kEmptyDnsAliases;
447     return *kEmptyDnsAliases.get();
448   }
449 
EndpointsCryptoReady()450   bool EndpointsCryptoReady() override { return true; }
451 
GetResolveErrorInfo()452   ResolveErrorInfo GetResolveErrorInfo() override {
453     return resolve_error_info_;
454   }
455 
ChangeRequestPriority(RequestPriority priority)456   void ChangeRequestPriority(RequestPriority priority) override {
457     priority_ = priority;
458   }
459 
460  private:
SetEndpointResultsInternal()461   void SetEndpointResultsInternal() override {
462     if (!endpoint_results_.has_value()) {
463       return;
464     }
465 
466     std::vector<ServiceEndpoint> service_endpoints;
467     for (const auto& endpoint : *endpoint_results_) {
468       std::vector<IPEndPoint> ipv4_endpoints;
469       std::vector<IPEndPoint> ipv6_endpoints;
470       for (const auto& ip_endpoint : endpoint.ip_endpoints) {
471         if (ip_endpoint.address().IsIPv6()) {
472           ipv6_endpoints.emplace_back(ip_endpoint);
473         } else {
474           ipv4_endpoints.emplace_back(ip_endpoint);
475         }
476       }
477       service_endpoints.emplace_back(std::move(ipv4_endpoints),
478                                      std::move(ipv6_endpoints),
479                                      endpoint.metadata);
480     }
481 
482     service_endpoint_results_ = std::move(service_endpoints);
483   }
484 
NotifyDelegateOfCompletion(int rv)485   void NotifyDelegateOfCompletion(int rv) {
486     CHECK(delegate_);
487     CHECK_NE(rv, ERR_IO_PENDING);
488     delegate_.ExtractAsDangling()->OnServiceEndpointRequestFinished(rv);
489   }
490 
491   raw_ptr<Delegate> delegate_;
492   std::vector<ServiceEndpoint> service_endpoint_results_;
493 
494   base::WeakPtrFactory<ServiceEndpointRequestImpl> weak_ptr_factory_{this};
495 };
496 
497 class MockHostResolverBase::ProbeRequestImpl
498     : public HostResolver::ProbeRequest {
499  public:
ProbeRequestImpl(base::WeakPtr<MockHostResolverBase> resolver)500   explicit ProbeRequestImpl(base::WeakPtr<MockHostResolverBase> resolver)
501       : resolver_(std::move(resolver)) {}
502 
503   ProbeRequestImpl(const ProbeRequestImpl&) = delete;
504   ProbeRequestImpl& operator=(const ProbeRequestImpl&) = delete;
505 
~ProbeRequestImpl()506   ~ProbeRequestImpl() override {
507     if (resolver_) {
508       resolver_->state_->ClearDohProbeRequestIfMatching(this);
509     }
510   }
511 
Start()512   int Start() override {
513     DCHECK(resolver_);
514     resolver_->state_->set_doh_probe_request(this);
515 
516     return ERR_IO_PENDING;
517   }
518 
519  private:
520   base::WeakPtr<MockHostResolverBase> resolver_;
521 };
522 
523 class MockHostResolverBase::MdnsListenerImpl
524     : public HostResolver::MdnsListener {
525  public:
MdnsListenerImpl(const HostPortPair & host,DnsQueryType query_type,base::WeakPtr<MockHostResolverBase> resolver)526   MdnsListenerImpl(const HostPortPair& host,
527                    DnsQueryType query_type,
528                    base::WeakPtr<MockHostResolverBase> resolver)
529       : host_(host), query_type_(query_type), resolver_(resolver) {
530     DCHECK_NE(DnsQueryType::UNSPECIFIED, query_type_);
531     DCHECK(resolver_);
532   }
533 
~MdnsListenerImpl()534   ~MdnsListenerImpl() override {
535     if (resolver_)
536       resolver_->RemoveCancelledListener(this);
537   }
538 
Start(Delegate * delegate)539   int Start(Delegate* delegate) override {
540     DCHECK(delegate);
541     DCHECK(!delegate_);
542     DCHECK(resolver_);
543 
544     delegate_ = delegate;
545     resolver_->AddListener(this);
546 
547     return OK;
548   }
549 
TriggerAddressResult(MdnsListenerUpdateType update_type,IPEndPoint address)550   void TriggerAddressResult(MdnsListenerUpdateType update_type,
551                             IPEndPoint address) {
552     delegate_->OnAddressResult(update_type, query_type_, std::move(address));
553   }
554 
TriggerTextResult(MdnsListenerUpdateType update_type,std::vector<std::string> text_records)555   void TriggerTextResult(MdnsListenerUpdateType update_type,
556                          std::vector<std::string> text_records) {
557     delegate_->OnTextResult(update_type, query_type_, std::move(text_records));
558   }
559 
TriggerHostnameResult(MdnsListenerUpdateType update_type,HostPortPair host)560   void TriggerHostnameResult(MdnsListenerUpdateType update_type,
561                              HostPortPair host) {
562     delegate_->OnHostnameResult(update_type, query_type_, std::move(host));
563   }
564 
TriggerUnhandledResult(MdnsListenerUpdateType update_type)565   void TriggerUnhandledResult(MdnsListenerUpdateType update_type) {
566     delegate_->OnUnhandledResult(update_type, query_type_);
567   }
568 
host() const569   const HostPortPair& host() const { return host_; }
query_type() const570   DnsQueryType query_type() const { return query_type_; }
571 
572  private:
573   const HostPortPair host_;
574   const DnsQueryType query_type_;
575 
576   raw_ptr<Delegate> delegate_ = nullptr;
577 
578   // Use a WeakPtr as the resolver may be destroyed while there are still
579   // outstanding listener objects.
580   base::WeakPtr<MockHostResolverBase> resolver_;
581 };
582 
583 MockHostResolverBase::RuleResolver::RuleKey::RuleKey() = default;
584 
585 MockHostResolverBase::RuleResolver::RuleKey::~RuleKey() = default;
586 
587 MockHostResolverBase::RuleResolver::RuleKey::RuleKey(const RuleKey&) = default;
588 
589 MockHostResolverBase::RuleResolver::RuleKey&
590 MockHostResolverBase::RuleResolver::RuleKey::operator=(const RuleKey&) =
591     default;
592 
593 MockHostResolverBase::RuleResolver::RuleKey::RuleKey(RuleKey&&) = default;
594 
595 MockHostResolverBase::RuleResolver::RuleKey&
596 MockHostResolverBase::RuleResolver::RuleKey::operator=(RuleKey&&) = default;
597 
598 MockHostResolverBase::RuleResolver::RuleResult::RuleResult() = default;
599 
RuleResult(std::vector<HostResolverEndpointResult> endpoints,std::set<std::string> aliases)600 MockHostResolverBase::RuleResolver::RuleResult::RuleResult(
601     std::vector<HostResolverEndpointResult> endpoints,
602     std::set<std::string> aliases)
603     : endpoints(std::move(endpoints)), aliases(std::move(aliases)) {}
604 
605 MockHostResolverBase::RuleResolver::RuleResult::~RuleResult() = default;
606 
607 MockHostResolverBase::RuleResolver::RuleResult::RuleResult(const RuleResult&) =
608     default;
609 
610 MockHostResolverBase::RuleResolver::RuleResult&
611 MockHostResolverBase::RuleResolver::RuleResult::operator=(const RuleResult&) =
612     default;
613 
614 MockHostResolverBase::RuleResolver::RuleResult::RuleResult(RuleResult&&) =
615     default;
616 
617 MockHostResolverBase::RuleResolver::RuleResult&
618 MockHostResolverBase::RuleResolver::RuleResult::operator=(RuleResult&&) =
619     default;
620 
RuleResolver(std::optional<RuleResultOrError> default_result)621 MockHostResolverBase::RuleResolver::RuleResolver(
622     std::optional<RuleResultOrError> default_result)
623     : default_result_(std::move(default_result)) {}
624 
625 MockHostResolverBase::RuleResolver::~RuleResolver() = default;
626 
627 MockHostResolverBase::RuleResolver::RuleResolver(const RuleResolver&) = default;
628 
629 MockHostResolverBase::RuleResolver&
630 MockHostResolverBase::RuleResolver::operator=(const RuleResolver&) = default;
631 
632 MockHostResolverBase::RuleResolver::RuleResolver(RuleResolver&&) = default;
633 
634 MockHostResolverBase::RuleResolver&
635 MockHostResolverBase::RuleResolver::operator=(RuleResolver&&) = default;
636 
637 const MockHostResolverBase::RuleResolver::RuleResultOrError&
Resolve(const Host & request_endpoint,DnsQueryTypeSet request_types,HostResolverSource request_source) const638 MockHostResolverBase::RuleResolver::Resolve(
639     const Host& request_endpoint,
640     DnsQueryTypeSet request_types,
641     HostResolverSource request_source) const {
642   for (const auto& rule : rules_) {
643     const RuleKey& key = rule.first;
644     const RuleResultOrError& result = rule.second;
645 
646     if (absl::holds_alternative<RuleKey::NoScheme>(key.scheme) &&
647         request_endpoint.HasScheme()) {
648       continue;
649     }
650 
651     if (key.port.has_value() &&
652         key.port.value() != request_endpoint.GetPort()) {
653       continue;
654     }
655 
656     DCHECK(!key.query_type.has_value() ||
657            key.query_type.value() != DnsQueryType::UNSPECIFIED);
658     if (key.query_type.has_value() &&
659         !request_types.Has(key.query_type.value())) {
660       continue;
661     }
662 
663     if (key.query_source.has_value() &&
664         request_source != key.query_source.value()) {
665       continue;
666     }
667 
668     if (absl::holds_alternative<RuleKey::Scheme>(key.scheme) &&
669         (!request_endpoint.HasScheme() ||
670          request_endpoint.GetScheme() !=
671              absl::get<RuleKey::Scheme>(key.scheme))) {
672       continue;
673     }
674 
675     if (!base::MatchPattern(request_endpoint.GetHostnameWithoutBrackets(),
676                             key.hostname_pattern)) {
677       continue;
678     }
679 
680     return result;
681   }
682 
683   if (default_result_)
684     return default_result_.value();
685 
686   NOTREACHED() << "Request " << request_endpoint.GetHostname()
687                << " did not match any MockHostResolver rules.";
688 }
689 
ClearRules()690 void MockHostResolverBase::RuleResolver::ClearRules() {
691   rules_.clear();
692 }
693 
694 // static
695 MockHostResolverBase::RuleResolver::RuleResultOrError
GetLocalhostResult()696 MockHostResolverBase::RuleResolver::GetLocalhostResult() {
697   HostResolverEndpointResult endpoint;
698   endpoint.ip_endpoints = {IPEndPoint(IPAddress::IPv4Localhost(), /*port=*/0)};
699   return RuleResult(std::vector{endpoint});
700 }
701 
AddRule(RuleKey key,RuleResultOrError result)702 void MockHostResolverBase::RuleResolver::AddRule(RuleKey key,
703                                                  RuleResultOrError result) {
704   // Literals are always resolved to themselves by MockHostResolverBase,
705   // consequently we do not support remapping them.
706   IPAddress ip_address;
707   DCHECK(!ip_address.AssignFromIPLiteral(key.hostname_pattern));
708 
709   CHECK(rules_.emplace(std::move(key), std::move(result)).second)
710       << "Duplicate rule key";
711 }
712 
AddRule(RuleKey key,std::string_view ip_literal)713 void MockHostResolverBase::RuleResolver::AddRule(RuleKey key,
714                                                  std::string_view ip_literal) {
715   std::vector<HostResolverEndpointResult> endpoints;
716   endpoints.emplace_back();
717   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
718   AddRule(std::move(key), RuleResult(std::move(endpoints)));
719 }
720 
AddRule(std::string_view hostname_pattern,RuleResultOrError result)721 void MockHostResolverBase::RuleResolver::AddRule(
722     std::string_view hostname_pattern,
723     RuleResultOrError result) {
724   RuleKey key;
725   key.hostname_pattern = std::string(hostname_pattern);
726   AddRule(std::move(key), std::move(result));
727 }
728 
AddRule(std::string_view hostname_pattern,std::string_view ip_literal)729 void MockHostResolverBase::RuleResolver::AddRule(
730     std::string_view hostname_pattern,
731     std::string_view ip_literal) {
732   std::vector<HostResolverEndpointResult> endpoints;
733   endpoints.emplace_back();
734   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
735   AddRule(hostname_pattern, RuleResult(std::move(endpoints)));
736 }
737 
AddRule(std::string_view hostname_pattern,Error error)738 void MockHostResolverBase::RuleResolver::AddRule(
739     std::string_view hostname_pattern,
740     Error error) {
741   RuleKey key;
742   key.hostname_pattern = std::string(hostname_pattern);
743 
744   AddRule(std::move(key), error);
745 }
746 
AddIPLiteralRule(std::string_view hostname_pattern,std::string_view ip_literal,std::string_view canonical_name)747 void MockHostResolverBase::RuleResolver::AddIPLiteralRule(
748     std::string_view hostname_pattern,
749     std::string_view ip_literal,
750     std::string_view canonical_name) {
751   RuleKey key;
752   key.hostname_pattern = std::string(hostname_pattern);
753 
754   std::set<std::string> aliases;
755   if (!canonical_name.empty())
756     aliases.emplace(canonical_name);
757 
758   std::vector<HostResolverEndpointResult> endpoints;
759   endpoints.emplace_back();
760   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
761   AddRule(std::move(key), RuleResult(std::move(endpoints), std::move(aliases)));
762 }
763 
AddIPLiteralRuleWithDnsAliases(std::string_view hostname_pattern,std::string_view ip_literal,std::vector<std::string> dns_aliases)764 void MockHostResolverBase::RuleResolver::AddIPLiteralRuleWithDnsAliases(
765     std::string_view hostname_pattern,
766     std::string_view ip_literal,
767     std::vector<std::string> dns_aliases) {
768   std::vector<HostResolverEndpointResult> endpoints;
769   endpoints.emplace_back();
770   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
771   AddRule(hostname_pattern,
772           RuleResult(
773               std::move(endpoints),
774               std::set<std::string>(dns_aliases.begin(), dns_aliases.end())));
775 }
776 
AddIPLiteralRuleWithDnsAliases(std::string_view hostname_pattern,std::string_view ip_literal,std::set<std::string> dns_aliases)777 void MockHostResolverBase::RuleResolver::AddIPLiteralRuleWithDnsAliases(
778     std::string_view hostname_pattern,
779     std::string_view ip_literal,
780     std::set<std::string> dns_aliases) {
781   std::vector<std::string> aliases_vector;
782   base::ranges::move(dns_aliases, std::back_inserter(aliases_vector));
783 
784   AddIPLiteralRuleWithDnsAliases(hostname_pattern, ip_literal,
785                                  std::move(aliases_vector));
786 }
787 
AddSimulatedFailure(std::string_view hostname_pattern)788 void MockHostResolverBase::RuleResolver::AddSimulatedFailure(
789     std::string_view hostname_pattern) {
790   AddRule(hostname_pattern, ERR_NAME_NOT_RESOLVED);
791 }
792 
AddSimulatedTimeoutFailure(std::string_view hostname_pattern)793 void MockHostResolverBase::RuleResolver::AddSimulatedTimeoutFailure(
794     std::string_view hostname_pattern) {
795   AddRule(hostname_pattern, ERR_DNS_TIMED_OUT);
796 }
797 
AddRuleWithFlags(std::string_view host_pattern,std::string_view ip_literal,HostResolverFlags,std::vector<std::string> dns_aliases)798 void MockHostResolverBase::RuleResolver::AddRuleWithFlags(
799     std::string_view host_pattern,
800     std::string_view ip_literal,
801     HostResolverFlags /*flags*/,
802     std::vector<std::string> dns_aliases) {
803   std::vector<HostResolverEndpointResult> endpoints;
804   endpoints.emplace_back();
805   CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
806   AddRule(host_pattern, RuleResult(std::move(endpoints),
807                                    std::set<std::string>(dns_aliases.begin(),
808                                                          dns_aliases.end())));
809 }
810 
811 MockHostResolverBase::State::State() = default;
812 MockHostResolverBase::State::~State() = default;
813 
~MockHostResolverBase()814 MockHostResolverBase::~MockHostResolverBase() {
815   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
816 
817   // Sanity check that pending requests are always cleaned up, by waiting for
818   // completion, manually cancelling, or calling OnShutdown().
819   DCHECK(!state_->has_pending_requests());
820 }
821 
OnShutdown()822 void MockHostResolverBase::OnShutdown() {
823   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
824 
825   // Cancel all pending requests.
826   for (auto& request : state_->mutable_requests()) {
827     request.second->DetachFromResolver();
828   }
829   state_->mutable_requests().clear();
830 
831   // Prevent future requests by clearing resolution rules and the cache.
832   rule_resolver_.ClearRules();
833   cache_ = nullptr;
834 
835   state_->ClearDohProbeRequest();
836 }
837 
838 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(url::SchemeHostPort host,NetworkAnonymizationKey network_anonymization_key,NetLogWithSource net_log,std::optional<ResolveHostParameters> optional_parameters)839 MockHostResolverBase::CreateRequest(
840     url::SchemeHostPort host,
841     NetworkAnonymizationKey network_anonymization_key,
842     NetLogWithSource net_log,
843     std::optional<ResolveHostParameters> optional_parameters) {
844   return std::make_unique<RequestImpl>(
845       Host(std::move(host)), network_anonymization_key, optional_parameters,
846       weak_ptr_factory_.GetWeakPtr());
847 }
848 
849 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(const HostPortPair & host,const NetworkAnonymizationKey & network_anonymization_key,const NetLogWithSource & source_net_log,const std::optional<ResolveHostParameters> & optional_parameters)850 MockHostResolverBase::CreateRequest(
851     const HostPortPair& host,
852     const NetworkAnonymizationKey& network_anonymization_key,
853     const NetLogWithSource& source_net_log,
854     const std::optional<ResolveHostParameters>& optional_parameters) {
855   return std::make_unique<RequestImpl>(Host(host), network_anonymization_key,
856                                        optional_parameters,
857                                        weak_ptr_factory_.GetWeakPtr());
858 }
859 
860 std::unique_ptr<HostResolver::ServiceEndpointRequest>
CreateServiceEndpointRequest(Host host,NetworkAnonymizationKey network_anonymization_key,NetLogWithSource net_log,ResolveHostParameters parameters)861 MockHostResolverBase::CreateServiceEndpointRequest(
862     Host host,
863     NetworkAnonymizationKey network_anonymization_key,
864     NetLogWithSource net_log,
865     ResolveHostParameters parameters) {
866   return std::make_unique<ServiceEndpointRequestImpl>(
867       std::move(host), network_anonymization_key, parameters,
868       weak_ptr_factory_.GetWeakPtr());
869 }
870 
871 std::unique_ptr<HostResolver::ProbeRequest>
CreateDohProbeRequest()872 MockHostResolverBase::CreateDohProbeRequest() {
873   return std::make_unique<ProbeRequestImpl>(weak_ptr_factory_.GetWeakPtr());
874 }
875 
876 std::unique_ptr<HostResolver::MdnsListener>
CreateMdnsListener(const HostPortPair & host,DnsQueryType query_type)877 MockHostResolverBase::CreateMdnsListener(const HostPortPair& host,
878                                          DnsQueryType query_type) {
879   return std::make_unique<MdnsListenerImpl>(host, query_type,
880                                             weak_ptr_factory_.GetWeakPtr());
881 }
882 
GetHostCache()883 HostCache* MockHostResolverBase::GetHostCache() {
884   return cache_.get();
885 }
886 
LoadIntoCache(absl::variant<url::SchemeHostPort,HostPortPair> endpoint,const NetworkAnonymizationKey & network_anonymization_key,const std::optional<ResolveHostParameters> & optional_parameters)887 int MockHostResolverBase::LoadIntoCache(
888     absl::variant<url::SchemeHostPort, HostPortPair> endpoint,
889     const NetworkAnonymizationKey& network_anonymization_key,
890     const std::optional<ResolveHostParameters>& optional_parameters) {
891   return LoadIntoCache(Host(std::move(endpoint)), network_anonymization_key,
892                        optional_parameters);
893 }
894 
LoadIntoCache(const Host & endpoint,const NetworkAnonymizationKey & network_anonymization_key,const std::optional<ResolveHostParameters> & optional_parameters)895 int MockHostResolverBase::LoadIntoCache(
896     const Host& endpoint,
897     const NetworkAnonymizationKey& network_anonymization_key,
898     const std::optional<ResolveHostParameters>& optional_parameters) {
899   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
900   DCHECK(cache_);
901 
902   ResolveHostParameters parameters =
903       optional_parameters.value_or(ResolveHostParameters());
904 
905   std::vector<HostResolverEndpointResult> endpoints;
906   std::set<std::string> aliases;
907   std::optional<HostCache::EntryStaleness> stale_info;
908   int rv = ResolveFromIPLiteralOrCache(
909       endpoint, network_anonymization_key, parameters.dns_query_type,
910       ParametersToHostResolverFlags(parameters), parameters.source,
911       parameters.cache_usage, &endpoints, &aliases, &stale_info);
912   if (rv != ERR_DNS_CACHE_MISS) {
913     // Request already in cache (or IP literal). No need to load it.
914     return rv;
915   }
916 
917   // Just like the real resolver, refuse to do anything with invalid
918   // hostnames.
919   if (!dns_names_util::IsValidDnsName(endpoint.GetHostnameWithoutBrackets()))
920     return ERR_NAME_NOT_RESOLVED;
921 
922   RequestImpl request(endpoint, network_anonymization_key, optional_parameters,
923                       weak_ptr_factory_.GetWeakPtr());
924   return DoSynchronousResolution(request);
925 }
926 
ResolveAllPending()927 void MockHostResolverBase::ResolveAllPending() {
928   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
929   DCHECK(ondemand_mode_);
930   for (auto& [id, request] : state_->mutable_requests()) {
931     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
932         FROM_HERE, base::BindOnce(&MockHostResolverBase::ResolveNow,
933                                   weak_ptr_factory_.GetWeakPtr(), id));
934   }
935 }
936 
last_id()937 size_t MockHostResolverBase::last_id() {
938   if (!has_pending_requests())
939     return 0;
940   return state_->mutable_requests().rbegin()->first;
941 }
942 
ResolveNow(size_t id)943 void MockHostResolverBase::ResolveNow(size_t id) {
944   auto it = state_->mutable_requests().find(id);
945   if (it == state_->mutable_requests().end())
946     return;  // was canceled
947 
948   RequestBase* req = it->second;
949   state_->mutable_requests().erase(it);
950 
951   int error = DoSynchronousResolution(*req);
952   req->OnAsyncCompleted(id, error);
953 }
954 
DetachRequest(size_t id)955 void MockHostResolverBase::DetachRequest(size_t id) {
956   auto it = state_->mutable_requests().find(id);
957   CHECK(it != state_->mutable_requests().end());
958   state_->mutable_requests().erase(it);
959 }
960 
request_host(size_t id)961 std::string_view MockHostResolverBase::request_host(size_t id) {
962   DCHECK(request(id));
963   return request(id)->request_endpoint().GetHostnameWithoutBrackets();
964 }
965 
request_priority(size_t id)966 RequestPriority MockHostResolverBase::request_priority(size_t id) {
967   DCHECK(request(id));
968   return request(id)->priority();
969 }
970 
971 const NetworkAnonymizationKey&
request_network_anonymization_key(size_t id)972 MockHostResolverBase::request_network_anonymization_key(size_t id) {
973   DCHECK(request(id));
974   return request(id)->network_anonymization_key();
975 }
976 
ResolveOnlyRequestNow()977 void MockHostResolverBase::ResolveOnlyRequestNow() {
978   DCHECK_EQ(1u, state_->mutable_requests().size());
979   ResolveNow(state_->mutable_requests().begin()->first);
980 }
981 
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type,const IPEndPoint & address_result)982 void MockHostResolverBase::TriggerMdnsListeners(
983     const HostPortPair& host,
984     DnsQueryType query_type,
985     MdnsListenerUpdateType update_type,
986     const IPEndPoint& address_result) {
987   for (MdnsListenerImpl* listener : listeners_) {
988     if (listener->host() == host && listener->query_type() == query_type)
989       listener->TriggerAddressResult(update_type, address_result);
990   }
991 }
992 
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type,const std::vector<std::string> & text_result)993 void MockHostResolverBase::TriggerMdnsListeners(
994     const HostPortPair& host,
995     DnsQueryType query_type,
996     MdnsListenerUpdateType update_type,
997     const std::vector<std::string>& text_result) {
998   for (MdnsListenerImpl* listener : listeners_) {
999     if (listener->host() == host && listener->query_type() == query_type)
1000       listener->TriggerTextResult(update_type, text_result);
1001   }
1002 }
1003 
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type,const HostPortPair & host_result)1004 void MockHostResolverBase::TriggerMdnsListeners(
1005     const HostPortPair& host,
1006     DnsQueryType query_type,
1007     MdnsListenerUpdateType update_type,
1008     const HostPortPair& host_result) {
1009   for (MdnsListenerImpl* listener : listeners_) {
1010     if (listener->host() == host && listener->query_type() == query_type)
1011       listener->TriggerHostnameResult(update_type, host_result);
1012   }
1013 }
1014 
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type)1015 void MockHostResolverBase::TriggerMdnsListeners(
1016     const HostPortPair& host,
1017     DnsQueryType query_type,
1018     MdnsListenerUpdateType update_type) {
1019   for (MdnsListenerImpl* listener : listeners_) {
1020     if (listener->host() == host && listener->query_type() == query_type)
1021       listener->TriggerUnhandledResult(update_type);
1022   }
1023 }
1024 
request(size_t id)1025 MockHostResolverBase::RequestBase* MockHostResolverBase::request(size_t id) {
1026   RequestMap::iterator request = state_->mutable_requests().find(id);
1027   CHECK(request != state_->mutable_requests().end());
1028   CHECK_EQ(request->second->id(), id);
1029   return (*request).second;
1030 }
1031 
1032 // start id from 1 to distinguish from NULL RequestHandle
MockHostResolverBase(bool use_caching,int cache_invalidation_num,RuleResolver rule_resolver)1033 MockHostResolverBase::MockHostResolverBase(bool use_caching,
1034                                            int cache_invalidation_num,
1035                                            RuleResolver rule_resolver)
1036     : rule_resolver_(std::move(rule_resolver)),
1037       initial_cache_invalidation_num_(cache_invalidation_num),
1038       tick_clock_(base::DefaultTickClock::GetInstance()),
1039       state_(base::MakeRefCounted<State>()) {
1040   if (use_caching)
1041     cache_ = std::make_unique<HostCache>(kMaxCacheEntries);
1042   else
1043     DCHECK_GE(0, cache_invalidation_num);
1044 }
1045 
Resolve(RequestBase * request)1046 int MockHostResolverBase::Resolve(RequestBase* request) {
1047   DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1048 
1049   last_request_priority_ = request->parameters().initial_priority;
1050   last_request_network_anonymization_key_ =
1051       request->network_anonymization_key();
1052   last_secure_dns_policy_ = request->parameters().secure_dns_policy;
1053   state_->IncrementNumResolve();
1054   std::vector<HostResolverEndpointResult> endpoints;
1055   std::set<std::string> aliases;
1056   std::optional<HostCache::EntryStaleness> stale_info;
1057   // TODO(crbug.com/40203587): Allow caching `ConnectionEndpoint` results.
1058   int rv = ResolveFromIPLiteralOrCache(
1059       request->request_endpoint(), request->network_anonymization_key(),
1060       request->parameters().dns_query_type, request->host_resolver_flags(),
1061       request->parameters().source, request->parameters().cache_usage,
1062       &endpoints, &aliases, &stale_info);
1063 
1064   if (rv == OK && !request->parameters().is_speculative) {
1065     request->SetEndpointResults(std::move(endpoints), std::move(aliases),
1066                                 std::move(stale_info));
1067   } else {
1068     request->SetError(rv);
1069   }
1070 
1071   if (rv != ERR_DNS_CACHE_MISS ||
1072       request->parameters().source == HostResolverSource::LOCAL_ONLY) {
1073     return SquashErrorCode(rv);
1074   }
1075 
1076   // Just like the real resolver, refuse to do anything with invalid
1077   // hostnames.
1078   if (!dns_names_util::IsValidDnsName(
1079           request->request_endpoint().GetHostnameWithoutBrackets())) {
1080     request->SetError(ERR_NAME_NOT_RESOLVED);
1081     return ERR_NAME_NOT_RESOLVED;
1082   }
1083 
1084   if (synchronous_mode_)
1085     return DoSynchronousResolution(*request);
1086 
1087   // Store the request for asynchronous resolution
1088   size_t id = next_request_id_++;
1089   request->set_id(id);
1090   state_->mutable_requests()[id] = request;
1091 
1092   if (!ondemand_mode_) {
1093     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
1094         FROM_HERE, base::BindOnce(&MockHostResolverBase::ResolveNow,
1095                                   weak_ptr_factory_.GetWeakPtr(), id));
1096   }
1097 
1098   return ERR_IO_PENDING;
1099 }
1100 
ResolveFromIPLiteralOrCache(const Host & endpoint,const NetworkAnonymizationKey & network_anonymization_key,DnsQueryType dns_query_type,HostResolverFlags flags,HostResolverSource source,HostResolver::ResolveHostParameters::CacheUsage cache_usage,std::vector<HostResolverEndpointResult> * out_endpoints,std::set<std::string> * out_aliases,std::optional<HostCache::EntryStaleness> * out_stale_info)1101 int MockHostResolverBase::ResolveFromIPLiteralOrCache(
1102     const Host& endpoint,
1103     const NetworkAnonymizationKey& network_anonymization_key,
1104     DnsQueryType dns_query_type,
1105     HostResolverFlags flags,
1106     HostResolverSource source,
1107     HostResolver::ResolveHostParameters::CacheUsage cache_usage,
1108     std::vector<HostResolverEndpointResult>* out_endpoints,
1109     std::set<std::string>* out_aliases,
1110     std::optional<HostCache::EntryStaleness>* out_stale_info) {
1111   DCHECK(out_endpoints);
1112   DCHECK(out_aliases);
1113   DCHECK(out_stale_info);
1114   out_endpoints->clear();
1115   out_aliases->clear();
1116   *out_stale_info = std::nullopt;
1117 
1118   IPAddress ip_address;
1119   if (ip_address.AssignFromIPLiteral(endpoint.GetHostnameWithoutBrackets())) {
1120     const DnsQueryType desired_address_query =
1121         AddressFamilyToDnsQueryType(GetAddressFamily(ip_address));
1122     DCHECK_NE(desired_address_query, DnsQueryType::UNSPECIFIED);
1123 
1124     // This matches the behavior HostResolverImpl.
1125     if (dns_query_type != DnsQueryType::UNSPECIFIED &&
1126         dns_query_type != desired_address_query) {
1127       return ERR_NAME_NOT_RESOLVED;
1128     }
1129 
1130     *out_endpoints = std::vector<HostResolverEndpointResult>(1);
1131     (*out_endpoints)[0].ip_endpoints.emplace_back(ip_address,
1132                                                   endpoint.GetPort());
1133     if (flags & HOST_RESOLVER_CANONNAME)
1134       *out_aliases = {ip_address.ToString()};
1135     return OK;
1136   }
1137 
1138   std::vector<IPEndPoint> localhost_endpoints;
1139   // Immediately resolve any "localhost" or recognized similar names.
1140   if (IsAddressType(dns_query_type) &&
1141       ResolveLocalHostname(endpoint.GetHostnameWithoutBrackets(),
1142                            &localhost_endpoints)) {
1143     *out_endpoints = std::vector<HostResolverEndpointResult>(1);
1144     (*out_endpoints)[0].ip_endpoints = localhost_endpoints;
1145     return OK;
1146   }
1147   int rv = ERR_DNS_CACHE_MISS;
1148   bool cache_allowed =
1149       cache_usage == HostResolver::ResolveHostParameters::CacheUsage::ALLOWED ||
1150       cache_usage ==
1151           HostResolver::ResolveHostParameters::CacheUsage::STALE_ALLOWED;
1152   if (cache_.get() && cache_allowed) {
1153     // Local-only requests search the cache for non-local-only results.
1154     HostResolverSource effective_source =
1155         source == HostResolverSource::LOCAL_ONLY ? HostResolverSource::ANY
1156                                                  : source;
1157     HostCache::Key key(GetCacheHost(endpoint), dns_query_type, flags,
1158                        effective_source, network_anonymization_key);
1159     const std::pair<const HostCache::Key, HostCache::Entry>* cache_result;
1160     HostCache::EntryStaleness stale_info = HostCache::kNotStale;
1161     if (cache_usage ==
1162         HostResolver::ResolveHostParameters::CacheUsage::STALE_ALLOWED) {
1163       cache_result = cache_->LookupStale(key, tick_clock_->NowTicks(),
1164                                          &stale_info, true /* ignore_secure */);
1165     } else {
1166       cache_result = cache_->Lookup(key, tick_clock_->NowTicks(),
1167                                     true /* ignore_secure */);
1168     }
1169     if (cache_result) {
1170       rv = cache_result->second.error();
1171       if (rv == OK) {
1172         *out_endpoints = cache_result->second.GetEndpoints();
1173 
1174         *out_aliases = cache_result->second.aliases();
1175         *out_stale_info = std::move(stale_info);
1176       }
1177 
1178       auto cache_invalidation_iterator = cache_invalidation_nums_.find(key);
1179       if (cache_invalidation_iterator != cache_invalidation_nums_.end()) {
1180         DCHECK_LE(1, cache_invalidation_iterator->second);
1181         cache_invalidation_iterator->second--;
1182         if (cache_invalidation_iterator->second == 0) {
1183           HostCache::Entry new_entry(cache_result->second);
1184           cache_->Set(key, new_entry, tick_clock_->NowTicks(),
1185                       base::TimeDelta());
1186           cache_invalidation_nums_.erase(cache_invalidation_iterator);
1187         }
1188       }
1189     }
1190   }
1191   return rv;
1192 }
1193 
DoSynchronousResolution(RequestBase & request)1194 int MockHostResolverBase::DoSynchronousResolution(RequestBase& request) {
1195   state_->IncrementNumNonLocalResolves();
1196 
1197   const RuleResolver::RuleResultOrError& result = rule_resolver_.Resolve(
1198       request.request_endpoint(), {request.parameters().dns_query_type},
1199       request.parameters().source);
1200 
1201   int error = ERR_UNEXPECTED;
1202   std::optional<HostCache::Entry> cache_entry;
1203   if (absl::holds_alternative<RuleResolver::RuleResult>(result)) {
1204     const auto& rule_result = absl::get<RuleResolver::RuleResult>(result);
1205     const auto& endpoint_results = rule_result.endpoints;
1206     const auto& aliases = rule_result.aliases;
1207     request.SetEndpointResults(endpoint_results, aliases,
1208                                /*staleness=*/std::nullopt);
1209     // TODO(crbug.com/40203587): Change `error` on empty results?
1210     error = OK;
1211     if (cache_.get()) {
1212       cache_entry = CreateCacheEntry(request.request_endpoint().GetHostname(),
1213                                      endpoint_results, aliases);
1214     }
1215   } else {
1216     DCHECK(absl::holds_alternative<RuleResolver::ErrorResult>(result));
1217     error = absl::get<RuleResolver::ErrorResult>(result);
1218     request.SetError(error);
1219     if (cache_.get()) {
1220       cache_entry.emplace(error, HostCache::Entry::SOURCE_UNKNOWN);
1221     }
1222   }
1223   if (cache_.get() && cache_entry.has_value()) {
1224     HostCache::Key key(
1225         GetCacheHost(request.request_endpoint()),
1226         request.parameters().dns_query_type, request.host_resolver_flags(),
1227         request.parameters().source, request.network_anonymization_key());
1228     // Storing a failure with TTL 0 so that it overwrites previous value.
1229     base::TimeDelta ttl;
1230     if (error == OK) {
1231       ttl = base::Seconds(kCacheEntryTTLSeconds);
1232       if (initial_cache_invalidation_num_ > 0)
1233         cache_invalidation_nums_[key] = initial_cache_invalidation_num_;
1234     }
1235     cache_->Set(key, cache_entry.value(), tick_clock_->NowTicks(), ttl);
1236   }
1237 
1238   return SquashErrorCode(error);
1239 }
1240 
AddListener(MdnsListenerImpl * listener)1241 void MockHostResolverBase::AddListener(MdnsListenerImpl* listener) {
1242   listeners_.insert(listener);
1243 }
1244 
RemoveCancelledListener(MdnsListenerImpl * listener)1245 void MockHostResolverBase::RemoveCancelledListener(MdnsListenerImpl* listener) {
1246   listeners_.erase(listener);
1247 }
1248 
MockHostResolverFactory(MockHostResolverBase::RuleResolver rules,bool use_caching,int cache_invalidation_num)1249 MockHostResolverFactory::MockHostResolverFactory(
1250     MockHostResolverBase::RuleResolver rules,
1251     bool use_caching,
1252     int cache_invalidation_num)
1253     : rules_(std::move(rules)),
1254       use_caching_(use_caching),
1255       cache_invalidation_num_(cache_invalidation_num) {}
1256 
1257 MockHostResolverFactory::~MockHostResolverFactory() = default;
1258 
CreateResolver(HostResolverManager * manager,std::string_view host_mapping_rules,bool enable_caching)1259 std::unique_ptr<HostResolver> MockHostResolverFactory::CreateResolver(
1260     HostResolverManager* manager,
1261     std::string_view host_mapping_rules,
1262     bool enable_caching) {
1263   DCHECK(host_mapping_rules.empty());
1264 
1265   // Explicit new to access private constructor.
1266   auto resolver = base::WrapUnique(new MockHostResolverBase(
1267       enable_caching && use_caching_, cache_invalidation_num_, rules_));
1268   return resolver;
1269 }
1270 
CreateStandaloneResolver(NetLog * net_log,const HostResolver::ManagerOptions & options,std::string_view host_mapping_rules,bool enable_caching)1271 std::unique_ptr<HostResolver> MockHostResolverFactory::CreateStandaloneResolver(
1272     NetLog* net_log,
1273     const HostResolver::ManagerOptions& options,
1274     std::string_view host_mapping_rules,
1275     bool enable_caching) {
1276   return CreateResolver(nullptr, host_mapping_rules, enable_caching);
1277 }
1278 
1279 //-----------------------------------------------------------------------------
1280 
Rule(ResolverType resolver_type,std::string_view host_pattern,AddressFamily address_family,HostResolverFlags host_resolver_flags,std::string_view replacement,std::vector<std::string> dns_aliases,int latency_ms)1281 RuleBasedHostResolverProc::Rule::Rule(ResolverType resolver_type,
1282                                       std::string_view host_pattern,
1283                                       AddressFamily address_family,
1284                                       HostResolverFlags host_resolver_flags,
1285                                       std::string_view replacement,
1286                                       std::vector<std::string> dns_aliases,
1287                                       int latency_ms)
1288     : resolver_type(resolver_type),
1289       host_pattern(host_pattern),
1290       address_family(address_family),
1291       host_resolver_flags(host_resolver_flags),
1292       replacement(replacement),
1293       dns_aliases(std::move(dns_aliases)),
1294       latency_ms(latency_ms) {
1295   DCHECK(this->dns_aliases != std::vector<std::string>({""}));
1296 }
1297 
1298 RuleBasedHostResolverProc::Rule::Rule(const Rule& other) = default;
1299 
1300 RuleBasedHostResolverProc::Rule::~Rule() = default;
1301 
RuleBasedHostResolverProc(scoped_refptr<HostResolverProc> previous,bool allow_fallback)1302 RuleBasedHostResolverProc::RuleBasedHostResolverProc(
1303     scoped_refptr<HostResolverProc> previous,
1304     bool allow_fallback)
1305     : HostResolverProc(std::move(previous), allow_fallback) {}
1306 
AddRule(std::string_view host_pattern,std::string_view replacement)1307 void RuleBasedHostResolverProc::AddRule(std::string_view host_pattern,
1308                                         std::string_view replacement) {
1309   AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1310                           replacement);
1311 }
1312 
AddRuleForAddressFamily(std::string_view host_pattern,AddressFamily address_family,std::string_view replacement)1313 void RuleBasedHostResolverProc::AddRuleForAddressFamily(
1314     std::string_view host_pattern,
1315     AddressFamily address_family,
1316     std::string_view replacement) {
1317   DCHECK(!replacement.empty());
1318   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1319   Rule rule(Rule::kResolverTypeSystem, host_pattern, address_family, flags,
1320             replacement, {} /* dns_aliases */, 0);
1321   AddRuleInternal(rule);
1322 }
1323 
AddRuleWithFlags(std::string_view host_pattern,std::string_view replacement,HostResolverFlags flags,std::vector<std::string> dns_aliases)1324 void RuleBasedHostResolverProc::AddRuleWithFlags(
1325     std::string_view host_pattern,
1326     std::string_view replacement,
1327     HostResolverFlags flags,
1328     std::vector<std::string> dns_aliases) {
1329   DCHECK(!replacement.empty());
1330   Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1331             flags, replacement, std::move(dns_aliases), 0);
1332   AddRuleInternal(rule);
1333 }
1334 
AddIPLiteralRule(std::string_view host_pattern,std::string_view ip_literal,std::string_view canonical_name)1335 void RuleBasedHostResolverProc::AddIPLiteralRule(
1336     std::string_view host_pattern,
1337     std::string_view ip_literal,
1338     std::string_view canonical_name) {
1339   // Literals are always resolved to themselves by HostResolverImpl,
1340   // consequently we do not support remapping them.
1341   IPAddress ip_address;
1342   DCHECK(!ip_address.AssignFromIPLiteral(host_pattern));
1343   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1344   std::vector<std::string> aliases;
1345   if (!canonical_name.empty()) {
1346     flags |= HOST_RESOLVER_CANONNAME;
1347     aliases.emplace_back(canonical_name);
1348   }
1349 
1350   Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
1351             ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, std::move(aliases),
1352             0);
1353   AddRuleInternal(rule);
1354 }
1355 
AddIPLiteralRuleWithDnsAliases(std::string_view host_pattern,std::string_view ip_literal,std::vector<std::string> dns_aliases)1356 void RuleBasedHostResolverProc::AddIPLiteralRuleWithDnsAliases(
1357     std::string_view host_pattern,
1358     std::string_view ip_literal,
1359     std::vector<std::string> dns_aliases) {
1360   // Literals are always resolved to themselves by HostResolverImpl,
1361   // consequently we do not support remapping them.
1362   IPAddress ip_address;
1363   DCHECK(!ip_address.AssignFromIPLiteral(host_pattern));
1364   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1365   if (!dns_aliases.empty())
1366     flags |= HOST_RESOLVER_CANONNAME;
1367 
1368   Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
1369             ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal,
1370             std::move(dns_aliases), 0);
1371   AddRuleInternal(rule);
1372 }
1373 
AddRuleWithLatency(std::string_view host_pattern,std::string_view replacement,int latency_ms)1374 void RuleBasedHostResolverProc::AddRuleWithLatency(
1375     std::string_view host_pattern,
1376     std::string_view replacement,
1377     int latency_ms) {
1378   DCHECK(!replacement.empty());
1379   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1380   Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1381             flags, replacement, /*dns_aliases=*/{}, latency_ms);
1382   AddRuleInternal(rule);
1383 }
1384 
AllowDirectLookup(std::string_view host_pattern)1385 void RuleBasedHostResolverProc::AllowDirectLookup(
1386     std::string_view host_pattern) {
1387   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1388   Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1389             flags, std::string(), /*dns_aliases=*/{}, 0);
1390   AddRuleInternal(rule);
1391 }
1392 
AddSimulatedFailure(std::string_view host_pattern,HostResolverFlags flags)1393 void RuleBasedHostResolverProc::AddSimulatedFailure(
1394     std::string_view host_pattern,
1395     HostResolverFlags flags) {
1396   Rule rule(Rule::kResolverTypeFail, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1397             flags, std::string(), /*dns_aliases=*/{}, 0);
1398   AddRuleInternal(rule);
1399 }
1400 
AddSimulatedTimeoutFailure(std::string_view host_pattern,HostResolverFlags flags)1401 void RuleBasedHostResolverProc::AddSimulatedTimeoutFailure(
1402     std::string_view host_pattern,
1403     HostResolverFlags flags) {
1404   Rule rule(Rule::kResolverTypeFailTimeout, host_pattern,
1405             ADDRESS_FAMILY_UNSPECIFIED, flags, std::string(),
1406             /*dns_aliases=*/{}, 0);
1407   AddRuleInternal(rule);
1408 }
1409 
ClearRules()1410 void RuleBasedHostResolverProc::ClearRules() {
1411   CHECK(modifications_allowed_);
1412   base::AutoLock lock(rule_lock_);
1413   rules_.clear();
1414 }
1415 
DisableModifications()1416 void RuleBasedHostResolverProc::DisableModifications() {
1417   modifications_allowed_ = false;
1418 }
1419 
GetRules()1420 RuleBasedHostResolverProc::RuleList RuleBasedHostResolverProc::GetRules() {
1421   RuleList rv;
1422   {
1423     base::AutoLock lock(rule_lock_);
1424     rv = rules_;
1425   }
1426   return rv;
1427 }
1428 
NumResolvesForHostPattern(std::string_view host_pattern)1429 size_t RuleBasedHostResolverProc::NumResolvesForHostPattern(
1430     std::string_view host_pattern) {
1431   base::AutoLock lock(rule_lock_);
1432   return num_resolves_per_host_pattern_[host_pattern];
1433 }
1434 
Resolve(const std::string & host,AddressFamily address_family,HostResolverFlags host_resolver_flags,AddressList * addrlist,int * os_error)1435 int RuleBasedHostResolverProc::Resolve(const std::string& host,
1436                                        AddressFamily address_family,
1437                                        HostResolverFlags host_resolver_flags,
1438                                        AddressList* addrlist,
1439                                        int* os_error) {
1440   base::AutoLock lock(rule_lock_);
1441   RuleList::iterator r;
1442   for (r = rules_.begin(); r != rules_.end(); ++r) {
1443     bool matches_address_family =
1444         r->address_family == ADDRESS_FAMILY_UNSPECIFIED ||
1445         r->address_family == address_family;
1446     // Ignore HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6, since it should
1447     // have no impact on whether a rule matches.
1448     HostResolverFlags flags =
1449         host_resolver_flags & ~HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
1450     // Flags match if all of the bitflags in host_resolver_flags are enabled
1451     // in the rule's host_resolver_flags. However, the rule may have additional
1452     // flags specified, in which case the flags should still be considered a
1453     // match.
1454     bool matches_flags = (r->host_resolver_flags & flags) == flags;
1455     if (matches_flags && matches_address_family &&
1456         base::MatchPattern(host, r->host_pattern)) {
1457       num_resolves_per_host_pattern_[r->host_pattern]++;
1458 
1459       if (r->latency_ms != 0) {
1460         base::PlatformThread::Sleep(base::Milliseconds(r->latency_ms));
1461       }
1462 
1463       // Remap to a new host.
1464       const std::string& effective_host =
1465           r->replacement.empty() ? host : r->replacement;
1466 
1467       // Apply the resolving function to the remapped hostname.
1468       switch (r->resolver_type) {
1469         case Rule::kResolverTypeFail:
1470           return ERR_NAME_NOT_RESOLVED;
1471         case Rule::kResolverTypeFailTimeout:
1472           return ERR_DNS_TIMED_OUT;
1473         case Rule::kResolverTypeSystem:
1474           EnsureSystemHostResolverCallReady();
1475           return SystemHostResolverCall(effective_host, address_family,
1476                                         host_resolver_flags, addrlist,
1477                                         os_error);
1478         case Rule::kResolverTypeIPLiteral: {
1479           AddressList raw_addr_list;
1480           std::vector<std::string> aliases;
1481           aliases = (!r->dns_aliases.empty())
1482                         ? r->dns_aliases
1483                         : std::vector<std::string>({host});
1484           std::vector<net::IPEndPoint> ip_endpoints;
1485           int result = ParseAddressList(effective_host, &ip_endpoints);
1486           // Filter out addresses with the wrong family.
1487           *addrlist = AddressList();
1488           for (const auto& address : ip_endpoints) {
1489             if (address_family == ADDRESS_FAMILY_UNSPECIFIED ||
1490                 address_family == address.GetFamily()) {
1491               addrlist->push_back(address);
1492             }
1493           }
1494           addrlist->SetDnsAliases(aliases);
1495 
1496           if (result == OK && addrlist->empty())
1497             return ERR_NAME_NOT_RESOLVED;
1498           return result;
1499         }
1500         default:
1501           NOTREACHED();
1502       }
1503     }
1504   }
1505 
1506   return ResolveUsingPrevious(host, address_family, host_resolver_flags,
1507                               addrlist, os_error);
1508 }
1509 
1510 RuleBasedHostResolverProc::~RuleBasedHostResolverProc() = default;
1511 
AddRuleInternal(const Rule & rule)1512 void RuleBasedHostResolverProc::AddRuleInternal(const Rule& rule) {
1513   Rule fixed_rule = rule;
1514   // SystemResolverProc expects valid DNS addresses.
1515   // So for kResolverTypeSystem rules:
1516   // * CHECK that replacement is empty (empty domain names mean use a direct
1517   //   lookup) or a valid DNS name (which includes IP addresses).
1518   // * If the replacement is an IP address, switch to an IP literal rule.
1519   if (fixed_rule.resolver_type == Rule::kResolverTypeSystem) {
1520     CHECK(fixed_rule.replacement.empty() ||
1521           dns_names_util::IsValidDnsName(fixed_rule.replacement));
1522 
1523     IPAddress ip_address;
1524     bool valid_address = ip_address.AssignFromIPLiteral(fixed_rule.replacement);
1525     if (valid_address) {
1526       fixed_rule.resolver_type = Rule::kResolverTypeIPLiteral;
1527     }
1528   }
1529 
1530   CHECK(modifications_allowed_);
1531   base::AutoLock lock(rule_lock_);
1532   rules_.push_back(fixed_rule);
1533 }
1534 
CreateCatchAllHostResolverProc()1535 scoped_refptr<RuleBasedHostResolverProc> CreateCatchAllHostResolverProc() {
1536   auto catchall =
1537       base::MakeRefCounted<RuleBasedHostResolverProc>(/*previous=*/nullptr,
1538                                                       /*allow_fallback=*/false);
1539   // Note that IPv6 lookups fail.
1540   catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost");
1541 
1542   // Next add a rules-based layer that the test controls.
1543   return base::MakeRefCounted<RuleBasedHostResolverProc>(
1544       std::move(catchall), /*allow_fallback=*/false);
1545 }
1546 
1547 //-----------------------------------------------------------------------------
1548 
1549 // Implementation of ResolveHostRequest that tracks cancellations when the
1550 // request is destroyed after being started.
1551 class HangingHostResolver::RequestImpl
1552     : public HostResolver::ResolveHostRequest,
1553       public HostResolver::ProbeRequest {
1554  public:
RequestImpl(base::WeakPtr<HangingHostResolver> resolver)1555   explicit RequestImpl(base::WeakPtr<HangingHostResolver> resolver)
1556       : resolver_(resolver) {}
1557 
1558   RequestImpl(const RequestImpl&) = delete;
1559   RequestImpl& operator=(const RequestImpl&) = delete;
1560 
~RequestImpl()1561   ~RequestImpl() override {
1562     if (is_running_ && resolver_)
1563       resolver_->state_->IncrementNumCancellations();
1564   }
1565 
Start(CompletionOnceCallback callback)1566   int Start(CompletionOnceCallback callback) override { return Start(); }
1567 
Start()1568   int Start() override {
1569     DCHECK(resolver_);
1570     is_running_ = true;
1571     return ERR_IO_PENDING;
1572   }
1573 
GetAddressResults() const1574   const AddressList* GetAddressResults() const override {
1575     base::ImmediateCrash();
1576   }
1577 
GetEndpointResults() const1578   const std::vector<HostResolverEndpointResult>* GetEndpointResults()
1579       const override {
1580     base::ImmediateCrash();
1581   }
1582 
GetTextResults() const1583   const std::vector<std::string>* GetTextResults() const override {
1584     base::ImmediateCrash();
1585   }
1586 
GetHostnameResults() const1587   const std::vector<HostPortPair>* GetHostnameResults() const override {
1588     base::ImmediateCrash();
1589   }
1590 
GetDnsAliasResults() const1591   const std::set<std::string>* GetDnsAliasResults() const override {
1592     base::ImmediateCrash();
1593   }
1594 
GetResolveErrorInfo() const1595   net::ResolveErrorInfo GetResolveErrorInfo() const override {
1596     base::ImmediateCrash();
1597   }
1598 
GetStaleInfo() const1599   const std::optional<HostCache::EntryStaleness>& GetStaleInfo()
1600       const override {
1601     base::ImmediateCrash();
1602   }
1603 
ChangeRequestPriority(RequestPriority priority)1604   void ChangeRequestPriority(RequestPriority priority) override {}
1605 
1606  private:
1607   // Use a WeakPtr as the resolver may be destroyed while there are still
1608   // outstanding request objects.
1609   base::WeakPtr<HangingHostResolver> resolver_;
1610   bool is_running_ = false;
1611 };
1612 
1613 HangingHostResolver::State::State() = default;
1614 HangingHostResolver::State::~State() = default;
1615 
HangingHostResolver()1616 HangingHostResolver::HangingHostResolver()
1617     : state_(base::MakeRefCounted<State>()) {}
1618 
1619 HangingHostResolver::~HangingHostResolver() = default;
1620 
OnShutdown()1621 void HangingHostResolver::OnShutdown() {
1622   shutting_down_ = true;
1623 }
1624 
1625 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(url::SchemeHostPort host,NetworkAnonymizationKey network_anonymization_key,NetLogWithSource net_log,std::optional<ResolveHostParameters> optional_parameters)1626 HangingHostResolver::CreateRequest(
1627     url::SchemeHostPort host,
1628     NetworkAnonymizationKey network_anonymization_key,
1629     NetLogWithSource net_log,
1630     std::optional<ResolveHostParameters> optional_parameters) {
1631   // TODO(crbug.com/40181080): Propagate scheme and make affect behavior.
1632   return CreateRequest(HostPortPair::FromSchemeHostPort(host),
1633                        network_anonymization_key, net_log, optional_parameters);
1634 }
1635 
1636 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(const HostPortPair & host,const NetworkAnonymizationKey & network_anonymization_key,const NetLogWithSource & source_net_log,const std::optional<ResolveHostParameters> & optional_parameters)1637 HangingHostResolver::CreateRequest(
1638     const HostPortPair& host,
1639     const NetworkAnonymizationKey& network_anonymization_key,
1640     const NetLogWithSource& source_net_log,
1641     const std::optional<ResolveHostParameters>& optional_parameters) {
1642   last_host_ = host;
1643   last_network_anonymization_key_ = network_anonymization_key;
1644 
1645   if (shutting_down_)
1646     return CreateFailingRequest(ERR_CONTEXT_SHUT_DOWN);
1647 
1648   if (optional_parameters &&
1649       optional_parameters.value().source == HostResolverSource::LOCAL_ONLY) {
1650     return CreateFailingRequest(ERR_DNS_CACHE_MISS);
1651   }
1652 
1653   return std::make_unique<RequestImpl>(weak_ptr_factory_.GetWeakPtr());
1654 }
1655 
1656 std::unique_ptr<HostResolver::ServiceEndpointRequest>
CreateServiceEndpointRequest(Host host,NetworkAnonymizationKey network_anonymization_key,NetLogWithSource net_log,ResolveHostParameters parameters)1657 HangingHostResolver::CreateServiceEndpointRequest(
1658     Host host,
1659     NetworkAnonymizationKey network_anonymization_key,
1660     NetLogWithSource net_log,
1661     ResolveHostParameters parameters) {
1662   NOTIMPLEMENTED();
1663   return nullptr;
1664 }
1665 
1666 std::unique_ptr<HostResolver::ProbeRequest>
CreateDohProbeRequest()1667 HangingHostResolver::CreateDohProbeRequest() {
1668   if (shutting_down_)
1669     return CreateFailingProbeRequest(ERR_CONTEXT_SHUT_DOWN);
1670 
1671   return std::make_unique<RequestImpl>(weak_ptr_factory_.GetWeakPtr());
1672 }
1673 
SetRequestContext(URLRequestContext * url_request_context)1674 void HangingHostResolver::SetRequestContext(
1675     URLRequestContext* url_request_context) {}
1676 
1677 //-----------------------------------------------------------------------------
1678 
1679 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() = default;
1680 
ScopedDefaultHostResolverProc(HostResolverProc * proc)1681 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc(
1682     HostResolverProc* proc) {
1683   Init(proc);
1684 }
1685 
~ScopedDefaultHostResolverProc()1686 ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() {
1687   HostResolverProc* old_proc =
1688       HostResolverProc::SetDefault(previous_proc_.get());
1689   // The lifetimes of multiple instances must be nested.
1690   CHECK_EQ(old_proc, current_proc_.get());
1691 }
1692 
Init(HostResolverProc * proc)1693 void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) {
1694   current_proc_ = proc;
1695   previous_proc_ = HostResolverProc::SetDefault(current_proc_.get());
1696   current_proc_->SetLastProc(previous_proc_);
1697 }
1698 
1699 }  // namespace net
1700