• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9 
10 #include "net/test/embedded_test_server/embedded_test_server.h"
11 
12 #include <stdint.h>
13 
14 #include <memory>
15 #include <optional>
16 #include <string_view>
17 #include <utility>
18 
19 #include "base/files/file_path.h"
20 #include "base/functional/bind.h"
21 #include "base/functional/callback_forward.h"
22 #include "base/functional/callback_helpers.h"
23 #include "base/location.h"
24 #include "base/logging.h"
25 #include "base/message_loop/message_pump_type.h"
26 #include "base/path_service.h"
27 #include "base/process/process_metrics.h"
28 #include "base/run_loop.h"
29 #include "base/strings/string_number_conversions.h"
30 #include "base/strings/string_util.h"
31 #include "base/strings/stringprintf.h"
32 #include "base/task/current_thread.h"
33 #include "base/task/single_thread_task_executor.h"
34 #include "base/task/single_thread_task_runner.h"
35 #include "base/test/bind.h"
36 #include "base/threading/thread_restrictions.h"
37 #include "crypto/rsa_private_key.h"
38 #include "net/base/hex_utils.h"
39 #include "net/base/ip_address.h"
40 #include "net/base/ip_endpoint.h"
41 #include "net/base/net_errors.h"
42 #include "net/base/port_util.h"
43 #include "net/log/net_log_source.h"
44 #include "net/socket/next_proto.h"
45 #include "net/socket/ssl_server_socket.h"
46 #include "net/socket/stream_socket.h"
47 #include "net/socket/tcp_server_socket.h"
48 #include "net/spdy/spdy_test_util_common.h"
49 #include "net/ssl/ssl_info.h"
50 #include "net/ssl/ssl_server_config.h"
51 #include "net/test/cert_builder.h"
52 #include "net/test/cert_test_util.h"
53 #include "net/test/embedded_test_server/default_handlers.h"
54 #include "net/test/embedded_test_server/embedded_test_server_connection_listener.h"
55 #include "net/test/embedded_test_server/http_request.h"
56 #include "net/test/embedded_test_server/http_response.h"
57 #include "net/test/embedded_test_server/request_handler_util.h"
58 #include "net/test/key_util.h"
59 #include "net/test/revocation_builder.h"
60 #include "net/test/test_data_directory.h"
61 #include "net/third_party/quiche/src/quiche/http2/core/spdy_frame_builder.h"
62 #include "third_party/boringssl/src/pki/extended_key_usage.h"
63 #include "url/origin.h"
64 
65 namespace net::test_server {
66 
67 namespace {
68 
ServeResponseForPath(const std::string & expected_path,HttpStatusCode status_code,const std::string & content_type,const std::string & content,const HttpRequest & request)69 std::unique_ptr<HttpResponse> ServeResponseForPath(
70     const std::string& expected_path,
71     HttpStatusCode status_code,
72     const std::string& content_type,
73     const std::string& content,
74     const HttpRequest& request) {
75   if (request.GetURL().path() != expected_path)
76     return nullptr;
77 
78   auto http_response = std::make_unique<BasicHttpResponse>();
79   http_response->set_code(status_code);
80   http_response->set_content_type(content_type);
81   http_response->set_content(content);
82   return http_response;
83 }
84 
85 // Serves response for |expected_path| or any subpath of it.
86 // |expected_path| should not include a trailing "/".
ServeResponseForSubPaths(const std::string & expected_path,HttpStatusCode status_code,const std::string & content_type,const std::string & content,const HttpRequest & request)87 std::unique_ptr<HttpResponse> ServeResponseForSubPaths(
88     const std::string& expected_path,
89     HttpStatusCode status_code,
90     const std::string& content_type,
91     const std::string& content,
92     const HttpRequest& request) {
93   if (request.GetURL().path() != expected_path &&
94       !request.GetURL().path().starts_with(expected_path + "/")) {
95     return nullptr;
96   }
97 
98   auto http_response = std::make_unique<BasicHttpResponse>();
99   http_response->set_code(status_code);
100   http_response->set_content_type(content_type);
101   http_response->set_content(content);
102   return http_response;
103 }
104 
MaybeCreateOCSPResponse(CertBuilder * target,const EmbeddedTestServer::OCSPConfig & config,std::string * out_response)105 bool MaybeCreateOCSPResponse(CertBuilder* target,
106                              const EmbeddedTestServer::OCSPConfig& config,
107                              std::string* out_response) {
108   using OCSPResponseType = EmbeddedTestServer::OCSPConfig::ResponseType;
109 
110   if (!config.single_responses.empty() &&
111       config.response_type != OCSPResponseType::kSuccessful) {
112     // OCSPConfig contained single_responses for a non-successful response.
113     return false;
114   }
115 
116   if (config.response_type == OCSPResponseType::kOff) {
117     *out_response = std::string();
118     return true;
119   }
120 
121   if (!target) {
122     // OCSPConfig enabled but corresponding certificate is null.
123     return false;
124   }
125 
126   switch (config.response_type) {
127     case OCSPResponseType::kOff:
128       return false;
129     case OCSPResponseType::kMalformedRequest:
130       *out_response = BuildOCSPResponseError(
131           bssl::OCSPResponse::ResponseStatus::MALFORMED_REQUEST);
132       return true;
133     case OCSPResponseType::kInternalError:
134       *out_response = BuildOCSPResponseError(
135           bssl::OCSPResponse::ResponseStatus::INTERNAL_ERROR);
136       return true;
137     case OCSPResponseType::kTryLater:
138       *out_response =
139           BuildOCSPResponseError(bssl::OCSPResponse::ResponseStatus::TRY_LATER);
140       return true;
141     case OCSPResponseType::kSigRequired:
142       *out_response = BuildOCSPResponseError(
143           bssl::OCSPResponse::ResponseStatus::SIG_REQUIRED);
144       return true;
145     case OCSPResponseType::kUnauthorized:
146       *out_response = BuildOCSPResponseError(
147           bssl::OCSPResponse::ResponseStatus::UNAUTHORIZED);
148       return true;
149     case OCSPResponseType::kInvalidResponse:
150       *out_response = "3";
151       return true;
152     case OCSPResponseType::kInvalidResponseData:
153       *out_response =
154           BuildOCSPResponseWithResponseData(target->issuer()->GetKey(),
155                                             // OCTET_STRING { "not ocsp data" }
156                                             "\x04\x0dnot ocsp data");
157       return true;
158     case OCSPResponseType::kSuccessful:
159       break;
160   }
161 
162   base::Time now = base::Time::Now();
163   base::Time target_not_before, target_not_after;
164   if (!target->GetValidity(&target_not_before, &target_not_after))
165     return false;
166   base::Time produced_at;
167   using OCSPProduced = EmbeddedTestServer::OCSPConfig::Produced;
168   switch (config.produced) {
169     case OCSPProduced::kValid:
170       produced_at = std::max(now - base::Days(1), target_not_before);
171       break;
172     case OCSPProduced::kBeforeCert:
173       produced_at = target_not_before - base::Days(1);
174       break;
175     case OCSPProduced::kAfterCert:
176       produced_at = target_not_after + base::Days(1);
177       break;
178   }
179 
180   std::vector<OCSPBuilderSingleResponse> responses;
181   for (const auto& config_response : config.single_responses) {
182     OCSPBuilderSingleResponse response;
183     response.serial = target->GetSerialNumber();
184     if (config_response.serial ==
185         EmbeddedTestServer::OCSPConfig::SingleResponse::Serial::kMismatch) {
186       response.serial ^= 1;
187     }
188     response.cert_status = config_response.cert_status;
189     // |revocation_time| is ignored if |cert_status| is not REVOKED.
190     response.revocation_time = now - base::Days(1000);
191 
192     using OCSPDate = EmbeddedTestServer::OCSPConfig::SingleResponse::Date;
193     switch (config_response.ocsp_date) {
194       case OCSPDate::kValid:
195         response.this_update = now - base::Days(1);
196         response.next_update = response.this_update + base::Days(7);
197         break;
198       case OCSPDate::kOld:
199         response.this_update = now - base::Days(8);
200         response.next_update = response.this_update + base::Days(7);
201         break;
202       case OCSPDate::kEarly:
203         response.this_update = now + base::Days(1);
204         response.next_update = response.this_update + base::Days(7);
205         break;
206       case OCSPDate::kLong:
207         response.this_update = now - base::Days(365);
208         response.next_update = response.this_update + base::Days(366);
209         break;
210       case OCSPDate::kLonger:
211         response.this_update = now - base::Days(367);
212         response.next_update = response.this_update + base::Days(368);
213         break;
214     }
215 
216     responses.push_back(response);
217   }
218   *out_response =
219       BuildOCSPResponse(target->issuer()->GetSubject(),
220                         target->issuer()->GetKey(), produced_at, responses);
221   return true;
222 }
223 
DispatchResponseToDelegate(std::unique_ptr<HttpResponse> response,base::WeakPtr<HttpResponseDelegate> delegate)224 void DispatchResponseToDelegate(std::unique_ptr<HttpResponse> response,
225                                 base::WeakPtr<HttpResponseDelegate> delegate) {
226   HttpResponse* const response_ptr = response.get();
227   delegate->AddResponse(std::move(response));
228   response_ptr->SendResponse(delegate);
229 }
230 
231 }  // namespace
232 
EmbeddedTestServerHandle(EmbeddedTestServerHandle && other)233 EmbeddedTestServerHandle::EmbeddedTestServerHandle(
234     EmbeddedTestServerHandle&& other) {
235   operator=(std::move(other));
236 }
237 
operator =(EmbeddedTestServerHandle && other)238 EmbeddedTestServerHandle& EmbeddedTestServerHandle::operator=(
239     EmbeddedTestServerHandle&& other) {
240   EmbeddedTestServerHandle temporary;
241   std::swap(other.test_server_, temporary.test_server_);
242   std::swap(temporary.test_server_, test_server_);
243   return *this;
244 }
245 
EmbeddedTestServerHandle(EmbeddedTestServer * test_server)246 EmbeddedTestServerHandle::EmbeddedTestServerHandle(
247     EmbeddedTestServer* test_server)
248     : test_server_(test_server) {}
249 
~EmbeddedTestServerHandle()250 EmbeddedTestServerHandle::~EmbeddedTestServerHandle() {
251   if (test_server_)
252     CHECK(test_server_->ShutdownAndWaitUntilComplete());
253 }
254 
255 EmbeddedTestServer::OCSPConfig::OCSPConfig() = default;
OCSPConfig(ResponseType response_type)256 EmbeddedTestServer::OCSPConfig::OCSPConfig(ResponseType response_type)
257     : response_type(response_type) {}
OCSPConfig(std::vector<SingleResponse> single_responses,Produced produced)258 EmbeddedTestServer::OCSPConfig::OCSPConfig(
259     std::vector<SingleResponse> single_responses,
260     Produced produced)
261     : response_type(ResponseType::kSuccessful),
262       produced(produced),
263       single_responses(std::move(single_responses)) {}
264 EmbeddedTestServer::OCSPConfig::OCSPConfig(const OCSPConfig&) = default;
265 EmbeddedTestServer::OCSPConfig::OCSPConfig(OCSPConfig&&) = default;
266 EmbeddedTestServer::OCSPConfig::~OCSPConfig() = default;
267 EmbeddedTestServer::OCSPConfig& EmbeddedTestServer::OCSPConfig::operator=(
268     const OCSPConfig&) = default;
269 EmbeddedTestServer::OCSPConfig& EmbeddedTestServer::OCSPConfig::operator=(
270     OCSPConfig&&) = default;
271 
272 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig() =
273     default;
274 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig(
275     const ServerCertificateConfig&) = default;
276 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig(
277     ServerCertificateConfig&&) = default;
278 EmbeddedTestServer::ServerCertificateConfig::~ServerCertificateConfig() =
279     default;
280 EmbeddedTestServer::ServerCertificateConfig&
281 EmbeddedTestServer::ServerCertificateConfig::operator=(
282     const ServerCertificateConfig&) = default;
283 EmbeddedTestServer::ServerCertificateConfig&
284 EmbeddedTestServer::ServerCertificateConfig::operator=(
285     ServerCertificateConfig&&) = default;
286 
EmbeddedTestServer()287 EmbeddedTestServer::EmbeddedTestServer() : EmbeddedTestServer(TYPE_HTTP) {}
288 
EmbeddedTestServer(Type type,HttpConnection::Protocol protocol)289 EmbeddedTestServer::EmbeddedTestServer(Type type,
290                                        HttpConnection::Protocol protocol)
291     : is_using_ssl_(type == TYPE_HTTPS), protocol_(protocol) {
292   DCHECK(thread_checker_.CalledOnValidThread());
293   // HTTP/2 is only valid by negotiation via TLS ALPN
294   DCHECK(protocol_ != HttpConnection::Protocol::kHttp2 || type == TYPE_HTTPS);
295 
296   if (!is_using_ssl_)
297     return;
298   scoped_test_root_ = RegisterTestCerts();
299 }
300 
~EmbeddedTestServer()301 EmbeddedTestServer::~EmbeddedTestServer() {
302   DCHECK(thread_checker_.CalledOnValidThread());
303 
304   if (Started())
305     CHECK(ShutdownAndWaitUntilComplete());
306 
307   {
308     base::ScopedAllowBaseSyncPrimitivesForTesting allow_wait_for_thread_join;
309     io_thread_.reset();
310   }
311 }
312 
RegisterTestCerts()313 ScopedTestRoot EmbeddedTestServer::RegisterTestCerts() {
314   base::ScopedAllowBlockingForTesting allow_blocking;
315   auto root = ImportCertFromFile(GetRootCertPemPath());
316   if (!root)
317     return ScopedTestRoot();
318   return ScopedTestRoot(root);
319 }
320 
SetConnectionListener(EmbeddedTestServerConnectionListener * listener)321 void EmbeddedTestServer::SetConnectionListener(
322     EmbeddedTestServerConnectionListener* listener) {
323   DCHECK(!io_thread_)
324       << "ConnectionListener must be set before starting the server.";
325   connection_listener_ = listener;
326 }
327 
StartAndReturnHandle(int port)328 EmbeddedTestServerHandle EmbeddedTestServer::StartAndReturnHandle(int port) {
329   bool result = Start(port);
330   return result ? EmbeddedTestServerHandle(this) : EmbeddedTestServerHandle();
331 }
332 
Start(int port,std::string_view address)333 bool EmbeddedTestServer::Start(int port, std::string_view address) {
334   bool success = InitializeAndListen(port, address);
335   if (success)
336     StartAcceptingConnections();
337   return success;
338 }
339 
InitializeAndListen(int port,std::string_view address)340 bool EmbeddedTestServer::InitializeAndListen(int port,
341                                              std::string_view address) {
342   DCHECK(!Started());
343 
344   const int max_tries = 5;
345   int num_tries = 0;
346   bool is_valid_port = false;
347 
348   do {
349     if (++num_tries > max_tries) {
350       LOG(ERROR) << "Failed to listen on a valid port after " << max_tries
351                  << " attempts.";
352       listen_socket_.reset();
353       return false;
354     }
355 
356     listen_socket_ = std::make_unique<TCPServerSocket>(nullptr, NetLogSource());
357 
358     int result =
359         listen_socket_->ListenWithAddressAndPort(address.data(), port, 10);
360     if (result) {
361       LOG(ERROR) << "Listen failed: " << ErrorToString(result);
362       listen_socket_.reset();
363       return false;
364     }
365 
366     result = listen_socket_->GetLocalAddress(&local_endpoint_);
367     if (result != OK) {
368       LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
369       listen_socket_.reset();
370       return false;
371     }
372 
373     port_ = local_endpoint_.port();
374     is_valid_port |= net::IsPortAllowedForScheme(
375         port_, is_using_ssl_ ? url::kHttpsScheme : url::kHttpScheme);
376   } while (!is_valid_port);
377 
378   if (is_using_ssl_) {
379     base_url_ = GURL("https://" + local_endpoint_.ToString());
380     if (cert_ == CERT_MISMATCHED_NAME || cert_ == CERT_COMMON_NAME_IS_DOMAIN) {
381       base_url_ = GURL(
382           base::StringPrintf("https://localhost:%d", local_endpoint_.port()));
383     }
384   } else {
385     base_url_ = GURL("http://" + local_endpoint_.ToString());
386   }
387 
388   listen_socket_->DetachFromThread();
389 
390   if (is_using_ssl_ && !InitializeSSLServerContext())
391     return false;
392 
393   return true;
394 }
395 
UsingStaticCert() const396 bool EmbeddedTestServer::UsingStaticCert() const {
397   return !GetCertificateName().empty();
398 }
399 
InitializeCertAndKeyFromFile()400 bool EmbeddedTestServer::InitializeCertAndKeyFromFile() {
401   base::ScopedAllowBlockingForTesting allow_blocking;
402   base::FilePath certs_dir(GetTestCertsDirectory());
403   std::string cert_name = GetCertificateName();
404   if (cert_name.empty())
405     return false;
406 
407   x509_cert_ = CreateCertificateChainFromFile(certs_dir, cert_name,
408                                               X509Certificate::FORMAT_AUTO);
409   if (!x509_cert_)
410     return false;
411 
412   private_key_ =
413       key_util::LoadEVP_PKEYFromPEM(certs_dir.AppendASCII(cert_name));
414   return !!private_key_;
415 }
416 
GenerateCertAndKey()417 bool EmbeddedTestServer::GenerateCertAndKey() {
418   // Create AIA server and start listening. Need to have the socket initialized
419   // so the URL can be put in the AIA records of the generated certs.
420   aia_http_server_ = std::make_unique<EmbeddedTestServer>(TYPE_HTTP);
421   if (!aia_http_server_->InitializeAndListen())
422     return false;
423 
424   base::ScopedAllowBlockingForTesting allow_blocking;
425   base::FilePath certs_dir(GetTestCertsDirectory());
426   auto now = base::Time::Now();
427 
428   std::unique_ptr<CertBuilder> root;
429   switch (cert_config_.root) {
430     case RootType::kTestRootCa:
431       root = CertBuilder::FromStaticCertFile(
432           certs_dir.AppendASCII("root_ca_cert.pem"));
433       break;
434     case RootType::kUniqueRoot:
435       root = std::make_unique<CertBuilder>(nullptr, nullptr);
436       root->SetValidity(now - base::Days(100), now + base::Days(1000));
437       root->SetBasicConstraints(/*is_ca=*/true, /*path_len=*/-1);
438       root->SetKeyUsages(
439           {bssl::KEY_USAGE_BIT_KEY_CERT_SIGN, bssl::KEY_USAGE_BIT_CRL_SIGN});
440       if (!cert_config_.root_dns_names.empty()) {
441         root->SetSubjectAltNames(cert_config_.root_dns_names, {});
442       }
443       break;
444   }
445 
446   // Will be nullptr if cert_config_.intermediate == kNone.
447   std::unique_ptr<CertBuilder> intermediate;
448   std::unique_ptr<CertBuilder> leaf;
449 
450   if (cert_config_.intermediate != IntermediateType::kNone) {
451     intermediate = std::make_unique<CertBuilder>(nullptr, root.get());
452     intermediate->SetValidity(now - base::Days(100), now + base::Days(1000));
453     intermediate->SetBasicConstraints(/*is_ca=*/true, /*path_len=*/-1);
454     intermediate->SetKeyUsages(
455         {bssl::KEY_USAGE_BIT_KEY_CERT_SIGN, bssl::KEY_USAGE_BIT_CRL_SIGN});
456 
457     leaf = std::make_unique<CertBuilder>(nullptr, intermediate.get());
458   } else {
459     leaf = std::make_unique<CertBuilder>(nullptr, root.get());
460   }
461   std::vector<GURL> leaf_ca_issuers_urls;
462   std::vector<GURL> leaf_ocsp_urls;
463 
464   leaf->SetValidity(now - base::Days(1), now + base::Days(20));
465   leaf->SetBasicConstraints(/*is_ca=*/cert_config_.leaf_is_ca, /*path_len=*/-1);
466   leaf->SetExtendedKeyUsages({bssl::der::Input(bssl::kServerAuth)});
467 
468   if (!cert_config_.policy_oids.empty()) {
469     leaf->SetCertificatePolicies(cert_config_.policy_oids);
470     if (intermediate)
471       intermediate->SetCertificatePolicies(cert_config_.policy_oids);
472   }
473 
474   if (!cert_config_.dns_names.empty() || !cert_config_.ip_addresses.empty()) {
475     leaf->SetSubjectAltNames(cert_config_.dns_names, cert_config_.ip_addresses);
476   } else {
477     leaf->SetSubjectAltNames({}, {net::IPAddress::IPv4Localhost()});
478   }
479 
480   if (!cert_config_.key_usages.empty()) {
481     leaf->SetKeyUsages(cert_config_.key_usages);
482   } else {
483     leaf->SetKeyUsages({bssl::KEY_USAGE_BIT_DIGITAL_SIGNATURE});
484   }
485 
486   if (!cert_config_.embedded_scts.empty()) {
487     leaf->SetSctConfig(cert_config_.embedded_scts);
488   }
489 
490   const std::string leaf_serial_text =
491       base::NumberToString(leaf->GetSerialNumber());
492   const std::string intermediate_serial_text =
493       intermediate ? base::NumberToString(intermediate->GetSerialNumber()) : "";
494 
495   std::string ocsp_response;
496   if (!MaybeCreateOCSPResponse(leaf.get(), cert_config_.ocsp_config,
497                                &ocsp_response)) {
498     return false;
499   }
500   if (!ocsp_response.empty()) {
501     std::string ocsp_path = "/ocsp/" + leaf_serial_text;
502     leaf_ocsp_urls.push_back(aia_http_server_->GetURL(ocsp_path));
503     aia_http_server_->RegisterRequestHandler(
504         base::BindRepeating(ServeResponseForSubPaths, ocsp_path, HTTP_OK,
505                             "application/ocsp-response", ocsp_response));
506   }
507 
508   std::string stapled_ocsp_response;
509   if (!MaybeCreateOCSPResponse(leaf.get(), cert_config_.stapled_ocsp_config,
510                                &stapled_ocsp_response)) {
511     return false;
512   }
513   if (!stapled_ocsp_response.empty()) {
514     ssl_config_.ocsp_response = std::vector<uint8_t>(
515         stapled_ocsp_response.begin(), stapled_ocsp_response.end());
516   }
517 
518   std::string intermediate_ocsp_response;
519   if (!MaybeCreateOCSPResponse(intermediate.get(),
520                                cert_config_.intermediate_ocsp_config,
521                                &intermediate_ocsp_response)) {
522     return false;
523   }
524   if (!intermediate_ocsp_response.empty()) {
525     std::string intermediate_ocsp_path = "/ocsp/" + intermediate_serial_text;
526     intermediate->SetCaIssuersAndOCSPUrls(
527         {}, {aia_http_server_->GetURL(intermediate_ocsp_path)});
528     aia_http_server_->RegisterRequestHandler(base::BindRepeating(
529         ServeResponseForSubPaths, intermediate_ocsp_path, HTTP_OK,
530         "application/ocsp-response", intermediate_ocsp_response));
531   }
532 
533   if (cert_config_.intermediate == IntermediateType::kByAIA) {
534     std::string ca_issuers_path = "/ca_issuers/" + intermediate_serial_text;
535     leaf_ca_issuers_urls.push_back(aia_http_server_->GetURL(ca_issuers_path));
536 
537     // Setup AIA server to serve the intermediate referred to by the leaf.
538     aia_http_server_->RegisterRequestHandler(
539         base::BindRepeating(ServeResponseForPath, ca_issuers_path, HTTP_OK,
540                             "application/pkix-cert", intermediate->GetDER()));
541   }
542 
543   if (!leaf_ca_issuers_urls.empty() || !leaf_ocsp_urls.empty()) {
544     leaf->SetCaIssuersAndOCSPUrls(leaf_ca_issuers_urls, leaf_ocsp_urls);
545   }
546 
547   if (cert_config_.intermediate == IntermediateType::kByAIA ||
548       cert_config_.intermediate == IntermediateType::kMissing) {
549     // Server certificate chain does not include the intermediate.
550     x509_cert_ = leaf->GetX509Certificate();
551   } else {
552     // Server certificate chain will include the intermediate, if there is one.
553     x509_cert_ = leaf->GetX509CertificateChain();
554   }
555 
556   if (intermediate) {
557     intermediate_ = intermediate->GetX509Certificate();
558   }
559 
560   root_ = root->GetX509Certificate();
561 
562   private_key_ = bssl::UpRef(leaf->GetKey());
563 
564   // If this server is already accepting connections but is being reconfigured,
565   // start the new AIA server now. Otherwise, wait until
566   // StartAcceptingConnections so that this server and the AIA server start at
567   // the same time. (If the test only called InitializeAndListen they expect no
568   // threads to be created yet.)
569   if (io_thread_)
570     aia_http_server_->StartAcceptingConnections();
571 
572   return true;
573 }
574 
InitializeSSLServerContext()575 bool EmbeddedTestServer::InitializeSSLServerContext() {
576   if (UsingStaticCert()) {
577     if (!InitializeCertAndKeyFromFile())
578       return false;
579   } else {
580     if (!GenerateCertAndKey())
581       return false;
582   }
583 
584   if (protocol_ == HttpConnection::Protocol::kHttp2) {
585     ssl_config_.alpn_protos = {NextProto::kProtoHTTP2};
586     if (!alps_accept_ch_.empty()) {
587       base::StringPairs origin_accept_ch;
588       size_t frame_size = spdy::kFrameHeaderSize;
589       // Figure out size and generate origins
590       for (const auto& pair : alps_accept_ch_) {
591         std::string_view hostname = pair.first;
592         std::string accept_ch = pair.second;
593 
594         GURL url = hostname.empty() ? GetURL("/") : GetURL(hostname, "/");
595         std::string origin = url::Origin::Create(url).Serialize();
596 
597         frame_size += accept_ch.size() + origin.size() +
598                       (sizeof(uint16_t) * 2);  // = Origin-Len + Value-Len
599 
600         origin_accept_ch.push_back({std::move(origin), std::move(accept_ch)});
601       }
602 
603       spdy::SpdyFrameBuilder builder(frame_size);
604       builder.BeginNewFrame(spdy::SpdyFrameType::ACCEPT_CH, 0, 0);
605       for (const auto& pair : origin_accept_ch) {
606         std::string_view origin = pair.first;
607         std::string_view accept_ch = pair.second;
608 
609         builder.WriteUInt16(origin.size());
610         builder.WriteBytes(origin.data(), origin.size());
611 
612         builder.WriteUInt16(accept_ch.size());
613         builder.WriteBytes(accept_ch.data(), accept_ch.size());
614       }
615 
616       spdy::SpdySerializedFrame serialized_frame = builder.take();
617       DCHECK_EQ(frame_size, serialized_frame.size());
618 
619       ssl_config_.application_settings[NextProto::kProtoHTTP2] =
620           std::vector<uint8_t>(
621               serialized_frame.data(),
622               serialized_frame.data() + serialized_frame.size());
623 
624       ssl_config_.client_hello_callback_for_testing =
625           base::BindRepeating([](const SSL_CLIENT_HELLO* client_hello) {
626             // Configure the server to use the ALPS codepoint that the client
627             // offered.
628             const uint8_t* unused_extension_bytes;
629             size_t unused_extension_len;
630             int use_alps_new_codepoint = SSL_early_callback_ctx_extension_get(
631                 client_hello, TLSEXT_TYPE_application_settings,
632                 &unused_extension_bytes, &unused_extension_len);
633             // Make sure we use the right ALPS codepoint.
634             SSL_set_alps_use_new_codepoint(client_hello->ssl,
635                                            use_alps_new_codepoint);
636             return true;
637           });
638     }
639   }
640 
641   context_ =
642       CreateSSLServerContext(x509_cert_.get(), private_key_.get(), ssl_config_);
643   return true;
644 }
645 
646 EmbeddedTestServerHandle
StartAcceptingConnectionsAndReturnHandle()647 EmbeddedTestServer::StartAcceptingConnectionsAndReturnHandle() {
648   StartAcceptingConnections();
649   return EmbeddedTestServerHandle(this);
650 }
651 
StartAcceptingConnections()652 void EmbeddedTestServer::StartAcceptingConnections() {
653   DCHECK(Started());
654   DCHECK(!io_thread_) << "Server must not be started while server is running";
655 
656   if (aia_http_server_)
657     aia_http_server_->StartAcceptingConnections();
658 
659   base::Thread::Options thread_options;
660   thread_options.message_pump_type = base::MessagePumpType::IO;
661   io_thread_ = std::make_unique<base::Thread>("EmbeddedTestServer IO Thread");
662   CHECK(io_thread_->StartWithOptions(std::move(thread_options)));
663   CHECK(io_thread_->WaitUntilThreadStarted());
664 
665   io_thread_->task_runner()->PostTask(
666       FROM_HERE, base::BindOnce(&EmbeddedTestServer::DoAcceptLoop,
667                                 base::Unretained(this)));
668 }
669 
ShutdownAndWaitUntilComplete()670 bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() {
671   DCHECK(thread_checker_.CalledOnValidThread());
672 
673   if (!io_thread_) {
674     // Can't stop a server that never started.
675     return true;
676   }
677 
678   // Ensure that the AIA HTTP server is no longer Started().
679   bool aia_http_server_not_started = true;
680   if (aia_http_server_ && aia_http_server_->Started()) {
681     aia_http_server_not_started =
682         aia_http_server_->ShutdownAndWaitUntilComplete();
683   }
684 
685   // Return false if either this or the AIA HTTP server are still Started().
686   return PostTaskToIOThreadAndWait(
687              base::BindOnce(&EmbeddedTestServer::ShutdownOnIOThread,
688                             base::Unretained(this))) &&
689          aia_http_server_not_started;
690 }
691 
692 // static
GetRootCertPemPath()693 base::FilePath EmbeddedTestServer::GetRootCertPemPath() {
694   return GetTestCertsDirectory().AppendASCII("root_ca_cert.pem");
695 }
696 
ShutdownOnIOThread()697 void EmbeddedTestServer::ShutdownOnIOThread() {
698   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
699   weak_factory_.InvalidateWeakPtrs();
700   shutdown_closures_.Notify();
701   listen_socket_.reset();
702   connections_.clear();
703 }
704 
GetConnectionForSocket(const StreamSocket * socket)705 HttpConnection* EmbeddedTestServer::GetConnectionForSocket(
706     const StreamSocket* socket) {
707   auto it = connections_.find(socket);
708   if (it != connections_.end()) {
709     return it->second.get();
710   }
711   return nullptr;
712 }
713 
HandleRequest(base::WeakPtr<HttpResponseDelegate> delegate,std::unique_ptr<HttpRequest> request,const StreamSocket * socket)714 void EmbeddedTestServer::HandleRequest(
715     base::WeakPtr<HttpResponseDelegate> delegate,
716     std::unique_ptr<HttpRequest> request,
717     const StreamSocket* socket) {
718   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
719   request->base_url = base_url_;
720 
721   for (const auto& monitor : request_monitors_)
722     monitor.Run(*request);
723 
724   HttpConnection* connection = GetConnectionForSocket(socket);
725   CHECK(connection);
726 
727   if (auth_handler_) {
728     auto auth_result = auth_handler_.Run(*request);
729     if (auth_result) {
730       DispatchResponseToDelegate(std::move(auth_result), delegate);
731       return;
732     }
733   }
734 
735   for (const auto& upgrade_request_handler : upgrade_request_handlers_) {
736     auto upgrade_response = upgrade_request_handler.Run(*request, connection);
737     if (upgrade_response.has_value()) {
738       if (upgrade_response.value() == UpgradeResult::kUpgraded) {
739         connections_.erase(socket);
740         return;
741       }
742     } else {
743       CHECK(upgrade_response.error());
744       DispatchResponseToDelegate(std::move(upgrade_response.error()), delegate);
745       return;
746     }
747   }
748 
749   std::unique_ptr<HttpResponse> response;
750 
751   for (const auto& handler : request_handlers_) {
752     response = handler.Run(*request);
753     if (response)
754       break;
755   }
756 
757   if (!response) {
758     for (const auto& handler : default_request_handlers_) {
759       response = handler.Run(*request);
760       if (response)
761         break;
762     }
763   }
764 
765   if (!response) {
766     LOG(WARNING) << "Request not handled. Returning 404: "
767                  << request->relative_url;
768     auto not_found_response = std::make_unique<BasicHttpResponse>();
769     not_found_response->set_code(HTTP_NOT_FOUND);
770     response = std::move(not_found_response);
771   }
772 
773   DispatchResponseToDelegate(std::move(response), delegate);
774 }
775 
GetURL(std::string_view relative_url) const776 GURL EmbeddedTestServer::GetURL(std::string_view relative_url) const {
777   DCHECK(Started()) << "You must start the server first.";
778   DCHECK(relative_url.starts_with("/")) << relative_url;
779   return base_url_.Resolve(relative_url);
780 }
781 
GetURL(std::string_view hostname,std::string_view relative_url) const782 GURL EmbeddedTestServer::GetURL(std::string_view hostname,
783                                 std::string_view relative_url) const {
784   GURL local_url = GetURL(relative_url);
785   GURL::Replacements replace_host;
786   replace_host.SetHostStr(hostname);
787   return local_url.ReplaceComponents(replace_host);
788 }
789 
GetOrigin(const std::optional<std::string> & hostname) const790 url::Origin EmbeddedTestServer::GetOrigin(
791     const std::optional<std::string>& hostname) const {
792   if (hostname)
793     return url::Origin::Create(GetURL(*hostname, "/"));
794   return url::Origin::Create(base_url_);
795 }
796 
GetAddressList(AddressList * address_list) const797 bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const {
798   *address_list = AddressList(local_endpoint_);
799   return true;
800 }
801 
GetIPLiteralString() const802 std::string EmbeddedTestServer::GetIPLiteralString() const {
803   return local_endpoint_.address().ToString();
804 }
805 
SetSSLConfigInternal(ServerCertificate cert,const ServerCertificateConfig * cert_config,const SSLServerConfig & ssl_config)806 void EmbeddedTestServer::SetSSLConfigInternal(
807     ServerCertificate cert,
808     const ServerCertificateConfig* cert_config,
809     const SSLServerConfig& ssl_config) {
810   DCHECK(!Started());
811   cert_ = cert;
812   DCHECK(!cert_config || cert == CERT_AUTO);
813   cert_config_ = cert_config ? *cert_config : ServerCertificateConfig();
814   x509_cert_ = nullptr;
815   private_key_ = nullptr;
816   ssl_config_ = ssl_config;
817 }
818 
SetSSLConfig(ServerCertificate cert,const SSLServerConfig & ssl_config)819 void EmbeddedTestServer::SetSSLConfig(ServerCertificate cert,
820                                       const SSLServerConfig& ssl_config) {
821   SetSSLConfigInternal(cert, /*cert_config=*/nullptr, ssl_config);
822 }
823 
SetSSLConfig(ServerCertificate cert)824 void EmbeddedTestServer::SetSSLConfig(ServerCertificate cert) {
825   SetSSLConfigInternal(cert, /*cert_config=*/nullptr, SSLServerConfig());
826 }
827 
SetSSLConfig(const ServerCertificateConfig & cert_config,const SSLServerConfig & ssl_config)828 void EmbeddedTestServer::SetSSLConfig(
829     const ServerCertificateConfig& cert_config,
830     const SSLServerConfig& ssl_config) {
831   SetSSLConfigInternal(CERT_AUTO, &cert_config, ssl_config);
832 }
833 
SetSSLConfig(const ServerCertificateConfig & cert_config)834 void EmbeddedTestServer::SetSSLConfig(
835     const ServerCertificateConfig& cert_config) {
836   SetSSLConfigInternal(CERT_AUTO, &cert_config, SSLServerConfig());
837 }
838 
SetCertHostnames(std::vector<std::string> hostnames)839 void EmbeddedTestServer::SetCertHostnames(std::vector<std::string> hostnames) {
840   ServerCertificateConfig cert_config;
841   cert_config.dns_names = std::move(hostnames);
842   cert_config.ip_addresses = {net::IPAddress::IPv4Localhost()};
843   SetSSLConfig(cert_config);
844 }
845 
ResetSSLConfigOnIOThread(ServerCertificate cert,const SSLServerConfig & ssl_config)846 bool EmbeddedTestServer::ResetSSLConfigOnIOThread(
847     ServerCertificate cert,
848     const SSLServerConfig& ssl_config) {
849   cert_ = cert;
850   cert_config_ = ServerCertificateConfig();
851   ssl_config_ = ssl_config;
852   connections_.clear();
853   return InitializeSSLServerContext();
854 }
855 
ResetSSLConfig(ServerCertificate cert,const SSLServerConfig & ssl_config)856 bool EmbeddedTestServer::ResetSSLConfig(ServerCertificate cert,
857                                         const SSLServerConfig& ssl_config) {
858   return PostTaskToIOThreadAndWaitWithResult(
859       base::BindOnce(&EmbeddedTestServer::ResetSSLConfigOnIOThread,
860                      base::Unretained(this), cert, ssl_config));
861 }
862 
GetCertificateName() const863 std::string EmbeddedTestServer::GetCertificateName() const {
864   DCHECK(is_using_ssl_);
865   switch (cert_) {
866     case CERT_OK:
867     case CERT_MISMATCHED_NAME:
868       return "ok_cert.pem";
869     case CERT_COMMON_NAME_IS_DOMAIN:
870       return "localhost_cert.pem";
871     case CERT_EXPIRED:
872       return "expired_cert.pem";
873     case CERT_CHAIN_WRONG_ROOT:
874       // This chain uses its own dedicated test root certificate to avoid
875       // side-effects that may affect testing.
876       return "redundant-server-chain.pem";
877     case CERT_COMMON_NAME_ONLY:
878       return "common_name_only.pem";
879     case CERT_SHA1_LEAF:
880       return "sha1_leaf.pem";
881     case CERT_OK_BY_INTERMEDIATE:
882       return "ok_cert_by_intermediate.pem";
883     case CERT_BAD_VALIDITY:
884       return "bad_validity.pem";
885     case CERT_TEST_NAMES:
886       return "test_names.pem";
887     case CERT_KEY_USAGE_RSA_ENCIPHERMENT:
888       return "key_usage_rsa_keyencipherment.pem";
889     case CERT_KEY_USAGE_RSA_DIGITAL_SIGNATURE:
890       return "key_usage_rsa_digitalsignature.pem";
891     case CERT_AUTO:
892       return std::string();
893   }
894 
895   return "ok_cert.pem";
896 }
897 
GetCertificate()898 scoped_refptr<X509Certificate> EmbeddedTestServer::GetCertificate() {
899   DCHECK(is_using_ssl_);
900   if (!x509_cert_) {
901     // Some tests want to get the certificate before the server has been
902     // initialized, so load it now if necessary. This is only possible if using
903     // a static certificate.
904     // TODO(mattm): change contract to require initializing first in all cases,
905     // update callers.
906     CHECK(UsingStaticCert());
907     // TODO(mattm): change contract to return nullptr on error instead of
908     // CHECKing, update callers.
909     CHECK(InitializeCertAndKeyFromFile());
910   }
911   return x509_cert_;
912 }
913 
GetGeneratedIntermediate()914 scoped_refptr<X509Certificate> EmbeddedTestServer::GetGeneratedIntermediate() {
915   DCHECK(is_using_ssl_);
916   DCHECK(!UsingStaticCert());
917   return intermediate_;
918 }
919 
GetRoot()920 scoped_refptr<X509Certificate> EmbeddedTestServer::GetRoot() {
921   DCHECK(is_using_ssl_);
922   return root_;
923 }
924 
ServeFilesFromDirectory(const base::FilePath & directory)925 void EmbeddedTestServer::ServeFilesFromDirectory(
926     const base::FilePath& directory) {
927   RegisterDefaultHandler(base::BindRepeating(&HandleFileRequest, directory));
928 }
929 
ServeFilesFromSourceDirectory(std::string_view relative)930 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
931     std::string_view relative) {
932   base::FilePath test_data_dir;
933   CHECK(base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir));
934   ServeFilesFromDirectory(test_data_dir.AppendASCII(relative));
935 }
936 
ServeFilesFromSourceDirectory(const base::FilePath & relative)937 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
938     const base::FilePath& relative) {
939   ServeFilesFromDirectory(GetFullPathFromSourceDirectory(relative));
940 }
941 
AddDefaultHandlers(const base::FilePath & directory)942 void EmbeddedTestServer::AddDefaultHandlers(const base::FilePath& directory) {
943   ServeFilesFromSourceDirectory(directory);
944   AddDefaultHandlers();
945 }
946 
AddDefaultHandlers()947 void EmbeddedTestServer::AddDefaultHandlers() {
948   RegisterDefaultHandlers(this);
949 }
950 
GetFullPathFromSourceDirectory(const base::FilePath & relative)951 base::FilePath EmbeddedTestServer::GetFullPathFromSourceDirectory(
952     const base::FilePath& relative) {
953   base::FilePath test_data_dir;
954   CHECK(base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir));
955   return test_data_dir.Append(relative);
956 }
957 
RegisterAuthHandler(const HandleRequestCallback & callback)958 void EmbeddedTestServer::RegisterAuthHandler(
959     const HandleRequestCallback& callback) {
960   CHECK(!io_thread_)
961       << "Handlers must be registered before starting the server.";
962   if (auth_handler_) {
963     DVLOG(2) << "Overwriting existing Auth handler.";
964   }
965   auth_handler_ = callback;
966 }
967 
RegisterUpgradeRequestHandler(const HandleUpgradeRequestCallback & callback)968 void EmbeddedTestServer::RegisterUpgradeRequestHandler(
969     const HandleUpgradeRequestCallback& callback) {
970   CHECK_NE(protocol_, HttpConnection::Protocol::kHttp2)
971       << "RegisterUpgradeRequestHandler() is not supported for HTTP/2 "
972          "connections";
973   CHECK(!io_thread_)
974       << "Handlers must be registered before starting the server.";
975   upgrade_request_handlers_.push_back(callback);
976 }
977 
RegisterRequestHandler(const HandleRequestCallback & callback)978 void EmbeddedTestServer::RegisterRequestHandler(
979     const HandleRequestCallback& callback) {
980   DCHECK(!io_thread_)
981       << "Handlers must be registered before starting the server.";
982   request_handlers_.push_back(callback);
983 }
984 
RegisterRequestMonitor(const MonitorRequestCallback & callback)985 void EmbeddedTestServer::RegisterRequestMonitor(
986     const MonitorRequestCallback& callback) {
987   DCHECK(!io_thread_)
988       << "Monitors must be registered before starting the server.";
989   request_monitors_.push_back(callback);
990 }
991 
RegisterDefaultHandler(const HandleRequestCallback & callback)992 void EmbeddedTestServer::RegisterDefaultHandler(
993     const HandleRequestCallback& callback) {
994   DCHECK(!io_thread_)
995       << "Handlers must be registered before starting the server.";
996   default_request_handlers_.push_back(callback);
997 }
998 
DoSSLUpgrade(std::unique_ptr<StreamSocket> connection)999 std::unique_ptr<SSLServerSocket> EmbeddedTestServer::DoSSLUpgrade(
1000     std::unique_ptr<StreamSocket> connection) {
1001   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
1002 
1003   return context_->CreateSSLServerSocket(std::move(connection));
1004 }
1005 
DoAcceptLoop()1006 void EmbeddedTestServer::DoAcceptLoop() {
1007   while (true) {
1008     int rv = listen_socket_->Accept(
1009         &accepted_socket_,
1010         base::BindOnce(&EmbeddedTestServer::OnAcceptCompleted,
1011                        base::Unretained(this)));
1012     if (rv != OK)
1013       return;
1014 
1015     HandleAcceptResult(std::move(accepted_socket_));
1016   }
1017 }
1018 
FlushAllSocketsAndConnectionsOnUIThread()1019 bool EmbeddedTestServer::FlushAllSocketsAndConnectionsOnUIThread() {
1020   return PostTaskToIOThreadAndWait(
1021       base::BindOnce(&EmbeddedTestServer::FlushAllSocketsAndConnections,
1022                      base::Unretained(this)));
1023 }
1024 
FlushAllSocketsAndConnections()1025 void EmbeddedTestServer::FlushAllSocketsAndConnections() {
1026   connections_.clear();
1027 }
1028 
SetAlpsAcceptCH(std::string hostname,std::string accept_ch)1029 void EmbeddedTestServer::SetAlpsAcceptCH(std::string hostname,
1030                                          std::string accept_ch) {
1031   alps_accept_ch_.insert_or_assign(std::move(hostname), std::move(accept_ch));
1032 }
1033 
RegisterShutdownClosure(base::OnceClosure closure)1034 base::CallbackListSubscription EmbeddedTestServer::RegisterShutdownClosure(
1035     base::OnceClosure closure) {
1036   return shutdown_closures_.Add(std::move(closure));
1037 }
1038 
OnAcceptCompleted(int rv)1039 void EmbeddedTestServer::OnAcceptCompleted(int rv) {
1040   DCHECK_NE(ERR_IO_PENDING, rv);
1041   HandleAcceptResult(std::move(accepted_socket_));
1042   DoAcceptLoop();
1043 }
1044 
OnHandshakeDone(HttpConnection * connection,int rv)1045 void EmbeddedTestServer::OnHandshakeDone(HttpConnection* connection, int rv) {
1046   if (connection->Socket()->IsConnected()) {
1047     connection->OnSocketReady();
1048   } else {
1049     RemoveConnection(connection);
1050   }
1051 }
1052 
HandleAcceptResult(std::unique_ptr<StreamSocket> socket_ptr)1053 void EmbeddedTestServer::HandleAcceptResult(
1054     std::unique_ptr<StreamSocket> socket_ptr) {
1055   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
1056   if (connection_listener_)
1057     socket_ptr = connection_listener_->AcceptedSocket(std::move(socket_ptr));
1058 
1059   if (!is_using_ssl_) {
1060     AddConnection(std::move(socket_ptr))->OnSocketReady();
1061     return;
1062   }
1063 
1064   socket_ptr = DoSSLUpgrade(std::move(socket_ptr));
1065 
1066   StreamSocket* socket = socket_ptr.get();
1067   HttpConnection* connection = AddConnection(std::move(socket_ptr));
1068 
1069   int rv = static_cast<SSLServerSocket*>(socket)->Handshake(
1070       base::BindOnce(&EmbeddedTestServer::OnHandshakeDone,
1071                      base::Unretained(this), connection));
1072   if (rv != ERR_IO_PENDING)
1073     OnHandshakeDone(connection, rv);
1074 }
1075 
AddConnection(std::unique_ptr<StreamSocket> socket_ptr)1076 HttpConnection* EmbeddedTestServer::AddConnection(
1077     std::unique_ptr<StreamSocket> socket_ptr) {
1078   StreamSocket* socket = socket_ptr.get();
1079   std::unique_ptr<HttpConnection> connection_ptr = HttpConnection::Create(
1080       std::move(socket_ptr), connection_listener_, this, protocol_);
1081   HttpConnection* connection = connection_ptr.get();
1082   connections_[socket] = std::move(connection_ptr);
1083 
1084   return connection;
1085 }
1086 
RemoveConnection(HttpConnection * connection,EmbeddedTestServerConnectionListener * listener)1087 void EmbeddedTestServer::RemoveConnection(
1088     HttpConnection* connection,
1089     EmbeddedTestServerConnectionListener* listener) {
1090   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
1091   DCHECK(connection);
1092   DCHECK_EQ(1u, connections_.count(connection->Socket()));
1093 
1094   StreamSocket* raw_socket = connection->Socket();
1095   std::unique_ptr<StreamSocket> socket = connection->TakeSocket();
1096   connections_.erase(raw_socket);
1097 
1098   if (listener && socket && socket->IsConnected())
1099     listener->OnResponseCompletedSuccessfully(std::move(socket));
1100 }
1101 
PostTaskToIOThreadAndWait(base::OnceClosure closure)1102 bool EmbeddedTestServer::PostTaskToIOThreadAndWait(base::OnceClosure closure) {
1103   // Note that PostTaskAndReply below requires
1104   // base::SingleThreadTaskRunner::GetCurrentDefault() to return a task runner
1105   // for posting the reply task. However, in order to make EmbeddedTestServer
1106   // universally usable, it needs to cope with the situation where it's running
1107   // on a thread on which a task executor is not (yet) available or has been
1108   // destroyed already.
1109   //
1110   // To handle this situation, create temporary task executor to support the
1111   // PostTaskAndReply operation if the current thread has no task executor.
1112   // TODO(mattm): Is this still necessary/desirable? Try removing this and see
1113   // if anything breaks.
1114   std::unique_ptr<base::SingleThreadTaskExecutor> temporary_loop;
1115   if (!base::CurrentThread::Get())
1116     temporary_loop = std::make_unique<base::SingleThreadTaskExecutor>();
1117 
1118   base::RunLoop run_loop;
1119   if (!io_thread_->task_runner()->PostTaskAndReply(
1120           FROM_HERE, std::move(closure), run_loop.QuitClosure())) {
1121     return false;
1122   }
1123   run_loop.Run();
1124 
1125   return true;
1126 }
1127 
PostTaskToIOThreadAndWaitWithResult(base::OnceCallback<bool ()> task)1128 bool EmbeddedTestServer::PostTaskToIOThreadAndWaitWithResult(
1129     base::OnceCallback<bool()> task) {
1130   // Note that PostTaskAndReply below requires
1131   // base::SingleThreadTaskRunner::GetCurrentDefault() to return a task runner
1132   // for posting the reply task. However, in order to make EmbeddedTestServer
1133   // universally usable, it needs to cope with the situation where it's running
1134   // on a thread on which a task executor is not (yet) available or has been
1135   // destroyed already.
1136   //
1137   // To handle this situation, create temporary task executor to support the
1138   // PostTaskAndReply operation if the current thread has no task executor.
1139   // TODO(mattm): Is this still necessary/desirable? Try removing this and see
1140   // if anything breaks.
1141   std::unique_ptr<base::SingleThreadTaskExecutor> temporary_loop;
1142   if (!base::CurrentThread::Get())
1143     temporary_loop = std::make_unique<base::SingleThreadTaskExecutor>();
1144 
1145   base::RunLoop run_loop;
1146   bool task_result = false;
1147   if (!io_thread_->task_runner()->PostTaskAndReplyWithResult(
1148           FROM_HERE, std::move(task),
1149           base::BindOnce(base::BindLambdaForTesting([&](bool result) {
1150             task_result = result;
1151             run_loop.Quit();
1152           })))) {
1153     return false;
1154   }
1155   run_loop.Run();
1156 
1157   return task_result;
1158 }
1159 
1160 }  // namespace net::test_server
1161