1 // Copyright 2021 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/test/test_doh_server.h"
6
7 #include <string.h>
8
9 #include <memory>
10 #include <string_view>
11
12 #include "base/base64url.h"
13 #include "base/check.h"
14 #include "base/functional/bind.h"
15 #include "base/logging.h"
16 #include "base/memory/scoped_refptr.h"
17 #include "base/ranges/algorithm.h"
18 #include "base/strings/string_number_conversions.h"
19 #include "base/synchronization/lock.h"
20 #include "base/time/time.h"
21 #include "net/base/io_buffer.h"
22 #include "net/base/url_util.h"
23 #include "net/dns/dns_names_util.h"
24 #include "net/dns/dns_query.h"
25 #include "net/dns/dns_response.h"
26 #include "net/dns/dns_test_util.h"
27 #include "net/dns/public/dns_protocol.h"
28 #include "net/http/http_status_code.h"
29 #include "net/test/embedded_test_server/embedded_test_server.h"
30 #include "net/test/embedded_test_server/http_request.h"
31 #include "net/test/embedded_test_server/http_response.h"
32 #include "url/gurl.h"
33
34 namespace net {
35
36 namespace {
37
38 const char kPath[] = "/dns-query";
39
MakeHttpErrorResponse(HttpStatusCode status,std::string_view error)40 std::unique_ptr<test_server::HttpResponse> MakeHttpErrorResponse(
41 HttpStatusCode status,
42 std::string_view error) {
43 auto response = std::make_unique<test_server::BasicHttpResponse>();
44 response->set_code(status);
45 response->set_content(std::string(error));
46 response->set_content_type("text/plain;charset=utf-8");
47 return response;
48 }
49
MakeHttpResponseFromDns(const DnsResponse & dns_response)50 std::unique_ptr<test_server::HttpResponse> MakeHttpResponseFromDns(
51 const DnsResponse& dns_response) {
52 if (!dns_response.IsValid()) {
53 return MakeHttpErrorResponse(HTTP_INTERNAL_SERVER_ERROR,
54 "error making DNS response");
55 }
56
57 auto response = std::make_unique<test_server::BasicHttpResponse>();
58 response->set_code(HTTP_OK);
59 response->set_content(std::string(dns_response.io_buffer()->data(),
60 dns_response.io_buffer_size()));
61 response->set_content_type("application/dns-message");
62 return response;
63 }
64
65 } // namespace
66
TestDohServer()67 TestDohServer::TestDohServer() {
68 server_.RegisterRequestHandler(base::BindRepeating(
69 &TestDohServer::HandleRequest, base::Unretained(this)));
70 }
71
72 TestDohServer::~TestDohServer() = default;
73
SetHostname(std::string_view name)74 void TestDohServer::SetHostname(std::string_view name) {
75 DCHECK(!server_.Started());
76 hostname_ = std::string(name);
77 }
78
SetFailRequests(bool fail_requests)79 void TestDohServer::SetFailRequests(bool fail_requests) {
80 base::AutoLock lock(lock_);
81 fail_requests_ = fail_requests;
82 }
83
AddAddressRecord(std::string_view name,const IPAddress & address,base::TimeDelta ttl)84 void TestDohServer::AddAddressRecord(std::string_view name,
85 const IPAddress& address,
86 base::TimeDelta ttl) {
87 AddRecord(BuildTestAddressRecord(std::string(name), address, ttl));
88 }
89
AddRecord(const DnsResourceRecord & record)90 void TestDohServer::AddRecord(const DnsResourceRecord& record) {
91 base::AutoLock lock(lock_);
92 records_.insert(
93 std::make_pair(std::make_pair(record.name, record.type), record));
94 }
95
Start()96 bool TestDohServer::Start() {
97 if (!InitializeAndListen()) {
98 return false;
99 }
100 StartAcceptingConnections();
101 return true;
102 }
103
InitializeAndListen()104 bool TestDohServer::InitializeAndListen() {
105 if (hostname_) {
106 EmbeddedTestServer::ServerCertificateConfig cert_config;
107 cert_config.dns_names = {*hostname_};
108 server_.SetSSLConfig(cert_config);
109 } else {
110 // `CERT_OK` is valid for 127.0.0.1.
111 server_.SetSSLConfig(EmbeddedTestServer::CERT_OK);
112 }
113 return server_.InitializeAndListen();
114 }
115
StartAcceptingConnections()116 void TestDohServer::StartAcceptingConnections() {
117 server_.StartAcceptingConnections();
118 }
119
ShutdownAndWaitUntilComplete()120 bool TestDohServer::ShutdownAndWaitUntilComplete() {
121 return server_.ShutdownAndWaitUntilComplete();
122 }
123
GetTemplate()124 std::string TestDohServer::GetTemplate() {
125 GURL url =
126 hostname_ ? server_.GetURL(*hostname_, kPath) : server_.GetURL(kPath);
127 return url.spec() + "{?dns}";
128 }
129
GetPostOnlyTemplate()130 std::string TestDohServer::GetPostOnlyTemplate() {
131 GURL url =
132 hostname_ ? server_.GetURL(*hostname_, kPath) : server_.GetURL(kPath);
133 return url.spec();
134 }
135
QueriesServed()136 int TestDohServer::QueriesServed() {
137 base::AutoLock lock(lock_);
138 return queries_served_;
139 }
140
QueriesServedForSubdomains(std::string_view domain)141 int TestDohServer::QueriesServedForSubdomains(std::string_view domain) {
142 CHECK(net::dns_names_util::IsValidDnsName(domain));
143 auto is_subdomain = [&domain](std::string_view candidate) {
144 return net::IsSubdomainOf(candidate, domain);
145 };
146 base::AutoLock lock(lock_);
147 return base::ranges::count_if(query_qnames_, is_subdomain);
148 }
149
HandleRequest(const test_server::HttpRequest & request)150 std::unique_ptr<test_server::HttpResponse> TestDohServer::HandleRequest(
151 const test_server::HttpRequest& request) {
152 GURL request_url = request.GetURL();
153 if (request_url.path_piece() != kPath) {
154 return nullptr;
155 }
156
157 base::AutoLock lock(lock_);
158 queries_served_++;
159
160 if (fail_requests_) {
161 return MakeHttpErrorResponse(HTTP_NOT_FOUND, "failed request");
162 }
163
164 // See RFC 8484, Section 4.1.
165 std::string query;
166 if (request.method == test_server::METHOD_GET) {
167 std::string query_b64;
168 if (!GetValueForKeyInQuery(request_url, "dns", &query_b64) ||
169 !base::Base64UrlDecode(
170 query_b64, base::Base64UrlDecodePolicy::IGNORE_PADDING, &query)) {
171 return MakeHttpErrorResponse(HTTP_BAD_REQUEST,
172 "could not decode query string");
173 }
174 } else if (request.method == test_server::METHOD_POST) {
175 auto content_type = request.headers.find("content-type");
176 if (content_type == request.headers.end() ||
177 content_type->second != "application/dns-message") {
178 return MakeHttpErrorResponse(HTTP_BAD_REQUEST,
179 "unsupported content type");
180 }
181 query = request.content;
182 } else {
183 return MakeHttpErrorResponse(HTTP_BAD_REQUEST, "invalid method");
184 }
185
186 // Parse the DNS query.
187 auto query_buf = base::MakeRefCounted<IOBufferWithSize>(query.size());
188 memcpy(query_buf->data(), query.data(), query.size());
189 DnsQuery dns_query(std::move(query_buf));
190 if (!dns_query.Parse(query.size())) {
191 return MakeHttpErrorResponse(HTTP_BAD_REQUEST, "invalid DNS query");
192 }
193
194 absl::optional<std::string> name = dns_names_util::NetworkToDottedName(
195 dns_query.qname(), /*require_complete=*/true);
196 if (!name) {
197 DnsResponse response(dns_query.id(), /*is_authoritative=*/false,
198 /*answers=*/{}, /*authority_records=*/{},
199 /*additional_records=*/{}, dns_query,
200 dns_protocol::kRcodeFORMERR);
201 return MakeHttpResponseFromDns(response);
202 }
203 query_qnames_.push_back(*name);
204
205 auto range = records_.equal_range(std::make_pair(*name, dns_query.qtype()));
206 std::vector<DnsResourceRecord> answers;
207 for (auto i = range.first; i != range.second; ++i) {
208 answers.push_back(i->second);
209 }
210
211 LOG(INFO) << "Serving " << answers.size() << " records for " << *name
212 << ", qtype " << dns_query.qtype();
213
214 // Note `answers` may be empty. NOERROR with no answers is how to express
215 // NODATA, so there is no need handle it specially.
216 //
217 // For now, this server does not support configuring additional records. When
218 // testing more complex HTTPS record cases, this will need to be extended.
219 //
220 // TODO(crbug.com/1251204): Add SOA records to test the default TTL.
221 DnsResponse response(dns_query.id(), /*is_authoritative=*/true,
222 /*answers=*/answers, /*authority_records=*/{},
223 /*additional_records=*/{}, dns_query);
224 return MakeHttpResponseFromDns(response);
225 }
226
227 } // namespace net
228