1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/test/embedded_test_server/embedded_test_server.h"
6
7 #include <stdint.h>
8
9 #include <memory>
10 #include <utility>
11
12 #include "base/files/file_path.h"
13 #include "base/functional/bind.h"
14 #include "base/functional/callback_forward.h"
15 #include "base/functional/callback_helpers.h"
16 #include "base/location.h"
17 #include "base/logging.h"
18 #include "base/message_loop/message_pump_type.h"
19 #include "base/path_service.h"
20 #include "base/process/process_metrics.h"
21 #include "base/run_loop.h"
22 #include "base/strings/string_number_conversions.h"
23 #include "base/strings/string_piece.h"
24 #include "base/strings/string_util.h"
25 #include "base/strings/stringprintf.h"
26 #include "base/task/current_thread.h"
27 #include "base/task/single_thread_task_executor.h"
28 #include "base/task/single_thread_task_runner.h"
29 #include "base/test/bind.h"
30 #include "base/threading/thread_restrictions.h"
31 #include "crypto/rsa_private_key.h"
32 #include "net/base/hex_utils.h"
33 #include "net/base/ip_endpoint.h"
34 #include "net/base/net_errors.h"
35 #include "net/base/port_util.h"
36 #include "net/log/net_log_source.h"
37 #include "net/socket/next_proto.h"
38 #include "net/socket/ssl_server_socket.h"
39 #include "net/socket/stream_socket.h"
40 #include "net/socket/tcp_server_socket.h"
41 #include "net/spdy/spdy_test_util_common.h"
42 #include "net/ssl/ssl_info.h"
43 #include "net/ssl/ssl_server_config.h"
44 #include "net/test/cert_builder.h"
45 #include "net/test/cert_test_util.h"
46 #include "net/test/embedded_test_server/default_handlers.h"
47 #include "net/test/embedded_test_server/embedded_test_server_connection_listener.h"
48 #include "net/test/embedded_test_server/http_request.h"
49 #include "net/test/embedded_test_server/http_response.h"
50 #include "net/test/embedded_test_server/request_handler_util.h"
51 #include "net/test/key_util.h"
52 #include "net/test/revocation_builder.h"
53 #include "net/test/test_data_directory.h"
54 #include "net/third_party/quiche/src/quiche/spdy/core/spdy_frame_builder.h"
55 #include "third_party/abseil-cpp/absl/types/optional.h"
56 #include "third_party/boringssl/src/pki/extended_key_usage.h"
57 #include "url/origin.h"
58
59 namespace net::test_server {
60
61 namespace {
62
ServeResponseForPath(const std::string & expected_path,HttpStatusCode status_code,const std::string & content_type,const std::string & content,const HttpRequest & request)63 std::unique_ptr<HttpResponse> ServeResponseForPath(
64 const std::string& expected_path,
65 HttpStatusCode status_code,
66 const std::string& content_type,
67 const std::string& content,
68 const HttpRequest& request) {
69 if (request.GetURL().path() != expected_path)
70 return nullptr;
71
72 auto http_response = std::make_unique<BasicHttpResponse>();
73 http_response->set_code(status_code);
74 http_response->set_content_type(content_type);
75 http_response->set_content(content);
76 return http_response;
77 }
78
79 // Serves response for |expected_path| or any subpath of it.
80 // |expected_path| should not include a trailing "/".
ServeResponseForSubPaths(const std::string & expected_path,HttpStatusCode status_code,const std::string & content_type,const std::string & content,const HttpRequest & request)81 std::unique_ptr<HttpResponse> ServeResponseForSubPaths(
82 const std::string& expected_path,
83 HttpStatusCode status_code,
84 const std::string& content_type,
85 const std::string& content,
86 const HttpRequest& request) {
87 if (request.GetURL().path() != expected_path &&
88 !request.GetURL().path().starts_with(expected_path + "/")) {
89 return nullptr;
90 }
91
92 auto http_response = std::make_unique<BasicHttpResponse>();
93 http_response->set_code(status_code);
94 http_response->set_content_type(content_type);
95 http_response->set_content(content);
96 return http_response;
97 }
98
MaybeCreateOCSPResponse(CertBuilder * target,const EmbeddedTestServer::OCSPConfig & config,std::string * out_response)99 bool MaybeCreateOCSPResponse(CertBuilder* target,
100 const EmbeddedTestServer::OCSPConfig& config,
101 std::string* out_response) {
102 using OCSPResponseType = EmbeddedTestServer::OCSPConfig::ResponseType;
103
104 if (!config.single_responses.empty() &&
105 config.response_type != OCSPResponseType::kSuccessful) {
106 // OCSPConfig contained single_responses for a non-successful response.
107 return false;
108 }
109
110 if (config.response_type == OCSPResponseType::kOff) {
111 *out_response = std::string();
112 return true;
113 }
114
115 if (!target) {
116 // OCSPConfig enabled but corresponding certificate is null.
117 return false;
118 }
119
120 switch (config.response_type) {
121 case OCSPResponseType::kOff:
122 return false;
123 case OCSPResponseType::kMalformedRequest:
124 *out_response = BuildOCSPResponseError(
125 bssl::OCSPResponse::ResponseStatus::MALFORMED_REQUEST);
126 return true;
127 case OCSPResponseType::kInternalError:
128 *out_response = BuildOCSPResponseError(
129 bssl::OCSPResponse::ResponseStatus::INTERNAL_ERROR);
130 return true;
131 case OCSPResponseType::kTryLater:
132 *out_response =
133 BuildOCSPResponseError(bssl::OCSPResponse::ResponseStatus::TRY_LATER);
134 return true;
135 case OCSPResponseType::kSigRequired:
136 *out_response = BuildOCSPResponseError(
137 bssl::OCSPResponse::ResponseStatus::SIG_REQUIRED);
138 return true;
139 case OCSPResponseType::kUnauthorized:
140 *out_response = BuildOCSPResponseError(
141 bssl::OCSPResponse::ResponseStatus::UNAUTHORIZED);
142 return true;
143 case OCSPResponseType::kInvalidResponse:
144 *out_response = "3";
145 return true;
146 case OCSPResponseType::kInvalidResponseData:
147 *out_response =
148 BuildOCSPResponseWithResponseData(target->issuer()->GetKey(),
149 // OCTET_STRING { "not ocsp data" }
150 "\x04\x0dnot ocsp data");
151 return true;
152 case OCSPResponseType::kSuccessful:
153 break;
154 }
155
156 base::Time now = base::Time::Now();
157 base::Time target_not_before, target_not_after;
158 if (!target->GetValidity(&target_not_before, &target_not_after))
159 return false;
160 base::Time produced_at;
161 using OCSPProduced = EmbeddedTestServer::OCSPConfig::Produced;
162 switch (config.produced) {
163 case OCSPProduced::kValid:
164 produced_at = std::max(now - base::Days(1), target_not_before);
165 break;
166 case OCSPProduced::kBeforeCert:
167 produced_at = target_not_before - base::Days(1);
168 break;
169 case OCSPProduced::kAfterCert:
170 produced_at = target_not_after + base::Days(1);
171 break;
172 }
173
174 std::vector<OCSPBuilderSingleResponse> responses;
175 for (const auto& config_response : config.single_responses) {
176 OCSPBuilderSingleResponse response;
177 response.serial = target->GetSerialNumber();
178 if (config_response.serial ==
179 EmbeddedTestServer::OCSPConfig::SingleResponse::Serial::kMismatch) {
180 response.serial ^= 1;
181 }
182 response.cert_status = config_response.cert_status;
183 // |revocation_time| is ignored if |cert_status| is not REVOKED.
184 response.revocation_time = now - base::Days(1000);
185
186 using OCSPDate = EmbeddedTestServer::OCSPConfig::SingleResponse::Date;
187 switch (config_response.ocsp_date) {
188 case OCSPDate::kValid:
189 response.this_update = now - base::Days(1);
190 response.next_update = response.this_update + base::Days(7);
191 break;
192 case OCSPDate::kOld:
193 response.this_update = now - base::Days(8);
194 response.next_update = response.this_update + base::Days(7);
195 break;
196 case OCSPDate::kEarly:
197 response.this_update = now + base::Days(1);
198 response.next_update = response.this_update + base::Days(7);
199 break;
200 case OCSPDate::kLong:
201 response.this_update = now - base::Days(365);
202 response.next_update = response.this_update + base::Days(366);
203 break;
204 case OCSPDate::kLonger:
205 response.this_update = now - base::Days(367);
206 response.next_update = response.this_update + base::Days(368);
207 break;
208 }
209
210 responses.push_back(response);
211 }
212 *out_response =
213 BuildOCSPResponse(target->issuer()->GetSubject(),
214 target->issuer()->GetKey(), produced_at, responses);
215 return true;
216 }
217
218 } // namespace
219
EmbeddedTestServerHandle(EmbeddedTestServerHandle && other)220 EmbeddedTestServerHandle::EmbeddedTestServerHandle(
221 EmbeddedTestServerHandle&& other) {
222 operator=(std::move(other));
223 }
224
operator =(EmbeddedTestServerHandle && other)225 EmbeddedTestServerHandle& EmbeddedTestServerHandle::operator=(
226 EmbeddedTestServerHandle&& other) {
227 EmbeddedTestServerHandle temporary;
228 std::swap(other.test_server_, temporary.test_server_);
229 std::swap(temporary.test_server_, test_server_);
230 return *this;
231 }
232
EmbeddedTestServerHandle(EmbeddedTestServer * test_server)233 EmbeddedTestServerHandle::EmbeddedTestServerHandle(
234 EmbeddedTestServer* test_server)
235 : test_server_(test_server) {}
236
~EmbeddedTestServerHandle()237 EmbeddedTestServerHandle::~EmbeddedTestServerHandle() {
238 if (test_server_)
239 CHECK(test_server_->ShutdownAndWaitUntilComplete());
240 }
241
242 EmbeddedTestServer::OCSPConfig::OCSPConfig() = default;
OCSPConfig(ResponseType response_type)243 EmbeddedTestServer::OCSPConfig::OCSPConfig(ResponseType response_type)
244 : response_type(response_type) {}
OCSPConfig(std::vector<SingleResponse> single_responses,Produced produced)245 EmbeddedTestServer::OCSPConfig::OCSPConfig(
246 std::vector<SingleResponse> single_responses,
247 Produced produced)
248 : response_type(ResponseType::kSuccessful),
249 produced(produced),
250 single_responses(std::move(single_responses)) {}
251 EmbeddedTestServer::OCSPConfig::OCSPConfig(const OCSPConfig&) = default;
252 EmbeddedTestServer::OCSPConfig::OCSPConfig(OCSPConfig&&) = default;
253 EmbeddedTestServer::OCSPConfig::~OCSPConfig() = default;
254 EmbeddedTestServer::OCSPConfig& EmbeddedTestServer::OCSPConfig::operator=(
255 const OCSPConfig&) = default;
256 EmbeddedTestServer::OCSPConfig& EmbeddedTestServer::OCSPConfig::operator=(
257 OCSPConfig&&) = default;
258
259 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig() =
260 default;
261 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig(
262 const ServerCertificateConfig&) = default;
263 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig(
264 ServerCertificateConfig&&) = default;
265 EmbeddedTestServer::ServerCertificateConfig::~ServerCertificateConfig() =
266 default;
267 EmbeddedTestServer::ServerCertificateConfig&
268 EmbeddedTestServer::ServerCertificateConfig::operator=(
269 const ServerCertificateConfig&) = default;
270 EmbeddedTestServer::ServerCertificateConfig&
271 EmbeddedTestServer::ServerCertificateConfig::operator=(
272 ServerCertificateConfig&&) = default;
273
EmbeddedTestServer()274 EmbeddedTestServer::EmbeddedTestServer() : EmbeddedTestServer(TYPE_HTTP) {}
275
EmbeddedTestServer(Type type,HttpConnection::Protocol protocol)276 EmbeddedTestServer::EmbeddedTestServer(Type type,
277 HttpConnection::Protocol protocol)
278 : is_using_ssl_(type == TYPE_HTTPS), protocol_(protocol) {
279 DCHECK(thread_checker_.CalledOnValidThread());
280 // HTTP/2 is only valid by negotiation via TLS ALPN
281 DCHECK(protocol_ != HttpConnection::Protocol::kHttp2 || type == TYPE_HTTPS);
282
283 if (!is_using_ssl_)
284 return;
285 scoped_test_root_ = RegisterTestCerts();
286 }
287
~EmbeddedTestServer()288 EmbeddedTestServer::~EmbeddedTestServer() {
289 DCHECK(thread_checker_.CalledOnValidThread());
290
291 if (Started())
292 CHECK(ShutdownAndWaitUntilComplete());
293
294 {
295 base::ScopedAllowBaseSyncPrimitivesForTesting allow_wait_for_thread_join;
296 io_thread_.reset();
297 }
298 }
299
RegisterTestCerts()300 ScopedTestRoot EmbeddedTestServer::RegisterTestCerts() {
301 base::ScopedAllowBlockingForTesting allow_blocking;
302 auto root = ImportCertFromFile(GetRootCertPemPath());
303 if (!root)
304 return ScopedTestRoot();
305 return ScopedTestRoot(root);
306 }
307
SetConnectionListener(EmbeddedTestServerConnectionListener * listener)308 void EmbeddedTestServer::SetConnectionListener(
309 EmbeddedTestServerConnectionListener* listener) {
310 DCHECK(!io_thread_)
311 << "ConnectionListener must be set before starting the server.";
312 connection_listener_ = listener;
313 }
314
StartAndReturnHandle(int port)315 EmbeddedTestServerHandle EmbeddedTestServer::StartAndReturnHandle(int port) {
316 bool result = Start(port);
317 return result ? EmbeddedTestServerHandle(this) : EmbeddedTestServerHandle();
318 }
319
Start(int port,base::StringPiece address)320 bool EmbeddedTestServer::Start(int port, base::StringPiece address) {
321 bool success = InitializeAndListen(port, address);
322 if (success)
323 StartAcceptingConnections();
324 return success;
325 }
326
InitializeAndListen(int port,base::StringPiece address)327 bool EmbeddedTestServer::InitializeAndListen(int port,
328 base::StringPiece address) {
329 DCHECK(!Started());
330
331 const int max_tries = 5;
332 int num_tries = 0;
333 bool is_valid_port = false;
334
335 do {
336 if (++num_tries > max_tries) {
337 LOG(ERROR) << "Failed to listen on a valid port after " << max_tries
338 << " attempts.";
339 listen_socket_.reset();
340 return false;
341 }
342
343 listen_socket_ = std::make_unique<TCPServerSocket>(nullptr, NetLogSource());
344
345 int result =
346 listen_socket_->ListenWithAddressAndPort(address.data(), port, 10);
347 if (result) {
348 LOG(ERROR) << "Listen failed: " << ErrorToString(result);
349 listen_socket_.reset();
350 return false;
351 }
352
353 result = listen_socket_->GetLocalAddress(&local_endpoint_);
354 if (result != OK) {
355 LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
356 listen_socket_.reset();
357 return false;
358 }
359
360 port_ = local_endpoint_.port();
361 is_valid_port |= net::IsPortAllowedForScheme(
362 port_, is_using_ssl_ ? url::kHttpsScheme : url::kHttpScheme);
363 } while (!is_valid_port);
364
365 if (is_using_ssl_) {
366 base_url_ = GURL("https://" + local_endpoint_.ToString());
367 if (cert_ == CERT_MISMATCHED_NAME || cert_ == CERT_COMMON_NAME_IS_DOMAIN) {
368 base_url_ = GURL(
369 base::StringPrintf("https://localhost:%d", local_endpoint_.port()));
370 }
371 } else {
372 base_url_ = GURL("http://" + local_endpoint_.ToString());
373 }
374
375 listen_socket_->DetachFromThread();
376
377 if (is_using_ssl_ && !InitializeSSLServerContext())
378 return false;
379
380 return true;
381 }
382
UsingStaticCert() const383 bool EmbeddedTestServer::UsingStaticCert() const {
384 return !GetCertificateName().empty();
385 }
386
InitializeCertAndKeyFromFile()387 bool EmbeddedTestServer::InitializeCertAndKeyFromFile() {
388 base::ScopedAllowBlockingForTesting allow_blocking;
389 base::FilePath certs_dir(GetTestCertsDirectory());
390 std::string cert_name = GetCertificateName();
391 if (cert_name.empty())
392 return false;
393
394 x509_cert_ = CreateCertificateChainFromFile(certs_dir, cert_name,
395 X509Certificate::FORMAT_AUTO);
396 if (!x509_cert_)
397 return false;
398
399 private_key_ =
400 key_util::LoadEVP_PKEYFromPEM(certs_dir.AppendASCII(cert_name));
401 return !!private_key_;
402 }
403
GenerateCertAndKey()404 bool EmbeddedTestServer::GenerateCertAndKey() {
405 // Create AIA server and start listening. Need to have the socket initialized
406 // so the URL can be put in the AIA records of the generated certs.
407 aia_http_server_ = std::make_unique<EmbeddedTestServer>(TYPE_HTTP);
408 if (!aia_http_server_->InitializeAndListen())
409 return false;
410
411 base::ScopedAllowBlockingForTesting allow_blocking;
412 base::FilePath certs_dir(GetTestCertsDirectory());
413
414 std::unique_ptr<CertBuilder> static_root = CertBuilder::FromStaticCertFile(
415 certs_dir.AppendASCII("root_ca_cert.pem"));
416
417 auto now = base::Time::Now();
418 // Will be nullptr if cert_config_.intermediate == kNone.
419 std::unique_ptr<CertBuilder> intermediate;
420 std::unique_ptr<CertBuilder> leaf;
421
422 if (cert_config_.intermediate != IntermediateType::kNone) {
423 intermediate = CertBuilder::FromFile(
424 certs_dir.AppendASCII("intermediate_ca_cert.pem"), static_root.get());
425 if (!intermediate)
426 return false;
427 intermediate->SetValidity(now - base::Days(100), now + base::Days(1000));
428
429 leaf = CertBuilder::FromFile(certs_dir.AppendASCII("ok_cert.pem"),
430 intermediate.get());
431 } else {
432 leaf = CertBuilder::FromFile(certs_dir.AppendASCII("ok_cert.pem"),
433 static_root.get());
434 }
435 if (!leaf)
436 return false;
437
438 std::vector<GURL> leaf_ca_issuers_urls;
439 std::vector<GURL> leaf_ocsp_urls;
440
441 leaf->SetValidity(now - base::Days(1), now + base::Days(20));
442
443 if (!cert_config_.policy_oids.empty()) {
444 leaf->SetCertificatePolicies(cert_config_.policy_oids);
445 if (intermediate)
446 intermediate->SetCertificatePolicies(cert_config_.policy_oids);
447 }
448
449 if (!cert_config_.dns_names.empty() || !cert_config_.ip_addresses.empty()) {
450 leaf->SetSubjectAltNames(cert_config_.dns_names, cert_config_.ip_addresses);
451 }
452
453 if (!cert_config_.key_usages.empty()) {
454 leaf->SetKeyUsages(cert_config_.key_usages);
455 }
456
457 if (!cert_config_.embedded_scts.empty()) {
458 leaf->SetSctConfig(cert_config_.embedded_scts);
459 }
460
461 const std::string leaf_serial_text =
462 base::NumberToString(leaf->GetSerialNumber());
463 const std::string intermediate_serial_text =
464 intermediate ? base::NumberToString(intermediate->GetSerialNumber()) : "";
465
466 std::string ocsp_response;
467 if (!MaybeCreateOCSPResponse(leaf.get(), cert_config_.ocsp_config,
468 &ocsp_response)) {
469 return false;
470 }
471 if (!ocsp_response.empty()) {
472 std::string ocsp_path = "/ocsp/" + leaf_serial_text;
473 leaf_ocsp_urls.push_back(aia_http_server_->GetURL(ocsp_path));
474 aia_http_server_->RegisterRequestHandler(
475 base::BindRepeating(ServeResponseForSubPaths, ocsp_path, HTTP_OK,
476 "application/ocsp-response", ocsp_response));
477 }
478
479 std::string stapled_ocsp_response;
480 if (!MaybeCreateOCSPResponse(leaf.get(), cert_config_.stapled_ocsp_config,
481 &stapled_ocsp_response)) {
482 return false;
483 }
484 if (!stapled_ocsp_response.empty()) {
485 ssl_config_.ocsp_response = std::vector<uint8_t>(
486 stapled_ocsp_response.begin(), stapled_ocsp_response.end());
487 }
488
489 std::string intermediate_ocsp_response;
490 if (!MaybeCreateOCSPResponse(intermediate.get(),
491 cert_config_.intermediate_ocsp_config,
492 &intermediate_ocsp_response)) {
493 return false;
494 }
495 if (!intermediate_ocsp_response.empty()) {
496 std::string intermediate_ocsp_path = "/ocsp/" + intermediate_serial_text;
497 intermediate->SetCaIssuersAndOCSPUrls(
498 {}, {aia_http_server_->GetURL(intermediate_ocsp_path)});
499 aia_http_server_->RegisterRequestHandler(base::BindRepeating(
500 ServeResponseForSubPaths, intermediate_ocsp_path, HTTP_OK,
501 "application/ocsp-response", intermediate_ocsp_response));
502 }
503
504 if (cert_config_.intermediate == IntermediateType::kByAIA) {
505 std::string ca_issuers_path = "/ca_issuers/" + intermediate_serial_text;
506 leaf_ca_issuers_urls.push_back(aia_http_server_->GetURL(ca_issuers_path));
507
508 // Setup AIA server to serve the intermediate referred to by the leaf.
509 aia_http_server_->RegisterRequestHandler(
510 base::BindRepeating(ServeResponseForPath, ca_issuers_path, HTTP_OK,
511 "application/pkix-cert", intermediate->GetDER()));
512 }
513
514 if (!leaf_ca_issuers_urls.empty() || !leaf_ocsp_urls.empty()) {
515 leaf->SetCaIssuersAndOCSPUrls(leaf_ca_issuers_urls, leaf_ocsp_urls);
516 }
517
518 if (cert_config_.intermediate == IntermediateType::kByAIA) {
519 // Server certificate chain does not include the intermediate.
520 x509_cert_ = leaf->GetX509Certificate();
521 } else {
522 // Server certificate chain will include the intermediate, if there is one.
523 x509_cert_ = leaf->GetX509CertificateChain();
524 }
525
526 private_key_ = bssl::UpRef(leaf->GetKey());
527
528 // If this server is already accepting connections but is being reconfigured,
529 // start the new AIA server now. Otherwise, wait until
530 // StartAcceptingConnections so that this server and the AIA server start at
531 // the same time. (If the test only called InitializeAndListen they expect no
532 // threads to be created yet.)
533 if (io_thread_)
534 aia_http_server_->StartAcceptingConnections();
535
536 return true;
537 }
538
InitializeSSLServerContext()539 bool EmbeddedTestServer::InitializeSSLServerContext() {
540 if (UsingStaticCert()) {
541 if (!InitializeCertAndKeyFromFile())
542 return false;
543 } else {
544 if (!GenerateCertAndKey())
545 return false;
546 }
547
548 if (protocol_ == HttpConnection::Protocol::kHttp2) {
549 ssl_config_.alpn_protos = {NextProto::kProtoHTTP2};
550 if (!alps_accept_ch_.empty()) {
551 base::StringPairs origin_accept_ch;
552 size_t frame_size = spdy::kFrameHeaderSize;
553 // Figure out size and generate origins
554 for (const auto& pair : alps_accept_ch_) {
555 base::StringPiece hostname = pair.first;
556 std::string accept_ch = pair.second;
557
558 GURL url = hostname.empty() ? GetURL("/") : GetURL(hostname, "/");
559 std::string origin = url::Origin::Create(url).Serialize();
560
561 frame_size += accept_ch.size() + origin.size() +
562 (sizeof(uint16_t) * 2); // = Origin-Len + Value-Len
563
564 origin_accept_ch.push_back({std::move(origin), std::move(accept_ch)});
565 }
566
567 spdy::SpdyFrameBuilder builder(frame_size);
568 builder.BeginNewFrame(spdy::SpdyFrameType::ACCEPT_CH, 0, 0);
569 for (const auto& pair : origin_accept_ch) {
570 base::StringPiece origin = pair.first;
571 base::StringPiece accept_ch = pair.second;
572
573 builder.WriteUInt16(origin.size());
574 builder.WriteBytes(origin.data(), origin.size());
575
576 builder.WriteUInt16(accept_ch.size());
577 builder.WriteBytes(accept_ch.data(), accept_ch.size());
578 }
579
580 spdy::SpdySerializedFrame serialized_frame = builder.take();
581 DCHECK_EQ(frame_size, serialized_frame.size());
582
583 ssl_config_.application_settings[NextProto::kProtoHTTP2] =
584 std::vector<uint8_t>(
585 serialized_frame.data(),
586 serialized_frame.data() + serialized_frame.size());
587
588 ssl_config_.client_hello_callback_for_testing =
589 base::BindRepeating([](const SSL_CLIENT_HELLO* client_hello) {
590 // Configure the server to use the ALPS codepoint that the client
591 // offered.
592 const uint8_t* unused_extension_bytes;
593 size_t unused_extension_len;
594 int use_alps_new_codepoint = SSL_early_callback_ctx_extension_get(
595 client_hello, TLSEXT_TYPE_application_settings,
596 &unused_extension_bytes, &unused_extension_len);
597 // Make sure we use the right ALPS codepoint.
598 SSL_set_alps_use_new_codepoint(client_hello->ssl,
599 use_alps_new_codepoint);
600 return true;
601 });
602 }
603 }
604
605 context_ =
606 CreateSSLServerContext(x509_cert_.get(), private_key_.get(), ssl_config_);
607 return true;
608 }
609
610 EmbeddedTestServerHandle
StartAcceptingConnectionsAndReturnHandle()611 EmbeddedTestServer::StartAcceptingConnectionsAndReturnHandle() {
612 StartAcceptingConnections();
613 return EmbeddedTestServerHandle(this);
614 }
615
StartAcceptingConnections()616 void EmbeddedTestServer::StartAcceptingConnections() {
617 DCHECK(Started());
618 DCHECK(!io_thread_) << "Server must not be started while server is running";
619
620 if (aia_http_server_)
621 aia_http_server_->StartAcceptingConnections();
622
623 base::Thread::Options thread_options;
624 thread_options.message_pump_type = base::MessagePumpType::IO;
625 io_thread_ = std::make_unique<base::Thread>("EmbeddedTestServer IO Thread");
626 CHECK(io_thread_->StartWithOptions(std::move(thread_options)));
627 CHECK(io_thread_->WaitUntilThreadStarted());
628
629 io_thread_->task_runner()->PostTask(
630 FROM_HERE, base::BindOnce(&EmbeddedTestServer::DoAcceptLoop,
631 base::Unretained(this)));
632 }
633
ShutdownAndWaitUntilComplete()634 bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() {
635 DCHECK(thread_checker_.CalledOnValidThread());
636
637 // Ensure that the AIA HTTP server is no longer Started().
638 bool aia_http_server_not_started = true;
639 if (aia_http_server_ && aia_http_server_->Started()) {
640 aia_http_server_not_started =
641 aia_http_server_->ShutdownAndWaitUntilComplete();
642 }
643
644 // Return false if either this or the AIA HTTP server are still Started().
645 return PostTaskToIOThreadAndWait(
646 base::BindOnce(&EmbeddedTestServer::ShutdownOnIOThread,
647 base::Unretained(this))) &&
648 aia_http_server_not_started;
649 }
650
651 // static
GetRootCertPemPath()652 base::FilePath EmbeddedTestServer::GetRootCertPemPath() {
653 return GetTestCertsDirectory().AppendASCII("root_ca_cert.pem");
654 }
655
ShutdownOnIOThread()656 void EmbeddedTestServer::ShutdownOnIOThread() {
657 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
658 weak_factory_.InvalidateWeakPtrs();
659 listen_socket_.reset();
660 connections_.clear();
661 }
662
HandleRequest(base::WeakPtr<HttpResponseDelegate> delegate,std::unique_ptr<HttpRequest> request)663 void EmbeddedTestServer::HandleRequest(
664 base::WeakPtr<HttpResponseDelegate> delegate,
665 std::unique_ptr<HttpRequest> request) {
666 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
667 request->base_url = base_url_;
668
669 for (const auto& monitor : request_monitors_)
670 monitor.Run(*request);
671
672 std::unique_ptr<HttpResponse> response;
673
674 for (const auto& handler : request_handlers_) {
675 response = handler.Run(*request);
676 if (response)
677 break;
678 }
679
680 if (!response) {
681 for (const auto& handler : default_request_handlers_) {
682 response = handler.Run(*request);
683 if (response)
684 break;
685 }
686 }
687
688 if (!response) {
689 LOG(WARNING) << "Request not handled. Returning 404: "
690 << request->relative_url;
691 auto not_found_response = std::make_unique<BasicHttpResponse>();
692 not_found_response->set_code(HTTP_NOT_FOUND);
693 response = std::move(not_found_response);
694 }
695
696 HttpResponse* const response_ptr = response.get();
697 delegate->AddResponse(std::move(response));
698 response_ptr->SendResponse(delegate);
699 }
700
GetURL(base::StringPiece relative_url) const701 GURL EmbeddedTestServer::GetURL(base::StringPiece relative_url) const {
702 DCHECK(Started()) << "You must start the server first.";
703 DCHECK(relative_url.starts_with("/")) << relative_url;
704 return base_url_.Resolve(relative_url);
705 }
706
GetURL(base::StringPiece hostname,base::StringPiece relative_url) const707 GURL EmbeddedTestServer::GetURL(base::StringPiece hostname,
708 base::StringPiece relative_url) const {
709 GURL local_url = GetURL(relative_url);
710 GURL::Replacements replace_host;
711 replace_host.SetHostStr(hostname);
712 return local_url.ReplaceComponents(replace_host);
713 }
714
GetOrigin(const absl::optional<std::string> & hostname) const715 url::Origin EmbeddedTestServer::GetOrigin(
716 const absl::optional<std::string>& hostname) const {
717 if (hostname)
718 return url::Origin::Create(GetURL(*hostname, "/"));
719 return url::Origin::Create(base_url_);
720 }
721
GetAddressList(AddressList * address_list) const722 bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const {
723 *address_list = AddressList(local_endpoint_);
724 return true;
725 }
726
GetIPLiteralString() const727 std::string EmbeddedTestServer::GetIPLiteralString() const {
728 return local_endpoint_.address().ToString();
729 }
730
SetSSLConfigInternal(ServerCertificate cert,const ServerCertificateConfig * cert_config,const SSLServerConfig & ssl_config)731 void EmbeddedTestServer::SetSSLConfigInternal(
732 ServerCertificate cert,
733 const ServerCertificateConfig* cert_config,
734 const SSLServerConfig& ssl_config) {
735 DCHECK(!Started());
736 cert_ = cert;
737 DCHECK(!cert_config || cert == CERT_AUTO);
738 cert_config_ = cert_config ? *cert_config : ServerCertificateConfig();
739 x509_cert_ = nullptr;
740 private_key_ = nullptr;
741 ssl_config_ = ssl_config;
742 }
743
SetSSLConfig(ServerCertificate cert,const SSLServerConfig & ssl_config)744 void EmbeddedTestServer::SetSSLConfig(ServerCertificate cert,
745 const SSLServerConfig& ssl_config) {
746 SetSSLConfigInternal(cert, /*cert_config=*/nullptr, ssl_config);
747 }
748
SetSSLConfig(ServerCertificate cert)749 void EmbeddedTestServer::SetSSLConfig(ServerCertificate cert) {
750 SetSSLConfigInternal(cert, /*cert_config=*/nullptr, SSLServerConfig());
751 }
752
SetSSLConfig(const ServerCertificateConfig & cert_config,const SSLServerConfig & ssl_config)753 void EmbeddedTestServer::SetSSLConfig(
754 const ServerCertificateConfig& cert_config,
755 const SSLServerConfig& ssl_config) {
756 SetSSLConfigInternal(CERT_AUTO, &cert_config, ssl_config);
757 }
758
SetSSLConfig(const ServerCertificateConfig & cert_config)759 void EmbeddedTestServer::SetSSLConfig(
760 const ServerCertificateConfig& cert_config) {
761 SetSSLConfigInternal(CERT_AUTO, &cert_config, SSLServerConfig());
762 }
763
ResetSSLConfigOnIOThread(ServerCertificate cert,const SSLServerConfig & ssl_config)764 bool EmbeddedTestServer::ResetSSLConfigOnIOThread(
765 ServerCertificate cert,
766 const SSLServerConfig& ssl_config) {
767 cert_ = cert;
768 cert_config_ = ServerCertificateConfig();
769 ssl_config_ = ssl_config;
770 connections_.clear();
771 return InitializeSSLServerContext();
772 }
773
ResetSSLConfig(ServerCertificate cert,const SSLServerConfig & ssl_config)774 bool EmbeddedTestServer::ResetSSLConfig(ServerCertificate cert,
775 const SSLServerConfig& ssl_config) {
776 return PostTaskToIOThreadAndWaitWithResult(
777 base::BindOnce(&EmbeddedTestServer::ResetSSLConfigOnIOThread,
778 base::Unretained(this), cert, ssl_config));
779 }
780
GetCertificateName() const781 std::string EmbeddedTestServer::GetCertificateName() const {
782 DCHECK(is_using_ssl_);
783 switch (cert_) {
784 case CERT_OK:
785 case CERT_MISMATCHED_NAME:
786 return "ok_cert.pem";
787 case CERT_COMMON_NAME_IS_DOMAIN:
788 return "localhost_cert.pem";
789 case CERT_EXPIRED:
790 return "expired_cert.pem";
791 case CERT_CHAIN_WRONG_ROOT:
792 // This chain uses its own dedicated test root certificate to avoid
793 // side-effects that may affect testing.
794 return "redundant-server-chain.pem";
795 case CERT_COMMON_NAME_ONLY:
796 return "common_name_only.pem";
797 case CERT_SHA1_LEAF:
798 return "sha1_leaf.pem";
799 case CERT_OK_BY_INTERMEDIATE:
800 return "ok_cert_by_intermediate.pem";
801 case CERT_BAD_VALIDITY:
802 return "bad_validity.pem";
803 case CERT_TEST_NAMES:
804 return "test_names.pem";
805 case CERT_KEY_USAGE_RSA_ENCIPHERMENT:
806 return "key_usage_rsa_keyencipherment.pem";
807 case CERT_KEY_USAGE_RSA_DIGITAL_SIGNATURE:
808 return "key_usage_rsa_digitalsignature.pem";
809 case CERT_AUTO:
810 return std::string();
811 }
812
813 return "ok_cert.pem";
814 }
815
GetCertificate()816 scoped_refptr<X509Certificate> EmbeddedTestServer::GetCertificate() {
817 DCHECK(is_using_ssl_);
818 if (!x509_cert_) {
819 // Some tests want to get the certificate before the server has been
820 // initialized, so load it now if necessary. This is only possible if using
821 // a static certificate.
822 // TODO(mattm): change contract to require initializing first in all cases,
823 // update callers.
824 CHECK(UsingStaticCert());
825 // TODO(mattm): change contract to return nullptr on error instead of
826 // CHECKing, update callers.
827 CHECK(InitializeCertAndKeyFromFile());
828 }
829 return x509_cert_;
830 }
831
ServeFilesFromDirectory(const base::FilePath & directory)832 void EmbeddedTestServer::ServeFilesFromDirectory(
833 const base::FilePath& directory) {
834 RegisterDefaultHandler(base::BindRepeating(&HandleFileRequest, directory));
835 }
836
ServeFilesFromSourceDirectory(base::StringPiece relative)837 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
838 base::StringPiece relative) {
839 base::FilePath test_data_dir;
840 CHECK(base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir));
841 ServeFilesFromDirectory(test_data_dir.AppendASCII(relative));
842 }
843
ServeFilesFromSourceDirectory(const base::FilePath & relative)844 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
845 const base::FilePath& relative) {
846 ServeFilesFromDirectory(GetFullPathFromSourceDirectory(relative));
847 }
848
AddDefaultHandlers(const base::FilePath & directory)849 void EmbeddedTestServer::AddDefaultHandlers(const base::FilePath& directory) {
850 ServeFilesFromSourceDirectory(directory);
851 AddDefaultHandlers();
852 }
853
AddDefaultHandlers()854 void EmbeddedTestServer::AddDefaultHandlers() {
855 RegisterDefaultHandlers(this);
856 }
857
GetFullPathFromSourceDirectory(const base::FilePath & relative)858 base::FilePath EmbeddedTestServer::GetFullPathFromSourceDirectory(
859 const base::FilePath& relative) {
860 base::FilePath test_data_dir;
861 CHECK(base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir));
862 return test_data_dir.Append(relative);
863 }
864
RegisterRequestHandler(const HandleRequestCallback & callback)865 void EmbeddedTestServer::RegisterRequestHandler(
866 const HandleRequestCallback& callback) {
867 DCHECK(!io_thread_)
868 << "Handlers must be registered before starting the server.";
869 request_handlers_.push_back(callback);
870 }
871
RegisterRequestMonitor(const MonitorRequestCallback & callback)872 void EmbeddedTestServer::RegisterRequestMonitor(
873 const MonitorRequestCallback& callback) {
874 DCHECK(!io_thread_)
875 << "Monitors must be registered before starting the server.";
876 request_monitors_.push_back(callback);
877 }
878
RegisterDefaultHandler(const HandleRequestCallback & callback)879 void EmbeddedTestServer::RegisterDefaultHandler(
880 const HandleRequestCallback& callback) {
881 DCHECK(!io_thread_)
882 << "Handlers must be registered before starting the server.";
883 default_request_handlers_.push_back(callback);
884 }
885
DoSSLUpgrade(std::unique_ptr<StreamSocket> connection)886 std::unique_ptr<SSLServerSocket> EmbeddedTestServer::DoSSLUpgrade(
887 std::unique_ptr<StreamSocket> connection) {
888 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
889
890 return context_->CreateSSLServerSocket(std::move(connection));
891 }
892
DoAcceptLoop()893 void EmbeddedTestServer::DoAcceptLoop() {
894 while (true) {
895 int rv = listen_socket_->Accept(
896 &accepted_socket_,
897 base::BindOnce(&EmbeddedTestServer::OnAcceptCompleted,
898 base::Unretained(this)));
899 if (rv != OK)
900 return;
901
902 HandleAcceptResult(std::move(accepted_socket_));
903 }
904 }
905
FlushAllSocketsAndConnectionsOnUIThread()906 bool EmbeddedTestServer::FlushAllSocketsAndConnectionsOnUIThread() {
907 return PostTaskToIOThreadAndWait(
908 base::BindOnce(&EmbeddedTestServer::FlushAllSocketsAndConnections,
909 base::Unretained(this)));
910 }
911
FlushAllSocketsAndConnections()912 void EmbeddedTestServer::FlushAllSocketsAndConnections() {
913 connections_.clear();
914 }
915
SetAlpsAcceptCH(std::string hostname,std::string accept_ch)916 void EmbeddedTestServer::SetAlpsAcceptCH(std::string hostname,
917 std::string accept_ch) {
918 alps_accept_ch_.insert_or_assign(std::move(hostname), std::move(accept_ch));
919 }
920
OnAcceptCompleted(int rv)921 void EmbeddedTestServer::OnAcceptCompleted(int rv) {
922 DCHECK_NE(ERR_IO_PENDING, rv);
923 HandleAcceptResult(std::move(accepted_socket_));
924 DoAcceptLoop();
925 }
926
OnHandshakeDone(HttpConnection * connection,int rv)927 void EmbeddedTestServer::OnHandshakeDone(HttpConnection* connection, int rv) {
928 if (connection->Socket()->IsConnected()) {
929 connection->OnSocketReady();
930 } else {
931 RemoveConnection(connection);
932 }
933 }
934
HandleAcceptResult(std::unique_ptr<StreamSocket> socket_ptr)935 void EmbeddedTestServer::HandleAcceptResult(
936 std::unique_ptr<StreamSocket> socket_ptr) {
937 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
938 if (connection_listener_)
939 socket_ptr = connection_listener_->AcceptedSocket(std::move(socket_ptr));
940
941 if (!is_using_ssl_) {
942 AddConnection(std::move(socket_ptr))->OnSocketReady();
943 return;
944 }
945
946 socket_ptr = DoSSLUpgrade(std::move(socket_ptr));
947
948 StreamSocket* socket = socket_ptr.get();
949 HttpConnection* connection = AddConnection(std::move(socket_ptr));
950
951 int rv = static_cast<SSLServerSocket*>(socket)->Handshake(
952 base::BindOnce(&EmbeddedTestServer::OnHandshakeDone,
953 base::Unretained(this), connection));
954 if (rv != ERR_IO_PENDING)
955 OnHandshakeDone(connection, rv);
956 }
957
AddConnection(std::unique_ptr<StreamSocket> socket_ptr)958 HttpConnection* EmbeddedTestServer::AddConnection(
959 std::unique_ptr<StreamSocket> socket_ptr) {
960 StreamSocket* socket = socket_ptr.get();
961 std::unique_ptr<HttpConnection> connection_ptr = HttpConnection::Create(
962 std::move(socket_ptr), connection_listener_, this, protocol_);
963 HttpConnection* connection = connection_ptr.get();
964 connections_[socket] = std::move(connection_ptr);
965
966 return connection;
967 }
968
RemoveConnection(HttpConnection * connection,EmbeddedTestServerConnectionListener * listener)969 void EmbeddedTestServer::RemoveConnection(
970 HttpConnection* connection,
971 EmbeddedTestServerConnectionListener* listener) {
972 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
973 DCHECK(connection);
974 DCHECK_EQ(1u, connections_.count(connection->Socket()));
975
976 StreamSocket* raw_socket = connection->Socket();
977 std::unique_ptr<StreamSocket> socket = connection->TakeSocket();
978 connections_.erase(raw_socket);
979
980 if (listener && socket && socket->IsConnected())
981 listener->OnResponseCompletedSuccessfully(std::move(socket));
982 }
983
PostTaskToIOThreadAndWait(base::OnceClosure closure)984 bool EmbeddedTestServer::PostTaskToIOThreadAndWait(base::OnceClosure closure) {
985 // Note that PostTaskAndReply below requires
986 // base::SingleThreadTaskRunner::GetCurrentDefault() to return a task runner
987 // for posting the reply task. However, in order to make EmbeddedTestServer
988 // universally usable, it needs to cope with the situation where it's running
989 // on a thread on which a task executor is not (yet) available or has been
990 // destroyed already.
991 //
992 // To handle this situation, create temporary task executor to support the
993 // PostTaskAndReply operation if the current thread has no task executor.
994 // TODO(mattm): Is this still necessary/desirable? Try removing this and see
995 // if anything breaks.
996 std::unique_ptr<base::SingleThreadTaskExecutor> temporary_loop;
997 if (!base::CurrentThread::Get())
998 temporary_loop = std::make_unique<base::SingleThreadTaskExecutor>();
999
1000 base::RunLoop run_loop;
1001 if (!io_thread_->task_runner()->PostTaskAndReply(
1002 FROM_HERE, std::move(closure), run_loop.QuitClosure())) {
1003 return false;
1004 }
1005 run_loop.Run();
1006
1007 return true;
1008 }
1009
PostTaskToIOThreadAndWaitWithResult(base::OnceCallback<bool ()> task)1010 bool EmbeddedTestServer::PostTaskToIOThreadAndWaitWithResult(
1011 base::OnceCallback<bool()> task) {
1012 // Note that PostTaskAndReply below requires
1013 // base::SingleThreadTaskRunner::GetCurrentDefault() to return a task runner
1014 // for posting the reply task. However, in order to make EmbeddedTestServer
1015 // universally usable, it needs to cope with the situation where it's running
1016 // on a thread on which a task executor is not (yet) available or has been
1017 // destroyed already.
1018 //
1019 // To handle this situation, create temporary task executor to support the
1020 // PostTaskAndReply operation if the current thread has no task executor.
1021 // TODO(mattm): Is this still necessary/desirable? Try removing this and see
1022 // if anything breaks.
1023 std::unique_ptr<base::SingleThreadTaskExecutor> temporary_loop;
1024 if (!base::CurrentThread::Get())
1025 temporary_loop = std::make_unique<base::SingleThreadTaskExecutor>();
1026
1027 base::RunLoop run_loop;
1028 bool task_result = false;
1029 if (!io_thread_->task_runner()->PostTaskAndReplyWithResult(
1030 FROM_HERE, std::move(task),
1031 base::BindOnce(base::BindLambdaForTesting([&](bool result) {
1032 task_result = result;
1033 run_loop.Quit();
1034 })))) {
1035 return false;
1036 }
1037 run_loop.Run();
1038
1039 return task_result;
1040 }
1041
1042 } // namespace net::test_server
1043