• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/test/embedded_test_server/embedded_test_server.h"
6 
7 #include <stdint.h>
8 
9 #include <memory>
10 #include <utility>
11 
12 #include "base/files/file_path.h"
13 #include "base/functional/bind.h"
14 #include "base/functional/callback_forward.h"
15 #include "base/functional/callback_helpers.h"
16 #include "base/location.h"
17 #include "base/logging.h"
18 #include "base/message_loop/message_pump_type.h"
19 #include "base/path_service.h"
20 #include "base/process/process_metrics.h"
21 #include "base/run_loop.h"
22 #include "base/strings/string_number_conversions.h"
23 #include "base/strings/string_piece.h"
24 #include "base/strings/string_util.h"
25 #include "base/strings/stringprintf.h"
26 #include "base/task/current_thread.h"
27 #include "base/task/single_thread_task_executor.h"
28 #include "base/task/single_thread_task_runner.h"
29 #include "base/test/bind.h"
30 #include "base/threading/thread_restrictions.h"
31 #include "crypto/rsa_private_key.h"
32 #include "net/base/hex_utils.h"
33 #include "net/base/ip_endpoint.h"
34 #include "net/base/net_errors.h"
35 #include "net/base/port_util.h"
36 #include "net/cert/pki/extended_key_usage.h"
37 #include "net/log/net_log_source.h"
38 #include "net/socket/next_proto.h"
39 #include "net/socket/ssl_server_socket.h"
40 #include "net/socket/stream_socket.h"
41 #include "net/socket/tcp_server_socket.h"
42 #include "net/spdy/spdy_test_util_common.h"
43 #include "net/ssl/ssl_info.h"
44 #include "net/ssl/ssl_server_config.h"
45 #include "net/test/cert_builder.h"
46 #include "net/test/cert_test_util.h"
47 #include "net/test/embedded_test_server/default_handlers.h"
48 #include "net/test/embedded_test_server/embedded_test_server_connection_listener.h"
49 #include "net/test/embedded_test_server/http_request.h"
50 #include "net/test/embedded_test_server/http_response.h"
51 #include "net/test/embedded_test_server/request_handler_util.h"
52 #include "net/test/key_util.h"
53 #include "net/test/revocation_builder.h"
54 #include "net/test/test_data_directory.h"
55 #include "net/third_party/quiche/src/quiche/spdy/core/spdy_frame_builder.h"
56 #include "third_party/abseil-cpp/absl/types/optional.h"
57 #include "url/origin.h"
58 
59 namespace net::test_server {
60 
61 namespace {
62 
ServeResponseForPath(const std::string & expected_path,HttpStatusCode status_code,const std::string & content_type,const std::string & content,const HttpRequest & request)63 std::unique_ptr<HttpResponse> ServeResponseForPath(
64     const std::string& expected_path,
65     HttpStatusCode status_code,
66     const std::string& content_type,
67     const std::string& content,
68     const HttpRequest& request) {
69   if (request.GetURL().path() != expected_path)
70     return nullptr;
71 
72   auto http_response = std::make_unique<BasicHttpResponse>();
73   http_response->set_code(status_code);
74   http_response->set_content_type(content_type);
75   http_response->set_content(content);
76   return http_response;
77 }
78 
79 // Serves response for |expected_path| or any subpath of it.
80 // |expected_path| should not include a trailing "/".
ServeResponseForSubPaths(const std::string & expected_path,HttpStatusCode status_code,const std::string & content_type,const std::string & content,const HttpRequest & request)81 std::unique_ptr<HttpResponse> ServeResponseForSubPaths(
82     const std::string& expected_path,
83     HttpStatusCode status_code,
84     const std::string& content_type,
85     const std::string& content,
86     const HttpRequest& request) {
87   if (request.GetURL().path() != expected_path &&
88       !base::StartsWith(request.GetURL().path(), expected_path + "/",
89                         base::CompareCase::SENSITIVE)) {
90     return nullptr;
91   }
92 
93   auto http_response = std::make_unique<BasicHttpResponse>();
94   http_response->set_code(status_code);
95   http_response->set_content_type(content_type);
96   http_response->set_content(content);
97   return http_response;
98 }
99 
MaybeCreateOCSPResponse(CertBuilder * target,const EmbeddedTestServer::OCSPConfig & config,std::string * out_response)100 bool MaybeCreateOCSPResponse(CertBuilder* target,
101                              const EmbeddedTestServer::OCSPConfig& config,
102                              std::string* out_response) {
103   using OCSPResponseType = EmbeddedTestServer::OCSPConfig::ResponseType;
104 
105   if (!config.single_responses.empty() &&
106       config.response_type != OCSPResponseType::kSuccessful) {
107     // OCSPConfig contained single_responses for a non-successful response.
108     return false;
109   }
110 
111   if (config.response_type == OCSPResponseType::kOff) {
112     *out_response = std::string();
113     return true;
114   }
115 
116   if (!target) {
117     // OCSPConfig enabled but corresponding certificate is null.
118     return false;
119   }
120 
121   switch (config.response_type) {
122     case OCSPResponseType::kOff:
123       return false;
124     case OCSPResponseType::kMalformedRequest:
125       *out_response = BuildOCSPResponseError(
126           OCSPResponse::ResponseStatus::MALFORMED_REQUEST);
127       return true;
128     case OCSPResponseType::kInternalError:
129       *out_response =
130           BuildOCSPResponseError(OCSPResponse::ResponseStatus::INTERNAL_ERROR);
131       return true;
132     case OCSPResponseType::kTryLater:
133       *out_response =
134           BuildOCSPResponseError(OCSPResponse::ResponseStatus::TRY_LATER);
135       return true;
136     case OCSPResponseType::kSigRequired:
137       *out_response =
138           BuildOCSPResponseError(OCSPResponse::ResponseStatus::SIG_REQUIRED);
139       return true;
140     case OCSPResponseType::kUnauthorized:
141       *out_response =
142           BuildOCSPResponseError(OCSPResponse::ResponseStatus::UNAUTHORIZED);
143       return true;
144     case OCSPResponseType::kInvalidResponse:
145       *out_response = "3";
146       return true;
147     case OCSPResponseType::kInvalidResponseData:
148       *out_response =
149           BuildOCSPResponseWithResponseData(target->issuer()->GetKey(),
150                                             // OCTET_STRING { "not ocsp data" }
151                                             "\x04\x0dnot ocsp data");
152       return true;
153     case OCSPResponseType::kSuccessful:
154       break;
155   }
156 
157   base::Time now = base::Time::Now();
158   base::Time target_not_before, target_not_after;
159   if (!target->GetValidity(&target_not_before, &target_not_after))
160     return false;
161   base::Time produced_at;
162   using OCSPProduced = EmbeddedTestServer::OCSPConfig::Produced;
163   switch (config.produced) {
164     case OCSPProduced::kValid:
165       produced_at = std::max(now - base::Days(1), target_not_before);
166       break;
167     case OCSPProduced::kBeforeCert:
168       produced_at = target_not_before - base::Days(1);
169       break;
170     case OCSPProduced::kAfterCert:
171       produced_at = target_not_after + base::Days(1);
172       break;
173   }
174 
175   std::vector<OCSPBuilderSingleResponse> responses;
176   for (const auto& config_response : config.single_responses) {
177     OCSPBuilderSingleResponse response;
178     response.serial = target->GetSerialNumber();
179     if (config_response.serial ==
180         EmbeddedTestServer::OCSPConfig::SingleResponse::Serial::kMismatch) {
181       response.serial ^= 1;
182     }
183     response.cert_status = config_response.cert_status;
184     // |revocation_time| is ignored if |cert_status| is not REVOKED.
185     response.revocation_time = now - base::Days(1000);
186 
187     using OCSPDate = EmbeddedTestServer::OCSPConfig::SingleResponse::Date;
188     switch (config_response.ocsp_date) {
189       case OCSPDate::kValid:
190         response.this_update = now - base::Days(1);
191         response.next_update = response.this_update + base::Days(7);
192         break;
193       case OCSPDate::kOld:
194         response.this_update = now - base::Days(8);
195         response.next_update = response.this_update + base::Days(7);
196         break;
197       case OCSPDate::kEarly:
198         response.this_update = now + base::Days(1);
199         response.next_update = response.this_update + base::Days(7);
200         break;
201       case OCSPDate::kLong:
202         response.this_update = now - base::Days(365);
203         response.next_update = response.this_update + base::Days(366);
204         break;
205       case OCSPDate::kLonger:
206         response.this_update = now - base::Days(367);
207         response.next_update = response.this_update + base::Days(368);
208         break;
209     }
210 
211     responses.push_back(response);
212   }
213   *out_response =
214       BuildOCSPResponse(target->issuer()->GetSubject(),
215                         target->issuer()->GetKey(), produced_at, responses);
216   return true;
217 }
218 
219 }  // namespace
220 
EmbeddedTestServerHandle(EmbeddedTestServerHandle && other)221 EmbeddedTestServerHandle::EmbeddedTestServerHandle(
222     EmbeddedTestServerHandle&& other) {
223   operator=(std::move(other));
224 }
225 
operator =(EmbeddedTestServerHandle && other)226 EmbeddedTestServerHandle& EmbeddedTestServerHandle::operator=(
227     EmbeddedTestServerHandle&& other) {
228   EmbeddedTestServerHandle temporary;
229   std::swap(other.test_server_, temporary.test_server_);
230   std::swap(temporary.test_server_, test_server_);
231   return *this;
232 }
233 
EmbeddedTestServerHandle(EmbeddedTestServer * test_server)234 EmbeddedTestServerHandle::EmbeddedTestServerHandle(
235     EmbeddedTestServer* test_server)
236     : test_server_(test_server) {}
237 
~EmbeddedTestServerHandle()238 EmbeddedTestServerHandle::~EmbeddedTestServerHandle() {
239   if (test_server_)
240     CHECK(test_server_->ShutdownAndWaitUntilComplete());
241 }
242 
243 EmbeddedTestServer::OCSPConfig::OCSPConfig() = default;
OCSPConfig(ResponseType response_type)244 EmbeddedTestServer::OCSPConfig::OCSPConfig(ResponseType response_type)
245     : response_type(response_type) {}
OCSPConfig(std::vector<SingleResponse> single_responses,Produced produced)246 EmbeddedTestServer::OCSPConfig::OCSPConfig(
247     std::vector<SingleResponse> single_responses,
248     Produced produced)
249     : response_type(ResponseType::kSuccessful),
250       produced(produced),
251       single_responses(std::move(single_responses)) {}
252 EmbeddedTestServer::OCSPConfig::OCSPConfig(const OCSPConfig&) = default;
253 EmbeddedTestServer::OCSPConfig::OCSPConfig(OCSPConfig&&) = default;
254 EmbeddedTestServer::OCSPConfig::~OCSPConfig() = default;
255 EmbeddedTestServer::OCSPConfig& EmbeddedTestServer::OCSPConfig::operator=(
256     const OCSPConfig&) = default;
257 EmbeddedTestServer::OCSPConfig& EmbeddedTestServer::OCSPConfig::operator=(
258     OCSPConfig&&) = default;
259 
260 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig() =
261     default;
262 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig(
263     const ServerCertificateConfig&) = default;
264 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig(
265     ServerCertificateConfig&&) = default;
266 EmbeddedTestServer::ServerCertificateConfig::~ServerCertificateConfig() =
267     default;
268 EmbeddedTestServer::ServerCertificateConfig&
269 EmbeddedTestServer::ServerCertificateConfig::operator=(
270     const ServerCertificateConfig&) = default;
271 EmbeddedTestServer::ServerCertificateConfig&
272 EmbeddedTestServer::ServerCertificateConfig::operator=(
273     ServerCertificateConfig&&) = default;
274 
EmbeddedTestServer()275 EmbeddedTestServer::EmbeddedTestServer() : EmbeddedTestServer(TYPE_HTTP) {}
276 
EmbeddedTestServer(Type type,HttpConnection::Protocol protocol)277 EmbeddedTestServer::EmbeddedTestServer(Type type,
278                                        HttpConnection::Protocol protocol)
279     : is_using_ssl_(type == TYPE_HTTPS), protocol_(protocol) {
280   DCHECK(thread_checker_.CalledOnValidThread());
281   // HTTP/2 is only valid by negotiation via TLS ALPN
282   DCHECK(protocol_ != HttpConnection::Protocol::kHttp2 || type == TYPE_HTTPS);
283 
284   if (!is_using_ssl_)
285     return;
286   scoped_test_root_ = RegisterTestCerts();
287 }
288 
~EmbeddedTestServer()289 EmbeddedTestServer::~EmbeddedTestServer() {
290   DCHECK(thread_checker_.CalledOnValidThread());
291 
292   if (Started())
293     CHECK(ShutdownAndWaitUntilComplete());
294 
295   {
296     base::ScopedAllowBaseSyncPrimitivesForTesting allow_wait_for_thread_join;
297     io_thread_.reset();
298   }
299 }
300 
RegisterTestCerts()301 ScopedTestRoot EmbeddedTestServer::RegisterTestCerts() {
302   base::ScopedAllowBlockingForTesting allow_blocking;
303   auto root = ImportCertFromFile(GetRootCertPemPath());
304   if (!root)
305     return ScopedTestRoot();
306   return ScopedTestRoot(root.get());
307 }
308 
SetConnectionListener(EmbeddedTestServerConnectionListener * listener)309 void EmbeddedTestServer::SetConnectionListener(
310     EmbeddedTestServerConnectionListener* listener) {
311   DCHECK(!io_thread_)
312       << "ConnectionListener must be set before starting the server.";
313   connection_listener_ = listener;
314 }
315 
StartAndReturnHandle(int port)316 EmbeddedTestServerHandle EmbeddedTestServer::StartAndReturnHandle(int port) {
317   bool result = Start(port);
318   return result ? EmbeddedTestServerHandle(this) : EmbeddedTestServerHandle();
319 }
320 
Start(int port)321 bool EmbeddedTestServer::Start(int port) {
322   bool success = InitializeAndListen(port);
323   if (success)
324     StartAcceptingConnections();
325   return success;
326 }
327 
InitializeAndListen(int port)328 bool EmbeddedTestServer::InitializeAndListen(int port) {
329   DCHECK(!Started());
330 
331   const int max_tries = 5;
332   int num_tries = 0;
333   bool is_valid_port = false;
334 
335   do {
336     if (++num_tries > max_tries) {
337       LOG(ERROR) << "Failed to listen on a valid port after " << max_tries
338                  << " attempts.";
339       listen_socket_.reset();
340       return false;
341     }
342 
343     listen_socket_ = std::make_unique<TCPServerSocket>(nullptr, NetLogSource());
344 
345     int result =
346         listen_socket_->ListenWithAddressAndPort("127.0.0.1", port, 10);
347     if (result) {
348       LOG(ERROR) << "Listen failed: " << ErrorToString(result);
349       listen_socket_.reset();
350       return false;
351     }
352 
353     result = listen_socket_->GetLocalAddress(&local_endpoint_);
354     if (result != OK) {
355       LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
356       listen_socket_.reset();
357       return false;
358     }
359 
360     port_ = local_endpoint_.port();
361     is_valid_port |= net::IsPortAllowedForScheme(
362         port_, is_using_ssl_ ? url::kHttpsScheme : url::kHttpScheme);
363   } while (!is_valid_port);
364 
365   if (is_using_ssl_) {
366     base_url_ = GURL("https://" + local_endpoint_.ToString());
367     if (cert_ == CERT_MISMATCHED_NAME || cert_ == CERT_COMMON_NAME_IS_DOMAIN) {
368       base_url_ = GURL(
369           base::StringPrintf("https://localhost:%d", local_endpoint_.port()));
370     }
371   } else {
372     base_url_ = GURL("http://" + local_endpoint_.ToString());
373   }
374 
375   listen_socket_->DetachFromThread();
376 
377   if (is_using_ssl_ && !InitializeSSLServerContext())
378     return false;
379 
380   return true;
381 }
382 
UsingStaticCert() const383 bool EmbeddedTestServer::UsingStaticCert() const {
384   return !GetCertificateName().empty();
385 }
386 
InitializeCertAndKeyFromFile()387 bool EmbeddedTestServer::InitializeCertAndKeyFromFile() {
388   base::ScopedAllowBlockingForTesting allow_blocking;
389   base::FilePath certs_dir(GetTestCertsDirectory());
390   std::string cert_name = GetCertificateName();
391   if (cert_name.empty())
392     return false;
393 
394   x509_cert_ = CreateCertificateChainFromFile(certs_dir, cert_name,
395                                               X509Certificate::FORMAT_AUTO);
396   if (!x509_cert_)
397     return false;
398 
399   private_key_ =
400       key_util::LoadEVP_PKEYFromPEM(certs_dir.AppendASCII(cert_name));
401   return !!private_key_;
402 }
403 
GenerateCertAndKey()404 bool EmbeddedTestServer::GenerateCertAndKey() {
405   // Create AIA server and start listening. Need to have the socket initialized
406   // so the URL can be put in the AIA records of the generated certs.
407   aia_http_server_ = std::make_unique<EmbeddedTestServer>(TYPE_HTTP);
408   if (!aia_http_server_->InitializeAndListen())
409     return false;
410 
411   base::ScopedAllowBlockingForTesting allow_blocking;
412   base::FilePath certs_dir(GetTestCertsDirectory());
413 
414   std::unique_ptr<CertBuilder> static_root = CertBuilder::FromStaticCertFile(
415       certs_dir.AppendASCII("root_ca_cert.pem"));
416 
417   auto now = base::Time::Now();
418   // Will be nullptr if cert_config_.intermediate == kNone.
419   std::unique_ptr<CertBuilder> intermediate;
420   std::unique_ptr<CertBuilder> leaf;
421 
422   if (cert_config_.intermediate != IntermediateType::kNone) {
423     intermediate = CertBuilder::FromFile(
424         certs_dir.AppendASCII("intermediate_ca_cert.pem"), static_root.get());
425     if (!intermediate)
426       return false;
427     intermediate->SetValidity(now - base::Days(100), now + base::Days(1000));
428 
429     leaf = CertBuilder::FromFile(certs_dir.AppendASCII("ok_cert.pem"),
430                                  intermediate.get());
431     // Workaround for weird CertVerifyProcWin issue where if too many
432     // intermediates with the same key are fetched by AIA any further
433     // verifications using that key will fail. See
434     // https://crbug.com/1328060. Since generating ECDSA keys is cheap, just do
435     // this on all configurations rather than restricting to Windows, though
436     // this hack can be removed once we delete CertVerifyProcWin.
437     if (cert_config_.intermediate == IntermediateType::kByAIA) {
438       intermediate->GenerateECKey();
439       leaf->SetSignatureAlgorithm(SignatureAlgorithm::kEcdsaSha256);
440     }
441   } else {
442     leaf = CertBuilder::FromFile(certs_dir.AppendASCII("ok_cert.pem"),
443                                  static_root.get());
444   }
445   if (!leaf)
446     return false;
447 
448   std::vector<GURL> leaf_ca_issuers_urls;
449   std::vector<GURL> leaf_ocsp_urls;
450 
451   leaf->SetValidity(now - base::Days(1), now + base::Days(20));
452 
453   if (!cert_config_.policy_oids.empty()) {
454     leaf->SetCertificatePolicies(cert_config_.policy_oids);
455     if (intermediate)
456       intermediate->SetCertificatePolicies(cert_config_.policy_oids);
457   }
458 
459   if (!cert_config_.dns_names.empty() || !cert_config_.ip_addresses.empty()) {
460     leaf->SetSubjectAltNames(cert_config_.dns_names, cert_config_.ip_addresses);
461   }
462 
463   const std::string leaf_serial_text =
464       base::NumberToString(leaf->GetSerialNumber());
465   const std::string intermediate_serial_text =
466       intermediate ? base::NumberToString(intermediate->GetSerialNumber()) : "";
467 
468   std::string ocsp_response;
469   if (!MaybeCreateOCSPResponse(leaf.get(), cert_config_.ocsp_config,
470                                &ocsp_response)) {
471     return false;
472   }
473   if (!ocsp_response.empty()) {
474     std::string ocsp_path = "/ocsp/" + leaf_serial_text;
475     leaf_ocsp_urls.push_back(aia_http_server_->GetURL(ocsp_path));
476     aia_http_server_->RegisterRequestHandler(
477         base::BindRepeating(ServeResponseForSubPaths, ocsp_path, HTTP_OK,
478                             "application/ocsp-response", ocsp_response));
479   }
480 
481   std::string stapled_ocsp_response;
482   if (!MaybeCreateOCSPResponse(leaf.get(), cert_config_.stapled_ocsp_config,
483                                &stapled_ocsp_response)) {
484     return false;
485   }
486   if (!stapled_ocsp_response.empty()) {
487     ssl_config_.ocsp_response = std::vector<uint8_t>(
488         stapled_ocsp_response.begin(), stapled_ocsp_response.end());
489   }
490 
491   std::string intermediate_ocsp_response;
492   if (!MaybeCreateOCSPResponse(intermediate.get(),
493                                cert_config_.intermediate_ocsp_config,
494                                &intermediate_ocsp_response)) {
495     return false;
496   }
497   if (!intermediate_ocsp_response.empty()) {
498     std::string intermediate_ocsp_path = "/ocsp/" + intermediate_serial_text;
499     intermediate->SetCaIssuersAndOCSPUrls(
500         {}, {aia_http_server_->GetURL(intermediate_ocsp_path)});
501     aia_http_server_->RegisterRequestHandler(base::BindRepeating(
502         ServeResponseForSubPaths, intermediate_ocsp_path, HTTP_OK,
503         "application/ocsp-response", intermediate_ocsp_response));
504   }
505 
506   if (cert_config_.intermediate == IntermediateType::kByAIA) {
507     std::string ca_issuers_path = "/ca_issuers/" + intermediate_serial_text;
508     leaf_ca_issuers_urls.push_back(aia_http_server_->GetURL(ca_issuers_path));
509 
510     // Setup AIA server to serve the intermediate referred to by the leaf.
511     aia_http_server_->RegisterRequestHandler(
512         base::BindRepeating(ServeResponseForPath, ca_issuers_path, HTTP_OK,
513                             "application/pkix-cert", intermediate->GetDER()));
514   }
515 
516   if (!leaf_ca_issuers_urls.empty() || !leaf_ocsp_urls.empty()) {
517     leaf->SetCaIssuersAndOCSPUrls(leaf_ca_issuers_urls, leaf_ocsp_urls);
518   }
519 
520   if (cert_config_.intermediate == IntermediateType::kByAIA) {
521     // Server certificate chain does not include the intermediate.
522     x509_cert_ = leaf->GetX509Certificate();
523   } else {
524     // Server certificate chain will include the intermediate, if there is one.
525     x509_cert_ = leaf->GetX509CertificateChain();
526   }
527 
528   private_key_ = bssl::UpRef(leaf->GetKey());
529 
530   // If this server is already accepting connections but is being reconfigured,
531   // start the new AIA server now. Otherwise, wait until
532   // StartAcceptingConnections so that this server and the AIA server start at
533   // the same time. (If the test only called InitializeAndListen they expect no
534   // threads to be created yet.)
535   if (io_thread_)
536     aia_http_server_->StartAcceptingConnections();
537 
538   return true;
539 }
540 
InitializeSSLServerContext()541 bool EmbeddedTestServer::InitializeSSLServerContext() {
542   if (UsingStaticCert()) {
543     if (!InitializeCertAndKeyFromFile())
544       return false;
545   } else {
546     if (!GenerateCertAndKey())
547       return false;
548   }
549 
550   if (protocol_ == HttpConnection::Protocol::kHttp2) {
551     ssl_config_.alpn_protos = {NextProto::kProtoHTTP2};
552     if (!alps_accept_ch_.empty()) {
553       base::StringPairs origin_accept_ch;
554       size_t frame_size = spdy::kFrameHeaderSize;
555       // Figure out size and generate origins
556       for (const auto& pair : alps_accept_ch_) {
557         base::StringPiece hostname = pair.first;
558         std::string accept_ch = pair.second;
559 
560         GURL url = hostname.empty() ? GetURL("/") : GetURL(hostname, "/");
561         std::string origin = url::Origin::Create(url).Serialize();
562 
563         frame_size += accept_ch.size() + origin.size() +
564                       (sizeof(uint16_t) * 2);  // = Origin-Len + Value-Len
565 
566         origin_accept_ch.push_back({std::move(origin), std::move(accept_ch)});
567       }
568 
569       spdy::SpdyFrameBuilder builder(frame_size);
570       builder.BeginNewFrame(spdy::SpdyFrameType::ACCEPT_CH, 0, 0);
571       for (const auto& pair : origin_accept_ch) {
572         base::StringPiece origin = pair.first;
573         base::StringPiece accept_ch = pair.second;
574 
575         builder.WriteUInt16(origin.size());
576         builder.WriteBytes(origin.data(), origin.size());
577 
578         builder.WriteUInt16(accept_ch.size());
579         builder.WriteBytes(accept_ch.data(), accept_ch.size());
580       }
581 
582       spdy::SpdySerializedFrame serialized_frame = builder.take();
583       DCHECK_EQ(frame_size, serialized_frame.size());
584 
585       ssl_config_.application_settings[NextProto::kProtoHTTP2] =
586           std::vector<uint8_t>(
587               serialized_frame.data(),
588               serialized_frame.data() + serialized_frame.size());
589     }
590   }
591 
592   context_ =
593       CreateSSLServerContext(x509_cert_.get(), private_key_.get(), ssl_config_);
594   return true;
595 }
596 
597 EmbeddedTestServerHandle
StartAcceptingConnectionsAndReturnHandle()598 EmbeddedTestServer::StartAcceptingConnectionsAndReturnHandle() {
599   return EmbeddedTestServerHandle(this);
600 }
601 
StartAcceptingConnections()602 void EmbeddedTestServer::StartAcceptingConnections() {
603   DCHECK(Started());
604   DCHECK(!io_thread_) << "Server must not be started while server is running";
605 
606   if (aia_http_server_)
607     aia_http_server_->StartAcceptingConnections();
608 
609   base::Thread::Options thread_options;
610   thread_options.message_pump_type = base::MessagePumpType::IO;
611   io_thread_ = std::make_unique<base::Thread>("EmbeddedTestServer IO Thread");
612   CHECK(io_thread_->StartWithOptions(std::move(thread_options)));
613   CHECK(io_thread_->WaitUntilThreadStarted());
614 
615   io_thread_->task_runner()->PostTask(
616       FROM_HERE, base::BindOnce(&EmbeddedTestServer::DoAcceptLoop,
617                                 base::Unretained(this)));
618 }
619 
ShutdownAndWaitUntilComplete()620 bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() {
621   DCHECK(thread_checker_.CalledOnValidThread());
622 
623   // Ensure that the AIA HTTP server is no longer Started().
624   bool aia_http_server_not_started = true;
625   if (aia_http_server_ && aia_http_server_->Started()) {
626     aia_http_server_not_started =
627         aia_http_server_->ShutdownAndWaitUntilComplete();
628   }
629 
630   // Return false if either this or the AIA HTTP server are still Started().
631   return PostTaskToIOThreadAndWait(
632              base::BindOnce(&EmbeddedTestServer::ShutdownOnIOThread,
633                             base::Unretained(this))) &&
634          aia_http_server_not_started;
635 }
636 
637 // static
GetRootCertPemPath()638 base::FilePath EmbeddedTestServer::GetRootCertPemPath() {
639   return GetTestCertsDirectory().AppendASCII("root_ca_cert.pem");
640 }
641 
ShutdownOnIOThread()642 void EmbeddedTestServer::ShutdownOnIOThread() {
643   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
644   weak_factory_.InvalidateWeakPtrs();
645   listen_socket_.reset();
646   connections_.clear();
647 }
648 
HandleRequest(base::WeakPtr<HttpResponseDelegate> delegate,std::unique_ptr<HttpRequest> request)649 void EmbeddedTestServer::HandleRequest(
650     base::WeakPtr<HttpResponseDelegate> delegate,
651     std::unique_ptr<HttpRequest> request) {
652   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
653   request->base_url = base_url_;
654 
655   for (const auto& monitor : request_monitors_)
656     monitor.Run(*request);
657 
658   std::unique_ptr<HttpResponse> response;
659 
660   for (const auto& handler : request_handlers_) {
661     response = handler.Run(*request);
662     if (response)
663       break;
664   }
665 
666   if (!response) {
667     for (const auto& handler : default_request_handlers_) {
668       response = handler.Run(*request);
669       if (response)
670         break;
671     }
672   }
673 
674   if (!response) {
675     LOG(WARNING) << "Request not handled. Returning 404: "
676                  << request->relative_url;
677     auto not_found_response = std::make_unique<BasicHttpResponse>();
678     not_found_response->set_code(HTTP_NOT_FOUND);
679     response = std::move(not_found_response);
680   }
681 
682   HttpResponse* const response_ptr = response.get();
683   delegate->AddResponse(std::move(response));
684   response_ptr->SendResponse(delegate);
685 }
686 
GetURL(base::StringPiece relative_url) const687 GURL EmbeddedTestServer::GetURL(base::StringPiece relative_url) const {
688   DCHECK(Started()) << "You must start the server first.";
689   DCHECK(base::StartsWith(relative_url, "/", base::CompareCase::SENSITIVE))
690       << relative_url;
691   return base_url_.Resolve(relative_url);
692 }
693 
GetURL(base::StringPiece hostname,base::StringPiece relative_url) const694 GURL EmbeddedTestServer::GetURL(base::StringPiece hostname,
695                                 base::StringPiece relative_url) const {
696   GURL local_url = GetURL(relative_url);
697   GURL::Replacements replace_host;
698   replace_host.SetHostStr(hostname);
699   return local_url.ReplaceComponents(replace_host);
700 }
701 
GetOrigin(const absl::optional<std::string> & hostname) const702 url::Origin EmbeddedTestServer::GetOrigin(
703     const absl::optional<std::string>& hostname) const {
704   if (hostname)
705     return url::Origin::Create(GetURL(*hostname, "/"));
706   return url::Origin::Create(base_url_);
707 }
708 
GetAddressList(AddressList * address_list) const709 bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const {
710   *address_list = AddressList(local_endpoint_);
711   return true;
712 }
713 
GetIPLiteralString() const714 std::string EmbeddedTestServer::GetIPLiteralString() const {
715   return local_endpoint_.address().ToString();
716 }
717 
SetSSLConfigInternal(ServerCertificate cert,const ServerCertificateConfig * cert_config,const SSLServerConfig & ssl_config)718 void EmbeddedTestServer::SetSSLConfigInternal(
719     ServerCertificate cert,
720     const ServerCertificateConfig* cert_config,
721     const SSLServerConfig& ssl_config) {
722   DCHECK(!Started());
723   cert_ = cert;
724   DCHECK(!cert_config || cert == CERT_AUTO);
725   cert_config_ = cert_config ? *cert_config : ServerCertificateConfig();
726   x509_cert_ = nullptr;
727   private_key_ = nullptr;
728   ssl_config_ = ssl_config;
729 }
730 
SetSSLConfig(ServerCertificate cert,const SSLServerConfig & ssl_config)731 void EmbeddedTestServer::SetSSLConfig(ServerCertificate cert,
732                                       const SSLServerConfig& ssl_config) {
733   SetSSLConfigInternal(cert, /*cert_config=*/nullptr, ssl_config);
734 }
735 
SetSSLConfig(ServerCertificate cert)736 void EmbeddedTestServer::SetSSLConfig(ServerCertificate cert) {
737   SetSSLConfigInternal(cert, /*cert_config=*/nullptr, SSLServerConfig());
738 }
739 
SetSSLConfig(const ServerCertificateConfig & cert_config,const SSLServerConfig & ssl_config)740 void EmbeddedTestServer::SetSSLConfig(
741     const ServerCertificateConfig& cert_config,
742     const SSLServerConfig& ssl_config) {
743   SetSSLConfigInternal(CERT_AUTO, &cert_config, ssl_config);
744 }
745 
SetSSLConfig(const ServerCertificateConfig & cert_config)746 void EmbeddedTestServer::SetSSLConfig(
747     const ServerCertificateConfig& cert_config) {
748   SetSSLConfigInternal(CERT_AUTO, &cert_config, SSLServerConfig());
749 }
750 
ResetSSLConfigOnIOThread(ServerCertificate cert,const SSLServerConfig & ssl_config)751 bool EmbeddedTestServer::ResetSSLConfigOnIOThread(
752     ServerCertificate cert,
753     const SSLServerConfig& ssl_config) {
754   cert_ = cert;
755   cert_config_ = ServerCertificateConfig();
756   ssl_config_ = ssl_config;
757   connections_.clear();
758   return InitializeSSLServerContext();
759 }
760 
ResetSSLConfig(ServerCertificate cert,const SSLServerConfig & ssl_config)761 bool EmbeddedTestServer::ResetSSLConfig(ServerCertificate cert,
762                                         const SSLServerConfig& ssl_config) {
763   return PostTaskToIOThreadAndWaitWithResult(
764       base::BindOnce(&EmbeddedTestServer::ResetSSLConfigOnIOThread,
765                      base::Unretained(this), cert, ssl_config));
766 }
767 
GetCertificateName() const768 std::string EmbeddedTestServer::GetCertificateName() const {
769   DCHECK(is_using_ssl_);
770   switch (cert_) {
771     case CERT_OK:
772     case CERT_MISMATCHED_NAME:
773       return "ok_cert.pem";
774     case CERT_COMMON_NAME_IS_DOMAIN:
775       return "localhost_cert.pem";
776     case CERT_EXPIRED:
777       return "expired_cert.pem";
778     case CERT_CHAIN_WRONG_ROOT:
779       // This chain uses its own dedicated test root certificate to avoid
780       // side-effects that may affect testing.
781       return "redundant-server-chain.pem";
782     case CERT_COMMON_NAME_ONLY:
783       return "common_name_only.pem";
784     case CERT_SHA1_LEAF:
785       return "sha1_leaf.pem";
786     case CERT_OK_BY_INTERMEDIATE:
787       return "ok_cert_by_intermediate.pem";
788     case CERT_BAD_VALIDITY:
789       return "bad_validity.pem";
790     case CERT_TEST_NAMES:
791       return "test_names.pem";
792     case CERT_KEY_USAGE_RSA_ENCIPHERMENT:
793       return "key_usage_rsa_keyencipherment.pem";
794     case CERT_KEY_USAGE_RSA_DIGITAL_SIGNATURE:
795       return "key_usage_rsa_digitalsignature.pem";
796     case CERT_AUTO:
797       return std::string();
798   }
799 
800   return "ok_cert.pem";
801 }
802 
GetCertificate()803 scoped_refptr<X509Certificate> EmbeddedTestServer::GetCertificate() {
804   DCHECK(is_using_ssl_);
805   if (!x509_cert_) {
806     // Some tests want to get the certificate before the server has been
807     // initialized, so load it now if necessary. This is only possible if using
808     // a static certificate.
809     // TODO(mattm): change contract to require initializing first in all cases,
810     // update callers.
811     CHECK(UsingStaticCert());
812     // TODO(mattm): change contract to return nullptr on error instead of
813     // CHECKing, update callers.
814     CHECK(InitializeCertAndKeyFromFile());
815   }
816   return x509_cert_;
817 }
818 
ServeFilesFromDirectory(const base::FilePath & directory)819 void EmbeddedTestServer::ServeFilesFromDirectory(
820     const base::FilePath& directory) {
821   RegisterDefaultHandler(base::BindRepeating(&HandleFileRequest, directory));
822 }
823 
ServeFilesFromSourceDirectory(base::StringPiece relative)824 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
825     base::StringPiece relative) {
826   base::FilePath test_data_dir;
827   CHECK(base::PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir));
828   ServeFilesFromDirectory(test_data_dir.AppendASCII(relative));
829 }
830 
ServeFilesFromSourceDirectory(const base::FilePath & relative)831 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
832     const base::FilePath& relative) {
833   ServeFilesFromDirectory(GetFullPathFromSourceDirectory(relative));
834 }
835 
AddDefaultHandlers(const base::FilePath & directory)836 void EmbeddedTestServer::AddDefaultHandlers(const base::FilePath& directory) {
837   ServeFilesFromSourceDirectory(directory);
838   AddDefaultHandlers();
839 }
840 
AddDefaultHandlers()841 void EmbeddedTestServer::AddDefaultHandlers() {
842   RegisterDefaultHandlers(this);
843 }
844 
GetFullPathFromSourceDirectory(const base::FilePath & relative)845 base::FilePath EmbeddedTestServer::GetFullPathFromSourceDirectory(
846     const base::FilePath& relative) {
847   base::FilePath test_data_dir;
848   CHECK(base::PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir));
849   return test_data_dir.Append(relative);
850 }
851 
RegisterRequestHandler(const HandleRequestCallback & callback)852 void EmbeddedTestServer::RegisterRequestHandler(
853     const HandleRequestCallback& callback) {
854   DCHECK(!io_thread_)
855       << "Handlers must be registered before starting the server.";
856   request_handlers_.push_back(callback);
857 }
858 
RegisterRequestMonitor(const MonitorRequestCallback & callback)859 void EmbeddedTestServer::RegisterRequestMonitor(
860     const MonitorRequestCallback& callback) {
861   DCHECK(!io_thread_)
862       << "Monitors must be registered before starting the server.";
863   request_monitors_.push_back(callback);
864 }
865 
RegisterDefaultHandler(const HandleRequestCallback & callback)866 void EmbeddedTestServer::RegisterDefaultHandler(
867     const HandleRequestCallback& callback) {
868   DCHECK(!io_thread_)
869       << "Handlers must be registered before starting the server.";
870   default_request_handlers_.push_back(callback);
871 }
872 
DoSSLUpgrade(std::unique_ptr<StreamSocket> connection)873 std::unique_ptr<SSLServerSocket> EmbeddedTestServer::DoSSLUpgrade(
874     std::unique_ptr<StreamSocket> connection) {
875   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
876 
877   return context_->CreateSSLServerSocket(std::move(connection));
878 }
879 
DoAcceptLoop()880 void EmbeddedTestServer::DoAcceptLoop() {
881   while (true) {
882     int rv = listen_socket_->Accept(
883         &accepted_socket_,
884         base::BindOnce(&EmbeddedTestServer::OnAcceptCompleted,
885                        base::Unretained(this)));
886     if (rv != OK)
887       return;
888 
889     HandleAcceptResult(std::move(accepted_socket_));
890   }
891 }
892 
FlushAllSocketsAndConnectionsOnUIThread()893 bool EmbeddedTestServer::FlushAllSocketsAndConnectionsOnUIThread() {
894   return PostTaskToIOThreadAndWait(
895       base::BindOnce(&EmbeddedTestServer::FlushAllSocketsAndConnections,
896                      base::Unretained(this)));
897 }
898 
FlushAllSocketsAndConnections()899 void EmbeddedTestServer::FlushAllSocketsAndConnections() {
900   connections_.clear();
901 }
902 
SetAlpsAcceptCH(std::string hostname,std::string accept_ch)903 void EmbeddedTestServer::SetAlpsAcceptCH(std::string hostname,
904                                          std::string accept_ch) {
905   alps_accept_ch_.insert_or_assign(std::move(hostname), std::move(accept_ch));
906 }
907 
OnAcceptCompleted(int rv)908 void EmbeddedTestServer::OnAcceptCompleted(int rv) {
909   DCHECK_NE(ERR_IO_PENDING, rv);
910   HandleAcceptResult(std::move(accepted_socket_));
911   DoAcceptLoop();
912 }
913 
OnHandshakeDone(HttpConnection * connection,int rv)914 void EmbeddedTestServer::OnHandshakeDone(HttpConnection* connection, int rv) {
915   if (connection->Socket()->IsConnected()) {
916     connection->OnSocketReady();
917   } else {
918     RemoveConnection(connection);
919   }
920 }
921 
HandleAcceptResult(std::unique_ptr<StreamSocket> socket_ptr)922 void EmbeddedTestServer::HandleAcceptResult(
923     std::unique_ptr<StreamSocket> socket_ptr) {
924   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
925   if (connection_listener_)
926     socket_ptr = connection_listener_->AcceptedSocket(std::move(socket_ptr));
927 
928   if (!is_using_ssl_) {
929     AddConnection(std::move(socket_ptr))->OnSocketReady();
930     return;
931   }
932 
933   socket_ptr = DoSSLUpgrade(std::move(socket_ptr));
934 
935   StreamSocket* socket = socket_ptr.get();
936   HttpConnection* connection = AddConnection(std::move(socket_ptr));
937 
938   int rv = static_cast<SSLServerSocket*>(socket)->Handshake(
939       base::BindOnce(&EmbeddedTestServer::OnHandshakeDone,
940                      base::Unretained(this), connection));
941   if (rv != ERR_IO_PENDING)
942     OnHandshakeDone(connection, rv);
943 }
944 
AddConnection(std::unique_ptr<StreamSocket> socket_ptr)945 HttpConnection* EmbeddedTestServer::AddConnection(
946     std::unique_ptr<StreamSocket> socket_ptr) {
947   StreamSocket* socket = socket_ptr.get();
948   std::unique_ptr<HttpConnection> connection_ptr = HttpConnection::Create(
949       std::move(socket_ptr), connection_listener_, this, protocol_);
950   HttpConnection* connection = connection_ptr.get();
951   connections_[socket] = std::move(connection_ptr);
952 
953   return connection;
954 }
955 
RemoveConnection(HttpConnection * connection,EmbeddedTestServerConnectionListener * listener)956 void EmbeddedTestServer::RemoveConnection(
957     HttpConnection* connection,
958     EmbeddedTestServerConnectionListener* listener) {
959   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
960   DCHECK(connection);
961   DCHECK_EQ(1u, connections_.count(connection->Socket()));
962 
963   StreamSocket* raw_socket = connection->Socket();
964   std::unique_ptr<StreamSocket> socket = connection->TakeSocket();
965   connections_.erase(raw_socket);
966 
967   if (listener && socket && socket->IsConnected())
968     listener->OnResponseCompletedSuccessfully(std::move(socket));
969 }
970 
PostTaskToIOThreadAndWait(base::OnceClosure closure)971 bool EmbeddedTestServer::PostTaskToIOThreadAndWait(base::OnceClosure closure) {
972   // Note that PostTaskAndReply below requires
973   // base::SingleThreadTaskRunner::GetCurrentDefault() to return a task runner
974   // for posting the reply task. However, in order to make EmbeddedTestServer
975   // universally usable, it needs to cope with the situation where it's running
976   // on a thread on which a task executor is not (yet) available or has been
977   // destroyed already.
978   //
979   // To handle this situation, create temporary task executor to support the
980   // PostTaskAndReply operation if the current thread has no task executor.
981   // TODO(mattm): Is this still necessary/desirable? Try removing this and see
982   // if anything breaks.
983   std::unique_ptr<base::SingleThreadTaskExecutor> temporary_loop;
984   if (!base::CurrentThread::Get())
985     temporary_loop = std::make_unique<base::SingleThreadTaskExecutor>();
986 
987   base::RunLoop run_loop;
988   if (!io_thread_->task_runner()->PostTaskAndReply(
989           FROM_HERE, std::move(closure), run_loop.QuitClosure())) {
990     return false;
991   }
992   run_loop.Run();
993 
994   return true;
995 }
996 
PostTaskToIOThreadAndWaitWithResult(base::OnceCallback<bool ()> task)997 bool EmbeddedTestServer::PostTaskToIOThreadAndWaitWithResult(
998     base::OnceCallback<bool()> task) {
999   // Note that PostTaskAndReply below requires
1000   // base::SingleThreadTaskRunner::GetCurrentDefault() to return a task runner
1001   // for posting the reply task. However, in order to make EmbeddedTestServer
1002   // universally usable, it needs to cope with the situation where it's running
1003   // on a thread on which a task executor is not (yet) available or has been
1004   // destroyed already.
1005   //
1006   // To handle this situation, create temporary task executor to support the
1007   // PostTaskAndReply operation if the current thread has no task executor.
1008   // TODO(mattm): Is this still necessary/desirable? Try removing this and see
1009   // if anything breaks.
1010   std::unique_ptr<base::SingleThreadTaskExecutor> temporary_loop;
1011   if (!base::CurrentThread::Get())
1012     temporary_loop = std::make_unique<base::SingleThreadTaskExecutor>();
1013 
1014   base::RunLoop run_loop;
1015   bool task_result = false;
1016   if (!io_thread_->task_runner()->PostTaskAndReplyWithResult(
1017           FROM_HERE, std::move(task),
1018           base::BindOnce(base::BindLambdaForTesting([&](bool result) {
1019             task_result = result;
1020             run_loop.Quit();
1021           })))) {
1022     return false;
1023   }
1024   run_loop.Run();
1025 
1026   return task_result;
1027 }
1028 
1029 }  // namespace net::test_server
1030