• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 // This test suite uses SSLClientSocket to test the implementation of
6 // SSLServerSocket. In order to establish connections between the sockets
7 // we need two additional classes:
8 // 1. FakeSocket
9 //    Connects SSL socket to FakeDataChannel. This class is just a stub.
10 //
11 // 2. FakeDataChannel
12 //    Implements the actual exchange of data between two FakeSockets.
13 //
14 // Implementations of these two classes are included in this file.
15 
16 #include "net/socket/ssl_server_socket.h"
17 
18 #include <stdint.h>
19 #include <stdlib.h>
20 
21 #include <memory>
22 #include <utility>
23 
24 #include "base/check.h"
25 #include "base/compiler_specific.h"
26 #include "base/containers/queue.h"
27 #include "base/files/file_path.h"
28 #include "base/files/file_util.h"
29 #include "base/functional/bind.h"
30 #include "base/functional/callback_helpers.h"
31 #include "base/location.h"
32 #include "base/memory/raw_ptr.h"
33 #include "base/memory/scoped_refptr.h"
34 #include "base/notreached.h"
35 #include "base/run_loop.h"
36 #include "base/task/single_thread_task_runner.h"
37 #include "base/test/task_environment.h"
38 #include "build/build_config.h"
39 #include "crypto/rsa_private_key.h"
40 #include "net/base/address_list.h"
41 #include "net/base/completion_once_callback.h"
42 #include "net/base/host_port_pair.h"
43 #include "net/base/io_buffer.h"
44 #include "net/base/ip_address.h"
45 #include "net/base/ip_endpoint.h"
46 #include "net/base/net_errors.h"
47 #include "net/cert/cert_status_flags.h"
48 #include "net/cert/ct_policy_enforcer.h"
49 #include "net/cert/ct_policy_status.h"
50 #include "net/cert/mock_cert_verifier.h"
51 #include "net/cert/mock_client_cert_verifier.h"
52 #include "net/cert/signed_certificate_timestamp_and_status.h"
53 #include "net/cert/x509_certificate.h"
54 #include "net/http/transport_security_state.h"
55 #include "net/log/net_log_with_source.h"
56 #include "net/socket/client_socket_factory.h"
57 #include "net/socket/socket_test_util.h"
58 #include "net/socket/ssl_client_socket.h"
59 #include "net/socket/stream_socket.h"
60 #include "net/ssl/ssl_cert_request_info.h"
61 #include "net/ssl/ssl_cipher_suite_names.h"
62 #include "net/ssl/ssl_client_session_cache.h"
63 #include "net/ssl/ssl_connection_status_flags.h"
64 #include "net/ssl/ssl_info.h"
65 #include "net/ssl/ssl_private_key.h"
66 #include "net/ssl/ssl_server_config.h"
67 #include "net/ssl/test_ssl_config_service.h"
68 #include "net/ssl/test_ssl_private_key.h"
69 #include "net/test/cert_test_util.h"
70 #include "net/test/gtest_util.h"
71 #include "net/test/test_data_directory.h"
72 #include "net/test/test_with_task_environment.h"
73 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
74 #include "testing/gmock/include/gmock/gmock.h"
75 #include "testing/gtest/include/gtest/gtest.h"
76 #include "testing/platform_test.h"
77 #include "third_party/boringssl/src/include/openssl/evp.h"
78 #include "third_party/boringssl/src/include/openssl/ssl.h"
79 
80 using net::test::IsError;
81 using net::test::IsOk;
82 
83 namespace net {
84 
85 namespace {
86 
87 // Client certificates are disabled on iOS.
88 #if !BUILDFLAG(IS_IOS)
89 const char kClientCertFileName[] = "client_1.pem";
90 const char kClientPrivateKeyFileName[] = "client_1.pk8";
91 const char kWrongClientCertFileName[] = "client_2.pem";
92 const char kWrongClientPrivateKeyFileName[] = "client_2.pk8";
93 #endif  // !IS_IOS
94 
95 const uint16_t kEcdheCiphers[] = {
96     0xc007,  // ECDHE_ECDSA_WITH_RC4_128_SHA
97     0xc009,  // ECDHE_ECDSA_WITH_AES_128_CBC_SHA
98     0xc00a,  // ECDHE_ECDSA_WITH_AES_256_CBC_SHA
99     0xc011,  // ECDHE_RSA_WITH_RC4_128_SHA
100     0xc013,  // ECDHE_RSA_WITH_AES_128_CBC_SHA
101     0xc014,  // ECDHE_RSA_WITH_AES_256_CBC_SHA
102     0xc02b,  // ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
103     0xc02c,  // ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
104     0xc02f,  // ECDHE_RSA_WITH_AES_128_GCM_SHA256
105     0xc030,  // ECDHE_RSA_WITH_AES_256_GCM_SHA384
106     0xcca8,  // ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
107     0xcca9,  // ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
108 };
109 
110 class MockCTPolicyEnforcer : public CTPolicyEnforcer {
111  public:
112   MockCTPolicyEnforcer() = default;
113   ~MockCTPolicyEnforcer() override = default;
CheckCompliance(X509Certificate * cert,const ct::SCTList & verified_scts,const NetLogWithSource & net_log)114   ct::CTPolicyCompliance CheckCompliance(
115       X509Certificate* cert,
116       const ct::SCTList& verified_scts,
117       const NetLogWithSource& net_log) override {
118     return ct::CTPolicyCompliance::CT_POLICY_COMPLIES_VIA_SCTS;
119   }
120 };
121 
122 class FakeDataChannel {
123  public:
124   FakeDataChannel() = default;
125 
126   FakeDataChannel(const FakeDataChannel&) = delete;
127   FakeDataChannel& operator=(const FakeDataChannel&) = delete;
128 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)129   int Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) {
130     DCHECK(read_callback_.is_null());
131     DCHECK(!read_buf_.get());
132     if (closed_)
133       return 0;
134     if (data_.empty()) {
135       read_callback_ = std::move(callback);
136       read_buf_ = buf;
137       read_buf_len_ = buf_len;
138       return ERR_IO_PENDING;
139     }
140     return PropagateData(buf, buf_len);
141   }
142 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)143   int Write(IOBuffer* buf,
144             int buf_len,
145             CompletionOnceCallback callback,
146             const NetworkTrafficAnnotationTag& traffic_annotation) {
147     DCHECK(write_callback_.is_null());
148     if (closed_) {
149       if (write_called_after_close_)
150         return ERR_CONNECTION_RESET;
151       write_called_after_close_ = true;
152       write_callback_ = std::move(callback);
153       base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
154           FROM_HERE, base::BindOnce(&FakeDataChannel::DoWriteCallback,
155                                     weak_factory_.GetWeakPtr()));
156       return ERR_IO_PENDING;
157     }
158     // This function returns synchronously, so make a copy of the buffer.
159     data_.push(base::MakeRefCounted<DrainableIOBuffer>(
160         base::MakeRefCounted<StringIOBuffer>(std::string(buf->data(), buf_len)),
161         buf_len));
162     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
163         FROM_HERE, base::BindOnce(&FakeDataChannel::DoReadCallback,
164                                   weak_factory_.GetWeakPtr()));
165     return buf_len;
166   }
167 
168   // Closes the FakeDataChannel. After Close() is called, Read() returns 0,
169   // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that
170   // after the FakeDataChannel is closed, the first Write() call completes
171   // asynchronously, which is necessary to reproduce bug 127822.
Close()172   void Close() {
173     closed_ = true;
174     if (!read_callback_.is_null()) {
175       base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
176           FROM_HERE, base::BindOnce(&FakeDataChannel::DoReadCallback,
177                                     weak_factory_.GetWeakPtr()));
178     }
179   }
180 
181  private:
DoReadCallback()182   void DoReadCallback() {
183     if (read_callback_.is_null())
184       return;
185 
186     if (closed_) {
187       std::move(read_callback_).Run(ERR_CONNECTION_CLOSED);
188       return;
189     }
190 
191     if (data_.empty())
192       return;
193 
194     int copied = PropagateData(read_buf_, read_buf_len_);
195     read_buf_ = nullptr;
196     read_buf_len_ = 0;
197     std::move(read_callback_).Run(copied);
198   }
199 
DoWriteCallback()200   void DoWriteCallback() {
201     if (write_callback_.is_null())
202       return;
203 
204     std::move(write_callback_).Run(ERR_CONNECTION_RESET);
205   }
206 
PropagateData(scoped_refptr<IOBuffer> read_buf,int read_buf_len)207   int PropagateData(scoped_refptr<IOBuffer> read_buf, int read_buf_len) {
208     scoped_refptr<DrainableIOBuffer> buf = data_.front();
209     int copied = std::min(buf->BytesRemaining(), read_buf_len);
210     memcpy(read_buf->data(), buf->data(), copied);
211     buf->DidConsume(copied);
212 
213     if (!buf->BytesRemaining())
214       data_.pop();
215     return copied;
216   }
217 
218   CompletionOnceCallback read_callback_;
219   scoped_refptr<IOBuffer> read_buf_;
220   int read_buf_len_ = 0;
221 
222   CompletionOnceCallback write_callback_;
223 
224   base::queue<scoped_refptr<DrainableIOBuffer>> data_;
225 
226   // True if Close() has been called.
227   bool closed_ = false;
228 
229   // Controls the completion of Write() after the FakeDataChannel is closed.
230   // After the FakeDataChannel is closed, the first Write() call completes
231   // asynchronously.
232   bool write_called_after_close_ = false;
233 
234   base::WeakPtrFactory<FakeDataChannel> weak_factory_{this};
235 };
236 
237 class FakeSocket : public StreamSocket {
238  public:
FakeSocket(FakeDataChannel * incoming_channel,FakeDataChannel * outgoing_channel)239   FakeSocket(FakeDataChannel* incoming_channel,
240              FakeDataChannel* outgoing_channel)
241       : incoming_(incoming_channel), outgoing_(outgoing_channel) {}
242 
243   FakeSocket(const FakeSocket&) = delete;
244   FakeSocket& operator=(const FakeSocket&) = delete;
245 
246   ~FakeSocket() override = default;
247 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)248   int Read(IOBuffer* buf,
249            int buf_len,
250            CompletionOnceCallback callback) override {
251     // Read random number of bytes.
252     buf_len = rand() % buf_len + 1;
253     return incoming_->Read(buf, buf_len, std::move(callback));
254   }
255 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)256   int Write(IOBuffer* buf,
257             int buf_len,
258             CompletionOnceCallback callback,
259             const NetworkTrafficAnnotationTag& traffic_annotation) override {
260     // Write random number of bytes.
261     buf_len = rand() % buf_len + 1;
262     return outgoing_->Write(buf, buf_len, std::move(callback),
263                             TRAFFIC_ANNOTATION_FOR_TESTS);
264   }
265 
SetReceiveBufferSize(int32_t size)266   int SetReceiveBufferSize(int32_t size) override { return OK; }
267 
SetSendBufferSize(int32_t size)268   int SetSendBufferSize(int32_t size) override { return OK; }
269 
Connect(CompletionOnceCallback callback)270   int Connect(CompletionOnceCallback callback) override { return OK; }
271 
Disconnect()272   void Disconnect() override {
273     incoming_->Close();
274     outgoing_->Close();
275   }
276 
IsConnected() const277   bool IsConnected() const override { return true; }
278 
IsConnectedAndIdle() const279   bool IsConnectedAndIdle() const override { return true; }
280 
GetPeerAddress(IPEndPoint * address) const281   int GetPeerAddress(IPEndPoint* address) const override {
282     *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/);
283     return OK;
284   }
285 
GetLocalAddress(IPEndPoint * address) const286   int GetLocalAddress(IPEndPoint* address) const override {
287     *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/);
288     return OK;
289   }
290 
NetLog() const291   const NetLogWithSource& NetLog() const override { return net_log_; }
292 
WasEverUsed() const293   bool WasEverUsed() const override { return true; }
294 
WasAlpnNegotiated() const295   bool WasAlpnNegotiated() const override { return false; }
296 
GetNegotiatedProtocol() const297   NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
298 
GetSSLInfo(SSLInfo * ssl_info)299   bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
300 
GetTotalReceivedBytes() const301   int64_t GetTotalReceivedBytes() const override {
302     NOTIMPLEMENTED();
303     return 0;
304   }
305 
ApplySocketTag(const SocketTag & tag)306   void ApplySocketTag(const SocketTag& tag) override {}
307 
308  private:
309   NetLogWithSource net_log_;
310   raw_ptr<FakeDataChannel> incoming_;
311   raw_ptr<FakeDataChannel> outgoing_;
312 };
313 
314 }  // namespace
315 
316 // Verify the correctness of the test helper classes first.
TEST(FakeSocketTest,DataTransfer)317 TEST(FakeSocketTest, DataTransfer) {
318   base::test::TaskEnvironment task_environment;
319 
320   // Establish channels between two sockets.
321   FakeDataChannel channel_1;
322   FakeDataChannel channel_2;
323   FakeSocket client(&channel_1, &channel_2);
324   FakeSocket server(&channel_2, &channel_1);
325 
326   const char kTestData[] = "testing123";
327   const int kTestDataSize = strlen(kTestData);
328   const int kReadBufSize = 1024;
329   scoped_refptr<IOBuffer> write_buf =
330       base::MakeRefCounted<StringIOBuffer>(kTestData);
331   scoped_refptr<IOBuffer> read_buf =
332       base::MakeRefCounted<IOBuffer>(kReadBufSize);
333 
334   // Write then read.
335   int written =
336       server.Write(write_buf.get(), kTestDataSize, CompletionOnceCallback(),
337                    TRAFFIC_ANNOTATION_FOR_TESTS);
338   EXPECT_GT(written, 0);
339   EXPECT_LE(written, kTestDataSize);
340 
341   int read =
342       client.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
343   EXPECT_GT(read, 0);
344   EXPECT_LE(read, written);
345   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
346 
347   // Read then write.
348   TestCompletionCallback callback;
349   EXPECT_EQ(ERR_IO_PENDING,
350             server.Read(read_buf.get(), kReadBufSize, callback.callback()));
351 
352   written =
353       client.Write(write_buf.get(), kTestDataSize, CompletionOnceCallback(),
354                    TRAFFIC_ANNOTATION_FOR_TESTS);
355   EXPECT_GT(written, 0);
356   EXPECT_LE(written, kTestDataSize);
357 
358   read = callback.WaitForResult();
359   EXPECT_GT(read, 0);
360   EXPECT_LE(read, written);
361   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
362 }
363 
364 class SSLServerSocketTest : public PlatformTest, public WithTaskEnvironment {
365  public:
SSLServerSocketTest()366   SSLServerSocketTest()
367       : ssl_config_service_(
368             std::make_unique<TestSSLConfigService>(SSLContextConfig())),
369         cert_verifier_(std::make_unique<MockCertVerifier>()),
370         client_cert_verifier_(std::make_unique<MockClientCertVerifier>()),
371         transport_security_state_(std::make_unique<TransportSecurityState>()),
372         ct_policy_enforcer_(std::make_unique<MockCTPolicyEnforcer>()),
373         ssl_client_session_cache_(std::make_unique<SSLClientSessionCache>(
374             SSLClientSessionCache::Config())) {}
375 
SetUp()376   void SetUp() override {
377     PlatformTest::SetUp();
378 
379     cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
380     client_cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
381 
382     server_cert_ =
383         ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der");
384     ASSERT_TRUE(server_cert_);
385     server_private_key_ = ReadTestKey("unittest.key.bin");
386     ASSERT_TRUE(server_private_key_);
387 
388     std::unique_ptr<crypto::RSAPrivateKey> key =
389         ReadTestKey("unittest.key.bin");
390     ASSERT_TRUE(key);
391     server_ssl_private_key_ = WrapOpenSSLPrivateKey(bssl::UpRef(key->key()));
392 
393     // Certificate provided by the host doesn't need authority.
394     client_ssl_config_.allowed_bad_certs.emplace_back(
395         server_cert_, CERT_STATUS_AUTHORITY_INVALID);
396 
397     client_context_ = std::make_unique<SSLClientContext>(
398         ssl_config_service_.get(), cert_verifier_.get(),
399         transport_security_state_.get(), ct_policy_enforcer_.get(),
400         ssl_client_session_cache_.get(), nullptr);
401   }
402 
403  protected:
CreateContext()404   void CreateContext() {
405     client_socket_.reset();
406     server_socket_.reset();
407     channel_1_.reset();
408     channel_2_.reset();
409     server_context_ = CreateSSLServerContext(
410         server_cert_.get(), *server_private_key_, server_ssl_config_);
411   }
412 
CreateContextSSLPrivateKey()413   void CreateContextSSLPrivateKey() {
414     client_socket_.reset();
415     server_socket_.reset();
416     channel_1_.reset();
417     channel_2_.reset();
418     server_context_.reset();
419     server_context_ = CreateSSLServerContext(
420         server_cert_.get(), server_ssl_private_key_, server_ssl_config_);
421   }
422 
GetHostAndPort()423   static HostPortPair GetHostAndPort() { return HostPortPair("unittest", 0); }
424 
CreateSockets()425   void CreateSockets() {
426     client_socket_.reset();
427     server_socket_.reset();
428     channel_1_ = std::make_unique<FakeDataChannel>();
429     channel_2_ = std::make_unique<FakeDataChannel>();
430     std::unique_ptr<StreamSocket> client_connection =
431         std::make_unique<FakeSocket>(channel_1_.get(), channel_2_.get());
432     std::unique_ptr<StreamSocket> server_socket =
433         std::make_unique<FakeSocket>(channel_2_.get(), channel_1_.get());
434 
435     client_socket_ = client_context_->CreateSSLClientSocket(
436         std::move(client_connection), GetHostAndPort(), client_ssl_config_);
437     ASSERT_TRUE(client_socket_);
438 
439     server_socket_ =
440         server_context_->CreateSSLServerSocket(std::move(server_socket));
441     ASSERT_TRUE(server_socket_);
442   }
443 
444 // Client certificates are disabled on iOS.
445 #if !BUILDFLAG(IS_IOS)
ConfigureClientCertsForClient(const char * cert_file_name,const char * private_key_file_name)446   void ConfigureClientCertsForClient(const char* cert_file_name,
447                                      const char* private_key_file_name) {
448     scoped_refptr<X509Certificate> client_cert =
449         ImportCertFromFile(GetTestCertsDirectory(), cert_file_name);
450     ASSERT_TRUE(client_cert);
451 
452     std::unique_ptr<crypto::RSAPrivateKey> key =
453         ReadTestKey(private_key_file_name);
454     ASSERT_TRUE(key);
455 
456     client_context_->SetClientCertificate(
457         GetHostAndPort(), std::move(client_cert),
458         WrapOpenSSLPrivateKey(bssl::UpRef(key->key())));
459   }
460 
ConfigureClientCertsForServer()461   void ConfigureClientCertsForServer() {
462     server_ssl_config_.client_cert_type =
463         SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT;
464 
465     // "CN=B CA" - DER encoded DN of the issuer of client_1.pem
466     static const uint8_t kClientCertCAName[] = {
467         0x30, 0x0f, 0x31, 0x0d, 0x30, 0x0b, 0x06, 0x03, 0x55,
468         0x04, 0x03, 0x0c, 0x04, 0x42, 0x20, 0x43, 0x41};
469     server_ssl_config_.cert_authorities.emplace_back(
470         std::begin(kClientCertCAName), std::end(kClientCertCAName));
471 
472     scoped_refptr<X509Certificate> expected_client_cert(
473         ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName));
474     ASSERT_TRUE(expected_client_cert);
475 
476     client_cert_verifier_->AddResultForCert(expected_client_cert.get(), OK);
477 
478     server_ssl_config_.client_cert_verifier = client_cert_verifier_.get();
479   }
480 #endif  // !IS_IOS
481 
ReadTestKey(base::StringPiece name)482   std::unique_ptr<crypto::RSAPrivateKey> ReadTestKey(base::StringPiece name) {
483     base::FilePath certs_dir(GetTestCertsDirectory());
484     base::FilePath key_path = certs_dir.AppendASCII(name);
485     std::string key_string;
486     if (!base::ReadFileToString(key_path, &key_string))
487       return nullptr;
488     std::vector<uint8_t> key_vector(
489         reinterpret_cast<const uint8_t*>(key_string.data()),
490         reinterpret_cast<const uint8_t*>(key_string.data() +
491                                          key_string.length()));
492     std::unique_ptr<crypto::RSAPrivateKey> key(
493         crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
494     return key;
495   }
496 
PumpServerToClient()497   void PumpServerToClient() {
498     const int kReadBufSize = 1024;
499     scoped_refptr<StringIOBuffer> write_buf =
500         base::MakeRefCounted<StringIOBuffer>("testing123");
501     scoped_refptr<DrainableIOBuffer> read_buf =
502         base::MakeRefCounted<DrainableIOBuffer>(
503             base::MakeRefCounted<IOBuffer>(kReadBufSize), kReadBufSize);
504     TestCompletionCallback write_callback;
505     TestCompletionCallback read_callback;
506     int server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
507                                            write_callback.callback(),
508                                            TRAFFIC_ANNOTATION_FOR_TESTS);
509     EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
510     int client_ret = client_socket_->Read(
511         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
512     EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
513 
514     server_ret = write_callback.GetResult(server_ret);
515     EXPECT_GT(server_ret, 0);
516     client_ret = read_callback.GetResult(client_ret);
517     ASSERT_GT(client_ret, 0);
518   }
519 
520   std::unique_ptr<FakeDataChannel> channel_1_;
521   std::unique_ptr<FakeDataChannel> channel_2_;
522   SSLConfig client_ssl_config_;
523   SSLServerConfig server_ssl_config_;
524   std::unique_ptr<TestSSLConfigService> ssl_config_service_;
525   std::unique_ptr<MockCertVerifier> cert_verifier_;
526   std::unique_ptr<MockClientCertVerifier> client_cert_verifier_;
527   std::unique_ptr<TransportSecurityState> transport_security_state_;
528   std::unique_ptr<MockCTPolicyEnforcer> ct_policy_enforcer_;
529   std::unique_ptr<SSLClientSessionCache> ssl_client_session_cache_;
530   std::unique_ptr<SSLClientContext> client_context_;
531   std::unique_ptr<SSLServerContext> server_context_;
532   std::unique_ptr<SSLClientSocket> client_socket_;
533   std::unique_ptr<SSLServerSocket> server_socket_;
534   std::unique_ptr<crypto::RSAPrivateKey> server_private_key_;
535   scoped_refptr<SSLPrivateKey> server_ssl_private_key_;
536   scoped_refptr<X509Certificate> server_cert_;
537 };
538 
539 class SSLServerSocketReadTest : public SSLServerSocketTest,
540                                 public ::testing::WithParamInterface<bool> {
541  protected:
SSLServerSocketReadTest()542   SSLServerSocketReadTest() : read_if_ready_enabled_(GetParam()) {}
543 
Read(StreamSocket * socket,IOBuffer * buf,int buf_len,CompletionOnceCallback callback)544   int Read(StreamSocket* socket,
545            IOBuffer* buf,
546            int buf_len,
547            CompletionOnceCallback callback) {
548     if (read_if_ready_enabled()) {
549       return socket->ReadIfReady(buf, buf_len, std::move(callback));
550     }
551     return socket->Read(buf, buf_len, std::move(callback));
552   }
553 
read_if_ready_enabled() const554   bool read_if_ready_enabled() const { return read_if_ready_enabled_; }
555 
556  private:
557   const bool read_if_ready_enabled_;
558 };
559 
560 INSTANTIATE_TEST_SUITE_P(/* no prefix */,
561                          SSLServerSocketReadTest,
562                          ::testing::Bool());
563 
564 // This test only executes creation of client and server sockets. This is to
565 // test that creation of sockets doesn't crash and have minimal code to run
566 // with memory leak/corruption checking tools.
TEST_F(SSLServerSocketTest,Initialize)567 TEST_F(SSLServerSocketTest, Initialize) {
568   ASSERT_NO_FATAL_FAILURE(CreateContext());
569   ASSERT_NO_FATAL_FAILURE(CreateSockets());
570 }
571 
572 // This test executes Connect() on SSLClientSocket and Handshake() on
573 // SSLServerSocket to make sure handshaking between the two sockets is
574 // completed successfully.
TEST_F(SSLServerSocketTest,Handshake)575 TEST_F(SSLServerSocketTest, Handshake) {
576   ASSERT_NO_FATAL_FAILURE(CreateContext());
577   ASSERT_NO_FATAL_FAILURE(CreateSockets());
578 
579   TestCompletionCallback handshake_callback;
580   int server_ret = server_socket_->Handshake(handshake_callback.callback());
581 
582   TestCompletionCallback connect_callback;
583   int client_ret = client_socket_->Connect(connect_callback.callback());
584 
585   client_ret = connect_callback.GetResult(client_ret);
586   server_ret = handshake_callback.GetResult(server_ret);
587 
588   ASSERT_THAT(client_ret, IsOk());
589   ASSERT_THAT(server_ret, IsOk());
590 
591   // Make sure the cert status is expected.
592   SSLInfo ssl_info;
593   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
594   EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
595 
596   // The default cipher suite should be ECDHE and an AEAD.
597   uint16_t cipher_suite =
598       SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
599   const char* key_exchange;
600   const char* cipher;
601   const char* mac;
602   bool is_aead;
603   bool is_tls13;
604   SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
605                           cipher_suite);
606   EXPECT_TRUE(is_aead);
607 }
608 
609 // This test makes sure the session cache is working.
TEST_F(SSLServerSocketTest,HandshakeCached)610 TEST_F(SSLServerSocketTest, HandshakeCached) {
611   ASSERT_NO_FATAL_FAILURE(CreateContext());
612   ASSERT_NO_FATAL_FAILURE(CreateSockets());
613 
614   TestCompletionCallback handshake_callback;
615   int server_ret = server_socket_->Handshake(handshake_callback.callback());
616 
617   TestCompletionCallback connect_callback;
618   int client_ret = client_socket_->Connect(connect_callback.callback());
619 
620   client_ret = connect_callback.GetResult(client_ret);
621   server_ret = handshake_callback.GetResult(server_ret);
622 
623   ASSERT_THAT(client_ret, IsOk());
624   ASSERT_THAT(server_ret, IsOk());
625 
626   // Make sure the cert status is expected.
627   SSLInfo ssl_info;
628   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
629   EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
630   SSLInfo ssl_server_info;
631   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
632   EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
633 
634   // Pump client read to get new session tickets.
635   PumpServerToClient();
636 
637   // Make sure the second connection is cached.
638   ASSERT_NO_FATAL_FAILURE(CreateSockets());
639   TestCompletionCallback handshake_callback2;
640   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
641 
642   TestCompletionCallback connect_callback2;
643   int client_ret2 = client_socket_->Connect(connect_callback2.callback());
644 
645   client_ret2 = connect_callback2.GetResult(client_ret2);
646   server_ret2 = handshake_callback2.GetResult(server_ret2);
647 
648   ASSERT_THAT(client_ret2, IsOk());
649   ASSERT_THAT(server_ret2, IsOk());
650 
651   // Make sure the cert status is expected.
652   SSLInfo ssl_info2;
653   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
654   EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
655   SSLInfo ssl_server_info2;
656   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
657   EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
658 }
659 
660 // This test makes sure the session cache separates out by server context.
TEST_F(SSLServerSocketTest,HandshakeCachedContextSwitch)661 TEST_F(SSLServerSocketTest, HandshakeCachedContextSwitch) {
662   ASSERT_NO_FATAL_FAILURE(CreateContext());
663   ASSERT_NO_FATAL_FAILURE(CreateSockets());
664 
665   TestCompletionCallback handshake_callback;
666   int server_ret = server_socket_->Handshake(handshake_callback.callback());
667 
668   TestCompletionCallback connect_callback;
669   int client_ret = client_socket_->Connect(connect_callback.callback());
670 
671   client_ret = connect_callback.GetResult(client_ret);
672   server_ret = handshake_callback.GetResult(server_ret);
673 
674   ASSERT_THAT(client_ret, IsOk());
675   ASSERT_THAT(server_ret, IsOk());
676 
677   // Make sure the cert status is expected.
678   SSLInfo ssl_info;
679   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
680   EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
681   SSLInfo ssl_server_info;
682   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
683   EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
684 
685   // Make sure the second connection is NOT cached when using a new context.
686   ASSERT_NO_FATAL_FAILURE(CreateContext());
687   ASSERT_NO_FATAL_FAILURE(CreateSockets());
688 
689   TestCompletionCallback handshake_callback2;
690   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
691 
692   TestCompletionCallback connect_callback2;
693   int client_ret2 = client_socket_->Connect(connect_callback2.callback());
694 
695   client_ret2 = connect_callback2.GetResult(client_ret2);
696   server_ret2 = handshake_callback2.GetResult(server_ret2);
697 
698   ASSERT_THAT(client_ret2, IsOk());
699   ASSERT_THAT(server_ret2, IsOk());
700 
701   // Make sure the cert status is expected.
702   SSLInfo ssl_info2;
703   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
704   EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_FULL);
705   SSLInfo ssl_server_info2;
706   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
707   EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_FULL);
708 }
709 
710 // Client certificates are disabled on iOS.
711 #if !BUILDFLAG(IS_IOS)
712 // This test executes Connect() on SSLClientSocket and Handshake() on
713 // SSLServerSocket to make sure handshaking between the two sockets is
714 // completed successfully, using client certificate.
TEST_F(SSLServerSocketTest,HandshakeWithClientCert)715 TEST_F(SSLServerSocketTest, HandshakeWithClientCert) {
716   scoped_refptr<X509Certificate> client_cert =
717       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
718   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
719       kClientCertFileName, kClientPrivateKeyFileName));
720   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
721   ASSERT_NO_FATAL_FAILURE(CreateContext());
722   ASSERT_NO_FATAL_FAILURE(CreateSockets());
723 
724   TestCompletionCallback handshake_callback;
725   int server_ret = server_socket_->Handshake(handshake_callback.callback());
726 
727   TestCompletionCallback connect_callback;
728   int client_ret = client_socket_->Connect(connect_callback.callback());
729 
730   client_ret = connect_callback.GetResult(client_ret);
731   server_ret = handshake_callback.GetResult(server_ret);
732 
733   ASSERT_THAT(client_ret, IsOk());
734   ASSERT_THAT(server_ret, IsOk());
735 
736   // Make sure the cert status is expected.
737   SSLInfo ssl_info;
738   client_socket_->GetSSLInfo(&ssl_info);
739   EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
740   server_socket_->GetSSLInfo(&ssl_info);
741   ASSERT_TRUE(ssl_info.cert.get());
742   EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_info.cert.get()));
743 }
744 
745 // This test executes Connect() on SSLClientSocket and Handshake() twice on
746 // SSLServerSocket to make sure handshaking between the two sockets is
747 // completed successfully, using client certificate. The second connection is
748 // expected to succeed through the session cache.
TEST_F(SSLServerSocketTest,HandshakeWithClientCertCached)749 TEST_F(SSLServerSocketTest, HandshakeWithClientCertCached) {
750   scoped_refptr<X509Certificate> client_cert =
751       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
752   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
753       kClientCertFileName, kClientPrivateKeyFileName));
754   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
755   ASSERT_NO_FATAL_FAILURE(CreateContext());
756   ASSERT_NO_FATAL_FAILURE(CreateSockets());
757 
758   TestCompletionCallback handshake_callback;
759   int server_ret = server_socket_->Handshake(handshake_callback.callback());
760 
761   TestCompletionCallback connect_callback;
762   int client_ret = client_socket_->Connect(connect_callback.callback());
763 
764   client_ret = connect_callback.GetResult(client_ret);
765   server_ret = handshake_callback.GetResult(server_ret);
766 
767   ASSERT_THAT(client_ret, IsOk());
768   ASSERT_THAT(server_ret, IsOk());
769 
770   // Make sure the cert status is expected.
771   SSLInfo ssl_info;
772   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
773   EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
774   SSLInfo ssl_server_info;
775   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
776   ASSERT_TRUE(ssl_server_info.cert.get());
777   EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_server_info.cert.get()));
778   EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
779   // Pump client read to get new session tickets.
780   PumpServerToClient();
781   server_socket_->Disconnect();
782   client_socket_->Disconnect();
783 
784   // Create the connection again.
785   ASSERT_NO_FATAL_FAILURE(CreateSockets());
786   TestCompletionCallback handshake_callback2;
787   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
788 
789   TestCompletionCallback connect_callback2;
790   int client_ret2 = client_socket_->Connect(connect_callback2.callback());
791 
792   client_ret2 = connect_callback2.GetResult(client_ret2);
793   server_ret2 = handshake_callback2.GetResult(server_ret2);
794 
795   ASSERT_THAT(client_ret2, IsOk());
796   ASSERT_THAT(server_ret2, IsOk());
797 
798   // Make sure the cert status is expected.
799   SSLInfo ssl_info2;
800   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
801   EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
802   SSLInfo ssl_server_info2;
803   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
804   ASSERT_TRUE(ssl_server_info2.cert.get());
805   EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_server_info2.cert.get()));
806   EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
807 }
808 
TEST_F(SSLServerSocketTest,HandshakeWithClientCertRequiredNotSupplied)809 TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSupplied) {
810   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
811   ASSERT_NO_FATAL_FAILURE(CreateContext());
812   ASSERT_NO_FATAL_FAILURE(CreateSockets());
813   // Use the default setting for the client socket, which is to not send
814   // a client certificate. This will cause the client to receive an
815   // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
816   // requested cert_authorities from the CertificateRequest sent by the
817   // server.
818 
819   TestCompletionCallback handshake_callback;
820   int server_ret = server_socket_->Handshake(handshake_callback.callback());
821 
822   TestCompletionCallback connect_callback;
823   EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
824             connect_callback.GetResult(
825                 client_socket_->Connect(connect_callback.callback())));
826 
827   auto request_info = base::MakeRefCounted<SSLCertRequestInfo>();
828   client_socket_->GetSSLCertRequestInfo(request_info.get());
829 
830   // Check that the authority name that arrived in the CertificateRequest
831   // handshake message is as expected.
832   scoped_refptr<X509Certificate> client_cert =
833       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
834   ASSERT_TRUE(client_cert);
835   EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
836 
837   client_socket_->Disconnect();
838 
839   EXPECT_THAT(handshake_callback.GetResult(server_ret),
840               IsError(ERR_CONNECTION_CLOSED));
841 }
842 
TEST_F(SSLServerSocketTest,HandshakeWithClientCertRequiredNotSuppliedCached)843 TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSuppliedCached) {
844   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
845   ASSERT_NO_FATAL_FAILURE(CreateContext());
846   ASSERT_NO_FATAL_FAILURE(CreateSockets());
847   // Use the default setting for the client socket, which is to not send
848   // a client certificate. This will cause the client to receive an
849   // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
850   // requested cert_authorities from the CertificateRequest sent by the
851   // server.
852 
853   TestCompletionCallback handshake_callback;
854   int server_ret = server_socket_->Handshake(handshake_callback.callback());
855 
856   TestCompletionCallback connect_callback;
857   EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
858             connect_callback.GetResult(
859                 client_socket_->Connect(connect_callback.callback())));
860 
861   auto request_info = base::MakeRefCounted<SSLCertRequestInfo>();
862   client_socket_->GetSSLCertRequestInfo(request_info.get());
863 
864   // Check that the authority name that arrived in the CertificateRequest
865   // handshake message is as expected.
866   scoped_refptr<X509Certificate> client_cert =
867       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
868   ASSERT_TRUE(client_cert);
869   EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
870 
871   client_socket_->Disconnect();
872 
873   EXPECT_THAT(handshake_callback.GetResult(server_ret),
874               IsError(ERR_CONNECTION_CLOSED));
875   server_socket_->Disconnect();
876 
877   // Below, check that the cache didn't store the result of a failed handshake.
878   ASSERT_NO_FATAL_FAILURE(CreateSockets());
879   TestCompletionCallback handshake_callback2;
880   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
881 
882   TestCompletionCallback connect_callback2;
883   EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
884             connect_callback2.GetResult(
885                 client_socket_->Connect(connect_callback2.callback())));
886 
887   auto request_info2 = base::MakeRefCounted<SSLCertRequestInfo>();
888   client_socket_->GetSSLCertRequestInfo(request_info2.get());
889 
890   // Check that the authority name that arrived in the CertificateRequest
891   // handshake message is as expected.
892   EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info2->cert_authorities));
893 
894   client_socket_->Disconnect();
895 
896   EXPECT_THAT(handshake_callback2.GetResult(server_ret2),
897               IsError(ERR_CONNECTION_CLOSED));
898 }
899 
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSupplied)900 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSupplied) {
901   scoped_refptr<X509Certificate> client_cert =
902       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
903   ASSERT_TRUE(client_cert);
904 
905   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
906       kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
907   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
908   ASSERT_NO_FATAL_FAILURE(CreateContext());
909   ASSERT_NO_FATAL_FAILURE(CreateSockets());
910 
911   TestCompletionCallback handshake_callback;
912   int server_ret = server_socket_->Handshake(handshake_callback.callback());
913 
914   TestCompletionCallback connect_callback;
915   int client_ret = client_socket_->Connect(connect_callback.callback());
916 
917   // In TLS 1.3, the client cert error isn't exposed until Read is called.
918   EXPECT_EQ(OK, connect_callback.GetResult(client_ret));
919   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
920             handshake_callback.GetResult(server_ret));
921 
922   // Pump client read to get client cert error.
923   const int kReadBufSize = 1024;
924   scoped_refptr<DrainableIOBuffer> read_buf =
925       base::MakeRefCounted<DrainableIOBuffer>(
926           base::MakeRefCounted<IOBuffer>(kReadBufSize), kReadBufSize);
927   TestCompletionCallback read_callback;
928   client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
929                                     read_callback.callback());
930   client_ret = read_callback.GetResult(client_ret);
931   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
932 }
933 
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSuppliedTLS12)934 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedTLS12) {
935   scoped_refptr<X509Certificate> client_cert =
936       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
937   ASSERT_TRUE(client_cert);
938 
939   client_ssl_config_.version_max_override = SSL_PROTOCOL_VERSION_TLS1_2;
940   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
941       kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
942   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
943   ASSERT_NO_FATAL_FAILURE(CreateContext());
944   ASSERT_NO_FATAL_FAILURE(CreateSockets());
945 
946   TestCompletionCallback handshake_callback;
947   int server_ret = server_socket_->Handshake(handshake_callback.callback());
948 
949   TestCompletionCallback connect_callback;
950   int client_ret = client_socket_->Connect(connect_callback.callback());
951 
952   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
953             connect_callback.GetResult(client_ret));
954   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
955             handshake_callback.GetResult(server_ret));
956 }
957 
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSuppliedCached)958 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedCached) {
959   scoped_refptr<X509Certificate> client_cert =
960       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
961   ASSERT_TRUE(client_cert);
962 
963   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
964       kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
965   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
966   ASSERT_NO_FATAL_FAILURE(CreateContext());
967   ASSERT_NO_FATAL_FAILURE(CreateSockets());
968 
969   TestCompletionCallback handshake_callback;
970   int server_ret = server_socket_->Handshake(handshake_callback.callback());
971 
972   TestCompletionCallback connect_callback;
973   int client_ret = client_socket_->Connect(connect_callback.callback());
974 
975   // In TLS 1.3, the client cert error isn't exposed until Read is called.
976   EXPECT_EQ(OK, connect_callback.GetResult(client_ret));
977   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
978             handshake_callback.GetResult(server_ret));
979 
980   // Pump client read to get client cert error.
981   const int kReadBufSize = 1024;
982   scoped_refptr<DrainableIOBuffer> read_buf =
983       base::MakeRefCounted<DrainableIOBuffer>(
984           base::MakeRefCounted<IOBuffer>(kReadBufSize), kReadBufSize);
985   TestCompletionCallback read_callback;
986   client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
987                                     read_callback.callback());
988   client_ret = read_callback.GetResult(client_ret);
989   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
990 
991   client_socket_->Disconnect();
992   server_socket_->Disconnect();
993 
994   // Below, check that the cache didn't store the result of a failed handshake.
995   ASSERT_NO_FATAL_FAILURE(CreateSockets());
996   TestCompletionCallback handshake_callback2;
997   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
998 
999   TestCompletionCallback connect_callback2;
1000   int client_ret2 = client_socket_->Connect(connect_callback2.callback());
1001 
1002   // In TLS 1.3, the client cert error isn't exposed until Read is called.
1003   EXPECT_EQ(OK, connect_callback2.GetResult(client_ret2));
1004   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
1005             handshake_callback2.GetResult(server_ret2));
1006 
1007   // Pump client read to get client cert error.
1008   client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
1009                                     read_callback.callback());
1010   client_ret = read_callback.GetResult(client_ret);
1011   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
1012 }
1013 #endif  // !IS_IOS
1014 
TEST_P(SSLServerSocketReadTest,DataTransfer)1015 TEST_P(SSLServerSocketReadTest, DataTransfer) {
1016   ASSERT_NO_FATAL_FAILURE(CreateContext());
1017   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1018 
1019   // Establish connection.
1020   TestCompletionCallback connect_callback;
1021   int client_ret = client_socket_->Connect(connect_callback.callback());
1022   ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1023 
1024   TestCompletionCallback handshake_callback;
1025   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1026   ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1027 
1028   client_ret = connect_callback.GetResult(client_ret);
1029   ASSERT_THAT(client_ret, IsOk());
1030   server_ret = handshake_callback.GetResult(server_ret);
1031   ASSERT_THAT(server_ret, IsOk());
1032 
1033   const int kReadBufSize = 1024;
1034   scoped_refptr<StringIOBuffer> write_buf =
1035       base::MakeRefCounted<StringIOBuffer>("testing123");
1036   scoped_refptr<DrainableIOBuffer> read_buf =
1037       base::MakeRefCounted<DrainableIOBuffer>(
1038           base::MakeRefCounted<IOBuffer>(kReadBufSize), kReadBufSize);
1039 
1040   // Write then read.
1041   TestCompletionCallback write_callback;
1042   TestCompletionCallback read_callback;
1043   server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
1044                                      write_callback.callback(),
1045                                      TRAFFIC_ANNOTATION_FOR_TESTS);
1046   EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
1047   client_ret = client_socket_->Read(
1048       read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
1049   EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1050 
1051   server_ret = write_callback.GetResult(server_ret);
1052   EXPECT_GT(server_ret, 0);
1053   client_ret = read_callback.GetResult(client_ret);
1054   ASSERT_GT(client_ret, 0);
1055 
1056   read_buf->DidConsume(client_ret);
1057   while (read_buf->BytesConsumed() < write_buf->size()) {
1058     client_ret = client_socket_->Read(
1059         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
1060     EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1061     client_ret = read_callback.GetResult(client_ret);
1062     ASSERT_GT(client_ret, 0);
1063     read_buf->DidConsume(client_ret);
1064   }
1065   EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
1066   read_buf->SetOffset(0);
1067   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
1068 
1069   // Read then write.
1070   write_buf = base::MakeRefCounted<StringIOBuffer>("hello123");
1071   server_ret = Read(server_socket_.get(), read_buf.get(),
1072                     read_buf->BytesRemaining(), read_callback.callback());
1073   EXPECT_EQ(server_ret, ERR_IO_PENDING);
1074   client_ret = client_socket_->Write(write_buf.get(), write_buf->size(),
1075                                      write_callback.callback(),
1076                                      TRAFFIC_ANNOTATION_FOR_TESTS);
1077   EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1078 
1079   server_ret = read_callback.GetResult(server_ret);
1080   if (read_if_ready_enabled()) {
1081     // ReadIfReady signals the data is available but does not consume it.
1082     // The data is consumed later below.
1083     ASSERT_EQ(server_ret, OK);
1084   } else {
1085     ASSERT_GT(server_ret, 0);
1086     read_buf->DidConsume(server_ret);
1087   }
1088   client_ret = write_callback.GetResult(client_ret);
1089   EXPECT_GT(client_ret, 0);
1090 
1091   while (read_buf->BytesConsumed() < write_buf->size()) {
1092     server_ret = Read(server_socket_.get(), read_buf.get(),
1093                       read_buf->BytesRemaining(), read_callback.callback());
1094     // All the data was written above, so the data should be synchronously
1095     // available out of both Read() and ReadIfReady().
1096     ASSERT_GT(server_ret, 0);
1097     read_buf->DidConsume(server_ret);
1098   }
1099   EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
1100   read_buf->SetOffset(0);
1101   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
1102 }
1103 
1104 // A regression test for bug 127822 (http://crbug.com/127822).
1105 // If the server closes the connection after the handshake is finished,
1106 // the client's Write() call should not cause an infinite loop.
1107 // NOTE: this is a test for SSLClientSocket rather than SSLServerSocket.
TEST_F(SSLServerSocketTest,ClientWriteAfterServerClose)1108 TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
1109   ASSERT_NO_FATAL_FAILURE(CreateContext());
1110   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1111 
1112   // Establish connection.
1113   TestCompletionCallback connect_callback;
1114   int client_ret = client_socket_->Connect(connect_callback.callback());
1115   ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1116 
1117   TestCompletionCallback handshake_callback;
1118   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1119   ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1120 
1121   client_ret = connect_callback.GetResult(client_ret);
1122   ASSERT_THAT(client_ret, IsOk());
1123   server_ret = handshake_callback.GetResult(server_ret);
1124   ASSERT_THAT(server_ret, IsOk());
1125 
1126   scoped_refptr<StringIOBuffer> write_buf =
1127       base::MakeRefCounted<StringIOBuffer>("testing123");
1128 
1129   // The server closes the connection. The server needs to write some
1130   // data first so that the client's Read() calls from the transport
1131   // socket won't return ERR_IO_PENDING.  This ensures that the client
1132   // will call Read() on the transport socket again.
1133   TestCompletionCallback write_callback;
1134   server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
1135                                      write_callback.callback(),
1136                                      TRAFFIC_ANNOTATION_FOR_TESTS);
1137   EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
1138 
1139   server_ret = write_callback.GetResult(server_ret);
1140   EXPECT_GT(server_ret, 0);
1141 
1142   server_socket_->Disconnect();
1143 
1144   // The client writes some data. This should not cause an infinite loop.
1145   client_ret = client_socket_->Write(write_buf.get(), write_buf->size(),
1146                                      write_callback.callback(),
1147                                      TRAFFIC_ANNOTATION_FOR_TESTS);
1148   EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1149 
1150   client_ret = write_callback.GetResult(client_ret);
1151   EXPECT_GT(client_ret, 0);
1152 
1153   base::RunLoop run_loop;
1154   base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
1155       FROM_HERE, run_loop.QuitClosure(), base::Milliseconds(10));
1156   run_loop.Run();
1157 }
1158 
1159 // This test executes ExportKeyingMaterial() on the client and server sockets,
1160 // after connecting them, and verifies that the results match.
1161 // This test will fail if False Start is enabled (see crbug.com/90208).
TEST_F(SSLServerSocketTest,ExportKeyingMaterial)1162 TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
1163   ASSERT_NO_FATAL_FAILURE(CreateContext());
1164   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1165 
1166   TestCompletionCallback connect_callback;
1167   int client_ret = client_socket_->Connect(connect_callback.callback());
1168   ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1169 
1170   TestCompletionCallback handshake_callback;
1171   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1172   ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1173 
1174   if (client_ret == ERR_IO_PENDING) {
1175     ASSERT_THAT(connect_callback.WaitForResult(), IsOk());
1176   }
1177   if (server_ret == ERR_IO_PENDING) {
1178     ASSERT_THAT(handshake_callback.WaitForResult(), IsOk());
1179   }
1180 
1181   const int kKeyingMaterialSize = 32;
1182   const char kKeyingLabel[] = "EXPERIMENTAL-server-socket-test";
1183   const char kKeyingContext[] = "";
1184   unsigned char server_out[kKeyingMaterialSize];
1185   int rv = server_socket_->ExportKeyingMaterial(
1186       kKeyingLabel, false, kKeyingContext, server_out, sizeof(server_out));
1187   ASSERT_THAT(rv, IsOk());
1188 
1189   unsigned char client_out[kKeyingMaterialSize];
1190   rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, false, kKeyingContext,
1191                                             client_out, sizeof(client_out));
1192   ASSERT_THAT(rv, IsOk());
1193   EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out)));
1194 
1195   const char kKeyingLabelBad[] = "EXPERIMENTAL-server-socket-test-bad";
1196   unsigned char client_bad[kKeyingMaterialSize];
1197   rv = client_socket_->ExportKeyingMaterial(
1198       kKeyingLabelBad, false, kKeyingContext, client_bad, sizeof(client_bad));
1199   ASSERT_EQ(rv, OK);
1200   EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out)));
1201 }
1202 
1203 // Verifies that SSLConfig::require_ecdhe flags works properly.
TEST_F(SSLServerSocketTest,RequireEcdheFlag)1204 TEST_F(SSLServerSocketTest, RequireEcdheFlag) {
1205   // Disable all ECDHE suites on the client side.
1206   SSLContextConfig config;
1207   config.disabled_cipher_suites.assign(
1208       kEcdheCiphers, kEcdheCiphers + std::size(kEcdheCiphers));
1209 
1210   // Legacy RSA key exchange ciphers only exist in TLS 1.2 and below.
1211   config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
1212   ssl_config_service_->UpdateSSLConfigAndNotify(config);
1213 
1214   // Require ECDHE on the server.
1215   server_ssl_config_.require_ecdhe = true;
1216 
1217   ASSERT_NO_FATAL_FAILURE(CreateContext());
1218   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1219 
1220   TestCompletionCallback connect_callback;
1221   int client_ret = client_socket_->Connect(connect_callback.callback());
1222 
1223   TestCompletionCallback handshake_callback;
1224   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1225 
1226   client_ret = connect_callback.GetResult(client_ret);
1227   server_ret = handshake_callback.GetResult(server_ret);
1228 
1229   ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1230   ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1231 }
1232 
1233 // This test executes Connect() on SSLClientSocket and Handshake() on
1234 // SSLServerSocket to make sure handshaking between the two sockets is
1235 // completed successfully. The server key is represented by SSLPrivateKey.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKey)1236 TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKey) {
1237   ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1238   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1239 
1240   TestCompletionCallback handshake_callback;
1241   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1242 
1243   TestCompletionCallback connect_callback;
1244   int client_ret = client_socket_->Connect(connect_callback.callback());
1245 
1246   client_ret = connect_callback.GetResult(client_ret);
1247   server_ret = handshake_callback.GetResult(server_ret);
1248 
1249   ASSERT_THAT(client_ret, IsOk());
1250   ASSERT_THAT(server_ret, IsOk());
1251 
1252   // Make sure the cert status is expected.
1253   SSLInfo ssl_info;
1254   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
1255   EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
1256 
1257   // The default cipher suite should be ECDHE and an AEAD.
1258   uint16_t cipher_suite =
1259       SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
1260   const char* key_exchange;
1261   const char* cipher;
1262   const char* mac;
1263   bool is_aead;
1264   bool is_tls13;
1265   SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
1266                           cipher_suite);
1267   EXPECT_TRUE(is_aead);
1268 }
1269 
1270 // Verifies that non-ECDHE ciphers are disabled when using SSLPrivateKey as the
1271 // server key.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKeyRequireEcdhe)1272 TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKeyRequireEcdhe) {
1273   // Disable all ECDHE suites on the client side.
1274   SSLContextConfig config;
1275   config.disabled_cipher_suites.assign(
1276       kEcdheCiphers, kEcdheCiphers + std::size(kEcdheCiphers));
1277   // TLS 1.3 always works with SSLPrivateKey.
1278   config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
1279   ssl_config_service_->UpdateSSLConfigAndNotify(config);
1280 
1281   ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1282   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1283 
1284   TestCompletionCallback connect_callback;
1285   int client_ret = client_socket_->Connect(connect_callback.callback());
1286 
1287   TestCompletionCallback handshake_callback;
1288   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1289 
1290   client_ret = connect_callback.GetResult(client_ret);
1291   server_ret = handshake_callback.GetResult(server_ret);
1292 
1293   ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1294   ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1295 }
1296 
1297 class SSLServerSocketAlpsTest
1298     : public SSLServerSocketTest,
1299       public ::testing::WithParamInterface<std::tuple<bool, bool>> {
1300  public:
SSLServerSocketAlpsTest()1301   SSLServerSocketAlpsTest()
1302       : client_alps_enabled_(std::get<0>(GetParam())),
1303         server_alps_enabled_(std::get<1>(GetParam())) {}
1304   ~SSLServerSocketAlpsTest() override = default;
1305   const bool client_alps_enabled_;
1306   const bool server_alps_enabled_;
1307 };
1308 
1309 INSTANTIATE_TEST_SUITE_P(All,
1310                          SSLServerSocketAlpsTest,
1311                          ::testing::Combine(::testing::Bool(),
1312                                             ::testing::Bool()));
1313 
TEST_P(SSLServerSocketAlpsTest,Alps)1314 TEST_P(SSLServerSocketAlpsTest, Alps) {
1315   const std::string server_data = "server sends some test data";
1316   const std::string client_data = "client also sends some data";
1317 
1318   server_ssl_config_.alpn_protos = {kProtoHTTP2};
1319   if (server_alps_enabled_) {
1320     server_ssl_config_.application_settings[kProtoHTTP2] =
1321         std::vector<uint8_t>(server_data.begin(), server_data.end());
1322   }
1323 
1324   client_ssl_config_.alpn_protos = {kProtoHTTP2};
1325   if (client_alps_enabled_) {
1326     client_ssl_config_.application_settings[kProtoHTTP2] =
1327         std::vector<uint8_t>(client_data.begin(), client_data.end());
1328   }
1329 
1330   ASSERT_NO_FATAL_FAILURE(CreateContext());
1331   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1332 
1333   TestCompletionCallback handshake_callback;
1334   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1335 
1336   TestCompletionCallback connect_callback;
1337   int client_ret = client_socket_->Connect(connect_callback.callback());
1338 
1339   client_ret = connect_callback.GetResult(client_ret);
1340   server_ret = handshake_callback.GetResult(server_ret);
1341 
1342   ASSERT_THAT(client_ret, IsOk());
1343   ASSERT_THAT(server_ret, IsOk());
1344 
1345   // ALPS is negotiated only if ALPS is enabled both on client and server.
1346   const auto alps_data_received_by_client =
1347       client_socket_->GetPeerApplicationSettings();
1348   const auto alps_data_received_by_server =
1349       server_socket_->GetPeerApplicationSettings();
1350 
1351   if (client_alps_enabled_ && server_alps_enabled_) {
1352     ASSERT_TRUE(alps_data_received_by_client.has_value());
1353     EXPECT_EQ(server_data, alps_data_received_by_client.value());
1354     ASSERT_TRUE(alps_data_received_by_server.has_value());
1355     EXPECT_EQ(client_data, alps_data_received_by_server.value());
1356   } else {
1357     EXPECT_FALSE(alps_data_received_by_client.has_value());
1358     EXPECT_FALSE(alps_data_received_by_server.has_value());
1359   }
1360 }
1361 
1362 // Test that CancelReadIfReady works.
TEST_F(SSLServerSocketTest,CancelReadIfReady)1363 TEST_F(SSLServerSocketTest, CancelReadIfReady) {
1364   ASSERT_NO_FATAL_FAILURE(CreateContext());
1365   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1366 
1367   TestCompletionCallback connect_callback;
1368   int client_ret = client_socket_->Connect(connect_callback.callback());
1369   TestCompletionCallback handshake_callback;
1370   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1371   ASSERT_THAT(connect_callback.GetResult(client_ret), IsOk());
1372   ASSERT_THAT(handshake_callback.GetResult(server_ret), IsOk());
1373 
1374   // Attempt to read from the server socket. There will not be anything to read.
1375   // Cancel the read immediately afterwards.
1376   TestCompletionCallback read_callback;
1377   auto read_buf = base::MakeRefCounted<IOBuffer>(1);
1378   int read_ret =
1379       server_socket_->ReadIfReady(read_buf.get(), 1, read_callback.callback());
1380   ASSERT_THAT(read_ret, IsError(ERR_IO_PENDING));
1381   ASSERT_THAT(server_socket_->CancelReadIfReady(), IsOk());
1382 
1383   // After the client writes data, the server should still not pick up a result.
1384   auto write_buf = base::MakeRefCounted<StringIOBuffer>("a");
1385   TestCompletionCallback write_callback;
1386   ASSERT_EQ(write_callback.GetResult(client_socket_->Write(
1387                 write_buf.get(), write_buf->size(), write_callback.callback(),
1388                 TRAFFIC_ANNOTATION_FOR_TESTS)),
1389             write_buf->size());
1390   base::RunLoop().RunUntilIdle();
1391   EXPECT_FALSE(read_callback.have_result());
1392 
1393   // After a canceled read, future reads are still possible.
1394   while (true) {
1395     TestCompletionCallback read_callback2;
1396     read_ret = server_socket_->ReadIfReady(read_buf.get(), 1,
1397                                            read_callback2.callback());
1398     if (read_ret != ERR_IO_PENDING) {
1399       break;
1400     }
1401     ASSERT_THAT(read_callback2.GetResult(read_ret), IsOk());
1402   }
1403   ASSERT_EQ(1, read_ret);
1404   EXPECT_EQ(read_buf->data()[0], 'a');
1405 }
1406 
1407 }  // namespace net
1408