1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/test/embedded_test_server/embedded_test_server.h"
11
12 #include <stdint.h>
13
14 #include <memory>
15 #include <optional>
16 #include <string_view>
17 #include <utility>
18
19 #include "base/files/file_path.h"
20 #include "base/functional/bind.h"
21 #include "base/functional/callback_forward.h"
22 #include "base/functional/callback_helpers.h"
23 #include "base/location.h"
24 #include "base/logging.h"
25 #include "base/message_loop/message_pump_type.h"
26 #include "base/path_service.h"
27 #include "base/process/process_metrics.h"
28 #include "base/run_loop.h"
29 #include "base/strings/string_number_conversions.h"
30 #include "base/strings/string_util.h"
31 #include "base/strings/stringprintf.h"
32 #include "base/task/current_thread.h"
33 #include "base/task/single_thread_task_executor.h"
34 #include "base/task/single_thread_task_runner.h"
35 #include "base/test/bind.h"
36 #include "base/threading/thread_restrictions.h"
37 #include "crypto/rsa_private_key.h"
38 #include "net/base/hex_utils.h"
39 #include "net/base/ip_address.h"
40 #include "net/base/ip_endpoint.h"
41 #include "net/base/net_errors.h"
42 #include "net/base/port_util.h"
43 #include "net/log/net_log_source.h"
44 #include "net/socket/next_proto.h"
45 #include "net/socket/ssl_server_socket.h"
46 #include "net/socket/stream_socket.h"
47 #include "net/socket/tcp_server_socket.h"
48 #include "net/spdy/spdy_test_util_common.h"
49 #include "net/ssl/ssl_info.h"
50 #include "net/ssl/ssl_server_config.h"
51 #include "net/test/cert_builder.h"
52 #include "net/test/cert_test_util.h"
53 #include "net/test/embedded_test_server/default_handlers.h"
54 #include "net/test/embedded_test_server/embedded_test_server_connection_listener.h"
55 #include "net/test/embedded_test_server/http_request.h"
56 #include "net/test/embedded_test_server/http_response.h"
57 #include "net/test/embedded_test_server/request_handler_util.h"
58 #include "net/test/key_util.h"
59 #include "net/test/revocation_builder.h"
60 #include "net/test/test_data_directory.h"
61 #include "net/third_party/quiche/src/quiche/http2/core/spdy_frame_builder.h"
62 #include "third_party/boringssl/src/pki/extended_key_usage.h"
63 #include "url/origin.h"
64
65 namespace net::test_server {
66
67 namespace {
68
ServeResponseForPath(const std::string & expected_path,HttpStatusCode status_code,const std::string & content_type,const std::string & content,const HttpRequest & request)69 std::unique_ptr<HttpResponse> ServeResponseForPath(
70 const std::string& expected_path,
71 HttpStatusCode status_code,
72 const std::string& content_type,
73 const std::string& content,
74 const HttpRequest& request) {
75 if (request.GetURL().path() != expected_path)
76 return nullptr;
77
78 auto http_response = std::make_unique<BasicHttpResponse>();
79 http_response->set_code(status_code);
80 http_response->set_content_type(content_type);
81 http_response->set_content(content);
82 return http_response;
83 }
84
85 // Serves response for |expected_path| or any subpath of it.
86 // |expected_path| should not include a trailing "/".
ServeResponseForSubPaths(const std::string & expected_path,HttpStatusCode status_code,const std::string & content_type,const std::string & content,const HttpRequest & request)87 std::unique_ptr<HttpResponse> ServeResponseForSubPaths(
88 const std::string& expected_path,
89 HttpStatusCode status_code,
90 const std::string& content_type,
91 const std::string& content,
92 const HttpRequest& request) {
93 if (request.GetURL().path() != expected_path &&
94 !request.GetURL().path().starts_with(expected_path + "/")) {
95 return nullptr;
96 }
97
98 auto http_response = std::make_unique<BasicHttpResponse>();
99 http_response->set_code(status_code);
100 http_response->set_content_type(content_type);
101 http_response->set_content(content);
102 return http_response;
103 }
104
MaybeCreateOCSPResponse(CertBuilder * target,const EmbeddedTestServer::OCSPConfig & config,std::string * out_response)105 bool MaybeCreateOCSPResponse(CertBuilder* target,
106 const EmbeddedTestServer::OCSPConfig& config,
107 std::string* out_response) {
108 using OCSPResponseType = EmbeddedTestServer::OCSPConfig::ResponseType;
109
110 if (!config.single_responses.empty() &&
111 config.response_type != OCSPResponseType::kSuccessful) {
112 // OCSPConfig contained single_responses for a non-successful response.
113 return false;
114 }
115
116 if (config.response_type == OCSPResponseType::kOff) {
117 *out_response = std::string();
118 return true;
119 }
120
121 if (!target) {
122 // OCSPConfig enabled but corresponding certificate is null.
123 return false;
124 }
125
126 switch (config.response_type) {
127 case OCSPResponseType::kOff:
128 return false;
129 case OCSPResponseType::kMalformedRequest:
130 *out_response = BuildOCSPResponseError(
131 bssl::OCSPResponse::ResponseStatus::MALFORMED_REQUEST);
132 return true;
133 case OCSPResponseType::kInternalError:
134 *out_response = BuildOCSPResponseError(
135 bssl::OCSPResponse::ResponseStatus::INTERNAL_ERROR);
136 return true;
137 case OCSPResponseType::kTryLater:
138 *out_response =
139 BuildOCSPResponseError(bssl::OCSPResponse::ResponseStatus::TRY_LATER);
140 return true;
141 case OCSPResponseType::kSigRequired:
142 *out_response = BuildOCSPResponseError(
143 bssl::OCSPResponse::ResponseStatus::SIG_REQUIRED);
144 return true;
145 case OCSPResponseType::kUnauthorized:
146 *out_response = BuildOCSPResponseError(
147 bssl::OCSPResponse::ResponseStatus::UNAUTHORIZED);
148 return true;
149 case OCSPResponseType::kInvalidResponse:
150 *out_response = "3";
151 return true;
152 case OCSPResponseType::kInvalidResponseData:
153 *out_response =
154 BuildOCSPResponseWithResponseData(target->issuer()->GetKey(),
155 // OCTET_STRING { "not ocsp data" }
156 "\x04\x0dnot ocsp data");
157 return true;
158 case OCSPResponseType::kSuccessful:
159 break;
160 }
161
162 base::Time now = base::Time::Now();
163 base::Time target_not_before, target_not_after;
164 if (!target->GetValidity(&target_not_before, &target_not_after))
165 return false;
166 base::Time produced_at;
167 using OCSPProduced = EmbeddedTestServer::OCSPConfig::Produced;
168 switch (config.produced) {
169 case OCSPProduced::kValid:
170 produced_at = std::max(now - base::Days(1), target_not_before);
171 break;
172 case OCSPProduced::kBeforeCert:
173 produced_at = target_not_before - base::Days(1);
174 break;
175 case OCSPProduced::kAfterCert:
176 produced_at = target_not_after + base::Days(1);
177 break;
178 }
179
180 std::vector<OCSPBuilderSingleResponse> responses;
181 for (const auto& config_response : config.single_responses) {
182 OCSPBuilderSingleResponse response;
183 response.serial = target->GetSerialNumber();
184 if (config_response.serial ==
185 EmbeddedTestServer::OCSPConfig::SingleResponse::Serial::kMismatch) {
186 response.serial ^= 1;
187 }
188 response.cert_status = config_response.cert_status;
189 // |revocation_time| is ignored if |cert_status| is not REVOKED.
190 response.revocation_time = now - base::Days(1000);
191
192 using OCSPDate = EmbeddedTestServer::OCSPConfig::SingleResponse::Date;
193 switch (config_response.ocsp_date) {
194 case OCSPDate::kValid:
195 response.this_update = now - base::Days(1);
196 response.next_update = response.this_update + base::Days(7);
197 break;
198 case OCSPDate::kOld:
199 response.this_update = now - base::Days(8);
200 response.next_update = response.this_update + base::Days(7);
201 break;
202 case OCSPDate::kEarly:
203 response.this_update = now + base::Days(1);
204 response.next_update = response.this_update + base::Days(7);
205 break;
206 case OCSPDate::kLong:
207 response.this_update = now - base::Days(365);
208 response.next_update = response.this_update + base::Days(366);
209 break;
210 case OCSPDate::kLonger:
211 response.this_update = now - base::Days(367);
212 response.next_update = response.this_update + base::Days(368);
213 break;
214 }
215
216 responses.push_back(response);
217 }
218 *out_response =
219 BuildOCSPResponse(target->issuer()->GetSubject(),
220 target->issuer()->GetKey(), produced_at, responses);
221 return true;
222 }
223
DispatchResponseToDelegate(std::unique_ptr<HttpResponse> response,base::WeakPtr<HttpResponseDelegate> delegate)224 void DispatchResponseToDelegate(std::unique_ptr<HttpResponse> response,
225 base::WeakPtr<HttpResponseDelegate> delegate) {
226 HttpResponse* const response_ptr = response.get();
227 delegate->AddResponse(std::move(response));
228 response_ptr->SendResponse(delegate);
229 }
230
231 } // namespace
232
EmbeddedTestServerHandle(EmbeddedTestServerHandle && other)233 EmbeddedTestServerHandle::EmbeddedTestServerHandle(
234 EmbeddedTestServerHandle&& other) {
235 operator=(std::move(other));
236 }
237
operator =(EmbeddedTestServerHandle && other)238 EmbeddedTestServerHandle& EmbeddedTestServerHandle::operator=(
239 EmbeddedTestServerHandle&& other) {
240 EmbeddedTestServerHandle temporary;
241 std::swap(other.test_server_, temporary.test_server_);
242 std::swap(temporary.test_server_, test_server_);
243 return *this;
244 }
245
EmbeddedTestServerHandle(EmbeddedTestServer * test_server)246 EmbeddedTestServerHandle::EmbeddedTestServerHandle(
247 EmbeddedTestServer* test_server)
248 : test_server_(test_server) {}
249
~EmbeddedTestServerHandle()250 EmbeddedTestServerHandle::~EmbeddedTestServerHandle() {
251 if (test_server_)
252 CHECK(test_server_->ShutdownAndWaitUntilComplete());
253 }
254
255 EmbeddedTestServer::OCSPConfig::OCSPConfig() = default;
OCSPConfig(ResponseType response_type)256 EmbeddedTestServer::OCSPConfig::OCSPConfig(ResponseType response_type)
257 : response_type(response_type) {}
OCSPConfig(std::vector<SingleResponse> single_responses,Produced produced)258 EmbeddedTestServer::OCSPConfig::OCSPConfig(
259 std::vector<SingleResponse> single_responses,
260 Produced produced)
261 : response_type(ResponseType::kSuccessful),
262 produced(produced),
263 single_responses(std::move(single_responses)) {}
264 EmbeddedTestServer::OCSPConfig::OCSPConfig(const OCSPConfig&) = default;
265 EmbeddedTestServer::OCSPConfig::OCSPConfig(OCSPConfig&&) = default;
266 EmbeddedTestServer::OCSPConfig::~OCSPConfig() = default;
267 EmbeddedTestServer::OCSPConfig& EmbeddedTestServer::OCSPConfig::operator=(
268 const OCSPConfig&) = default;
269 EmbeddedTestServer::OCSPConfig& EmbeddedTestServer::OCSPConfig::operator=(
270 OCSPConfig&&) = default;
271
272 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig() =
273 default;
274 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig(
275 const ServerCertificateConfig&) = default;
276 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig(
277 ServerCertificateConfig&&) = default;
278 EmbeddedTestServer::ServerCertificateConfig::~ServerCertificateConfig() =
279 default;
280 EmbeddedTestServer::ServerCertificateConfig&
281 EmbeddedTestServer::ServerCertificateConfig::operator=(
282 const ServerCertificateConfig&) = default;
283 EmbeddedTestServer::ServerCertificateConfig&
284 EmbeddedTestServer::ServerCertificateConfig::operator=(
285 ServerCertificateConfig&&) = default;
286
EmbeddedTestServer()287 EmbeddedTestServer::EmbeddedTestServer() : EmbeddedTestServer(TYPE_HTTP) {}
288
EmbeddedTestServer(Type type,HttpConnection::Protocol protocol)289 EmbeddedTestServer::EmbeddedTestServer(Type type,
290 HttpConnection::Protocol protocol)
291 : is_using_ssl_(type == TYPE_HTTPS), protocol_(protocol) {
292 DCHECK(thread_checker_.CalledOnValidThread());
293 // HTTP/2 is only valid by negotiation via TLS ALPN
294 DCHECK(protocol_ != HttpConnection::Protocol::kHttp2 || type == TYPE_HTTPS);
295
296 if (!is_using_ssl_)
297 return;
298 scoped_test_root_ = RegisterTestCerts();
299 }
300
~EmbeddedTestServer()301 EmbeddedTestServer::~EmbeddedTestServer() {
302 DCHECK(thread_checker_.CalledOnValidThread());
303
304 if (Started())
305 CHECK(ShutdownAndWaitUntilComplete());
306
307 {
308 base::ScopedAllowBaseSyncPrimitivesForTesting allow_wait_for_thread_join;
309 io_thread_.reset();
310 }
311 }
312
RegisterTestCerts()313 ScopedTestRoot EmbeddedTestServer::RegisterTestCerts() {
314 base::ScopedAllowBlockingForTesting allow_blocking;
315 auto root = ImportCertFromFile(GetRootCertPemPath());
316 if (!root)
317 return ScopedTestRoot();
318 return ScopedTestRoot(root);
319 }
320
SetConnectionListener(EmbeddedTestServerConnectionListener * listener)321 void EmbeddedTestServer::SetConnectionListener(
322 EmbeddedTestServerConnectionListener* listener) {
323 DCHECK(!io_thread_)
324 << "ConnectionListener must be set before starting the server.";
325 connection_listener_ = listener;
326 }
327
StartAndReturnHandle(int port)328 EmbeddedTestServerHandle EmbeddedTestServer::StartAndReturnHandle(int port) {
329 bool result = Start(port);
330 return result ? EmbeddedTestServerHandle(this) : EmbeddedTestServerHandle();
331 }
332
Start(int port,std::string_view address)333 bool EmbeddedTestServer::Start(int port, std::string_view address) {
334 bool success = InitializeAndListen(port, address);
335 if (success)
336 StartAcceptingConnections();
337 return success;
338 }
339
InitializeAndListen(int port,std::string_view address)340 bool EmbeddedTestServer::InitializeAndListen(int port,
341 std::string_view address) {
342 DCHECK(!Started());
343
344 const int max_tries = 5;
345 int num_tries = 0;
346 bool is_valid_port = false;
347
348 do {
349 if (++num_tries > max_tries) {
350 LOG(ERROR) << "Failed to listen on a valid port after " << max_tries
351 << " attempts.";
352 listen_socket_.reset();
353 return false;
354 }
355
356 listen_socket_ = std::make_unique<TCPServerSocket>(nullptr, NetLogSource());
357
358 int result =
359 listen_socket_->ListenWithAddressAndPort(address.data(), port, 10);
360 if (result) {
361 LOG(ERROR) << "Listen failed: " << ErrorToString(result);
362 listen_socket_.reset();
363 return false;
364 }
365
366 result = listen_socket_->GetLocalAddress(&local_endpoint_);
367 if (result != OK) {
368 LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
369 listen_socket_.reset();
370 return false;
371 }
372
373 port_ = local_endpoint_.port();
374 is_valid_port |= net::IsPortAllowedForScheme(
375 port_, is_using_ssl_ ? url::kHttpsScheme : url::kHttpScheme);
376 } while (!is_valid_port);
377
378 if (is_using_ssl_) {
379 base_url_ = GURL("https://" + local_endpoint_.ToString());
380 if (cert_ == CERT_MISMATCHED_NAME || cert_ == CERT_COMMON_NAME_IS_DOMAIN) {
381 base_url_ = GURL(
382 base::StringPrintf("https://localhost:%d", local_endpoint_.port()));
383 }
384 } else {
385 base_url_ = GURL("http://" + local_endpoint_.ToString());
386 }
387
388 listen_socket_->DetachFromThread();
389
390 if (is_using_ssl_ && !InitializeSSLServerContext())
391 return false;
392
393 return true;
394 }
395
UsingStaticCert() const396 bool EmbeddedTestServer::UsingStaticCert() const {
397 return !GetCertificateName().empty();
398 }
399
InitializeCertAndKeyFromFile()400 bool EmbeddedTestServer::InitializeCertAndKeyFromFile() {
401 base::ScopedAllowBlockingForTesting allow_blocking;
402 base::FilePath certs_dir(GetTestCertsDirectory());
403 std::string cert_name = GetCertificateName();
404 if (cert_name.empty())
405 return false;
406
407 x509_cert_ = CreateCertificateChainFromFile(certs_dir, cert_name,
408 X509Certificate::FORMAT_AUTO);
409 if (!x509_cert_)
410 return false;
411
412 private_key_ =
413 key_util::LoadEVP_PKEYFromPEM(certs_dir.AppendASCII(cert_name));
414 return !!private_key_;
415 }
416
GenerateCertAndKey()417 bool EmbeddedTestServer::GenerateCertAndKey() {
418 // Create AIA server and start listening. Need to have the socket initialized
419 // so the URL can be put in the AIA records of the generated certs.
420 aia_http_server_ = std::make_unique<EmbeddedTestServer>(TYPE_HTTP);
421 if (!aia_http_server_->InitializeAndListen())
422 return false;
423
424 base::ScopedAllowBlockingForTesting allow_blocking;
425 base::FilePath certs_dir(GetTestCertsDirectory());
426 auto now = base::Time::Now();
427
428 std::unique_ptr<CertBuilder> root;
429 switch (cert_config_.root) {
430 case RootType::kTestRootCa:
431 root = CertBuilder::FromStaticCertFile(
432 certs_dir.AppendASCII("root_ca_cert.pem"));
433 break;
434 case RootType::kUniqueRoot:
435 root = std::make_unique<CertBuilder>(nullptr, nullptr);
436 root->SetValidity(now - base::Days(100), now + base::Days(1000));
437 root->SetBasicConstraints(/*is_ca=*/true, /*path_len=*/-1);
438 root->SetKeyUsages(
439 {bssl::KEY_USAGE_BIT_KEY_CERT_SIGN, bssl::KEY_USAGE_BIT_CRL_SIGN});
440 if (!cert_config_.root_dns_names.empty()) {
441 root->SetSubjectAltNames(cert_config_.root_dns_names, {});
442 }
443 break;
444 }
445
446 // Will be nullptr if cert_config_.intermediate == kNone.
447 std::unique_ptr<CertBuilder> intermediate;
448 std::unique_ptr<CertBuilder> leaf;
449
450 if (cert_config_.intermediate != IntermediateType::kNone) {
451 intermediate = std::make_unique<CertBuilder>(nullptr, root.get());
452 intermediate->SetValidity(now - base::Days(100), now + base::Days(1000));
453 intermediate->SetBasicConstraints(/*is_ca=*/true, /*path_len=*/-1);
454 intermediate->SetKeyUsages(
455 {bssl::KEY_USAGE_BIT_KEY_CERT_SIGN, bssl::KEY_USAGE_BIT_CRL_SIGN});
456
457 leaf = std::make_unique<CertBuilder>(nullptr, intermediate.get());
458 } else {
459 leaf = std::make_unique<CertBuilder>(nullptr, root.get());
460 }
461 std::vector<GURL> leaf_ca_issuers_urls;
462 std::vector<GURL> leaf_ocsp_urls;
463
464 leaf->SetValidity(now - base::Days(1), now + base::Days(20));
465 leaf->SetBasicConstraints(/*is_ca=*/cert_config_.leaf_is_ca, /*path_len=*/-1);
466 leaf->SetExtendedKeyUsages({bssl::der::Input(bssl::kServerAuth)});
467
468 if (!cert_config_.policy_oids.empty()) {
469 leaf->SetCertificatePolicies(cert_config_.policy_oids);
470 if (intermediate)
471 intermediate->SetCertificatePolicies(cert_config_.policy_oids);
472 }
473
474 if (!cert_config_.dns_names.empty() || !cert_config_.ip_addresses.empty()) {
475 leaf->SetSubjectAltNames(cert_config_.dns_names, cert_config_.ip_addresses);
476 } else {
477 leaf->SetSubjectAltNames({}, {net::IPAddress::IPv4Localhost()});
478 }
479
480 if (!cert_config_.key_usages.empty()) {
481 leaf->SetKeyUsages(cert_config_.key_usages);
482 } else {
483 leaf->SetKeyUsages({bssl::KEY_USAGE_BIT_DIGITAL_SIGNATURE});
484 }
485
486 if (!cert_config_.embedded_scts.empty()) {
487 leaf->SetSctConfig(cert_config_.embedded_scts);
488 }
489
490 const std::string leaf_serial_text =
491 base::NumberToString(leaf->GetSerialNumber());
492 const std::string intermediate_serial_text =
493 intermediate ? base::NumberToString(intermediate->GetSerialNumber()) : "";
494
495 std::string ocsp_response;
496 if (!MaybeCreateOCSPResponse(leaf.get(), cert_config_.ocsp_config,
497 &ocsp_response)) {
498 return false;
499 }
500 if (!ocsp_response.empty()) {
501 std::string ocsp_path = "/ocsp/" + leaf_serial_text;
502 leaf_ocsp_urls.push_back(aia_http_server_->GetURL(ocsp_path));
503 aia_http_server_->RegisterRequestHandler(
504 base::BindRepeating(ServeResponseForSubPaths, ocsp_path, HTTP_OK,
505 "application/ocsp-response", ocsp_response));
506 }
507
508 std::string stapled_ocsp_response;
509 if (!MaybeCreateOCSPResponse(leaf.get(), cert_config_.stapled_ocsp_config,
510 &stapled_ocsp_response)) {
511 return false;
512 }
513 if (!stapled_ocsp_response.empty()) {
514 ssl_config_.ocsp_response = std::vector<uint8_t>(
515 stapled_ocsp_response.begin(), stapled_ocsp_response.end());
516 }
517
518 std::string intermediate_ocsp_response;
519 if (!MaybeCreateOCSPResponse(intermediate.get(),
520 cert_config_.intermediate_ocsp_config,
521 &intermediate_ocsp_response)) {
522 return false;
523 }
524 if (!intermediate_ocsp_response.empty()) {
525 std::string intermediate_ocsp_path = "/ocsp/" + intermediate_serial_text;
526 intermediate->SetCaIssuersAndOCSPUrls(
527 {}, {aia_http_server_->GetURL(intermediate_ocsp_path)});
528 aia_http_server_->RegisterRequestHandler(base::BindRepeating(
529 ServeResponseForSubPaths, intermediate_ocsp_path, HTTP_OK,
530 "application/ocsp-response", intermediate_ocsp_response));
531 }
532
533 if (cert_config_.intermediate == IntermediateType::kByAIA) {
534 std::string ca_issuers_path = "/ca_issuers/" + intermediate_serial_text;
535 leaf_ca_issuers_urls.push_back(aia_http_server_->GetURL(ca_issuers_path));
536
537 // Setup AIA server to serve the intermediate referred to by the leaf.
538 aia_http_server_->RegisterRequestHandler(
539 base::BindRepeating(ServeResponseForPath, ca_issuers_path, HTTP_OK,
540 "application/pkix-cert", intermediate->GetDER()));
541 }
542
543 if (!leaf_ca_issuers_urls.empty() || !leaf_ocsp_urls.empty()) {
544 leaf->SetCaIssuersAndOCSPUrls(leaf_ca_issuers_urls, leaf_ocsp_urls);
545 }
546
547 if (cert_config_.intermediate == IntermediateType::kByAIA ||
548 cert_config_.intermediate == IntermediateType::kMissing) {
549 // Server certificate chain does not include the intermediate.
550 x509_cert_ = leaf->GetX509Certificate();
551 } else {
552 // Server certificate chain will include the intermediate, if there is one.
553 x509_cert_ = leaf->GetX509CertificateChain();
554 }
555
556 if (intermediate) {
557 intermediate_ = intermediate->GetX509Certificate();
558 }
559
560 root_ = root->GetX509Certificate();
561
562 private_key_ = bssl::UpRef(leaf->GetKey());
563
564 // If this server is already accepting connections but is being reconfigured,
565 // start the new AIA server now. Otherwise, wait until
566 // StartAcceptingConnections so that this server and the AIA server start at
567 // the same time. (If the test only called InitializeAndListen they expect no
568 // threads to be created yet.)
569 if (io_thread_)
570 aia_http_server_->StartAcceptingConnections();
571
572 return true;
573 }
574
InitializeSSLServerContext()575 bool EmbeddedTestServer::InitializeSSLServerContext() {
576 if (UsingStaticCert()) {
577 if (!InitializeCertAndKeyFromFile())
578 return false;
579 } else {
580 if (!GenerateCertAndKey())
581 return false;
582 }
583
584 if (protocol_ == HttpConnection::Protocol::kHttp2) {
585 ssl_config_.alpn_protos = {NextProto::kProtoHTTP2};
586 if (!alps_accept_ch_.empty()) {
587 base::StringPairs origin_accept_ch;
588 size_t frame_size = spdy::kFrameHeaderSize;
589 // Figure out size and generate origins
590 for (const auto& pair : alps_accept_ch_) {
591 std::string_view hostname = pair.first;
592 std::string accept_ch = pair.second;
593
594 GURL url = hostname.empty() ? GetURL("/") : GetURL(hostname, "/");
595 std::string origin = url::Origin::Create(url).Serialize();
596
597 frame_size += accept_ch.size() + origin.size() +
598 (sizeof(uint16_t) * 2); // = Origin-Len + Value-Len
599
600 origin_accept_ch.push_back({std::move(origin), std::move(accept_ch)});
601 }
602
603 spdy::SpdyFrameBuilder builder(frame_size);
604 builder.BeginNewFrame(spdy::SpdyFrameType::ACCEPT_CH, 0, 0);
605 for (const auto& pair : origin_accept_ch) {
606 std::string_view origin = pair.first;
607 std::string_view accept_ch = pair.second;
608
609 builder.WriteUInt16(origin.size());
610 builder.WriteBytes(origin.data(), origin.size());
611
612 builder.WriteUInt16(accept_ch.size());
613 builder.WriteBytes(accept_ch.data(), accept_ch.size());
614 }
615
616 spdy::SpdySerializedFrame serialized_frame = builder.take();
617 DCHECK_EQ(frame_size, serialized_frame.size());
618
619 ssl_config_.application_settings[NextProto::kProtoHTTP2] =
620 std::vector<uint8_t>(
621 serialized_frame.data(),
622 serialized_frame.data() + serialized_frame.size());
623
624 ssl_config_.client_hello_callback_for_testing =
625 base::BindRepeating([](const SSL_CLIENT_HELLO* client_hello) {
626 // Configure the server to use the ALPS codepoint that the client
627 // offered.
628 const uint8_t* unused_extension_bytes;
629 size_t unused_extension_len;
630 int use_alps_new_codepoint = SSL_early_callback_ctx_extension_get(
631 client_hello, TLSEXT_TYPE_application_settings,
632 &unused_extension_bytes, &unused_extension_len);
633 // Make sure we use the right ALPS codepoint.
634 SSL_set_alps_use_new_codepoint(client_hello->ssl,
635 use_alps_new_codepoint);
636 return true;
637 });
638 }
639 }
640
641 context_ =
642 CreateSSLServerContext(x509_cert_.get(), private_key_.get(), ssl_config_);
643 return true;
644 }
645
646 EmbeddedTestServerHandle
StartAcceptingConnectionsAndReturnHandle()647 EmbeddedTestServer::StartAcceptingConnectionsAndReturnHandle() {
648 StartAcceptingConnections();
649 return EmbeddedTestServerHandle(this);
650 }
651
StartAcceptingConnections()652 void EmbeddedTestServer::StartAcceptingConnections() {
653 DCHECK(Started());
654 DCHECK(!io_thread_) << "Server must not be started while server is running";
655
656 if (aia_http_server_)
657 aia_http_server_->StartAcceptingConnections();
658
659 base::Thread::Options thread_options;
660 thread_options.message_pump_type = base::MessagePumpType::IO;
661 io_thread_ = std::make_unique<base::Thread>("EmbeddedTestServer IO Thread");
662 CHECK(io_thread_->StartWithOptions(std::move(thread_options)));
663 CHECK(io_thread_->WaitUntilThreadStarted());
664
665 io_thread_->task_runner()->PostTask(
666 FROM_HERE, base::BindOnce(&EmbeddedTestServer::DoAcceptLoop,
667 base::Unretained(this)));
668 }
669
ShutdownAndWaitUntilComplete()670 bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() {
671 DCHECK(thread_checker_.CalledOnValidThread());
672
673 if (!io_thread_) {
674 // Can't stop a server that never started.
675 return true;
676 }
677
678 // Ensure that the AIA HTTP server is no longer Started().
679 bool aia_http_server_not_started = true;
680 if (aia_http_server_ && aia_http_server_->Started()) {
681 aia_http_server_not_started =
682 aia_http_server_->ShutdownAndWaitUntilComplete();
683 }
684
685 // Return false if either this or the AIA HTTP server are still Started().
686 return PostTaskToIOThreadAndWait(
687 base::BindOnce(&EmbeddedTestServer::ShutdownOnIOThread,
688 base::Unretained(this))) &&
689 aia_http_server_not_started;
690 }
691
692 // static
GetRootCertPemPath()693 base::FilePath EmbeddedTestServer::GetRootCertPemPath() {
694 return GetTestCertsDirectory().AppendASCII("root_ca_cert.pem");
695 }
696
ShutdownOnIOThread()697 void EmbeddedTestServer::ShutdownOnIOThread() {
698 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
699 weak_factory_.InvalidateWeakPtrs();
700 shutdown_closures_.Notify();
701 listen_socket_.reset();
702 connections_.clear();
703 }
704
GetConnectionForSocket(const StreamSocket * socket)705 HttpConnection* EmbeddedTestServer::GetConnectionForSocket(
706 const StreamSocket* socket) {
707 auto it = connections_.find(socket);
708 if (it != connections_.end()) {
709 return it->second.get();
710 }
711 return nullptr;
712 }
713
HandleRequest(base::WeakPtr<HttpResponseDelegate> delegate,std::unique_ptr<HttpRequest> request,const StreamSocket * socket)714 void EmbeddedTestServer::HandleRequest(
715 base::WeakPtr<HttpResponseDelegate> delegate,
716 std::unique_ptr<HttpRequest> request,
717 const StreamSocket* socket) {
718 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
719 request->base_url = base_url_;
720
721 for (const auto& monitor : request_monitors_)
722 monitor.Run(*request);
723
724 HttpConnection* connection = GetConnectionForSocket(socket);
725 CHECK(connection);
726
727 if (auth_handler_) {
728 auto auth_result = auth_handler_.Run(*request);
729 if (auth_result) {
730 DispatchResponseToDelegate(std::move(auth_result), delegate);
731 return;
732 }
733 }
734
735 for (const auto& upgrade_request_handler : upgrade_request_handlers_) {
736 auto upgrade_response = upgrade_request_handler.Run(*request, connection);
737 if (upgrade_response.has_value()) {
738 if (upgrade_response.value() == UpgradeResult::kUpgraded) {
739 connections_.erase(socket);
740 return;
741 }
742 } else {
743 CHECK(upgrade_response.error());
744 DispatchResponseToDelegate(std::move(upgrade_response.error()), delegate);
745 return;
746 }
747 }
748
749 std::unique_ptr<HttpResponse> response;
750
751 for (const auto& handler : request_handlers_) {
752 response = handler.Run(*request);
753 if (response)
754 break;
755 }
756
757 if (!response) {
758 for (const auto& handler : default_request_handlers_) {
759 response = handler.Run(*request);
760 if (response)
761 break;
762 }
763 }
764
765 if (!response) {
766 LOG(WARNING) << "Request not handled. Returning 404: "
767 << request->relative_url;
768 auto not_found_response = std::make_unique<BasicHttpResponse>();
769 not_found_response->set_code(HTTP_NOT_FOUND);
770 response = std::move(not_found_response);
771 }
772
773 DispatchResponseToDelegate(std::move(response), delegate);
774 }
775
GetURL(std::string_view relative_url) const776 GURL EmbeddedTestServer::GetURL(std::string_view relative_url) const {
777 DCHECK(Started()) << "You must start the server first.";
778 DCHECK(relative_url.starts_with("/")) << relative_url;
779 return base_url_.Resolve(relative_url);
780 }
781
GetURL(std::string_view hostname,std::string_view relative_url) const782 GURL EmbeddedTestServer::GetURL(std::string_view hostname,
783 std::string_view relative_url) const {
784 GURL local_url = GetURL(relative_url);
785 GURL::Replacements replace_host;
786 replace_host.SetHostStr(hostname);
787 return local_url.ReplaceComponents(replace_host);
788 }
789
GetOrigin(const std::optional<std::string> & hostname) const790 url::Origin EmbeddedTestServer::GetOrigin(
791 const std::optional<std::string>& hostname) const {
792 if (hostname)
793 return url::Origin::Create(GetURL(*hostname, "/"));
794 return url::Origin::Create(base_url_);
795 }
796
GetAddressList(AddressList * address_list) const797 bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const {
798 *address_list = AddressList(local_endpoint_);
799 return true;
800 }
801
GetIPLiteralString() const802 std::string EmbeddedTestServer::GetIPLiteralString() const {
803 return local_endpoint_.address().ToString();
804 }
805
SetSSLConfigInternal(ServerCertificate cert,const ServerCertificateConfig * cert_config,const SSLServerConfig & ssl_config)806 void EmbeddedTestServer::SetSSLConfigInternal(
807 ServerCertificate cert,
808 const ServerCertificateConfig* cert_config,
809 const SSLServerConfig& ssl_config) {
810 DCHECK(!Started());
811 cert_ = cert;
812 DCHECK(!cert_config || cert == CERT_AUTO);
813 cert_config_ = cert_config ? *cert_config : ServerCertificateConfig();
814 x509_cert_ = nullptr;
815 private_key_ = nullptr;
816 ssl_config_ = ssl_config;
817 }
818
SetSSLConfig(ServerCertificate cert,const SSLServerConfig & ssl_config)819 void EmbeddedTestServer::SetSSLConfig(ServerCertificate cert,
820 const SSLServerConfig& ssl_config) {
821 SetSSLConfigInternal(cert, /*cert_config=*/nullptr, ssl_config);
822 }
823
SetSSLConfig(ServerCertificate cert)824 void EmbeddedTestServer::SetSSLConfig(ServerCertificate cert) {
825 SetSSLConfigInternal(cert, /*cert_config=*/nullptr, SSLServerConfig());
826 }
827
SetSSLConfig(const ServerCertificateConfig & cert_config,const SSLServerConfig & ssl_config)828 void EmbeddedTestServer::SetSSLConfig(
829 const ServerCertificateConfig& cert_config,
830 const SSLServerConfig& ssl_config) {
831 SetSSLConfigInternal(CERT_AUTO, &cert_config, ssl_config);
832 }
833
SetSSLConfig(const ServerCertificateConfig & cert_config)834 void EmbeddedTestServer::SetSSLConfig(
835 const ServerCertificateConfig& cert_config) {
836 SetSSLConfigInternal(CERT_AUTO, &cert_config, SSLServerConfig());
837 }
838
SetCertHostnames(std::vector<std::string> hostnames)839 void EmbeddedTestServer::SetCertHostnames(std::vector<std::string> hostnames) {
840 ServerCertificateConfig cert_config;
841 cert_config.dns_names = std::move(hostnames);
842 cert_config.ip_addresses = {net::IPAddress::IPv4Localhost()};
843 SetSSLConfig(cert_config);
844 }
845
ResetSSLConfigOnIOThread(ServerCertificate cert,const SSLServerConfig & ssl_config)846 bool EmbeddedTestServer::ResetSSLConfigOnIOThread(
847 ServerCertificate cert,
848 const SSLServerConfig& ssl_config) {
849 cert_ = cert;
850 cert_config_ = ServerCertificateConfig();
851 ssl_config_ = ssl_config;
852 connections_.clear();
853 return InitializeSSLServerContext();
854 }
855
ResetSSLConfig(ServerCertificate cert,const SSLServerConfig & ssl_config)856 bool EmbeddedTestServer::ResetSSLConfig(ServerCertificate cert,
857 const SSLServerConfig& ssl_config) {
858 return PostTaskToIOThreadAndWaitWithResult(
859 base::BindOnce(&EmbeddedTestServer::ResetSSLConfigOnIOThread,
860 base::Unretained(this), cert, ssl_config));
861 }
862
GetCertificateName() const863 std::string EmbeddedTestServer::GetCertificateName() const {
864 DCHECK(is_using_ssl_);
865 switch (cert_) {
866 case CERT_OK:
867 case CERT_MISMATCHED_NAME:
868 return "ok_cert.pem";
869 case CERT_COMMON_NAME_IS_DOMAIN:
870 return "localhost_cert.pem";
871 case CERT_EXPIRED:
872 return "expired_cert.pem";
873 case CERT_CHAIN_WRONG_ROOT:
874 // This chain uses its own dedicated test root certificate to avoid
875 // side-effects that may affect testing.
876 return "redundant-server-chain.pem";
877 case CERT_COMMON_NAME_ONLY:
878 return "common_name_only.pem";
879 case CERT_SHA1_LEAF:
880 return "sha1_leaf.pem";
881 case CERT_OK_BY_INTERMEDIATE:
882 return "ok_cert_by_intermediate.pem";
883 case CERT_BAD_VALIDITY:
884 return "bad_validity.pem";
885 case CERT_TEST_NAMES:
886 return "test_names.pem";
887 case CERT_KEY_USAGE_RSA_ENCIPHERMENT:
888 return "key_usage_rsa_keyencipherment.pem";
889 case CERT_KEY_USAGE_RSA_DIGITAL_SIGNATURE:
890 return "key_usage_rsa_digitalsignature.pem";
891 case CERT_AUTO:
892 return std::string();
893 }
894
895 return "ok_cert.pem";
896 }
897
GetCertificate()898 scoped_refptr<X509Certificate> EmbeddedTestServer::GetCertificate() {
899 DCHECK(is_using_ssl_);
900 if (!x509_cert_) {
901 // Some tests want to get the certificate before the server has been
902 // initialized, so load it now if necessary. This is only possible if using
903 // a static certificate.
904 // TODO(mattm): change contract to require initializing first in all cases,
905 // update callers.
906 CHECK(UsingStaticCert());
907 // TODO(mattm): change contract to return nullptr on error instead of
908 // CHECKing, update callers.
909 CHECK(InitializeCertAndKeyFromFile());
910 }
911 return x509_cert_;
912 }
913
GetGeneratedIntermediate()914 scoped_refptr<X509Certificate> EmbeddedTestServer::GetGeneratedIntermediate() {
915 DCHECK(is_using_ssl_);
916 DCHECK(!UsingStaticCert());
917 return intermediate_;
918 }
919
GetRoot()920 scoped_refptr<X509Certificate> EmbeddedTestServer::GetRoot() {
921 DCHECK(is_using_ssl_);
922 return root_;
923 }
924
ServeFilesFromDirectory(const base::FilePath & directory)925 void EmbeddedTestServer::ServeFilesFromDirectory(
926 const base::FilePath& directory) {
927 RegisterDefaultHandler(base::BindRepeating(&HandleFileRequest, directory));
928 }
929
ServeFilesFromSourceDirectory(std::string_view relative)930 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
931 std::string_view relative) {
932 base::FilePath test_data_dir;
933 CHECK(base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir));
934 ServeFilesFromDirectory(test_data_dir.AppendASCII(relative));
935 }
936
ServeFilesFromSourceDirectory(const base::FilePath & relative)937 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
938 const base::FilePath& relative) {
939 ServeFilesFromDirectory(GetFullPathFromSourceDirectory(relative));
940 }
941
AddDefaultHandlers(const base::FilePath & directory)942 void EmbeddedTestServer::AddDefaultHandlers(const base::FilePath& directory) {
943 ServeFilesFromSourceDirectory(directory);
944 AddDefaultHandlers();
945 }
946
AddDefaultHandlers()947 void EmbeddedTestServer::AddDefaultHandlers() {
948 RegisterDefaultHandlers(this);
949 }
950
GetFullPathFromSourceDirectory(const base::FilePath & relative)951 base::FilePath EmbeddedTestServer::GetFullPathFromSourceDirectory(
952 const base::FilePath& relative) {
953 base::FilePath test_data_dir;
954 CHECK(base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir));
955 return test_data_dir.Append(relative);
956 }
957
RegisterAuthHandler(const HandleRequestCallback & callback)958 void EmbeddedTestServer::RegisterAuthHandler(
959 const HandleRequestCallback& callback) {
960 CHECK(!io_thread_)
961 << "Handlers must be registered before starting the server.";
962 if (auth_handler_) {
963 DVLOG(2) << "Overwriting existing Auth handler.";
964 }
965 auth_handler_ = callback;
966 }
967
RegisterUpgradeRequestHandler(const HandleUpgradeRequestCallback & callback)968 void EmbeddedTestServer::RegisterUpgradeRequestHandler(
969 const HandleUpgradeRequestCallback& callback) {
970 CHECK_NE(protocol_, HttpConnection::Protocol::kHttp2)
971 << "RegisterUpgradeRequestHandler() is not supported for HTTP/2 "
972 "connections";
973 CHECK(!io_thread_)
974 << "Handlers must be registered before starting the server.";
975 upgrade_request_handlers_.push_back(callback);
976 }
977
RegisterRequestHandler(const HandleRequestCallback & callback)978 void EmbeddedTestServer::RegisterRequestHandler(
979 const HandleRequestCallback& callback) {
980 DCHECK(!io_thread_)
981 << "Handlers must be registered before starting the server.";
982 request_handlers_.push_back(callback);
983 }
984
RegisterRequestMonitor(const MonitorRequestCallback & callback)985 void EmbeddedTestServer::RegisterRequestMonitor(
986 const MonitorRequestCallback& callback) {
987 DCHECK(!io_thread_)
988 << "Monitors must be registered before starting the server.";
989 request_monitors_.push_back(callback);
990 }
991
RegisterDefaultHandler(const HandleRequestCallback & callback)992 void EmbeddedTestServer::RegisterDefaultHandler(
993 const HandleRequestCallback& callback) {
994 DCHECK(!io_thread_)
995 << "Handlers must be registered before starting the server.";
996 default_request_handlers_.push_back(callback);
997 }
998
DoSSLUpgrade(std::unique_ptr<StreamSocket> connection)999 std::unique_ptr<SSLServerSocket> EmbeddedTestServer::DoSSLUpgrade(
1000 std::unique_ptr<StreamSocket> connection) {
1001 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
1002
1003 return context_->CreateSSLServerSocket(std::move(connection));
1004 }
1005
DoAcceptLoop()1006 void EmbeddedTestServer::DoAcceptLoop() {
1007 while (true) {
1008 int rv = listen_socket_->Accept(
1009 &accepted_socket_,
1010 base::BindOnce(&EmbeddedTestServer::OnAcceptCompleted,
1011 base::Unretained(this)));
1012 if (rv != OK)
1013 return;
1014
1015 HandleAcceptResult(std::move(accepted_socket_));
1016 }
1017 }
1018
FlushAllSocketsAndConnectionsOnUIThread()1019 bool EmbeddedTestServer::FlushAllSocketsAndConnectionsOnUIThread() {
1020 return PostTaskToIOThreadAndWait(
1021 base::BindOnce(&EmbeddedTestServer::FlushAllSocketsAndConnections,
1022 base::Unretained(this)));
1023 }
1024
FlushAllSocketsAndConnections()1025 void EmbeddedTestServer::FlushAllSocketsAndConnections() {
1026 connections_.clear();
1027 }
1028
SetAlpsAcceptCH(std::string hostname,std::string accept_ch)1029 void EmbeddedTestServer::SetAlpsAcceptCH(std::string hostname,
1030 std::string accept_ch) {
1031 alps_accept_ch_.insert_or_assign(std::move(hostname), std::move(accept_ch));
1032 }
1033
RegisterShutdownClosure(base::OnceClosure closure)1034 base::CallbackListSubscription EmbeddedTestServer::RegisterShutdownClosure(
1035 base::OnceClosure closure) {
1036 return shutdown_closures_.Add(std::move(closure));
1037 }
1038
OnAcceptCompleted(int rv)1039 void EmbeddedTestServer::OnAcceptCompleted(int rv) {
1040 DCHECK_NE(ERR_IO_PENDING, rv);
1041 HandleAcceptResult(std::move(accepted_socket_));
1042 DoAcceptLoop();
1043 }
1044
OnHandshakeDone(HttpConnection * connection,int rv)1045 void EmbeddedTestServer::OnHandshakeDone(HttpConnection* connection, int rv) {
1046 if (connection->Socket()->IsConnected()) {
1047 connection->OnSocketReady();
1048 } else {
1049 RemoveConnection(connection);
1050 }
1051 }
1052
HandleAcceptResult(std::unique_ptr<StreamSocket> socket_ptr)1053 void EmbeddedTestServer::HandleAcceptResult(
1054 std::unique_ptr<StreamSocket> socket_ptr) {
1055 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
1056 if (connection_listener_)
1057 socket_ptr = connection_listener_->AcceptedSocket(std::move(socket_ptr));
1058
1059 if (!is_using_ssl_) {
1060 AddConnection(std::move(socket_ptr))->OnSocketReady();
1061 return;
1062 }
1063
1064 socket_ptr = DoSSLUpgrade(std::move(socket_ptr));
1065
1066 StreamSocket* socket = socket_ptr.get();
1067 HttpConnection* connection = AddConnection(std::move(socket_ptr));
1068
1069 int rv = static_cast<SSLServerSocket*>(socket)->Handshake(
1070 base::BindOnce(&EmbeddedTestServer::OnHandshakeDone,
1071 base::Unretained(this), connection));
1072 if (rv != ERR_IO_PENDING)
1073 OnHandshakeDone(connection, rv);
1074 }
1075
AddConnection(std::unique_ptr<StreamSocket> socket_ptr)1076 HttpConnection* EmbeddedTestServer::AddConnection(
1077 std::unique_ptr<StreamSocket> socket_ptr) {
1078 StreamSocket* socket = socket_ptr.get();
1079 std::unique_ptr<HttpConnection> connection_ptr = HttpConnection::Create(
1080 std::move(socket_ptr), connection_listener_, this, protocol_);
1081 HttpConnection* connection = connection_ptr.get();
1082 connections_[socket] = std::move(connection_ptr);
1083
1084 return connection;
1085 }
1086
RemoveConnection(HttpConnection * connection,EmbeddedTestServerConnectionListener * listener)1087 void EmbeddedTestServer::RemoveConnection(
1088 HttpConnection* connection,
1089 EmbeddedTestServerConnectionListener* listener) {
1090 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
1091 DCHECK(connection);
1092 DCHECK_EQ(1u, connections_.count(connection->Socket()));
1093
1094 StreamSocket* raw_socket = connection->Socket();
1095 std::unique_ptr<StreamSocket> socket = connection->TakeSocket();
1096 connections_.erase(raw_socket);
1097
1098 if (listener && socket && socket->IsConnected())
1099 listener->OnResponseCompletedSuccessfully(std::move(socket));
1100 }
1101
PostTaskToIOThreadAndWait(base::OnceClosure closure)1102 bool EmbeddedTestServer::PostTaskToIOThreadAndWait(base::OnceClosure closure) {
1103 // Note that PostTaskAndReply below requires
1104 // base::SingleThreadTaskRunner::GetCurrentDefault() to return a task runner
1105 // for posting the reply task. However, in order to make EmbeddedTestServer
1106 // universally usable, it needs to cope with the situation where it's running
1107 // on a thread on which a task executor is not (yet) available or has been
1108 // destroyed already.
1109 //
1110 // To handle this situation, create temporary task executor to support the
1111 // PostTaskAndReply operation if the current thread has no task executor.
1112 // TODO(mattm): Is this still necessary/desirable? Try removing this and see
1113 // if anything breaks.
1114 std::unique_ptr<base::SingleThreadTaskExecutor> temporary_loop;
1115 if (!base::CurrentThread::Get())
1116 temporary_loop = std::make_unique<base::SingleThreadTaskExecutor>();
1117
1118 base::RunLoop run_loop;
1119 if (!io_thread_->task_runner()->PostTaskAndReply(
1120 FROM_HERE, std::move(closure), run_loop.QuitClosure())) {
1121 return false;
1122 }
1123 run_loop.Run();
1124
1125 return true;
1126 }
1127
PostTaskToIOThreadAndWaitWithResult(base::OnceCallback<bool ()> task)1128 bool EmbeddedTestServer::PostTaskToIOThreadAndWaitWithResult(
1129 base::OnceCallback<bool()> task) {
1130 // Note that PostTaskAndReply below requires
1131 // base::SingleThreadTaskRunner::GetCurrentDefault() to return a task runner
1132 // for posting the reply task. However, in order to make EmbeddedTestServer
1133 // universally usable, it needs to cope with the situation where it's running
1134 // on a thread on which a task executor is not (yet) available or has been
1135 // destroyed already.
1136 //
1137 // To handle this situation, create temporary task executor to support the
1138 // PostTaskAndReply operation if the current thread has no task executor.
1139 // TODO(mattm): Is this still necessary/desirable? Try removing this and see
1140 // if anything breaks.
1141 std::unique_ptr<base::SingleThreadTaskExecutor> temporary_loop;
1142 if (!base::CurrentThread::Get())
1143 temporary_loop = std::make_unique<base::SingleThreadTaskExecutor>();
1144
1145 base::RunLoop run_loop;
1146 bool task_result = false;
1147 if (!io_thread_->task_runner()->PostTaskAndReplyWithResult(
1148 FROM_HERE, std::move(task),
1149 base::BindOnce(base::BindLambdaForTesting([&](bool result) {
1150 task_result = result;
1151 run_loop.Quit();
1152 })))) {
1153 return false;
1154 }
1155 run_loop.Run();
1156
1157 return task_result;
1158 }
1159
1160 } // namespace net::test_server
1161