• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9 
10 // This test suite uses SSLClientSocket to test the implementation of
11 // SSLServerSocket. In order to establish connections between the sockets
12 // we need two additional classes:
13 // 1. FakeSocket
14 //    Connects SSL socket to FakeDataChannel. This class is just a stub.
15 //
16 // 2. FakeDataChannel
17 //    Implements the actual exchange of data between two FakeSockets.
18 //
19 // Implementations of these two classes are included in this file.
20 
21 #include "net/socket/ssl_server_socket.h"
22 
23 #include <stdint.h>
24 #include <stdlib.h>
25 
26 #include <memory>
27 #include <string_view>
28 #include <utility>
29 
30 #include "base/check.h"
31 #include "base/compiler_specific.h"
32 #include "base/containers/queue.h"
33 #include "base/files/file_path.h"
34 #include "base/files/file_util.h"
35 #include "base/functional/bind.h"
36 #include "base/functional/callback_helpers.h"
37 #include "base/location.h"
38 #include "base/memory/raw_ptr.h"
39 #include "base/memory/scoped_refptr.h"
40 #include "base/notreached.h"
41 #include "base/run_loop.h"
42 #include "base/task/single_thread_task_runner.h"
43 #include "base/test/bind.h"
44 #include "base/test/task_environment.h"
45 #include "build/build_config.h"
46 #include "crypto/rsa_private_key.h"
47 #include "net/base/address_list.h"
48 #include "net/base/completion_once_callback.h"
49 #include "net/base/host_port_pair.h"
50 #include "net/base/io_buffer.h"
51 #include "net/base/ip_address.h"
52 #include "net/base/ip_endpoint.h"
53 #include "net/base/net_errors.h"
54 #include "net/cert/cert_status_flags.h"
55 #include "net/cert/mock_cert_verifier.h"
56 #include "net/cert/mock_client_cert_verifier.h"
57 #include "net/cert/signed_certificate_timestamp_and_status.h"
58 #include "net/cert/x509_certificate.h"
59 #include "net/http/transport_security_state.h"
60 #include "net/log/net_log_with_source.h"
61 #include "net/socket/client_socket_factory.h"
62 #include "net/socket/socket_test_util.h"
63 #include "net/socket/ssl_client_socket.h"
64 #include "net/socket/stream_socket.h"
65 #include "net/ssl/openssl_private_key.h"
66 #include "net/ssl/ssl_cert_request_info.h"
67 #include "net/ssl/ssl_cipher_suite_names.h"
68 #include "net/ssl/ssl_client_session_cache.h"
69 #include "net/ssl/ssl_connection_status_flags.h"
70 #include "net/ssl/ssl_info.h"
71 #include "net/ssl/ssl_private_key.h"
72 #include "net/ssl/ssl_server_config.h"
73 #include "net/ssl/test_ssl_config_service.h"
74 #include "net/test/cert_test_util.h"
75 #include "net/test/gtest_util.h"
76 #include "net/test/test_data_directory.h"
77 #include "net/test/test_with_task_environment.h"
78 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
79 #include "testing/gmock/include/gmock/gmock.h"
80 #include "testing/gtest/include/gtest/gtest.h"
81 #include "testing/platform_test.h"
82 #include "third_party/boringssl/src/include/openssl/evp.h"
83 #include "third_party/boringssl/src/include/openssl/ssl.h"
84 
85 using net::test::IsError;
86 using net::test::IsOk;
87 
88 namespace net {
89 
90 namespace {
91 
92 // Client certificates are disabled on iOS.
93 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
94 const char kClientCertFileName[] = "client_1.pem";
95 const char kClientPrivateKeyFileName[] = "client_1.pk8";
96 const char kWrongClientCertFileName[] = "client_2.pem";
97 const char kWrongClientPrivateKeyFileName[] = "client_2.pk8";
98 #endif  // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
99 
100 const uint16_t kEcdheCiphers[] = {
101     0xc007,  // ECDHE_ECDSA_WITH_RC4_128_SHA
102     0xc009,  // ECDHE_ECDSA_WITH_AES_128_CBC_SHA
103     0xc00a,  // ECDHE_ECDSA_WITH_AES_256_CBC_SHA
104     0xc011,  // ECDHE_RSA_WITH_RC4_128_SHA
105     0xc013,  // ECDHE_RSA_WITH_AES_128_CBC_SHA
106     0xc014,  // ECDHE_RSA_WITH_AES_256_CBC_SHA
107     0xc02b,  // ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
108     0xc02c,  // ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
109     0xc02f,  // ECDHE_RSA_WITH_AES_128_GCM_SHA256
110     0xc030,  // ECDHE_RSA_WITH_AES_256_GCM_SHA384
111     0xcca8,  // ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
112     0xcca9,  // ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
113 };
114 
115 class FakeDataChannel {
116  public:
117   FakeDataChannel() = default;
118 
119   FakeDataChannel(const FakeDataChannel&) = delete;
120   FakeDataChannel& operator=(const FakeDataChannel&) = delete;
121 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)122   int Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) {
123     DCHECK(read_callback_.is_null());
124     DCHECK(!read_buf_.get());
125     if (closed_)
126       return 0;
127     if (data_.empty()) {
128       read_callback_ = std::move(callback);
129       read_buf_ = buf;
130       read_buf_len_ = buf_len;
131       return ERR_IO_PENDING;
132     }
133     return PropagateData(buf, buf_len);
134   }
135 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)136   int Write(IOBuffer* buf,
137             int buf_len,
138             CompletionOnceCallback callback,
139             const NetworkTrafficAnnotationTag& traffic_annotation) {
140     DCHECK(write_callback_.is_null());
141     if (closed_) {
142       if (write_called_after_close_)
143         return ERR_CONNECTION_RESET;
144       write_called_after_close_ = true;
145       write_callback_ = std::move(callback);
146       base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
147           FROM_HERE, base::BindOnce(&FakeDataChannel::DoWriteCallback,
148                                     weak_factory_.GetWeakPtr()));
149       return ERR_IO_PENDING;
150     }
151     // This function returns synchronously, so make a copy of the buffer.
152     data_.push(base::MakeRefCounted<DrainableIOBuffer>(
153         base::MakeRefCounted<StringIOBuffer>(std::string(buf->data(), buf_len)),
154         buf_len));
155     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
156         FROM_HERE, base::BindOnce(&FakeDataChannel::DoReadCallback,
157                                   weak_factory_.GetWeakPtr()));
158     return buf_len;
159   }
160 
161   // Closes the FakeDataChannel. After Close() is called, Read() returns 0,
162   // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that
163   // after the FakeDataChannel is closed, the first Write() call completes
164   // asynchronously, which is necessary to reproduce bug 127822.
Close()165   void Close() {
166     closed_ = true;
167     if (!read_callback_.is_null()) {
168       base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
169           FROM_HERE, base::BindOnce(&FakeDataChannel::DoReadCallback,
170                                     weak_factory_.GetWeakPtr()));
171     }
172   }
173 
174  private:
DoReadCallback()175   void DoReadCallback() {
176     if (read_callback_.is_null())
177       return;
178 
179     if (closed_) {
180       std::move(read_callback_).Run(ERR_CONNECTION_CLOSED);
181       return;
182     }
183 
184     if (data_.empty())
185       return;
186 
187     int copied = PropagateData(read_buf_, read_buf_len_);
188     read_buf_ = nullptr;
189     read_buf_len_ = 0;
190     std::move(read_callback_).Run(copied);
191   }
192 
DoWriteCallback()193   void DoWriteCallback() {
194     if (write_callback_.is_null())
195       return;
196 
197     std::move(write_callback_).Run(ERR_CONNECTION_RESET);
198   }
199 
PropagateData(scoped_refptr<IOBuffer> read_buf,int read_buf_len)200   int PropagateData(scoped_refptr<IOBuffer> read_buf, int read_buf_len) {
201     scoped_refptr<DrainableIOBuffer> buf = data_.front();
202     int copied = std::min(buf->BytesRemaining(), read_buf_len);
203     memcpy(read_buf->data(), buf->data(), copied);
204     buf->DidConsume(copied);
205 
206     if (!buf->BytesRemaining())
207       data_.pop();
208     return copied;
209   }
210 
211   CompletionOnceCallback read_callback_;
212   scoped_refptr<IOBuffer> read_buf_;
213   int read_buf_len_ = 0;
214 
215   CompletionOnceCallback write_callback_;
216 
217   base::queue<scoped_refptr<DrainableIOBuffer>> data_;
218 
219   // True if Close() has been called.
220   bool closed_ = false;
221 
222   // Controls the completion of Write() after the FakeDataChannel is closed.
223   // After the FakeDataChannel is closed, the first Write() call completes
224   // asynchronously.
225   bool write_called_after_close_ = false;
226 
227   base::WeakPtrFactory<FakeDataChannel> weak_factory_{this};
228 };
229 
230 class FakeSocket : public StreamSocket {
231  public:
FakeSocket(FakeDataChannel * incoming_channel,FakeDataChannel * outgoing_channel)232   FakeSocket(FakeDataChannel* incoming_channel,
233              FakeDataChannel* outgoing_channel)
234       : incoming_(incoming_channel), outgoing_(outgoing_channel) {}
235 
236   FakeSocket(const FakeSocket&) = delete;
237   FakeSocket& operator=(const FakeSocket&) = delete;
238 
239   ~FakeSocket() override = default;
240 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)241   int Read(IOBuffer* buf,
242            int buf_len,
243            CompletionOnceCallback callback) override {
244     // Read random number of bytes.
245     buf_len = rand() % buf_len + 1;
246     return incoming_->Read(buf, buf_len, std::move(callback));
247   }
248 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)249   int Write(IOBuffer* buf,
250             int buf_len,
251             CompletionOnceCallback callback,
252             const NetworkTrafficAnnotationTag& traffic_annotation) override {
253     // Write random number of bytes.
254     buf_len = rand() % buf_len + 1;
255     return outgoing_->Write(buf, buf_len, std::move(callback),
256                             TRAFFIC_ANNOTATION_FOR_TESTS);
257   }
258 
SetReceiveBufferSize(int32_t size)259   int SetReceiveBufferSize(int32_t size) override { return OK; }
260 
SetSendBufferSize(int32_t size)261   int SetSendBufferSize(int32_t size) override { return OK; }
262 
Connect(CompletionOnceCallback callback)263   int Connect(CompletionOnceCallback callback) override { return OK; }
264 
Disconnect()265   void Disconnect() override {
266     incoming_->Close();
267     outgoing_->Close();
268   }
269 
IsConnected() const270   bool IsConnected() const override { return true; }
271 
IsConnectedAndIdle() const272   bool IsConnectedAndIdle() const override { return true; }
273 
GetPeerAddress(IPEndPoint * address) const274   int GetPeerAddress(IPEndPoint* address) const override {
275     *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/);
276     return OK;
277   }
278 
GetLocalAddress(IPEndPoint * address) const279   int GetLocalAddress(IPEndPoint* address) const override {
280     *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/);
281     return OK;
282   }
283 
NetLog() const284   const NetLogWithSource& NetLog() const override { return net_log_; }
285 
WasEverUsed() const286   bool WasEverUsed() const override { return true; }
287 
GetNegotiatedProtocol() const288   NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
289 
GetSSLInfo(SSLInfo * ssl_info)290   bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
291 
GetTotalReceivedBytes() const292   int64_t GetTotalReceivedBytes() const override {
293     NOTIMPLEMENTED();
294     return 0;
295   }
296 
ApplySocketTag(const SocketTag & tag)297   void ApplySocketTag(const SocketTag& tag) override {}
298 
299  private:
300   NetLogWithSource net_log_;
301   raw_ptr<FakeDataChannel> incoming_;
302   raw_ptr<FakeDataChannel> outgoing_;
303 };
304 
305 }  // namespace
306 
307 // Verify the correctness of the test helper classes first.
TEST(FakeSocketTest,DataTransfer)308 TEST(FakeSocketTest, DataTransfer) {
309   base::test::TaskEnvironment task_environment;
310 
311   // Establish channels between two sockets.
312   FakeDataChannel channel_1;
313   FakeDataChannel channel_2;
314   FakeSocket client(&channel_1, &channel_2);
315   FakeSocket server(&channel_2, &channel_1);
316 
317   const char kTestData[] = "testing123";
318   const int kTestDataSize = strlen(kTestData);
319   const int kReadBufSize = 1024;
320   auto write_buf = base::MakeRefCounted<StringIOBuffer>(kTestData);
321   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
322 
323   // Write then read.
324   int written =
325       server.Write(write_buf.get(), kTestDataSize, CompletionOnceCallback(),
326                    TRAFFIC_ANNOTATION_FOR_TESTS);
327   EXPECT_GT(written, 0);
328   EXPECT_LE(written, kTestDataSize);
329 
330   int read =
331       client.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
332   EXPECT_GT(read, 0);
333   EXPECT_LE(read, written);
334   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
335 
336   // Read then write.
337   TestCompletionCallback callback;
338   EXPECT_EQ(ERR_IO_PENDING,
339             server.Read(read_buf.get(), kReadBufSize, callback.callback()));
340 
341   written =
342       client.Write(write_buf.get(), kTestDataSize, CompletionOnceCallback(),
343                    TRAFFIC_ANNOTATION_FOR_TESTS);
344   EXPECT_GT(written, 0);
345   EXPECT_LE(written, kTestDataSize);
346 
347   read = callback.WaitForResult();
348   EXPECT_GT(read, 0);
349   EXPECT_LE(read, written);
350   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
351 }
352 
353 class SSLServerSocketTest : public PlatformTest, public WithTaskEnvironment {
354  public:
SSLServerSocketTest()355   SSLServerSocketTest()
356       : ssl_config_service_(
357             std::make_unique<TestSSLConfigService>(SSLContextConfig())),
358         cert_verifier_(std::make_unique<MockCertVerifier>()),
359         client_cert_verifier_(std::make_unique<MockClientCertVerifier>()),
360         transport_security_state_(std::make_unique<TransportSecurityState>()),
361         ssl_client_session_cache_(std::make_unique<SSLClientSessionCache>(
362             SSLClientSessionCache::Config())) {}
363 
SetUp()364   void SetUp() override {
365     PlatformTest::SetUp();
366 
367     cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
368     client_cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
369 
370     server_cert_ =
371         ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der");
372     ASSERT_TRUE(server_cert_);
373     server_private_key_ = ReadTestKey("unittest.key.bin");
374     ASSERT_TRUE(server_private_key_);
375 
376     std::unique_ptr<crypto::RSAPrivateKey> key =
377         ReadTestKey("unittest.key.bin");
378     ASSERT_TRUE(key);
379     server_ssl_private_key_ = WrapOpenSSLPrivateKey(bssl::UpRef(key->key()));
380 
381     // Certificate provided by the host doesn't need authority.
382     client_ssl_config_.allowed_bad_certs.emplace_back(
383         server_cert_, CERT_STATUS_AUTHORITY_INVALID);
384 
385     client_context_ = std::make_unique<SSLClientContext>(
386         ssl_config_service_.get(), cert_verifier_.get(),
387         transport_security_state_.get(), ssl_client_session_cache_.get(),
388         nullptr);
389   }
390 
391  protected:
CreateContext()392   void CreateContext() {
393     client_socket_.reset();
394     server_socket_.reset();
395     channel_1_.reset();
396     channel_2_.reset();
397     server_context_ = CreateSSLServerContext(
398         server_cert_.get(), *server_private_key_, server_ssl_config_);
399   }
400 
CreateContextSSLPrivateKey()401   void CreateContextSSLPrivateKey() {
402     client_socket_.reset();
403     server_socket_.reset();
404     channel_1_.reset();
405     channel_2_.reset();
406     server_context_.reset();
407     server_context_ = CreateSSLServerContext(
408         server_cert_.get(), server_ssl_private_key_, server_ssl_config_);
409   }
410 
GetHostAndPort()411   static HostPortPair GetHostAndPort() { return HostPortPair("unittest", 0); }
412 
CreateSockets()413   void CreateSockets() {
414     client_socket_.reset();
415     server_socket_.reset();
416     channel_1_ = std::make_unique<FakeDataChannel>();
417     channel_2_ = std::make_unique<FakeDataChannel>();
418     std::unique_ptr<StreamSocket> client_connection =
419         std::make_unique<FakeSocket>(channel_1_.get(), channel_2_.get());
420     std::unique_ptr<StreamSocket> server_socket =
421         std::make_unique<FakeSocket>(channel_2_.get(), channel_1_.get());
422 
423     client_socket_ = client_context_->CreateSSLClientSocket(
424         std::move(client_connection), GetHostAndPort(), client_ssl_config_);
425     ASSERT_TRUE(client_socket_);
426 
427     server_socket_ =
428         server_context_->CreateSSLServerSocket(std::move(server_socket));
429     ASSERT_TRUE(server_socket_);
430   }
431 
432 // Client certificates are disabled on iOS.
433 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
ConfigureClientCertsForClient(const char * cert_file_name,const char * private_key_file_name)434   void ConfigureClientCertsForClient(const char* cert_file_name,
435                                      const char* private_key_file_name) {
436     scoped_refptr<X509Certificate> client_cert =
437         ImportCertFromFile(GetTestCertsDirectory(), cert_file_name);
438     ASSERT_TRUE(client_cert);
439 
440     std::unique_ptr<crypto::RSAPrivateKey> key =
441         ReadTestKey(private_key_file_name);
442     ASSERT_TRUE(key);
443 
444     client_context_->SetClientCertificate(
445         GetHostAndPort(), std::move(client_cert),
446         WrapOpenSSLPrivateKey(bssl::UpRef(key->key())));
447   }
448 
ConfigureClientCertsForServer()449   void ConfigureClientCertsForServer() {
450     server_ssl_config_.client_cert_type =
451         SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT;
452 
453     // "CN=B CA" - DER encoded DN of the issuer of client_1.pem
454     static const uint8_t kClientCertCAName[] = {
455         0x30, 0x0f, 0x31, 0x0d, 0x30, 0x0b, 0x06, 0x03, 0x55,
456         0x04, 0x03, 0x0c, 0x04, 0x42, 0x20, 0x43, 0x41};
457     server_ssl_config_.cert_authorities.emplace_back(
458         std::begin(kClientCertCAName), std::end(kClientCertCAName));
459 
460     scoped_refptr<X509Certificate> expected_client_cert(
461         ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName));
462     ASSERT_TRUE(expected_client_cert);
463 
464     client_cert_verifier_->AddResultForCert(expected_client_cert.get(), OK);
465 
466     server_ssl_config_.client_cert_verifier = client_cert_verifier_.get();
467   }
468 #endif  // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
469 
ReadTestKey(std::string_view name)470   std::unique_ptr<crypto::RSAPrivateKey> ReadTestKey(std::string_view name) {
471     base::FilePath certs_dir(GetTestCertsDirectory());
472     base::FilePath key_path = certs_dir.AppendASCII(name);
473     std::string key_string;
474     if (!base::ReadFileToString(key_path, &key_string))
475       return nullptr;
476     std::vector<uint8_t> key_vector(
477         reinterpret_cast<const uint8_t*>(key_string.data()),
478         reinterpret_cast<const uint8_t*>(key_string.data() +
479                                          key_string.length()));
480     std::unique_ptr<crypto::RSAPrivateKey> key(
481         crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
482     return key;
483   }
484 
PumpServerToClient()485   void PumpServerToClient() {
486     const int kReadBufSize = 1024;
487     scoped_refptr<StringIOBuffer> write_buf =
488         base::MakeRefCounted<StringIOBuffer>("testing123");
489     scoped_refptr<DrainableIOBuffer> read_buf =
490         base::MakeRefCounted<DrainableIOBuffer>(
491             base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
492     TestCompletionCallback write_callback;
493     TestCompletionCallback read_callback;
494     int server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
495                                            write_callback.callback(),
496                                            TRAFFIC_ANNOTATION_FOR_TESTS);
497     EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
498     int client_ret = client_socket_->Read(
499         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
500     EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
501 
502     server_ret = write_callback.GetResult(server_ret);
503     EXPECT_GT(server_ret, 0);
504     client_ret = read_callback.GetResult(client_ret);
505     ASSERT_GT(client_ret, 0);
506   }
507 
508   std::unique_ptr<FakeDataChannel> channel_1_;
509   std::unique_ptr<FakeDataChannel> channel_2_;
510   std::unique_ptr<TestSSLConfigService> ssl_config_service_;
511   std::unique_ptr<MockCertVerifier> cert_verifier_;
512   std::unique_ptr<MockClientCertVerifier> client_cert_verifier_;
513   SSLConfig client_ssl_config_;
514   // Note that this has a pointer to the `cert_verifier_`, so must be destroyed
515   // before that is.
516   SSLServerConfig server_ssl_config_;
517   std::unique_ptr<TransportSecurityState> transport_security_state_;
518   std::unique_ptr<SSLClientSessionCache> ssl_client_session_cache_;
519   std::unique_ptr<SSLClientContext> client_context_;
520   std::unique_ptr<SSLServerContext> server_context_;
521   std::unique_ptr<SSLClientSocket> client_socket_;
522   std::unique_ptr<SSLServerSocket> server_socket_;
523   std::unique_ptr<crypto::RSAPrivateKey> server_private_key_;
524   scoped_refptr<SSLPrivateKey> server_ssl_private_key_;
525   scoped_refptr<X509Certificate> server_cert_;
526 };
527 
528 class SSLServerSocketReadTest : public SSLServerSocketTest,
529                                 public ::testing::WithParamInterface<bool> {
530  protected:
SSLServerSocketReadTest()531   SSLServerSocketReadTest() : read_if_ready_enabled_(GetParam()) {}
532 
Read(StreamSocket * socket,IOBuffer * buf,int buf_len,CompletionOnceCallback callback)533   int Read(StreamSocket* socket,
534            IOBuffer* buf,
535            int buf_len,
536            CompletionOnceCallback callback) {
537     if (read_if_ready_enabled()) {
538       return socket->ReadIfReady(buf, buf_len, std::move(callback));
539     }
540     return socket->Read(buf, buf_len, std::move(callback));
541   }
542 
read_if_ready_enabled() const543   bool read_if_ready_enabled() const { return read_if_ready_enabled_; }
544 
545  private:
546   const bool read_if_ready_enabled_;
547 };
548 
549 INSTANTIATE_TEST_SUITE_P(/* no prefix */,
550                          SSLServerSocketReadTest,
551                          ::testing::Bool());
552 
553 // This test only executes creation of client and server sockets. This is to
554 // test that creation of sockets doesn't crash and have minimal code to run
555 // with memory leak/corruption checking tools.
TEST_F(SSLServerSocketTest,Initialize)556 TEST_F(SSLServerSocketTest, Initialize) {
557   ASSERT_NO_FATAL_FAILURE(CreateContext());
558   ASSERT_NO_FATAL_FAILURE(CreateSockets());
559 }
560 
561 // This test executes Connect() on SSLClientSocket and Handshake() on
562 // SSLServerSocket to make sure handshaking between the two sockets is
563 // completed successfully.
TEST_F(SSLServerSocketTest,Handshake)564 TEST_F(SSLServerSocketTest, Handshake) {
565   ASSERT_NO_FATAL_FAILURE(CreateContext());
566   ASSERT_NO_FATAL_FAILURE(CreateSockets());
567 
568   TestCompletionCallback handshake_callback;
569   int server_ret = server_socket_->Handshake(handshake_callback.callback());
570 
571   TestCompletionCallback connect_callback;
572   int client_ret = client_socket_->Connect(connect_callback.callback());
573 
574   client_ret = connect_callback.GetResult(client_ret);
575   server_ret = handshake_callback.GetResult(server_ret);
576 
577   ASSERT_THAT(client_ret, IsOk());
578   ASSERT_THAT(server_ret, IsOk());
579 
580   // Make sure the cert status is expected.
581   SSLInfo ssl_info;
582   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
583   EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
584 
585   // The default cipher suite should be ECDHE and an AEAD.
586   uint16_t cipher_suite =
587       SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
588   const char* key_exchange;
589   const char* cipher;
590   const char* mac;
591   bool is_aead;
592   bool is_tls13;
593   SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
594                           cipher_suite);
595   EXPECT_TRUE(is_aead);
596 }
597 
598 // This test makes sure the session cache is working.
TEST_F(SSLServerSocketTest,HandshakeCached)599 TEST_F(SSLServerSocketTest, HandshakeCached) {
600   ASSERT_NO_FATAL_FAILURE(CreateContext());
601   ASSERT_NO_FATAL_FAILURE(CreateSockets());
602 
603   TestCompletionCallback handshake_callback;
604   int server_ret = server_socket_->Handshake(handshake_callback.callback());
605 
606   TestCompletionCallback connect_callback;
607   int client_ret = client_socket_->Connect(connect_callback.callback());
608 
609   client_ret = connect_callback.GetResult(client_ret);
610   server_ret = handshake_callback.GetResult(server_ret);
611 
612   ASSERT_THAT(client_ret, IsOk());
613   ASSERT_THAT(server_ret, IsOk());
614 
615   // Make sure the cert status is expected.
616   SSLInfo ssl_info;
617   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
618   EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
619   SSLInfo ssl_server_info;
620   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
621   EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
622 
623   // Pump client read to get new session tickets.
624   PumpServerToClient();
625 
626   // Make sure the second connection is cached.
627   ASSERT_NO_FATAL_FAILURE(CreateSockets());
628   TestCompletionCallback handshake_callback2;
629   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
630 
631   TestCompletionCallback connect_callback2;
632   int client_ret2 = client_socket_->Connect(connect_callback2.callback());
633 
634   client_ret2 = connect_callback2.GetResult(client_ret2);
635   server_ret2 = handshake_callback2.GetResult(server_ret2);
636 
637   ASSERT_THAT(client_ret2, IsOk());
638   ASSERT_THAT(server_ret2, IsOk());
639 
640   // Make sure the cert status is expected.
641   SSLInfo ssl_info2;
642   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
643   EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
644   SSLInfo ssl_server_info2;
645   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
646   EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
647 }
648 
649 // This test makes sure the session cache separates out by server context.
TEST_F(SSLServerSocketTest,HandshakeCachedContextSwitch)650 TEST_F(SSLServerSocketTest, HandshakeCachedContextSwitch) {
651   ASSERT_NO_FATAL_FAILURE(CreateContext());
652   ASSERT_NO_FATAL_FAILURE(CreateSockets());
653 
654   TestCompletionCallback handshake_callback;
655   int server_ret = server_socket_->Handshake(handshake_callback.callback());
656 
657   TestCompletionCallback connect_callback;
658   int client_ret = client_socket_->Connect(connect_callback.callback());
659 
660   client_ret = connect_callback.GetResult(client_ret);
661   server_ret = handshake_callback.GetResult(server_ret);
662 
663   ASSERT_THAT(client_ret, IsOk());
664   ASSERT_THAT(server_ret, IsOk());
665 
666   // Make sure the cert status is expected.
667   SSLInfo ssl_info;
668   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
669   EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
670   SSLInfo ssl_server_info;
671   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
672   EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
673 
674   // Make sure the second connection is NOT cached when using a new context.
675   ASSERT_NO_FATAL_FAILURE(CreateContext());
676   ASSERT_NO_FATAL_FAILURE(CreateSockets());
677 
678   TestCompletionCallback handshake_callback2;
679   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
680 
681   TestCompletionCallback connect_callback2;
682   int client_ret2 = client_socket_->Connect(connect_callback2.callback());
683 
684   client_ret2 = connect_callback2.GetResult(client_ret2);
685   server_ret2 = handshake_callback2.GetResult(server_ret2);
686 
687   ASSERT_THAT(client_ret2, IsOk());
688   ASSERT_THAT(server_ret2, IsOk());
689 
690   // Make sure the cert status is expected.
691   SSLInfo ssl_info2;
692   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
693   EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_FULL);
694   SSLInfo ssl_server_info2;
695   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
696   EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_FULL);
697 }
698 
699 // Client certificates are disabled on iOS.
700 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
701 // This test executes Connect() on SSLClientSocket and Handshake() on
702 // SSLServerSocket to make sure handshaking between the two sockets is
703 // completed successfully, using client certificate.
TEST_F(SSLServerSocketTest,HandshakeWithClientCert)704 TEST_F(SSLServerSocketTest, HandshakeWithClientCert) {
705   scoped_refptr<X509Certificate> client_cert =
706       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
707   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
708       kClientCertFileName, kClientPrivateKeyFileName));
709   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
710   ASSERT_NO_FATAL_FAILURE(CreateContext());
711   ASSERT_NO_FATAL_FAILURE(CreateSockets());
712 
713   TestCompletionCallback handshake_callback;
714   int server_ret = server_socket_->Handshake(handshake_callback.callback());
715 
716   TestCompletionCallback connect_callback;
717   int client_ret = client_socket_->Connect(connect_callback.callback());
718 
719   client_ret = connect_callback.GetResult(client_ret);
720   server_ret = handshake_callback.GetResult(server_ret);
721 
722   ASSERT_THAT(client_ret, IsOk());
723   ASSERT_THAT(server_ret, IsOk());
724 
725   // Make sure the cert status is expected.
726   SSLInfo ssl_info;
727   client_socket_->GetSSLInfo(&ssl_info);
728   EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
729   server_socket_->GetSSLInfo(&ssl_info);
730   ASSERT_TRUE(ssl_info.cert.get());
731   EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_info.cert.get()));
732 }
733 
734 // This test executes Connect() on SSLClientSocket and Handshake() twice on
735 // SSLServerSocket to make sure handshaking between the two sockets is
736 // completed successfully, using client certificate. The second connection is
737 // expected to succeed through the session cache.
TEST_F(SSLServerSocketTest,HandshakeWithClientCertCached)738 TEST_F(SSLServerSocketTest, HandshakeWithClientCertCached) {
739   scoped_refptr<X509Certificate> client_cert =
740       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
741   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
742       kClientCertFileName, kClientPrivateKeyFileName));
743   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
744   ASSERT_NO_FATAL_FAILURE(CreateContext());
745   ASSERT_NO_FATAL_FAILURE(CreateSockets());
746 
747   TestCompletionCallback handshake_callback;
748   int server_ret = server_socket_->Handshake(handshake_callback.callback());
749 
750   TestCompletionCallback connect_callback;
751   int client_ret = client_socket_->Connect(connect_callback.callback());
752 
753   client_ret = connect_callback.GetResult(client_ret);
754   server_ret = handshake_callback.GetResult(server_ret);
755 
756   ASSERT_THAT(client_ret, IsOk());
757   ASSERT_THAT(server_ret, IsOk());
758 
759   // Make sure the cert status is expected.
760   SSLInfo ssl_info;
761   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
762   EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
763   SSLInfo ssl_server_info;
764   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
765   ASSERT_TRUE(ssl_server_info.cert.get());
766   EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_server_info.cert.get()));
767   EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
768   // Pump client read to get new session tickets.
769   PumpServerToClient();
770   server_socket_->Disconnect();
771   client_socket_->Disconnect();
772 
773   // Create the connection again.
774   ASSERT_NO_FATAL_FAILURE(CreateSockets());
775   TestCompletionCallback handshake_callback2;
776   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
777 
778   TestCompletionCallback connect_callback2;
779   int client_ret2 = client_socket_->Connect(connect_callback2.callback());
780 
781   client_ret2 = connect_callback2.GetResult(client_ret2);
782   server_ret2 = handshake_callback2.GetResult(server_ret2);
783 
784   ASSERT_THAT(client_ret2, IsOk());
785   ASSERT_THAT(server_ret2, IsOk());
786 
787   // Make sure the cert status is expected.
788   SSLInfo ssl_info2;
789   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
790   EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
791   SSLInfo ssl_server_info2;
792   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
793   ASSERT_TRUE(ssl_server_info2.cert.get());
794   EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_server_info2.cert.get()));
795   EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
796 }
797 
TEST_F(SSLServerSocketTest,HandshakeWithClientCertRequiredNotSupplied)798 TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSupplied) {
799   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
800   ASSERT_NO_FATAL_FAILURE(CreateContext());
801   ASSERT_NO_FATAL_FAILURE(CreateSockets());
802   // Use the default setting for the client socket, which is to not send
803   // a client certificate. This will cause the client to receive an
804   // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
805   // requested cert_authorities from the CertificateRequest sent by the
806   // server.
807 
808   TestCompletionCallback handshake_callback;
809   int server_ret = server_socket_->Handshake(handshake_callback.callback());
810 
811   TestCompletionCallback connect_callback;
812   EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
813             connect_callback.GetResult(
814                 client_socket_->Connect(connect_callback.callback())));
815 
816   auto request_info = base::MakeRefCounted<SSLCertRequestInfo>();
817   client_socket_->GetSSLCertRequestInfo(request_info.get());
818 
819   // Check that the authority name that arrived in the CertificateRequest
820   // handshake message is as expected.
821   scoped_refptr<X509Certificate> client_cert =
822       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
823   ASSERT_TRUE(client_cert);
824   EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
825 
826   client_socket_->Disconnect();
827 
828   EXPECT_THAT(handshake_callback.GetResult(server_ret),
829               IsError(ERR_CONNECTION_CLOSED));
830 }
831 
TEST_F(SSLServerSocketTest,HandshakeWithClientCertRequiredNotSuppliedCached)832 TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSuppliedCached) {
833   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
834   ASSERT_NO_FATAL_FAILURE(CreateContext());
835   ASSERT_NO_FATAL_FAILURE(CreateSockets());
836   // Use the default setting for the client socket, which is to not send
837   // a client certificate. This will cause the client to receive an
838   // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
839   // requested cert_authorities from the CertificateRequest sent by the
840   // server.
841 
842   TestCompletionCallback handshake_callback;
843   int server_ret = server_socket_->Handshake(handshake_callback.callback());
844 
845   TestCompletionCallback connect_callback;
846   EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
847             connect_callback.GetResult(
848                 client_socket_->Connect(connect_callback.callback())));
849 
850   auto request_info = base::MakeRefCounted<SSLCertRequestInfo>();
851   client_socket_->GetSSLCertRequestInfo(request_info.get());
852 
853   // Check that the authority name that arrived in the CertificateRequest
854   // handshake message is as expected.
855   scoped_refptr<X509Certificate> client_cert =
856       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
857   ASSERT_TRUE(client_cert);
858   EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
859 
860   client_socket_->Disconnect();
861 
862   EXPECT_THAT(handshake_callback.GetResult(server_ret),
863               IsError(ERR_CONNECTION_CLOSED));
864   server_socket_->Disconnect();
865 
866   // Below, check that the cache didn't store the result of a failed handshake.
867   ASSERT_NO_FATAL_FAILURE(CreateSockets());
868   TestCompletionCallback handshake_callback2;
869   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
870 
871   TestCompletionCallback connect_callback2;
872   EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
873             connect_callback2.GetResult(
874                 client_socket_->Connect(connect_callback2.callback())));
875 
876   auto request_info2 = base::MakeRefCounted<SSLCertRequestInfo>();
877   client_socket_->GetSSLCertRequestInfo(request_info2.get());
878 
879   // Check that the authority name that arrived in the CertificateRequest
880   // handshake message is as expected.
881   EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info2->cert_authorities));
882 
883   client_socket_->Disconnect();
884 
885   EXPECT_THAT(handshake_callback2.GetResult(server_ret2),
886               IsError(ERR_CONNECTION_CLOSED));
887 }
888 
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSupplied)889 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSupplied) {
890   scoped_refptr<X509Certificate> client_cert =
891       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
892   ASSERT_TRUE(client_cert);
893 
894   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
895       kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
896   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
897   ASSERT_NO_FATAL_FAILURE(CreateContext());
898   ASSERT_NO_FATAL_FAILURE(CreateSockets());
899 
900   TestCompletionCallback handshake_callback;
901   int server_ret = server_socket_->Handshake(handshake_callback.callback());
902 
903   TestCompletionCallback connect_callback;
904   int client_ret = client_socket_->Connect(connect_callback.callback());
905 
906   // In TLS 1.3, the client cert error isn't exposed until Read is called.
907   EXPECT_EQ(OK, connect_callback.GetResult(client_ret));
908   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
909             handshake_callback.GetResult(server_ret));
910 
911   // Pump client read to get client cert error.
912   const int kReadBufSize = 1024;
913   scoped_refptr<DrainableIOBuffer> read_buf =
914       base::MakeRefCounted<DrainableIOBuffer>(
915           base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
916   TestCompletionCallback read_callback;
917   client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
918                                     read_callback.callback());
919   client_ret = read_callback.GetResult(client_ret);
920   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
921 }
922 
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSuppliedTLS12)923 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedTLS12) {
924   scoped_refptr<X509Certificate> client_cert =
925       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
926   ASSERT_TRUE(client_cert);
927 
928   client_ssl_config_.version_max_override = SSL_PROTOCOL_VERSION_TLS1_2;
929   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
930       kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
931   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
932   ASSERT_NO_FATAL_FAILURE(CreateContext());
933   ASSERT_NO_FATAL_FAILURE(CreateSockets());
934 
935   TestCompletionCallback handshake_callback;
936   int server_ret = server_socket_->Handshake(handshake_callback.callback());
937 
938   TestCompletionCallback connect_callback;
939   int client_ret = client_socket_->Connect(connect_callback.callback());
940 
941   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
942             connect_callback.GetResult(client_ret));
943   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
944             handshake_callback.GetResult(server_ret));
945 }
946 
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSuppliedCached)947 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedCached) {
948   scoped_refptr<X509Certificate> client_cert =
949       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
950   ASSERT_TRUE(client_cert);
951 
952   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
953       kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
954   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
955   ASSERT_NO_FATAL_FAILURE(CreateContext());
956   ASSERT_NO_FATAL_FAILURE(CreateSockets());
957 
958   TestCompletionCallback handshake_callback;
959   int server_ret = server_socket_->Handshake(handshake_callback.callback());
960 
961   TestCompletionCallback connect_callback;
962   int client_ret = client_socket_->Connect(connect_callback.callback());
963 
964   // In TLS 1.3, the client cert error isn't exposed until Read is called.
965   EXPECT_EQ(OK, connect_callback.GetResult(client_ret));
966   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
967             handshake_callback.GetResult(server_ret));
968 
969   // Pump client read to get client cert error.
970   const int kReadBufSize = 1024;
971   scoped_refptr<DrainableIOBuffer> read_buf =
972       base::MakeRefCounted<DrainableIOBuffer>(
973           base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
974   TestCompletionCallback read_callback;
975   client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
976                                     read_callback.callback());
977   client_ret = read_callback.GetResult(client_ret);
978   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
979 
980   client_socket_->Disconnect();
981   server_socket_->Disconnect();
982 
983   // Below, check that the cache didn't store the result of a failed handshake.
984   ASSERT_NO_FATAL_FAILURE(CreateSockets());
985   TestCompletionCallback handshake_callback2;
986   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
987 
988   TestCompletionCallback connect_callback2;
989   int client_ret2 = client_socket_->Connect(connect_callback2.callback());
990 
991   // In TLS 1.3, the client cert error isn't exposed until Read is called.
992   EXPECT_EQ(OK, connect_callback2.GetResult(client_ret2));
993   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
994             handshake_callback2.GetResult(server_ret2));
995 
996   // Pump client read to get client cert error.
997   client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
998                                     read_callback.callback());
999   client_ret = read_callback.GetResult(client_ret);
1000   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
1001 }
1002 #endif  // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
1003 
TEST_P(SSLServerSocketReadTest,DataTransfer)1004 TEST_P(SSLServerSocketReadTest, DataTransfer) {
1005   ASSERT_NO_FATAL_FAILURE(CreateContext());
1006   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1007 
1008   // Establish connection.
1009   TestCompletionCallback connect_callback;
1010   int client_ret = client_socket_->Connect(connect_callback.callback());
1011   ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1012 
1013   TestCompletionCallback handshake_callback;
1014   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1015   ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1016 
1017   client_ret = connect_callback.GetResult(client_ret);
1018   ASSERT_THAT(client_ret, IsOk());
1019   server_ret = handshake_callback.GetResult(server_ret);
1020   ASSERT_THAT(server_ret, IsOk());
1021 
1022   const int kReadBufSize = 1024;
1023   scoped_refptr<StringIOBuffer> write_buf =
1024       base::MakeRefCounted<StringIOBuffer>("testing123");
1025   scoped_refptr<DrainableIOBuffer> read_buf =
1026       base::MakeRefCounted<DrainableIOBuffer>(
1027           base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
1028 
1029   // Write then read.
1030   TestCompletionCallback write_callback;
1031   TestCompletionCallback read_callback;
1032   server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
1033                                      write_callback.callback(),
1034                                      TRAFFIC_ANNOTATION_FOR_TESTS);
1035   EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
1036   client_ret = client_socket_->Read(
1037       read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
1038   EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1039 
1040   server_ret = write_callback.GetResult(server_ret);
1041   EXPECT_GT(server_ret, 0);
1042   client_ret = read_callback.GetResult(client_ret);
1043   ASSERT_GT(client_ret, 0);
1044 
1045   read_buf->DidConsume(client_ret);
1046   while (read_buf->BytesConsumed() < write_buf->size()) {
1047     client_ret = client_socket_->Read(
1048         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
1049     EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1050     client_ret = read_callback.GetResult(client_ret);
1051     ASSERT_GT(client_ret, 0);
1052     read_buf->DidConsume(client_ret);
1053   }
1054   EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
1055   read_buf->SetOffset(0);
1056   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
1057 
1058   // Read then write.
1059   write_buf = base::MakeRefCounted<StringIOBuffer>("hello123");
1060   server_ret = Read(server_socket_.get(), read_buf.get(),
1061                     read_buf->BytesRemaining(), read_callback.callback());
1062   EXPECT_EQ(server_ret, ERR_IO_PENDING);
1063   client_ret = client_socket_->Write(write_buf.get(), write_buf->size(),
1064                                      write_callback.callback(),
1065                                      TRAFFIC_ANNOTATION_FOR_TESTS);
1066   EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1067 
1068   server_ret = read_callback.GetResult(server_ret);
1069   if (read_if_ready_enabled()) {
1070     // ReadIfReady signals the data is available but does not consume it.
1071     // The data is consumed later below.
1072     ASSERT_EQ(server_ret, OK);
1073   } else {
1074     ASSERT_GT(server_ret, 0);
1075     read_buf->DidConsume(server_ret);
1076   }
1077   client_ret = write_callback.GetResult(client_ret);
1078   EXPECT_GT(client_ret, 0);
1079 
1080   while (read_buf->BytesConsumed() < write_buf->size()) {
1081     server_ret = Read(server_socket_.get(), read_buf.get(),
1082                       read_buf->BytesRemaining(), read_callback.callback());
1083     // All the data was written above, so the data should be synchronously
1084     // available out of both Read() and ReadIfReady().
1085     ASSERT_GT(server_ret, 0);
1086     read_buf->DidConsume(server_ret);
1087   }
1088   EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
1089   read_buf->SetOffset(0);
1090   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
1091 }
1092 
1093 // A regression test for bug 127822 (http://crbug.com/127822).
1094 // If the server closes the connection after the handshake is finished,
1095 // the client's Write() call should not cause an infinite loop.
1096 // NOTE: this is a test for SSLClientSocket rather than SSLServerSocket.
TEST_F(SSLServerSocketTest,ClientWriteAfterServerClose)1097 TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
1098   ASSERT_NO_FATAL_FAILURE(CreateContext());
1099   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1100 
1101   // Establish connection.
1102   TestCompletionCallback connect_callback;
1103   int client_ret = client_socket_->Connect(connect_callback.callback());
1104   ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1105 
1106   TestCompletionCallback handshake_callback;
1107   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1108   ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1109 
1110   client_ret = connect_callback.GetResult(client_ret);
1111   ASSERT_THAT(client_ret, IsOk());
1112   server_ret = handshake_callback.GetResult(server_ret);
1113   ASSERT_THAT(server_ret, IsOk());
1114 
1115   scoped_refptr<StringIOBuffer> write_buf =
1116       base::MakeRefCounted<StringIOBuffer>("testing123");
1117 
1118   // The server closes the connection. The server needs to write some
1119   // data first so that the client's Read() calls from the transport
1120   // socket won't return ERR_IO_PENDING.  This ensures that the client
1121   // will call Read() on the transport socket again.
1122   TestCompletionCallback write_callback;
1123   server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
1124                                      write_callback.callback(),
1125                                      TRAFFIC_ANNOTATION_FOR_TESTS);
1126   EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
1127 
1128   server_ret = write_callback.GetResult(server_ret);
1129   EXPECT_GT(server_ret, 0);
1130 
1131   server_socket_->Disconnect();
1132 
1133   // The client writes some data. This should not cause an infinite loop.
1134   client_ret = client_socket_->Write(write_buf.get(), write_buf->size(),
1135                                      write_callback.callback(),
1136                                      TRAFFIC_ANNOTATION_FOR_TESTS);
1137   EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1138 
1139   client_ret = write_callback.GetResult(client_ret);
1140   EXPECT_GT(client_ret, 0);
1141 
1142   base::RunLoop run_loop;
1143   base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
1144       FROM_HERE, run_loop.QuitClosure(), base::Milliseconds(10));
1145   run_loop.Run();
1146 }
1147 
1148 // This test executes ExportKeyingMaterial() on the client and server sockets,
1149 // after connecting them, and verifies that the results match.
1150 // This test will fail if False Start is enabled (see crbug.com/90208).
TEST_F(SSLServerSocketTest,ExportKeyingMaterial)1151 TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
1152   ASSERT_NO_FATAL_FAILURE(CreateContext());
1153   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1154 
1155   TestCompletionCallback connect_callback;
1156   int client_ret = client_socket_->Connect(connect_callback.callback());
1157   ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1158 
1159   TestCompletionCallback handshake_callback;
1160   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1161   ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1162 
1163   if (client_ret == ERR_IO_PENDING) {
1164     ASSERT_THAT(connect_callback.WaitForResult(), IsOk());
1165   }
1166   if (server_ret == ERR_IO_PENDING) {
1167     ASSERT_THAT(handshake_callback.WaitForResult(), IsOk());
1168   }
1169 
1170   const int kKeyingMaterialSize = 32;
1171   const char kKeyingLabel[] = "EXPERIMENTAL-server-socket-test";
1172   std::array<uint8_t, kKeyingMaterialSize> server_out;
1173   int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel, std::nullopt,
1174                                                 server_out);
1175   ASSERT_THAT(rv, IsOk());
1176 
1177   std::array<uint8_t, kKeyingMaterialSize> client_out;
1178   rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, std::nullopt,
1179                                             client_out);
1180   ASSERT_THAT(rv, IsOk());
1181   EXPECT_EQ(server_out, client_out);
1182 
1183   const char kKeyingLabelBad[] = "EXPERIMENTAL-server-socket-test-bad";
1184   std::array<uint8_t, kKeyingMaterialSize> client_bad;
1185   rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad, std::nullopt,
1186                                             client_bad);
1187   ASSERT_EQ(rv, OK);
1188   EXPECT_NE(server_out, client_bad);
1189 }
1190 
1191 // Verifies that SSLConfig::require_ecdhe flags works properly.
TEST_F(SSLServerSocketTest,RequireEcdheFlag)1192 TEST_F(SSLServerSocketTest, RequireEcdheFlag) {
1193   // Disable all ECDHE suites on the client side.
1194   SSLContextConfig config;
1195   config.disabled_cipher_suites.assign(
1196       kEcdheCiphers, kEcdheCiphers + std::size(kEcdheCiphers));
1197 
1198   // Legacy RSA key exchange ciphers only exist in TLS 1.2 and below.
1199   config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
1200   ssl_config_service_->UpdateSSLConfigAndNotify(config);
1201 
1202   // Require ECDHE on the server.
1203   server_ssl_config_.require_ecdhe = true;
1204 
1205   ASSERT_NO_FATAL_FAILURE(CreateContext());
1206   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1207 
1208   TestCompletionCallback connect_callback;
1209   int client_ret = client_socket_->Connect(connect_callback.callback());
1210 
1211   TestCompletionCallback handshake_callback;
1212   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1213 
1214   client_ret = connect_callback.GetResult(client_ret);
1215   server_ret = handshake_callback.GetResult(server_ret);
1216 
1217   ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1218   ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1219 }
1220 
1221 // This test executes Connect() on SSLClientSocket and Handshake() on
1222 // SSLServerSocket to make sure handshaking between the two sockets is
1223 // completed successfully. The server key is represented by SSLPrivateKey.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKey)1224 TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKey) {
1225   ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1226   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1227 
1228   TestCompletionCallback handshake_callback;
1229   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1230 
1231   TestCompletionCallback connect_callback;
1232   int client_ret = client_socket_->Connect(connect_callback.callback());
1233 
1234   client_ret = connect_callback.GetResult(client_ret);
1235   server_ret = handshake_callback.GetResult(server_ret);
1236 
1237   ASSERT_THAT(client_ret, IsOk());
1238   ASSERT_THAT(server_ret, IsOk());
1239 
1240   // Make sure the cert status is expected.
1241   SSLInfo ssl_info;
1242   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
1243   EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
1244 
1245   // The default cipher suite should be ECDHE and an AEAD.
1246   uint16_t cipher_suite =
1247       SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
1248   const char* key_exchange;
1249   const char* cipher;
1250   const char* mac;
1251   bool is_aead;
1252   bool is_tls13;
1253   SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
1254                           cipher_suite);
1255   EXPECT_TRUE(is_aead);
1256 }
1257 
1258 namespace {
1259 
1260 // Helper that wraps an underlying SSLPrivateKey to allow the test to
1261 // do some work immediately before a `Sign()` operation is performed.
1262 class SSLPrivateKeyHook : public SSLPrivateKey {
1263  public:
SSLPrivateKeyHook(scoped_refptr<SSLPrivateKey> private_key,base::RepeatingClosure on_sign)1264   SSLPrivateKeyHook(scoped_refptr<SSLPrivateKey> private_key,
1265                     base::RepeatingClosure on_sign)
1266       : private_key_(std::move(private_key)), on_sign_(std::move(on_sign)) {}
1267 
1268   // SSLPrivateKey implementation.
GetProviderName()1269   std::string GetProviderName() override {
1270     return private_key_->GetProviderName();
1271   }
GetAlgorithmPreferences()1272   std::vector<uint16_t> GetAlgorithmPreferences() override {
1273     return private_key_->GetAlgorithmPreferences();
1274   }
Sign(uint16_t algorithm,base::span<const uint8_t> input,SignCallback callback)1275   void Sign(uint16_t algorithm,
1276             base::span<const uint8_t> input,
1277             SignCallback callback) override {
1278     on_sign_.Run();
1279     private_key_->Sign(algorithm, input, std::move(callback));
1280   }
1281 
1282  private:
1283   ~SSLPrivateKeyHook() override = default;
1284 
1285   const scoped_refptr<SSLPrivateKey> private_key_;
1286   const base::RepeatingClosure on_sign_;
1287 };
1288 
1289 }  // namespace
1290 
1291 // Verifies that if the client disconnects while during private key signing then
1292 // the disconnection is correctly reported to the `Handshake()` completion
1293 // callback, with `ERR_CONNECTION_CLOSED`.
1294 // This is a regression test for crbug.com/1449461.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKeyDisconnectDuringSigning_ReturnsError)1295 TEST_F(SSLServerSocketTest,
1296        HandshakeServerSSLPrivateKeyDisconnectDuringSigning_ReturnsError) {
1297   auto on_sign = base::BindLambdaForTesting([&]() {
1298     client_socket_->Disconnect();
1299     ASSERT_FALSE(client_socket_->IsConnected());
1300   });
1301   server_ssl_private_key_ = base::MakeRefCounted<SSLPrivateKeyHook>(
1302       std::move(server_ssl_private_key_), on_sign);
1303   ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1304   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1305 
1306   TestCompletionCallback handshake_callback;
1307   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1308   ASSERT_EQ(server_ret, net::ERR_IO_PENDING);
1309 
1310   TestCompletionCallback connect_callback;
1311   client_socket_->Connect(connect_callback.callback());
1312 
1313   // If resuming the handshake after private-key signing is not handled
1314   // correctly as per crbug.com/1449461 then the test will hang and timeout
1315   // at this point, due to the server-side completion callback not being
1316   // correctly invoked.
1317   server_ret = handshake_callback.GetResult(server_ret);
1318   EXPECT_EQ(server_ret, net::ERR_CONNECTION_CLOSED);
1319 }
1320 
1321 // Verifies that non-ECDHE ciphers are disabled when using SSLPrivateKey as the
1322 // server key.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKeyRequireEcdhe)1323 TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKeyRequireEcdhe) {
1324   // Disable all ECDHE suites on the client side.
1325   SSLContextConfig config;
1326   config.disabled_cipher_suites.assign(
1327       kEcdheCiphers, kEcdheCiphers + std::size(kEcdheCiphers));
1328   // TLS 1.3 always works with SSLPrivateKey.
1329   config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
1330   ssl_config_service_->UpdateSSLConfigAndNotify(config);
1331 
1332   ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1333   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1334 
1335   TestCompletionCallback connect_callback;
1336   int client_ret = client_socket_->Connect(connect_callback.callback());
1337 
1338   TestCompletionCallback handshake_callback;
1339   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1340 
1341   client_ret = connect_callback.GetResult(client_ret);
1342   server_ret = handshake_callback.GetResult(server_ret);
1343 
1344   ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1345   ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1346 }
1347 
1348 class SSLServerSocketAlpsTest
1349     : public SSLServerSocketTest,
1350       public ::testing::WithParamInterface<std::tuple<bool, bool>> {
1351  public:
SSLServerSocketAlpsTest()1352   SSLServerSocketAlpsTest()
1353       : client_alps_enabled_(std::get<0>(GetParam())),
1354         server_alps_enabled_(std::get<1>(GetParam())) {}
1355   ~SSLServerSocketAlpsTest() override = default;
1356   const bool client_alps_enabled_;
1357   const bool server_alps_enabled_;
1358 };
1359 
1360 INSTANTIATE_TEST_SUITE_P(All,
1361                          SSLServerSocketAlpsTest,
1362                          ::testing::Combine(::testing::Bool(),
1363                                             ::testing::Bool()));
1364 
TEST_P(SSLServerSocketAlpsTest,Alps)1365 TEST_P(SSLServerSocketAlpsTest, Alps) {
1366   const std::string server_data = "server sends some test data";
1367   const std::string client_data = "client also sends some data";
1368 
1369   server_ssl_config_.alpn_protos = {kProtoHTTP2};
1370   if (server_alps_enabled_) {
1371     server_ssl_config_.application_settings[kProtoHTTP2] =
1372         std::vector<uint8_t>(server_data.begin(), server_data.end());
1373   }
1374 
1375   client_ssl_config_.alpn_protos = {kProtoHTTP2};
1376   if (client_alps_enabled_) {
1377     client_ssl_config_.application_settings[kProtoHTTP2] =
1378         std::vector<uint8_t>(client_data.begin(), client_data.end());
1379   }
1380 
1381   ASSERT_NO_FATAL_FAILURE(CreateContext());
1382   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1383 
1384   TestCompletionCallback handshake_callback;
1385   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1386 
1387   TestCompletionCallback connect_callback;
1388   int client_ret = client_socket_->Connect(connect_callback.callback());
1389 
1390   client_ret = connect_callback.GetResult(client_ret);
1391   server_ret = handshake_callback.GetResult(server_ret);
1392 
1393   ASSERT_THAT(client_ret, IsOk());
1394   ASSERT_THAT(server_ret, IsOk());
1395 
1396   // ALPS is negotiated only if ALPS is enabled both on client and server.
1397   const auto alps_data_received_by_client =
1398       client_socket_->GetPeerApplicationSettings();
1399   const auto alps_data_received_by_server =
1400       server_socket_->GetPeerApplicationSettings();
1401 
1402   if (client_alps_enabled_ && server_alps_enabled_) {
1403     ASSERT_TRUE(alps_data_received_by_client.has_value());
1404     EXPECT_EQ(server_data, alps_data_received_by_client.value());
1405     ASSERT_TRUE(alps_data_received_by_server.has_value());
1406     EXPECT_EQ(client_data, alps_data_received_by_server.value());
1407   } else {
1408     EXPECT_FALSE(alps_data_received_by_client.has_value());
1409     EXPECT_FALSE(alps_data_received_by_server.has_value());
1410   }
1411 }
1412 
1413 // Test that CancelReadIfReady works.
TEST_F(SSLServerSocketTest,CancelReadIfReady)1414 TEST_F(SSLServerSocketTest, CancelReadIfReady) {
1415   ASSERT_NO_FATAL_FAILURE(CreateContext());
1416   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1417 
1418   TestCompletionCallback connect_callback;
1419   int client_ret = client_socket_->Connect(connect_callback.callback());
1420   TestCompletionCallback handshake_callback;
1421   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1422   ASSERT_THAT(connect_callback.GetResult(client_ret), IsOk());
1423   ASSERT_THAT(handshake_callback.GetResult(server_ret), IsOk());
1424 
1425   // Attempt to read from the server socket. There will not be anything to read.
1426   // Cancel the read immediately afterwards.
1427   TestCompletionCallback read_callback;
1428   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(1);
1429   int read_ret =
1430       server_socket_->ReadIfReady(read_buf.get(), 1, read_callback.callback());
1431   ASSERT_THAT(read_ret, IsError(ERR_IO_PENDING));
1432   ASSERT_THAT(server_socket_->CancelReadIfReady(), IsOk());
1433 
1434   // After the client writes data, the server should still not pick up a result.
1435   auto write_buf = base::MakeRefCounted<StringIOBuffer>("a");
1436   TestCompletionCallback write_callback;
1437   ASSERT_EQ(write_callback.GetResult(client_socket_->Write(
1438                 write_buf.get(), write_buf->size(), write_callback.callback(),
1439                 TRAFFIC_ANNOTATION_FOR_TESTS)),
1440             write_buf->size());
1441   base::RunLoop().RunUntilIdle();
1442   EXPECT_FALSE(read_callback.have_result());
1443 
1444   // After a canceled read, future reads are still possible.
1445   while (true) {
1446     TestCompletionCallback read_callback2;
1447     read_ret = server_socket_->ReadIfReady(read_buf.get(), 1,
1448                                            read_callback2.callback());
1449     if (read_ret != ERR_IO_PENDING) {
1450       break;
1451     }
1452     ASSERT_THAT(read_callback2.GetResult(read_ret), IsOk());
1453   }
1454   ASSERT_EQ(1, read_ret);
1455   EXPECT_EQ(read_buf->data()[0], 'a');
1456 }
1457 
1458 }  // namespace net
1459