• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/ssl_server_socket_impl.h"
6 
7 #include <memory>
8 #include <utility>
9 
10 #include "base/functional/bind.h"
11 #include "base/functional/callback_helpers.h"
12 #include "base/logging.h"
13 #include "base/memory/raw_ptr.h"
14 #include "base/memory/weak_ptr.h"
15 #include "base/strings/string_util.h"
16 #include "crypto/openssl_util.h"
17 #include "crypto/rsa_private_key.h"
18 #include "net/base/completion_once_callback.h"
19 #include "net/base/net_errors.h"
20 #include "net/cert/cert_verify_result.h"
21 #include "net/cert/client_cert_verifier.h"
22 #include "net/cert/x509_util.h"
23 #include "net/log/net_log_event_type.h"
24 #include "net/log/net_log_with_source.h"
25 #include "net/socket/socket_bio_adapter.h"
26 #include "net/ssl/openssl_ssl_util.h"
27 #include "net/ssl/ssl_connection_status_flags.h"
28 #include "net/ssl/ssl_info.h"
29 #include "net/ssl/ssl_private_key.h"
30 #include "net/traffic_annotation/network_traffic_annotation.h"
31 #include "third_party/abseil-cpp/absl/types/optional.h"
32 #include "third_party/boringssl/src/include/openssl/bytestring.h"
33 #include "third_party/boringssl/src/include/openssl/err.h"
34 #include "third_party/boringssl/src/include/openssl/pool.h"
35 #include "third_party/boringssl/src/include/openssl/ssl.h"
36 
37 #define GotoState(s) next_handshake_state_ = s
38 
39 namespace net {
40 
41 namespace {
42 
43 // This constant can be any non-negative/non-zero value (eg: it does not
44 // overlap with any value of the net::Error range, including net::OK).
45 const int kSSLServerSocketNoPendingResult = 1;
46 
47 }  // namespace
48 
49 class SSLServerContextImpl::SocketImpl : public SSLServerSocket,
50                                          public SocketBIOAdapter::Delegate {
51  public:
52   SocketImpl(SSLServerContextImpl* context,
53              std::unique_ptr<StreamSocket> socket);
54 
55   SocketImpl(const SocketImpl&) = delete;
56   SocketImpl& operator=(const SocketImpl&) = delete;
57 
58   ~SocketImpl() override;
59 
60   // SSLServerSocket interface.
61   int Handshake(CompletionOnceCallback callback) override;
62 
63   // SSLSocket interface.
64   int ExportKeyingMaterial(base::StringPiece label,
65                            bool has_context,
66                            base::StringPiece context,
67                            unsigned char* out,
68                            unsigned int outlen) override;
69 
70   // Socket interface (via StreamSocket).
71   int Read(IOBuffer* buf,
72            int buf_len,
73            CompletionOnceCallback callback) override;
74   int ReadIfReady(IOBuffer* buf,
75                   int buf_len,
76                   CompletionOnceCallback callback) override;
77   int CancelReadIfReady() override;
78   int Write(IOBuffer* buf,
79             int buf_len,
80             CompletionOnceCallback callback,
81             const NetworkTrafficAnnotationTag& traffic_annotation) override;
82   int SetReceiveBufferSize(int32_t size) override;
83   int SetSendBufferSize(int32_t size) override;
84 
85   // StreamSocket implementation.
86   int Connect(CompletionOnceCallback callback) override;
87   void Disconnect() override;
88   bool IsConnected() const override;
89   bool IsConnectedAndIdle() const override;
90   int GetPeerAddress(IPEndPoint* address) const override;
91   int GetLocalAddress(IPEndPoint* address) const override;
92   const NetLogWithSource& NetLog() const override;
93   bool WasEverUsed() const override;
94   NextProto GetNegotiatedProtocol() const override;
95   absl::optional<base::StringPiece> GetPeerApplicationSettings() const override;
96   bool GetSSLInfo(SSLInfo* ssl_info) override;
97   int64_t GetTotalReceivedBytes() const override;
98   void ApplySocketTag(const SocketTag& tag) override;
99 
100   static SocketImpl* FromSSL(SSL* ssl);
101 
102   static ssl_verify_result_t CertVerifyCallback(SSL* ssl, uint8_t* out_alert);
103   ssl_verify_result_t CertVerifyCallbackImpl(uint8_t* out_alert);
104 
105   static const SSL_PRIVATE_KEY_METHOD kPrivateKeyMethod;
106   static ssl_private_key_result_t PrivateKeySignCallback(SSL* ssl,
107                                                          uint8_t* out,
108                                                          size_t* out_len,
109                                                          size_t max_out,
110                                                          uint16_t algorithm,
111                                                          const uint8_t* in,
112                                                          size_t in_len);
113   static ssl_private_key_result_t PrivateKeyDecryptCallback(SSL* ssl,
114                                                             uint8_t* out,
115                                                             size_t* out_len,
116                                                             size_t max_out,
117                                                             const uint8_t* in,
118                                                             size_t in_len);
119   static ssl_private_key_result_t PrivateKeyCompleteCallback(SSL* ssl,
120                                                              uint8_t* out,
121                                                              size_t* out_len,
122                                                              size_t max_out);
123 
124   ssl_private_key_result_t PrivateKeySignCallback(uint8_t* out,
125                                                   size_t* out_len,
126                                                   size_t max_out,
127                                                   uint16_t algorithm,
128                                                   const uint8_t* in,
129                                                   size_t in_len);
130   ssl_private_key_result_t PrivateKeyCompleteCallback(uint8_t* out,
131                                                       size_t* out_len,
132                                                       size_t max_out);
133   void OnPrivateKeyComplete(Error error, const std::vector<uint8_t>& signature);
134 
135   static int ALPNSelectCallback(SSL* ssl,
136                                 const uint8_t** out,
137                                 uint8_t* out_len,
138                                 const uint8_t* in,
139                                 unsigned in_len,
140                                 void* arg);
141 
142   static ssl_select_cert_result_t SelectCertificateCallback(
143       const SSL_CLIENT_HELLO* client_hello);
144 
145   // SocketBIOAdapter::Delegate implementation.
146   void OnReadReady() override;
147   void OnWriteReady() override;
148 
149  private:
150   enum State {
151     STATE_NONE,
152     STATE_HANDSHAKE,
153   };
154 
155   void OnHandshakeIOComplete(int result);
156 
157   [[nodiscard]] int DoPayloadRead(IOBuffer* buf, int buf_len);
158   [[nodiscard]] int DoPayloadWrite();
159 
160   [[nodiscard]] int DoHandshakeLoop(int last_io_result);
161   [[nodiscard]] int DoHandshake();
162   void DoHandshakeCallback(int result);
163   void DoReadCallback(int result);
164   void DoWriteCallback(int result);
165 
166   [[nodiscard]] int Init();
167   void ExtractClientCert();
168 
169   raw_ptr<SSLServerContextImpl, DanglingUntriaged> context_;
170 
171   NetLogWithSource net_log_;
172 
173   CompletionOnceCallback user_handshake_callback_;
174   CompletionOnceCallback user_read_callback_;
175   CompletionOnceCallback user_write_callback_;
176 
177   // SSLPrivateKey signature.
178   int signature_result_;
179   std::vector<uint8_t> signature_;
180 
181   // Used by Read function.
182   scoped_refptr<IOBuffer> user_read_buf_;
183   int user_read_buf_len_ = 0;
184 
185   // Used by Write function.
186   scoped_refptr<IOBuffer> user_write_buf_;
187   int user_write_buf_len_ = 0;
188 
189   // OpenSSL stuff
190   bssl::UniquePtr<SSL> ssl_;
191 
192   // Whether we received any data in early data.
193   bool early_data_received_ = false;
194 
195   // StreamSocket for sending and receiving data.
196   std::unique_ptr<StreamSocket> transport_socket_;
197   std::unique_ptr<SocketBIOAdapter> transport_adapter_;
198 
199   // Certificate for the client.
200   scoped_refptr<X509Certificate> client_cert_;
201 
202   State next_handshake_state_ = STATE_NONE;
203   bool completed_handshake_ = false;
204 
205   NextProto negotiated_protocol_ = kProtoUnknown;
206 
207   base::WeakPtrFactory<SocketImpl> weak_factory_{this};
208 };
209 
SocketImpl(SSLServerContextImpl * context,std::unique_ptr<StreamSocket> transport_socket)210 SSLServerContextImpl::SocketImpl::SocketImpl(
211     SSLServerContextImpl* context,
212     std::unique_ptr<StreamSocket> transport_socket)
213     : context_(context),
214       signature_result_(kSSLServerSocketNoPendingResult),
215       transport_socket_(std::move(transport_socket)) {}
216 
~SocketImpl()217 SSLServerContextImpl::SocketImpl::~SocketImpl() {
218   if (ssl_) {
219     // Calling SSL_shutdown prevents the session from being marked as
220     // unresumable.
221     SSL_shutdown(ssl_.get());
222     ssl_.reset();
223   }
224 }
225 
226 // static
227 const SSL_PRIVATE_KEY_METHOD
228     SSLServerContextImpl::SocketImpl::kPrivateKeyMethod = {
229         &SSLServerContextImpl::SocketImpl::PrivateKeySignCallback,
230         &SSLServerContextImpl::SocketImpl::PrivateKeyDecryptCallback,
231         &SSLServerContextImpl::SocketImpl::PrivateKeyCompleteCallback,
232 };
233 
234 // static
235 ssl_private_key_result_t
PrivateKeySignCallback(SSL * ssl,uint8_t * out,size_t * out_len,size_t max_out,uint16_t algorithm,const uint8_t * in,size_t in_len)236 SSLServerContextImpl::SocketImpl::PrivateKeySignCallback(SSL* ssl,
237                                                          uint8_t* out,
238                                                          size_t* out_len,
239                                                          size_t max_out,
240                                                          uint16_t algorithm,
241                                                          const uint8_t* in,
242                                                          size_t in_len) {
243   return FromSSL(ssl)->PrivateKeySignCallback(out, out_len, max_out, algorithm,
244                                               in, in_len);
245 }
246 
247 // static
248 ssl_private_key_result_t
PrivateKeyDecryptCallback(SSL * ssl,uint8_t * out,size_t * out_len,size_t max_out,const uint8_t * in,size_t in_len)249 SSLServerContextImpl::SocketImpl::PrivateKeyDecryptCallback(SSL* ssl,
250                                                             uint8_t* out,
251                                                             size_t* out_len,
252                                                             size_t max_out,
253                                                             const uint8_t* in,
254                                                             size_t in_len) {
255   // Decrypt is not supported.
256   return ssl_private_key_failure;
257 }
258 
259 // static
260 ssl_private_key_result_t
PrivateKeyCompleteCallback(SSL * ssl,uint8_t * out,size_t * out_len,size_t max_out)261 SSLServerContextImpl::SocketImpl::PrivateKeyCompleteCallback(SSL* ssl,
262                                                              uint8_t* out,
263                                                              size_t* out_len,
264                                                              size_t max_out) {
265   return FromSSL(ssl)->PrivateKeyCompleteCallback(out, out_len, max_out);
266 }
267 
268 ssl_private_key_result_t
PrivateKeySignCallback(uint8_t * out,size_t * out_len,size_t max_out,uint16_t algorithm,const uint8_t * in,size_t in_len)269 SSLServerContextImpl::SocketImpl::PrivateKeySignCallback(uint8_t* out,
270                                                          size_t* out_len,
271                                                          size_t max_out,
272                                                          uint16_t algorithm,
273                                                          const uint8_t* in,
274                                                          size_t in_len) {
275   DCHECK(context_);
276   DCHECK(context_->private_key_);
277   signature_result_ = ERR_IO_PENDING;
278   context_->private_key_->Sign(
279       algorithm, base::make_span(in, in_len),
280       base::BindOnce(&SSLServerContextImpl::SocketImpl::OnPrivateKeyComplete,
281                      weak_factory_.GetWeakPtr()));
282   return ssl_private_key_retry;
283 }
284 
285 ssl_private_key_result_t
PrivateKeyCompleteCallback(uint8_t * out,size_t * out_len,size_t max_out)286 SSLServerContextImpl::SocketImpl::PrivateKeyCompleteCallback(uint8_t* out,
287                                                              size_t* out_len,
288                                                              size_t max_out) {
289   if (signature_result_ == ERR_IO_PENDING)
290     return ssl_private_key_retry;
291   if (signature_result_ != OK) {
292     OpenSSLPutNetError(FROM_HERE, signature_result_);
293     return ssl_private_key_failure;
294   }
295   if (signature_.size() > max_out) {
296     OpenSSLPutNetError(FROM_HERE, ERR_SSL_CLIENT_AUTH_SIGNATURE_FAILED);
297     return ssl_private_key_failure;
298   }
299   memcpy(out, signature_.data(), signature_.size());
300   *out_len = signature_.size();
301   signature_.clear();
302   return ssl_private_key_success;
303 }
304 
OnPrivateKeyComplete(Error error,const std::vector<uint8_t> & signature)305 void SSLServerContextImpl::SocketImpl::OnPrivateKeyComplete(
306     Error error,
307     const std::vector<uint8_t>& signature) {
308   DCHECK_EQ(ERR_IO_PENDING, signature_result_);
309   DCHECK(signature_.empty());
310 
311   signature_result_ = error;
312   if (signature_result_ == OK)
313     signature_ = signature;
314   OnHandshakeIOComplete(ERR_IO_PENDING);
315 }
316 
317 // static
ALPNSelectCallback(SSL * ssl,const uint8_t ** out,uint8_t * out_len,const uint8_t * in,unsigned in_len,void * arg)318 int SSLServerContextImpl::SocketImpl::ALPNSelectCallback(SSL* ssl,
319                                                          const uint8_t** out,
320                                                          uint8_t* out_len,
321                                                          const uint8_t* in,
322                                                          unsigned in_len,
323                                                          void* arg) {
324   SSLServerContextImpl::SocketImpl* socket = FromSSL(ssl);
325 
326   // Iterate over the server protocols in preference order.
327   for (NextProto server_proto :
328        socket->context_->ssl_server_config_.alpn_protos) {
329     const char* server_proto_str = NextProtoToString(server_proto);
330 
331     // See if the client advertised the corresponding protocol.
332     CBS cbs;
333     CBS_init(&cbs, in, in_len);
334     while (CBS_len(&cbs) != 0) {
335       CBS client_proto;
336       if (!CBS_get_u8_length_prefixed(&cbs, &client_proto)) {
337         return SSL_TLSEXT_ERR_NOACK;
338       }
339       if (base::StringPiece(
340               reinterpret_cast<const char*>(CBS_data(&client_proto)),
341               CBS_len(&client_proto)) == server_proto_str) {
342         *out = CBS_data(&client_proto);
343         *out_len = CBS_len(&client_proto);
344 
345         const auto& application_settings =
346             socket->context_->ssl_server_config_.application_settings;
347         auto it = application_settings.find(server_proto);
348         if (it != application_settings.end()) {
349           const std::vector<uint8_t>& data = it->second;
350           SSL_add_application_settings(ssl, CBS_data(&client_proto),
351                                        CBS_len(&client_proto), data.data(),
352                                        data.size());
353         }
354         return SSL_TLSEXT_ERR_OK;
355       }
356     }
357   }
358   return SSL_TLSEXT_ERR_NOACK;
359 }
360 
361 ssl_select_cert_result_t
SelectCertificateCallback(const SSL_CLIENT_HELLO * client_hello)362 SSLServerContextImpl::SocketImpl::SelectCertificateCallback(
363     const SSL_CLIENT_HELLO* client_hello) {
364   SSLServerContextImpl::SocketImpl* socket = FromSSL(client_hello->ssl);
365   const SSLServerConfig& config = socket->context_->ssl_server_config_;
366   if (!config.client_hello_callback_for_testing.is_null() &&
367       !config.client_hello_callback_for_testing.Run(client_hello)) {
368     return ssl_select_cert_error;
369   }
370   return ssl_select_cert_success;
371 }
372 
Handshake(CompletionOnceCallback callback)373 int SSLServerContextImpl::SocketImpl::Handshake(
374     CompletionOnceCallback callback) {
375   net_log_.BeginEvent(NetLogEventType::SSL_SERVER_HANDSHAKE);
376 
377   // Set up new ssl object.
378   int rv = Init();
379   if (rv != OK) {
380     LOG(ERROR) << "Failed to initialize OpenSSL: rv=" << rv;
381     net_log_.EndEventWithNetErrorCode(NetLogEventType::SSL_SERVER_HANDSHAKE,
382                                       rv);
383     return rv;
384   }
385 
386   // Set SSL to server mode. Handshake happens in the loop below.
387   SSL_set_accept_state(ssl_.get());
388 
389   GotoState(STATE_HANDSHAKE);
390   rv = DoHandshakeLoop(OK);
391   if (rv == ERR_IO_PENDING) {
392     user_handshake_callback_ = std::move(callback);
393   } else {
394     net_log_.EndEventWithNetErrorCode(NetLogEventType::SSL_SERVER_HANDSHAKE,
395                                       rv);
396   }
397 
398   return rv > OK ? OK : rv;
399 }
400 
ExportKeyingMaterial(base::StringPiece label,bool has_context,base::StringPiece context,unsigned char * out,unsigned int outlen)401 int SSLServerContextImpl::SocketImpl::ExportKeyingMaterial(
402     base::StringPiece label,
403     bool has_context,
404     base::StringPiece context,
405     unsigned char* out,
406     unsigned int outlen) {
407   if (!IsConnected())
408     return ERR_SOCKET_NOT_CONNECTED;
409 
410   crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
411 
412   int rv = SSL_export_keying_material(
413       ssl_.get(), out, outlen, label.data(), label.size(),
414       reinterpret_cast<const unsigned char*>(context.data()), context.length(),
415       context.length() > 0);
416 
417   if (rv != 1) {
418     int ssl_error = SSL_get_error(ssl_.get(), rv);
419     LOG(ERROR) << "Failed to export keying material;"
420                << " returned " << rv << ", SSL error code " << ssl_error;
421     return MapOpenSSLError(ssl_error, err_tracer);
422   }
423   return OK;
424 }
425 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)426 int SSLServerContextImpl::SocketImpl::Read(IOBuffer* buf,
427                                            int buf_len,
428                                            CompletionOnceCallback callback) {
429   int rv = ReadIfReady(buf, buf_len, std::move(callback));
430   if (rv == ERR_IO_PENDING) {
431     user_read_buf_ = buf;
432     user_read_buf_len_ = buf_len;
433   }
434   return rv;
435 }
436 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)437 int SSLServerContextImpl::SocketImpl::ReadIfReady(
438     IOBuffer* buf,
439     int buf_len,
440     CompletionOnceCallback callback) {
441   DCHECK(user_read_callback_.is_null());
442   DCHECK(user_handshake_callback_.is_null());
443   DCHECK(!user_read_buf_);
444   DCHECK(!callback.is_null());
445   DCHECK(completed_handshake_);
446 
447   int rv = DoPayloadRead(buf, buf_len);
448 
449   if (rv == ERR_IO_PENDING) {
450     user_read_callback_ = std::move(callback);
451   }
452 
453   return rv;
454 }
455 
CancelReadIfReady()456 int SSLServerContextImpl::SocketImpl::CancelReadIfReady() {
457   DCHECK(user_read_callback_);
458   DCHECK(!user_read_buf_);
459 
460   // Cancel |user_read_callback_|, because caller does not expect the callback
461   // to be invoked after they have canceled the ReadIfReady.
462   //
463   // We do not pass the signal on to |stream_socket_| or |transport_adapter_|.
464   // When it completes, it will signal OnReadReady(), which will notice there is
465   // no read operation to progress and skip it. Unlike with SSLClientSocket,
466   // SSL and transport reads are more aligned, but this avoids making
467   // assumptions or breaking the SocketBIOAdapter's state.
468   user_read_callback_.Reset();
469   return OK;
470 }
471 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)472 int SSLServerContextImpl::SocketImpl::Write(
473     IOBuffer* buf,
474     int buf_len,
475     CompletionOnceCallback callback,
476     const NetworkTrafficAnnotationTag& traffic_annotation) {
477   DCHECK(user_write_callback_.is_null());
478   DCHECK(!user_write_buf_);
479   DCHECK(!callback.is_null());
480 
481   user_write_buf_ = buf;
482   user_write_buf_len_ = buf_len;
483 
484   int rv = DoPayloadWrite();
485 
486   if (rv == ERR_IO_PENDING) {
487     user_write_callback_ = std::move(callback);
488   } else {
489     user_write_buf_ = nullptr;
490     user_write_buf_len_ = 0;
491   }
492   return rv;
493 }
494 
SetReceiveBufferSize(int32_t size)495 int SSLServerContextImpl::SocketImpl::SetReceiveBufferSize(int32_t size) {
496   return transport_socket_->SetReceiveBufferSize(size);
497 }
498 
SetSendBufferSize(int32_t size)499 int SSLServerContextImpl::SocketImpl::SetSendBufferSize(int32_t size) {
500   return transport_socket_->SetSendBufferSize(size);
501 }
502 
Connect(CompletionOnceCallback callback)503 int SSLServerContextImpl::SocketImpl::Connect(CompletionOnceCallback callback) {
504   NOTIMPLEMENTED();
505   return ERR_NOT_IMPLEMENTED;
506 }
507 
Disconnect()508 void SSLServerContextImpl::SocketImpl::Disconnect() {
509   transport_socket_->Disconnect();
510 }
511 
IsConnected() const512 bool SSLServerContextImpl::SocketImpl::IsConnected() const {
513   // TODO(wtc): Find out if we should check transport_socket_->IsConnected()
514   // as well.
515   return completed_handshake_;
516 }
517 
IsConnectedAndIdle() const518 bool SSLServerContextImpl::SocketImpl::IsConnectedAndIdle() const {
519   return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
520 }
521 
GetPeerAddress(IPEndPoint * address) const522 int SSLServerContextImpl::SocketImpl::GetPeerAddress(
523     IPEndPoint* address) const {
524   if (!IsConnected())
525     return ERR_SOCKET_NOT_CONNECTED;
526   return transport_socket_->GetPeerAddress(address);
527 }
528 
GetLocalAddress(IPEndPoint * address) const529 int SSLServerContextImpl::SocketImpl::GetLocalAddress(
530     IPEndPoint* address) const {
531   if (!IsConnected())
532     return ERR_SOCKET_NOT_CONNECTED;
533   return transport_socket_->GetLocalAddress(address);
534 }
535 
NetLog() const536 const NetLogWithSource& SSLServerContextImpl::SocketImpl::NetLog() const {
537   return net_log_;
538 }
539 
WasEverUsed() const540 bool SSLServerContextImpl::SocketImpl::WasEverUsed() const {
541   return transport_socket_->WasEverUsed();
542 }
543 
GetNegotiatedProtocol() const544 NextProto SSLServerContextImpl::SocketImpl::GetNegotiatedProtocol() const {
545   return negotiated_protocol_;
546 }
547 
548 absl::optional<base::StringPiece>
GetPeerApplicationSettings() const549 SSLServerContextImpl::SocketImpl::GetPeerApplicationSettings() const {
550   if (!SSL_has_application_settings(ssl_.get())) {
551     return absl::nullopt;
552   }
553 
554   const uint8_t* out_data;
555   size_t out_len;
556   SSL_get0_peer_application_settings(ssl_.get(), &out_data, &out_len);
557   return base::StringPiece{reinterpret_cast<const char*>(out_data), out_len};
558 }
559 
GetSSLInfo(SSLInfo * ssl_info)560 bool SSLServerContextImpl::SocketImpl::GetSSLInfo(SSLInfo* ssl_info) {
561   ssl_info->Reset();
562   if (!completed_handshake_)
563     return false;
564 
565   ssl_info->cert = client_cert_;
566 
567   const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl_.get());
568   CHECK(cipher);
569 
570   SSLConnectionStatusSetCipherSuite(SSL_CIPHER_get_protocol_id(cipher),
571                                     &ssl_info->connection_status);
572   SSLConnectionStatusSetVersion(GetNetSSLVersion(ssl_.get()),
573                                 &ssl_info->connection_status);
574 
575   ssl_info->early_data_received = early_data_received_;
576   ssl_info->encrypted_client_hello = SSL_ech_accepted(ssl_.get());
577   ssl_info->handshake_type = SSL_session_reused(ssl_.get())
578                                  ? SSLInfo::HANDSHAKE_RESUME
579                                  : SSLInfo::HANDSHAKE_FULL;
580 
581   return true;
582 }
583 
GetTotalReceivedBytes() const584 int64_t SSLServerContextImpl::SocketImpl::GetTotalReceivedBytes() const {
585   return transport_socket_->GetTotalReceivedBytes();
586 }
587 
ApplySocketTag(const SocketTag & tag)588 void SSLServerContextImpl::SocketImpl::ApplySocketTag(const SocketTag& tag) {
589   NOTIMPLEMENTED();
590 }
591 
OnReadReady()592 void SSLServerContextImpl::SocketImpl::OnReadReady() {
593   if (next_handshake_state_ == STATE_HANDSHAKE) {
594     // In handshake phase. The parameter to OnHandshakeIOComplete is unused.
595     OnHandshakeIOComplete(OK);
596     return;
597   }
598 
599   // BoringSSL does not support renegotiation as a server, so the only other
600   // operation blocked on Read is DoPayloadRead.
601   if (!user_read_buf_) {
602     if (!user_read_callback_.is_null()) {
603       DoReadCallback(OK);
604     }
605     return;
606   }
607 
608   int rv = DoPayloadRead(user_read_buf_.get(), user_read_buf_len_);
609   if (rv != ERR_IO_PENDING)
610     DoReadCallback(rv);
611 }
612 
OnWriteReady()613 void SSLServerContextImpl::SocketImpl::OnWriteReady() {
614   if (next_handshake_state_ == STATE_HANDSHAKE) {
615     // In handshake phase. The parameter to OnHandshakeIOComplete is unused.
616     OnHandshakeIOComplete(OK);
617     return;
618   }
619 
620   // BoringSSL does not support renegotiation as a server, so the only other
621   // operation blocked on Read is DoPayloadWrite.
622   if (!user_write_buf_)
623     return;
624 
625   int rv = DoPayloadWrite();
626   if (rv != ERR_IO_PENDING)
627     DoWriteCallback(rv);
628 }
629 
OnHandshakeIOComplete(int result)630 void SSLServerContextImpl::SocketImpl::OnHandshakeIOComplete(int result) {
631   int rv = DoHandshakeLoop(result);
632   if (rv == ERR_IO_PENDING)
633     return;
634 
635   net_log_.EndEventWithNetErrorCode(NetLogEventType::SSL_SERVER_HANDSHAKE, rv);
636   if (!user_handshake_callback_.is_null())
637     DoHandshakeCallback(rv);
638 }
639 
DoPayloadRead(IOBuffer * buf,int buf_len)640 int SSLServerContextImpl::SocketImpl::DoPayloadRead(IOBuffer* buf,
641                                                     int buf_len) {
642   DCHECK(completed_handshake_);
643   DCHECK_EQ(STATE_NONE, next_handshake_state_);
644   DCHECK(buf);
645   DCHECK_GT(buf_len, 0);
646 
647   crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
648   int rv = SSL_read(ssl_.get(), buf->data(), buf_len);
649   if (rv >= 0) {
650     if (SSL_in_early_data(ssl_.get()))
651       early_data_received_ = true;
652     return rv;
653   }
654   int ssl_error = SSL_get_error(ssl_.get(), rv);
655   OpenSSLErrorInfo error_info;
656   int net_error =
657       MapOpenSSLErrorWithDetails(ssl_error, err_tracer, &error_info);
658   if (net_error != ERR_IO_PENDING) {
659     NetLogOpenSSLError(net_log_, NetLogEventType::SSL_READ_ERROR, net_error,
660                        ssl_error, error_info);
661   }
662   return net_error;
663 }
664 
DoPayloadWrite()665 int SSLServerContextImpl::SocketImpl::DoPayloadWrite() {
666   DCHECK(completed_handshake_);
667   DCHECK_EQ(STATE_NONE, next_handshake_state_);
668   DCHECK(user_write_buf_);
669 
670   crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
671   int rv = SSL_write(ssl_.get(), user_write_buf_->data(), user_write_buf_len_);
672   if (rv >= 0)
673     return rv;
674   int ssl_error = SSL_get_error(ssl_.get(), rv);
675   OpenSSLErrorInfo error_info;
676   int net_error =
677       MapOpenSSLErrorWithDetails(ssl_error, err_tracer, &error_info);
678   if (net_error != ERR_IO_PENDING) {
679     NetLogOpenSSLError(net_log_, NetLogEventType::SSL_WRITE_ERROR, net_error,
680                        ssl_error, error_info);
681   }
682   return net_error;
683 }
684 
DoHandshakeLoop(int last_io_result)685 int SSLServerContextImpl::SocketImpl::DoHandshakeLoop(int last_io_result) {
686   int rv = last_io_result;
687   do {
688     // Default to STATE_NONE for next state.
689     // (This is a quirk carried over from the windows
690     // implementation.  It makes reading the logs a bit harder.)
691     // State handlers can and often do call GotoState just
692     // to stay in the current state.
693     State state = next_handshake_state_;
694     GotoState(STATE_NONE);
695     switch (state) {
696       case STATE_HANDSHAKE:
697         rv = DoHandshake();
698         break;
699       case STATE_NONE:
700       default:
701         rv = ERR_UNEXPECTED;
702         LOG(DFATAL) << "unexpected state " << state;
703         break;
704     }
705   } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE);
706   return rv;
707 }
708 
DoHandshake()709 int SSLServerContextImpl::SocketImpl::DoHandshake() {
710   crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
711   int net_error = OK;
712   int rv = SSL_do_handshake(ssl_.get());
713   if (rv == 1) {
714     const STACK_OF(CRYPTO_BUFFER)* certs =
715         SSL_get0_peer_certificates(ssl_.get());
716     if (certs) {
717       client_cert_ = x509_util::CreateX509CertificateFromBuffers(certs);
718       if (!client_cert_)
719         return ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT;
720     }
721 
722     const uint8_t* alpn_proto = nullptr;
723     unsigned alpn_len = 0;
724     SSL_get0_alpn_selected(ssl_.get(), &alpn_proto, &alpn_len);
725     if (alpn_len > 0) {
726       base::StringPiece proto(reinterpret_cast<const char*>(alpn_proto),
727                               alpn_len);
728       negotiated_protocol_ = NextProtoFromString(proto);
729     }
730 
731     if (context_->ssl_server_config_.alert_after_handshake_for_testing) {
732       SSL_send_fatal_alert(ssl_.get(),
733                            context_->ssl_server_config_
734                                .alert_after_handshake_for_testing.value());
735       return ERR_FAILED;
736     }
737 
738     completed_handshake_ = true;
739   } else {
740     int ssl_error = SSL_get_error(ssl_.get(), rv);
741 
742     if (ssl_error == SSL_ERROR_WANT_PRIVATE_KEY_OPERATION) {
743       DCHECK(context_->private_key_);
744       GotoState(STATE_HANDSHAKE);
745       return ERR_IO_PENDING;
746     }
747 
748     OpenSSLErrorInfo error_info;
749     net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer, &error_info);
750 
751     // SSL_R_CERTIFICATE_VERIFY_FAILED's mapping is different between client and
752     // server.
753     if (ERR_GET_LIB(error_info.error_code) == ERR_LIB_SSL &&
754         ERR_GET_REASON(error_info.error_code) ==
755             SSL_R_CERTIFICATE_VERIFY_FAILED) {
756       net_error = ERR_BAD_SSL_CLIENT_AUTH_CERT;
757     }
758 
759     // If not done, stay in this state
760     if (net_error == ERR_IO_PENDING) {
761       GotoState(STATE_HANDSHAKE);
762     } else {
763       LOG(ERROR) << "handshake failed; returned " << rv << ", SSL error code "
764                  << ssl_error << ", net_error " << net_error;
765       NetLogOpenSSLError(net_log_, NetLogEventType::SSL_HANDSHAKE_ERROR,
766                          net_error, ssl_error, error_info);
767     }
768   }
769   return net_error;
770 }
771 
DoHandshakeCallback(int rv)772 void SSLServerContextImpl::SocketImpl::DoHandshakeCallback(int rv) {
773   DCHECK_NE(rv, ERR_IO_PENDING);
774   std::move(user_handshake_callback_).Run(rv > OK ? OK : rv);
775 }
776 
DoReadCallback(int rv)777 void SSLServerContextImpl::SocketImpl::DoReadCallback(int rv) {
778   DCHECK(rv != ERR_IO_PENDING);
779   DCHECK(!user_read_callback_.is_null());
780 
781   user_read_buf_ = nullptr;
782   user_read_buf_len_ = 0;
783   std::move(user_read_callback_).Run(rv);
784 }
785 
DoWriteCallback(int rv)786 void SSLServerContextImpl::SocketImpl::DoWriteCallback(int rv) {
787   DCHECK(rv != ERR_IO_PENDING);
788   DCHECK(!user_write_callback_.is_null());
789 
790   user_write_buf_ = nullptr;
791   user_write_buf_len_ = 0;
792   std::move(user_write_callback_).Run(rv);
793 }
794 
Init()795 int SSLServerContextImpl::SocketImpl::Init() {
796   static const int kBufferSize = 17 * 1024;
797 
798   crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
799 
800   ssl_.reset(SSL_new(context_->ssl_ctx_.get()));
801   if (!ssl_ || !SSL_set_app_data(ssl_.get(), this)) {
802     return ERR_UNEXPECTED;
803   }
804 
805   SSL_set_shed_handshake_config(ssl_.get(), 1);
806 
807   // Set certificate and private key.
808   if (context_->pkey_) {
809     DCHECK(context_->cert_->cert_buffer());
810     if (!SetSSLChainAndKey(ssl_.get(), context_->cert_.get(),
811                            context_->pkey_.get(), nullptr)) {
812       return ERR_UNEXPECTED;
813     }
814   } else {
815     DCHECK(context_->private_key_);
816     if (!SetSSLChainAndKey(ssl_.get(), context_->cert_.get(), nullptr,
817                            &kPrivateKeyMethod)) {
818       return ERR_UNEXPECTED;
819     }
820     std::vector<uint16_t> preferences =
821         context_->private_key_->GetAlgorithmPreferences();
822     SSL_set_signing_algorithm_prefs(ssl_.get(), preferences.data(),
823                                     preferences.size());
824   }
825 
826   if (context_->ssl_server_config_.signature_algorithm_for_testing
827           .has_value()) {
828     uint16_t id = *context_->ssl_server_config_.signature_algorithm_for_testing;
829     CHECK(SSL_set_signing_algorithm_prefs(ssl_.get(), &id, 1));
830   }
831 
832   const std::vector<int>& curves =
833       context_->ssl_server_config_.curves_for_testing;
834   if (!curves.empty()) {
835     CHECK(SSL_set1_curves(ssl_.get(), curves.data(), curves.size()));
836   }
837 
838   transport_adapter_ = std::make_unique<SocketBIOAdapter>(
839       transport_socket_.get(), kBufferSize, kBufferSize, this);
840   BIO* transport_bio = transport_adapter_->bio();
841 
842   BIO_up_ref(transport_bio);  // SSL_set0_rbio takes ownership.
843   SSL_set0_rbio(ssl_.get(), transport_bio);
844 
845   BIO_up_ref(transport_bio);  // SSL_set0_wbio takes ownership.
846   SSL_set0_wbio(ssl_.get(), transport_bio);
847 
848   return OK;
849 }
850 
FromSSL(SSL * ssl)851 SSLServerContextImpl::SocketImpl* SSLServerContextImpl::SocketImpl::FromSSL(
852     SSL* ssl) {
853   SocketImpl* socket = reinterpret_cast<SocketImpl*>(SSL_get_app_data(ssl));
854   DCHECK(socket);
855   return socket;
856 }
857 
858 // static
CertVerifyCallback(SSL * ssl,uint8_t * out_alert)859 ssl_verify_result_t SSLServerContextImpl::SocketImpl::CertVerifyCallback(
860     SSL* ssl,
861     uint8_t* out_alert) {
862   return FromSSL(ssl)->CertVerifyCallbackImpl(out_alert);
863 }
864 
CertVerifyCallbackImpl(uint8_t * out_alert)865 ssl_verify_result_t SSLServerContextImpl::SocketImpl::CertVerifyCallbackImpl(
866     uint8_t* out_alert) {
867   ClientCertVerifier* verifier =
868       context_->ssl_server_config_.client_cert_verifier;
869   // If a verifier was not supplied, all certificates are accepted.
870   if (!verifier)
871     return ssl_verify_ok;
872 
873   scoped_refptr<X509Certificate> client_cert =
874       x509_util::CreateX509CertificateFromBuffers(
875           SSL_get0_peer_certificates(ssl_.get()));
876   if (!client_cert) {
877     *out_alert = SSL_AD_BAD_CERTIFICATE;
878     return ssl_verify_invalid;
879   }
880 
881   // TODO(davidben): Support asynchronous verifiers. http://crbug.com/347402
882   std::unique_ptr<ClientCertVerifier::Request> ignore_async;
883   int res = verifier->Verify(client_cert.get(), CompletionOnceCallback(),
884                              &ignore_async);
885   DCHECK_NE(res, ERR_IO_PENDING);
886 
887   if (res != OK) {
888     // TODO(davidben): Map from certificate verification failure to alert.
889     *out_alert = SSL_AD_CERTIFICATE_UNKNOWN;
890     return ssl_verify_invalid;
891   }
892   return ssl_verify_ok;
893 }
894 
CreateSSLServerContext(X509Certificate * certificate,EVP_PKEY * pkey,const SSLServerConfig & ssl_server_config)895 std::unique_ptr<SSLServerContext> CreateSSLServerContext(
896     X509Certificate* certificate,
897     EVP_PKEY* pkey,
898     const SSLServerConfig& ssl_server_config) {
899   return std::make_unique<SSLServerContextImpl>(certificate, pkey,
900                                                 ssl_server_config);
901 }
902 
CreateSSLServerContext(X509Certificate * certificate,const crypto::RSAPrivateKey & key,const SSLServerConfig & ssl_server_config)903 std::unique_ptr<SSLServerContext> CreateSSLServerContext(
904     X509Certificate* certificate,
905     const crypto::RSAPrivateKey& key,
906     const SSLServerConfig& ssl_server_config) {
907   return std::make_unique<SSLServerContextImpl>(certificate, key.key(),
908                                                 ssl_server_config);
909 }
910 
CreateSSLServerContext(X509Certificate * certificate,scoped_refptr<SSLPrivateKey> key,const SSLServerConfig & ssl_config)911 std::unique_ptr<SSLServerContext> CreateSSLServerContext(
912     X509Certificate* certificate,
913     scoped_refptr<SSLPrivateKey> key,
914     const SSLServerConfig& ssl_config) {
915   return std::make_unique<SSLServerContextImpl>(certificate, key, ssl_config);
916 }
917 
SSLServerContextImpl(X509Certificate * certificate,scoped_refptr<net::SSLPrivateKey> key,const SSLServerConfig & ssl_server_config)918 SSLServerContextImpl::SSLServerContextImpl(
919     X509Certificate* certificate,
920     scoped_refptr<net::SSLPrivateKey> key,
921     const SSLServerConfig& ssl_server_config)
922     : ssl_server_config_(ssl_server_config),
923       cert_(certificate),
924       private_key_(key) {
925   CHECK(private_key_);
926   Init();
927 }
928 
SSLServerContextImpl(X509Certificate * certificate,EVP_PKEY * pkey,const SSLServerConfig & ssl_server_config)929 SSLServerContextImpl::SSLServerContextImpl(
930     X509Certificate* certificate,
931     EVP_PKEY* pkey,
932     const SSLServerConfig& ssl_server_config)
933     : ssl_server_config_(ssl_server_config), cert_(certificate) {
934   CHECK(pkey);
935   pkey_ = bssl::UpRef(pkey);
936   Init();
937 }
938 
Init()939 void SSLServerContextImpl::Init() {
940   crypto::EnsureOpenSSLInit();
941   ssl_ctx_.reset(SSL_CTX_new(TLS_with_buffers_method()));
942   SSL_CTX_set_session_cache_mode(ssl_ctx_.get(), SSL_SESS_CACHE_SERVER);
943   uint8_t session_ctx_id = 0;
944   SSL_CTX_set_session_id_context(ssl_ctx_.get(), &session_ctx_id,
945                                  sizeof(session_ctx_id));
946   // Deduplicate all certificates minted from the SSL_CTX in memory.
947   SSL_CTX_set0_buffer_pool(ssl_ctx_.get(), x509_util::GetBufferPool());
948 
949   int verify_mode = 0;
950   switch (ssl_server_config_.client_cert_type) {
951     case SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT:
952       verify_mode |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
953       [[fallthrough]];
954     case SSLServerConfig::ClientCertType::OPTIONAL_CLIENT_CERT:
955       verify_mode |= SSL_VERIFY_PEER;
956       SSL_CTX_set_custom_verify(ssl_ctx_.get(), verify_mode,
957                                 SocketImpl::CertVerifyCallback);
958       break;
959     case SSLServerConfig::ClientCertType::NO_CLIENT_CERT:
960       break;
961   }
962 
963   SSL_CTX_set_early_data_enabled(ssl_ctx_.get(),
964                                  ssl_server_config_.early_data_enabled);
965   // TLS versions before TLS 1.2 are no longer supported.
966   CHECK_LE(TLS1_2_VERSION, ssl_server_config_.version_min);
967   CHECK_LE(TLS1_2_VERSION, ssl_server_config_.version_max);
968   CHECK(SSL_CTX_set_min_proto_version(ssl_ctx_.get(),
969                                       ssl_server_config_.version_min));
970   CHECK(SSL_CTX_set_max_proto_version(ssl_ctx_.get(),
971                                       ssl_server_config_.version_max));
972 
973   // OpenSSL defaults some options to on, others to off. To avoid ambiguity,
974   // set everything we care about to an absolute value.
975   SslSetClearMask options;
976   options.ConfigureFlag(SSL_OP_NO_COMPRESSION, true);
977 
978   SSL_CTX_set_options(ssl_ctx_.get(), options.set_mask);
979   SSL_CTX_clear_options(ssl_ctx_.get(), options.clear_mask);
980 
981   // Same as above, this time for the SSL mode.
982   SslSetClearMask mode;
983 
984   mode.ConfigureFlag(SSL_MODE_RELEASE_BUFFERS, true);
985 
986   SSL_CTX_set_mode(ssl_ctx_.get(), mode.set_mask);
987   SSL_CTX_clear_mode(ssl_ctx_.get(), mode.clear_mask);
988 
989   if (ssl_server_config_.cipher_suite_for_testing.has_value()) {
990     const SSL_CIPHER* cipher =
991         SSL_get_cipher_by_value(*ssl_server_config_.cipher_suite_for_testing);
992     CHECK(cipher);
993     CHECK(SSL_CTX_set_strict_cipher_list(ssl_ctx_.get(),
994                                          SSL_CIPHER_get_name(cipher)));
995   } else {
996     // Use BoringSSL defaults, but disable 3DES and HMAC-SHA1 ciphers in ECDSA.
997     // These are the remaining CBC-mode ECDSA ciphers.
998     std::string command("ALL:!aPSK:!ECDSA+SHA1:!3DES");
999 
1000     // SSLPrivateKey only supports ECDHE-based ciphers because it lacks decrypt.
1001     if (ssl_server_config_.require_ecdhe || (!pkey_ && private_key_))
1002       command.append(":!kRSA");
1003 
1004     // Remove any disabled ciphers.
1005     for (uint16_t id : ssl_server_config_.disabled_cipher_suites) {
1006       const SSL_CIPHER* cipher = SSL_get_cipher_by_value(id);
1007       if (cipher) {
1008         command.append(":!");
1009         command.append(SSL_CIPHER_get_name(cipher));
1010       }
1011     }
1012 
1013     CHECK(SSL_CTX_set_strict_cipher_list(ssl_ctx_.get(), command.c_str()));
1014   }
1015 
1016   if (ssl_server_config_.client_cert_type !=
1017           SSLServerConfig::ClientCertType::NO_CLIENT_CERT &&
1018       !ssl_server_config_.cert_authorities.empty()) {
1019     bssl::UniquePtr<STACK_OF(CRYPTO_BUFFER)> stack(sk_CRYPTO_BUFFER_new_null());
1020     for (const auto& authority : ssl_server_config_.cert_authorities) {
1021       sk_CRYPTO_BUFFER_push(stack.get(),
1022                             x509_util::CreateCryptoBuffer(authority).release());
1023     }
1024     SSL_CTX_set0_client_CAs(ssl_ctx_.get(), stack.release());
1025   }
1026 
1027   SSL_CTX_set_alpn_select_cb(ssl_ctx_.get(), &SocketImpl::ALPNSelectCallback,
1028                              nullptr);
1029 
1030   if (!ssl_server_config_.ocsp_response.empty()) {
1031     SSL_CTX_set_ocsp_response(ssl_ctx_.get(),
1032                               ssl_server_config_.ocsp_response.data(),
1033                               ssl_server_config_.ocsp_response.size());
1034   }
1035 
1036   if (!ssl_server_config_.signed_cert_timestamp_list.empty()) {
1037     SSL_CTX_set_signed_cert_timestamp_list(
1038         ssl_ctx_.get(), ssl_server_config_.signed_cert_timestamp_list.data(),
1039         ssl_server_config_.signed_cert_timestamp_list.size());
1040   }
1041 
1042   if (ssl_server_config_.ech_keys) {
1043     CHECK(SSL_CTX_set1_ech_keys(ssl_ctx_.get(),
1044                                 ssl_server_config_.ech_keys.get()));
1045   }
1046 
1047   SSL_CTX_set_select_certificate_cb(ssl_ctx_.get(),
1048                                     &SocketImpl::SelectCertificateCallback);
1049 }
1050 
1051 SSLServerContextImpl::~SSLServerContextImpl() = default;
1052 
CreateSSLServerSocket(std::unique_ptr<StreamSocket> socket)1053 std::unique_ptr<SSLServerSocket> SSLServerContextImpl::CreateSSLServerSocket(
1054     std::unique_ptr<StreamSocket> socket) {
1055   return std::make_unique<SocketImpl>(this, std::move(socket));
1056 }
1057 
1058 }  // namespace net
1059