1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 // This test suite uses SSLClientSocket to test the implementation of
11 // SSLServerSocket. In order to establish connections between the sockets
12 // we need two additional classes:
13 // 1. FakeSocket
14 // Connects SSL socket to FakeDataChannel. This class is just a stub.
15 //
16 // 2. FakeDataChannel
17 // Implements the actual exchange of data between two FakeSockets.
18 //
19 // Implementations of these two classes are included in this file.
20
21 #include "net/socket/ssl_server_socket.h"
22
23 #include <stdint.h>
24 #include <stdlib.h>
25
26 #include <memory>
27 #include <string_view>
28 #include <utility>
29
30 #include "base/check.h"
31 #include "base/compiler_specific.h"
32 #include "base/containers/queue.h"
33 #include "base/files/file_path.h"
34 #include "base/files/file_util.h"
35 #include "base/functional/bind.h"
36 #include "base/functional/callback_helpers.h"
37 #include "base/location.h"
38 #include "base/memory/raw_ptr.h"
39 #include "base/memory/scoped_refptr.h"
40 #include "base/notreached.h"
41 #include "base/run_loop.h"
42 #include "base/task/single_thread_task_runner.h"
43 #include "base/test/bind.h"
44 #include "base/test/task_environment.h"
45 #include "build/build_config.h"
46 #include "crypto/rsa_private_key.h"
47 #include "net/base/address_list.h"
48 #include "net/base/completion_once_callback.h"
49 #include "net/base/host_port_pair.h"
50 #include "net/base/io_buffer.h"
51 #include "net/base/ip_address.h"
52 #include "net/base/ip_endpoint.h"
53 #include "net/base/net_errors.h"
54 #include "net/cert/cert_status_flags.h"
55 #include "net/cert/mock_cert_verifier.h"
56 #include "net/cert/mock_client_cert_verifier.h"
57 #include "net/cert/signed_certificate_timestamp_and_status.h"
58 #include "net/cert/x509_certificate.h"
59 #include "net/http/transport_security_state.h"
60 #include "net/log/net_log_with_source.h"
61 #include "net/socket/client_socket_factory.h"
62 #include "net/socket/socket_test_util.h"
63 #include "net/socket/ssl_client_socket.h"
64 #include "net/socket/stream_socket.h"
65 #include "net/ssl/openssl_private_key.h"
66 #include "net/ssl/ssl_cert_request_info.h"
67 #include "net/ssl/ssl_cipher_suite_names.h"
68 #include "net/ssl/ssl_client_session_cache.h"
69 #include "net/ssl/ssl_connection_status_flags.h"
70 #include "net/ssl/ssl_info.h"
71 #include "net/ssl/ssl_private_key.h"
72 #include "net/ssl/ssl_server_config.h"
73 #include "net/ssl/test_ssl_config_service.h"
74 #include "net/test/cert_test_util.h"
75 #include "net/test/gtest_util.h"
76 #include "net/test/test_data_directory.h"
77 #include "net/test/test_with_task_environment.h"
78 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
79 #include "testing/gmock/include/gmock/gmock.h"
80 #include "testing/gtest/include/gtest/gtest.h"
81 #include "testing/platform_test.h"
82 #include "third_party/boringssl/src/include/openssl/evp.h"
83 #include "third_party/boringssl/src/include/openssl/ssl.h"
84
85 using net::test::IsError;
86 using net::test::IsOk;
87
88 namespace net {
89
90 namespace {
91
92 // Client certificates are disabled on iOS.
93 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
94 const char kClientCertFileName[] = "client_1.pem";
95 const char kClientPrivateKeyFileName[] = "client_1.pk8";
96 const char kWrongClientCertFileName[] = "client_2.pem";
97 const char kWrongClientPrivateKeyFileName[] = "client_2.pk8";
98 #endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
99
100 const uint16_t kEcdheCiphers[] = {
101 0xc007, // ECDHE_ECDSA_WITH_RC4_128_SHA
102 0xc009, // ECDHE_ECDSA_WITH_AES_128_CBC_SHA
103 0xc00a, // ECDHE_ECDSA_WITH_AES_256_CBC_SHA
104 0xc011, // ECDHE_RSA_WITH_RC4_128_SHA
105 0xc013, // ECDHE_RSA_WITH_AES_128_CBC_SHA
106 0xc014, // ECDHE_RSA_WITH_AES_256_CBC_SHA
107 0xc02b, // ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
108 0xc02c, // ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
109 0xc02f, // ECDHE_RSA_WITH_AES_128_GCM_SHA256
110 0xc030, // ECDHE_RSA_WITH_AES_256_GCM_SHA384
111 0xcca8, // ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
112 0xcca9, // ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
113 };
114
115 class FakeDataChannel {
116 public:
117 FakeDataChannel() = default;
118
119 FakeDataChannel(const FakeDataChannel&) = delete;
120 FakeDataChannel& operator=(const FakeDataChannel&) = delete;
121
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)122 int Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) {
123 DCHECK(read_callback_.is_null());
124 DCHECK(!read_buf_.get());
125 if (closed_)
126 return 0;
127 if (data_.empty()) {
128 read_callback_ = std::move(callback);
129 read_buf_ = buf;
130 read_buf_len_ = buf_len;
131 return ERR_IO_PENDING;
132 }
133 return PropagateData(buf, buf_len);
134 }
135
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)136 int Write(IOBuffer* buf,
137 int buf_len,
138 CompletionOnceCallback callback,
139 const NetworkTrafficAnnotationTag& traffic_annotation) {
140 DCHECK(write_callback_.is_null());
141 if (closed_) {
142 if (write_called_after_close_)
143 return ERR_CONNECTION_RESET;
144 write_called_after_close_ = true;
145 write_callback_ = std::move(callback);
146 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
147 FROM_HERE, base::BindOnce(&FakeDataChannel::DoWriteCallback,
148 weak_factory_.GetWeakPtr()));
149 return ERR_IO_PENDING;
150 }
151 // This function returns synchronously, so make a copy of the buffer.
152 data_.push(base::MakeRefCounted<DrainableIOBuffer>(
153 base::MakeRefCounted<StringIOBuffer>(std::string(buf->data(), buf_len)),
154 buf_len));
155 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
156 FROM_HERE, base::BindOnce(&FakeDataChannel::DoReadCallback,
157 weak_factory_.GetWeakPtr()));
158 return buf_len;
159 }
160
161 // Closes the FakeDataChannel. After Close() is called, Read() returns 0,
162 // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that
163 // after the FakeDataChannel is closed, the first Write() call completes
164 // asynchronously, which is necessary to reproduce bug 127822.
Close()165 void Close() {
166 closed_ = true;
167 if (!read_callback_.is_null()) {
168 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
169 FROM_HERE, base::BindOnce(&FakeDataChannel::DoReadCallback,
170 weak_factory_.GetWeakPtr()));
171 }
172 }
173
174 private:
DoReadCallback()175 void DoReadCallback() {
176 if (read_callback_.is_null())
177 return;
178
179 if (closed_) {
180 std::move(read_callback_).Run(ERR_CONNECTION_CLOSED);
181 return;
182 }
183
184 if (data_.empty())
185 return;
186
187 int copied = PropagateData(read_buf_, read_buf_len_);
188 read_buf_ = nullptr;
189 read_buf_len_ = 0;
190 std::move(read_callback_).Run(copied);
191 }
192
DoWriteCallback()193 void DoWriteCallback() {
194 if (write_callback_.is_null())
195 return;
196
197 std::move(write_callback_).Run(ERR_CONNECTION_RESET);
198 }
199
PropagateData(scoped_refptr<IOBuffer> read_buf,int read_buf_len)200 int PropagateData(scoped_refptr<IOBuffer> read_buf, int read_buf_len) {
201 scoped_refptr<DrainableIOBuffer> buf = data_.front();
202 int copied = std::min(buf->BytesRemaining(), read_buf_len);
203 memcpy(read_buf->data(), buf->data(), copied);
204 buf->DidConsume(copied);
205
206 if (!buf->BytesRemaining())
207 data_.pop();
208 return copied;
209 }
210
211 CompletionOnceCallback read_callback_;
212 scoped_refptr<IOBuffer> read_buf_;
213 int read_buf_len_ = 0;
214
215 CompletionOnceCallback write_callback_;
216
217 base::queue<scoped_refptr<DrainableIOBuffer>> data_;
218
219 // True if Close() has been called.
220 bool closed_ = false;
221
222 // Controls the completion of Write() after the FakeDataChannel is closed.
223 // After the FakeDataChannel is closed, the first Write() call completes
224 // asynchronously.
225 bool write_called_after_close_ = false;
226
227 base::WeakPtrFactory<FakeDataChannel> weak_factory_{this};
228 };
229
230 class FakeSocket : public StreamSocket {
231 public:
FakeSocket(FakeDataChannel * incoming_channel,FakeDataChannel * outgoing_channel)232 FakeSocket(FakeDataChannel* incoming_channel,
233 FakeDataChannel* outgoing_channel)
234 : incoming_(incoming_channel), outgoing_(outgoing_channel) {}
235
236 FakeSocket(const FakeSocket&) = delete;
237 FakeSocket& operator=(const FakeSocket&) = delete;
238
239 ~FakeSocket() override = default;
240
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)241 int Read(IOBuffer* buf,
242 int buf_len,
243 CompletionOnceCallback callback) override {
244 // Read random number of bytes.
245 buf_len = rand() % buf_len + 1;
246 return incoming_->Read(buf, buf_len, std::move(callback));
247 }
248
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)249 int Write(IOBuffer* buf,
250 int buf_len,
251 CompletionOnceCallback callback,
252 const NetworkTrafficAnnotationTag& traffic_annotation) override {
253 // Write random number of bytes.
254 buf_len = rand() % buf_len + 1;
255 return outgoing_->Write(buf, buf_len, std::move(callback),
256 TRAFFIC_ANNOTATION_FOR_TESTS);
257 }
258
SetReceiveBufferSize(int32_t size)259 int SetReceiveBufferSize(int32_t size) override { return OK; }
260
SetSendBufferSize(int32_t size)261 int SetSendBufferSize(int32_t size) override { return OK; }
262
Connect(CompletionOnceCallback callback)263 int Connect(CompletionOnceCallback callback) override { return OK; }
264
Disconnect()265 void Disconnect() override {
266 incoming_->Close();
267 outgoing_->Close();
268 }
269
IsConnected() const270 bool IsConnected() const override { return true; }
271
IsConnectedAndIdle() const272 bool IsConnectedAndIdle() const override { return true; }
273
GetPeerAddress(IPEndPoint * address) const274 int GetPeerAddress(IPEndPoint* address) const override {
275 *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/);
276 return OK;
277 }
278
GetLocalAddress(IPEndPoint * address) const279 int GetLocalAddress(IPEndPoint* address) const override {
280 *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/);
281 return OK;
282 }
283
NetLog() const284 const NetLogWithSource& NetLog() const override { return net_log_; }
285
WasEverUsed() const286 bool WasEverUsed() const override { return true; }
287
GetNegotiatedProtocol() const288 NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
289
GetSSLInfo(SSLInfo * ssl_info)290 bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
291
GetTotalReceivedBytes() const292 int64_t GetTotalReceivedBytes() const override {
293 NOTIMPLEMENTED();
294 return 0;
295 }
296
ApplySocketTag(const SocketTag & tag)297 void ApplySocketTag(const SocketTag& tag) override {}
298
299 private:
300 NetLogWithSource net_log_;
301 raw_ptr<FakeDataChannel> incoming_;
302 raw_ptr<FakeDataChannel> outgoing_;
303 };
304
305 } // namespace
306
307 // Verify the correctness of the test helper classes first.
TEST(FakeSocketTest,DataTransfer)308 TEST(FakeSocketTest, DataTransfer) {
309 base::test::TaskEnvironment task_environment;
310
311 // Establish channels between two sockets.
312 FakeDataChannel channel_1;
313 FakeDataChannel channel_2;
314 FakeSocket client(&channel_1, &channel_2);
315 FakeSocket server(&channel_2, &channel_1);
316
317 const char kTestData[] = "testing123";
318 const int kTestDataSize = strlen(kTestData);
319 const int kReadBufSize = 1024;
320 auto write_buf = base::MakeRefCounted<StringIOBuffer>(kTestData);
321 auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
322
323 // Write then read.
324 int written =
325 server.Write(write_buf.get(), kTestDataSize, CompletionOnceCallback(),
326 TRAFFIC_ANNOTATION_FOR_TESTS);
327 EXPECT_GT(written, 0);
328 EXPECT_LE(written, kTestDataSize);
329
330 int read =
331 client.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
332 EXPECT_GT(read, 0);
333 EXPECT_LE(read, written);
334 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
335
336 // Read then write.
337 TestCompletionCallback callback;
338 EXPECT_EQ(ERR_IO_PENDING,
339 server.Read(read_buf.get(), kReadBufSize, callback.callback()));
340
341 written =
342 client.Write(write_buf.get(), kTestDataSize, CompletionOnceCallback(),
343 TRAFFIC_ANNOTATION_FOR_TESTS);
344 EXPECT_GT(written, 0);
345 EXPECT_LE(written, kTestDataSize);
346
347 read = callback.WaitForResult();
348 EXPECT_GT(read, 0);
349 EXPECT_LE(read, written);
350 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
351 }
352
353 class SSLServerSocketTest : public PlatformTest, public WithTaskEnvironment {
354 public:
SSLServerSocketTest()355 SSLServerSocketTest()
356 : ssl_config_service_(
357 std::make_unique<TestSSLConfigService>(SSLContextConfig())),
358 cert_verifier_(std::make_unique<MockCertVerifier>()),
359 client_cert_verifier_(std::make_unique<MockClientCertVerifier>()),
360 transport_security_state_(std::make_unique<TransportSecurityState>()),
361 ssl_client_session_cache_(std::make_unique<SSLClientSessionCache>(
362 SSLClientSessionCache::Config())) {}
363
SetUp()364 void SetUp() override {
365 PlatformTest::SetUp();
366
367 cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
368 client_cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
369
370 server_cert_ =
371 ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der");
372 ASSERT_TRUE(server_cert_);
373 server_private_key_ = ReadTestKey("unittest.key.bin");
374 ASSERT_TRUE(server_private_key_);
375
376 std::unique_ptr<crypto::RSAPrivateKey> key =
377 ReadTestKey("unittest.key.bin");
378 ASSERT_TRUE(key);
379 server_ssl_private_key_ = WrapOpenSSLPrivateKey(bssl::UpRef(key->key()));
380
381 // Certificate provided by the host doesn't need authority.
382 client_ssl_config_.allowed_bad_certs.emplace_back(
383 server_cert_, CERT_STATUS_AUTHORITY_INVALID);
384
385 client_context_ = std::make_unique<SSLClientContext>(
386 ssl_config_service_.get(), cert_verifier_.get(),
387 transport_security_state_.get(), ssl_client_session_cache_.get(),
388 nullptr);
389 }
390
391 protected:
CreateContext()392 void CreateContext() {
393 client_socket_.reset();
394 server_socket_.reset();
395 channel_1_.reset();
396 channel_2_.reset();
397 server_context_ = CreateSSLServerContext(
398 server_cert_.get(), *server_private_key_, server_ssl_config_);
399 }
400
CreateContextSSLPrivateKey()401 void CreateContextSSLPrivateKey() {
402 client_socket_.reset();
403 server_socket_.reset();
404 channel_1_.reset();
405 channel_2_.reset();
406 server_context_.reset();
407 server_context_ = CreateSSLServerContext(
408 server_cert_.get(), server_ssl_private_key_, server_ssl_config_);
409 }
410
GetHostAndPort()411 static HostPortPair GetHostAndPort() { return HostPortPair("unittest", 0); }
412
CreateSockets()413 void CreateSockets() {
414 client_socket_.reset();
415 server_socket_.reset();
416 channel_1_ = std::make_unique<FakeDataChannel>();
417 channel_2_ = std::make_unique<FakeDataChannel>();
418 std::unique_ptr<StreamSocket> client_connection =
419 std::make_unique<FakeSocket>(channel_1_.get(), channel_2_.get());
420 std::unique_ptr<StreamSocket> server_socket =
421 std::make_unique<FakeSocket>(channel_2_.get(), channel_1_.get());
422
423 client_socket_ = client_context_->CreateSSLClientSocket(
424 std::move(client_connection), GetHostAndPort(), client_ssl_config_);
425 ASSERT_TRUE(client_socket_);
426
427 server_socket_ =
428 server_context_->CreateSSLServerSocket(std::move(server_socket));
429 ASSERT_TRUE(server_socket_);
430 }
431
432 // Client certificates are disabled on iOS.
433 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
ConfigureClientCertsForClient(const char * cert_file_name,const char * private_key_file_name)434 void ConfigureClientCertsForClient(const char* cert_file_name,
435 const char* private_key_file_name) {
436 scoped_refptr<X509Certificate> client_cert =
437 ImportCertFromFile(GetTestCertsDirectory(), cert_file_name);
438 ASSERT_TRUE(client_cert);
439
440 std::unique_ptr<crypto::RSAPrivateKey> key =
441 ReadTestKey(private_key_file_name);
442 ASSERT_TRUE(key);
443
444 client_context_->SetClientCertificate(
445 GetHostAndPort(), std::move(client_cert),
446 WrapOpenSSLPrivateKey(bssl::UpRef(key->key())));
447 }
448
ConfigureClientCertsForServer()449 void ConfigureClientCertsForServer() {
450 server_ssl_config_.client_cert_type =
451 SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT;
452
453 // "CN=B CA" - DER encoded DN of the issuer of client_1.pem
454 static const uint8_t kClientCertCAName[] = {
455 0x30, 0x0f, 0x31, 0x0d, 0x30, 0x0b, 0x06, 0x03, 0x55,
456 0x04, 0x03, 0x0c, 0x04, 0x42, 0x20, 0x43, 0x41};
457 server_ssl_config_.cert_authorities.emplace_back(
458 std::begin(kClientCertCAName), std::end(kClientCertCAName));
459
460 scoped_refptr<X509Certificate> expected_client_cert(
461 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName));
462 ASSERT_TRUE(expected_client_cert);
463
464 client_cert_verifier_->AddResultForCert(expected_client_cert.get(), OK);
465
466 server_ssl_config_.client_cert_verifier = client_cert_verifier_.get();
467 }
468 #endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
469
ReadTestKey(std::string_view name)470 std::unique_ptr<crypto::RSAPrivateKey> ReadTestKey(std::string_view name) {
471 base::FilePath certs_dir(GetTestCertsDirectory());
472 base::FilePath key_path = certs_dir.AppendASCII(name);
473 std::string key_string;
474 if (!base::ReadFileToString(key_path, &key_string))
475 return nullptr;
476 std::vector<uint8_t> key_vector(
477 reinterpret_cast<const uint8_t*>(key_string.data()),
478 reinterpret_cast<const uint8_t*>(key_string.data() +
479 key_string.length()));
480 std::unique_ptr<crypto::RSAPrivateKey> key(
481 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
482 return key;
483 }
484
PumpServerToClient()485 void PumpServerToClient() {
486 const int kReadBufSize = 1024;
487 scoped_refptr<StringIOBuffer> write_buf =
488 base::MakeRefCounted<StringIOBuffer>("testing123");
489 scoped_refptr<DrainableIOBuffer> read_buf =
490 base::MakeRefCounted<DrainableIOBuffer>(
491 base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
492 TestCompletionCallback write_callback;
493 TestCompletionCallback read_callback;
494 int server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
495 write_callback.callback(),
496 TRAFFIC_ANNOTATION_FOR_TESTS);
497 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
498 int client_ret = client_socket_->Read(
499 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
500 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
501
502 server_ret = write_callback.GetResult(server_ret);
503 EXPECT_GT(server_ret, 0);
504 client_ret = read_callback.GetResult(client_ret);
505 ASSERT_GT(client_ret, 0);
506 }
507
508 std::unique_ptr<FakeDataChannel> channel_1_;
509 std::unique_ptr<FakeDataChannel> channel_2_;
510 std::unique_ptr<TestSSLConfigService> ssl_config_service_;
511 std::unique_ptr<MockCertVerifier> cert_verifier_;
512 std::unique_ptr<MockClientCertVerifier> client_cert_verifier_;
513 SSLConfig client_ssl_config_;
514 // Note that this has a pointer to the `cert_verifier_`, so must be destroyed
515 // before that is.
516 SSLServerConfig server_ssl_config_;
517 std::unique_ptr<TransportSecurityState> transport_security_state_;
518 std::unique_ptr<SSLClientSessionCache> ssl_client_session_cache_;
519 std::unique_ptr<SSLClientContext> client_context_;
520 std::unique_ptr<SSLServerContext> server_context_;
521 std::unique_ptr<SSLClientSocket> client_socket_;
522 std::unique_ptr<SSLServerSocket> server_socket_;
523 std::unique_ptr<crypto::RSAPrivateKey> server_private_key_;
524 scoped_refptr<SSLPrivateKey> server_ssl_private_key_;
525 scoped_refptr<X509Certificate> server_cert_;
526 };
527
528 class SSLServerSocketReadTest : public SSLServerSocketTest,
529 public ::testing::WithParamInterface<bool> {
530 protected:
SSLServerSocketReadTest()531 SSLServerSocketReadTest() : read_if_ready_enabled_(GetParam()) {}
532
Read(StreamSocket * socket,IOBuffer * buf,int buf_len,CompletionOnceCallback callback)533 int Read(StreamSocket* socket,
534 IOBuffer* buf,
535 int buf_len,
536 CompletionOnceCallback callback) {
537 if (read_if_ready_enabled()) {
538 return socket->ReadIfReady(buf, buf_len, std::move(callback));
539 }
540 return socket->Read(buf, buf_len, std::move(callback));
541 }
542
read_if_ready_enabled() const543 bool read_if_ready_enabled() const { return read_if_ready_enabled_; }
544
545 private:
546 const bool read_if_ready_enabled_;
547 };
548
549 INSTANTIATE_TEST_SUITE_P(/* no prefix */,
550 SSLServerSocketReadTest,
551 ::testing::Bool());
552
553 // This test only executes creation of client and server sockets. This is to
554 // test that creation of sockets doesn't crash and have minimal code to run
555 // with memory leak/corruption checking tools.
TEST_F(SSLServerSocketTest,Initialize)556 TEST_F(SSLServerSocketTest, Initialize) {
557 ASSERT_NO_FATAL_FAILURE(CreateContext());
558 ASSERT_NO_FATAL_FAILURE(CreateSockets());
559 }
560
561 // This test executes Connect() on SSLClientSocket and Handshake() on
562 // SSLServerSocket to make sure handshaking between the two sockets is
563 // completed successfully.
TEST_F(SSLServerSocketTest,Handshake)564 TEST_F(SSLServerSocketTest, Handshake) {
565 ASSERT_NO_FATAL_FAILURE(CreateContext());
566 ASSERT_NO_FATAL_FAILURE(CreateSockets());
567
568 TestCompletionCallback handshake_callback;
569 int server_ret = server_socket_->Handshake(handshake_callback.callback());
570
571 TestCompletionCallback connect_callback;
572 int client_ret = client_socket_->Connect(connect_callback.callback());
573
574 client_ret = connect_callback.GetResult(client_ret);
575 server_ret = handshake_callback.GetResult(server_ret);
576
577 ASSERT_THAT(client_ret, IsOk());
578 ASSERT_THAT(server_ret, IsOk());
579
580 // Make sure the cert status is expected.
581 SSLInfo ssl_info;
582 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
583 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
584
585 // The default cipher suite should be ECDHE and an AEAD.
586 uint16_t cipher_suite =
587 SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
588 const char* key_exchange;
589 const char* cipher;
590 const char* mac;
591 bool is_aead;
592 bool is_tls13;
593 SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
594 cipher_suite);
595 EXPECT_TRUE(is_aead);
596 }
597
598 // This test makes sure the session cache is working.
TEST_F(SSLServerSocketTest,HandshakeCached)599 TEST_F(SSLServerSocketTest, HandshakeCached) {
600 ASSERT_NO_FATAL_FAILURE(CreateContext());
601 ASSERT_NO_FATAL_FAILURE(CreateSockets());
602
603 TestCompletionCallback handshake_callback;
604 int server_ret = server_socket_->Handshake(handshake_callback.callback());
605
606 TestCompletionCallback connect_callback;
607 int client_ret = client_socket_->Connect(connect_callback.callback());
608
609 client_ret = connect_callback.GetResult(client_ret);
610 server_ret = handshake_callback.GetResult(server_ret);
611
612 ASSERT_THAT(client_ret, IsOk());
613 ASSERT_THAT(server_ret, IsOk());
614
615 // Make sure the cert status is expected.
616 SSLInfo ssl_info;
617 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
618 EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
619 SSLInfo ssl_server_info;
620 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
621 EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
622
623 // Pump client read to get new session tickets.
624 PumpServerToClient();
625
626 // Make sure the second connection is cached.
627 ASSERT_NO_FATAL_FAILURE(CreateSockets());
628 TestCompletionCallback handshake_callback2;
629 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
630
631 TestCompletionCallback connect_callback2;
632 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
633
634 client_ret2 = connect_callback2.GetResult(client_ret2);
635 server_ret2 = handshake_callback2.GetResult(server_ret2);
636
637 ASSERT_THAT(client_ret2, IsOk());
638 ASSERT_THAT(server_ret2, IsOk());
639
640 // Make sure the cert status is expected.
641 SSLInfo ssl_info2;
642 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
643 EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
644 SSLInfo ssl_server_info2;
645 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
646 EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
647 }
648
649 // This test makes sure the session cache separates out by server context.
TEST_F(SSLServerSocketTest,HandshakeCachedContextSwitch)650 TEST_F(SSLServerSocketTest, HandshakeCachedContextSwitch) {
651 ASSERT_NO_FATAL_FAILURE(CreateContext());
652 ASSERT_NO_FATAL_FAILURE(CreateSockets());
653
654 TestCompletionCallback handshake_callback;
655 int server_ret = server_socket_->Handshake(handshake_callback.callback());
656
657 TestCompletionCallback connect_callback;
658 int client_ret = client_socket_->Connect(connect_callback.callback());
659
660 client_ret = connect_callback.GetResult(client_ret);
661 server_ret = handshake_callback.GetResult(server_ret);
662
663 ASSERT_THAT(client_ret, IsOk());
664 ASSERT_THAT(server_ret, IsOk());
665
666 // Make sure the cert status is expected.
667 SSLInfo ssl_info;
668 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
669 EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
670 SSLInfo ssl_server_info;
671 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
672 EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
673
674 // Make sure the second connection is NOT cached when using a new context.
675 ASSERT_NO_FATAL_FAILURE(CreateContext());
676 ASSERT_NO_FATAL_FAILURE(CreateSockets());
677
678 TestCompletionCallback handshake_callback2;
679 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
680
681 TestCompletionCallback connect_callback2;
682 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
683
684 client_ret2 = connect_callback2.GetResult(client_ret2);
685 server_ret2 = handshake_callback2.GetResult(server_ret2);
686
687 ASSERT_THAT(client_ret2, IsOk());
688 ASSERT_THAT(server_ret2, IsOk());
689
690 // Make sure the cert status is expected.
691 SSLInfo ssl_info2;
692 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
693 EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_FULL);
694 SSLInfo ssl_server_info2;
695 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
696 EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_FULL);
697 }
698
699 // Client certificates are disabled on iOS.
700 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
701 // This test executes Connect() on SSLClientSocket and Handshake() on
702 // SSLServerSocket to make sure handshaking between the two sockets is
703 // completed successfully, using client certificate.
TEST_F(SSLServerSocketTest,HandshakeWithClientCert)704 TEST_F(SSLServerSocketTest, HandshakeWithClientCert) {
705 scoped_refptr<X509Certificate> client_cert =
706 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
707 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
708 kClientCertFileName, kClientPrivateKeyFileName));
709 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
710 ASSERT_NO_FATAL_FAILURE(CreateContext());
711 ASSERT_NO_FATAL_FAILURE(CreateSockets());
712
713 TestCompletionCallback handshake_callback;
714 int server_ret = server_socket_->Handshake(handshake_callback.callback());
715
716 TestCompletionCallback connect_callback;
717 int client_ret = client_socket_->Connect(connect_callback.callback());
718
719 client_ret = connect_callback.GetResult(client_ret);
720 server_ret = handshake_callback.GetResult(server_ret);
721
722 ASSERT_THAT(client_ret, IsOk());
723 ASSERT_THAT(server_ret, IsOk());
724
725 // Make sure the cert status is expected.
726 SSLInfo ssl_info;
727 client_socket_->GetSSLInfo(&ssl_info);
728 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
729 server_socket_->GetSSLInfo(&ssl_info);
730 ASSERT_TRUE(ssl_info.cert.get());
731 EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_info.cert.get()));
732 }
733
734 // This test executes Connect() on SSLClientSocket and Handshake() twice on
735 // SSLServerSocket to make sure handshaking between the two sockets is
736 // completed successfully, using client certificate. The second connection is
737 // expected to succeed through the session cache.
TEST_F(SSLServerSocketTest,HandshakeWithClientCertCached)738 TEST_F(SSLServerSocketTest, HandshakeWithClientCertCached) {
739 scoped_refptr<X509Certificate> client_cert =
740 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
741 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
742 kClientCertFileName, kClientPrivateKeyFileName));
743 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
744 ASSERT_NO_FATAL_FAILURE(CreateContext());
745 ASSERT_NO_FATAL_FAILURE(CreateSockets());
746
747 TestCompletionCallback handshake_callback;
748 int server_ret = server_socket_->Handshake(handshake_callback.callback());
749
750 TestCompletionCallback connect_callback;
751 int client_ret = client_socket_->Connect(connect_callback.callback());
752
753 client_ret = connect_callback.GetResult(client_ret);
754 server_ret = handshake_callback.GetResult(server_ret);
755
756 ASSERT_THAT(client_ret, IsOk());
757 ASSERT_THAT(server_ret, IsOk());
758
759 // Make sure the cert status is expected.
760 SSLInfo ssl_info;
761 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
762 EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
763 SSLInfo ssl_server_info;
764 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
765 ASSERT_TRUE(ssl_server_info.cert.get());
766 EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_server_info.cert.get()));
767 EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
768 // Pump client read to get new session tickets.
769 PumpServerToClient();
770 server_socket_->Disconnect();
771 client_socket_->Disconnect();
772
773 // Create the connection again.
774 ASSERT_NO_FATAL_FAILURE(CreateSockets());
775 TestCompletionCallback handshake_callback2;
776 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
777
778 TestCompletionCallback connect_callback2;
779 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
780
781 client_ret2 = connect_callback2.GetResult(client_ret2);
782 server_ret2 = handshake_callback2.GetResult(server_ret2);
783
784 ASSERT_THAT(client_ret2, IsOk());
785 ASSERT_THAT(server_ret2, IsOk());
786
787 // Make sure the cert status is expected.
788 SSLInfo ssl_info2;
789 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
790 EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
791 SSLInfo ssl_server_info2;
792 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
793 ASSERT_TRUE(ssl_server_info2.cert.get());
794 EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_server_info2.cert.get()));
795 EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
796 }
797
TEST_F(SSLServerSocketTest,HandshakeWithClientCertRequiredNotSupplied)798 TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSupplied) {
799 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
800 ASSERT_NO_FATAL_FAILURE(CreateContext());
801 ASSERT_NO_FATAL_FAILURE(CreateSockets());
802 // Use the default setting for the client socket, which is to not send
803 // a client certificate. This will cause the client to receive an
804 // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
805 // requested cert_authorities from the CertificateRequest sent by the
806 // server.
807
808 TestCompletionCallback handshake_callback;
809 int server_ret = server_socket_->Handshake(handshake_callback.callback());
810
811 TestCompletionCallback connect_callback;
812 EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
813 connect_callback.GetResult(
814 client_socket_->Connect(connect_callback.callback())));
815
816 auto request_info = base::MakeRefCounted<SSLCertRequestInfo>();
817 client_socket_->GetSSLCertRequestInfo(request_info.get());
818
819 // Check that the authority name that arrived in the CertificateRequest
820 // handshake message is as expected.
821 scoped_refptr<X509Certificate> client_cert =
822 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
823 ASSERT_TRUE(client_cert);
824 EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
825
826 client_socket_->Disconnect();
827
828 EXPECT_THAT(handshake_callback.GetResult(server_ret),
829 IsError(ERR_CONNECTION_CLOSED));
830 }
831
TEST_F(SSLServerSocketTest,HandshakeWithClientCertRequiredNotSuppliedCached)832 TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSuppliedCached) {
833 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
834 ASSERT_NO_FATAL_FAILURE(CreateContext());
835 ASSERT_NO_FATAL_FAILURE(CreateSockets());
836 // Use the default setting for the client socket, which is to not send
837 // a client certificate. This will cause the client to receive an
838 // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
839 // requested cert_authorities from the CertificateRequest sent by the
840 // server.
841
842 TestCompletionCallback handshake_callback;
843 int server_ret = server_socket_->Handshake(handshake_callback.callback());
844
845 TestCompletionCallback connect_callback;
846 EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
847 connect_callback.GetResult(
848 client_socket_->Connect(connect_callback.callback())));
849
850 auto request_info = base::MakeRefCounted<SSLCertRequestInfo>();
851 client_socket_->GetSSLCertRequestInfo(request_info.get());
852
853 // Check that the authority name that arrived in the CertificateRequest
854 // handshake message is as expected.
855 scoped_refptr<X509Certificate> client_cert =
856 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
857 ASSERT_TRUE(client_cert);
858 EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
859
860 client_socket_->Disconnect();
861
862 EXPECT_THAT(handshake_callback.GetResult(server_ret),
863 IsError(ERR_CONNECTION_CLOSED));
864 server_socket_->Disconnect();
865
866 // Below, check that the cache didn't store the result of a failed handshake.
867 ASSERT_NO_FATAL_FAILURE(CreateSockets());
868 TestCompletionCallback handshake_callback2;
869 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
870
871 TestCompletionCallback connect_callback2;
872 EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
873 connect_callback2.GetResult(
874 client_socket_->Connect(connect_callback2.callback())));
875
876 auto request_info2 = base::MakeRefCounted<SSLCertRequestInfo>();
877 client_socket_->GetSSLCertRequestInfo(request_info2.get());
878
879 // Check that the authority name that arrived in the CertificateRequest
880 // handshake message is as expected.
881 EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info2->cert_authorities));
882
883 client_socket_->Disconnect();
884
885 EXPECT_THAT(handshake_callback2.GetResult(server_ret2),
886 IsError(ERR_CONNECTION_CLOSED));
887 }
888
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSupplied)889 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSupplied) {
890 scoped_refptr<X509Certificate> client_cert =
891 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
892 ASSERT_TRUE(client_cert);
893
894 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
895 kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
896 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
897 ASSERT_NO_FATAL_FAILURE(CreateContext());
898 ASSERT_NO_FATAL_FAILURE(CreateSockets());
899
900 TestCompletionCallback handshake_callback;
901 int server_ret = server_socket_->Handshake(handshake_callback.callback());
902
903 TestCompletionCallback connect_callback;
904 int client_ret = client_socket_->Connect(connect_callback.callback());
905
906 // In TLS 1.3, the client cert error isn't exposed until Read is called.
907 EXPECT_EQ(OK, connect_callback.GetResult(client_ret));
908 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
909 handshake_callback.GetResult(server_ret));
910
911 // Pump client read to get client cert error.
912 const int kReadBufSize = 1024;
913 scoped_refptr<DrainableIOBuffer> read_buf =
914 base::MakeRefCounted<DrainableIOBuffer>(
915 base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
916 TestCompletionCallback read_callback;
917 client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
918 read_callback.callback());
919 client_ret = read_callback.GetResult(client_ret);
920 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
921 }
922
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSuppliedTLS12)923 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedTLS12) {
924 scoped_refptr<X509Certificate> client_cert =
925 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
926 ASSERT_TRUE(client_cert);
927
928 client_ssl_config_.version_max_override = SSL_PROTOCOL_VERSION_TLS1_2;
929 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
930 kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
931 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
932 ASSERT_NO_FATAL_FAILURE(CreateContext());
933 ASSERT_NO_FATAL_FAILURE(CreateSockets());
934
935 TestCompletionCallback handshake_callback;
936 int server_ret = server_socket_->Handshake(handshake_callback.callback());
937
938 TestCompletionCallback connect_callback;
939 int client_ret = client_socket_->Connect(connect_callback.callback());
940
941 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
942 connect_callback.GetResult(client_ret));
943 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
944 handshake_callback.GetResult(server_ret));
945 }
946
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSuppliedCached)947 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedCached) {
948 scoped_refptr<X509Certificate> client_cert =
949 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
950 ASSERT_TRUE(client_cert);
951
952 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
953 kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
954 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
955 ASSERT_NO_FATAL_FAILURE(CreateContext());
956 ASSERT_NO_FATAL_FAILURE(CreateSockets());
957
958 TestCompletionCallback handshake_callback;
959 int server_ret = server_socket_->Handshake(handshake_callback.callback());
960
961 TestCompletionCallback connect_callback;
962 int client_ret = client_socket_->Connect(connect_callback.callback());
963
964 // In TLS 1.3, the client cert error isn't exposed until Read is called.
965 EXPECT_EQ(OK, connect_callback.GetResult(client_ret));
966 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
967 handshake_callback.GetResult(server_ret));
968
969 // Pump client read to get client cert error.
970 const int kReadBufSize = 1024;
971 scoped_refptr<DrainableIOBuffer> read_buf =
972 base::MakeRefCounted<DrainableIOBuffer>(
973 base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
974 TestCompletionCallback read_callback;
975 client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
976 read_callback.callback());
977 client_ret = read_callback.GetResult(client_ret);
978 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
979
980 client_socket_->Disconnect();
981 server_socket_->Disconnect();
982
983 // Below, check that the cache didn't store the result of a failed handshake.
984 ASSERT_NO_FATAL_FAILURE(CreateSockets());
985 TestCompletionCallback handshake_callback2;
986 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
987
988 TestCompletionCallback connect_callback2;
989 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
990
991 // In TLS 1.3, the client cert error isn't exposed until Read is called.
992 EXPECT_EQ(OK, connect_callback2.GetResult(client_ret2));
993 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
994 handshake_callback2.GetResult(server_ret2));
995
996 // Pump client read to get client cert error.
997 client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
998 read_callback.callback());
999 client_ret = read_callback.GetResult(client_ret);
1000 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
1001 }
1002 #endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
1003
TEST_P(SSLServerSocketReadTest,DataTransfer)1004 TEST_P(SSLServerSocketReadTest, DataTransfer) {
1005 ASSERT_NO_FATAL_FAILURE(CreateContext());
1006 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1007
1008 // Establish connection.
1009 TestCompletionCallback connect_callback;
1010 int client_ret = client_socket_->Connect(connect_callback.callback());
1011 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1012
1013 TestCompletionCallback handshake_callback;
1014 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1015 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1016
1017 client_ret = connect_callback.GetResult(client_ret);
1018 ASSERT_THAT(client_ret, IsOk());
1019 server_ret = handshake_callback.GetResult(server_ret);
1020 ASSERT_THAT(server_ret, IsOk());
1021
1022 const int kReadBufSize = 1024;
1023 scoped_refptr<StringIOBuffer> write_buf =
1024 base::MakeRefCounted<StringIOBuffer>("testing123");
1025 scoped_refptr<DrainableIOBuffer> read_buf =
1026 base::MakeRefCounted<DrainableIOBuffer>(
1027 base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
1028
1029 // Write then read.
1030 TestCompletionCallback write_callback;
1031 TestCompletionCallback read_callback;
1032 server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
1033 write_callback.callback(),
1034 TRAFFIC_ANNOTATION_FOR_TESTS);
1035 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
1036 client_ret = client_socket_->Read(
1037 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
1038 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1039
1040 server_ret = write_callback.GetResult(server_ret);
1041 EXPECT_GT(server_ret, 0);
1042 client_ret = read_callback.GetResult(client_ret);
1043 ASSERT_GT(client_ret, 0);
1044
1045 read_buf->DidConsume(client_ret);
1046 while (read_buf->BytesConsumed() < write_buf->size()) {
1047 client_ret = client_socket_->Read(
1048 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
1049 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1050 client_ret = read_callback.GetResult(client_ret);
1051 ASSERT_GT(client_ret, 0);
1052 read_buf->DidConsume(client_ret);
1053 }
1054 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
1055 read_buf->SetOffset(0);
1056 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
1057
1058 // Read then write.
1059 write_buf = base::MakeRefCounted<StringIOBuffer>("hello123");
1060 server_ret = Read(server_socket_.get(), read_buf.get(),
1061 read_buf->BytesRemaining(), read_callback.callback());
1062 EXPECT_EQ(server_ret, ERR_IO_PENDING);
1063 client_ret = client_socket_->Write(write_buf.get(), write_buf->size(),
1064 write_callback.callback(),
1065 TRAFFIC_ANNOTATION_FOR_TESTS);
1066 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1067
1068 server_ret = read_callback.GetResult(server_ret);
1069 if (read_if_ready_enabled()) {
1070 // ReadIfReady signals the data is available but does not consume it.
1071 // The data is consumed later below.
1072 ASSERT_EQ(server_ret, OK);
1073 } else {
1074 ASSERT_GT(server_ret, 0);
1075 read_buf->DidConsume(server_ret);
1076 }
1077 client_ret = write_callback.GetResult(client_ret);
1078 EXPECT_GT(client_ret, 0);
1079
1080 while (read_buf->BytesConsumed() < write_buf->size()) {
1081 server_ret = Read(server_socket_.get(), read_buf.get(),
1082 read_buf->BytesRemaining(), read_callback.callback());
1083 // All the data was written above, so the data should be synchronously
1084 // available out of both Read() and ReadIfReady().
1085 ASSERT_GT(server_ret, 0);
1086 read_buf->DidConsume(server_ret);
1087 }
1088 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
1089 read_buf->SetOffset(0);
1090 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
1091 }
1092
1093 // A regression test for bug 127822 (http://crbug.com/127822).
1094 // If the server closes the connection after the handshake is finished,
1095 // the client's Write() call should not cause an infinite loop.
1096 // NOTE: this is a test for SSLClientSocket rather than SSLServerSocket.
TEST_F(SSLServerSocketTest,ClientWriteAfterServerClose)1097 TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
1098 ASSERT_NO_FATAL_FAILURE(CreateContext());
1099 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1100
1101 // Establish connection.
1102 TestCompletionCallback connect_callback;
1103 int client_ret = client_socket_->Connect(connect_callback.callback());
1104 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1105
1106 TestCompletionCallback handshake_callback;
1107 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1108 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1109
1110 client_ret = connect_callback.GetResult(client_ret);
1111 ASSERT_THAT(client_ret, IsOk());
1112 server_ret = handshake_callback.GetResult(server_ret);
1113 ASSERT_THAT(server_ret, IsOk());
1114
1115 scoped_refptr<StringIOBuffer> write_buf =
1116 base::MakeRefCounted<StringIOBuffer>("testing123");
1117
1118 // The server closes the connection. The server needs to write some
1119 // data first so that the client's Read() calls from the transport
1120 // socket won't return ERR_IO_PENDING. This ensures that the client
1121 // will call Read() on the transport socket again.
1122 TestCompletionCallback write_callback;
1123 server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
1124 write_callback.callback(),
1125 TRAFFIC_ANNOTATION_FOR_TESTS);
1126 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
1127
1128 server_ret = write_callback.GetResult(server_ret);
1129 EXPECT_GT(server_ret, 0);
1130
1131 server_socket_->Disconnect();
1132
1133 // The client writes some data. This should not cause an infinite loop.
1134 client_ret = client_socket_->Write(write_buf.get(), write_buf->size(),
1135 write_callback.callback(),
1136 TRAFFIC_ANNOTATION_FOR_TESTS);
1137 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1138
1139 client_ret = write_callback.GetResult(client_ret);
1140 EXPECT_GT(client_ret, 0);
1141
1142 base::RunLoop run_loop;
1143 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
1144 FROM_HERE, run_loop.QuitClosure(), base::Milliseconds(10));
1145 run_loop.Run();
1146 }
1147
1148 // This test executes ExportKeyingMaterial() on the client and server sockets,
1149 // after connecting them, and verifies that the results match.
1150 // This test will fail if False Start is enabled (see crbug.com/90208).
TEST_F(SSLServerSocketTest,ExportKeyingMaterial)1151 TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
1152 ASSERT_NO_FATAL_FAILURE(CreateContext());
1153 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1154
1155 TestCompletionCallback connect_callback;
1156 int client_ret = client_socket_->Connect(connect_callback.callback());
1157 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1158
1159 TestCompletionCallback handshake_callback;
1160 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1161 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1162
1163 if (client_ret == ERR_IO_PENDING) {
1164 ASSERT_THAT(connect_callback.WaitForResult(), IsOk());
1165 }
1166 if (server_ret == ERR_IO_PENDING) {
1167 ASSERT_THAT(handshake_callback.WaitForResult(), IsOk());
1168 }
1169
1170 const int kKeyingMaterialSize = 32;
1171 const char kKeyingLabel[] = "EXPERIMENTAL-server-socket-test";
1172 std::array<uint8_t, kKeyingMaterialSize> server_out;
1173 int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel, std::nullopt,
1174 server_out);
1175 ASSERT_THAT(rv, IsOk());
1176
1177 std::array<uint8_t, kKeyingMaterialSize> client_out;
1178 rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, std::nullopt,
1179 client_out);
1180 ASSERT_THAT(rv, IsOk());
1181 EXPECT_EQ(server_out, client_out);
1182
1183 const char kKeyingLabelBad[] = "EXPERIMENTAL-server-socket-test-bad";
1184 std::array<uint8_t, kKeyingMaterialSize> client_bad;
1185 rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad, std::nullopt,
1186 client_bad);
1187 ASSERT_EQ(rv, OK);
1188 EXPECT_NE(server_out, client_bad);
1189 }
1190
1191 // Verifies that SSLConfig::require_ecdhe flags works properly.
TEST_F(SSLServerSocketTest,RequireEcdheFlag)1192 TEST_F(SSLServerSocketTest, RequireEcdheFlag) {
1193 // Disable all ECDHE suites on the client side.
1194 SSLContextConfig config;
1195 config.disabled_cipher_suites.assign(
1196 kEcdheCiphers, kEcdheCiphers + std::size(kEcdheCiphers));
1197
1198 // Legacy RSA key exchange ciphers only exist in TLS 1.2 and below.
1199 config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
1200 ssl_config_service_->UpdateSSLConfigAndNotify(config);
1201
1202 // Require ECDHE on the server.
1203 server_ssl_config_.require_ecdhe = true;
1204
1205 ASSERT_NO_FATAL_FAILURE(CreateContext());
1206 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1207
1208 TestCompletionCallback connect_callback;
1209 int client_ret = client_socket_->Connect(connect_callback.callback());
1210
1211 TestCompletionCallback handshake_callback;
1212 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1213
1214 client_ret = connect_callback.GetResult(client_ret);
1215 server_ret = handshake_callback.GetResult(server_ret);
1216
1217 ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1218 ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1219 }
1220
1221 // This test executes Connect() on SSLClientSocket and Handshake() on
1222 // SSLServerSocket to make sure handshaking between the two sockets is
1223 // completed successfully. The server key is represented by SSLPrivateKey.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKey)1224 TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKey) {
1225 ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1226 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1227
1228 TestCompletionCallback handshake_callback;
1229 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1230
1231 TestCompletionCallback connect_callback;
1232 int client_ret = client_socket_->Connect(connect_callback.callback());
1233
1234 client_ret = connect_callback.GetResult(client_ret);
1235 server_ret = handshake_callback.GetResult(server_ret);
1236
1237 ASSERT_THAT(client_ret, IsOk());
1238 ASSERT_THAT(server_ret, IsOk());
1239
1240 // Make sure the cert status is expected.
1241 SSLInfo ssl_info;
1242 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
1243 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
1244
1245 // The default cipher suite should be ECDHE and an AEAD.
1246 uint16_t cipher_suite =
1247 SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
1248 const char* key_exchange;
1249 const char* cipher;
1250 const char* mac;
1251 bool is_aead;
1252 bool is_tls13;
1253 SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
1254 cipher_suite);
1255 EXPECT_TRUE(is_aead);
1256 }
1257
1258 namespace {
1259
1260 // Helper that wraps an underlying SSLPrivateKey to allow the test to
1261 // do some work immediately before a `Sign()` operation is performed.
1262 class SSLPrivateKeyHook : public SSLPrivateKey {
1263 public:
SSLPrivateKeyHook(scoped_refptr<SSLPrivateKey> private_key,base::RepeatingClosure on_sign)1264 SSLPrivateKeyHook(scoped_refptr<SSLPrivateKey> private_key,
1265 base::RepeatingClosure on_sign)
1266 : private_key_(std::move(private_key)), on_sign_(std::move(on_sign)) {}
1267
1268 // SSLPrivateKey implementation.
GetProviderName()1269 std::string GetProviderName() override {
1270 return private_key_->GetProviderName();
1271 }
GetAlgorithmPreferences()1272 std::vector<uint16_t> GetAlgorithmPreferences() override {
1273 return private_key_->GetAlgorithmPreferences();
1274 }
Sign(uint16_t algorithm,base::span<const uint8_t> input,SignCallback callback)1275 void Sign(uint16_t algorithm,
1276 base::span<const uint8_t> input,
1277 SignCallback callback) override {
1278 on_sign_.Run();
1279 private_key_->Sign(algorithm, input, std::move(callback));
1280 }
1281
1282 private:
1283 ~SSLPrivateKeyHook() override = default;
1284
1285 const scoped_refptr<SSLPrivateKey> private_key_;
1286 const base::RepeatingClosure on_sign_;
1287 };
1288
1289 } // namespace
1290
1291 // Verifies that if the client disconnects while during private key signing then
1292 // the disconnection is correctly reported to the `Handshake()` completion
1293 // callback, with `ERR_CONNECTION_CLOSED`.
1294 // This is a regression test for crbug.com/1449461.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKeyDisconnectDuringSigning_ReturnsError)1295 TEST_F(SSLServerSocketTest,
1296 HandshakeServerSSLPrivateKeyDisconnectDuringSigning_ReturnsError) {
1297 auto on_sign = base::BindLambdaForTesting([&]() {
1298 client_socket_->Disconnect();
1299 ASSERT_FALSE(client_socket_->IsConnected());
1300 });
1301 server_ssl_private_key_ = base::MakeRefCounted<SSLPrivateKeyHook>(
1302 std::move(server_ssl_private_key_), on_sign);
1303 ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1304 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1305
1306 TestCompletionCallback handshake_callback;
1307 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1308 ASSERT_EQ(server_ret, net::ERR_IO_PENDING);
1309
1310 TestCompletionCallback connect_callback;
1311 client_socket_->Connect(connect_callback.callback());
1312
1313 // If resuming the handshake after private-key signing is not handled
1314 // correctly as per crbug.com/1449461 then the test will hang and timeout
1315 // at this point, due to the server-side completion callback not being
1316 // correctly invoked.
1317 server_ret = handshake_callback.GetResult(server_ret);
1318 EXPECT_EQ(server_ret, net::ERR_CONNECTION_CLOSED);
1319 }
1320
1321 // Verifies that non-ECDHE ciphers are disabled when using SSLPrivateKey as the
1322 // server key.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKeyRequireEcdhe)1323 TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKeyRequireEcdhe) {
1324 // Disable all ECDHE suites on the client side.
1325 SSLContextConfig config;
1326 config.disabled_cipher_suites.assign(
1327 kEcdheCiphers, kEcdheCiphers + std::size(kEcdheCiphers));
1328 // TLS 1.3 always works with SSLPrivateKey.
1329 config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
1330 ssl_config_service_->UpdateSSLConfigAndNotify(config);
1331
1332 ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1333 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1334
1335 TestCompletionCallback connect_callback;
1336 int client_ret = client_socket_->Connect(connect_callback.callback());
1337
1338 TestCompletionCallback handshake_callback;
1339 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1340
1341 client_ret = connect_callback.GetResult(client_ret);
1342 server_ret = handshake_callback.GetResult(server_ret);
1343
1344 ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1345 ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1346 }
1347
1348 class SSLServerSocketAlpsTest
1349 : public SSLServerSocketTest,
1350 public ::testing::WithParamInterface<std::tuple<bool, bool>> {
1351 public:
SSLServerSocketAlpsTest()1352 SSLServerSocketAlpsTest()
1353 : client_alps_enabled_(std::get<0>(GetParam())),
1354 server_alps_enabled_(std::get<1>(GetParam())) {}
1355 ~SSLServerSocketAlpsTest() override = default;
1356 const bool client_alps_enabled_;
1357 const bool server_alps_enabled_;
1358 };
1359
1360 INSTANTIATE_TEST_SUITE_P(All,
1361 SSLServerSocketAlpsTest,
1362 ::testing::Combine(::testing::Bool(),
1363 ::testing::Bool()));
1364
TEST_P(SSLServerSocketAlpsTest,Alps)1365 TEST_P(SSLServerSocketAlpsTest, Alps) {
1366 const std::string server_data = "server sends some test data";
1367 const std::string client_data = "client also sends some data";
1368
1369 server_ssl_config_.alpn_protos = {kProtoHTTP2};
1370 if (server_alps_enabled_) {
1371 server_ssl_config_.application_settings[kProtoHTTP2] =
1372 std::vector<uint8_t>(server_data.begin(), server_data.end());
1373 }
1374
1375 client_ssl_config_.alpn_protos = {kProtoHTTP2};
1376 if (client_alps_enabled_) {
1377 client_ssl_config_.application_settings[kProtoHTTP2] =
1378 std::vector<uint8_t>(client_data.begin(), client_data.end());
1379 }
1380
1381 ASSERT_NO_FATAL_FAILURE(CreateContext());
1382 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1383
1384 TestCompletionCallback handshake_callback;
1385 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1386
1387 TestCompletionCallback connect_callback;
1388 int client_ret = client_socket_->Connect(connect_callback.callback());
1389
1390 client_ret = connect_callback.GetResult(client_ret);
1391 server_ret = handshake_callback.GetResult(server_ret);
1392
1393 ASSERT_THAT(client_ret, IsOk());
1394 ASSERT_THAT(server_ret, IsOk());
1395
1396 // ALPS is negotiated only if ALPS is enabled both on client and server.
1397 const auto alps_data_received_by_client =
1398 client_socket_->GetPeerApplicationSettings();
1399 const auto alps_data_received_by_server =
1400 server_socket_->GetPeerApplicationSettings();
1401
1402 if (client_alps_enabled_ && server_alps_enabled_) {
1403 ASSERT_TRUE(alps_data_received_by_client.has_value());
1404 EXPECT_EQ(server_data, alps_data_received_by_client.value());
1405 ASSERT_TRUE(alps_data_received_by_server.has_value());
1406 EXPECT_EQ(client_data, alps_data_received_by_server.value());
1407 } else {
1408 EXPECT_FALSE(alps_data_received_by_client.has_value());
1409 EXPECT_FALSE(alps_data_received_by_server.has_value());
1410 }
1411 }
1412
1413 // Test that CancelReadIfReady works.
TEST_F(SSLServerSocketTest,CancelReadIfReady)1414 TEST_F(SSLServerSocketTest, CancelReadIfReady) {
1415 ASSERT_NO_FATAL_FAILURE(CreateContext());
1416 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1417
1418 TestCompletionCallback connect_callback;
1419 int client_ret = client_socket_->Connect(connect_callback.callback());
1420 TestCompletionCallback handshake_callback;
1421 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1422 ASSERT_THAT(connect_callback.GetResult(client_ret), IsOk());
1423 ASSERT_THAT(handshake_callback.GetResult(server_ret), IsOk());
1424
1425 // Attempt to read from the server socket. There will not be anything to read.
1426 // Cancel the read immediately afterwards.
1427 TestCompletionCallback read_callback;
1428 auto read_buf = base::MakeRefCounted<IOBufferWithSize>(1);
1429 int read_ret =
1430 server_socket_->ReadIfReady(read_buf.get(), 1, read_callback.callback());
1431 ASSERT_THAT(read_ret, IsError(ERR_IO_PENDING));
1432 ASSERT_THAT(server_socket_->CancelReadIfReady(), IsOk());
1433
1434 // After the client writes data, the server should still not pick up a result.
1435 auto write_buf = base::MakeRefCounted<StringIOBuffer>("a");
1436 TestCompletionCallback write_callback;
1437 ASSERT_EQ(write_callback.GetResult(client_socket_->Write(
1438 write_buf.get(), write_buf->size(), write_callback.callback(),
1439 TRAFFIC_ANNOTATION_FOR_TESTS)),
1440 write_buf->size());
1441 base::RunLoop().RunUntilIdle();
1442 EXPECT_FALSE(read_callback.have_result());
1443
1444 // After a canceled read, future reads are still possible.
1445 while (true) {
1446 TestCompletionCallback read_callback2;
1447 read_ret = server_socket_->ReadIfReady(read_buf.get(), 1,
1448 read_callback2.callback());
1449 if (read_ret != ERR_IO_PENDING) {
1450 break;
1451 }
1452 ASSERT_THAT(read_callback2.GetResult(read_ret), IsOk());
1453 }
1454 ASSERT_EQ(1, read_ret);
1455 EXPECT_EQ(read_buf->data()[0], 'a');
1456 }
1457
1458 } // namespace net
1459