1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/dns/dns_test_util.h"
11
12 #include <cstdint>
13 #include <optional>
14 #include <string>
15 #include <string_view>
16 #include <utility>
17 #include <vector>
18
19 #include "base/check.h"
20 #include "base/containers/span.h"
21 #include "base/functional/bind.h"
22 #include "base/location.h"
23 #include "base/numerics/byte_conversions.h"
24 #include "base/numerics/safe_conversions.h"
25 #include "base/ranges/algorithm.h"
26 #include "base/strings/strcat.h"
27 #include "base/sys_byteorder.h"
28 #include "base/task/single_thread_task_runner.h"
29 #include "base/test/test_timeouts.h"
30 #include "base/threading/thread_restrictions.h"
31 #include "base/time/time.h"
32 #include "base/types/optional_util.h"
33 #include "net/base/io_buffer.h"
34 #include "net/base/ip_address.h"
35 #include "net/base/ip_endpoint.h"
36 #include "net/base/net_errors.h"
37 #include "net/dns/address_sorter.h"
38 #include "net/dns/dns_hosts.h"
39 #include "net/dns/dns_names_util.h"
40 #include "net/dns/dns_query.h"
41 #include "net/dns/dns_session.h"
42 #include "net/dns/mock_host_resolver.h"
43 #include "net/dns/public/dns_over_https_server_config.h"
44 #include "net/dns/resolve_context.h"
45 #include "testing/gmock/include/gmock/gmock-matchers.h"
46 #include "testing/gtest/include/gtest/gtest.h"
47 #include "url/scheme_host_port.h"
48
49 namespace net {
50 namespace {
51
52 const uint8_t kMalformedResponseHeader[] = {
53 // Header
54 0x00, 0x14, // Arbitrary ID
55 0x81, 0x80, // Standard query response, RA, no error
56 0x00, 0x01, // 1 question
57 0x00, 0x01, // 1 RR (answers)
58 0x00, 0x00, // 0 authority RRs
59 0x00, 0x00, // 0 additional RRs
60 };
61
62 // Create a response containing a valid question (as would normally be validated
63 // in DnsTransaction) but completely missing a header-declared answer.
CreateMalformedResponse(std::string hostname,uint16_t type)64 DnsResponse CreateMalformedResponse(std::string hostname, uint16_t type) {
65 std::optional<std::vector<uint8_t>> dns_name =
66 dns_names_util::DottedNameToNetwork(hostname);
67 CHECK(dns_name.has_value());
68 DnsQuery query(/*id=*/0x14, dns_name.value(), type);
69
70 // Build response to simulate the barebones validation DnsResponse applies to
71 // responses received from the network.
72 auto buffer = base::MakeRefCounted<IOBufferWithSize>(
73 sizeof(kMalformedResponseHeader) + query.question().size());
74 memcpy(buffer->data(), kMalformedResponseHeader,
75 sizeof(kMalformedResponseHeader));
76 memcpy(buffer->data() + sizeof(kMalformedResponseHeader),
77 query.question().data(), query.question().size());
78
79 DnsResponse response(buffer, buffer->size());
80 CHECK(response.InitParseWithoutQuery(buffer->size()));
81
82 return response;
83 }
84
85 class MockAddressSorter : public AddressSorter {
86 public:
87 ~MockAddressSorter() override = default;
Sort(const std::vector<IPEndPoint> & endpoints,CallbackType callback) const88 void Sort(const std::vector<IPEndPoint>& endpoints,
89 CallbackType callback) const override {
90 // Do nothing.
91 std::move(callback).Run(true, endpoints);
92 }
93 };
94
95 } // namespace
96
CreateValidDnsConfig()97 DnsConfig CreateValidDnsConfig() {
98 IPAddress dns_ip(192, 168, 1, 0);
99 DnsConfig config;
100 config.nameservers.emplace_back(dns_ip, dns_protocol::kDefaultPort);
101 config.doh_config =
102 *DnsOverHttpsConfig::FromString("https://dns.example.com/");
103 config.secure_dns_mode = SecureDnsMode::kOff;
104 EXPECT_TRUE(config.IsValid());
105 return config;
106 }
107
BuildTestDnsRecord(std::string name,uint16_t type,std::string rdata,base::TimeDelta ttl)108 DnsResourceRecord BuildTestDnsRecord(std::string name,
109 uint16_t type,
110 std::string rdata,
111 base::TimeDelta ttl) {
112 DCHECK(!name.empty());
113
114 DnsResourceRecord record;
115 record.name = std::move(name);
116 record.type = type;
117 record.klass = dns_protocol::kClassIN;
118 record.ttl = ttl.InSeconds();
119
120 if (!rdata.empty())
121 record.SetOwnedRdata(std::move(rdata));
122
123 return record;
124 }
125
BuildTestCnameRecord(std::string name,std::string_view canonical_name,base::TimeDelta ttl)126 DnsResourceRecord BuildTestCnameRecord(std::string name,
127 std::string_view canonical_name,
128 base::TimeDelta ttl) {
129 DCHECK(!name.empty());
130 DCHECK(!canonical_name.empty());
131
132 std::optional<std::vector<uint8_t>> rdata =
133 dns_names_util::DottedNameToNetwork(canonical_name);
134 CHECK(rdata.has_value());
135
136 return BuildTestDnsRecord(
137 std::move(name), dns_protocol::kTypeCNAME,
138 std::string(reinterpret_cast<char*>(rdata.value().data()),
139 rdata.value().size()),
140 ttl);
141 }
142
BuildTestAddressRecord(std::string name,const IPAddress & ip,base::TimeDelta ttl)143 DnsResourceRecord BuildTestAddressRecord(std::string name,
144 const IPAddress& ip,
145 base::TimeDelta ttl) {
146 DCHECK(!name.empty());
147 DCHECK(ip.IsValid());
148
149 return BuildTestDnsRecord(
150 std::move(name),
151 ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA,
152 net::IPAddressToPackedString(ip), ttl);
153 }
154
BuildTestTextRecord(std::string name,std::vector<std::string> text_strings,base::TimeDelta ttl)155 DnsResourceRecord BuildTestTextRecord(std::string name,
156 std::vector<std::string> text_strings,
157 base::TimeDelta ttl) {
158 DCHECK(!text_strings.empty());
159
160 std::string rdata;
161 for (const std::string& text_string : text_strings) {
162 DCHECK(!text_string.empty());
163
164 rdata += base::checked_cast<unsigned char>(text_string.size());
165 rdata += text_string;
166 }
167
168 return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeTXT,
169 std::move(rdata), ttl);
170 }
171
BuildTestHttpsAliasRecord(std::string name,std::string_view alias_name,base::TimeDelta ttl)172 DnsResourceRecord BuildTestHttpsAliasRecord(std::string name,
173 std::string_view alias_name,
174 base::TimeDelta ttl) {
175 DCHECK(!name.empty());
176
177 std::string rdata("\000\000", 2);
178
179 std::optional<std::vector<uint8_t>> alias_domain =
180 dns_names_util::DottedNameToNetwork(alias_name);
181 CHECK(alias_domain.has_value());
182 rdata.append(reinterpret_cast<char*>(alias_domain.value().data()),
183 alias_domain.value().size());
184
185 return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeHttps,
186 std::move(rdata), ttl);
187 }
188
BuildTestHttpsServiceAlpnParam(const std::vector<std::string> & alpns)189 std::pair<uint16_t, std::string> BuildTestHttpsServiceAlpnParam(
190 const std::vector<std::string>& alpns) {
191 std::string param_value;
192
193 for (const std::string& alpn : alpns) {
194 CHECK(!alpn.empty());
195 param_value.append(
196 1, static_cast<char>(base::checked_cast<uint8_t>(alpn.size())));
197 param_value.append(alpn);
198 }
199
200 return std::pair(dns_protocol::kHttpsServiceParamKeyAlpn,
201 std::move(param_value));
202 }
203
BuildTestHttpsServiceEchConfigParam(base::span<const uint8_t> ech_config_list)204 std::pair<uint16_t, std::string> BuildTestHttpsServiceEchConfigParam(
205 base::span<const uint8_t> ech_config_list) {
206 return std::pair(
207 dns_protocol::kHttpsServiceParamKeyEchConfig,
208 std::string(reinterpret_cast<const char*>(ech_config_list.data()),
209 ech_config_list.size()));
210 }
211
BuildTestHttpsServiceMandatoryParam(std::vector<uint16_t> param_key_list)212 std::pair<uint16_t, std::string> BuildTestHttpsServiceMandatoryParam(
213 std::vector<uint16_t> param_key_list) {
214 base::ranges::sort(param_key_list);
215
216 std::string value;
217 for (uint16_t param_key : param_key_list) {
218 std::array<uint8_t, 2> num_buffer = base::U16ToBigEndian(param_key);
219 value.append(num_buffer.begin(), num_buffer.end());
220 }
221
222 return std::pair(dns_protocol::kHttpsServiceParamKeyMandatory,
223 std::move(value));
224 }
225
BuildTestHttpsServicePortParam(uint16_t port)226 std::pair<uint16_t, std::string> BuildTestHttpsServicePortParam(uint16_t port) {
227 std::array<uint8_t, 2> buffer = base::U16ToBigEndian(port);
228 return std::pair(dns_protocol::kHttpsServiceParamKeyPort,
229 std::string(buffer.begin(), buffer.end()));
230 }
231
BuildTestHttpsServiceRecord(std::string name,uint16_t priority,std::string_view service_name,const std::map<uint16_t,std::string> & params,base::TimeDelta ttl)232 DnsResourceRecord BuildTestHttpsServiceRecord(
233 std::string name,
234 uint16_t priority,
235 std::string_view service_name,
236 const std::map<uint16_t, std::string>& params,
237 base::TimeDelta ttl) {
238 DCHECK(!name.empty());
239 DCHECK_NE(priority, 0);
240
241 std::string rdata;
242
243 {
244 std::array<uint8_t, 2> buf = base::U16ToBigEndian(priority);
245 rdata.append(buf.begin(), buf.end());
246 }
247
248 std::optional<std::vector<uint8_t>> service_domain;
249 if (service_name == ".") {
250 // HTTPS records have special behavior for `service_name == "."` (that it
251 // will be treated as if the service name is the same as the record owner
252 // name), so allow such inputs despite normally being disallowed for
253 // Chrome-encoded DNS names.
254 service_domain = std::vector<uint8_t>{0};
255 } else {
256 service_domain = dns_names_util::DottedNameToNetwork(service_name);
257 }
258 CHECK(service_domain.has_value());
259 rdata.append(reinterpret_cast<char*>(service_domain.value().data()),
260 service_domain.value().size());
261
262 for (auto& param : params) {
263 {
264 std::array<uint8_t, 2> buf = base::U16ToBigEndian(param.first);
265 rdata.append(buf.begin(), buf.end());
266 }
267 {
268 std::array<uint8_t, 2> buf = base::U16ToBigEndian(
269 base::checked_cast<uint16_t>(param.second.size()));
270 rdata.append(buf.begin(), buf.end());
271 }
272 rdata.append(param.second);
273 }
274
275 return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeHttps,
276 std::move(rdata), ttl);
277 }
278
BuildTestDnsResponse(std::string name,uint16_t type,const std::vector<DnsResourceRecord> & answers,const std::vector<DnsResourceRecord> & authority,const std::vector<DnsResourceRecord> & additional,uint8_t rcode)279 DnsResponse BuildTestDnsResponse(
280 std::string name,
281 uint16_t type,
282 const std::vector<DnsResourceRecord>& answers,
283 const std::vector<DnsResourceRecord>& authority,
284 const std::vector<DnsResourceRecord>& additional,
285 uint8_t rcode) {
286 DCHECK(!name.empty());
287
288 std::optional<std::vector<uint8_t>> dns_name =
289 dns_names_util::DottedNameToNetwork(name);
290 CHECK(dns_name.has_value());
291
292 std::optional<DnsQuery> query(std::in_place, 0, dns_name.value(), type);
293 return DnsResponse(0, true /* is_authoritative */, answers,
294 authority /* authority_records */,
295 additional /* additional_records */, query, rcode,
296 false /* validate_records */);
297 }
298
BuildTestDnsAddressResponse(std::string name,const IPAddress & ip,std::string answer_name)299 DnsResponse BuildTestDnsAddressResponse(std::string name,
300 const IPAddress& ip,
301 std::string answer_name) {
302 DCHECK(ip.IsValid());
303
304 if (answer_name.empty())
305 answer_name = name;
306
307 std::vector<DnsResourceRecord> answers = {
308 BuildTestAddressRecord(std::move(answer_name), ip)};
309
310 return BuildTestDnsResponse(
311 std::move(name),
312 ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, answers);
313 }
314
BuildTestDnsAddressResponseWithCname(std::string name,const IPAddress & ip,std::string cannonname,std::string answer_name)315 DnsResponse BuildTestDnsAddressResponseWithCname(std::string name,
316 const IPAddress& ip,
317 std::string cannonname,
318 std::string answer_name) {
319 DCHECK(ip.IsValid());
320 DCHECK(!cannonname.empty());
321
322 if (answer_name.empty())
323 answer_name = name;
324
325 std::optional<std::vector<uint8_t>> cname_rdata =
326 dns_names_util::DottedNameToNetwork(cannonname);
327 CHECK(cname_rdata.has_value());
328
329 std::vector<DnsResourceRecord> answers = {
330 BuildTestDnsRecord(
331 std::move(answer_name), dns_protocol::kTypeCNAME,
332 std::string(reinterpret_cast<char*>(cname_rdata.value().data()),
333 cname_rdata.value().size())),
334 BuildTestAddressRecord(std::move(cannonname), ip)};
335
336 return BuildTestDnsResponse(
337 std::move(name),
338 ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, answers);
339 }
340
BuildTestDnsTextResponse(std::string name,std::vector<std::vector<std::string>> text_records,std::string answer_name)341 DnsResponse BuildTestDnsTextResponse(
342 std::string name,
343 std::vector<std::vector<std::string>> text_records,
344 std::string answer_name) {
345 if (answer_name.empty())
346 answer_name = name;
347
348 std::vector<DnsResourceRecord> answers;
349 for (std::vector<std::string>& text_record : text_records) {
350 answers.push_back(BuildTestTextRecord(answer_name, std::move(text_record)));
351 }
352
353 return BuildTestDnsResponse(std::move(name), dns_protocol::kTypeTXT, answers);
354 }
355
BuildTestDnsPointerResponse(std::string name,std::vector<std::string> pointer_names,std::string answer_name)356 DnsResponse BuildTestDnsPointerResponse(std::string name,
357 std::vector<std::string> pointer_names,
358 std::string answer_name) {
359 if (answer_name.empty())
360 answer_name = name;
361
362 std::vector<DnsResourceRecord> answers;
363 for (std::string& pointer_name : pointer_names) {
364 std::optional<std::vector<uint8_t>> rdata =
365 dns_names_util::DottedNameToNetwork(pointer_name);
366 CHECK(rdata.has_value());
367
368 answers.push_back(BuildTestDnsRecord(
369 answer_name, dns_protocol::kTypePTR,
370 std::string(reinterpret_cast<char*>(rdata.value().data()),
371 rdata.value().size())));
372 }
373
374 return BuildTestDnsResponse(std::move(name), dns_protocol::kTypePTR, answers);
375 }
376
BuildTestDnsServiceResponse(std::string name,std::vector<TestServiceRecord> service_records,std::string answer_name)377 DnsResponse BuildTestDnsServiceResponse(
378 std::string name,
379 std::vector<TestServiceRecord> service_records,
380 std::string answer_name) {
381 if (answer_name.empty())
382 answer_name = name;
383
384 std::vector<DnsResourceRecord> answers;
385 for (TestServiceRecord& service_record : service_records) {
386 std::string rdata;
387 {
388 std::array<uint8_t, 2> buf =
389 base::U16ToBigEndian(service_record.priority);
390 rdata.append(buf.begin(), buf.end());
391 }
392 {
393 std::array<uint8_t, 2> buf = base::U16ToBigEndian(service_record.weight);
394 rdata.append(buf.begin(), buf.end());
395 }
396 {
397 std::array<uint8_t, 2> buf = base::U16ToBigEndian(service_record.port);
398 rdata.append(buf.begin(), buf.end());
399 }
400
401 std::optional<std::vector<uint8_t>> dns_name =
402 dns_names_util::DottedNameToNetwork(service_record.target);
403 CHECK(dns_name.has_value());
404 rdata.append(reinterpret_cast<char*>(dns_name.value().data()),
405 dns_name.value().size());
406
407 answers.push_back(BuildTestDnsRecord(answer_name, dns_protocol::kTypeSRV,
408 std::move(rdata), base::Hours(5)));
409 }
410
411 return BuildTestDnsResponse(std::move(name), dns_protocol::kTypeSRV, answers);
412 }
413
Result(ResultType type,std::optional<DnsResponse> response,std::optional<int> net_error)414 MockDnsClientRule::Result::Result(ResultType type,
415 std::optional<DnsResponse> response,
416 std::optional<int> net_error)
417 : type(type), response(std::move(response)), net_error(net_error) {}
418
Result(DnsResponse response)419 MockDnsClientRule::Result::Result(DnsResponse response)
420 : type(ResultType::kOk),
421 response(std::move(response)),
422 net_error(std::nullopt) {}
423
424 MockDnsClientRule::Result::Result(Result&&) = default;
425
426 MockDnsClientRule::Result& MockDnsClientRule::Result::operator=(Result&&) =
427 default;
428
429 MockDnsClientRule::Result::~Result() = default;
430
MockDnsClientRule(const std::string & prefix,uint16_t qtype,bool secure,Result result,bool delay,URLRequestContext * context)431 MockDnsClientRule::MockDnsClientRule(const std::string& prefix,
432 uint16_t qtype,
433 bool secure,
434 Result result,
435 bool delay,
436 URLRequestContext* context)
437 : result(std::move(result)),
438 prefix(prefix),
439 qtype(qtype),
440 secure(secure),
441 delay(delay),
442 context(context) {}
443
444 MockDnsClientRule::MockDnsClientRule(MockDnsClientRule&& rule) = default;
445
446 // A DnsTransaction which uses MockDnsClientRuleList to determine the response.
447 class MockDnsTransactionFactory::MockTransaction final : public DnsTransaction {
448 public:
MockTransaction(const MockDnsClientRuleList & rules,std::string hostname,uint16_t qtype,bool secure,bool force_doh_server_available,SecureDnsMode secure_dns_mode,ResolveContext * resolve_context,bool fast_timeout)449 MockTransaction(const MockDnsClientRuleList& rules,
450 std::string hostname,
451 uint16_t qtype,
452 bool secure,
453 bool force_doh_server_available,
454 SecureDnsMode secure_dns_mode,
455 ResolveContext* resolve_context,
456 bool fast_timeout)
457 : hostname_(std::move(hostname)), qtype_(qtype) {
458 // Do not allow matching any rules if transaction is secure and no DoH
459 // servers are available.
460 if (!secure || force_doh_server_available ||
461 resolve_context->NumAvailableDohServers(
462 resolve_context->current_session_for_testing()) > 0) {
463 // Find the relevant rule which matches |qtype|, |secure|, prefix of
464 // |hostname_|, and |url_request_context| (iff the rule context is not
465 // null).
466 for (const auto& rule : rules) {
467 const std::string& prefix = rule.prefix;
468 if ((rule.qtype == qtype) && (rule.secure == secure) &&
469 (hostname_.size() >= prefix.size()) &&
470 (hostname_.compare(0, prefix.size(), prefix) == 0) &&
471 (!rule.context ||
472 rule.context == resolve_context->url_request_context())) {
473 const MockDnsClientRule::Result* result = &rule.result;
474 result_ = MockDnsClientRule::Result(result->type);
475 result_.net_error = result->net_error;
476 delayed_ = rule.delay;
477
478 // Generate a DnsResponse when not provided with the rule.
479 std::vector<DnsResourceRecord> authority_records;
480 std::optional<std::vector<uint8_t>> dns_name =
481 dns_names_util::DottedNameToNetwork(hostname_);
482 CHECK(dns_name.has_value());
483 std::optional<DnsQuery> query(std::in_place, /*id=*/22,
484 dns_name.value(), qtype_);
485 switch (result->type) {
486 case MockDnsClientRule::ResultType::kNoDomain:
487 case MockDnsClientRule::ResultType::kEmpty:
488 DCHECK(!result->response); // Not expected to be provided.
489 authority_records = {BuildTestDnsRecord(
490 hostname_, dns_protocol::kTypeSOA, "fake rdata")};
491 result_.response = DnsResponse(
492 22 /* id */, false /* is_authoritative */,
493 std::vector<DnsResourceRecord>() /* answers */,
494 authority_records,
495 std::vector<DnsResourceRecord>() /* additional_records */,
496 query,
497 result->type == MockDnsClientRule::ResultType::kNoDomain
498 ? dns_protocol::kRcodeNXDOMAIN
499 : 0);
500 break;
501 case MockDnsClientRule::ResultType::kFail:
502 if (result->response)
503 SetResponse(result);
504 break;
505 case MockDnsClientRule::ResultType::kTimeout:
506 DCHECK(!result->response); // Not expected to be provided.
507 break;
508 case MockDnsClientRule::ResultType::kSlow:
509 if (!fast_timeout)
510 SetResponse(result);
511 break;
512 case MockDnsClientRule::ResultType::kOk:
513 SetResponse(result);
514 break;
515 case MockDnsClientRule::ResultType::kMalformed:
516 DCHECK(!result->response); // Not expected to be provided.
517 result_.response = CreateMalformedResponse(hostname_, qtype_);
518 break;
519 case MockDnsClientRule::ResultType::kUnexpected:
520 if (!delayed_) {
521 // Assume a delayed kUnexpected transaction is only an issue if
522 // allowed to complete.
523 ADD_FAILURE()
524 << "Unexpected DNS transaction created for hostname "
525 << hostname_;
526 }
527 break;
528 }
529
530 break;
531 }
532 }
533 }
534 }
535
GetHostname() const536 const std::string& GetHostname() const override { return hostname_; }
537
GetType() const538 uint16_t GetType() const override { return qtype_; }
539
Start(ResponseCallback callback)540 void Start(ResponseCallback callback) override {
541 CHECK(!callback.is_null());
542 CHECK(callback_.is_null());
543 EXPECT_FALSE(started_);
544
545 callback_ = std::move(callback);
546 started_ = true;
547 if (delayed_)
548 return;
549 // Using WeakPtr to cleanly cancel when transaction is destroyed.
550 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
551 FROM_HERE, base::BindOnce(&MockTransaction::Finish,
552 weak_ptr_factory_.GetWeakPtr()));
553 }
554
FinishDelayedTransaction()555 void FinishDelayedTransaction() {
556 EXPECT_TRUE(delayed_);
557 delayed_ = false;
558 Finish();
559 }
560
delayed() const561 bool delayed() const { return delayed_; }
562
AsWeakPtr()563 base::WeakPtr<MockTransaction> AsWeakPtr() {
564 return weak_ptr_factory_.GetWeakPtr();
565 }
566
567 private:
SetResponse(const MockDnsClientRule::Result * result)568 void SetResponse(const MockDnsClientRule::Result* result) {
569 if (result->response) {
570 // Copy response in case |result| is destroyed before the transaction
571 // completes.
572 auto buffer_copy = base::MakeRefCounted<IOBufferWithSize>(
573 result->response->io_buffer_size());
574 memcpy(buffer_copy->data(), result->response->io_buffer()->data(),
575 result->response->io_buffer_size());
576 result_.response = DnsResponse(std::move(buffer_copy),
577 result->response->io_buffer_size());
578 CHECK(result_.response->InitParseWithoutQuery(
579 result->response->io_buffer_size()));
580 } else {
581 // Generated response only available for address types.
582 DCHECK(qtype_ == dns_protocol::kTypeA ||
583 qtype_ == dns_protocol::kTypeAAAA);
584 result_.response = BuildTestDnsAddressResponse(
585 hostname_, qtype_ == dns_protocol::kTypeA
586 ? IPAddress::IPv4Localhost()
587 : IPAddress::IPv6Localhost());
588 }
589 }
590
Finish()591 void Finish() {
592 switch (result_.type) {
593 case MockDnsClientRule::ResultType::kNoDomain:
594 case MockDnsClientRule::ResultType::kFail: {
595 int error = result_.net_error.value_or(ERR_NAME_NOT_RESOLVED);
596 DCHECK_NE(error, OK);
597 std::move(callback_).Run(error, base::OptionalToPtr(result_.response));
598 break;
599 }
600 case MockDnsClientRule::ResultType::kEmpty:
601 case MockDnsClientRule::ResultType::kOk:
602 case MockDnsClientRule::ResultType::kMalformed:
603 DCHECK(!result_.net_error.has_value());
604 std::move(callback_).Run(OK, base::OptionalToPtr(result_.response));
605 break;
606 case MockDnsClientRule::ResultType::kTimeout:
607 DCHECK(!result_.net_error.has_value());
608 std::move(callback_).Run(ERR_DNS_TIMED_OUT, /*response=*/nullptr);
609 break;
610 case MockDnsClientRule::ResultType::kSlow:
611 if (result_.response) {
612 std::move(callback_).Run(
613 result_.net_error.value_or(OK),
614 result_.response ? &result_.response.value() : nullptr);
615 } else {
616 DCHECK(!result_.net_error.has_value());
617 std::move(callback_).Run(ERR_DNS_TIMED_OUT, /*response=*/nullptr);
618 }
619 break;
620 case MockDnsClientRule::ResultType::kUnexpected:
621 ADD_FAILURE() << "Unexpected DNS transaction completed for hostname "
622 << hostname_;
623 break;
624 }
625 }
626
SetRequestPriority(RequestPriority priority)627 void SetRequestPriority(RequestPriority priority) override {}
628
629 MockDnsClientRule::Result result_{MockDnsClientRule::ResultType::kFail};
630 const std::string hostname_;
631 const uint16_t qtype_;
632 ResponseCallback callback_;
633 bool started_ = false;
634 bool delayed_ = false;
635 base::WeakPtrFactory<MockTransaction> weak_ptr_factory_{this};
636 };
637
638 class MockDnsTransactionFactory::MockDohProbeRunner : public DnsProbeRunner {
639 public:
MockDohProbeRunner(base::WeakPtr<MockDnsTransactionFactory> factory)640 explicit MockDohProbeRunner(base::WeakPtr<MockDnsTransactionFactory> factory)
641 : factory_(std::move(factory)) {}
642
~MockDohProbeRunner()643 ~MockDohProbeRunner() override {
644 if (factory_)
645 factory_->running_doh_probe_runners_.erase(this);
646 }
647
Start(bool network_change)648 void Start(bool network_change) override {
649 DCHECK(factory_);
650 factory_->running_doh_probe_runners_.insert(this);
651 }
652
GetDelayUntilNextProbeForTest(size_t doh_server_index) const653 base::TimeDelta GetDelayUntilNextProbeForTest(
654 size_t doh_server_index) const override {
655 NOTREACHED();
656 }
657
658 private:
659 base::WeakPtr<MockDnsTransactionFactory> factory_;
660 };
661
MockDnsTransactionFactory(MockDnsClientRuleList rules)662 MockDnsTransactionFactory::MockDnsTransactionFactory(
663 MockDnsClientRuleList rules)
664 : rules_(std::move(rules)) {}
665
666 MockDnsTransactionFactory::~MockDnsTransactionFactory() = default;
667
CreateTransaction(std::string hostname,uint16_t qtype,const NetLogWithSource &,bool secure,SecureDnsMode secure_dns_mode,ResolveContext * resolve_context,bool fast_timeout)668 std::unique_ptr<DnsTransaction> MockDnsTransactionFactory::CreateTransaction(
669 std::string hostname,
670 uint16_t qtype,
671 const NetLogWithSource&,
672 bool secure,
673 SecureDnsMode secure_dns_mode,
674 ResolveContext* resolve_context,
675 bool fast_timeout) {
676 std::unique_ptr<MockTransaction> transaction =
677 std::make_unique<MockTransaction>(rules_, std::move(hostname), qtype,
678 secure, force_doh_server_available_,
679 secure_dns_mode, resolve_context,
680 fast_timeout);
681 if (transaction->delayed())
682 delayed_transactions_.push_back(transaction->AsWeakPtr());
683 return transaction;
684 }
685
CreateDohProbeRunner(ResolveContext * resolve_context)686 std::unique_ptr<DnsProbeRunner> MockDnsTransactionFactory::CreateDohProbeRunner(
687 ResolveContext* resolve_context) {
688 return std::make_unique<MockDohProbeRunner>(weak_ptr_factory_.GetWeakPtr());
689 }
690
AddEDNSOption(std::unique_ptr<OptRecordRdata::Opt> opt)691 void MockDnsTransactionFactory::AddEDNSOption(
692 std::unique_ptr<OptRecordRdata::Opt> opt) {}
693
GetSecureDnsModeForTest()694 SecureDnsMode MockDnsTransactionFactory::GetSecureDnsModeForTest() {
695 return SecureDnsMode::kAutomatic;
696 }
697
CompleteDelayedTransactions()698 void MockDnsTransactionFactory::CompleteDelayedTransactions() {
699 DelayedTransactionList old_delayed_transactions;
700 old_delayed_transactions.swap(delayed_transactions_);
701 for (auto& old_delayed_transaction : old_delayed_transactions) {
702 if (old_delayed_transaction.get())
703 old_delayed_transaction->FinishDelayedTransaction();
704 }
705 }
706
CompleteOneDelayedTransactionOfType(DnsQueryType type)707 bool MockDnsTransactionFactory::CompleteOneDelayedTransactionOfType(
708 DnsQueryType type) {
709 for (base::WeakPtr<MockTransaction>& t : delayed_transactions_) {
710 if (t && t->GetType() == DnsQueryTypeToQtype(type)) {
711 t->FinishDelayedTransaction();
712 t.reset();
713 return true;
714 }
715 }
716 return false;
717 }
718
MockDnsClient(DnsConfig config,MockDnsClientRuleList rules)719 MockDnsClient::MockDnsClient(DnsConfig config, MockDnsClientRuleList rules)
720 : config_(std::move(config)),
721 factory_(std::make_unique<MockDnsTransactionFactory>(std::move(rules))),
722 address_sorter_(std::make_unique<MockAddressSorter>()) {
723 effective_config_ = BuildEffectiveConfig();
724 session_ = BuildSession();
725 }
726
727 MockDnsClient::~MockDnsClient() = default;
728
CanUseSecureDnsTransactions() const729 bool MockDnsClient::CanUseSecureDnsTransactions() const {
730 const DnsConfig* config = GetEffectiveConfig();
731 return config && config->IsValid() && !config->doh_config.servers().empty();
732 }
733
CanUseInsecureDnsTransactions() const734 bool MockDnsClient::CanUseInsecureDnsTransactions() const {
735 const DnsConfig* config = GetEffectiveConfig();
736 return config && config->IsValid() && insecure_enabled_ &&
737 !config->dns_over_tls_active;
738 }
739
CanQueryAdditionalTypesViaInsecureDns() const740 bool MockDnsClient::CanQueryAdditionalTypesViaInsecureDns() const {
741 DCHECK(CanUseInsecureDnsTransactions());
742 return additional_types_enabled_;
743 }
744
SetInsecureEnabled(bool enabled,bool additional_types_enabled)745 void MockDnsClient::SetInsecureEnabled(bool enabled,
746 bool additional_types_enabled) {
747 insecure_enabled_ = enabled;
748 additional_types_enabled_ = additional_types_enabled;
749 }
750
FallbackFromSecureTransactionPreferred(ResolveContext * context) const751 bool MockDnsClient::FallbackFromSecureTransactionPreferred(
752 ResolveContext* context) const {
753 bool doh_server_available =
754 force_doh_server_available_ ||
755 context->NumAvailableDohServers(session_.get()) > 0;
756 return !CanUseSecureDnsTransactions() || !doh_server_available;
757 }
758
FallbackFromInsecureTransactionPreferred() const759 bool MockDnsClient::FallbackFromInsecureTransactionPreferred() const {
760 return !CanUseInsecureDnsTransactions() ||
761 fallback_failures_ >= max_fallback_failures_;
762 }
763
SetSystemConfig(std::optional<DnsConfig> system_config)764 bool MockDnsClient::SetSystemConfig(std::optional<DnsConfig> system_config) {
765 if (ignore_system_config_changes_)
766 return false;
767
768 std::optional<DnsConfig> before = effective_config_;
769 config_ = std::move(system_config);
770 effective_config_ = BuildEffectiveConfig();
771 session_ = BuildSession();
772 return before != effective_config_;
773 }
774
SetConfigOverrides(DnsConfigOverrides config_overrides)775 bool MockDnsClient::SetConfigOverrides(DnsConfigOverrides config_overrides) {
776 std::optional<DnsConfig> before = effective_config_;
777 overrides_ = std::move(config_overrides);
778 effective_config_ = BuildEffectiveConfig();
779 session_ = BuildSession();
780 return before != effective_config_;
781 }
782
ReplaceCurrentSession()783 void MockDnsClient::ReplaceCurrentSession() {
784 // Noop if no current effective config.
785 session_ = BuildSession();
786 }
787
GetCurrentSession()788 DnsSession* MockDnsClient::GetCurrentSession() {
789 return session_.get();
790 }
791
GetEffectiveConfig() const792 const DnsConfig* MockDnsClient::GetEffectiveConfig() const {
793 return effective_config_.has_value() ? &effective_config_.value() : nullptr;
794 }
795
GetDnsConfigAsValueForNetLog() const796 base::Value::Dict MockDnsClient::GetDnsConfigAsValueForNetLog() const {
797 // This is just a stub implementation that never produces a meaningful value.
798 return base::Value::Dict();
799 }
800
GetHosts() const801 const DnsHosts* MockDnsClient::GetHosts() const {
802 const DnsConfig* config = GetEffectiveConfig();
803 if (!config)
804 return nullptr;
805
806 return &config->hosts;
807 }
808
GetTransactionFactory()809 DnsTransactionFactory* MockDnsClient::GetTransactionFactory() {
810 return GetEffectiveConfig() ? factory_.get() : nullptr;
811 }
812
GetAddressSorter()813 AddressSorter* MockDnsClient::GetAddressSorter() {
814 return GetEffectiveConfig() ? address_sorter_.get() : nullptr;
815 }
816
IncrementInsecureFallbackFailures()817 void MockDnsClient::IncrementInsecureFallbackFailures() {
818 ++fallback_failures_;
819 }
820
ClearInsecureFallbackFailures()821 void MockDnsClient::ClearInsecureFallbackFailures() {
822 fallback_failures_ = 0;
823 }
824
GetSystemConfigForTesting() const825 std::optional<DnsConfig> MockDnsClient::GetSystemConfigForTesting() const {
826 return config_;
827 }
828
GetConfigOverridesForTesting() const829 DnsConfigOverrides MockDnsClient::GetConfigOverridesForTesting() const {
830 return overrides_;
831 }
832
SetTransactionFactoryForTesting(std::unique_ptr<DnsTransactionFactory> factory)833 void MockDnsClient::SetTransactionFactoryForTesting(
834 std::unique_ptr<DnsTransactionFactory> factory) {
835 NOTREACHED();
836 }
837
SetAddressSorterForTesting(std::unique_ptr<AddressSorter> address_sorter)838 void MockDnsClient::SetAddressSorterForTesting(
839 std::unique_ptr<AddressSorter> address_sorter) {
840 address_sorter_ = std::move(address_sorter);
841 }
842
GetPresetAddrs(const url::SchemeHostPort & endpoint) const843 std::optional<std::vector<IPEndPoint>> MockDnsClient::GetPresetAddrs(
844 const url::SchemeHostPort& endpoint) const {
845 EXPECT_THAT(preset_endpoint_, testing::Optional(endpoint));
846 return preset_addrs_;
847 }
848
CompleteDelayedTransactions()849 void MockDnsClient::CompleteDelayedTransactions() {
850 factory_->CompleteDelayedTransactions();
851 }
852
CompleteOneDelayedTransactionOfType(DnsQueryType type)853 bool MockDnsClient::CompleteOneDelayedTransactionOfType(DnsQueryType type) {
854 return factory_->CompleteOneDelayedTransactionOfType(type);
855 }
856
SetForceDohServerAvailable(bool available)857 void MockDnsClient::SetForceDohServerAvailable(bool available) {
858 force_doh_server_available_ = available;
859 factory_->set_force_doh_server_available(available);
860 }
861
BuildEffectiveConfig()862 std::optional<DnsConfig> MockDnsClient::BuildEffectiveConfig() {
863 if (overrides_.OverridesEverything())
864 return overrides_.ApplyOverrides(DnsConfig());
865 if (!config_ || !config_.value().IsValid())
866 return std::nullopt;
867
868 return overrides_.ApplyOverrides(config_.value());
869 }
870
BuildSession()871 scoped_refptr<DnsSession> MockDnsClient::BuildSession() {
872 if (!effective_config_)
873 return nullptr;
874
875 // Session not expected to be used for anything that will actually require
876 // random numbers.
877 auto null_random_callback =
878 base::BindRepeating([](int, int) -> int { base::ImmediateCrash(); });
879
880 return base::MakeRefCounted<DnsSession>(
881 effective_config_.value(), null_random_callback, nullptr /* net_log */);
882 }
883
MockHostResolverProc()884 MockHostResolverProc::MockHostResolverProc()
885 : HostResolverProc(nullptr),
886 requests_waiting_(&lock_),
887 slots_available_(&lock_) {}
888
889 MockHostResolverProc::~MockHostResolverProc() = default;
890
WaitFor(unsigned count)891 bool MockHostResolverProc::WaitFor(unsigned count) {
892 base::AutoLock lock(lock_);
893 base::Time start_time = base::Time::Now();
894 while (num_requests_waiting_ < count) {
895 requests_waiting_.TimedWait(TestTimeouts::action_timeout());
896 if (base::Time::Now() > start_time + TestTimeouts::action_timeout()) {
897 return false;
898 }
899 }
900 return true;
901 }
902
SignalMultiple(unsigned count)903 void MockHostResolverProc::SignalMultiple(unsigned count) {
904 base::AutoLock lock(lock_);
905 num_slots_available_ += count;
906 slots_available_.Broadcast();
907 }
908
SignalAll()909 void MockHostResolverProc::SignalAll() {
910 base::AutoLock lock(lock_);
911 num_slots_available_ = num_requests_waiting_;
912 slots_available_.Broadcast();
913 }
914
AddRule(const std::string & hostname,AddressFamily family,const AddressList & result,HostResolverFlags flags)915 void MockHostResolverProc::AddRule(const std::string& hostname,
916 AddressFamily family,
917 const AddressList& result,
918 HostResolverFlags flags) {
919 base::AutoLock lock(lock_);
920 rules_[ResolveKey(hostname, family, flags)] = result;
921 }
922
AddRule(const std::string & hostname,AddressFamily family,const std::string & ip_list,HostResolverFlags flags,const std::string & canonical_name)923 void MockHostResolverProc::AddRule(const std::string& hostname,
924 AddressFamily family,
925 const std::string& ip_list,
926 HostResolverFlags flags,
927 const std::string& canonical_name) {
928 AddressList result;
929 std::vector<std::string> dns_aliases;
930 if (canonical_name != "") {
931 dns_aliases = {canonical_name};
932 }
933 int rv = ParseAddressList(ip_list, &result.endpoints());
934 result.SetDnsAliases(dns_aliases);
935 DCHECK_EQ(OK, rv);
936 AddRule(hostname, family, result, flags);
937 }
938
AddRuleForAllFamilies(const std::string & hostname,const std::string & ip_list,HostResolverFlags flags,const std::string & canonical_name)939 void MockHostResolverProc::AddRuleForAllFamilies(
940 const std::string& hostname,
941 const std::string& ip_list,
942 HostResolverFlags flags,
943 const std::string& canonical_name) {
944 AddressList result;
945 std::vector<std::string> dns_aliases;
946 if (canonical_name != "") {
947 dns_aliases = {canonical_name};
948 }
949 int rv = ParseAddressList(ip_list, &result.endpoints());
950 result.SetDnsAliases(dns_aliases);
951 DCHECK_EQ(OK, rv);
952 AddRule(hostname, ADDRESS_FAMILY_UNSPECIFIED, result, flags);
953 AddRule(hostname, ADDRESS_FAMILY_IPV4, result, flags);
954 AddRule(hostname, ADDRESS_FAMILY_IPV6, result, flags);
955 }
956
Resolve(const std::string & hostname,AddressFamily address_family,HostResolverFlags host_resolver_flags,AddressList * addrlist,int * os_error)957 int MockHostResolverProc::Resolve(const std::string& hostname,
958 AddressFamily address_family,
959 HostResolverFlags host_resolver_flags,
960 AddressList* addrlist,
961 int* os_error) {
962 base::AutoLock lock(lock_);
963 capture_list_.emplace_back(hostname, address_family, host_resolver_flags);
964 ++num_requests_waiting_;
965 requests_waiting_.Broadcast();
966 {
967 base::ScopedAllowBaseSyncPrimitivesForTesting
968 scoped_allow_base_sync_primitives;
969 while (!num_slots_available_) {
970 slots_available_.Wait();
971 }
972 }
973 DCHECK_GT(num_requests_waiting_, 0u);
974 --num_slots_available_;
975 --num_requests_waiting_;
976 if (rules_.empty()) {
977 int rv = ParseAddressList("127.0.0.1", &addrlist->endpoints());
978 DCHECK_EQ(OK, rv);
979 return OK;
980 }
981 ResolveKey key(hostname, address_family, host_resolver_flags);
982 if (rules_.count(key) == 0) {
983 return ERR_NAME_NOT_RESOLVED;
984 }
985 *addrlist = rules_[key];
986 return OK;
987 }
988
GetCaptureList() const989 MockHostResolverProc::CaptureList MockHostResolverProc::GetCaptureList() const {
990 CaptureList copy;
991 {
992 base::AutoLock lock(lock_);
993 copy = capture_list_;
994 }
995 return copy;
996 }
997
ClearCaptureList()998 void MockHostResolverProc::ClearCaptureList() {
999 base::AutoLock lock(lock_);
1000 capture_list_.clear();
1001 }
1002
HasBlockedRequests() const1003 bool MockHostResolverProc::HasBlockedRequests() const {
1004 base::AutoLock lock(lock_);
1005 return num_requests_waiting_ > num_slots_available_;
1006 }
1007
1008 } // namespace net
1009