• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9 
10 #include "net/dns/dns_test_util.h"
11 
12 #include <cstdint>
13 #include <optional>
14 #include <string>
15 #include <string_view>
16 #include <utility>
17 #include <vector>
18 
19 #include "base/check.h"
20 #include "base/containers/span.h"
21 #include "base/functional/bind.h"
22 #include "base/location.h"
23 #include "base/numerics/byte_conversions.h"
24 #include "base/numerics/safe_conversions.h"
25 #include "base/ranges/algorithm.h"
26 #include "base/strings/strcat.h"
27 #include "base/sys_byteorder.h"
28 #include "base/task/single_thread_task_runner.h"
29 #include "base/test/test_timeouts.h"
30 #include "base/threading/thread_restrictions.h"
31 #include "base/time/time.h"
32 #include "base/types/optional_util.h"
33 #include "net/base/io_buffer.h"
34 #include "net/base/ip_address.h"
35 #include "net/base/ip_endpoint.h"
36 #include "net/base/net_errors.h"
37 #include "net/dns/address_sorter.h"
38 #include "net/dns/dns_hosts.h"
39 #include "net/dns/dns_names_util.h"
40 #include "net/dns/dns_query.h"
41 #include "net/dns/dns_session.h"
42 #include "net/dns/mock_host_resolver.h"
43 #include "net/dns/public/dns_over_https_server_config.h"
44 #include "net/dns/resolve_context.h"
45 #include "testing/gmock/include/gmock/gmock-matchers.h"
46 #include "testing/gtest/include/gtest/gtest.h"
47 #include "url/scheme_host_port.h"
48 
49 namespace net {
50 namespace {
51 
52 const uint8_t kMalformedResponseHeader[] = {
53     // Header
54     0x00, 0x14,  // Arbitrary ID
55     0x81, 0x80,  // Standard query response, RA, no error
56     0x00, 0x01,  // 1 question
57     0x00, 0x01,  // 1 RR (answers)
58     0x00, 0x00,  // 0 authority RRs
59     0x00, 0x00,  // 0 additional RRs
60 };
61 
62 // Create a response containing a valid question (as would normally be validated
63 // in DnsTransaction) but completely missing a header-declared answer.
CreateMalformedResponse(std::string hostname,uint16_t type)64 DnsResponse CreateMalformedResponse(std::string hostname, uint16_t type) {
65   std::optional<std::vector<uint8_t>> dns_name =
66       dns_names_util::DottedNameToNetwork(hostname);
67   CHECK(dns_name.has_value());
68   DnsQuery query(/*id=*/0x14, dns_name.value(), type);
69 
70   // Build response to simulate the barebones validation DnsResponse applies to
71   // responses received from the network.
72   auto buffer = base::MakeRefCounted<IOBufferWithSize>(
73       sizeof(kMalformedResponseHeader) + query.question().size());
74   memcpy(buffer->data(), kMalformedResponseHeader,
75          sizeof(kMalformedResponseHeader));
76   memcpy(buffer->data() + sizeof(kMalformedResponseHeader),
77          query.question().data(), query.question().size());
78 
79   DnsResponse response(buffer, buffer->size());
80   CHECK(response.InitParseWithoutQuery(buffer->size()));
81 
82   return response;
83 }
84 
85 class MockAddressSorter : public AddressSorter {
86  public:
87   ~MockAddressSorter() override = default;
Sort(const std::vector<IPEndPoint> & endpoints,CallbackType callback) const88   void Sort(const std::vector<IPEndPoint>& endpoints,
89             CallbackType callback) const override {
90     // Do nothing.
91     std::move(callback).Run(true, endpoints);
92   }
93 };
94 
95 }  // namespace
96 
CreateValidDnsConfig()97 DnsConfig CreateValidDnsConfig() {
98   IPAddress dns_ip(192, 168, 1, 0);
99   DnsConfig config;
100   config.nameservers.emplace_back(dns_ip, dns_protocol::kDefaultPort);
101   config.doh_config =
102       *DnsOverHttpsConfig::FromString("https://dns.example.com/");
103   config.secure_dns_mode = SecureDnsMode::kOff;
104   EXPECT_TRUE(config.IsValid());
105   return config;
106 }
107 
BuildTestDnsRecord(std::string name,uint16_t type,std::string rdata,base::TimeDelta ttl)108 DnsResourceRecord BuildTestDnsRecord(std::string name,
109                                      uint16_t type,
110                                      std::string rdata,
111                                      base::TimeDelta ttl) {
112   DCHECK(!name.empty());
113 
114   DnsResourceRecord record;
115   record.name = std::move(name);
116   record.type = type;
117   record.klass = dns_protocol::kClassIN;
118   record.ttl = ttl.InSeconds();
119 
120   if (!rdata.empty())
121     record.SetOwnedRdata(std::move(rdata));
122 
123   return record;
124 }
125 
BuildTestCnameRecord(std::string name,std::string_view canonical_name,base::TimeDelta ttl)126 DnsResourceRecord BuildTestCnameRecord(std::string name,
127                                        std::string_view canonical_name,
128                                        base::TimeDelta ttl) {
129   DCHECK(!name.empty());
130   DCHECK(!canonical_name.empty());
131 
132   std::optional<std::vector<uint8_t>> rdata =
133       dns_names_util::DottedNameToNetwork(canonical_name);
134   CHECK(rdata.has_value());
135 
136   return BuildTestDnsRecord(
137       std::move(name), dns_protocol::kTypeCNAME,
138       std::string(reinterpret_cast<char*>(rdata.value().data()),
139                   rdata.value().size()),
140       ttl);
141 }
142 
BuildTestAddressRecord(std::string name,const IPAddress & ip,base::TimeDelta ttl)143 DnsResourceRecord BuildTestAddressRecord(std::string name,
144                                          const IPAddress& ip,
145                                          base::TimeDelta ttl) {
146   DCHECK(!name.empty());
147   DCHECK(ip.IsValid());
148 
149   return BuildTestDnsRecord(
150       std::move(name),
151       ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA,
152       net::IPAddressToPackedString(ip), ttl);
153 }
154 
BuildTestTextRecord(std::string name,std::vector<std::string> text_strings,base::TimeDelta ttl)155 DnsResourceRecord BuildTestTextRecord(std::string name,
156                                       std::vector<std::string> text_strings,
157                                       base::TimeDelta ttl) {
158   DCHECK(!text_strings.empty());
159 
160   std::string rdata;
161   for (const std::string& text_string : text_strings) {
162     DCHECK(!text_string.empty());
163 
164     rdata += base::checked_cast<unsigned char>(text_string.size());
165     rdata += text_string;
166   }
167 
168   return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeTXT,
169                             std::move(rdata), ttl);
170 }
171 
BuildTestHttpsAliasRecord(std::string name,std::string_view alias_name,base::TimeDelta ttl)172 DnsResourceRecord BuildTestHttpsAliasRecord(std::string name,
173                                             std::string_view alias_name,
174                                             base::TimeDelta ttl) {
175   DCHECK(!name.empty());
176 
177   std::string rdata("\000\000", 2);
178 
179   std::optional<std::vector<uint8_t>> alias_domain =
180       dns_names_util::DottedNameToNetwork(alias_name);
181   CHECK(alias_domain.has_value());
182   rdata.append(reinterpret_cast<char*>(alias_domain.value().data()),
183                alias_domain.value().size());
184 
185   return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeHttps,
186                             std::move(rdata), ttl);
187 }
188 
BuildTestHttpsServiceAlpnParam(const std::vector<std::string> & alpns)189 std::pair<uint16_t, std::string> BuildTestHttpsServiceAlpnParam(
190     const std::vector<std::string>& alpns) {
191   std::string param_value;
192 
193   for (const std::string& alpn : alpns) {
194     CHECK(!alpn.empty());
195     param_value.append(
196         1, static_cast<char>(base::checked_cast<uint8_t>(alpn.size())));
197     param_value.append(alpn);
198   }
199 
200   return std::pair(dns_protocol::kHttpsServiceParamKeyAlpn,
201                    std::move(param_value));
202 }
203 
BuildTestHttpsServiceEchConfigParam(base::span<const uint8_t> ech_config_list)204 std::pair<uint16_t, std::string> BuildTestHttpsServiceEchConfigParam(
205     base::span<const uint8_t> ech_config_list) {
206   return std::pair(
207       dns_protocol::kHttpsServiceParamKeyEchConfig,
208       std::string(reinterpret_cast<const char*>(ech_config_list.data()),
209                   ech_config_list.size()));
210 }
211 
BuildTestHttpsServiceMandatoryParam(std::vector<uint16_t> param_key_list)212 std::pair<uint16_t, std::string> BuildTestHttpsServiceMandatoryParam(
213     std::vector<uint16_t> param_key_list) {
214   base::ranges::sort(param_key_list);
215 
216   std::string value;
217   for (uint16_t param_key : param_key_list) {
218     std::array<uint8_t, 2> num_buffer = base::U16ToBigEndian(param_key);
219     value.append(num_buffer.begin(), num_buffer.end());
220   }
221 
222   return std::pair(dns_protocol::kHttpsServiceParamKeyMandatory,
223                    std::move(value));
224 }
225 
BuildTestHttpsServicePortParam(uint16_t port)226 std::pair<uint16_t, std::string> BuildTestHttpsServicePortParam(uint16_t port) {
227   std::array<uint8_t, 2> buffer = base::U16ToBigEndian(port);
228   return std::pair(dns_protocol::kHttpsServiceParamKeyPort,
229                    std::string(buffer.begin(), buffer.end()));
230 }
231 
BuildTestHttpsServiceRecord(std::string name,uint16_t priority,std::string_view service_name,const std::map<uint16_t,std::string> & params,base::TimeDelta ttl)232 DnsResourceRecord BuildTestHttpsServiceRecord(
233     std::string name,
234     uint16_t priority,
235     std::string_view service_name,
236     const std::map<uint16_t, std::string>& params,
237     base::TimeDelta ttl) {
238   DCHECK(!name.empty());
239   DCHECK_NE(priority, 0);
240 
241   std::string rdata;
242 
243   {
244     std::array<uint8_t, 2> buf = base::U16ToBigEndian(priority);
245     rdata.append(buf.begin(), buf.end());
246   }
247 
248   std::optional<std::vector<uint8_t>> service_domain;
249   if (service_name == ".") {
250     // HTTPS records have special behavior for `service_name == "."` (that it
251     // will be treated as if the service name is the same as the record owner
252     // name), so allow such inputs despite normally being disallowed for
253     // Chrome-encoded DNS names.
254     service_domain = std::vector<uint8_t>{0};
255   } else {
256     service_domain = dns_names_util::DottedNameToNetwork(service_name);
257   }
258   CHECK(service_domain.has_value());
259   rdata.append(reinterpret_cast<char*>(service_domain.value().data()),
260                service_domain.value().size());
261 
262   for (auto& param : params) {
263     {
264       std::array<uint8_t, 2> buf = base::U16ToBigEndian(param.first);
265       rdata.append(buf.begin(), buf.end());
266     }
267     {
268       std::array<uint8_t, 2> buf = base::U16ToBigEndian(
269           base::checked_cast<uint16_t>(param.second.size()));
270       rdata.append(buf.begin(), buf.end());
271     }
272     rdata.append(param.second);
273   }
274 
275   return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeHttps,
276                             std::move(rdata), ttl);
277 }
278 
BuildTestDnsResponse(std::string name,uint16_t type,const std::vector<DnsResourceRecord> & answers,const std::vector<DnsResourceRecord> & authority,const std::vector<DnsResourceRecord> & additional,uint8_t rcode)279 DnsResponse BuildTestDnsResponse(
280     std::string name,
281     uint16_t type,
282     const std::vector<DnsResourceRecord>& answers,
283     const std::vector<DnsResourceRecord>& authority,
284     const std::vector<DnsResourceRecord>& additional,
285     uint8_t rcode) {
286   DCHECK(!name.empty());
287 
288   std::optional<std::vector<uint8_t>> dns_name =
289       dns_names_util::DottedNameToNetwork(name);
290   CHECK(dns_name.has_value());
291 
292   std::optional<DnsQuery> query(std::in_place, 0, dns_name.value(), type);
293   return DnsResponse(0, true /* is_authoritative */, answers,
294                      authority /* authority_records */,
295                      additional /* additional_records */, query, rcode,
296                      false /* validate_records */);
297 }
298 
BuildTestDnsAddressResponse(std::string name,const IPAddress & ip,std::string answer_name)299 DnsResponse BuildTestDnsAddressResponse(std::string name,
300                                         const IPAddress& ip,
301                                         std::string answer_name) {
302   DCHECK(ip.IsValid());
303 
304   if (answer_name.empty())
305     answer_name = name;
306 
307   std::vector<DnsResourceRecord> answers = {
308       BuildTestAddressRecord(std::move(answer_name), ip)};
309 
310   return BuildTestDnsResponse(
311       std::move(name),
312       ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, answers);
313 }
314 
BuildTestDnsAddressResponseWithCname(std::string name,const IPAddress & ip,std::string cannonname,std::string answer_name)315 DnsResponse BuildTestDnsAddressResponseWithCname(std::string name,
316                                                  const IPAddress& ip,
317                                                  std::string cannonname,
318                                                  std::string answer_name) {
319   DCHECK(ip.IsValid());
320   DCHECK(!cannonname.empty());
321 
322   if (answer_name.empty())
323     answer_name = name;
324 
325   std::optional<std::vector<uint8_t>> cname_rdata =
326       dns_names_util::DottedNameToNetwork(cannonname);
327   CHECK(cname_rdata.has_value());
328 
329   std::vector<DnsResourceRecord> answers = {
330       BuildTestDnsRecord(
331           std::move(answer_name), dns_protocol::kTypeCNAME,
332           std::string(reinterpret_cast<char*>(cname_rdata.value().data()),
333                       cname_rdata.value().size())),
334       BuildTestAddressRecord(std::move(cannonname), ip)};
335 
336   return BuildTestDnsResponse(
337       std::move(name),
338       ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, answers);
339 }
340 
BuildTestDnsTextResponse(std::string name,std::vector<std::vector<std::string>> text_records,std::string answer_name)341 DnsResponse BuildTestDnsTextResponse(
342     std::string name,
343     std::vector<std::vector<std::string>> text_records,
344     std::string answer_name) {
345   if (answer_name.empty())
346     answer_name = name;
347 
348   std::vector<DnsResourceRecord> answers;
349   for (std::vector<std::string>& text_record : text_records) {
350     answers.push_back(BuildTestTextRecord(answer_name, std::move(text_record)));
351   }
352 
353   return BuildTestDnsResponse(std::move(name), dns_protocol::kTypeTXT, answers);
354 }
355 
BuildTestDnsPointerResponse(std::string name,std::vector<std::string> pointer_names,std::string answer_name)356 DnsResponse BuildTestDnsPointerResponse(std::string name,
357                                         std::vector<std::string> pointer_names,
358                                         std::string answer_name) {
359   if (answer_name.empty())
360     answer_name = name;
361 
362   std::vector<DnsResourceRecord> answers;
363   for (std::string& pointer_name : pointer_names) {
364     std::optional<std::vector<uint8_t>> rdata =
365         dns_names_util::DottedNameToNetwork(pointer_name);
366     CHECK(rdata.has_value());
367 
368     answers.push_back(BuildTestDnsRecord(
369         answer_name, dns_protocol::kTypePTR,
370         std::string(reinterpret_cast<char*>(rdata.value().data()),
371                     rdata.value().size())));
372   }
373 
374   return BuildTestDnsResponse(std::move(name), dns_protocol::kTypePTR, answers);
375 }
376 
BuildTestDnsServiceResponse(std::string name,std::vector<TestServiceRecord> service_records,std::string answer_name)377 DnsResponse BuildTestDnsServiceResponse(
378     std::string name,
379     std::vector<TestServiceRecord> service_records,
380     std::string answer_name) {
381   if (answer_name.empty())
382     answer_name = name;
383 
384   std::vector<DnsResourceRecord> answers;
385   for (TestServiceRecord& service_record : service_records) {
386     std::string rdata;
387     {
388       std::array<uint8_t, 2> buf =
389           base::U16ToBigEndian(service_record.priority);
390       rdata.append(buf.begin(), buf.end());
391     }
392     {
393       std::array<uint8_t, 2> buf = base::U16ToBigEndian(service_record.weight);
394       rdata.append(buf.begin(), buf.end());
395     }
396     {
397       std::array<uint8_t, 2> buf = base::U16ToBigEndian(service_record.port);
398       rdata.append(buf.begin(), buf.end());
399     }
400 
401     std::optional<std::vector<uint8_t>> dns_name =
402         dns_names_util::DottedNameToNetwork(service_record.target);
403     CHECK(dns_name.has_value());
404     rdata.append(reinterpret_cast<char*>(dns_name.value().data()),
405                  dns_name.value().size());
406 
407     answers.push_back(BuildTestDnsRecord(answer_name, dns_protocol::kTypeSRV,
408                                          std::move(rdata), base::Hours(5)));
409   }
410 
411   return BuildTestDnsResponse(std::move(name), dns_protocol::kTypeSRV, answers);
412 }
413 
Result(ResultType type,std::optional<DnsResponse> response,std::optional<int> net_error)414 MockDnsClientRule::Result::Result(ResultType type,
415                                   std::optional<DnsResponse> response,
416                                   std::optional<int> net_error)
417     : type(type), response(std::move(response)), net_error(net_error) {}
418 
Result(DnsResponse response)419 MockDnsClientRule::Result::Result(DnsResponse response)
420     : type(ResultType::kOk),
421       response(std::move(response)),
422       net_error(std::nullopt) {}
423 
424 MockDnsClientRule::Result::Result(Result&&) = default;
425 
426 MockDnsClientRule::Result& MockDnsClientRule::Result::operator=(Result&&) =
427     default;
428 
429 MockDnsClientRule::Result::~Result() = default;
430 
MockDnsClientRule(const std::string & prefix,uint16_t qtype,bool secure,Result result,bool delay,URLRequestContext * context)431 MockDnsClientRule::MockDnsClientRule(const std::string& prefix,
432                                      uint16_t qtype,
433                                      bool secure,
434                                      Result result,
435                                      bool delay,
436                                      URLRequestContext* context)
437     : result(std::move(result)),
438       prefix(prefix),
439       qtype(qtype),
440       secure(secure),
441       delay(delay),
442       context(context) {}
443 
444 MockDnsClientRule::MockDnsClientRule(MockDnsClientRule&& rule) = default;
445 
446 // A DnsTransaction which uses MockDnsClientRuleList to determine the response.
447 class MockDnsTransactionFactory::MockTransaction final : public DnsTransaction {
448  public:
MockTransaction(const MockDnsClientRuleList & rules,std::string hostname,uint16_t qtype,bool secure,bool force_doh_server_available,SecureDnsMode secure_dns_mode,ResolveContext * resolve_context,bool fast_timeout)449   MockTransaction(const MockDnsClientRuleList& rules,
450                   std::string hostname,
451                   uint16_t qtype,
452                   bool secure,
453                   bool force_doh_server_available,
454                   SecureDnsMode secure_dns_mode,
455                   ResolveContext* resolve_context,
456                   bool fast_timeout)
457       : hostname_(std::move(hostname)), qtype_(qtype) {
458     // Do not allow matching any rules if transaction is secure and no DoH
459     // servers are available.
460     if (!secure || force_doh_server_available ||
461         resolve_context->NumAvailableDohServers(
462             resolve_context->current_session_for_testing()) > 0) {
463       // Find the relevant rule which matches |qtype|, |secure|, prefix of
464       // |hostname_|, and |url_request_context| (iff the rule context is not
465       // null).
466       for (const auto& rule : rules) {
467         const std::string& prefix = rule.prefix;
468         if ((rule.qtype == qtype) && (rule.secure == secure) &&
469             (hostname_.size() >= prefix.size()) &&
470             (hostname_.compare(0, prefix.size(), prefix) == 0) &&
471             (!rule.context ||
472              rule.context == resolve_context->url_request_context())) {
473           const MockDnsClientRule::Result* result = &rule.result;
474           result_ = MockDnsClientRule::Result(result->type);
475           result_.net_error = result->net_error;
476           delayed_ = rule.delay;
477 
478           // Generate a DnsResponse when not provided with the rule.
479           std::vector<DnsResourceRecord> authority_records;
480           std::optional<std::vector<uint8_t>> dns_name =
481               dns_names_util::DottedNameToNetwork(hostname_);
482           CHECK(dns_name.has_value());
483           std::optional<DnsQuery> query(std::in_place, /*id=*/22,
484                                         dns_name.value(), qtype_);
485           switch (result->type) {
486             case MockDnsClientRule::ResultType::kNoDomain:
487             case MockDnsClientRule::ResultType::kEmpty:
488               DCHECK(!result->response);  // Not expected to be provided.
489               authority_records = {BuildTestDnsRecord(
490                   hostname_, dns_protocol::kTypeSOA, "fake rdata")};
491               result_.response = DnsResponse(
492                   22 /* id */, false /* is_authoritative */,
493                   std::vector<DnsResourceRecord>() /* answers */,
494                   authority_records,
495                   std::vector<DnsResourceRecord>() /* additional_records */,
496                   query,
497                   result->type == MockDnsClientRule::ResultType::kNoDomain
498                       ? dns_protocol::kRcodeNXDOMAIN
499                       : 0);
500               break;
501             case MockDnsClientRule::ResultType::kFail:
502               if (result->response)
503                 SetResponse(result);
504               break;
505             case MockDnsClientRule::ResultType::kTimeout:
506               DCHECK(!result->response);  // Not expected to be provided.
507               break;
508             case MockDnsClientRule::ResultType::kSlow:
509               if (!fast_timeout)
510                 SetResponse(result);
511               break;
512             case MockDnsClientRule::ResultType::kOk:
513               SetResponse(result);
514               break;
515             case MockDnsClientRule::ResultType::kMalformed:
516               DCHECK(!result->response);  // Not expected to be provided.
517               result_.response = CreateMalformedResponse(hostname_, qtype_);
518               break;
519             case MockDnsClientRule::ResultType::kUnexpected:
520               if (!delayed_) {
521                 // Assume a delayed kUnexpected transaction is only an issue if
522                 // allowed to complete.
523                 ADD_FAILURE()
524                     << "Unexpected DNS transaction created for hostname "
525                     << hostname_;
526               }
527               break;
528           }
529 
530           break;
531         }
532       }
533     }
534   }
535 
GetHostname() const536   const std::string& GetHostname() const override { return hostname_; }
537 
GetType() const538   uint16_t GetType() const override { return qtype_; }
539 
Start(ResponseCallback callback)540   void Start(ResponseCallback callback) override {
541     CHECK(!callback.is_null());
542     CHECK(callback_.is_null());
543     EXPECT_FALSE(started_);
544 
545     callback_ = std::move(callback);
546     started_ = true;
547     if (delayed_)
548       return;
549     // Using WeakPtr to cleanly cancel when transaction is destroyed.
550     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
551         FROM_HERE, base::BindOnce(&MockTransaction::Finish,
552                                   weak_ptr_factory_.GetWeakPtr()));
553   }
554 
FinishDelayedTransaction()555   void FinishDelayedTransaction() {
556     EXPECT_TRUE(delayed_);
557     delayed_ = false;
558     Finish();
559   }
560 
delayed() const561   bool delayed() const { return delayed_; }
562 
AsWeakPtr()563   base::WeakPtr<MockTransaction> AsWeakPtr() {
564     return weak_ptr_factory_.GetWeakPtr();
565   }
566 
567  private:
SetResponse(const MockDnsClientRule::Result * result)568   void SetResponse(const MockDnsClientRule::Result* result) {
569     if (result->response) {
570       // Copy response in case |result| is destroyed before the transaction
571       // completes.
572       auto buffer_copy = base::MakeRefCounted<IOBufferWithSize>(
573           result->response->io_buffer_size());
574       memcpy(buffer_copy->data(), result->response->io_buffer()->data(),
575              result->response->io_buffer_size());
576       result_.response = DnsResponse(std::move(buffer_copy),
577                                      result->response->io_buffer_size());
578       CHECK(result_.response->InitParseWithoutQuery(
579           result->response->io_buffer_size()));
580     } else {
581       // Generated response only available for address types.
582       DCHECK(qtype_ == dns_protocol::kTypeA ||
583              qtype_ == dns_protocol::kTypeAAAA);
584       result_.response = BuildTestDnsAddressResponse(
585           hostname_, qtype_ == dns_protocol::kTypeA
586                          ? IPAddress::IPv4Localhost()
587                          : IPAddress::IPv6Localhost());
588     }
589   }
590 
Finish()591   void Finish() {
592     switch (result_.type) {
593       case MockDnsClientRule::ResultType::kNoDomain:
594       case MockDnsClientRule::ResultType::kFail: {
595         int error = result_.net_error.value_or(ERR_NAME_NOT_RESOLVED);
596         DCHECK_NE(error, OK);
597         std::move(callback_).Run(error, base::OptionalToPtr(result_.response));
598         break;
599       }
600       case MockDnsClientRule::ResultType::kEmpty:
601       case MockDnsClientRule::ResultType::kOk:
602       case MockDnsClientRule::ResultType::kMalformed:
603         DCHECK(!result_.net_error.has_value());
604         std::move(callback_).Run(OK, base::OptionalToPtr(result_.response));
605         break;
606       case MockDnsClientRule::ResultType::kTimeout:
607         DCHECK(!result_.net_error.has_value());
608         std::move(callback_).Run(ERR_DNS_TIMED_OUT, /*response=*/nullptr);
609         break;
610       case MockDnsClientRule::ResultType::kSlow:
611         if (result_.response) {
612           std::move(callback_).Run(
613               result_.net_error.value_or(OK),
614               result_.response ? &result_.response.value() : nullptr);
615         } else {
616           DCHECK(!result_.net_error.has_value());
617           std::move(callback_).Run(ERR_DNS_TIMED_OUT, /*response=*/nullptr);
618         }
619         break;
620       case MockDnsClientRule::ResultType::kUnexpected:
621         ADD_FAILURE() << "Unexpected DNS transaction completed for hostname "
622                       << hostname_;
623         break;
624     }
625   }
626 
SetRequestPriority(RequestPriority priority)627   void SetRequestPriority(RequestPriority priority) override {}
628 
629   MockDnsClientRule::Result result_{MockDnsClientRule::ResultType::kFail};
630   const std::string hostname_;
631   const uint16_t qtype_;
632   ResponseCallback callback_;
633   bool started_ = false;
634   bool delayed_ = false;
635   base::WeakPtrFactory<MockTransaction> weak_ptr_factory_{this};
636 };
637 
638 class MockDnsTransactionFactory::MockDohProbeRunner : public DnsProbeRunner {
639  public:
MockDohProbeRunner(base::WeakPtr<MockDnsTransactionFactory> factory)640   explicit MockDohProbeRunner(base::WeakPtr<MockDnsTransactionFactory> factory)
641       : factory_(std::move(factory)) {}
642 
~MockDohProbeRunner()643   ~MockDohProbeRunner() override {
644     if (factory_)
645       factory_->running_doh_probe_runners_.erase(this);
646   }
647 
Start(bool network_change)648   void Start(bool network_change) override {
649     DCHECK(factory_);
650     factory_->running_doh_probe_runners_.insert(this);
651   }
652 
GetDelayUntilNextProbeForTest(size_t doh_server_index) const653   base::TimeDelta GetDelayUntilNextProbeForTest(
654       size_t doh_server_index) const override {
655     NOTREACHED();
656   }
657 
658  private:
659   base::WeakPtr<MockDnsTransactionFactory> factory_;
660 };
661 
MockDnsTransactionFactory(MockDnsClientRuleList rules)662 MockDnsTransactionFactory::MockDnsTransactionFactory(
663     MockDnsClientRuleList rules)
664     : rules_(std::move(rules)) {}
665 
666 MockDnsTransactionFactory::~MockDnsTransactionFactory() = default;
667 
CreateTransaction(std::string hostname,uint16_t qtype,const NetLogWithSource &,bool secure,SecureDnsMode secure_dns_mode,ResolveContext * resolve_context,bool fast_timeout)668 std::unique_ptr<DnsTransaction> MockDnsTransactionFactory::CreateTransaction(
669     std::string hostname,
670     uint16_t qtype,
671     const NetLogWithSource&,
672     bool secure,
673     SecureDnsMode secure_dns_mode,
674     ResolveContext* resolve_context,
675     bool fast_timeout) {
676   std::unique_ptr<MockTransaction> transaction =
677       std::make_unique<MockTransaction>(rules_, std::move(hostname), qtype,
678                                         secure, force_doh_server_available_,
679                                         secure_dns_mode, resolve_context,
680                                         fast_timeout);
681   if (transaction->delayed())
682     delayed_transactions_.push_back(transaction->AsWeakPtr());
683   return transaction;
684 }
685 
CreateDohProbeRunner(ResolveContext * resolve_context)686 std::unique_ptr<DnsProbeRunner> MockDnsTransactionFactory::CreateDohProbeRunner(
687     ResolveContext* resolve_context) {
688   return std::make_unique<MockDohProbeRunner>(weak_ptr_factory_.GetWeakPtr());
689 }
690 
AddEDNSOption(std::unique_ptr<OptRecordRdata::Opt> opt)691 void MockDnsTransactionFactory::AddEDNSOption(
692     std::unique_ptr<OptRecordRdata::Opt> opt) {}
693 
GetSecureDnsModeForTest()694 SecureDnsMode MockDnsTransactionFactory::GetSecureDnsModeForTest() {
695   return SecureDnsMode::kAutomatic;
696 }
697 
CompleteDelayedTransactions()698 void MockDnsTransactionFactory::CompleteDelayedTransactions() {
699   DelayedTransactionList old_delayed_transactions;
700   old_delayed_transactions.swap(delayed_transactions_);
701   for (auto& old_delayed_transaction : old_delayed_transactions) {
702     if (old_delayed_transaction.get())
703       old_delayed_transaction->FinishDelayedTransaction();
704   }
705 }
706 
CompleteOneDelayedTransactionOfType(DnsQueryType type)707 bool MockDnsTransactionFactory::CompleteOneDelayedTransactionOfType(
708     DnsQueryType type) {
709   for (base::WeakPtr<MockTransaction>& t : delayed_transactions_) {
710     if (t && t->GetType() == DnsQueryTypeToQtype(type)) {
711       t->FinishDelayedTransaction();
712       t.reset();
713       return true;
714     }
715   }
716   return false;
717 }
718 
MockDnsClient(DnsConfig config,MockDnsClientRuleList rules)719 MockDnsClient::MockDnsClient(DnsConfig config, MockDnsClientRuleList rules)
720     : config_(std::move(config)),
721       factory_(std::make_unique<MockDnsTransactionFactory>(std::move(rules))),
722       address_sorter_(std::make_unique<MockAddressSorter>()) {
723   effective_config_ = BuildEffectiveConfig();
724   session_ = BuildSession();
725 }
726 
727 MockDnsClient::~MockDnsClient() = default;
728 
CanUseSecureDnsTransactions() const729 bool MockDnsClient::CanUseSecureDnsTransactions() const {
730   const DnsConfig* config = GetEffectiveConfig();
731   return config && config->IsValid() && !config->doh_config.servers().empty();
732 }
733 
CanUseInsecureDnsTransactions() const734 bool MockDnsClient::CanUseInsecureDnsTransactions() const {
735   const DnsConfig* config = GetEffectiveConfig();
736   return config && config->IsValid() && insecure_enabled_ &&
737          !config->dns_over_tls_active;
738 }
739 
CanQueryAdditionalTypesViaInsecureDns() const740 bool MockDnsClient::CanQueryAdditionalTypesViaInsecureDns() const {
741   DCHECK(CanUseInsecureDnsTransactions());
742   return additional_types_enabled_;
743 }
744 
SetInsecureEnabled(bool enabled,bool additional_types_enabled)745 void MockDnsClient::SetInsecureEnabled(bool enabled,
746                                        bool additional_types_enabled) {
747   insecure_enabled_ = enabled;
748   additional_types_enabled_ = additional_types_enabled;
749 }
750 
FallbackFromSecureTransactionPreferred(ResolveContext * context) const751 bool MockDnsClient::FallbackFromSecureTransactionPreferred(
752     ResolveContext* context) const {
753   bool doh_server_available =
754       force_doh_server_available_ ||
755       context->NumAvailableDohServers(session_.get()) > 0;
756   return !CanUseSecureDnsTransactions() || !doh_server_available;
757 }
758 
FallbackFromInsecureTransactionPreferred() const759 bool MockDnsClient::FallbackFromInsecureTransactionPreferred() const {
760   return !CanUseInsecureDnsTransactions() ||
761          fallback_failures_ >= max_fallback_failures_;
762 }
763 
SetSystemConfig(std::optional<DnsConfig> system_config)764 bool MockDnsClient::SetSystemConfig(std::optional<DnsConfig> system_config) {
765   if (ignore_system_config_changes_)
766     return false;
767 
768   std::optional<DnsConfig> before = effective_config_;
769   config_ = std::move(system_config);
770   effective_config_ = BuildEffectiveConfig();
771   session_ = BuildSession();
772   return before != effective_config_;
773 }
774 
SetConfigOverrides(DnsConfigOverrides config_overrides)775 bool MockDnsClient::SetConfigOverrides(DnsConfigOverrides config_overrides) {
776   std::optional<DnsConfig> before = effective_config_;
777   overrides_ = std::move(config_overrides);
778   effective_config_ = BuildEffectiveConfig();
779   session_ = BuildSession();
780   return before != effective_config_;
781 }
782 
ReplaceCurrentSession()783 void MockDnsClient::ReplaceCurrentSession() {
784   // Noop if no current effective config.
785   session_ = BuildSession();
786 }
787 
GetCurrentSession()788 DnsSession* MockDnsClient::GetCurrentSession() {
789   return session_.get();
790 }
791 
GetEffectiveConfig() const792 const DnsConfig* MockDnsClient::GetEffectiveConfig() const {
793   return effective_config_.has_value() ? &effective_config_.value() : nullptr;
794 }
795 
GetDnsConfigAsValueForNetLog() const796 base::Value::Dict MockDnsClient::GetDnsConfigAsValueForNetLog() const {
797   // This is just a stub implementation that never produces a meaningful value.
798   return base::Value::Dict();
799 }
800 
GetHosts() const801 const DnsHosts* MockDnsClient::GetHosts() const {
802   const DnsConfig* config = GetEffectiveConfig();
803   if (!config)
804     return nullptr;
805 
806   return &config->hosts;
807 }
808 
GetTransactionFactory()809 DnsTransactionFactory* MockDnsClient::GetTransactionFactory() {
810   return GetEffectiveConfig() ? factory_.get() : nullptr;
811 }
812 
GetAddressSorter()813 AddressSorter* MockDnsClient::GetAddressSorter() {
814   return GetEffectiveConfig() ? address_sorter_.get() : nullptr;
815 }
816 
IncrementInsecureFallbackFailures()817 void MockDnsClient::IncrementInsecureFallbackFailures() {
818   ++fallback_failures_;
819 }
820 
ClearInsecureFallbackFailures()821 void MockDnsClient::ClearInsecureFallbackFailures() {
822   fallback_failures_ = 0;
823 }
824 
GetSystemConfigForTesting() const825 std::optional<DnsConfig> MockDnsClient::GetSystemConfigForTesting() const {
826   return config_;
827 }
828 
GetConfigOverridesForTesting() const829 DnsConfigOverrides MockDnsClient::GetConfigOverridesForTesting() const {
830   return overrides_;
831 }
832 
SetTransactionFactoryForTesting(std::unique_ptr<DnsTransactionFactory> factory)833 void MockDnsClient::SetTransactionFactoryForTesting(
834     std::unique_ptr<DnsTransactionFactory> factory) {
835   NOTREACHED();
836 }
837 
SetAddressSorterForTesting(std::unique_ptr<AddressSorter> address_sorter)838 void MockDnsClient::SetAddressSorterForTesting(
839     std::unique_ptr<AddressSorter> address_sorter) {
840   address_sorter_ = std::move(address_sorter);
841 }
842 
GetPresetAddrs(const url::SchemeHostPort & endpoint) const843 std::optional<std::vector<IPEndPoint>> MockDnsClient::GetPresetAddrs(
844     const url::SchemeHostPort& endpoint) const {
845   EXPECT_THAT(preset_endpoint_, testing::Optional(endpoint));
846   return preset_addrs_;
847 }
848 
CompleteDelayedTransactions()849 void MockDnsClient::CompleteDelayedTransactions() {
850   factory_->CompleteDelayedTransactions();
851 }
852 
CompleteOneDelayedTransactionOfType(DnsQueryType type)853 bool MockDnsClient::CompleteOneDelayedTransactionOfType(DnsQueryType type) {
854   return factory_->CompleteOneDelayedTransactionOfType(type);
855 }
856 
SetForceDohServerAvailable(bool available)857 void MockDnsClient::SetForceDohServerAvailable(bool available) {
858   force_doh_server_available_ = available;
859   factory_->set_force_doh_server_available(available);
860 }
861 
BuildEffectiveConfig()862 std::optional<DnsConfig> MockDnsClient::BuildEffectiveConfig() {
863   if (overrides_.OverridesEverything())
864     return overrides_.ApplyOverrides(DnsConfig());
865   if (!config_ || !config_.value().IsValid())
866     return std::nullopt;
867 
868   return overrides_.ApplyOverrides(config_.value());
869 }
870 
BuildSession()871 scoped_refptr<DnsSession> MockDnsClient::BuildSession() {
872   if (!effective_config_)
873     return nullptr;
874 
875   // Session not expected to be used for anything that will actually require
876   // random numbers.
877   auto null_random_callback =
878       base::BindRepeating([](int, int) -> int { base::ImmediateCrash(); });
879 
880   return base::MakeRefCounted<DnsSession>(
881       effective_config_.value(), null_random_callback, nullptr /* net_log */);
882 }
883 
MockHostResolverProc()884 MockHostResolverProc::MockHostResolverProc()
885     : HostResolverProc(nullptr),
886       requests_waiting_(&lock_),
887       slots_available_(&lock_) {}
888 
889 MockHostResolverProc::~MockHostResolverProc() = default;
890 
WaitFor(unsigned count)891 bool MockHostResolverProc::WaitFor(unsigned count) {
892   base::AutoLock lock(lock_);
893   base::Time start_time = base::Time::Now();
894   while (num_requests_waiting_ < count) {
895     requests_waiting_.TimedWait(TestTimeouts::action_timeout());
896     if (base::Time::Now() > start_time + TestTimeouts::action_timeout()) {
897       return false;
898     }
899   }
900   return true;
901 }
902 
SignalMultiple(unsigned count)903 void MockHostResolverProc::SignalMultiple(unsigned count) {
904   base::AutoLock lock(lock_);
905   num_slots_available_ += count;
906   slots_available_.Broadcast();
907 }
908 
SignalAll()909 void MockHostResolverProc::SignalAll() {
910   base::AutoLock lock(lock_);
911   num_slots_available_ = num_requests_waiting_;
912   slots_available_.Broadcast();
913 }
914 
AddRule(const std::string & hostname,AddressFamily family,const AddressList & result,HostResolverFlags flags)915 void MockHostResolverProc::AddRule(const std::string& hostname,
916                                    AddressFamily family,
917                                    const AddressList& result,
918                                    HostResolverFlags flags) {
919   base::AutoLock lock(lock_);
920   rules_[ResolveKey(hostname, family, flags)] = result;
921 }
922 
AddRule(const std::string & hostname,AddressFamily family,const std::string & ip_list,HostResolverFlags flags,const std::string & canonical_name)923 void MockHostResolverProc::AddRule(const std::string& hostname,
924                                    AddressFamily family,
925                                    const std::string& ip_list,
926                                    HostResolverFlags flags,
927                                    const std::string& canonical_name) {
928   AddressList result;
929   std::vector<std::string> dns_aliases;
930   if (canonical_name != "") {
931     dns_aliases = {canonical_name};
932   }
933   int rv = ParseAddressList(ip_list, &result.endpoints());
934   result.SetDnsAliases(dns_aliases);
935   DCHECK_EQ(OK, rv);
936   AddRule(hostname, family, result, flags);
937 }
938 
AddRuleForAllFamilies(const std::string & hostname,const std::string & ip_list,HostResolverFlags flags,const std::string & canonical_name)939 void MockHostResolverProc::AddRuleForAllFamilies(
940     const std::string& hostname,
941     const std::string& ip_list,
942     HostResolverFlags flags,
943     const std::string& canonical_name) {
944   AddressList result;
945   std::vector<std::string> dns_aliases;
946   if (canonical_name != "") {
947     dns_aliases = {canonical_name};
948   }
949   int rv = ParseAddressList(ip_list, &result.endpoints());
950   result.SetDnsAliases(dns_aliases);
951   DCHECK_EQ(OK, rv);
952   AddRule(hostname, ADDRESS_FAMILY_UNSPECIFIED, result, flags);
953   AddRule(hostname, ADDRESS_FAMILY_IPV4, result, flags);
954   AddRule(hostname, ADDRESS_FAMILY_IPV6, result, flags);
955 }
956 
Resolve(const std::string & hostname,AddressFamily address_family,HostResolverFlags host_resolver_flags,AddressList * addrlist,int * os_error)957 int MockHostResolverProc::Resolve(const std::string& hostname,
958                                   AddressFamily address_family,
959                                   HostResolverFlags host_resolver_flags,
960                                   AddressList* addrlist,
961                                   int* os_error) {
962   base::AutoLock lock(lock_);
963   capture_list_.emplace_back(hostname, address_family, host_resolver_flags);
964   ++num_requests_waiting_;
965   requests_waiting_.Broadcast();
966   {
967     base::ScopedAllowBaseSyncPrimitivesForTesting
968         scoped_allow_base_sync_primitives;
969     while (!num_slots_available_) {
970       slots_available_.Wait();
971     }
972   }
973   DCHECK_GT(num_requests_waiting_, 0u);
974   --num_slots_available_;
975   --num_requests_waiting_;
976   if (rules_.empty()) {
977     int rv = ParseAddressList("127.0.0.1", &addrlist->endpoints());
978     DCHECK_EQ(OK, rv);
979     return OK;
980   }
981   ResolveKey key(hostname, address_family, host_resolver_flags);
982   if (rules_.count(key) == 0) {
983     return ERR_NAME_NOT_RESOLVED;
984   }
985   *addrlist = rules_[key];
986   return OK;
987 }
988 
GetCaptureList() const989 MockHostResolverProc::CaptureList MockHostResolverProc::GetCaptureList() const {
990   CaptureList copy;
991   {
992     base::AutoLock lock(lock_);
993     copy = capture_list_;
994   }
995   return copy;
996 }
997 
ClearCaptureList()998 void MockHostResolverProc::ClearCaptureList() {
999   base::AutoLock lock(lock_);
1000   capture_list_.clear();
1001 }
1002 
HasBlockedRequests() const1003 bool MockHostResolverProc::HasBlockedRequests() const {
1004   base::AutoLock lock(lock_);
1005   return num_requests_waiting_ > num_slots_available_;
1006 }
1007 
1008 }  // namespace net
1009