• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/test/test_doh_server.h"
6 
7 #include <string.h>
8 
9 #include <memory>
10 
11 #include "base/base64url.h"
12 #include "base/check.h"
13 #include "base/functional/bind.h"
14 #include "base/logging.h"
15 #include "base/memory/scoped_refptr.h"
16 #include "base/strings/string_number_conversions.h"
17 #include "base/strings/string_piece.h"
18 #include "base/synchronization/lock.h"
19 #include "net/base/io_buffer.h"
20 #include "net/base/url_util.h"
21 #include "net/dns/dns_names_util.h"
22 #include "net/dns/dns_query.h"
23 #include "net/dns/dns_response.h"
24 #include "net/dns/dns_test_util.h"
25 #include "net/dns/public/dns_protocol.h"
26 #include "net/http/http_status_code.h"
27 #include "net/test/embedded_test_server/embedded_test_server.h"
28 #include "net/test/embedded_test_server/http_request.h"
29 #include "net/test/embedded_test_server/http_response.h"
30 #include "url/gurl.h"
31 
32 namespace net {
33 
34 namespace {
35 
36 const char kPath[] = "/dns-query";
37 
MakeHttpErrorResponse(HttpStatusCode status,base::StringPiece error)38 std::unique_ptr<test_server::HttpResponse> MakeHttpErrorResponse(
39     HttpStatusCode status,
40     base::StringPiece error) {
41   auto response = std::make_unique<test_server::BasicHttpResponse>();
42   response->set_code(status);
43   response->set_content(std::string(error));
44   response->set_content_type("text/plain;charset=utf-8");
45   return response;
46 }
47 
MakeHttpResponseFromDns(const DnsResponse & dns_response)48 std::unique_ptr<test_server::HttpResponse> MakeHttpResponseFromDns(
49     const DnsResponse& dns_response) {
50   if (!dns_response.IsValid()) {
51     return MakeHttpErrorResponse(HTTP_INTERNAL_SERVER_ERROR,
52                                  "error making DNS response");
53   }
54 
55   auto response = std::make_unique<test_server::BasicHttpResponse>();
56   response->set_code(HTTP_OK);
57   response->set_content(std::string(dns_response.io_buffer()->data(),
58                                     dns_response.io_buffer_size()));
59   response->set_content_type("application/dns-message");
60   return response;
61 }
62 
63 }  // namespace
64 
TestDohServer()65 TestDohServer::TestDohServer() {
66   server_.RegisterRequestHandler(base::BindRepeating(
67       &TestDohServer::HandleRequest, base::Unretained(this)));
68 }
69 
70 TestDohServer::~TestDohServer() = default;
71 
SetHostname(base::StringPiece name)72 void TestDohServer::SetHostname(base::StringPiece name) {
73   DCHECK(!server_.Started());
74   hostname_ = std::string(name);
75 }
76 
SetFailRequests(bool fail_requests)77 void TestDohServer::SetFailRequests(bool fail_requests) {
78   base::AutoLock lock(lock_);
79   fail_requests_ = fail_requests;
80 }
81 
AddAddressRecord(base::StringPiece name,const IPAddress & address,base::TimeDelta ttl)82 void TestDohServer::AddAddressRecord(base::StringPiece name,
83                                      const IPAddress& address,
84                                      base::TimeDelta ttl) {
85   AddRecord(BuildTestAddressRecord(std::string(name), address, ttl));
86 }
87 
AddRecord(const DnsResourceRecord & record)88 void TestDohServer::AddRecord(const DnsResourceRecord& record) {
89   base::AutoLock lock(lock_);
90   records_.insert(
91       std::make_pair(std::make_pair(record.name, record.type), record));
92 }
93 
Start()94 bool TestDohServer::Start() {
95   if (!InitializeAndListen()) {
96     return false;
97   }
98   StartAcceptingConnections();
99   return true;
100 }
101 
InitializeAndListen()102 bool TestDohServer::InitializeAndListen() {
103   if (hostname_) {
104     EmbeddedTestServer::ServerCertificateConfig cert_config;
105     cert_config.dns_names = {*hostname_};
106     server_.SetSSLConfig(cert_config);
107   } else {
108     // `CERT_OK` is valid for 127.0.0.1.
109     server_.SetSSLConfig(EmbeddedTestServer::CERT_OK);
110   }
111   return server_.InitializeAndListen();
112 }
113 
StartAcceptingConnections()114 void TestDohServer::StartAcceptingConnections() {
115   server_.StartAcceptingConnections();
116 }
117 
ShutdownAndWaitUntilComplete()118 bool TestDohServer::ShutdownAndWaitUntilComplete() {
119   return server_.ShutdownAndWaitUntilComplete();
120 }
121 
GetTemplate()122 std::string TestDohServer::GetTemplate() {
123   GURL url =
124       hostname_ ? server_.GetURL(*hostname_, kPath) : server_.GetURL(kPath);
125   return url.spec() + "{?dns}";
126 }
127 
GetPostOnlyTemplate()128 std::string TestDohServer::GetPostOnlyTemplate() {
129   GURL url =
130       hostname_ ? server_.GetURL(*hostname_, kPath) : server_.GetURL(kPath);
131   return url.spec();
132 }
133 
QueriesServed()134 int TestDohServer::QueriesServed() {
135   base::AutoLock lock(lock_);
136   return queries_served_;
137 }
138 
HandleRequest(const test_server::HttpRequest & request)139 std::unique_ptr<test_server::HttpResponse> TestDohServer::HandleRequest(
140     const test_server::HttpRequest& request) {
141   GURL request_url = request.GetURL();
142   if (request_url.path_piece() != kPath) {
143     return nullptr;
144   }
145 
146   base::AutoLock lock(lock_);
147   queries_served_++;
148 
149   if (fail_requests_) {
150     return MakeHttpErrorResponse(HTTP_NOT_FOUND, "failed request");
151   }
152 
153   // See RFC 8484, Section 4.1.
154   std::string query;
155   if (request.method == test_server::METHOD_GET) {
156     std::string query_b64;
157     if (!GetValueForKeyInQuery(request_url, "dns", &query_b64) ||
158         !base::Base64UrlDecode(
159             query_b64, base::Base64UrlDecodePolicy::IGNORE_PADDING, &query)) {
160       return MakeHttpErrorResponse(HTTP_BAD_REQUEST,
161                                    "could not decode query string");
162     }
163   } else if (request.method == test_server::METHOD_POST) {
164     auto content_type = request.headers.find("content-type");
165     if (content_type == request.headers.end() ||
166         content_type->second != "application/dns-message") {
167       return MakeHttpErrorResponse(HTTP_BAD_REQUEST,
168                                    "unsupported content type");
169     }
170     query = request.content;
171   } else {
172     return MakeHttpErrorResponse(HTTP_BAD_REQUEST, "invalid method");
173   }
174 
175   // Parse the DNS query.
176   auto query_buf = base::MakeRefCounted<IOBufferWithSize>(query.size());
177   memcpy(query_buf->data(), query.data(), query.size());
178   DnsQuery dns_query(std::move(query_buf));
179   if (!dns_query.Parse(query.size())) {
180     return MakeHttpErrorResponse(HTTP_BAD_REQUEST, "invalid DNS query");
181   }
182 
183   absl::optional<std::string> name = dns_names_util::NetworkToDottedName(
184       dns_query.qname(), /*require_complete=*/true);
185   if (!name) {
186     DnsResponse response(dns_query.id(), /*is_authoritative=*/false,
187                          /*answers=*/{}, /*authority_records=*/{},
188                          /*additional_records=*/{}, dns_query,
189                          dns_protocol::kRcodeFORMERR);
190     return MakeHttpResponseFromDns(response);
191   }
192 
193   auto range = records_.equal_range(std::make_pair(*name, dns_query.qtype()));
194   std::vector<DnsResourceRecord> answers;
195   for (auto i = range.first; i != range.second; ++i) {
196     answers.push_back(i->second);
197   }
198 
199   VLOG(1) << "Serving " << answers.size() << " records for " << *name
200           << ", qtype " << dns_query.qtype();
201 
202   // Note `answers` may be empty. NOERROR with no answers is how to express
203   // NODATA, so there is no need handle it specially.
204   //
205   // For now, this server does not support configuring additional records. When
206   // testing more complex HTTPS record cases, this will need to be extended.
207   //
208   // TODO(crbug.com/1251204): Add SOA records to test the default TTL.
209   DnsResponse response(dns_query.id(), /*is_authoritative=*/true,
210                        /*answers=*/answers, /*authority_records=*/{},
211                        /*additional_records=*/{}, dns_query);
212   return MakeHttpResponseFromDns(response);
213 }
214 
215 }  // namespace net
216