• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/ssl_server_socket_nss.h"
6 
7 #if defined(OS_WIN)
8 #include <winsock2.h>
9 #endif
10 
11 #if defined(USE_SYSTEM_SSL)
12 #include <dlfcn.h>
13 #endif
14 #if defined(OS_MACOSX)
15 #include <Security/Security.h>
16 #endif
17 #include <certdb.h>
18 #include <cryptohi.h>
19 #include <hasht.h>
20 #include <keyhi.h>
21 #include <nspr.h>
22 #include <nss.h>
23 #include <pk11pub.h>
24 #include <secerr.h>
25 #include <sechash.h>
26 #include <ssl.h>
27 #include <sslerr.h>
28 #include <sslproto.h>
29 
30 #include <limits>
31 
32 #include "base/callback_helpers.h"
33 #include "base/lazy_instance.h"
34 #include "base/memory/ref_counted.h"
35 #include "crypto/rsa_private_key.h"
36 #include "crypto/nss_util_internal.h"
37 #include "net/base/io_buffer.h"
38 #include "net/base/net_errors.h"
39 #include "net/base/net_log.h"
40 #include "net/socket/nss_ssl_util.h"
41 
42 // SSL plaintext fragments are shorter than 16KB. Although the record layer
43 // overhead is allowed to be 2K + 5 bytes, in practice the overhead is much
44 // smaller than 1KB. So a 17KB buffer should be large enough to hold an
45 // entire SSL record.
46 static const int kRecvBufferSize = 17 * 1024;
47 static const int kSendBufferSize = 17 * 1024;
48 
49 #define GotoState(s) next_handshake_state_ = s
50 
51 namespace net {
52 
53 namespace {
54 
55 bool g_nss_server_sockets_init = false;
56 
57 class NSSSSLServerInitSingleton {
58  public:
NSSSSLServerInitSingleton()59   NSSSSLServerInitSingleton() {
60     EnsureNSSSSLInit();
61 
62     SSL_ConfigServerSessionIDCache(1024, 5, 5, NULL);
63     g_nss_server_sockets_init = true;
64   }
65 
~NSSSSLServerInitSingleton()66   ~NSSSSLServerInitSingleton() {
67     SSL_ShutdownServerSessionIDCache();
68     g_nss_server_sockets_init = false;
69   }
70 };
71 
72 static base::LazyInstance<NSSSSLServerInitSingleton>
73     g_nss_ssl_server_init_singleton = LAZY_INSTANCE_INITIALIZER;
74 
75 }  // namespace
76 
EnableSSLServerSockets()77 void EnableSSLServerSockets() {
78   g_nss_ssl_server_init_singleton.Get();
79 }
80 
CreateSSLServerSocket(scoped_ptr<StreamSocket> socket,X509Certificate * cert,crypto::RSAPrivateKey * key,const SSLConfig & ssl_config)81 scoped_ptr<SSLServerSocket> CreateSSLServerSocket(
82     scoped_ptr<StreamSocket> socket,
83     X509Certificate* cert,
84     crypto::RSAPrivateKey* key,
85     const SSLConfig& ssl_config) {
86   DCHECK(g_nss_server_sockets_init) << "EnableSSLServerSockets() has not been"
87                                     << "called yet!";
88 
89   return scoped_ptr<SSLServerSocket>(
90       new SSLServerSocketNSS(socket.Pass(), cert, key, ssl_config));
91 }
92 
SSLServerSocketNSS(scoped_ptr<StreamSocket> transport_socket,scoped_refptr<X509Certificate> cert,crypto::RSAPrivateKey * key,const SSLConfig & ssl_config)93 SSLServerSocketNSS::SSLServerSocketNSS(
94     scoped_ptr<StreamSocket> transport_socket,
95     scoped_refptr<X509Certificate> cert,
96     crypto::RSAPrivateKey* key,
97     const SSLConfig& ssl_config)
98     : transport_send_busy_(false),
99       transport_recv_busy_(false),
100       user_read_buf_len_(0),
101       user_write_buf_len_(0),
102       nss_fd_(NULL),
103       nss_bufs_(NULL),
104       transport_socket_(transport_socket.Pass()),
105       ssl_config_(ssl_config),
106       cert_(cert),
107       next_handshake_state_(STATE_NONE),
108       completed_handshake_(false) {
109   // TODO(hclam): Need a better way to clone a key.
110   std::vector<uint8> key_bytes;
111   CHECK(key->ExportPrivateKey(&key_bytes));
112   key_.reset(crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_bytes));
113   CHECK(key_.get());
114 }
115 
~SSLServerSocketNSS()116 SSLServerSocketNSS::~SSLServerSocketNSS() {
117   if (nss_fd_ != NULL) {
118     PR_Close(nss_fd_);
119     nss_fd_ = NULL;
120   }
121 }
122 
Handshake(const CompletionCallback & callback)123 int SSLServerSocketNSS::Handshake(const CompletionCallback& callback) {
124   net_log_.BeginEvent(NetLog::TYPE_SSL_SERVER_HANDSHAKE);
125 
126   int rv = Init();
127   if (rv != OK) {
128     LOG(ERROR) << "Failed to initialize NSS";
129     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_SERVER_HANDSHAKE, rv);
130     return rv;
131   }
132 
133   rv = InitializeSSLOptions();
134   if (rv != OK) {
135     LOG(ERROR) << "Failed to initialize SSL options";
136     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_SERVER_HANDSHAKE, rv);
137     return rv;
138   }
139 
140   // Set peer address. TODO(hclam): This should be in a separate method.
141   PRNetAddr peername;
142   memset(&peername, 0, sizeof(peername));
143   peername.raw.family = AF_INET;
144   memio_SetPeerName(nss_fd_, &peername);
145 
146   GotoState(STATE_HANDSHAKE);
147   rv = DoHandshakeLoop(OK);
148   if (rv == ERR_IO_PENDING) {
149     user_handshake_callback_ = callback;
150   } else {
151     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_SERVER_HANDSHAKE, rv);
152   }
153 
154   return rv > OK ? OK : rv;
155 }
156 
ExportKeyingMaterial(const base::StringPiece & label,bool has_context,const base::StringPiece & context,unsigned char * out,unsigned int outlen)157 int SSLServerSocketNSS::ExportKeyingMaterial(const base::StringPiece& label,
158                                              bool has_context,
159                                              const base::StringPiece& context,
160                                              unsigned char* out,
161                                              unsigned int outlen) {
162   if (!IsConnected())
163     return ERR_SOCKET_NOT_CONNECTED;
164   SECStatus result = SSL_ExportKeyingMaterial(
165       nss_fd_, label.data(), label.size(), has_context,
166       reinterpret_cast<const unsigned char*>(context.data()),
167       context.length(), out, outlen);
168   if (result != SECSuccess) {
169     LogFailedNSSFunction(net_log_, "SSL_ExportKeyingMaterial", "");
170     return MapNSSError(PORT_GetError());
171   }
172   return OK;
173 }
174 
GetTLSUniqueChannelBinding(std::string * out)175 int SSLServerSocketNSS::GetTLSUniqueChannelBinding(std::string* out) {
176   if (!IsConnected())
177     return ERR_SOCKET_NOT_CONNECTED;
178   unsigned char buf[64];
179   unsigned int len;
180   SECStatus result = SSL_GetChannelBinding(nss_fd_,
181                                            SSL_CHANNEL_BINDING_TLS_UNIQUE,
182                                            buf, &len, arraysize(buf));
183   if (result != SECSuccess) {
184     LogFailedNSSFunction(net_log_, "SSL_GetChannelBinding", "");
185     return MapNSSError(PORT_GetError());
186   }
187   out->assign(reinterpret_cast<char*>(buf), len);
188   return OK;
189 }
190 
Connect(const CompletionCallback & callback)191 int SSLServerSocketNSS::Connect(const CompletionCallback& callback) {
192   NOTIMPLEMENTED();
193   return ERR_NOT_IMPLEMENTED;
194 }
195 
Read(IOBuffer * buf,int buf_len,const CompletionCallback & callback)196 int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len,
197                              const CompletionCallback& callback) {
198   DCHECK(user_read_callback_.is_null());
199   DCHECK(user_handshake_callback_.is_null());
200   DCHECK(!user_read_buf_.get());
201   DCHECK(nss_bufs_);
202   DCHECK(!callback.is_null());
203 
204   user_read_buf_ = buf;
205   user_read_buf_len_ = buf_len;
206 
207   DCHECK(completed_handshake_);
208 
209   int rv = DoReadLoop(OK);
210 
211   if (rv == ERR_IO_PENDING) {
212     user_read_callback_ = callback;
213   } else {
214     user_read_buf_ = NULL;
215     user_read_buf_len_ = 0;
216   }
217   return rv;
218 }
219 
Write(IOBuffer * buf,int buf_len,const CompletionCallback & callback)220 int SSLServerSocketNSS::Write(IOBuffer* buf, int buf_len,
221                               const CompletionCallback& callback) {
222   DCHECK(user_write_callback_.is_null());
223   DCHECK(!user_write_buf_.get());
224   DCHECK(nss_bufs_);
225   DCHECK(!callback.is_null());
226 
227   user_write_buf_ = buf;
228   user_write_buf_len_ = buf_len;
229 
230   int rv = DoWriteLoop(OK);
231 
232   if (rv == ERR_IO_PENDING) {
233     user_write_callback_ = callback;
234   } else {
235     user_write_buf_ = NULL;
236     user_write_buf_len_ = 0;
237   }
238   return rv;
239 }
240 
SetReceiveBufferSize(int32 size)241 int SSLServerSocketNSS::SetReceiveBufferSize(int32 size) {
242   return transport_socket_->SetReceiveBufferSize(size);
243 }
244 
SetSendBufferSize(int32 size)245 int SSLServerSocketNSS::SetSendBufferSize(int32 size) {
246   return transport_socket_->SetSendBufferSize(size);
247 }
248 
IsConnected() const249 bool SSLServerSocketNSS::IsConnected() const {
250   // TODO(wtc): Find out if we should check transport_socket_->IsConnected()
251   // as well.
252   return completed_handshake_;
253 }
254 
Disconnect()255 void SSLServerSocketNSS::Disconnect() {
256   transport_socket_->Disconnect();
257 }
258 
IsConnectedAndIdle() const259 bool SSLServerSocketNSS::IsConnectedAndIdle() const {
260   return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
261 }
262 
GetPeerAddress(IPEndPoint * address) const263 int SSLServerSocketNSS::GetPeerAddress(IPEndPoint* address) const {
264   if (!IsConnected())
265     return ERR_SOCKET_NOT_CONNECTED;
266   return transport_socket_->GetPeerAddress(address);
267 }
268 
GetLocalAddress(IPEndPoint * address) const269 int SSLServerSocketNSS::GetLocalAddress(IPEndPoint* address) const {
270   if (!IsConnected())
271     return ERR_SOCKET_NOT_CONNECTED;
272   return transport_socket_->GetLocalAddress(address);
273 }
274 
NetLog() const275 const BoundNetLog& SSLServerSocketNSS::NetLog() const {
276   return net_log_;
277 }
278 
SetSubresourceSpeculation()279 void SSLServerSocketNSS::SetSubresourceSpeculation() {
280   transport_socket_->SetSubresourceSpeculation();
281 }
282 
SetOmniboxSpeculation()283 void SSLServerSocketNSS::SetOmniboxSpeculation() {
284   transport_socket_->SetOmniboxSpeculation();
285 }
286 
WasEverUsed() const287 bool SSLServerSocketNSS::WasEverUsed() const {
288   return transport_socket_->WasEverUsed();
289 }
290 
UsingTCPFastOpen() const291 bool SSLServerSocketNSS::UsingTCPFastOpen() const {
292   return transport_socket_->UsingTCPFastOpen();
293 }
294 
WasNpnNegotiated() const295 bool SSLServerSocketNSS::WasNpnNegotiated() const {
296   NOTIMPLEMENTED();
297   return false;
298 }
299 
GetNegotiatedProtocol() const300 NextProto SSLServerSocketNSS::GetNegotiatedProtocol() const {
301   // NPN is not supported by this class.
302   return kProtoUnknown;
303 }
304 
GetSSLInfo(SSLInfo * ssl_info)305 bool SSLServerSocketNSS::GetSSLInfo(SSLInfo* ssl_info) {
306   NOTIMPLEMENTED();
307   return false;
308 }
309 
InitializeSSLOptions()310 int SSLServerSocketNSS::InitializeSSLOptions() {
311   // Transport connected, now hook it up to nss
312   nss_fd_ = memio_CreateIOLayer(kRecvBufferSize, kSendBufferSize);
313   if (nss_fd_ == NULL) {
314     return ERR_OUT_OF_MEMORY;  // TODO(port): map NSPR error code.
315   }
316 
317   // Grab pointer to buffers
318   nss_bufs_ = memio_GetSecret(nss_fd_);
319 
320   /* Create SSL state machine */
321   /* Push SSL onto our fake I/O socket */
322   nss_fd_ = SSL_ImportFD(NULL, nss_fd_);
323   if (nss_fd_ == NULL) {
324     LogFailedNSSFunction(net_log_, "SSL_ImportFD", "");
325     return ERR_OUT_OF_MEMORY;  // TODO(port): map NSPR/NSS error code.
326   }
327   // TODO(port): set more ssl options!  Check errors!
328 
329   int rv;
330 
331   rv = SSL_OptionSet(nss_fd_, SSL_SECURITY, PR_TRUE);
332   if (rv != SECSuccess) {
333     LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_SECURITY");
334     return ERR_UNEXPECTED;
335   }
336 
337   rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SSL2, PR_FALSE);
338   if (rv != SECSuccess) {
339     LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_SSL2");
340     return ERR_UNEXPECTED;
341   }
342 
343   SSLVersionRange version_range;
344   version_range.min = ssl_config_.version_min;
345   version_range.max = ssl_config_.version_max;
346   rv = SSL_VersionRangeSet(nss_fd_, &version_range);
347   if (rv != SECSuccess) {
348     LogFailedNSSFunction(net_log_, "SSL_VersionRangeSet", "");
349     return ERR_NO_SSL_VERSIONS_ENABLED;
350   }
351 
352   if (ssl_config_.require_forward_secrecy) {
353     const PRUint16* const ssl_ciphers = SSL_GetImplementedCiphers();
354     const PRUint16 num_ciphers = SSL_GetNumImplementedCiphers();
355 
356     // Require forward security by iterating over the cipher suites and
357     // disabling all those that don't use ECDHE.
358     for (unsigned i = 0; i < num_ciphers; i++) {
359       SSLCipherSuiteInfo info;
360       if (SSL_GetCipherSuiteInfo(ssl_ciphers[i], &info, sizeof(info)) ==
361           SECSuccess) {
362         if (strcmp(info.keaTypeName, "ECDHE") != 0) {
363           SSL_CipherPrefSet(nss_fd_, ssl_ciphers[i], PR_FALSE);
364         }
365       }
366     }
367   }
368 
369   for (std::vector<uint16>::const_iterator it =
370            ssl_config_.disabled_cipher_suites.begin();
371        it != ssl_config_.disabled_cipher_suites.end(); ++it) {
372     // This will fail if the specified cipher is not implemented by NSS, but
373     // the failure is harmless.
374     SSL_CipherPrefSet(nss_fd_, *it, PR_FALSE);
375   }
376 
377   // Server socket doesn't need session tickets.
378   rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SESSION_TICKETS, PR_FALSE);
379   if (rv != SECSuccess) {
380     LogFailedNSSFunction(
381         net_log_, "SSL_OptionSet", "SSL_ENABLE_SESSION_TICKETS");
382   }
383 
384   // Doing this will force PR_Accept perform handshake as server.
385   rv = SSL_OptionSet(nss_fd_, SSL_HANDSHAKE_AS_CLIENT, PR_FALSE);
386   if (rv != SECSuccess) {
387     LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_HANDSHAKE_AS_CLIENT");
388     return ERR_UNEXPECTED;
389   }
390 
391   rv = SSL_OptionSet(nss_fd_, SSL_HANDSHAKE_AS_SERVER, PR_TRUE);
392   if (rv != SECSuccess) {
393     LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_HANDSHAKE_AS_SERVER");
394     return ERR_UNEXPECTED;
395   }
396 
397   rv = SSL_OptionSet(nss_fd_, SSL_REQUEST_CERTIFICATE, PR_FALSE);
398   if (rv != SECSuccess) {
399     LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_REQUEST_CERTIFICATE");
400     return ERR_UNEXPECTED;
401   }
402 
403   rv = SSL_OptionSet(nss_fd_, SSL_REQUIRE_CERTIFICATE, PR_FALSE);
404   if (rv != SECSuccess) {
405     LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_REQUIRE_CERTIFICATE");
406     return ERR_UNEXPECTED;
407   }
408 
409   rv = SSL_AuthCertificateHook(nss_fd_, OwnAuthCertHandler, this);
410   if (rv != SECSuccess) {
411     LogFailedNSSFunction(net_log_, "SSL_AuthCertificateHook", "");
412     return ERR_UNEXPECTED;
413   }
414 
415   rv = SSL_HandshakeCallback(nss_fd_, HandshakeCallback, this);
416   if (rv != SECSuccess) {
417     LogFailedNSSFunction(net_log_, "SSL_HandshakeCallback", "");
418     return ERR_UNEXPECTED;
419   }
420 
421   // Get a certificate of CERTCertificate structure.
422   std::string der_string;
423   if (!X509Certificate::GetDEREncoded(cert_->os_cert_handle(), &der_string))
424     return ERR_UNEXPECTED;
425 
426   SECItem der_cert;
427   der_cert.data = reinterpret_cast<unsigned char*>(const_cast<char*>(
428       der_string.data()));
429   der_cert.len  = der_string.length();
430   der_cert.type = siDERCertBuffer;
431 
432   // Parse into a CERTCertificate structure.
433   CERTCertificate* cert = CERT_NewTempCertificate(
434       CERT_GetDefaultCertDB(), &der_cert, NULL, PR_FALSE, PR_TRUE);
435   if (!cert) {
436     LogFailedNSSFunction(net_log_, "CERT_NewTempCertificate", "");
437     return MapNSSError(PORT_GetError());
438   }
439 
440   // Get a key of SECKEYPrivateKey* structure.
441   std::vector<uint8> key_vector;
442   if (!key_->ExportPrivateKey(&key_vector)) {
443     CERT_DestroyCertificate(cert);
444     return ERR_UNEXPECTED;
445   }
446 
447   SECKEYPrivateKeyStr* private_key = NULL;
448   PK11SlotInfo* slot = PK11_GetInternalSlot();
449   if (!slot) {
450     CERT_DestroyCertificate(cert);
451     return ERR_UNEXPECTED;
452   }
453 
454   SECItem der_private_key_info;
455   der_private_key_info.data =
456       const_cast<unsigned char*>(&key_vector.front());
457   der_private_key_info.len = key_vector.size();
458   // The server's RSA private key must be imported into NSS with the
459   // following key usage bits:
460   // - KU_KEY_ENCIPHERMENT, required for the RSA key exchange algorithm.
461   // - KU_DIGITAL_SIGNATURE, required for the DHE_RSA and ECDHE_RSA key
462   //   exchange algorithms.
463   const unsigned int key_usage = KU_KEY_ENCIPHERMENT | KU_DIGITAL_SIGNATURE;
464   rv =  PK11_ImportDERPrivateKeyInfoAndReturnKey(
465       slot, &der_private_key_info, NULL, NULL, PR_FALSE, PR_FALSE,
466       key_usage, &private_key, NULL);
467   PK11_FreeSlot(slot);
468   if (rv != SECSuccess) {
469     CERT_DestroyCertificate(cert);
470     return ERR_UNEXPECTED;
471   }
472 
473   // Assign server certificate and private key.
474   SSLKEAType cert_kea = NSS_FindCertKEAType(cert);
475   rv = SSL_ConfigSecureServer(nss_fd_, cert, private_key, cert_kea);
476   CERT_DestroyCertificate(cert);
477   SECKEY_DestroyPrivateKey(private_key);
478 
479   if (rv != SECSuccess) {
480     PRErrorCode prerr = PR_GetError();
481     LOG(ERROR) << "Failed to config SSL server: " << prerr;
482     LogFailedNSSFunction(net_log_, "SSL_ConfigureSecureServer", "");
483     return ERR_UNEXPECTED;
484   }
485 
486   // Tell SSL we're a server; needed if not letting NSPR do socket I/O
487   rv = SSL_ResetHandshake(nss_fd_, PR_TRUE);
488   if (rv != SECSuccess) {
489     LogFailedNSSFunction(net_log_, "SSL_ResetHandshake", "");
490     return ERR_UNEXPECTED;
491   }
492 
493   return OK;
494 }
495 
OnSendComplete(int result)496 void SSLServerSocketNSS::OnSendComplete(int result) {
497   if (next_handshake_state_ == STATE_HANDSHAKE) {
498     // In handshake phase.
499     OnHandshakeIOComplete(result);
500     return;
501   }
502 
503   // TODO(byungchul): This state machine is not correct. Copy the state machine
504   // of SSLClientSocketNSS::OnSendComplete() which handles it better.
505   if (!completed_handshake_)
506     return;
507 
508   if (user_write_buf_.get()) {
509     int rv = DoWriteLoop(result);
510     if (rv != ERR_IO_PENDING)
511       DoWriteCallback(rv);
512   } else {
513     // Ensure that any queued ciphertext is flushed.
514     DoTransportIO();
515   }
516 }
517 
OnRecvComplete(int result)518 void SSLServerSocketNSS::OnRecvComplete(int result) {
519   if (next_handshake_state_ == STATE_HANDSHAKE) {
520     // In handshake phase.
521     OnHandshakeIOComplete(result);
522     return;
523   }
524 
525   // Network layer received some data, check if client requested to read
526   // decrypted data.
527   if (!user_read_buf_.get() || !completed_handshake_)
528     return;
529 
530   int rv = DoReadLoop(result);
531   if (rv != ERR_IO_PENDING)
532     DoReadCallback(rv);
533 }
534 
OnHandshakeIOComplete(int result)535 void SSLServerSocketNSS::OnHandshakeIOComplete(int result) {
536   int rv = DoHandshakeLoop(result);
537   if (rv == ERR_IO_PENDING)
538     return;
539 
540   net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_SERVER_HANDSHAKE, rv);
541   if (!user_handshake_callback_.is_null())
542     DoHandshakeCallback(rv);
543 }
544 
545 // Return 0 for EOF,
546 // > 0 for bytes transferred immediately,
547 // < 0 for error (or the non-error ERR_IO_PENDING).
BufferSend(void)548 int SSLServerSocketNSS::BufferSend(void) {
549   if (transport_send_busy_)
550     return ERR_IO_PENDING;
551 
552   const char* buf1;
553   const char* buf2;
554   unsigned int len1, len2;
555   if (memio_GetWriteParams(nss_bufs_, &buf1, &len1, &buf2, &len2)) {
556     // The error code itself is ignored, so just return ERR_ABORTED.
557     return ERR_ABORTED;
558   }
559   const unsigned int len = len1 + len2;
560 
561   int rv = 0;
562   if (len) {
563     scoped_refptr<IOBuffer> send_buffer(new IOBuffer(len));
564     memcpy(send_buffer->data(), buf1, len1);
565     memcpy(send_buffer->data() + len1, buf2, len2);
566     rv = transport_socket_->Write(
567         send_buffer.get(),
568         len,
569         base::Bind(&SSLServerSocketNSS::BufferSendComplete,
570                    base::Unretained(this)));
571     if (rv == ERR_IO_PENDING) {
572       transport_send_busy_ = true;
573     } else {
574       memio_PutWriteResult(nss_bufs_, MapErrorToNSS(rv));
575     }
576   }
577 
578   return rv;
579 }
580 
BufferSendComplete(int result)581 void SSLServerSocketNSS::BufferSendComplete(int result) {
582   memio_PutWriteResult(nss_bufs_, MapErrorToNSS(result));
583   transport_send_busy_ = false;
584   OnSendComplete(result);
585 }
586 
BufferRecv(void)587 int SSLServerSocketNSS::BufferRecv(void) {
588   if (transport_recv_busy_) return ERR_IO_PENDING;
589 
590   char* buf;
591   int nb = memio_GetReadParams(nss_bufs_, &buf);
592   int rv;
593   if (!nb) {
594     // buffer too full to read into, so no I/O possible at moment
595     rv = ERR_IO_PENDING;
596   } else {
597     recv_buffer_ = new IOBuffer(nb);
598     rv = transport_socket_->Read(
599         recv_buffer_.get(),
600         nb,
601         base::Bind(&SSLServerSocketNSS::BufferRecvComplete,
602                    base::Unretained(this)));
603     if (rv == ERR_IO_PENDING) {
604       transport_recv_busy_ = true;
605     } else {
606       if (rv > 0)
607         memcpy(buf, recv_buffer_->data(), rv);
608       memio_PutReadResult(nss_bufs_, MapErrorToNSS(rv));
609       recv_buffer_ = NULL;
610     }
611   }
612   return rv;
613 }
614 
BufferRecvComplete(int result)615 void SSLServerSocketNSS::BufferRecvComplete(int result) {
616   if (result > 0) {
617     char* buf;
618     memio_GetReadParams(nss_bufs_, &buf);
619     memcpy(buf, recv_buffer_->data(), result);
620   }
621   recv_buffer_ = NULL;
622   memio_PutReadResult(nss_bufs_, MapErrorToNSS(result));
623   transport_recv_busy_ = false;
624   OnRecvComplete(result);
625 }
626 
627 // Do as much network I/O as possible between the buffer and the
628 // transport socket. Return true if some I/O performed, false
629 // otherwise (error or ERR_IO_PENDING).
DoTransportIO()630 bool SSLServerSocketNSS::DoTransportIO() {
631   bool network_moved = false;
632   if (nss_bufs_ != NULL) {
633     int rv;
634     // Read and write as much data as we can. The loop is neccessary
635     // because Write() may return synchronously.
636     do {
637       rv = BufferSend();
638       if (rv > 0)
639         network_moved = true;
640     } while (rv > 0);
641     if (BufferRecv() >= 0)
642       network_moved = true;
643   }
644   return network_moved;
645 }
646 
DoPayloadRead()647 int SSLServerSocketNSS::DoPayloadRead() {
648   DCHECK(user_read_buf_.get());
649   DCHECK_GT(user_read_buf_len_, 0);
650   int rv = PR_Read(nss_fd_, user_read_buf_->data(), user_read_buf_len_);
651   if (rv >= 0)
652     return rv;
653   PRErrorCode prerr = PR_GetError();
654   if (prerr == PR_WOULD_BLOCK_ERROR) {
655     return ERR_IO_PENDING;
656   }
657   rv = MapNSSError(prerr);
658   net_log_.AddEvent(NetLog::TYPE_SSL_READ_ERROR,
659                     CreateNetLogSSLErrorCallback(rv, prerr));
660   return rv;
661 }
662 
DoPayloadWrite()663 int SSLServerSocketNSS::DoPayloadWrite() {
664   DCHECK(user_write_buf_.get());
665   int rv = PR_Write(nss_fd_, user_write_buf_->data(), user_write_buf_len_);
666   if (rv >= 0)
667     return rv;
668   PRErrorCode prerr = PR_GetError();
669   if (prerr == PR_WOULD_BLOCK_ERROR) {
670     return ERR_IO_PENDING;
671   }
672   rv = MapNSSError(prerr);
673   net_log_.AddEvent(NetLog::TYPE_SSL_WRITE_ERROR,
674                     CreateNetLogSSLErrorCallback(rv, prerr));
675   return rv;
676 }
677 
DoHandshakeLoop(int last_io_result)678 int SSLServerSocketNSS::DoHandshakeLoop(int last_io_result) {
679   int rv = last_io_result;
680   do {
681     // Default to STATE_NONE for next state.
682     // (This is a quirk carried over from the windows
683     // implementation.  It makes reading the logs a bit harder.)
684     // State handlers can and often do call GotoState just
685     // to stay in the current state.
686     State state = next_handshake_state_;
687     GotoState(STATE_NONE);
688     switch (state) {
689       case STATE_HANDSHAKE:
690         rv = DoHandshake();
691         break;
692       case STATE_NONE:
693       default:
694         rv = ERR_UNEXPECTED;
695         LOG(DFATAL) << "unexpected state " << state;
696         break;
697     }
698 
699     // Do the actual network I/O
700     bool network_moved = DoTransportIO();
701     if (network_moved && next_handshake_state_ == STATE_HANDSHAKE) {
702       // In general we exit the loop if rv is ERR_IO_PENDING.  In this
703       // special case we keep looping even if rv is ERR_IO_PENDING because
704       // the transport IO may allow DoHandshake to make progress.
705       rv = OK;  // This causes us to stay in the loop.
706     }
707   } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE);
708   return rv;
709 }
710 
DoReadLoop(int result)711 int SSLServerSocketNSS::DoReadLoop(int result) {
712   DCHECK(completed_handshake_);
713   DCHECK(next_handshake_state_ == STATE_NONE);
714 
715   if (result < 0)
716     return result;
717 
718   if (!nss_bufs_) {
719     LOG(DFATAL) << "!nss_bufs_";
720     int rv = ERR_UNEXPECTED;
721     net_log_.AddEvent(NetLog::TYPE_SSL_READ_ERROR,
722                       CreateNetLogSSLErrorCallback(rv, 0));
723     return rv;
724   }
725 
726   bool network_moved;
727   int rv;
728   do {
729     rv = DoPayloadRead();
730     network_moved = DoTransportIO();
731   } while (rv == ERR_IO_PENDING && network_moved);
732   return rv;
733 }
734 
DoWriteLoop(int result)735 int SSLServerSocketNSS::DoWriteLoop(int result) {
736   DCHECK(completed_handshake_);
737   DCHECK_EQ(next_handshake_state_, STATE_NONE);
738 
739   if (result < 0)
740     return result;
741 
742   if (!nss_bufs_) {
743     LOG(DFATAL) << "!nss_bufs_";
744     int rv = ERR_UNEXPECTED;
745     net_log_.AddEvent(NetLog::TYPE_SSL_WRITE_ERROR,
746                       CreateNetLogSSLErrorCallback(rv, 0));
747     return rv;
748   }
749 
750   bool network_moved;
751   int rv;
752   do {
753     rv = DoPayloadWrite();
754     network_moved = DoTransportIO();
755   } while (rv == ERR_IO_PENDING && network_moved);
756   return rv;
757 }
758 
DoHandshake()759 int SSLServerSocketNSS::DoHandshake() {
760   int net_error = OK;
761   SECStatus rv = SSL_ForceHandshake(nss_fd_);
762 
763   if (rv == SECSuccess) {
764     completed_handshake_ = true;
765   } else {
766     PRErrorCode prerr = PR_GetError();
767     net_error = MapNSSError(prerr);
768 
769     // If not done, stay in this state
770     if (net_error == ERR_IO_PENDING) {
771       GotoState(STATE_HANDSHAKE);
772     } else {
773       LOG(ERROR) << "handshake failed; NSS error code " << prerr
774                  << ", net_error " << net_error;
775       net_log_.AddEvent(NetLog::TYPE_SSL_HANDSHAKE_ERROR,
776                         CreateNetLogSSLErrorCallback(net_error, prerr));
777     }
778   }
779   return net_error;
780 }
781 
DoHandshakeCallback(int rv)782 void SSLServerSocketNSS::DoHandshakeCallback(int rv) {
783   DCHECK_NE(rv, ERR_IO_PENDING);
784   ResetAndReturn(&user_handshake_callback_).Run(rv > OK ? OK : rv);
785 }
786 
DoReadCallback(int rv)787 void SSLServerSocketNSS::DoReadCallback(int rv) {
788   DCHECK(rv != ERR_IO_PENDING);
789   DCHECK(!user_read_callback_.is_null());
790 
791   user_read_buf_ = NULL;
792   user_read_buf_len_ = 0;
793   ResetAndReturn(&user_read_callback_).Run(rv);
794 }
795 
DoWriteCallback(int rv)796 void SSLServerSocketNSS::DoWriteCallback(int rv) {
797   DCHECK(rv != ERR_IO_PENDING);
798   DCHECK(!user_write_callback_.is_null());
799 
800   user_write_buf_ = NULL;
801   user_write_buf_len_ = 0;
802   ResetAndReturn(&user_write_callback_).Run(rv);
803 }
804 
805 // static
806 // NSS calls this if an incoming certificate needs to be verified.
807 // Do nothing but return SECSuccess.
808 // This is called only in full handshake mode.
809 // Peer certificate is retrieved in HandshakeCallback() later, which is called
810 // in full handshake mode or in resumption handshake mode.
OwnAuthCertHandler(void * arg,PRFileDesc * socket,PRBool checksig,PRBool is_server)811 SECStatus SSLServerSocketNSS::OwnAuthCertHandler(void* arg,
812                                                  PRFileDesc* socket,
813                                                  PRBool checksig,
814                                                  PRBool is_server) {
815   // TODO(hclam): Implement.
816   // Tell NSS to not verify the certificate.
817   return SECSuccess;
818 }
819 
820 // static
821 // NSS calls this when handshake is completed.
822 // After the SSL handshake is finished we need to verify the certificate.
HandshakeCallback(PRFileDesc * socket,void * arg)823 void SSLServerSocketNSS::HandshakeCallback(PRFileDesc* socket,
824                                            void* arg) {
825   // TODO(hclam): Implement.
826 }
827 
Init()828 int SSLServerSocketNSS::Init() {
829   // Initialize the NSS SSL library in a threadsafe way.  This also
830   // initializes the NSS base library.
831   EnsureNSSSSLInit();
832   if (!NSS_IsInitialized())
833     return ERR_UNEXPECTED;
834 
835   EnableSSLServerSockets();
836   return OK;
837 }
838 
839 }  // namespace net
840