1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/dns/mock_host_resolver.h"
6
7 #include <stdint.h>
8
9 #include <memory>
10 #include <string>
11 #include <utility>
12 #include <vector>
13
14 #include "base/check_op.h"
15 #include "base/functional/bind.h"
16 #include "base/functional/callback_helpers.h"
17 #include "base/location.h"
18 #include "base/logging.h"
19 #include "base/memory/ptr_util.h"
20 #include "base/memory/raw_ptr.h"
21 #include "base/memory/ref_counted.h"
22 #include "base/no_destructor.h"
23 #include "base/notreached.h"
24 #include "base/strings/pattern.h"
25 #include "base/strings/string_piece.h"
26 #include "base/strings/string_split.h"
27 #include "base/strings/string_util.h"
28 #include "base/task/single_thread_task_runner.h"
29 #include "base/threading/platform_thread.h"
30 #include "base/time/default_tick_clock.h"
31 #include "base/time/tick_clock.h"
32 #include "base/time/time.h"
33 #include "base/types/optional_util.h"
34 #include "build/build_config.h"
35 #include "net/base/address_family.h"
36 #include "net/base/address_list.h"
37 #include "net/base/host_port_pair.h"
38 #include "net/base/ip_address.h"
39 #include "net/base/ip_endpoint.h"
40 #include "net/base/net_errors.h"
41 #include "net/base/net_export.h"
42 #include "net/base/network_anonymization_key.h"
43 #include "net/base/test_completion_callback.h"
44 #include "net/dns/dns_alias_utility.h"
45 #include "net/dns/dns_names_util.h"
46 #include "net/dns/dns_util.h"
47 #include "net/dns/host_cache.h"
48 #include "net/dns/host_resolver.h"
49 #include "net/dns/host_resolver_manager.h"
50 #include "net/dns/host_resolver_system_task.h"
51 #include "net/dns/https_record_rdata.h"
52 #include "net/dns/public/dns_query_type.h"
53 #include "net/dns/public/host_resolver_results.h"
54 #include "net/dns/public/host_resolver_source.h"
55 #include "net/dns/public/mdns_listener_update_type.h"
56 #include "net/dns/public/resolve_error_info.h"
57 #include "net/dns/public/secure_dns_policy.h"
58 #include "net/log/net_log_with_source.h"
59 #include "net/url_request/url_request_context.h"
60 #include "third_party/abseil-cpp/absl/types/optional.h"
61 #include "third_party/abseil-cpp/absl/types/variant.h"
62 #include "url/scheme_host_port.h"
63
64 #if BUILDFLAG(IS_WIN)
65 #include "net/base/winsock_init.h"
66 #endif
67
68 namespace net {
69
70 namespace {
71
72 // Cache size for the MockCachingHostResolver.
73 const unsigned kMaxCacheEntries = 100;
74 // TTL for the successful resolutions. Failures are not cached.
75 const unsigned kCacheEntryTTLSeconds = 60;
76
GetCacheHost(const HostResolver::Host & endpoint)77 absl::variant<url::SchemeHostPort, std::string> GetCacheHost(
78 const HostResolver::Host& endpoint) {
79 if (endpoint.HasScheme()) {
80 return endpoint.AsSchemeHostPort();
81 }
82
83 return endpoint.GetHostname();
84 }
85
CreateCacheEntry(base::StringPiece canonical_name,const std::vector<HostResolverEndpointResult> & endpoint_results,const std::set<std::string> & aliases)86 absl::optional<HostCache::Entry> CreateCacheEntry(
87 base::StringPiece canonical_name,
88 const std::vector<HostResolverEndpointResult>& endpoint_results,
89 const std::set<std::string>& aliases) {
90 absl::optional<std::vector<net::IPEndPoint>> ip_endpoints;
91 std::multimap<HttpsRecordPriority, ConnectionEndpointMetadata>
92 endpoint_metadatas;
93 for (const auto& endpoint_result : endpoint_results) {
94 if (!ip_endpoints) {
95 ip_endpoints = endpoint_result.ip_endpoints;
96 } else {
97 // TODO(crbug.com/1264933): Support caching different IP endpoints
98 // resutls.
99 CHECK(*ip_endpoints == endpoint_result.ip_endpoints)
100 << "Currently caching MockHostResolver only supports same IP "
101 "endpoints results.";
102 }
103
104 if (!endpoint_result.metadata.supported_protocol_alpns.empty()) {
105 endpoint_metadatas.emplace(/*priority=*/1, endpoint_result.metadata);
106 }
107 }
108 DCHECK(ip_endpoints);
109 auto endpoint_entry = HostCache::Entry(OK, *ip_endpoints, aliases,
110 HostCache::Entry::SOURCE_UNKNOWN);
111 endpoint_entry.set_canonical_names(std::set{std::string(canonical_name)});
112 if (endpoint_metadatas.empty()) {
113 return endpoint_entry;
114 }
115 return HostCache::Entry::MergeEntries(
116 HostCache::Entry(OK, std::move(endpoint_metadatas),
117 HostCache::Entry::SOURCE_UNKNOWN),
118 endpoint_entry);
119 }
120 } // namespace
121
ParseAddressList(base::StringPiece host_list,std::vector<net::IPEndPoint> * ip_endpoints)122 int ParseAddressList(base::StringPiece host_list,
123 std::vector<net::IPEndPoint>* ip_endpoints) {
124 ip_endpoints->clear();
125 for (base::StringPiece address : base::SplitStringPiece(
126 host_list, ",", base::TRIM_WHITESPACE, base::SPLIT_WANT_ALL)) {
127 IPAddress ip_address;
128 if (!ip_address.AssignFromIPLiteral(address)) {
129 LOG(WARNING) << "Not a supported IP literal: " << address;
130 return ERR_UNEXPECTED;
131 }
132 ip_endpoints->push_back(IPEndPoint(ip_address, 0));
133 }
134 return OK;
135 }
136
137 class MockHostResolverBase::RequestImpl
138 : public HostResolver::ResolveHostRequest {
139 public:
RequestImpl(Host request_endpoint,const NetworkAnonymizationKey & network_anonymization_key,const absl::optional<ResolveHostParameters> & optional_parameters,base::WeakPtr<MockHostResolverBase> resolver)140 RequestImpl(Host request_endpoint,
141 const NetworkAnonymizationKey& network_anonymization_key,
142 const absl::optional<ResolveHostParameters>& optional_parameters,
143 base::WeakPtr<MockHostResolverBase> resolver)
144 : request_endpoint_(std::move(request_endpoint)),
145 network_anonymization_key_(network_anonymization_key),
146 parameters_(optional_parameters ? optional_parameters.value()
147 : ResolveHostParameters()),
148 priority_(parameters_.initial_priority),
149 host_resolver_flags_(ParametersToHostResolverFlags(parameters_)),
150 resolve_error_info_(ResolveErrorInfo(ERR_IO_PENDING)),
151 resolver_(resolver) {}
152
153 RequestImpl(const RequestImpl&) = delete;
154 RequestImpl& operator=(const RequestImpl&) = delete;
155
~RequestImpl()156 ~RequestImpl() override {
157 if (id_ > 0) {
158 if (resolver_)
159 resolver_->DetachRequest(id_);
160 id_ = 0;
161 resolver_ = nullptr;
162 }
163 }
164
DetachFromResolver()165 void DetachFromResolver() {
166 id_ = 0;
167 resolver_ = nullptr;
168 }
169
Start(CompletionOnceCallback callback)170 int Start(CompletionOnceCallback callback) override {
171 DCHECK(callback);
172 // Start() may only be called once per request.
173 DCHECK_EQ(0u, id_);
174 DCHECK(!complete_);
175 DCHECK(!callback_);
176 // Parent HostResolver must still be alive to call Start().
177 DCHECK(resolver_);
178
179 int rv = resolver_->Resolve(this);
180 DCHECK(!complete_);
181 if (rv == ERR_IO_PENDING) {
182 DCHECK_GT(id_, 0u);
183 callback_ = std::move(callback);
184 } else {
185 DCHECK_EQ(0u, id_);
186 complete_ = true;
187 }
188
189 return rv;
190 }
191
GetAddressResults() const192 const AddressList* GetAddressResults() const override {
193 DCHECK(complete_);
194 return base::OptionalToPtr(address_results_);
195 }
196
GetEndpointResults() const197 const std::vector<HostResolverEndpointResult>* GetEndpointResults()
198 const override {
199 DCHECK(complete_);
200 return base::OptionalToPtr(endpoint_results_);
201 }
202
GetTextResults() const203 const std::vector<std::string>* GetTextResults() const override {
204 DCHECK(complete_);
205 static const base::NoDestructor<std::vector<std::string>> empty_result;
206 return empty_result.get();
207 }
208
GetHostnameResults() const209 const std::vector<HostPortPair>* GetHostnameResults() const override {
210 DCHECK(complete_);
211 static const base::NoDestructor<std::vector<HostPortPair>> empty_result;
212 return empty_result.get();
213 }
214
GetDnsAliasResults() const215 const std::set<std::string>* GetDnsAliasResults() const override {
216 DCHECK(complete_);
217 return base::OptionalToPtr(fixed_up_dns_alias_results_);
218 }
219
GetResolveErrorInfo() const220 net::ResolveErrorInfo GetResolveErrorInfo() const override {
221 DCHECK(complete_);
222 return resolve_error_info_;
223 }
224
GetStaleInfo() const225 const absl::optional<HostCache::EntryStaleness>& GetStaleInfo()
226 const override {
227 DCHECK(complete_);
228 return staleness_;
229 }
230
ChangeRequestPriority(RequestPriority priority)231 void ChangeRequestPriority(RequestPriority priority) override {
232 priority_ = priority;
233 }
234
SetError(int error)235 void SetError(int error) {
236 // Should only be called before request is marked completed.
237 DCHECK(!complete_);
238 resolve_error_info_ = ResolveErrorInfo(error);
239 }
240
241 // Sets `endpoint_results_`, `fixed_up_dns_alias_results_`,
242 // `address_results_` and `staleness_` after fixing them up.
243 // Also sets `error` to OK.
SetEndpointResults(std::vector<HostResolverEndpointResult> endpoint_results,std::set<std::string> aliases,absl::optional<HostCache::EntryStaleness> staleness)244 void SetEndpointResults(
245 std::vector<HostResolverEndpointResult> endpoint_results,
246 std::set<std::string> aliases,
247 absl::optional<HostCache::EntryStaleness> staleness) {
248 DCHECK(!complete_);
249 DCHECK(!endpoint_results_);
250 DCHECK(!parameters_.is_speculative);
251
252 endpoint_results_ = std::move(endpoint_results);
253 for (auto& result : *endpoint_results_) {
254 result.ip_endpoints = FixupEndPoints(result.ip_endpoints);
255 }
256
257 fixed_up_dns_alias_results_ = FixupAliases(aliases);
258
259 // `HostResolver` implementations are expected to provide an `AddressList`
260 // result whenever `HostResolverEndpointResult` is also available.
261 address_results_ = EndpointResultToAddressList(
262 *endpoint_results_, *fixed_up_dns_alias_results_);
263
264 staleness_ = std::move(staleness);
265
266 SetError(OK);
267 }
268
OnAsyncCompleted(size_t id,int error)269 void OnAsyncCompleted(size_t id, int error) {
270 DCHECK_EQ(id_, id);
271 id_ = 0;
272
273 // Check that error information has been set and that the top-level error
274 // code is valid.
275 DCHECK(resolve_error_info_.error != ERR_IO_PENDING);
276 DCHECK(error == OK || error == ERR_NAME_NOT_RESOLVED ||
277 error == ERR_DNS_NAME_HTTPS_ONLY);
278
279 DCHECK(!complete_);
280 complete_ = true;
281
282 DCHECK(callback_);
283 std::move(callback_).Run(error);
284 }
285
request_endpoint() const286 const Host& request_endpoint() const { return request_endpoint_; }
287
network_anonymization_key() const288 const NetworkAnonymizationKey& network_anonymization_key() const {
289 return network_anonymization_key_;
290 }
291
parameters() const292 const ResolveHostParameters& parameters() const { return parameters_; }
293
host_resolver_flags() const294 int host_resolver_flags() const { return host_resolver_flags_; }
295
id()296 size_t id() { return id_; }
297
priority() const298 RequestPriority priority() const { return priority_; }
299
set_id(size_t id)300 void set_id(size_t id) {
301 DCHECK_GT(id, 0u);
302 DCHECK_EQ(0u, id_);
303
304 id_ = id;
305 }
306
complete()307 bool complete() { return complete_; }
308
309 // Similar get GetAddressResults() and GetResolveErrorInfo(), but only exposed
310 // through the HostResolver::ResolveHostRequest interface, and don't have the
311 // DCHECKs that `complete_` is true.
address_results() const312 const absl::optional<AddressList>& address_results() const {
313 return address_results_;
314 }
resolve_error_info() const315 ResolveErrorInfo resolve_error_info() const { return resolve_error_info_; }
316
317 private:
FixupEndPoints(const std::vector<IPEndPoint> & endpoints)318 std::vector<IPEndPoint> FixupEndPoints(
319 const std::vector<IPEndPoint>& endpoints) {
320 std::vector<IPEndPoint> corrected;
321 for (const IPEndPoint& endpoint : endpoints) {
322 DCHECK_NE(endpoint.GetFamily(), ADDRESS_FAMILY_UNSPECIFIED);
323 if (parameters_.dns_query_type == DnsQueryType::UNSPECIFIED ||
324 parameters_.dns_query_type ==
325 AddressFamilyToDnsQueryType(endpoint.GetFamily())) {
326 if (endpoint.port() == 0) {
327 corrected.emplace_back(endpoint.address(),
328 request_endpoint_.GetPort());
329 } else {
330 corrected.push_back(endpoint);
331 }
332 }
333 }
334 return corrected;
335 }
FixupAliases(const std::set<std::string> aliases)336 std::set<std::string> FixupAliases(const std::set<std::string> aliases) {
337 if (aliases.empty())
338 return std::set<std::string>{
339 std::string(request_endpoint_.GetHostnameWithoutBrackets())};
340 return aliases;
341 }
342
343 const Host request_endpoint_;
344 const NetworkAnonymizationKey network_anonymization_key_;
345 const ResolveHostParameters parameters_;
346 RequestPriority priority_;
347 int host_resolver_flags_;
348
349 absl::optional<AddressList> address_results_;
350 absl::optional<std::vector<HostResolverEndpointResult>> endpoint_results_;
351 absl::optional<std::set<std::string>> fixed_up_dns_alias_results_;
352 absl::optional<HostCache::EntryStaleness> staleness_;
353 ResolveErrorInfo resolve_error_info_;
354
355 // Used while stored with the resolver for async resolution. Otherwise 0.
356 size_t id_ = 0;
357
358 CompletionOnceCallback callback_;
359 // Use a WeakPtr as the resolver may be destroyed while there are still
360 // outstanding request objects.
361 base::WeakPtr<MockHostResolverBase> resolver_;
362 bool complete_ = false;
363 };
364
365 class MockHostResolverBase::ProbeRequestImpl
366 : public HostResolver::ProbeRequest {
367 public:
ProbeRequestImpl(base::WeakPtr<MockHostResolverBase> resolver)368 explicit ProbeRequestImpl(base::WeakPtr<MockHostResolverBase> resolver)
369 : resolver_(std::move(resolver)) {}
370
371 ProbeRequestImpl(const ProbeRequestImpl&) = delete;
372 ProbeRequestImpl& operator=(const ProbeRequestImpl&) = delete;
373
~ProbeRequestImpl()374 ~ProbeRequestImpl() override {
375 if (resolver_) {
376 resolver_->state_->ClearDohProbeRequestIfMatching(this);
377 }
378 }
379
Start()380 int Start() override {
381 DCHECK(resolver_);
382 resolver_->state_->set_doh_probe_request(this);
383
384 return ERR_IO_PENDING;
385 }
386
387 private:
388 base::WeakPtr<MockHostResolverBase> resolver_;
389 };
390
391 class MockHostResolverBase::MdnsListenerImpl
392 : public HostResolver::MdnsListener {
393 public:
MdnsListenerImpl(const HostPortPair & host,DnsQueryType query_type,base::WeakPtr<MockHostResolverBase> resolver)394 MdnsListenerImpl(const HostPortPair& host,
395 DnsQueryType query_type,
396 base::WeakPtr<MockHostResolverBase> resolver)
397 : host_(host), query_type_(query_type), resolver_(resolver) {
398 DCHECK_NE(DnsQueryType::UNSPECIFIED, query_type_);
399 DCHECK(resolver_);
400 }
401
~MdnsListenerImpl()402 ~MdnsListenerImpl() override {
403 if (resolver_)
404 resolver_->RemoveCancelledListener(this);
405 }
406
Start(Delegate * delegate)407 int Start(Delegate* delegate) override {
408 DCHECK(delegate);
409 DCHECK(!delegate_);
410 DCHECK(resolver_);
411
412 delegate_ = delegate;
413 resolver_->AddListener(this);
414
415 return OK;
416 }
417
TriggerAddressResult(MdnsListenerUpdateType update_type,IPEndPoint address)418 void TriggerAddressResult(MdnsListenerUpdateType update_type,
419 IPEndPoint address) {
420 delegate_->OnAddressResult(update_type, query_type_, std::move(address));
421 }
422
TriggerTextResult(MdnsListenerUpdateType update_type,std::vector<std::string> text_records)423 void TriggerTextResult(MdnsListenerUpdateType update_type,
424 std::vector<std::string> text_records) {
425 delegate_->OnTextResult(update_type, query_type_, std::move(text_records));
426 }
427
TriggerHostnameResult(MdnsListenerUpdateType update_type,HostPortPair host)428 void TriggerHostnameResult(MdnsListenerUpdateType update_type,
429 HostPortPair host) {
430 delegate_->OnHostnameResult(update_type, query_type_, std::move(host));
431 }
432
TriggerUnhandledResult(MdnsListenerUpdateType update_type)433 void TriggerUnhandledResult(MdnsListenerUpdateType update_type) {
434 delegate_->OnUnhandledResult(update_type, query_type_);
435 }
436
host() const437 const HostPortPair& host() const { return host_; }
query_type() const438 DnsQueryType query_type() const { return query_type_; }
439
440 private:
441 const HostPortPair host_;
442 const DnsQueryType query_type_;
443
444 raw_ptr<Delegate> delegate_ = nullptr;
445
446 // Use a WeakPtr as the resolver may be destroyed while there are still
447 // outstanding listener objects.
448 base::WeakPtr<MockHostResolverBase> resolver_;
449 };
450
451 MockHostResolverBase::RuleResolver::RuleKey::RuleKey() = default;
452
453 MockHostResolverBase::RuleResolver::RuleKey::~RuleKey() = default;
454
455 MockHostResolverBase::RuleResolver::RuleKey::RuleKey(const RuleKey&) = default;
456
457 MockHostResolverBase::RuleResolver::RuleKey&
458 MockHostResolverBase::RuleResolver::RuleKey::operator=(const RuleKey&) =
459 default;
460
461 MockHostResolverBase::RuleResolver::RuleKey::RuleKey(RuleKey&&) = default;
462
463 MockHostResolverBase::RuleResolver::RuleKey&
464 MockHostResolverBase::RuleResolver::RuleKey::operator=(RuleKey&&) = default;
465
466 MockHostResolverBase::RuleResolver::RuleResult::RuleResult() = default;
467
RuleResult(std::vector<HostResolverEndpointResult> endpoints,std::set<std::string> aliases)468 MockHostResolverBase::RuleResolver::RuleResult::RuleResult(
469 std::vector<HostResolverEndpointResult> endpoints,
470 std::set<std::string> aliases)
471 : endpoints(std::move(endpoints)), aliases(std::move(aliases)) {}
472
473 MockHostResolverBase::RuleResolver::RuleResult::~RuleResult() = default;
474
475 MockHostResolverBase::RuleResolver::RuleResult::RuleResult(const RuleResult&) =
476 default;
477
478 MockHostResolverBase::RuleResolver::RuleResult&
479 MockHostResolverBase::RuleResolver::RuleResult::operator=(const RuleResult&) =
480 default;
481
482 MockHostResolverBase::RuleResolver::RuleResult::RuleResult(RuleResult&&) =
483 default;
484
485 MockHostResolverBase::RuleResolver::RuleResult&
486 MockHostResolverBase::RuleResolver::RuleResult::operator=(RuleResult&&) =
487 default;
488
RuleResolver(absl::optional<RuleResultOrError> default_result)489 MockHostResolverBase::RuleResolver::RuleResolver(
490 absl::optional<RuleResultOrError> default_result)
491 : default_result_(std::move(default_result)) {}
492
493 MockHostResolverBase::RuleResolver::~RuleResolver() = default;
494
495 MockHostResolverBase::RuleResolver::RuleResolver(const RuleResolver&) = default;
496
497 MockHostResolverBase::RuleResolver&
498 MockHostResolverBase::RuleResolver::operator=(const RuleResolver&) = default;
499
500 MockHostResolverBase::RuleResolver::RuleResolver(RuleResolver&&) = default;
501
502 MockHostResolverBase::RuleResolver&
503 MockHostResolverBase::RuleResolver::operator=(RuleResolver&&) = default;
504
505 const MockHostResolverBase::RuleResolver::RuleResultOrError&
Resolve(const Host & request_endpoint,DnsQueryTypeSet request_types,HostResolverSource request_source) const506 MockHostResolverBase::RuleResolver::Resolve(
507 const Host& request_endpoint,
508 DnsQueryTypeSet request_types,
509 HostResolverSource request_source) const {
510 for (const auto& rule : rules_) {
511 const RuleKey& key = rule.first;
512 const RuleResultOrError& result = rule.second;
513
514 if (absl::holds_alternative<RuleKey::NoScheme>(key.scheme) &&
515 request_endpoint.HasScheme()) {
516 continue;
517 }
518
519 if (key.port.has_value() &&
520 key.port.value() != request_endpoint.GetPort()) {
521 continue;
522 }
523
524 DCHECK(!key.query_type.has_value() ||
525 key.query_type.value() != DnsQueryType::UNSPECIFIED);
526 if (key.query_type.has_value() &&
527 !request_types.Has(key.query_type.value())) {
528 continue;
529 }
530
531 if (key.query_source.has_value() &&
532 request_source != key.query_source.value()) {
533 continue;
534 }
535
536 if (absl::holds_alternative<RuleKey::Scheme>(key.scheme) &&
537 (!request_endpoint.HasScheme() ||
538 request_endpoint.GetScheme() !=
539 absl::get<RuleKey::Scheme>(key.scheme))) {
540 continue;
541 }
542
543 if (!base::MatchPattern(request_endpoint.GetHostnameWithoutBrackets(),
544 key.hostname_pattern)) {
545 continue;
546 }
547
548 return result;
549 }
550
551 if (default_result_)
552 return default_result_.value();
553
554 NOTREACHED() << "Request " << request_endpoint.GetHostname()
555 << " did not match any MockHostResolver rules.";
556 static const RuleResultOrError kUnexpected = ERR_UNEXPECTED;
557 return kUnexpected;
558 }
559
ClearRules()560 void MockHostResolverBase::RuleResolver::ClearRules() {
561 rules_.clear();
562 }
563
564 // static
565 MockHostResolverBase::RuleResolver::RuleResultOrError
GetLocalhostResult()566 MockHostResolverBase::RuleResolver::GetLocalhostResult() {
567 HostResolverEndpointResult endpoint;
568 endpoint.ip_endpoints = {IPEndPoint(IPAddress::IPv4Localhost(), /*port=*/0)};
569 return RuleResult(std::vector{endpoint});
570 }
571
AddRule(RuleKey key,RuleResultOrError result)572 void MockHostResolverBase::RuleResolver::AddRule(RuleKey key,
573 RuleResultOrError result) {
574 // Literals are always resolved to themselves by MockHostResolverBase,
575 // consequently we do not support remapping them.
576 IPAddress ip_address;
577 DCHECK(!ip_address.AssignFromIPLiteral(key.hostname_pattern));
578
579 CHECK(rules_.emplace(std::move(key), std::move(result)).second)
580 << "Duplicate rule key";
581 }
582
AddRule(RuleKey key,base::StringPiece ip_literal)583 void MockHostResolverBase::RuleResolver::AddRule(RuleKey key,
584 base::StringPiece ip_literal) {
585 std::vector<HostResolverEndpointResult> endpoints;
586 endpoints.emplace_back();
587 CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
588 AddRule(std::move(key), RuleResult(std::move(endpoints)));
589 }
590
AddRule(base::StringPiece hostname_pattern,RuleResultOrError result)591 void MockHostResolverBase::RuleResolver::AddRule(
592 base::StringPiece hostname_pattern,
593 RuleResultOrError result) {
594 RuleKey key;
595 key.hostname_pattern = std::string(hostname_pattern);
596 AddRule(std::move(key), std::move(result));
597 }
598
AddRule(base::StringPiece hostname_pattern,base::StringPiece ip_literal)599 void MockHostResolverBase::RuleResolver::AddRule(
600 base::StringPiece hostname_pattern,
601 base::StringPiece ip_literal) {
602 std::vector<HostResolverEndpointResult> endpoints;
603 endpoints.emplace_back();
604 CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
605 AddRule(hostname_pattern, RuleResult(std::move(endpoints)));
606 }
607
AddRule(base::StringPiece hostname_pattern,Error error)608 void MockHostResolverBase::RuleResolver::AddRule(
609 base::StringPiece hostname_pattern,
610 Error error) {
611 RuleKey key;
612 key.hostname_pattern = std::string(hostname_pattern);
613
614 AddRule(std::move(key), error);
615 }
616
AddIPLiteralRule(base::StringPiece hostname_pattern,base::StringPiece ip_literal,base::StringPiece canonical_name)617 void MockHostResolverBase::RuleResolver::AddIPLiteralRule(
618 base::StringPiece hostname_pattern,
619 base::StringPiece ip_literal,
620 base::StringPiece canonical_name) {
621 RuleKey key;
622 key.hostname_pattern = std::string(hostname_pattern);
623
624 std::set<std::string> aliases;
625 if (!canonical_name.empty())
626 aliases.emplace(canonical_name);
627
628 std::vector<HostResolverEndpointResult> endpoints;
629 endpoints.emplace_back();
630 CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
631 AddRule(std::move(key), RuleResult(std::move(endpoints), std::move(aliases)));
632 }
633
AddIPLiteralRuleWithDnsAliases(base::StringPiece hostname_pattern,base::StringPiece ip_literal,std::vector<std::string> dns_aliases)634 void MockHostResolverBase::RuleResolver::AddIPLiteralRuleWithDnsAliases(
635 base::StringPiece hostname_pattern,
636 base::StringPiece ip_literal,
637 std::vector<std::string> dns_aliases) {
638 std::vector<HostResolverEndpointResult> endpoints;
639 endpoints.emplace_back();
640 CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
641 AddRule(hostname_pattern,
642 RuleResult(
643 std::move(endpoints),
644 std::set<std::string>(dns_aliases.begin(), dns_aliases.end())));
645 }
646
AddIPLiteralRuleWithDnsAliases(base::StringPiece hostname_pattern,base::StringPiece ip_literal,std::set<std::string> dns_aliases)647 void MockHostResolverBase::RuleResolver::AddIPLiteralRuleWithDnsAliases(
648 base::StringPiece hostname_pattern,
649 base::StringPiece ip_literal,
650 std::set<std::string> dns_aliases) {
651 std::vector<std::string> aliases_vector;
652 base::ranges::move(dns_aliases, std::back_inserter(aliases_vector));
653
654 AddIPLiteralRuleWithDnsAliases(hostname_pattern, ip_literal,
655 std::move(aliases_vector));
656 }
657
AddSimulatedFailure(base::StringPiece hostname_pattern)658 void MockHostResolverBase::RuleResolver::AddSimulatedFailure(
659 base::StringPiece hostname_pattern) {
660 AddRule(hostname_pattern, ERR_NAME_NOT_RESOLVED);
661 }
662
AddSimulatedTimeoutFailure(base::StringPiece hostname_pattern)663 void MockHostResolverBase::RuleResolver::AddSimulatedTimeoutFailure(
664 base::StringPiece hostname_pattern) {
665 AddRule(hostname_pattern, ERR_DNS_TIMED_OUT);
666 }
667
AddRuleWithFlags(base::StringPiece host_pattern,base::StringPiece ip_literal,HostResolverFlags,std::vector<std::string> dns_aliases)668 void MockHostResolverBase::RuleResolver::AddRuleWithFlags(
669 base::StringPiece host_pattern,
670 base::StringPiece ip_literal,
671 HostResolverFlags /*flags*/,
672 std::vector<std::string> dns_aliases) {
673 std::vector<HostResolverEndpointResult> endpoints;
674 endpoints.emplace_back();
675 CHECK_EQ(ParseAddressList(ip_literal, &endpoints[0].ip_endpoints), OK);
676 AddRule(host_pattern, RuleResult(std::move(endpoints),
677 std::set<std::string>(dns_aliases.begin(),
678 dns_aliases.end())));
679 }
680
681 MockHostResolverBase::State::State() = default;
682 MockHostResolverBase::State::~State() = default;
683
~MockHostResolverBase()684 MockHostResolverBase::~MockHostResolverBase() {
685 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
686
687 // Sanity check that pending requests are always cleaned up, by waiting for
688 // completion, manually cancelling, or calling OnShutdown().
689 DCHECK(!state_->has_pending_requests());
690 }
691
OnShutdown()692 void MockHostResolverBase::OnShutdown() {
693 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
694
695 // Cancel all pending requests.
696 for (auto& request : state_->mutable_requests()) {
697 request.second->DetachFromResolver();
698 }
699 state_->mutable_requests().clear();
700
701 // Prevent future requests by clearing resolution rules and the cache.
702 rule_resolver_.ClearRules();
703 cache_ = nullptr;
704
705 state_->ClearDohProbeRequest();
706 }
707
708 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(url::SchemeHostPort host,NetworkAnonymizationKey network_anonymization_key,NetLogWithSource net_log,absl::optional<ResolveHostParameters> optional_parameters)709 MockHostResolverBase::CreateRequest(
710 url::SchemeHostPort host,
711 NetworkAnonymizationKey network_anonymization_key,
712 NetLogWithSource net_log,
713 absl::optional<ResolveHostParameters> optional_parameters) {
714 return std::make_unique<RequestImpl>(Host(std::move(host)),
715 network_anonymization_key,
716 optional_parameters, AsWeakPtr());
717 }
718
719 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(const HostPortPair & host,const NetworkAnonymizationKey & network_anonymization_key,const NetLogWithSource & source_net_log,const absl::optional<ResolveHostParameters> & optional_parameters)720 MockHostResolverBase::CreateRequest(
721 const HostPortPair& host,
722 const NetworkAnonymizationKey& network_anonymization_key,
723 const NetLogWithSource& source_net_log,
724 const absl::optional<ResolveHostParameters>& optional_parameters) {
725 return std::make_unique<RequestImpl>(Host(host), network_anonymization_key,
726 optional_parameters, AsWeakPtr());
727 }
728
729 std::unique_ptr<HostResolver::ProbeRequest>
CreateDohProbeRequest()730 MockHostResolverBase::CreateDohProbeRequest() {
731 return std::make_unique<ProbeRequestImpl>(AsWeakPtr());
732 }
733
734 std::unique_ptr<HostResolver::MdnsListener>
CreateMdnsListener(const HostPortPair & host,DnsQueryType query_type)735 MockHostResolverBase::CreateMdnsListener(const HostPortPair& host,
736 DnsQueryType query_type) {
737 return std::make_unique<MdnsListenerImpl>(host, query_type, AsWeakPtr());
738 }
739
GetHostCache()740 HostCache* MockHostResolverBase::GetHostCache() {
741 return cache_.get();
742 }
743
LoadIntoCache(absl::variant<url::SchemeHostPort,HostPortPair> endpoint,const NetworkAnonymizationKey & network_anonymization_key,const absl::optional<ResolveHostParameters> & optional_parameters)744 int MockHostResolverBase::LoadIntoCache(
745 absl::variant<url::SchemeHostPort, HostPortPair> endpoint,
746 const NetworkAnonymizationKey& network_anonymization_key,
747 const absl::optional<ResolveHostParameters>& optional_parameters) {
748 return LoadIntoCache(Host(std::move(endpoint)), network_anonymization_key,
749 optional_parameters);
750 }
751
LoadIntoCache(const Host & endpoint,const NetworkAnonymizationKey & network_anonymization_key,const absl::optional<ResolveHostParameters> & optional_parameters)752 int MockHostResolverBase::LoadIntoCache(
753 const Host& endpoint,
754 const NetworkAnonymizationKey& network_anonymization_key,
755 const absl::optional<ResolveHostParameters>& optional_parameters) {
756 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
757 DCHECK(cache_);
758
759 ResolveHostParameters parameters =
760 optional_parameters.value_or(ResolveHostParameters());
761
762 std::vector<HostResolverEndpointResult> endpoints;
763 std::set<std::string> aliases;
764 absl::optional<HostCache::EntryStaleness> stale_info;
765 int rv = ResolveFromIPLiteralOrCache(
766 endpoint, network_anonymization_key, parameters.dns_query_type,
767 ParametersToHostResolverFlags(parameters), parameters.source,
768 parameters.cache_usage, &endpoints, &aliases, &stale_info);
769 if (rv != ERR_DNS_CACHE_MISS) {
770 // Request already in cache (or IP literal). No need to load it.
771 return rv;
772 }
773
774 // Just like the real resolver, refuse to do anything with invalid
775 // hostnames.
776 if (!dns_names_util::IsValidDnsName(endpoint.GetHostnameWithoutBrackets()))
777 return ERR_NAME_NOT_RESOLVED;
778
779 RequestImpl request(endpoint, network_anonymization_key, optional_parameters,
780 AsWeakPtr());
781 return DoSynchronousResolution(request);
782 }
783
ResolveAllPending()784 void MockHostResolverBase::ResolveAllPending() {
785 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
786 DCHECK(ondemand_mode_);
787 for (auto& [id, request] : state_->mutable_requests()) {
788 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
789 FROM_HERE,
790 base::BindOnce(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id));
791 }
792 }
793
last_id()794 size_t MockHostResolverBase::last_id() {
795 if (!has_pending_requests())
796 return 0;
797 return state_->mutable_requests().rbegin()->first;
798 }
799
ResolveNow(size_t id)800 void MockHostResolverBase::ResolveNow(size_t id) {
801 auto it = state_->mutable_requests().find(id);
802 if (it == state_->mutable_requests().end())
803 return; // was canceled
804
805 RequestImpl* req = it->second;
806 state_->mutable_requests().erase(it);
807
808 int error = DoSynchronousResolution(*req);
809 req->OnAsyncCompleted(id, error);
810 }
811
DetachRequest(size_t id)812 void MockHostResolverBase::DetachRequest(size_t id) {
813 auto it = state_->mutable_requests().find(id);
814 CHECK(it != state_->mutable_requests().end());
815 state_->mutable_requests().erase(it);
816 }
817
request_host(size_t id)818 base::StringPiece MockHostResolverBase::request_host(size_t id) {
819 DCHECK(request(id));
820 return request(id)->request_endpoint().GetHostnameWithoutBrackets();
821 }
822
request_priority(size_t id)823 RequestPriority MockHostResolverBase::request_priority(size_t id) {
824 DCHECK(request(id));
825 return request(id)->priority();
826 }
827
828 const NetworkAnonymizationKey&
request_network_anonymization_key(size_t id)829 MockHostResolverBase::request_network_anonymization_key(size_t id) {
830 DCHECK(request(id));
831 return request(id)->network_anonymization_key();
832 }
833
ResolveOnlyRequestNow()834 void MockHostResolverBase::ResolveOnlyRequestNow() {
835 DCHECK_EQ(1u, state_->mutable_requests().size());
836 ResolveNow(state_->mutable_requests().begin()->first);
837 }
838
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type,const IPEndPoint & address_result)839 void MockHostResolverBase::TriggerMdnsListeners(
840 const HostPortPair& host,
841 DnsQueryType query_type,
842 MdnsListenerUpdateType update_type,
843 const IPEndPoint& address_result) {
844 for (auto* listener : listeners_) {
845 if (listener->host() == host && listener->query_type() == query_type)
846 listener->TriggerAddressResult(update_type, address_result);
847 }
848 }
849
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type,const std::vector<std::string> & text_result)850 void MockHostResolverBase::TriggerMdnsListeners(
851 const HostPortPair& host,
852 DnsQueryType query_type,
853 MdnsListenerUpdateType update_type,
854 const std::vector<std::string>& text_result) {
855 for (auto* listener : listeners_) {
856 if (listener->host() == host && listener->query_type() == query_type)
857 listener->TriggerTextResult(update_type, text_result);
858 }
859 }
860
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type,const HostPortPair & host_result)861 void MockHostResolverBase::TriggerMdnsListeners(
862 const HostPortPair& host,
863 DnsQueryType query_type,
864 MdnsListenerUpdateType update_type,
865 const HostPortPair& host_result) {
866 for (auto* listener : listeners_) {
867 if (listener->host() == host && listener->query_type() == query_type)
868 listener->TriggerHostnameResult(update_type, host_result);
869 }
870 }
871
TriggerMdnsListeners(const HostPortPair & host,DnsQueryType query_type,MdnsListenerUpdateType update_type)872 void MockHostResolverBase::TriggerMdnsListeners(
873 const HostPortPair& host,
874 DnsQueryType query_type,
875 MdnsListenerUpdateType update_type) {
876 for (auto* listener : listeners_) {
877 if (listener->host() == host && listener->query_type() == query_type)
878 listener->TriggerUnhandledResult(update_type);
879 }
880 }
881
request(size_t id)882 MockHostResolverBase::RequestImpl* MockHostResolverBase::request(size_t id) {
883 RequestMap::iterator request = state_->mutable_requests().find(id);
884 CHECK(request != state_->mutable_requests().end());
885 CHECK_EQ(request->second->id(), id);
886 return (*request).second;
887 }
888
889 // start id from 1 to distinguish from NULL RequestHandle
MockHostResolverBase(bool use_caching,int cache_invalidation_num,RuleResolver rule_resolver)890 MockHostResolverBase::MockHostResolverBase(bool use_caching,
891 int cache_invalidation_num,
892 RuleResolver rule_resolver)
893 : rule_resolver_(std::move(rule_resolver)),
894 initial_cache_invalidation_num_(cache_invalidation_num),
895 tick_clock_(base::DefaultTickClock::GetInstance()),
896 state_(base::MakeRefCounted<State>()) {
897 if (use_caching)
898 cache_ = std::make_unique<HostCache>(kMaxCacheEntries);
899 else
900 DCHECK_GE(0, cache_invalidation_num);
901 }
902
Resolve(RequestImpl * request)903 int MockHostResolverBase::Resolve(RequestImpl* request) {
904 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
905
906 last_request_priority_ = request->parameters().initial_priority;
907 last_request_network_anonymization_key_ =
908 request->network_anonymization_key();
909 last_secure_dns_policy_ = request->parameters().secure_dns_policy;
910 state_->IncrementNumResolve();
911 std::vector<HostResolverEndpointResult> endpoints;
912 std::set<std::string> aliases;
913 absl::optional<HostCache::EntryStaleness> stale_info;
914 // TODO(crbug.com/1264933): Allow caching `ConnectionEndpoint` results.
915 int rv = ResolveFromIPLiteralOrCache(
916 request->request_endpoint(), request->network_anonymization_key(),
917 request->parameters().dns_query_type, request->host_resolver_flags(),
918 request->parameters().source, request->parameters().cache_usage,
919 &endpoints, &aliases, &stale_info);
920
921 if (rv == OK && !request->parameters().is_speculative) {
922 request->SetEndpointResults(std::move(endpoints), std::move(aliases),
923 std::move(stale_info));
924 } else {
925 request->SetError(rv);
926 }
927
928 if (rv != ERR_DNS_CACHE_MISS ||
929 request->parameters().source == HostResolverSource::LOCAL_ONLY) {
930 return SquashErrorCode(rv);
931 }
932
933 // Just like the real resolver, refuse to do anything with invalid
934 // hostnames.
935 if (!dns_names_util::IsValidDnsName(
936 request->request_endpoint().GetHostnameWithoutBrackets())) {
937 request->SetError(ERR_NAME_NOT_RESOLVED);
938 return ERR_NAME_NOT_RESOLVED;
939 }
940
941 if (synchronous_mode_)
942 return DoSynchronousResolution(*request);
943
944 // Store the request for asynchronous resolution
945 size_t id = next_request_id_++;
946 request->set_id(id);
947 state_->mutable_requests()[id] = request;
948
949 if (!ondemand_mode_) {
950 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
951 FROM_HERE,
952 base::BindOnce(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id));
953 }
954
955 return ERR_IO_PENDING;
956 }
957
ResolveFromIPLiteralOrCache(const Host & endpoint,const NetworkAnonymizationKey & network_anonymization_key,DnsQueryType dns_query_type,HostResolverFlags flags,HostResolverSource source,HostResolver::ResolveHostParameters::CacheUsage cache_usage,std::vector<HostResolverEndpointResult> * out_endpoints,std::set<std::string> * out_aliases,absl::optional<HostCache::EntryStaleness> * out_stale_info)958 int MockHostResolverBase::ResolveFromIPLiteralOrCache(
959 const Host& endpoint,
960 const NetworkAnonymizationKey& network_anonymization_key,
961 DnsQueryType dns_query_type,
962 HostResolverFlags flags,
963 HostResolverSource source,
964 HostResolver::ResolveHostParameters::CacheUsage cache_usage,
965 std::vector<HostResolverEndpointResult>* out_endpoints,
966 std::set<std::string>* out_aliases,
967 absl::optional<HostCache::EntryStaleness>* out_stale_info) {
968 DCHECK(out_endpoints);
969 DCHECK(out_aliases);
970 DCHECK(out_stale_info);
971 out_endpoints->clear();
972 out_aliases->clear();
973 *out_stale_info = absl::nullopt;
974
975 IPAddress ip_address;
976 if (ip_address.AssignFromIPLiteral(endpoint.GetHostnameWithoutBrackets())) {
977 const DnsQueryType desired_address_query =
978 AddressFamilyToDnsQueryType(GetAddressFamily(ip_address));
979 DCHECK_NE(desired_address_query, DnsQueryType::UNSPECIFIED);
980
981 // This matches the behavior HostResolverImpl.
982 if (dns_query_type != DnsQueryType::UNSPECIFIED &&
983 dns_query_type != desired_address_query) {
984 return ERR_NAME_NOT_RESOLVED;
985 }
986
987 *out_endpoints = std::vector<HostResolverEndpointResult>(1);
988 (*out_endpoints)[0].ip_endpoints.emplace_back(ip_address,
989 endpoint.GetPort());
990 if (flags & HOST_RESOLVER_CANONNAME)
991 *out_aliases = {ip_address.ToString()};
992 return OK;
993 }
994
995 std::vector<IPEndPoint> localhost_endpoints;
996 // Immediately resolve any "localhost" or recognized similar names.
997 if (IsAddressType(dns_query_type) &&
998 ResolveLocalHostname(endpoint.GetHostnameWithoutBrackets(),
999 &localhost_endpoints)) {
1000 *out_endpoints = std::vector<HostResolverEndpointResult>(1);
1001 (*out_endpoints)[0].ip_endpoints = localhost_endpoints;
1002 return OK;
1003 }
1004 int rv = ERR_DNS_CACHE_MISS;
1005 bool cache_allowed =
1006 cache_usage == HostResolver::ResolveHostParameters::CacheUsage::ALLOWED ||
1007 cache_usage ==
1008 HostResolver::ResolveHostParameters::CacheUsage::STALE_ALLOWED;
1009 if (cache_.get() && cache_allowed) {
1010 // Local-only requests search the cache for non-local-only results.
1011 HostResolverSource effective_source =
1012 source == HostResolverSource::LOCAL_ONLY ? HostResolverSource::ANY
1013 : source;
1014 HostCache::Key key(GetCacheHost(endpoint), dns_query_type, flags,
1015 effective_source, network_anonymization_key);
1016 const std::pair<const HostCache::Key, HostCache::Entry>* cache_result;
1017 HostCache::EntryStaleness stale_info = HostCache::kNotStale;
1018 if (cache_usage ==
1019 HostResolver::ResolveHostParameters::CacheUsage::STALE_ALLOWED) {
1020 cache_result = cache_->LookupStale(key, tick_clock_->NowTicks(),
1021 &stale_info, true /* ignore_secure */);
1022 } else {
1023 cache_result = cache_->Lookup(key, tick_clock_->NowTicks(),
1024 true /* ignore_secure */);
1025 }
1026 if (cache_result) {
1027 rv = cache_result->second.error();
1028 if (rv == OK) {
1029 *out_endpoints = cache_result->second.GetEndpoints();
1030
1031 *out_aliases = cache_result->second.aliases();
1032 *out_stale_info = std::move(stale_info);
1033 }
1034
1035 auto cache_invalidation_iterator = cache_invalidation_nums_.find(key);
1036 if (cache_invalidation_iterator != cache_invalidation_nums_.end()) {
1037 DCHECK_LE(1, cache_invalidation_iterator->second);
1038 cache_invalidation_iterator->second--;
1039 if (cache_invalidation_iterator->second == 0) {
1040 HostCache::Entry new_entry(cache_result->second);
1041 cache_->Set(key, new_entry, tick_clock_->NowTicks(),
1042 base::TimeDelta());
1043 cache_invalidation_nums_.erase(cache_invalidation_iterator);
1044 }
1045 }
1046 }
1047 }
1048 return rv;
1049 }
1050
DoSynchronousResolution(RequestImpl & request)1051 int MockHostResolverBase::DoSynchronousResolution(RequestImpl& request) {
1052 state_->IncrementNumNonLocalResolves();
1053
1054 const RuleResolver::RuleResultOrError& result = rule_resolver_.Resolve(
1055 request.request_endpoint(), {request.parameters().dns_query_type},
1056 request.parameters().source);
1057
1058 int error = ERR_UNEXPECTED;
1059 absl::optional<HostCache::Entry> cache_entry;
1060 if (absl::holds_alternative<RuleResolver::RuleResult>(result)) {
1061 const auto& rule_result = absl::get<RuleResolver::RuleResult>(result);
1062 const auto& endpoint_results = rule_result.endpoints;
1063 const auto& aliases = rule_result.aliases;
1064 request.SetEndpointResults(endpoint_results, aliases,
1065 /*staleness=*/absl::nullopt);
1066 // TODO(crbug.com/1264933): Change `error` on empty results?
1067 error = OK;
1068 if (cache_.get()) {
1069 cache_entry = CreateCacheEntry(request.request_endpoint().GetHostname(),
1070 endpoint_results, aliases);
1071 }
1072 } else {
1073 DCHECK(absl::holds_alternative<RuleResolver::ErrorResult>(result));
1074 error = absl::get<RuleResolver::ErrorResult>(result);
1075 request.SetError(error);
1076 if (cache_.get()) {
1077 cache_entry.emplace(error, HostCache::Entry::SOURCE_UNKNOWN);
1078 }
1079 }
1080 if (cache_.get() && cache_entry.has_value()) {
1081 HostCache::Key key(
1082 GetCacheHost(request.request_endpoint()),
1083 request.parameters().dns_query_type, request.host_resolver_flags(),
1084 request.parameters().source, request.network_anonymization_key());
1085 // Storing a failure with TTL 0 so that it overwrites previous value.
1086 base::TimeDelta ttl;
1087 if (error == OK) {
1088 ttl = base::Seconds(kCacheEntryTTLSeconds);
1089 if (initial_cache_invalidation_num_ > 0)
1090 cache_invalidation_nums_[key] = initial_cache_invalidation_num_;
1091 }
1092 cache_->Set(key, cache_entry.value(), tick_clock_->NowTicks(), ttl);
1093 }
1094
1095 return SquashErrorCode(error);
1096 }
1097
AddListener(MdnsListenerImpl * listener)1098 void MockHostResolverBase::AddListener(MdnsListenerImpl* listener) {
1099 listeners_.insert(listener);
1100 }
1101
RemoveCancelledListener(MdnsListenerImpl * listener)1102 void MockHostResolverBase::RemoveCancelledListener(MdnsListenerImpl* listener) {
1103 listeners_.erase(listener);
1104 }
1105
MockHostResolverFactory(MockHostResolverBase::RuleResolver rules,bool use_caching,int cache_invalidation_num)1106 MockHostResolverFactory::MockHostResolverFactory(
1107 MockHostResolverBase::RuleResolver rules,
1108 bool use_caching,
1109 int cache_invalidation_num)
1110 : rules_(std::move(rules)),
1111 use_caching_(use_caching),
1112 cache_invalidation_num_(cache_invalidation_num) {}
1113
1114 MockHostResolverFactory::~MockHostResolverFactory() = default;
1115
CreateResolver(HostResolverManager * manager,base::StringPiece host_mapping_rules,bool enable_caching)1116 std::unique_ptr<HostResolver> MockHostResolverFactory::CreateResolver(
1117 HostResolverManager* manager,
1118 base::StringPiece host_mapping_rules,
1119 bool enable_caching) {
1120 DCHECK(host_mapping_rules.empty());
1121
1122 // Explicit new to access private constructor.
1123 auto resolver = base::WrapUnique(new MockHostResolverBase(
1124 enable_caching && use_caching_, cache_invalidation_num_, rules_));
1125 return resolver;
1126 }
1127
CreateStandaloneResolver(NetLog * net_log,const HostResolver::ManagerOptions & options,base::StringPiece host_mapping_rules,bool enable_caching)1128 std::unique_ptr<HostResolver> MockHostResolverFactory::CreateStandaloneResolver(
1129 NetLog* net_log,
1130 const HostResolver::ManagerOptions& options,
1131 base::StringPiece host_mapping_rules,
1132 bool enable_caching) {
1133 return CreateResolver(nullptr, host_mapping_rules, enable_caching);
1134 }
1135
1136 //-----------------------------------------------------------------------------
1137
Rule(ResolverType resolver_type,base::StringPiece host_pattern,AddressFamily address_family,HostResolverFlags host_resolver_flags,base::StringPiece replacement,std::vector<std::string> dns_aliases,int latency_ms)1138 RuleBasedHostResolverProc::Rule::Rule(ResolverType resolver_type,
1139 base::StringPiece host_pattern,
1140 AddressFamily address_family,
1141 HostResolverFlags host_resolver_flags,
1142 base::StringPiece replacement,
1143 std::vector<std::string> dns_aliases,
1144 int latency_ms)
1145 : resolver_type(resolver_type),
1146 host_pattern(host_pattern),
1147 address_family(address_family),
1148 host_resolver_flags(host_resolver_flags),
1149 replacement(replacement),
1150 dns_aliases(std::move(dns_aliases)),
1151 latency_ms(latency_ms) {
1152 DCHECK(this->dns_aliases != std::vector<std::string>({""}));
1153 }
1154
1155 RuleBasedHostResolverProc::Rule::Rule(const Rule& other) = default;
1156
1157 RuleBasedHostResolverProc::Rule::~Rule() = default;
1158
RuleBasedHostResolverProc(scoped_refptr<HostResolverProc> previous,bool allow_fallback)1159 RuleBasedHostResolverProc::RuleBasedHostResolverProc(
1160 scoped_refptr<HostResolverProc> previous,
1161 bool allow_fallback)
1162 : HostResolverProc(std::move(previous), allow_fallback) {}
1163
AddRule(base::StringPiece host_pattern,base::StringPiece replacement)1164 void RuleBasedHostResolverProc::AddRule(base::StringPiece host_pattern,
1165 base::StringPiece replacement) {
1166 AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1167 replacement);
1168 }
1169
AddRuleForAddressFamily(base::StringPiece host_pattern,AddressFamily address_family,base::StringPiece replacement)1170 void RuleBasedHostResolverProc::AddRuleForAddressFamily(
1171 base::StringPiece host_pattern,
1172 AddressFamily address_family,
1173 base::StringPiece replacement) {
1174 DCHECK(!replacement.empty());
1175 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1176 Rule rule(Rule::kResolverTypeSystem, host_pattern, address_family, flags,
1177 replacement, {} /* dns_aliases */, 0);
1178 AddRuleInternal(rule);
1179 }
1180
AddRuleWithFlags(base::StringPiece host_pattern,base::StringPiece replacement,HostResolverFlags flags,std::vector<std::string> dns_aliases)1181 void RuleBasedHostResolverProc::AddRuleWithFlags(
1182 base::StringPiece host_pattern,
1183 base::StringPiece replacement,
1184 HostResolverFlags flags,
1185 std::vector<std::string> dns_aliases) {
1186 DCHECK(!replacement.empty());
1187 Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1188 flags, replacement, std::move(dns_aliases), 0);
1189 AddRuleInternal(rule);
1190 }
1191
AddIPLiteralRule(base::StringPiece host_pattern,base::StringPiece ip_literal,base::StringPiece canonical_name)1192 void RuleBasedHostResolverProc::AddIPLiteralRule(
1193 base::StringPiece host_pattern,
1194 base::StringPiece ip_literal,
1195 base::StringPiece canonical_name) {
1196 // Literals are always resolved to themselves by HostResolverImpl,
1197 // consequently we do not support remapping them.
1198 IPAddress ip_address;
1199 DCHECK(!ip_address.AssignFromIPLiteral(host_pattern));
1200 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1201 std::vector<std::string> aliases;
1202 if (!canonical_name.empty()) {
1203 flags |= HOST_RESOLVER_CANONNAME;
1204 aliases.emplace_back(canonical_name);
1205 }
1206
1207 Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
1208 ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, std::move(aliases),
1209 0);
1210 AddRuleInternal(rule);
1211 }
1212
AddIPLiteralRuleWithDnsAliases(base::StringPiece host_pattern,base::StringPiece ip_literal,std::vector<std::string> dns_aliases)1213 void RuleBasedHostResolverProc::AddIPLiteralRuleWithDnsAliases(
1214 base::StringPiece host_pattern,
1215 base::StringPiece ip_literal,
1216 std::vector<std::string> dns_aliases) {
1217 // Literals are always resolved to themselves by HostResolverImpl,
1218 // consequently we do not support remapping them.
1219 IPAddress ip_address;
1220 DCHECK(!ip_address.AssignFromIPLiteral(host_pattern));
1221 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1222 if (!dns_aliases.empty())
1223 flags |= HOST_RESOLVER_CANONNAME;
1224
1225 Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
1226 ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal,
1227 std::move(dns_aliases), 0);
1228 AddRuleInternal(rule);
1229 }
1230
AddRuleWithLatency(base::StringPiece host_pattern,base::StringPiece replacement,int latency_ms)1231 void RuleBasedHostResolverProc::AddRuleWithLatency(
1232 base::StringPiece host_pattern,
1233 base::StringPiece replacement,
1234 int latency_ms) {
1235 DCHECK(!replacement.empty());
1236 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1237 Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1238 flags, replacement, /*dns_aliases=*/{}, latency_ms);
1239 AddRuleInternal(rule);
1240 }
1241
AllowDirectLookup(base::StringPiece host_pattern)1242 void RuleBasedHostResolverProc::AllowDirectLookup(
1243 base::StringPiece host_pattern) {
1244 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY;
1245 Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1246 flags, std::string(), /*dns_aliases=*/{}, 0);
1247 AddRuleInternal(rule);
1248 }
1249
AddSimulatedFailure(base::StringPiece host_pattern,HostResolverFlags flags)1250 void RuleBasedHostResolverProc::AddSimulatedFailure(
1251 base::StringPiece host_pattern,
1252 HostResolverFlags flags) {
1253 Rule rule(Rule::kResolverTypeFail, host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
1254 flags, std::string(), /*dns_aliases=*/{}, 0);
1255 AddRuleInternal(rule);
1256 }
1257
AddSimulatedTimeoutFailure(base::StringPiece host_pattern,HostResolverFlags flags)1258 void RuleBasedHostResolverProc::AddSimulatedTimeoutFailure(
1259 base::StringPiece host_pattern,
1260 HostResolverFlags flags) {
1261 Rule rule(Rule::kResolverTypeFailTimeout, host_pattern,
1262 ADDRESS_FAMILY_UNSPECIFIED, flags, std::string(),
1263 /*dns_aliases=*/{}, 0);
1264 AddRuleInternal(rule);
1265 }
1266
ClearRules()1267 void RuleBasedHostResolverProc::ClearRules() {
1268 CHECK(modifications_allowed_);
1269 base::AutoLock lock(rule_lock_);
1270 rules_.clear();
1271 }
1272
DisableModifications()1273 void RuleBasedHostResolverProc::DisableModifications() {
1274 modifications_allowed_ = false;
1275 }
1276
GetRules()1277 RuleBasedHostResolverProc::RuleList RuleBasedHostResolverProc::GetRules() {
1278 RuleList rv;
1279 {
1280 base::AutoLock lock(rule_lock_);
1281 rv = rules_;
1282 }
1283 return rv;
1284 }
1285
NumResolvesForHostPattern(base::StringPiece host_pattern)1286 size_t RuleBasedHostResolverProc::NumResolvesForHostPattern(
1287 base::StringPiece host_pattern) {
1288 base::AutoLock lock(rule_lock_);
1289 return num_resolves_per_host_pattern_[host_pattern];
1290 }
1291
Resolve(const std::string & host,AddressFamily address_family,HostResolverFlags host_resolver_flags,AddressList * addrlist,int * os_error)1292 int RuleBasedHostResolverProc::Resolve(const std::string& host,
1293 AddressFamily address_family,
1294 HostResolverFlags host_resolver_flags,
1295 AddressList* addrlist,
1296 int* os_error) {
1297 base::AutoLock lock(rule_lock_);
1298 RuleList::iterator r;
1299 for (r = rules_.begin(); r != rules_.end(); ++r) {
1300 bool matches_address_family =
1301 r->address_family == ADDRESS_FAMILY_UNSPECIFIED ||
1302 r->address_family == address_family;
1303 // Ignore HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6, since it should
1304 // have no impact on whether a rule matches.
1305 HostResolverFlags flags =
1306 host_resolver_flags & ~HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
1307 // Flags match if all of the bitflags in host_resolver_flags are enabled
1308 // in the rule's host_resolver_flags. However, the rule may have additional
1309 // flags specified, in which case the flags should still be considered a
1310 // match.
1311 bool matches_flags = (r->host_resolver_flags & flags) == flags;
1312 if (matches_flags && matches_address_family &&
1313 base::MatchPattern(host, r->host_pattern)) {
1314 num_resolves_per_host_pattern_[r->host_pattern]++;
1315
1316 if (r->latency_ms != 0) {
1317 base::PlatformThread::Sleep(base::Milliseconds(r->latency_ms));
1318 }
1319
1320 // Remap to a new host.
1321 const std::string& effective_host =
1322 r->replacement.empty() ? host : r->replacement;
1323
1324 // Apply the resolving function to the remapped hostname.
1325 switch (r->resolver_type) {
1326 case Rule::kResolverTypeFail:
1327 return ERR_NAME_NOT_RESOLVED;
1328 case Rule::kResolverTypeFailTimeout:
1329 return ERR_DNS_TIMED_OUT;
1330 case Rule::kResolverTypeSystem:
1331 EnsureSystemHostResolverCallReady();
1332 return SystemHostResolverCall(effective_host, address_family,
1333 host_resolver_flags, addrlist,
1334 os_error);
1335 case Rule::kResolverTypeIPLiteral: {
1336 AddressList raw_addr_list;
1337 std::vector<std::string> aliases;
1338 aliases = (!r->dns_aliases.empty())
1339 ? r->dns_aliases
1340 : std::vector<std::string>({host});
1341 std::vector<net::IPEndPoint> ip_endpoints;
1342 int result = ParseAddressList(effective_host, &ip_endpoints);
1343 // Filter out addresses with the wrong family.
1344 *addrlist = AddressList();
1345 for (const auto& address : ip_endpoints) {
1346 if (address_family == ADDRESS_FAMILY_UNSPECIFIED ||
1347 address_family == address.GetFamily()) {
1348 addrlist->push_back(address);
1349 }
1350 }
1351 addrlist->SetDnsAliases(aliases);
1352
1353 if (result == OK && addrlist->empty())
1354 return ERR_NAME_NOT_RESOLVED;
1355 return result;
1356 }
1357 default:
1358 NOTREACHED();
1359 return ERR_UNEXPECTED;
1360 }
1361 }
1362 }
1363
1364 return ResolveUsingPrevious(host, address_family, host_resolver_flags,
1365 addrlist, os_error);
1366 }
1367
1368 RuleBasedHostResolverProc::~RuleBasedHostResolverProc() = default;
1369
AddRuleInternal(const Rule & rule)1370 void RuleBasedHostResolverProc::AddRuleInternal(const Rule& rule) {
1371 Rule fixed_rule = rule;
1372 // SystemResolverProc expects valid DNS addresses.
1373 // So for kResolverTypeSystem rules:
1374 // * CHECK that replacement is empty (empty domain names mean use a direct
1375 // lookup) or a valid DNS name (which includes IP addresses).
1376 // * If the replacement is an IP address, switch to an IP literal rule.
1377 if (fixed_rule.resolver_type == Rule::kResolverTypeSystem) {
1378 CHECK(fixed_rule.replacement.empty() ||
1379 dns_names_util::IsValidDnsName(fixed_rule.replacement));
1380
1381 IPAddress ip_address;
1382 bool valid_address = ip_address.AssignFromIPLiteral(fixed_rule.replacement);
1383 if (valid_address) {
1384 fixed_rule.resolver_type = Rule::kResolverTypeIPLiteral;
1385 }
1386 }
1387
1388 CHECK(modifications_allowed_);
1389 base::AutoLock lock(rule_lock_);
1390 rules_.push_back(fixed_rule);
1391 }
1392
CreateCatchAllHostResolverProc()1393 scoped_refptr<RuleBasedHostResolverProc> CreateCatchAllHostResolverProc() {
1394 auto catchall =
1395 base::MakeRefCounted<RuleBasedHostResolverProc>(/*previous=*/nullptr,
1396 /*allow_fallback=*/false);
1397 // Note that IPv6 lookups fail.
1398 catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost");
1399
1400 // Next add a rules-based layer that the test controls.
1401 return base::MakeRefCounted<RuleBasedHostResolverProc>(
1402 std::move(catchall), /*allow_fallback=*/false);
1403 }
1404
1405 //-----------------------------------------------------------------------------
1406
1407 // Implementation of ResolveHostRequest that tracks cancellations when the
1408 // request is destroyed after being started.
1409 class HangingHostResolver::RequestImpl
1410 : public HostResolver::ResolveHostRequest,
1411 public HostResolver::ProbeRequest {
1412 public:
RequestImpl(base::WeakPtr<HangingHostResolver> resolver)1413 explicit RequestImpl(base::WeakPtr<HangingHostResolver> resolver)
1414 : resolver_(resolver) {}
1415
1416 RequestImpl(const RequestImpl&) = delete;
1417 RequestImpl& operator=(const RequestImpl&) = delete;
1418
~RequestImpl()1419 ~RequestImpl() override {
1420 if (is_running_ && resolver_)
1421 resolver_->state_->IncrementNumCancellations();
1422 }
1423
Start(CompletionOnceCallback callback)1424 int Start(CompletionOnceCallback callback) override { return Start(); }
1425
Start()1426 int Start() override {
1427 DCHECK(resolver_);
1428 is_running_ = true;
1429 return ERR_IO_PENDING;
1430 }
1431
GetAddressResults() const1432 const AddressList* GetAddressResults() const override {
1433 base::ImmediateCrash();
1434 }
1435
GetEndpointResults() const1436 const std::vector<HostResolverEndpointResult>* GetEndpointResults()
1437 const override {
1438 base::ImmediateCrash();
1439 }
1440
GetTextResults() const1441 const std::vector<std::string>* GetTextResults() const override {
1442 base::ImmediateCrash();
1443 }
1444
GetHostnameResults() const1445 const std::vector<HostPortPair>* GetHostnameResults() const override {
1446 base::ImmediateCrash();
1447 }
1448
GetDnsAliasResults() const1449 const std::set<std::string>* GetDnsAliasResults() const override {
1450 base::ImmediateCrash();
1451 }
1452
GetResolveErrorInfo() const1453 net::ResolveErrorInfo GetResolveErrorInfo() const override {
1454 base::ImmediateCrash();
1455 }
1456
GetStaleInfo() const1457 const absl::optional<HostCache::EntryStaleness>& GetStaleInfo()
1458 const override {
1459 base::ImmediateCrash();
1460 }
1461
ChangeRequestPriority(RequestPriority priority)1462 void ChangeRequestPriority(RequestPriority priority) override {}
1463
1464 private:
1465 // Use a WeakPtr as the resolver may be destroyed while there are still
1466 // outstanding request objects.
1467 base::WeakPtr<HangingHostResolver> resolver_;
1468 bool is_running_ = false;
1469 };
1470
1471 HangingHostResolver::State::State() = default;
1472 HangingHostResolver::State::~State() = default;
1473
HangingHostResolver()1474 HangingHostResolver::HangingHostResolver()
1475 : state_(base::MakeRefCounted<State>()) {}
1476
1477 HangingHostResolver::~HangingHostResolver() = default;
1478
OnShutdown()1479 void HangingHostResolver::OnShutdown() {
1480 shutting_down_ = true;
1481 }
1482
1483 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(url::SchemeHostPort host,NetworkAnonymizationKey network_anonymization_key,NetLogWithSource net_log,absl::optional<ResolveHostParameters> optional_parameters)1484 HangingHostResolver::CreateRequest(
1485 url::SchemeHostPort host,
1486 NetworkAnonymizationKey network_anonymization_key,
1487 NetLogWithSource net_log,
1488 absl::optional<ResolveHostParameters> optional_parameters) {
1489 // TODO(crbug.com/1206799): Propagate scheme and make affect behavior.
1490 return CreateRequest(HostPortPair::FromSchemeHostPort(host),
1491 network_anonymization_key, net_log, optional_parameters);
1492 }
1493
1494 std::unique_ptr<HostResolver::ResolveHostRequest>
CreateRequest(const HostPortPair & host,const NetworkAnonymizationKey & network_anonymization_key,const NetLogWithSource & source_net_log,const absl::optional<ResolveHostParameters> & optional_parameters)1495 HangingHostResolver::CreateRequest(
1496 const HostPortPair& host,
1497 const NetworkAnonymizationKey& network_anonymization_key,
1498 const NetLogWithSource& source_net_log,
1499 const absl::optional<ResolveHostParameters>& optional_parameters) {
1500 last_host_ = host;
1501 last_network_anonymization_key_ = network_anonymization_key;
1502
1503 if (shutting_down_)
1504 return CreateFailingRequest(ERR_CONTEXT_SHUT_DOWN);
1505
1506 if (optional_parameters &&
1507 optional_parameters.value().source == HostResolverSource::LOCAL_ONLY) {
1508 return CreateFailingRequest(ERR_DNS_CACHE_MISS);
1509 }
1510
1511 return std::make_unique<RequestImpl>(weak_ptr_factory_.GetWeakPtr());
1512 }
1513
1514 std::unique_ptr<HostResolver::ProbeRequest>
CreateDohProbeRequest()1515 HangingHostResolver::CreateDohProbeRequest() {
1516 if (shutting_down_)
1517 return CreateFailingProbeRequest(ERR_CONTEXT_SHUT_DOWN);
1518
1519 return std::make_unique<RequestImpl>(weak_ptr_factory_.GetWeakPtr());
1520 }
1521
SetRequestContext(URLRequestContext * url_request_context)1522 void HangingHostResolver::SetRequestContext(
1523 URLRequestContext* url_request_context) {}
1524
1525 //-----------------------------------------------------------------------------
1526
1527 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() = default;
1528
ScopedDefaultHostResolverProc(HostResolverProc * proc)1529 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc(
1530 HostResolverProc* proc) {
1531 Init(proc);
1532 }
1533
~ScopedDefaultHostResolverProc()1534 ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() {
1535 HostResolverProc* old_proc =
1536 HostResolverProc::SetDefault(previous_proc_.get());
1537 // The lifetimes of multiple instances must be nested.
1538 CHECK_EQ(old_proc, current_proc_.get());
1539 }
1540
Init(HostResolverProc * proc)1541 void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) {
1542 current_proc_ = proc;
1543 previous_proc_ = HostResolverProc::SetDefault(current_proc_.get());
1544 current_proc_->SetLastProc(previous_proc_);
1545 }
1546
1547 } // namespace net
1548