1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/dns/dns_test_util.h"
6
7 #include <cstdint>
8 #include <string>
9 #include <utility>
10 #include <vector>
11
12 #include "base/big_endian.h"
13 #include "base/check.h"
14 #include "base/functional/bind.h"
15 #include "base/location.h"
16 #include "base/numerics/safe_conversions.h"
17 #include "base/ranges/algorithm.h"
18 #include "base/strings/strcat.h"
19 #include "base/sys_byteorder.h"
20 #include "base/task/single_thread_task_runner.h"
21 #include "base/time/time.h"
22 #include "base/types/optional_util.h"
23 #include "net/base/io_buffer.h"
24 #include "net/base/ip_address.h"
25 #include "net/base/ip_endpoint.h"
26 #include "net/base/net_errors.h"
27 #include "net/dns/address_sorter.h"
28 #include "net/dns/dns_hosts.h"
29 #include "net/dns/dns_names_util.h"
30 #include "net/dns/dns_query.h"
31 #include "net/dns/dns_session.h"
32 #include "net/dns/public/dns_over_https_server_config.h"
33 #include "net/dns/resolve_context.h"
34 #include "testing/gmock/include/gmock/gmock-matchers.h"
35 #include "testing/gtest/include/gtest/gtest.h"
36 #include "third_party/abseil-cpp/absl/types/optional.h"
37 #include "url/scheme_host_port.h"
38
39 namespace net {
40 namespace {
41
42 const uint8_t kMalformedResponseHeader[] = {
43 // Header
44 0x00, 0x14, // Arbitrary ID
45 0x81, 0x80, // Standard query response, RA, no error
46 0x00, 0x01, // 1 question
47 0x00, 0x01, // 1 RR (answers)
48 0x00, 0x00, // 0 authority RRs
49 0x00, 0x00, // 0 additional RRs
50 };
51
52 // Create a response containing a valid question (as would normally be validated
53 // in DnsTransaction) but completely missing a header-declared answer.
CreateMalformedResponse(std::string hostname,uint16_t type)54 DnsResponse CreateMalformedResponse(std::string hostname, uint16_t type) {
55 absl::optional<std::vector<uint8_t>> dns_name =
56 dns_names_util::DottedNameToNetwork(hostname);
57 CHECK(dns_name.has_value());
58 DnsQuery query(/*id=*/0x14, dns_name.value(), type);
59
60 // Build response to simulate the barebones validation DnsResponse applies to
61 // responses received from the network.
62 auto buffer = base::MakeRefCounted<IOBufferWithSize>(
63 sizeof(kMalformedResponseHeader) + query.question().size());
64 memcpy(buffer->data(), kMalformedResponseHeader,
65 sizeof(kMalformedResponseHeader));
66 memcpy(buffer->data() + sizeof(kMalformedResponseHeader),
67 query.question().data(), query.question().size());
68
69 DnsResponse response(buffer, buffer->size());
70 CHECK(response.InitParseWithoutQuery(buffer->size()));
71
72 return response;
73 }
74
75 class MockAddressSorter : public AddressSorter {
76 public:
77 ~MockAddressSorter() override = default;
Sort(const std::vector<IPEndPoint> & endpoints,CallbackType callback) const78 void Sort(const std::vector<IPEndPoint>& endpoints,
79 CallbackType callback) const override {
80 // Do nothing.
81 std::move(callback).Run(true, endpoints);
82 }
83 };
84
85 } // namespace
86
BuildTestDnsRecord(std::string name,uint16_t type,std::string rdata,base::TimeDelta ttl)87 DnsResourceRecord BuildTestDnsRecord(std::string name,
88 uint16_t type,
89 std::string rdata,
90 base::TimeDelta ttl) {
91 DCHECK(!name.empty());
92
93 DnsResourceRecord record;
94 record.name = std::move(name);
95 record.type = type;
96 record.klass = dns_protocol::kClassIN;
97 record.ttl = ttl.InSeconds();
98
99 if (!rdata.empty())
100 record.SetOwnedRdata(std::move(rdata));
101
102 return record;
103 }
104
BuildTestCnameRecord(std::string name,base::StringPiece canonical_name,base::TimeDelta ttl)105 DnsResourceRecord BuildTestCnameRecord(std::string name,
106 base::StringPiece canonical_name,
107 base::TimeDelta ttl) {
108 DCHECK(!name.empty());
109 DCHECK(!canonical_name.empty());
110
111 absl::optional<std::vector<uint8_t>> rdata =
112 dns_names_util::DottedNameToNetwork(canonical_name);
113 CHECK(rdata.has_value());
114
115 return BuildTestDnsRecord(
116 std::move(name), dns_protocol::kTypeCNAME,
117 std::string(reinterpret_cast<char*>(rdata.value().data()),
118 rdata.value().size()),
119 ttl);
120 }
121
BuildTestAddressRecord(std::string name,const IPAddress & ip,base::TimeDelta ttl)122 DnsResourceRecord BuildTestAddressRecord(std::string name,
123 const IPAddress& ip,
124 base::TimeDelta ttl) {
125 DCHECK(!name.empty());
126 DCHECK(ip.IsValid());
127
128 return BuildTestDnsRecord(
129 std::move(name),
130 ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA,
131 net::IPAddressToPackedString(ip), ttl);
132 }
133
BuildTestTextRecord(std::string name,std::vector<std::string> text_strings,base::TimeDelta ttl)134 DnsResourceRecord BuildTestTextRecord(std::string name,
135 std::vector<std::string> text_strings,
136 base::TimeDelta ttl) {
137 DCHECK(!text_strings.empty());
138
139 std::string rdata;
140 for (const std::string& text_string : text_strings) {
141 DCHECK(!text_string.empty());
142
143 rdata += base::checked_cast<unsigned char>(text_string.size());
144 rdata += text_string;
145 }
146
147 return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeTXT,
148 std::move(rdata), ttl);
149 }
150
BuildTestHttpsAliasRecord(std::string name,base::StringPiece alias_name,base::TimeDelta ttl)151 DnsResourceRecord BuildTestHttpsAliasRecord(std::string name,
152 base::StringPiece alias_name,
153 base::TimeDelta ttl) {
154 DCHECK(!name.empty());
155
156 std::string rdata("\000\000", 2);
157
158 absl::optional<std::vector<uint8_t>> alias_domain =
159 dns_names_util::DottedNameToNetwork(alias_name);
160 CHECK(alias_domain.has_value());
161 rdata.append(reinterpret_cast<char*>(alias_domain.value().data()),
162 alias_domain.value().size());
163
164 return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeHttps,
165 std::move(rdata), ttl);
166 }
167
BuildTestHttpsServiceAlpnParam(const std::vector<std::string> & alpns)168 std::pair<uint16_t, std::string> BuildTestHttpsServiceAlpnParam(
169 const std::vector<std::string>& alpns) {
170 std::string param_value;
171
172 for (const std::string& alpn : alpns) {
173 CHECK(!alpn.empty());
174 param_value.append(
175 1, static_cast<char>(base::checked_cast<uint8_t>(alpn.size())));
176 param_value.append(alpn);
177 }
178
179 return std::make_pair(dns_protocol::kHttpsServiceParamKeyAlpn,
180 std::move(param_value));
181 }
182
BuildTestHttpsServiceEchConfigParam(base::span<const uint8_t> ech_config_list)183 std::pair<uint16_t, std::string> BuildTestHttpsServiceEchConfigParam(
184 base::span<const uint8_t> ech_config_list) {
185 return std::make_pair(
186 dns_protocol::kHttpsServiceParamKeyEchConfig,
187 std::string(reinterpret_cast<const char*>(ech_config_list.data()),
188 ech_config_list.size()));
189 }
190
BuildTestHttpsServiceMandatoryParam(std::vector<uint16_t> param_key_list)191 std::pair<uint16_t, std::string> BuildTestHttpsServiceMandatoryParam(
192 std::vector<uint16_t> param_key_list) {
193 base::ranges::sort(param_key_list);
194
195 std::string value;
196 for (uint16_t param_key : param_key_list) {
197 char num_buffer[2];
198 base::WriteBigEndian(num_buffer, param_key);
199 value.append(num_buffer, 2);
200 }
201
202 return std::make_pair(dns_protocol::kHttpsServiceParamKeyMandatory,
203 std::move(value));
204 }
205
BuildTestHttpsServicePortParam(uint16_t port)206 std::pair<uint16_t, std::string> BuildTestHttpsServicePortParam(uint16_t port) {
207 char buffer[2];
208 base::WriteBigEndian(buffer, port);
209
210 return std::make_pair(dns_protocol::kHttpsServiceParamKeyPort,
211 std::string(buffer, 2));
212 }
213
BuildTestHttpsServiceRecord(std::string name,uint16_t priority,base::StringPiece service_name,const std::map<uint16_t,std::string> & params,base::TimeDelta ttl)214 DnsResourceRecord BuildTestHttpsServiceRecord(
215 std::string name,
216 uint16_t priority,
217 base::StringPiece service_name,
218 const std::map<uint16_t, std::string>& params,
219 base::TimeDelta ttl) {
220 DCHECK(!name.empty());
221 DCHECK_NE(priority, 0);
222
223 std::string rdata;
224
225 char num_buffer[2];
226 base::WriteBigEndian(num_buffer, priority);
227 rdata.append(num_buffer, 2);
228
229 absl::optional<std::vector<uint8_t>> service_domain;
230 if (service_name == ".") {
231 // HTTPS records have special behavior for `service_name == "."` (that it
232 // will be treated as if the service name is the same as the record owner
233 // name), so allow such inputs despite normally being disallowed for
234 // Chrome-encoded DNS names.
235 service_domain = std::vector<uint8_t>{0};
236 } else {
237 service_domain = dns_names_util::DottedNameToNetwork(service_name);
238 }
239 CHECK(service_domain.has_value());
240 rdata.append(reinterpret_cast<char*>(service_domain.value().data()),
241 service_domain.value().size());
242
243 for (auto& param : params) {
244 base::WriteBigEndian(num_buffer, param.first);
245 rdata.append(num_buffer, 2);
246
247 base::WriteBigEndian(num_buffer,
248 base::checked_cast<uint16_t>(param.second.size()));
249 rdata.append(num_buffer, 2);
250
251 rdata.append(param.second);
252 }
253
254 return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeHttps,
255 std::move(rdata), ttl);
256 }
257
BuildTestDnsResponse(std::string name,uint16_t type,const std::vector<DnsResourceRecord> & answers,const std::vector<DnsResourceRecord> & authority,const std::vector<DnsResourceRecord> & additional,uint8_t rcode)258 DnsResponse BuildTestDnsResponse(
259 std::string name,
260 uint16_t type,
261 const std::vector<DnsResourceRecord>& answers,
262 const std::vector<DnsResourceRecord>& authority,
263 const std::vector<DnsResourceRecord>& additional,
264 uint8_t rcode) {
265 DCHECK(!name.empty());
266
267 absl::optional<std::vector<uint8_t>> dns_name =
268 dns_names_util::DottedNameToNetwork(name);
269 CHECK(dns_name.has_value());
270
271 absl::optional<DnsQuery> query(absl::in_place, 0, dns_name.value(), type);
272 return DnsResponse(0, true /* is_authoritative */, answers,
273 authority /* authority_records */,
274 additional /* additional_records */, query, rcode,
275 false /* validate_records */);
276 }
277
BuildTestDnsAddressResponse(std::string name,const IPAddress & ip,std::string answer_name)278 DnsResponse BuildTestDnsAddressResponse(std::string name,
279 const IPAddress& ip,
280 std::string answer_name) {
281 DCHECK(ip.IsValid());
282
283 if (answer_name.empty())
284 answer_name = name;
285
286 std::vector<DnsResourceRecord> answers = {
287 BuildTestAddressRecord(std::move(answer_name), ip)};
288
289 return BuildTestDnsResponse(
290 std::move(name),
291 ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, answers);
292 }
293
BuildTestDnsAddressResponseWithCname(std::string name,const IPAddress & ip,std::string cannonname,std::string answer_name)294 DnsResponse BuildTestDnsAddressResponseWithCname(std::string name,
295 const IPAddress& ip,
296 std::string cannonname,
297 std::string answer_name) {
298 DCHECK(ip.IsValid());
299 DCHECK(!cannonname.empty());
300
301 if (answer_name.empty())
302 answer_name = name;
303
304 absl::optional<std::vector<uint8_t>> cname_rdata =
305 dns_names_util::DottedNameToNetwork(cannonname);
306 CHECK(cname_rdata.has_value());
307
308 std::vector<DnsResourceRecord> answers = {
309 BuildTestDnsRecord(
310 std::move(answer_name), dns_protocol::kTypeCNAME,
311 std::string(reinterpret_cast<char*>(cname_rdata.value().data()),
312 cname_rdata.value().size())),
313 BuildTestAddressRecord(std::move(cannonname), ip)};
314
315 return BuildTestDnsResponse(
316 std::move(name),
317 ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, answers);
318 }
319
BuildTestDnsTextResponse(std::string name,std::vector<std::vector<std::string>> text_records,std::string answer_name)320 DnsResponse BuildTestDnsTextResponse(
321 std::string name,
322 std::vector<std::vector<std::string>> text_records,
323 std::string answer_name) {
324 if (answer_name.empty())
325 answer_name = name;
326
327 std::vector<DnsResourceRecord> answers;
328 for (std::vector<std::string>& text_record : text_records) {
329 answers.push_back(BuildTestTextRecord(answer_name, std::move(text_record)));
330 }
331
332 return BuildTestDnsResponse(std::move(name), dns_protocol::kTypeTXT, answers);
333 }
334
BuildTestDnsPointerResponse(std::string name,std::vector<std::string> pointer_names,std::string answer_name)335 DnsResponse BuildTestDnsPointerResponse(std::string name,
336 std::vector<std::string> pointer_names,
337 std::string answer_name) {
338 if (answer_name.empty())
339 answer_name = name;
340
341 std::vector<DnsResourceRecord> answers;
342 for (std::string& pointer_name : pointer_names) {
343 absl::optional<std::vector<uint8_t>> rdata =
344 dns_names_util::DottedNameToNetwork(pointer_name);
345 CHECK(rdata.has_value());
346
347 answers.push_back(BuildTestDnsRecord(
348 answer_name, dns_protocol::kTypePTR,
349 std::string(reinterpret_cast<char*>(rdata.value().data()),
350 rdata.value().size())));
351 }
352
353 return BuildTestDnsResponse(std::move(name), dns_protocol::kTypePTR, answers);
354 }
355
BuildTestDnsServiceResponse(std::string name,std::vector<TestServiceRecord> service_records,std::string answer_name)356 DnsResponse BuildTestDnsServiceResponse(
357 std::string name,
358 std::vector<TestServiceRecord> service_records,
359 std::string answer_name) {
360 if (answer_name.empty())
361 answer_name = name;
362
363 std::vector<DnsResourceRecord> answers;
364 for (TestServiceRecord& service_record : service_records) {
365 std::string rdata;
366 char num_buffer[2];
367 base::WriteBigEndian(num_buffer, service_record.priority);
368 rdata.append(num_buffer, 2);
369 base::WriteBigEndian(num_buffer, service_record.weight);
370 rdata.append(num_buffer, 2);
371 base::WriteBigEndian(num_buffer, service_record.port);
372 rdata.append(num_buffer, 2);
373
374 absl::optional<std::vector<uint8_t>> dns_name =
375 dns_names_util::DottedNameToNetwork(service_record.target);
376 CHECK(dns_name.has_value());
377 rdata.append(reinterpret_cast<char*>(dns_name.value().data()),
378 dns_name.value().size());
379
380 answers.push_back(BuildTestDnsRecord(answer_name, dns_protocol::kTypeSRV,
381 std::move(rdata), base::Hours(5)));
382 }
383
384 return BuildTestDnsResponse(std::move(name), dns_protocol::kTypeSRV, answers);
385 }
386
Result(ResultType type,absl::optional<DnsResponse> response,absl::optional<int> net_error)387 MockDnsClientRule::Result::Result(ResultType type,
388 absl::optional<DnsResponse> response,
389 absl::optional<int> net_error)
390 : type(type), response(std::move(response)), net_error(net_error) {}
391
Result(DnsResponse response)392 MockDnsClientRule::Result::Result(DnsResponse response)
393 : type(ResultType::kOk),
394 response(std::move(response)),
395 net_error(absl::nullopt) {}
396
397 MockDnsClientRule::Result::Result(Result&&) = default;
398
399 MockDnsClientRule::Result& MockDnsClientRule::Result::operator=(Result&&) =
400 default;
401
402 MockDnsClientRule::Result::~Result() = default;
403
MockDnsClientRule(const std::string & prefix,uint16_t qtype,bool secure,Result result,bool delay,URLRequestContext * context)404 MockDnsClientRule::MockDnsClientRule(const std::string& prefix,
405 uint16_t qtype,
406 bool secure,
407 Result result,
408 bool delay,
409 URLRequestContext* context)
410 : result(std::move(result)),
411 prefix(prefix),
412 qtype(qtype),
413 secure(secure),
414 delay(delay),
415 context(context) {}
416
417 MockDnsClientRule::MockDnsClientRule(MockDnsClientRule&& rule) = default;
418
419 // A DnsTransaction which uses MockDnsClientRuleList to determine the response.
420 class MockDnsTransactionFactory::MockTransaction
421 : public DnsTransaction,
422 public base::SupportsWeakPtr<MockTransaction> {
423 public:
MockTransaction(const MockDnsClientRuleList & rules,std::string hostname,uint16_t qtype,bool secure,bool force_doh_server_available,SecureDnsMode secure_dns_mode,ResolveContext * resolve_context,bool fast_timeout)424 MockTransaction(const MockDnsClientRuleList& rules,
425 std::string hostname,
426 uint16_t qtype,
427 bool secure,
428 bool force_doh_server_available,
429 SecureDnsMode secure_dns_mode,
430 ResolveContext* resolve_context,
431 bool fast_timeout)
432 : hostname_(std::move(hostname)), qtype_(qtype) {
433 // Do not allow matching any rules if transaction is secure and no DoH
434 // servers are available.
435 if (!secure || force_doh_server_available ||
436 resolve_context->NumAvailableDohServers(
437 resolve_context->current_session_for_testing()) > 0) {
438 // Find the relevant rule which matches |qtype|, |secure|, prefix of
439 // |hostname_|, and |url_request_context| (iff the rule context is not
440 // null).
441 for (const auto& rule : rules) {
442 const std::string& prefix = rule.prefix;
443 if ((rule.qtype == qtype) && (rule.secure == secure) &&
444 (hostname_.size() >= prefix.size()) &&
445 (hostname_.compare(0, prefix.size(), prefix) == 0) &&
446 (!rule.context ||
447 rule.context == resolve_context->url_request_context())) {
448 const MockDnsClientRule::Result* result = &rule.result;
449 result_ = MockDnsClientRule::Result(result->type);
450 result_.net_error = result->net_error;
451 delayed_ = rule.delay;
452
453 // Generate a DnsResponse when not provided with the rule.
454 std::vector<DnsResourceRecord> authority_records;
455 absl::optional<std::vector<uint8_t>> dns_name =
456 dns_names_util::DottedNameToNetwork(hostname_);
457 CHECK(dns_name.has_value());
458 absl::optional<DnsQuery> query(absl::in_place, /*id=*/22,
459 dns_name.value(), qtype_);
460 switch (result->type) {
461 case MockDnsClientRule::ResultType::kNoDomain:
462 case MockDnsClientRule::ResultType::kEmpty:
463 DCHECK(!result->response); // Not expected to be provided.
464 authority_records = {BuildTestDnsRecord(
465 hostname_, dns_protocol::kTypeSOA, "fake rdata")};
466 result_.response = DnsResponse(
467 22 /* id */, false /* is_authoritative */,
468 std::vector<DnsResourceRecord>() /* answers */,
469 authority_records,
470 std::vector<DnsResourceRecord>() /* additional_records */,
471 query,
472 result->type == MockDnsClientRule::ResultType::kNoDomain
473 ? dns_protocol::kRcodeNXDOMAIN
474 : 0);
475 break;
476 case MockDnsClientRule::ResultType::kFail:
477 if (result->response)
478 SetResponse(result);
479 break;
480 case MockDnsClientRule::ResultType::kTimeout:
481 DCHECK(!result->response); // Not expected to be provided.
482 break;
483 case MockDnsClientRule::ResultType::kSlow:
484 if (!fast_timeout)
485 SetResponse(result);
486 break;
487 case MockDnsClientRule::ResultType::kOk:
488 SetResponse(result);
489 break;
490 case MockDnsClientRule::ResultType::kMalformed:
491 DCHECK(!result->response); // Not expected to be provided.
492 result_.response = CreateMalformedResponse(hostname_, qtype_);
493 break;
494 case MockDnsClientRule::ResultType::kUnexpected:
495 if (!delayed_) {
496 // Assume a delayed kUnexpected transaction is only an issue if
497 // allowed to complete.
498 ADD_FAILURE()
499 << "Unexpected DNS transaction created for hostname "
500 << hostname_;
501 }
502 break;
503 }
504
505 break;
506 }
507 }
508 }
509 }
510
GetHostname() const511 const std::string& GetHostname() const override { return hostname_; }
512
GetType() const513 uint16_t GetType() const override { return qtype_; }
514
Start(ResponseCallback callback)515 void Start(ResponseCallback callback) override {
516 CHECK(!callback.is_null());
517 CHECK(callback_.is_null());
518 EXPECT_FALSE(started_);
519
520 callback_ = std::move(callback);
521 started_ = true;
522 if (delayed_)
523 return;
524 // Using WeakPtr to cleanly cancel when transaction is destroyed.
525 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
526 FROM_HERE, base::BindOnce(&MockTransaction::Finish, AsWeakPtr()));
527 }
528
FinishDelayedTransaction()529 void FinishDelayedTransaction() {
530 EXPECT_TRUE(delayed_);
531 delayed_ = false;
532 Finish();
533 }
534
delayed() const535 bool delayed() const { return delayed_; }
536
537 private:
SetResponse(const MockDnsClientRule::Result * result)538 void SetResponse(const MockDnsClientRule::Result* result) {
539 if (result->response) {
540 // Copy response in case |result| is destroyed before the transaction
541 // completes.
542 auto buffer_copy =
543 base::MakeRefCounted<IOBuffer>(result->response->io_buffer_size());
544 memcpy(buffer_copy->data(), result->response->io_buffer()->data(),
545 result->response->io_buffer_size());
546 result_.response = DnsResponse(std::move(buffer_copy),
547 result->response->io_buffer_size());
548 CHECK(result_.response->InitParseWithoutQuery(
549 result->response->io_buffer_size()));
550 } else {
551 // Generated response only available for address types.
552 DCHECK(qtype_ == dns_protocol::kTypeA ||
553 qtype_ == dns_protocol::kTypeAAAA);
554 result_.response = BuildTestDnsAddressResponse(
555 hostname_, qtype_ == dns_protocol::kTypeA
556 ? IPAddress::IPv4Localhost()
557 : IPAddress::IPv6Localhost());
558 }
559 }
560
Finish()561 void Finish() {
562 switch (result_.type) {
563 case MockDnsClientRule::ResultType::kNoDomain:
564 case MockDnsClientRule::ResultType::kFail: {
565 int error = result_.net_error.value_or(ERR_NAME_NOT_RESOLVED);
566 DCHECK_NE(error, OK);
567 std::move(callback_).Run(error, base::OptionalToPtr(result_.response));
568 break;
569 }
570 case MockDnsClientRule::ResultType::kEmpty:
571 case MockDnsClientRule::ResultType::kOk:
572 case MockDnsClientRule::ResultType::kMalformed:
573 DCHECK(!result_.net_error.has_value());
574 std::move(callback_).Run(OK, base::OptionalToPtr(result_.response));
575 break;
576 case MockDnsClientRule::ResultType::kTimeout:
577 DCHECK(!result_.net_error.has_value());
578 std::move(callback_).Run(ERR_DNS_TIMED_OUT, /*response=*/nullptr);
579 break;
580 case MockDnsClientRule::ResultType::kSlow:
581 if (result_.response) {
582 std::move(callback_).Run(
583 result_.net_error.value_or(OK),
584 result_.response ? &result_.response.value() : nullptr);
585 } else {
586 DCHECK(!result_.net_error.has_value());
587 std::move(callback_).Run(ERR_DNS_TIMED_OUT, /*response=*/nullptr);
588 }
589 break;
590 case MockDnsClientRule::ResultType::kUnexpected:
591 ADD_FAILURE() << "Unexpected DNS transaction completed for hostname "
592 << hostname_;
593 break;
594 }
595 }
596
SetRequestPriority(RequestPriority priority)597 void SetRequestPriority(RequestPriority priority) override {}
598
599 MockDnsClientRule::Result result_{MockDnsClientRule::ResultType::kFail};
600 const std::string hostname_;
601 const uint16_t qtype_;
602 ResponseCallback callback_;
603 bool started_ = false;
604 bool delayed_ = false;
605 };
606
607 class MockDnsTransactionFactory::MockDohProbeRunner : public DnsProbeRunner {
608 public:
MockDohProbeRunner(base::WeakPtr<MockDnsTransactionFactory> factory)609 explicit MockDohProbeRunner(base::WeakPtr<MockDnsTransactionFactory> factory)
610 : factory_(std::move(factory)) {}
611
~MockDohProbeRunner()612 ~MockDohProbeRunner() override {
613 if (factory_)
614 factory_->running_doh_probe_runners_.erase(this);
615 }
616
Start(bool network_change)617 void Start(bool network_change) override {
618 DCHECK(factory_);
619 factory_->running_doh_probe_runners_.insert(this);
620 }
621
GetDelayUntilNextProbeForTest(size_t doh_server_index) const622 base::TimeDelta GetDelayUntilNextProbeForTest(
623 size_t doh_server_index) const override {
624 NOTREACHED();
625 return base::TimeDelta();
626 }
627
628 private:
629 base::WeakPtr<MockDnsTransactionFactory> factory_;
630 };
631
MockDnsTransactionFactory(MockDnsClientRuleList rules)632 MockDnsTransactionFactory::MockDnsTransactionFactory(
633 MockDnsClientRuleList rules)
634 : rules_(std::move(rules)) {}
635
636 MockDnsTransactionFactory::~MockDnsTransactionFactory() = default;
637
CreateTransaction(std::string hostname,uint16_t qtype,const NetLogWithSource &,bool secure,SecureDnsMode secure_dns_mode,ResolveContext * resolve_context,bool fast_timeout)638 std::unique_ptr<DnsTransaction> MockDnsTransactionFactory::CreateTransaction(
639 std::string hostname,
640 uint16_t qtype,
641 const NetLogWithSource&,
642 bool secure,
643 SecureDnsMode secure_dns_mode,
644 ResolveContext* resolve_context,
645 bool fast_timeout) {
646 std::unique_ptr<MockTransaction> transaction =
647 std::make_unique<MockTransaction>(rules_, std::move(hostname), qtype,
648 secure, force_doh_server_available_,
649 secure_dns_mode, resolve_context,
650 fast_timeout);
651 if (transaction->delayed())
652 delayed_transactions_.push_back(transaction->AsWeakPtr());
653 return transaction;
654 }
655
CreateDohProbeRunner(ResolveContext * resolve_context)656 std::unique_ptr<DnsProbeRunner> MockDnsTransactionFactory::CreateDohProbeRunner(
657 ResolveContext* resolve_context) {
658 return std::make_unique<MockDohProbeRunner>(weak_ptr_factory_.GetWeakPtr());
659 }
660
AddEDNSOption(std::unique_ptr<OptRecordRdata::Opt> opt)661 void MockDnsTransactionFactory::AddEDNSOption(
662 std::unique_ptr<OptRecordRdata::Opt> opt) {}
663
GetSecureDnsModeForTest()664 SecureDnsMode MockDnsTransactionFactory::GetSecureDnsModeForTest() {
665 return SecureDnsMode::kAutomatic;
666 }
667
CompleteDelayedTransactions()668 void MockDnsTransactionFactory::CompleteDelayedTransactions() {
669 DelayedTransactionList old_delayed_transactions;
670 old_delayed_transactions.swap(delayed_transactions_);
671 for (auto& old_delayed_transaction : old_delayed_transactions) {
672 if (old_delayed_transaction.get())
673 old_delayed_transaction->FinishDelayedTransaction();
674 }
675 }
676
CompleteOneDelayedTransactionOfType(DnsQueryType type)677 bool MockDnsTransactionFactory::CompleteOneDelayedTransactionOfType(
678 DnsQueryType type) {
679 for (base::WeakPtr<MockTransaction>& t : delayed_transactions_) {
680 if (t && t->GetType() == DnsQueryTypeToQtype(type)) {
681 t->FinishDelayedTransaction();
682 t.reset();
683 return true;
684 }
685 }
686 return false;
687 }
688
MockDnsClient(DnsConfig config,MockDnsClientRuleList rules)689 MockDnsClient::MockDnsClient(DnsConfig config, MockDnsClientRuleList rules)
690 : config_(std::move(config)),
691 factory_(std::make_unique<MockDnsTransactionFactory>(std::move(rules))),
692 address_sorter_(std::make_unique<MockAddressSorter>()) {
693 effective_config_ = BuildEffectiveConfig();
694 session_ = BuildSession();
695 }
696
697 MockDnsClient::~MockDnsClient() = default;
698
CanUseSecureDnsTransactions() const699 bool MockDnsClient::CanUseSecureDnsTransactions() const {
700 const DnsConfig* config = GetEffectiveConfig();
701 return config && config->IsValid() && !config->doh_config.servers().empty();
702 }
703
CanUseInsecureDnsTransactions() const704 bool MockDnsClient::CanUseInsecureDnsTransactions() const {
705 const DnsConfig* config = GetEffectiveConfig();
706 return config && config->IsValid() && insecure_enabled_ &&
707 !config->dns_over_tls_active;
708 }
709
CanQueryAdditionalTypesViaInsecureDns() const710 bool MockDnsClient::CanQueryAdditionalTypesViaInsecureDns() const {
711 DCHECK(CanUseInsecureDnsTransactions());
712 return additional_types_enabled_;
713 }
714
SetInsecureEnabled(bool enabled,bool additional_types_enabled)715 void MockDnsClient::SetInsecureEnabled(bool enabled,
716 bool additional_types_enabled) {
717 insecure_enabled_ = enabled;
718 additional_types_enabled_ = additional_types_enabled;
719 }
720
FallbackFromSecureTransactionPreferred(ResolveContext * context) const721 bool MockDnsClient::FallbackFromSecureTransactionPreferred(
722 ResolveContext* context) const {
723 bool doh_server_available =
724 force_doh_server_available_ ||
725 context->NumAvailableDohServers(session_.get()) > 0;
726 return !CanUseSecureDnsTransactions() || !doh_server_available;
727 }
728
FallbackFromInsecureTransactionPreferred() const729 bool MockDnsClient::FallbackFromInsecureTransactionPreferred() const {
730 return !CanUseInsecureDnsTransactions() ||
731 fallback_failures_ >= max_fallback_failures_;
732 }
733
SetSystemConfig(absl::optional<DnsConfig> system_config)734 bool MockDnsClient::SetSystemConfig(absl::optional<DnsConfig> system_config) {
735 if (ignore_system_config_changes_)
736 return false;
737
738 absl::optional<DnsConfig> before = effective_config_;
739 config_ = std::move(system_config);
740 effective_config_ = BuildEffectiveConfig();
741 session_ = BuildSession();
742 return before != effective_config_;
743 }
744
SetConfigOverrides(DnsConfigOverrides config_overrides)745 bool MockDnsClient::SetConfigOverrides(DnsConfigOverrides config_overrides) {
746 absl::optional<DnsConfig> before = effective_config_;
747 overrides_ = std::move(config_overrides);
748 effective_config_ = BuildEffectiveConfig();
749 session_ = BuildSession();
750 return before != effective_config_;
751 }
752
ReplaceCurrentSession()753 void MockDnsClient::ReplaceCurrentSession() {
754 // Noop if no current effective config.
755 session_ = BuildSession();
756 }
757
GetCurrentSession()758 DnsSession* MockDnsClient::GetCurrentSession() {
759 return session_.get();
760 }
761
GetEffectiveConfig() const762 const DnsConfig* MockDnsClient::GetEffectiveConfig() const {
763 return effective_config_.has_value() ? &effective_config_.value() : nullptr;
764 }
765
GetDnsConfigAsValueForNetLog() const766 base::Value::Dict MockDnsClient::GetDnsConfigAsValueForNetLog() const {
767 // This is just a stub implementation that never produces a meaningful value.
768 return base::Value::Dict();
769 }
770
GetHosts() const771 const DnsHosts* MockDnsClient::GetHosts() const {
772 const DnsConfig* config = GetEffectiveConfig();
773 if (!config)
774 return nullptr;
775
776 return &config->hosts;
777 }
778
GetTransactionFactory()779 DnsTransactionFactory* MockDnsClient::GetTransactionFactory() {
780 return GetEffectiveConfig() ? factory_.get() : nullptr;
781 }
782
GetAddressSorter()783 AddressSorter* MockDnsClient::GetAddressSorter() {
784 return GetEffectiveConfig() ? address_sorter_.get() : nullptr;
785 }
786
IncrementInsecureFallbackFailures()787 void MockDnsClient::IncrementInsecureFallbackFailures() {
788 ++fallback_failures_;
789 }
790
ClearInsecureFallbackFailures()791 void MockDnsClient::ClearInsecureFallbackFailures() {
792 fallback_failures_ = 0;
793 }
794
GetSystemConfigForTesting() const795 absl::optional<DnsConfig> MockDnsClient::GetSystemConfigForTesting() const {
796 return config_;
797 }
798
GetConfigOverridesForTesting() const799 DnsConfigOverrides MockDnsClient::GetConfigOverridesForTesting() const {
800 return overrides_;
801 }
802
SetTransactionFactoryForTesting(std::unique_ptr<DnsTransactionFactory> factory)803 void MockDnsClient::SetTransactionFactoryForTesting(
804 std::unique_ptr<DnsTransactionFactory> factory) {
805 NOTREACHED();
806 }
807
GetPresetAddrs(const url::SchemeHostPort & endpoint) const808 absl::optional<std::vector<IPEndPoint>> MockDnsClient::GetPresetAddrs(
809 const url::SchemeHostPort& endpoint) const {
810 EXPECT_THAT(preset_endpoint_, testing::Optional(endpoint));
811 return preset_addrs_;
812 }
813
CompleteDelayedTransactions()814 void MockDnsClient::CompleteDelayedTransactions() {
815 factory_->CompleteDelayedTransactions();
816 }
817
CompleteOneDelayedTransactionOfType(DnsQueryType type)818 bool MockDnsClient::CompleteOneDelayedTransactionOfType(DnsQueryType type) {
819 return factory_->CompleteOneDelayedTransactionOfType(type);
820 }
821
SetForceDohServerAvailable(bool available)822 void MockDnsClient::SetForceDohServerAvailable(bool available) {
823 force_doh_server_available_ = available;
824 factory_->set_force_doh_server_available(available);
825 }
826
BuildEffectiveConfig()827 absl::optional<DnsConfig> MockDnsClient::BuildEffectiveConfig() {
828 if (overrides_.OverridesEverything())
829 return overrides_.ApplyOverrides(DnsConfig());
830 if (!config_ || !config_.value().IsValid())
831 return absl::nullopt;
832
833 return overrides_.ApplyOverrides(config_.value());
834 }
835
BuildSession()836 scoped_refptr<DnsSession> MockDnsClient::BuildSession() {
837 if (!effective_config_)
838 return nullptr;
839
840 // Session not expected to be used for anything that will actually require
841 // random numbers.
842 auto null_random_callback =
843 base::BindRepeating([](int, int) -> int { base::ImmediateCrash(); });
844
845 return base::MakeRefCounted<DnsSession>(
846 effective_config_.value(), null_random_callback, nullptr /* net_log */);
847 }
848
849 } // namespace net
850