1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/dns_test_util.h"
6 
7 #include <cstdint>
8 #include <string>
9 #include <utility>
10 #include <vector>
11 
12 #include "base/big_endian.h"
13 #include "base/check.h"
14 #include "base/functional/bind.h"
15 #include "base/location.h"
16 #include "base/numerics/safe_conversions.h"
17 #include "base/ranges/algorithm.h"
18 #include "base/strings/strcat.h"
19 #include "base/sys_byteorder.h"
20 #include "base/task/single_thread_task_runner.h"
21 #include "base/time/time.h"
22 #include "base/types/optional_util.h"
23 #include "net/base/io_buffer.h"
24 #include "net/base/ip_address.h"
25 #include "net/base/ip_endpoint.h"
26 #include "net/base/net_errors.h"
27 #include "net/dns/address_sorter.h"
28 #include "net/dns/dns_hosts.h"
29 #include "net/dns/dns_names_util.h"
30 #include "net/dns/dns_query.h"
31 #include "net/dns/dns_session.h"
32 #include "net/dns/public/dns_over_https_server_config.h"
33 #include "net/dns/resolve_context.h"
34 #include "testing/gmock/include/gmock/gmock-matchers.h"
35 #include "testing/gtest/include/gtest/gtest.h"
36 #include "third_party/abseil-cpp/absl/types/optional.h"
37 #include "url/scheme_host_port.h"
38 
39 namespace net {
40 namespace {
41 
42 const uint8_t kMalformedResponseHeader[] = {
43     // Header
44     0x00, 0x14,  // Arbitrary ID
45     0x81, 0x80,  // Standard query response, RA, no error
46     0x00, 0x01,  // 1 question
47     0x00, 0x01,  // 1 RR (answers)
48     0x00, 0x00,  // 0 authority RRs
49     0x00, 0x00,  // 0 additional RRs
50 };
51 
52 // Create a response containing a valid question (as would normally be validated
53 // in DnsTransaction) but completely missing a header-declared answer.
CreateMalformedResponse(std::string hostname,uint16_t type)54 DnsResponse CreateMalformedResponse(std::string hostname, uint16_t type) {
55   absl::optional<std::vector<uint8_t>> dns_name =
56       dns_names_util::DottedNameToNetwork(hostname);
57   CHECK(dns_name.has_value());
58   DnsQuery query(/*id=*/0x14, dns_name.value(), type);
59 
60   // Build response to simulate the barebones validation DnsResponse applies to
61   // responses received from the network.
62   auto buffer = base::MakeRefCounted<IOBufferWithSize>(
63       sizeof(kMalformedResponseHeader) + query.question().size());
64   memcpy(buffer->data(), kMalformedResponseHeader,
65          sizeof(kMalformedResponseHeader));
66   memcpy(buffer->data() + sizeof(kMalformedResponseHeader),
67          query.question().data(), query.question().size());
68 
69   DnsResponse response(buffer, buffer->size());
70   CHECK(response.InitParseWithoutQuery(buffer->size()));
71 
72   return response;
73 }
74 
75 class MockAddressSorter : public AddressSorter {
76  public:
77   ~MockAddressSorter() override = default;
Sort(const std::vector<IPEndPoint> & endpoints,CallbackType callback) const78   void Sort(const std::vector<IPEndPoint>& endpoints,
79             CallbackType callback) const override {
80     // Do nothing.
81     std::move(callback).Run(true, endpoints);
82   }
83 };
84 
85 }  // namespace
86 
BuildTestDnsRecord(std::string name,uint16_t type,std::string rdata,base::TimeDelta ttl)87 DnsResourceRecord BuildTestDnsRecord(std::string name,
88                                      uint16_t type,
89                                      std::string rdata,
90                                      base::TimeDelta ttl) {
91   DCHECK(!name.empty());
92 
93   DnsResourceRecord record;
94   record.name = std::move(name);
95   record.type = type;
96   record.klass = dns_protocol::kClassIN;
97   record.ttl = ttl.InSeconds();
98 
99   if (!rdata.empty())
100     record.SetOwnedRdata(std::move(rdata));
101 
102   return record;
103 }
104 
BuildTestCnameRecord(std::string name,base::StringPiece canonical_name,base::TimeDelta ttl)105 DnsResourceRecord BuildTestCnameRecord(std::string name,
106                                        base::StringPiece canonical_name,
107                                        base::TimeDelta ttl) {
108   DCHECK(!name.empty());
109   DCHECK(!canonical_name.empty());
110 
111   absl::optional<std::vector<uint8_t>> rdata =
112       dns_names_util::DottedNameToNetwork(canonical_name);
113   CHECK(rdata.has_value());
114 
115   return BuildTestDnsRecord(
116       std::move(name), dns_protocol::kTypeCNAME,
117       std::string(reinterpret_cast<char*>(rdata.value().data()),
118                   rdata.value().size()),
119       ttl);
120 }
121 
BuildTestAddressRecord(std::string name,const IPAddress & ip,base::TimeDelta ttl)122 DnsResourceRecord BuildTestAddressRecord(std::string name,
123                                          const IPAddress& ip,
124                                          base::TimeDelta ttl) {
125   DCHECK(!name.empty());
126   DCHECK(ip.IsValid());
127 
128   return BuildTestDnsRecord(
129       std::move(name),
130       ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA,
131       net::IPAddressToPackedString(ip), ttl);
132 }
133 
BuildTestTextRecord(std::string name,std::vector<std::string> text_strings,base::TimeDelta ttl)134 DnsResourceRecord BuildTestTextRecord(std::string name,
135                                       std::vector<std::string> text_strings,
136                                       base::TimeDelta ttl) {
137   DCHECK(!text_strings.empty());
138 
139   std::string rdata;
140   for (const std::string& text_string : text_strings) {
141     DCHECK(!text_string.empty());
142 
143     rdata += base::checked_cast<unsigned char>(text_string.size());
144     rdata += text_string;
145   }
146 
147   return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeTXT,
148                             std::move(rdata), ttl);
149 }
150 
BuildTestHttpsAliasRecord(std::string name,base::StringPiece alias_name,base::TimeDelta ttl)151 DnsResourceRecord BuildTestHttpsAliasRecord(std::string name,
152                                             base::StringPiece alias_name,
153                                             base::TimeDelta ttl) {
154   DCHECK(!name.empty());
155 
156   std::string rdata("\000\000", 2);
157 
158   absl::optional<std::vector<uint8_t>> alias_domain =
159       dns_names_util::DottedNameToNetwork(alias_name);
160   CHECK(alias_domain.has_value());
161   rdata.append(reinterpret_cast<char*>(alias_domain.value().data()),
162                alias_domain.value().size());
163 
164   return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeHttps,
165                             std::move(rdata), ttl);
166 }
167 
BuildTestHttpsServiceAlpnParam(const std::vector<std::string> & alpns)168 std::pair<uint16_t, std::string> BuildTestHttpsServiceAlpnParam(
169     const std::vector<std::string>& alpns) {
170   std::string param_value;
171 
172   for (const std::string& alpn : alpns) {
173     CHECK(!alpn.empty());
174     param_value.append(
175         1, static_cast<char>(base::checked_cast<uint8_t>(alpn.size())));
176     param_value.append(alpn);
177   }
178 
179   return std::make_pair(dns_protocol::kHttpsServiceParamKeyAlpn,
180                         std::move(param_value));
181 }
182 
BuildTestHttpsServiceEchConfigParam(base::span<const uint8_t> ech_config_list)183 std::pair<uint16_t, std::string> BuildTestHttpsServiceEchConfigParam(
184     base::span<const uint8_t> ech_config_list) {
185   return std::make_pair(
186       dns_protocol::kHttpsServiceParamKeyEchConfig,
187       std::string(reinterpret_cast<const char*>(ech_config_list.data()),
188                   ech_config_list.size()));
189 }
190 
BuildTestHttpsServiceMandatoryParam(std::vector<uint16_t> param_key_list)191 std::pair<uint16_t, std::string> BuildTestHttpsServiceMandatoryParam(
192     std::vector<uint16_t> param_key_list) {
193   base::ranges::sort(param_key_list);
194 
195   std::string value;
196   for (uint16_t param_key : param_key_list) {
197     char num_buffer[2];
198     base::WriteBigEndian(num_buffer, param_key);
199     value.append(num_buffer, 2);
200   }
201 
202   return std::make_pair(dns_protocol::kHttpsServiceParamKeyMandatory,
203                         std::move(value));
204 }
205 
BuildTestHttpsServicePortParam(uint16_t port)206 std::pair<uint16_t, std::string> BuildTestHttpsServicePortParam(uint16_t port) {
207   char buffer[2];
208   base::WriteBigEndian(buffer, port);
209 
210   return std::make_pair(dns_protocol::kHttpsServiceParamKeyPort,
211                         std::string(buffer, 2));
212 }
213 
BuildTestHttpsServiceRecord(std::string name,uint16_t priority,base::StringPiece service_name,const std::map<uint16_t,std::string> & params,base::TimeDelta ttl)214 DnsResourceRecord BuildTestHttpsServiceRecord(
215     std::string name,
216     uint16_t priority,
217     base::StringPiece service_name,
218     const std::map<uint16_t, std::string>& params,
219     base::TimeDelta ttl) {
220   DCHECK(!name.empty());
221   DCHECK_NE(priority, 0);
222 
223   std::string rdata;
224 
225   char num_buffer[2];
226   base::WriteBigEndian(num_buffer, priority);
227   rdata.append(num_buffer, 2);
228 
229   absl::optional<std::vector<uint8_t>> service_domain;
230   if (service_name == ".") {
231     // HTTPS records have special behavior for `service_name == "."` (that it
232     // will be treated as if the service name is the same as the record owner
233     // name), so allow such inputs despite normally being disallowed for
234     // Chrome-encoded DNS names.
235     service_domain = std::vector<uint8_t>{0};
236   } else {
237     service_domain = dns_names_util::DottedNameToNetwork(service_name);
238   }
239   CHECK(service_domain.has_value());
240   rdata.append(reinterpret_cast<char*>(service_domain.value().data()),
241                service_domain.value().size());
242 
243   for (auto& param : params) {
244     base::WriteBigEndian(num_buffer, param.first);
245     rdata.append(num_buffer, 2);
246 
247     base::WriteBigEndian(num_buffer,
248                          base::checked_cast<uint16_t>(param.second.size()));
249     rdata.append(num_buffer, 2);
250 
251     rdata.append(param.second);
252   }
253 
254   return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeHttps,
255                             std::move(rdata), ttl);
256 }
257 
BuildTestDnsResponse(std::string name,uint16_t type,const std::vector<DnsResourceRecord> & answers,const std::vector<DnsResourceRecord> & authority,const std::vector<DnsResourceRecord> & additional,uint8_t rcode)258 DnsResponse BuildTestDnsResponse(
259     std::string name,
260     uint16_t type,
261     const std::vector<DnsResourceRecord>& answers,
262     const std::vector<DnsResourceRecord>& authority,
263     const std::vector<DnsResourceRecord>& additional,
264     uint8_t rcode) {
265   DCHECK(!name.empty());
266 
267   absl::optional<std::vector<uint8_t>> dns_name =
268       dns_names_util::DottedNameToNetwork(name);
269   CHECK(dns_name.has_value());
270 
271   absl::optional<DnsQuery> query(absl::in_place, 0, dns_name.value(), type);
272   return DnsResponse(0, true /* is_authoritative */, answers,
273                      authority /* authority_records */,
274                      additional /* additional_records */, query, rcode,
275                      false /* validate_records */);
276 }
277 
BuildTestDnsAddressResponse(std::string name,const IPAddress & ip,std::string answer_name)278 DnsResponse BuildTestDnsAddressResponse(std::string name,
279                                         const IPAddress& ip,
280                                         std::string answer_name) {
281   DCHECK(ip.IsValid());
282 
283   if (answer_name.empty())
284     answer_name = name;
285 
286   std::vector<DnsResourceRecord> answers = {
287       BuildTestAddressRecord(std::move(answer_name), ip)};
288 
289   return BuildTestDnsResponse(
290       std::move(name),
291       ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, answers);
292 }
293 
BuildTestDnsAddressResponseWithCname(std::string name,const IPAddress & ip,std::string cannonname,std::string answer_name)294 DnsResponse BuildTestDnsAddressResponseWithCname(std::string name,
295                                                  const IPAddress& ip,
296                                                  std::string cannonname,
297                                                  std::string answer_name) {
298   DCHECK(ip.IsValid());
299   DCHECK(!cannonname.empty());
300 
301   if (answer_name.empty())
302     answer_name = name;
303 
304   absl::optional<std::vector<uint8_t>> cname_rdata =
305       dns_names_util::DottedNameToNetwork(cannonname);
306   CHECK(cname_rdata.has_value());
307 
308   std::vector<DnsResourceRecord> answers = {
309       BuildTestDnsRecord(
310           std::move(answer_name), dns_protocol::kTypeCNAME,
311           std::string(reinterpret_cast<char*>(cname_rdata.value().data()),
312                       cname_rdata.value().size())),
313       BuildTestAddressRecord(std::move(cannonname), ip)};
314 
315   return BuildTestDnsResponse(
316       std::move(name),
317       ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, answers);
318 }
319 
BuildTestDnsTextResponse(std::string name,std::vector<std::vector<std::string>> text_records,std::string answer_name)320 DnsResponse BuildTestDnsTextResponse(
321     std::string name,
322     std::vector<std::vector<std::string>> text_records,
323     std::string answer_name) {
324   if (answer_name.empty())
325     answer_name = name;
326 
327   std::vector<DnsResourceRecord> answers;
328   for (std::vector<std::string>& text_record : text_records) {
329     answers.push_back(BuildTestTextRecord(answer_name, std::move(text_record)));
330   }
331 
332   return BuildTestDnsResponse(std::move(name), dns_protocol::kTypeTXT, answers);
333 }
334 
BuildTestDnsPointerResponse(std::string name,std::vector<std::string> pointer_names,std::string answer_name)335 DnsResponse BuildTestDnsPointerResponse(std::string name,
336                                         std::vector<std::string> pointer_names,
337                                         std::string answer_name) {
338   if (answer_name.empty())
339     answer_name = name;
340 
341   std::vector<DnsResourceRecord> answers;
342   for (std::string& pointer_name : pointer_names) {
343     absl::optional<std::vector<uint8_t>> rdata =
344         dns_names_util::DottedNameToNetwork(pointer_name);
345     CHECK(rdata.has_value());
346 
347     answers.push_back(BuildTestDnsRecord(
348         answer_name, dns_protocol::kTypePTR,
349         std::string(reinterpret_cast<char*>(rdata.value().data()),
350                     rdata.value().size())));
351   }
352 
353   return BuildTestDnsResponse(std::move(name), dns_protocol::kTypePTR, answers);
354 }
355 
BuildTestDnsServiceResponse(std::string name,std::vector<TestServiceRecord> service_records,std::string answer_name)356 DnsResponse BuildTestDnsServiceResponse(
357     std::string name,
358     std::vector<TestServiceRecord> service_records,
359     std::string answer_name) {
360   if (answer_name.empty())
361     answer_name = name;
362 
363   std::vector<DnsResourceRecord> answers;
364   for (TestServiceRecord& service_record : service_records) {
365     std::string rdata;
366     char num_buffer[2];
367     base::WriteBigEndian(num_buffer, service_record.priority);
368     rdata.append(num_buffer, 2);
369     base::WriteBigEndian(num_buffer, service_record.weight);
370     rdata.append(num_buffer, 2);
371     base::WriteBigEndian(num_buffer, service_record.port);
372     rdata.append(num_buffer, 2);
373 
374     absl::optional<std::vector<uint8_t>> dns_name =
375         dns_names_util::DottedNameToNetwork(service_record.target);
376     CHECK(dns_name.has_value());
377     rdata.append(reinterpret_cast<char*>(dns_name.value().data()),
378                  dns_name.value().size());
379 
380     answers.push_back(BuildTestDnsRecord(answer_name, dns_protocol::kTypeSRV,
381                                          std::move(rdata), base::Hours(5)));
382   }
383 
384   return BuildTestDnsResponse(std::move(name), dns_protocol::kTypeSRV, answers);
385 }
386 
Result(ResultType type,absl::optional<DnsResponse> response,absl::optional<int> net_error)387 MockDnsClientRule::Result::Result(ResultType type,
388                                   absl::optional<DnsResponse> response,
389                                   absl::optional<int> net_error)
390     : type(type), response(std::move(response)), net_error(net_error) {}
391 
Result(DnsResponse response)392 MockDnsClientRule::Result::Result(DnsResponse response)
393     : type(ResultType::kOk),
394       response(std::move(response)),
395       net_error(absl::nullopt) {}
396 
397 MockDnsClientRule::Result::Result(Result&&) = default;
398 
399 MockDnsClientRule::Result& MockDnsClientRule::Result::operator=(Result&&) =
400     default;
401 
402 MockDnsClientRule::Result::~Result() = default;
403 
MockDnsClientRule(const std::string & prefix,uint16_t qtype,bool secure,Result result,bool delay,URLRequestContext * context)404 MockDnsClientRule::MockDnsClientRule(const std::string& prefix,
405                                      uint16_t qtype,
406                                      bool secure,
407                                      Result result,
408                                      bool delay,
409                                      URLRequestContext* context)
410     : result(std::move(result)),
411       prefix(prefix),
412       qtype(qtype),
413       secure(secure),
414       delay(delay),
415       context(context) {}
416 
417 MockDnsClientRule::MockDnsClientRule(MockDnsClientRule&& rule) = default;
418 
419 // A DnsTransaction which uses MockDnsClientRuleList to determine the response.
420 class MockDnsTransactionFactory::MockTransaction
421     : public DnsTransaction,
422       public base::SupportsWeakPtr<MockTransaction> {
423  public:
MockTransaction(const MockDnsClientRuleList & rules,std::string hostname,uint16_t qtype,bool secure,bool force_doh_server_available,SecureDnsMode secure_dns_mode,ResolveContext * resolve_context,bool fast_timeout)424   MockTransaction(const MockDnsClientRuleList& rules,
425                   std::string hostname,
426                   uint16_t qtype,
427                   bool secure,
428                   bool force_doh_server_available,
429                   SecureDnsMode secure_dns_mode,
430                   ResolveContext* resolve_context,
431                   bool fast_timeout)
432       : hostname_(std::move(hostname)), qtype_(qtype) {
433     // Do not allow matching any rules if transaction is secure and no DoH
434     // servers are available.
435     if (!secure || force_doh_server_available ||
436         resolve_context->NumAvailableDohServers(
437             resolve_context->current_session_for_testing()) > 0) {
438       // Find the relevant rule which matches |qtype|, |secure|, prefix of
439       // |hostname_|, and |url_request_context| (iff the rule context is not
440       // null).
441       for (const auto& rule : rules) {
442         const std::string& prefix = rule.prefix;
443         if ((rule.qtype == qtype) && (rule.secure == secure) &&
444             (hostname_.size() >= prefix.size()) &&
445             (hostname_.compare(0, prefix.size(), prefix) == 0) &&
446             (!rule.context ||
447              rule.context == resolve_context->url_request_context())) {
448           const MockDnsClientRule::Result* result = &rule.result;
449           result_ = MockDnsClientRule::Result(result->type);
450           result_.net_error = result->net_error;
451           delayed_ = rule.delay;
452 
453           // Generate a DnsResponse when not provided with the rule.
454           std::vector<DnsResourceRecord> authority_records;
455           absl::optional<std::vector<uint8_t>> dns_name =
456               dns_names_util::DottedNameToNetwork(hostname_);
457           CHECK(dns_name.has_value());
458           absl::optional<DnsQuery> query(absl::in_place, /*id=*/22,
459                                          dns_name.value(), qtype_);
460           switch (result->type) {
461             case MockDnsClientRule::ResultType::kNoDomain:
462             case MockDnsClientRule::ResultType::kEmpty:
463               DCHECK(!result->response);  // Not expected to be provided.
464               authority_records = {BuildTestDnsRecord(
465                   hostname_, dns_protocol::kTypeSOA, "fake rdata")};
466               result_.response = DnsResponse(
467                   22 /* id */, false /* is_authoritative */,
468                   std::vector<DnsResourceRecord>() /* answers */,
469                   authority_records,
470                   std::vector<DnsResourceRecord>() /* additional_records */,
471                   query,
472                   result->type == MockDnsClientRule::ResultType::kNoDomain
473                       ? dns_protocol::kRcodeNXDOMAIN
474                       : 0);
475               break;
476             case MockDnsClientRule::ResultType::kFail:
477               if (result->response)
478                 SetResponse(result);
479               break;
480             case MockDnsClientRule::ResultType::kTimeout:
481               DCHECK(!result->response);  // Not expected to be provided.
482               break;
483             case MockDnsClientRule::ResultType::kSlow:
484               if (!fast_timeout)
485                 SetResponse(result);
486               break;
487             case MockDnsClientRule::ResultType::kOk:
488               SetResponse(result);
489               break;
490             case MockDnsClientRule::ResultType::kMalformed:
491               DCHECK(!result->response);  // Not expected to be provided.
492               result_.response = CreateMalformedResponse(hostname_, qtype_);
493               break;
494             case MockDnsClientRule::ResultType::kUnexpected:
495               if (!delayed_) {
496                 // Assume a delayed kUnexpected transaction is only an issue if
497                 // allowed to complete.
498                 ADD_FAILURE()
499                     << "Unexpected DNS transaction created for hostname "
500                     << hostname_;
501               }
502               break;
503           }
504 
505           break;
506         }
507       }
508     }
509   }
510 
GetHostname() const511   const std::string& GetHostname() const override { return hostname_; }
512 
GetType() const513   uint16_t GetType() const override { return qtype_; }
514 
Start(ResponseCallback callback)515   void Start(ResponseCallback callback) override {
516     CHECK(!callback.is_null());
517     CHECK(callback_.is_null());
518     EXPECT_FALSE(started_);
519 
520     callback_ = std::move(callback);
521     started_ = true;
522     if (delayed_)
523       return;
524     // Using WeakPtr to cleanly cancel when transaction is destroyed.
525     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
526         FROM_HERE, base::BindOnce(&MockTransaction::Finish, AsWeakPtr()));
527   }
528 
FinishDelayedTransaction()529   void FinishDelayedTransaction() {
530     EXPECT_TRUE(delayed_);
531     delayed_ = false;
532     Finish();
533   }
534 
delayed() const535   bool delayed() const { return delayed_; }
536 
537  private:
SetResponse(const MockDnsClientRule::Result * result)538   void SetResponse(const MockDnsClientRule::Result* result) {
539     if (result->response) {
540       // Copy response in case |result| is destroyed before the transaction
541       // completes.
542       auto buffer_copy =
543           base::MakeRefCounted<IOBuffer>(result->response->io_buffer_size());
544       memcpy(buffer_copy->data(), result->response->io_buffer()->data(),
545              result->response->io_buffer_size());
546       result_.response = DnsResponse(std::move(buffer_copy),
547                                      result->response->io_buffer_size());
548       CHECK(result_.response->InitParseWithoutQuery(
549           result->response->io_buffer_size()));
550     } else {
551       // Generated response only available for address types.
552       DCHECK(qtype_ == dns_protocol::kTypeA ||
553              qtype_ == dns_protocol::kTypeAAAA);
554       result_.response = BuildTestDnsAddressResponse(
555           hostname_, qtype_ == dns_protocol::kTypeA
556                          ? IPAddress::IPv4Localhost()
557                          : IPAddress::IPv6Localhost());
558     }
559   }
560 
Finish()561   void Finish() {
562     switch (result_.type) {
563       case MockDnsClientRule::ResultType::kNoDomain:
564       case MockDnsClientRule::ResultType::kFail: {
565         int error = result_.net_error.value_or(ERR_NAME_NOT_RESOLVED);
566         DCHECK_NE(error, OK);
567         std::move(callback_).Run(error, base::OptionalToPtr(result_.response));
568         break;
569       }
570       case MockDnsClientRule::ResultType::kEmpty:
571       case MockDnsClientRule::ResultType::kOk:
572       case MockDnsClientRule::ResultType::kMalformed:
573         DCHECK(!result_.net_error.has_value());
574         std::move(callback_).Run(OK, base::OptionalToPtr(result_.response));
575         break;
576       case MockDnsClientRule::ResultType::kTimeout:
577         DCHECK(!result_.net_error.has_value());
578         std::move(callback_).Run(ERR_DNS_TIMED_OUT, /*response=*/nullptr);
579         break;
580       case MockDnsClientRule::ResultType::kSlow:
581         if (result_.response) {
582           std::move(callback_).Run(
583               result_.net_error.value_or(OK),
584               result_.response ? &result_.response.value() : nullptr);
585         } else {
586           DCHECK(!result_.net_error.has_value());
587           std::move(callback_).Run(ERR_DNS_TIMED_OUT, /*response=*/nullptr);
588         }
589         break;
590       case MockDnsClientRule::ResultType::kUnexpected:
591         ADD_FAILURE() << "Unexpected DNS transaction completed for hostname "
592                       << hostname_;
593         break;
594     }
595   }
596 
SetRequestPriority(RequestPriority priority)597   void SetRequestPriority(RequestPriority priority) override {}
598 
599   MockDnsClientRule::Result result_{MockDnsClientRule::ResultType::kFail};
600   const std::string hostname_;
601   const uint16_t qtype_;
602   ResponseCallback callback_;
603   bool started_ = false;
604   bool delayed_ = false;
605 };
606 
607 class MockDnsTransactionFactory::MockDohProbeRunner : public DnsProbeRunner {
608  public:
MockDohProbeRunner(base::WeakPtr<MockDnsTransactionFactory> factory)609   explicit MockDohProbeRunner(base::WeakPtr<MockDnsTransactionFactory> factory)
610       : factory_(std::move(factory)) {}
611 
~MockDohProbeRunner()612   ~MockDohProbeRunner() override {
613     if (factory_)
614       factory_->running_doh_probe_runners_.erase(this);
615   }
616 
Start(bool network_change)617   void Start(bool network_change) override {
618     DCHECK(factory_);
619     factory_->running_doh_probe_runners_.insert(this);
620   }
621 
GetDelayUntilNextProbeForTest(size_t doh_server_index) const622   base::TimeDelta GetDelayUntilNextProbeForTest(
623       size_t doh_server_index) const override {
624     NOTREACHED();
625     return base::TimeDelta();
626   }
627 
628  private:
629   base::WeakPtr<MockDnsTransactionFactory> factory_;
630 };
631 
MockDnsTransactionFactory(MockDnsClientRuleList rules)632 MockDnsTransactionFactory::MockDnsTransactionFactory(
633     MockDnsClientRuleList rules)
634     : rules_(std::move(rules)) {}
635 
636 MockDnsTransactionFactory::~MockDnsTransactionFactory() = default;
637 
CreateTransaction(std::string hostname,uint16_t qtype,const NetLogWithSource &,bool secure,SecureDnsMode secure_dns_mode,ResolveContext * resolve_context,bool fast_timeout)638 std::unique_ptr<DnsTransaction> MockDnsTransactionFactory::CreateTransaction(
639     std::string hostname,
640     uint16_t qtype,
641     const NetLogWithSource&,
642     bool secure,
643     SecureDnsMode secure_dns_mode,
644     ResolveContext* resolve_context,
645     bool fast_timeout) {
646   std::unique_ptr<MockTransaction> transaction =
647       std::make_unique<MockTransaction>(rules_, std::move(hostname), qtype,
648                                         secure, force_doh_server_available_,
649                                         secure_dns_mode, resolve_context,
650                                         fast_timeout);
651   if (transaction->delayed())
652     delayed_transactions_.push_back(transaction->AsWeakPtr());
653   return transaction;
654 }
655 
CreateDohProbeRunner(ResolveContext * resolve_context)656 std::unique_ptr<DnsProbeRunner> MockDnsTransactionFactory::CreateDohProbeRunner(
657     ResolveContext* resolve_context) {
658   return std::make_unique<MockDohProbeRunner>(weak_ptr_factory_.GetWeakPtr());
659 }
660 
AddEDNSOption(std::unique_ptr<OptRecordRdata::Opt> opt)661 void MockDnsTransactionFactory::AddEDNSOption(
662     std::unique_ptr<OptRecordRdata::Opt> opt) {}
663 
GetSecureDnsModeForTest()664 SecureDnsMode MockDnsTransactionFactory::GetSecureDnsModeForTest() {
665   return SecureDnsMode::kAutomatic;
666 }
667 
CompleteDelayedTransactions()668 void MockDnsTransactionFactory::CompleteDelayedTransactions() {
669   DelayedTransactionList old_delayed_transactions;
670   old_delayed_transactions.swap(delayed_transactions_);
671   for (auto& old_delayed_transaction : old_delayed_transactions) {
672     if (old_delayed_transaction.get())
673       old_delayed_transaction->FinishDelayedTransaction();
674   }
675 }
676 
CompleteOneDelayedTransactionOfType(DnsQueryType type)677 bool MockDnsTransactionFactory::CompleteOneDelayedTransactionOfType(
678     DnsQueryType type) {
679   for (base::WeakPtr<MockTransaction>& t : delayed_transactions_) {
680     if (t && t->GetType() == DnsQueryTypeToQtype(type)) {
681       t->FinishDelayedTransaction();
682       t.reset();
683       return true;
684     }
685   }
686   return false;
687 }
688 
MockDnsClient(DnsConfig config,MockDnsClientRuleList rules)689 MockDnsClient::MockDnsClient(DnsConfig config, MockDnsClientRuleList rules)
690     : config_(std::move(config)),
691       factory_(std::make_unique<MockDnsTransactionFactory>(std::move(rules))),
692       address_sorter_(std::make_unique<MockAddressSorter>()) {
693   effective_config_ = BuildEffectiveConfig();
694   session_ = BuildSession();
695 }
696 
697 MockDnsClient::~MockDnsClient() = default;
698 
CanUseSecureDnsTransactions() const699 bool MockDnsClient::CanUseSecureDnsTransactions() const {
700   const DnsConfig* config = GetEffectiveConfig();
701   return config && config->IsValid() && !config->doh_config.servers().empty();
702 }
703 
CanUseInsecureDnsTransactions() const704 bool MockDnsClient::CanUseInsecureDnsTransactions() const {
705   const DnsConfig* config = GetEffectiveConfig();
706   return config && config->IsValid() && insecure_enabled_ &&
707          !config->dns_over_tls_active;
708 }
709 
CanQueryAdditionalTypesViaInsecureDns() const710 bool MockDnsClient::CanQueryAdditionalTypesViaInsecureDns() const {
711   DCHECK(CanUseInsecureDnsTransactions());
712   return additional_types_enabled_;
713 }
714 
SetInsecureEnabled(bool enabled,bool additional_types_enabled)715 void MockDnsClient::SetInsecureEnabled(bool enabled,
716                                        bool additional_types_enabled) {
717   insecure_enabled_ = enabled;
718   additional_types_enabled_ = additional_types_enabled;
719 }
720 
FallbackFromSecureTransactionPreferred(ResolveContext * context) const721 bool MockDnsClient::FallbackFromSecureTransactionPreferred(
722     ResolveContext* context) const {
723   bool doh_server_available =
724       force_doh_server_available_ ||
725       context->NumAvailableDohServers(session_.get()) > 0;
726   return !CanUseSecureDnsTransactions() || !doh_server_available;
727 }
728 
FallbackFromInsecureTransactionPreferred() const729 bool MockDnsClient::FallbackFromInsecureTransactionPreferred() const {
730   return !CanUseInsecureDnsTransactions() ||
731          fallback_failures_ >= max_fallback_failures_;
732 }
733 
SetSystemConfig(absl::optional<DnsConfig> system_config)734 bool MockDnsClient::SetSystemConfig(absl::optional<DnsConfig> system_config) {
735   if (ignore_system_config_changes_)
736     return false;
737 
738   absl::optional<DnsConfig> before = effective_config_;
739   config_ = std::move(system_config);
740   effective_config_ = BuildEffectiveConfig();
741   session_ = BuildSession();
742   return before != effective_config_;
743 }
744 
SetConfigOverrides(DnsConfigOverrides config_overrides)745 bool MockDnsClient::SetConfigOverrides(DnsConfigOverrides config_overrides) {
746   absl::optional<DnsConfig> before = effective_config_;
747   overrides_ = std::move(config_overrides);
748   effective_config_ = BuildEffectiveConfig();
749   session_ = BuildSession();
750   return before != effective_config_;
751 }
752 
ReplaceCurrentSession()753 void MockDnsClient::ReplaceCurrentSession() {
754   // Noop if no current effective config.
755   session_ = BuildSession();
756 }
757 
GetCurrentSession()758 DnsSession* MockDnsClient::GetCurrentSession() {
759   return session_.get();
760 }
761 
GetEffectiveConfig() const762 const DnsConfig* MockDnsClient::GetEffectiveConfig() const {
763   return effective_config_.has_value() ? &effective_config_.value() : nullptr;
764 }
765 
GetDnsConfigAsValueForNetLog() const766 base::Value::Dict MockDnsClient::GetDnsConfigAsValueForNetLog() const {
767   // This is just a stub implementation that never produces a meaningful value.
768   return base::Value::Dict();
769 }
770 
GetHosts() const771 const DnsHosts* MockDnsClient::GetHosts() const {
772   const DnsConfig* config = GetEffectiveConfig();
773   if (!config)
774     return nullptr;
775 
776   return &config->hosts;
777 }
778 
GetTransactionFactory()779 DnsTransactionFactory* MockDnsClient::GetTransactionFactory() {
780   return GetEffectiveConfig() ? factory_.get() : nullptr;
781 }
782 
GetAddressSorter()783 AddressSorter* MockDnsClient::GetAddressSorter() {
784   return GetEffectiveConfig() ? address_sorter_.get() : nullptr;
785 }
786 
IncrementInsecureFallbackFailures()787 void MockDnsClient::IncrementInsecureFallbackFailures() {
788   ++fallback_failures_;
789 }
790 
ClearInsecureFallbackFailures()791 void MockDnsClient::ClearInsecureFallbackFailures() {
792   fallback_failures_ = 0;
793 }
794 
GetSystemConfigForTesting() const795 absl::optional<DnsConfig> MockDnsClient::GetSystemConfigForTesting() const {
796   return config_;
797 }
798 
GetConfigOverridesForTesting() const799 DnsConfigOverrides MockDnsClient::GetConfigOverridesForTesting() const {
800   return overrides_;
801 }
802 
SetTransactionFactoryForTesting(std::unique_ptr<DnsTransactionFactory> factory)803 void MockDnsClient::SetTransactionFactoryForTesting(
804     std::unique_ptr<DnsTransactionFactory> factory) {
805   NOTREACHED();
806 }
807 
GetPresetAddrs(const url::SchemeHostPort & endpoint) const808 absl::optional<std::vector<IPEndPoint>> MockDnsClient::GetPresetAddrs(
809     const url::SchemeHostPort& endpoint) const {
810   EXPECT_THAT(preset_endpoint_, testing::Optional(endpoint));
811   return preset_addrs_;
812 }
813 
CompleteDelayedTransactions()814 void MockDnsClient::CompleteDelayedTransactions() {
815   factory_->CompleteDelayedTransactions();
816 }
817 
CompleteOneDelayedTransactionOfType(DnsQueryType type)818 bool MockDnsClient::CompleteOneDelayedTransactionOfType(DnsQueryType type) {
819   return factory_->CompleteOneDelayedTransactionOfType(type);
820 }
821 
SetForceDohServerAvailable(bool available)822 void MockDnsClient::SetForceDohServerAvailable(bool available) {
823   force_doh_server_available_ = available;
824   factory_->set_force_doh_server_available(available);
825 }
826 
BuildEffectiveConfig()827 absl::optional<DnsConfig> MockDnsClient::BuildEffectiveConfig() {
828   if (overrides_.OverridesEverything())
829     return overrides_.ApplyOverrides(DnsConfig());
830   if (!config_ || !config_.value().IsValid())
831     return absl::nullopt;
832 
833   return overrides_.ApplyOverrides(config_.value());
834 }
835 
BuildSession()836 scoped_refptr<DnsSession> MockDnsClient::BuildSession() {
837   if (!effective_config_)
838     return nullptr;
839 
840   // Session not expected to be used for anything that will actually require
841   // random numbers.
842   auto null_random_callback =
843       base::BindRepeating([](int, int) -> int { base::ImmediateCrash(); });
844 
845   return base::MakeRefCounted<DnsSession>(
846       effective_config_.value(), null_random_callback, nullptr /* net_log */);
847 }
848 
849 }  // namespace net
850