1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 // This test suite uses SSLClientSocket to test the implementation of
6 // SSLServerSocket. In order to establish connections between the sockets
7 // we need two additional classes:
8 // 1. FakeSocket
9 // Connects SSL socket to FakeDataChannel. This class is just a stub.
10 //
11 // 2. FakeDataChannel
12 // Implements the actual exchange of data between two FakeSockets.
13 //
14 // Implementations of these two classes are included in this file.
15
16 #include "net/socket/ssl_server_socket.h"
17
18 #include <stdint.h>
19 #include <stdlib.h>
20
21 #include <memory>
22 #include <utility>
23
24 #include "base/check.h"
25 #include "base/compiler_specific.h"
26 #include "base/containers/queue.h"
27 #include "base/files/file_path.h"
28 #include "base/files/file_util.h"
29 #include "base/functional/bind.h"
30 #include "base/functional/callback_helpers.h"
31 #include "base/location.h"
32 #include "base/memory/raw_ptr.h"
33 #include "base/memory/scoped_refptr.h"
34 #include "base/notreached.h"
35 #include "base/run_loop.h"
36 #include "base/task/single_thread_task_runner.h"
37 #include "base/test/task_environment.h"
38 #include "build/build_config.h"
39 #include "crypto/rsa_private_key.h"
40 #include "net/base/address_list.h"
41 #include "net/base/completion_once_callback.h"
42 #include "net/base/host_port_pair.h"
43 #include "net/base/io_buffer.h"
44 #include "net/base/ip_address.h"
45 #include "net/base/ip_endpoint.h"
46 #include "net/base/net_errors.h"
47 #include "net/cert/cert_status_flags.h"
48 #include "net/cert/ct_policy_enforcer.h"
49 #include "net/cert/ct_policy_status.h"
50 #include "net/cert/mock_cert_verifier.h"
51 #include "net/cert/mock_client_cert_verifier.h"
52 #include "net/cert/signed_certificate_timestamp_and_status.h"
53 #include "net/cert/x509_certificate.h"
54 #include "net/http/transport_security_state.h"
55 #include "net/log/net_log_with_source.h"
56 #include "net/socket/client_socket_factory.h"
57 #include "net/socket/socket_test_util.h"
58 #include "net/socket/ssl_client_socket.h"
59 #include "net/socket/stream_socket.h"
60 #include "net/ssl/ssl_cert_request_info.h"
61 #include "net/ssl/ssl_cipher_suite_names.h"
62 #include "net/ssl/ssl_client_session_cache.h"
63 #include "net/ssl/ssl_connection_status_flags.h"
64 #include "net/ssl/ssl_info.h"
65 #include "net/ssl/ssl_private_key.h"
66 #include "net/ssl/ssl_server_config.h"
67 #include "net/ssl/test_ssl_config_service.h"
68 #include "net/ssl/test_ssl_private_key.h"
69 #include "net/test/cert_test_util.h"
70 #include "net/test/gtest_util.h"
71 #include "net/test/test_data_directory.h"
72 #include "net/test/test_with_task_environment.h"
73 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
74 #include "testing/gmock/include/gmock/gmock.h"
75 #include "testing/gtest/include/gtest/gtest.h"
76 #include "testing/platform_test.h"
77 #include "third_party/boringssl/src/include/openssl/evp.h"
78 #include "third_party/boringssl/src/include/openssl/ssl.h"
79
80 using net::test::IsError;
81 using net::test::IsOk;
82
83 namespace net {
84
85 namespace {
86
87 // Client certificates are disabled on iOS.
88 #if !BUILDFLAG(IS_IOS)
89 const char kClientCertFileName[] = "client_1.pem";
90 const char kClientPrivateKeyFileName[] = "client_1.pk8";
91 const char kWrongClientCertFileName[] = "client_2.pem";
92 const char kWrongClientPrivateKeyFileName[] = "client_2.pk8";
93 #endif // !IS_IOS
94
95 const uint16_t kEcdheCiphers[] = {
96 0xc007, // ECDHE_ECDSA_WITH_RC4_128_SHA
97 0xc009, // ECDHE_ECDSA_WITH_AES_128_CBC_SHA
98 0xc00a, // ECDHE_ECDSA_WITH_AES_256_CBC_SHA
99 0xc011, // ECDHE_RSA_WITH_RC4_128_SHA
100 0xc013, // ECDHE_RSA_WITH_AES_128_CBC_SHA
101 0xc014, // ECDHE_RSA_WITH_AES_256_CBC_SHA
102 0xc02b, // ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
103 0xc02c, // ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
104 0xc02f, // ECDHE_RSA_WITH_AES_128_GCM_SHA256
105 0xc030, // ECDHE_RSA_WITH_AES_256_GCM_SHA384
106 0xcca8, // ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
107 0xcca9, // ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
108 };
109
110 class MockCTPolicyEnforcer : public CTPolicyEnforcer {
111 public:
112 MockCTPolicyEnforcer() = default;
113 ~MockCTPolicyEnforcer() override = default;
CheckCompliance(X509Certificate * cert,const ct::SCTList & verified_scts,const NetLogWithSource & net_log)114 ct::CTPolicyCompliance CheckCompliance(
115 X509Certificate* cert,
116 const ct::SCTList& verified_scts,
117 const NetLogWithSource& net_log) override {
118 return ct::CTPolicyCompliance::CT_POLICY_COMPLIES_VIA_SCTS;
119 }
120 };
121
122 class FakeDataChannel {
123 public:
124 FakeDataChannel() = default;
125
126 FakeDataChannel(const FakeDataChannel&) = delete;
127 FakeDataChannel& operator=(const FakeDataChannel&) = delete;
128
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)129 int Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) {
130 DCHECK(read_callback_.is_null());
131 DCHECK(!read_buf_.get());
132 if (closed_)
133 return 0;
134 if (data_.empty()) {
135 read_callback_ = std::move(callback);
136 read_buf_ = buf;
137 read_buf_len_ = buf_len;
138 return ERR_IO_PENDING;
139 }
140 return PropagateData(buf, buf_len);
141 }
142
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)143 int Write(IOBuffer* buf,
144 int buf_len,
145 CompletionOnceCallback callback,
146 const NetworkTrafficAnnotationTag& traffic_annotation) {
147 DCHECK(write_callback_.is_null());
148 if (closed_) {
149 if (write_called_after_close_)
150 return ERR_CONNECTION_RESET;
151 write_called_after_close_ = true;
152 write_callback_ = std::move(callback);
153 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
154 FROM_HERE, base::BindOnce(&FakeDataChannel::DoWriteCallback,
155 weak_factory_.GetWeakPtr()));
156 return ERR_IO_PENDING;
157 }
158 // This function returns synchronously, so make a copy of the buffer.
159 data_.push(base::MakeRefCounted<DrainableIOBuffer>(
160 base::MakeRefCounted<StringIOBuffer>(std::string(buf->data(), buf_len)),
161 buf_len));
162 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
163 FROM_HERE, base::BindOnce(&FakeDataChannel::DoReadCallback,
164 weak_factory_.GetWeakPtr()));
165 return buf_len;
166 }
167
168 // Closes the FakeDataChannel. After Close() is called, Read() returns 0,
169 // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that
170 // after the FakeDataChannel is closed, the first Write() call completes
171 // asynchronously, which is necessary to reproduce bug 127822.
Close()172 void Close() {
173 closed_ = true;
174 if (!read_callback_.is_null()) {
175 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
176 FROM_HERE, base::BindOnce(&FakeDataChannel::DoReadCallback,
177 weak_factory_.GetWeakPtr()));
178 }
179 }
180
181 private:
DoReadCallback()182 void DoReadCallback() {
183 if (read_callback_.is_null())
184 return;
185
186 if (closed_) {
187 std::move(read_callback_).Run(ERR_CONNECTION_CLOSED);
188 return;
189 }
190
191 if (data_.empty())
192 return;
193
194 int copied = PropagateData(read_buf_, read_buf_len_);
195 read_buf_ = nullptr;
196 read_buf_len_ = 0;
197 std::move(read_callback_).Run(copied);
198 }
199
DoWriteCallback()200 void DoWriteCallback() {
201 if (write_callback_.is_null())
202 return;
203
204 std::move(write_callback_).Run(ERR_CONNECTION_RESET);
205 }
206
PropagateData(scoped_refptr<IOBuffer> read_buf,int read_buf_len)207 int PropagateData(scoped_refptr<IOBuffer> read_buf, int read_buf_len) {
208 scoped_refptr<DrainableIOBuffer> buf = data_.front();
209 int copied = std::min(buf->BytesRemaining(), read_buf_len);
210 memcpy(read_buf->data(), buf->data(), copied);
211 buf->DidConsume(copied);
212
213 if (!buf->BytesRemaining())
214 data_.pop();
215 return copied;
216 }
217
218 CompletionOnceCallback read_callback_;
219 scoped_refptr<IOBuffer> read_buf_;
220 int read_buf_len_ = 0;
221
222 CompletionOnceCallback write_callback_;
223
224 base::queue<scoped_refptr<DrainableIOBuffer>> data_;
225
226 // True if Close() has been called.
227 bool closed_ = false;
228
229 // Controls the completion of Write() after the FakeDataChannel is closed.
230 // After the FakeDataChannel is closed, the first Write() call completes
231 // asynchronously.
232 bool write_called_after_close_ = false;
233
234 base::WeakPtrFactory<FakeDataChannel> weak_factory_{this};
235 };
236
237 class FakeSocket : public StreamSocket {
238 public:
FakeSocket(FakeDataChannel * incoming_channel,FakeDataChannel * outgoing_channel)239 FakeSocket(FakeDataChannel* incoming_channel,
240 FakeDataChannel* outgoing_channel)
241 : incoming_(incoming_channel), outgoing_(outgoing_channel) {}
242
243 FakeSocket(const FakeSocket&) = delete;
244 FakeSocket& operator=(const FakeSocket&) = delete;
245
246 ~FakeSocket() override = default;
247
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)248 int Read(IOBuffer* buf,
249 int buf_len,
250 CompletionOnceCallback callback) override {
251 // Read random number of bytes.
252 buf_len = rand() % buf_len + 1;
253 return incoming_->Read(buf, buf_len, std::move(callback));
254 }
255
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)256 int Write(IOBuffer* buf,
257 int buf_len,
258 CompletionOnceCallback callback,
259 const NetworkTrafficAnnotationTag& traffic_annotation) override {
260 // Write random number of bytes.
261 buf_len = rand() % buf_len + 1;
262 return outgoing_->Write(buf, buf_len, std::move(callback),
263 TRAFFIC_ANNOTATION_FOR_TESTS);
264 }
265
SetReceiveBufferSize(int32_t size)266 int SetReceiveBufferSize(int32_t size) override { return OK; }
267
SetSendBufferSize(int32_t size)268 int SetSendBufferSize(int32_t size) override { return OK; }
269
Connect(CompletionOnceCallback callback)270 int Connect(CompletionOnceCallback callback) override { return OK; }
271
Disconnect()272 void Disconnect() override {
273 incoming_->Close();
274 outgoing_->Close();
275 }
276
IsConnected() const277 bool IsConnected() const override { return true; }
278
IsConnectedAndIdle() const279 bool IsConnectedAndIdle() const override { return true; }
280
GetPeerAddress(IPEndPoint * address) const281 int GetPeerAddress(IPEndPoint* address) const override {
282 *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/);
283 return OK;
284 }
285
GetLocalAddress(IPEndPoint * address) const286 int GetLocalAddress(IPEndPoint* address) const override {
287 *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/);
288 return OK;
289 }
290
NetLog() const291 const NetLogWithSource& NetLog() const override { return net_log_; }
292
WasEverUsed() const293 bool WasEverUsed() const override { return true; }
294
WasAlpnNegotiated() const295 bool WasAlpnNegotiated() const override { return false; }
296
GetNegotiatedProtocol() const297 NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
298
GetSSLInfo(SSLInfo * ssl_info)299 bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
300
GetTotalReceivedBytes() const301 int64_t GetTotalReceivedBytes() const override {
302 NOTIMPLEMENTED();
303 return 0;
304 }
305
ApplySocketTag(const SocketTag & tag)306 void ApplySocketTag(const SocketTag& tag) override {}
307
308 private:
309 NetLogWithSource net_log_;
310 raw_ptr<FakeDataChannel> incoming_;
311 raw_ptr<FakeDataChannel> outgoing_;
312 };
313
314 } // namespace
315
316 // Verify the correctness of the test helper classes first.
TEST(FakeSocketTest,DataTransfer)317 TEST(FakeSocketTest, DataTransfer) {
318 base::test::TaskEnvironment task_environment;
319
320 // Establish channels between two sockets.
321 FakeDataChannel channel_1;
322 FakeDataChannel channel_2;
323 FakeSocket client(&channel_1, &channel_2);
324 FakeSocket server(&channel_2, &channel_1);
325
326 const char kTestData[] = "testing123";
327 const int kTestDataSize = strlen(kTestData);
328 const int kReadBufSize = 1024;
329 scoped_refptr<IOBuffer> write_buf =
330 base::MakeRefCounted<StringIOBuffer>(kTestData);
331 scoped_refptr<IOBuffer> read_buf =
332 base::MakeRefCounted<IOBuffer>(kReadBufSize);
333
334 // Write then read.
335 int written =
336 server.Write(write_buf.get(), kTestDataSize, CompletionOnceCallback(),
337 TRAFFIC_ANNOTATION_FOR_TESTS);
338 EXPECT_GT(written, 0);
339 EXPECT_LE(written, kTestDataSize);
340
341 int read =
342 client.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
343 EXPECT_GT(read, 0);
344 EXPECT_LE(read, written);
345 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
346
347 // Read then write.
348 TestCompletionCallback callback;
349 EXPECT_EQ(ERR_IO_PENDING,
350 server.Read(read_buf.get(), kReadBufSize, callback.callback()));
351
352 written =
353 client.Write(write_buf.get(), kTestDataSize, CompletionOnceCallback(),
354 TRAFFIC_ANNOTATION_FOR_TESTS);
355 EXPECT_GT(written, 0);
356 EXPECT_LE(written, kTestDataSize);
357
358 read = callback.WaitForResult();
359 EXPECT_GT(read, 0);
360 EXPECT_LE(read, written);
361 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
362 }
363
364 class SSLServerSocketTest : public PlatformTest, public WithTaskEnvironment {
365 public:
SSLServerSocketTest()366 SSLServerSocketTest()
367 : ssl_config_service_(
368 std::make_unique<TestSSLConfigService>(SSLContextConfig())),
369 cert_verifier_(std::make_unique<MockCertVerifier>()),
370 client_cert_verifier_(std::make_unique<MockClientCertVerifier>()),
371 transport_security_state_(std::make_unique<TransportSecurityState>()),
372 ct_policy_enforcer_(std::make_unique<MockCTPolicyEnforcer>()),
373 ssl_client_session_cache_(std::make_unique<SSLClientSessionCache>(
374 SSLClientSessionCache::Config())) {}
375
SetUp()376 void SetUp() override {
377 PlatformTest::SetUp();
378
379 cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
380 client_cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
381
382 server_cert_ =
383 ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der");
384 ASSERT_TRUE(server_cert_);
385 server_private_key_ = ReadTestKey("unittest.key.bin");
386 ASSERT_TRUE(server_private_key_);
387
388 std::unique_ptr<crypto::RSAPrivateKey> key =
389 ReadTestKey("unittest.key.bin");
390 ASSERT_TRUE(key);
391 server_ssl_private_key_ = WrapOpenSSLPrivateKey(bssl::UpRef(key->key()));
392
393 // Certificate provided by the host doesn't need authority.
394 client_ssl_config_.allowed_bad_certs.emplace_back(
395 server_cert_, CERT_STATUS_AUTHORITY_INVALID);
396
397 client_context_ = std::make_unique<SSLClientContext>(
398 ssl_config_service_.get(), cert_verifier_.get(),
399 transport_security_state_.get(), ct_policy_enforcer_.get(),
400 ssl_client_session_cache_.get(), nullptr);
401 }
402
403 protected:
CreateContext()404 void CreateContext() {
405 client_socket_.reset();
406 server_socket_.reset();
407 channel_1_.reset();
408 channel_2_.reset();
409 server_context_ = CreateSSLServerContext(
410 server_cert_.get(), *server_private_key_, server_ssl_config_);
411 }
412
CreateContextSSLPrivateKey()413 void CreateContextSSLPrivateKey() {
414 client_socket_.reset();
415 server_socket_.reset();
416 channel_1_.reset();
417 channel_2_.reset();
418 server_context_.reset();
419 server_context_ = CreateSSLServerContext(
420 server_cert_.get(), server_ssl_private_key_, server_ssl_config_);
421 }
422
GetHostAndPort()423 static HostPortPair GetHostAndPort() { return HostPortPair("unittest", 0); }
424
CreateSockets()425 void CreateSockets() {
426 client_socket_.reset();
427 server_socket_.reset();
428 channel_1_ = std::make_unique<FakeDataChannel>();
429 channel_2_ = std::make_unique<FakeDataChannel>();
430 std::unique_ptr<StreamSocket> client_connection =
431 std::make_unique<FakeSocket>(channel_1_.get(), channel_2_.get());
432 std::unique_ptr<StreamSocket> server_socket =
433 std::make_unique<FakeSocket>(channel_2_.get(), channel_1_.get());
434
435 client_socket_ = client_context_->CreateSSLClientSocket(
436 std::move(client_connection), GetHostAndPort(), client_ssl_config_);
437 ASSERT_TRUE(client_socket_);
438
439 server_socket_ =
440 server_context_->CreateSSLServerSocket(std::move(server_socket));
441 ASSERT_TRUE(server_socket_);
442 }
443
444 // Client certificates are disabled on iOS.
445 #if !BUILDFLAG(IS_IOS)
ConfigureClientCertsForClient(const char * cert_file_name,const char * private_key_file_name)446 void ConfigureClientCertsForClient(const char* cert_file_name,
447 const char* private_key_file_name) {
448 scoped_refptr<X509Certificate> client_cert =
449 ImportCertFromFile(GetTestCertsDirectory(), cert_file_name);
450 ASSERT_TRUE(client_cert);
451
452 std::unique_ptr<crypto::RSAPrivateKey> key =
453 ReadTestKey(private_key_file_name);
454 ASSERT_TRUE(key);
455
456 client_context_->SetClientCertificate(
457 GetHostAndPort(), std::move(client_cert),
458 WrapOpenSSLPrivateKey(bssl::UpRef(key->key())));
459 }
460
ConfigureClientCertsForServer()461 void ConfigureClientCertsForServer() {
462 server_ssl_config_.client_cert_type =
463 SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT;
464
465 // "CN=B CA" - DER encoded DN of the issuer of client_1.pem
466 static const uint8_t kClientCertCAName[] = {
467 0x30, 0x0f, 0x31, 0x0d, 0x30, 0x0b, 0x06, 0x03, 0x55,
468 0x04, 0x03, 0x0c, 0x04, 0x42, 0x20, 0x43, 0x41};
469 server_ssl_config_.cert_authorities.emplace_back(
470 std::begin(kClientCertCAName), std::end(kClientCertCAName));
471
472 scoped_refptr<X509Certificate> expected_client_cert(
473 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName));
474 ASSERT_TRUE(expected_client_cert);
475
476 client_cert_verifier_->AddResultForCert(expected_client_cert.get(), OK);
477
478 server_ssl_config_.client_cert_verifier = client_cert_verifier_.get();
479 }
480 #endif // !IS_IOS
481
ReadTestKey(base::StringPiece name)482 std::unique_ptr<crypto::RSAPrivateKey> ReadTestKey(base::StringPiece name) {
483 base::FilePath certs_dir(GetTestCertsDirectory());
484 base::FilePath key_path = certs_dir.AppendASCII(name);
485 std::string key_string;
486 if (!base::ReadFileToString(key_path, &key_string))
487 return nullptr;
488 std::vector<uint8_t> key_vector(
489 reinterpret_cast<const uint8_t*>(key_string.data()),
490 reinterpret_cast<const uint8_t*>(key_string.data() +
491 key_string.length()));
492 std::unique_ptr<crypto::RSAPrivateKey> key(
493 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
494 return key;
495 }
496
PumpServerToClient()497 void PumpServerToClient() {
498 const int kReadBufSize = 1024;
499 scoped_refptr<StringIOBuffer> write_buf =
500 base::MakeRefCounted<StringIOBuffer>("testing123");
501 scoped_refptr<DrainableIOBuffer> read_buf =
502 base::MakeRefCounted<DrainableIOBuffer>(
503 base::MakeRefCounted<IOBuffer>(kReadBufSize), kReadBufSize);
504 TestCompletionCallback write_callback;
505 TestCompletionCallback read_callback;
506 int server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
507 write_callback.callback(),
508 TRAFFIC_ANNOTATION_FOR_TESTS);
509 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
510 int client_ret = client_socket_->Read(
511 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
512 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
513
514 server_ret = write_callback.GetResult(server_ret);
515 EXPECT_GT(server_ret, 0);
516 client_ret = read_callback.GetResult(client_ret);
517 ASSERT_GT(client_ret, 0);
518 }
519
520 std::unique_ptr<FakeDataChannel> channel_1_;
521 std::unique_ptr<FakeDataChannel> channel_2_;
522 SSLConfig client_ssl_config_;
523 SSLServerConfig server_ssl_config_;
524 std::unique_ptr<TestSSLConfigService> ssl_config_service_;
525 std::unique_ptr<MockCertVerifier> cert_verifier_;
526 std::unique_ptr<MockClientCertVerifier> client_cert_verifier_;
527 std::unique_ptr<TransportSecurityState> transport_security_state_;
528 std::unique_ptr<MockCTPolicyEnforcer> ct_policy_enforcer_;
529 std::unique_ptr<SSLClientSessionCache> ssl_client_session_cache_;
530 std::unique_ptr<SSLClientContext> client_context_;
531 std::unique_ptr<SSLServerContext> server_context_;
532 std::unique_ptr<SSLClientSocket> client_socket_;
533 std::unique_ptr<SSLServerSocket> server_socket_;
534 std::unique_ptr<crypto::RSAPrivateKey> server_private_key_;
535 scoped_refptr<SSLPrivateKey> server_ssl_private_key_;
536 scoped_refptr<X509Certificate> server_cert_;
537 };
538
539 class SSLServerSocketReadTest : public SSLServerSocketTest,
540 public ::testing::WithParamInterface<bool> {
541 protected:
SSLServerSocketReadTest()542 SSLServerSocketReadTest() : read_if_ready_enabled_(GetParam()) {}
543
Read(StreamSocket * socket,IOBuffer * buf,int buf_len,CompletionOnceCallback callback)544 int Read(StreamSocket* socket,
545 IOBuffer* buf,
546 int buf_len,
547 CompletionOnceCallback callback) {
548 if (read_if_ready_enabled()) {
549 return socket->ReadIfReady(buf, buf_len, std::move(callback));
550 }
551 return socket->Read(buf, buf_len, std::move(callback));
552 }
553
read_if_ready_enabled() const554 bool read_if_ready_enabled() const { return read_if_ready_enabled_; }
555
556 private:
557 const bool read_if_ready_enabled_;
558 };
559
560 INSTANTIATE_TEST_SUITE_P(/* no prefix */,
561 SSLServerSocketReadTest,
562 ::testing::Bool());
563
564 // This test only executes creation of client and server sockets. This is to
565 // test that creation of sockets doesn't crash and have minimal code to run
566 // with memory leak/corruption checking tools.
TEST_F(SSLServerSocketTest,Initialize)567 TEST_F(SSLServerSocketTest, Initialize) {
568 ASSERT_NO_FATAL_FAILURE(CreateContext());
569 ASSERT_NO_FATAL_FAILURE(CreateSockets());
570 }
571
572 // This test executes Connect() on SSLClientSocket and Handshake() on
573 // SSLServerSocket to make sure handshaking between the two sockets is
574 // completed successfully.
TEST_F(SSLServerSocketTest,Handshake)575 TEST_F(SSLServerSocketTest, Handshake) {
576 ASSERT_NO_FATAL_FAILURE(CreateContext());
577 ASSERT_NO_FATAL_FAILURE(CreateSockets());
578
579 TestCompletionCallback handshake_callback;
580 int server_ret = server_socket_->Handshake(handshake_callback.callback());
581
582 TestCompletionCallback connect_callback;
583 int client_ret = client_socket_->Connect(connect_callback.callback());
584
585 client_ret = connect_callback.GetResult(client_ret);
586 server_ret = handshake_callback.GetResult(server_ret);
587
588 ASSERT_THAT(client_ret, IsOk());
589 ASSERT_THAT(server_ret, IsOk());
590
591 // Make sure the cert status is expected.
592 SSLInfo ssl_info;
593 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
594 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
595
596 // The default cipher suite should be ECDHE and an AEAD.
597 uint16_t cipher_suite =
598 SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
599 const char* key_exchange;
600 const char* cipher;
601 const char* mac;
602 bool is_aead;
603 bool is_tls13;
604 SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
605 cipher_suite);
606 EXPECT_TRUE(is_aead);
607 }
608
609 // This test makes sure the session cache is working.
TEST_F(SSLServerSocketTest,HandshakeCached)610 TEST_F(SSLServerSocketTest, HandshakeCached) {
611 ASSERT_NO_FATAL_FAILURE(CreateContext());
612 ASSERT_NO_FATAL_FAILURE(CreateSockets());
613
614 TestCompletionCallback handshake_callback;
615 int server_ret = server_socket_->Handshake(handshake_callback.callback());
616
617 TestCompletionCallback connect_callback;
618 int client_ret = client_socket_->Connect(connect_callback.callback());
619
620 client_ret = connect_callback.GetResult(client_ret);
621 server_ret = handshake_callback.GetResult(server_ret);
622
623 ASSERT_THAT(client_ret, IsOk());
624 ASSERT_THAT(server_ret, IsOk());
625
626 // Make sure the cert status is expected.
627 SSLInfo ssl_info;
628 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
629 EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
630 SSLInfo ssl_server_info;
631 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
632 EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
633
634 // Pump client read to get new session tickets.
635 PumpServerToClient();
636
637 // Make sure the second connection is cached.
638 ASSERT_NO_FATAL_FAILURE(CreateSockets());
639 TestCompletionCallback handshake_callback2;
640 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
641
642 TestCompletionCallback connect_callback2;
643 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
644
645 client_ret2 = connect_callback2.GetResult(client_ret2);
646 server_ret2 = handshake_callback2.GetResult(server_ret2);
647
648 ASSERT_THAT(client_ret2, IsOk());
649 ASSERT_THAT(server_ret2, IsOk());
650
651 // Make sure the cert status is expected.
652 SSLInfo ssl_info2;
653 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
654 EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
655 SSLInfo ssl_server_info2;
656 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
657 EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
658 }
659
660 // This test makes sure the session cache separates out by server context.
TEST_F(SSLServerSocketTest,HandshakeCachedContextSwitch)661 TEST_F(SSLServerSocketTest, HandshakeCachedContextSwitch) {
662 ASSERT_NO_FATAL_FAILURE(CreateContext());
663 ASSERT_NO_FATAL_FAILURE(CreateSockets());
664
665 TestCompletionCallback handshake_callback;
666 int server_ret = server_socket_->Handshake(handshake_callback.callback());
667
668 TestCompletionCallback connect_callback;
669 int client_ret = client_socket_->Connect(connect_callback.callback());
670
671 client_ret = connect_callback.GetResult(client_ret);
672 server_ret = handshake_callback.GetResult(server_ret);
673
674 ASSERT_THAT(client_ret, IsOk());
675 ASSERT_THAT(server_ret, IsOk());
676
677 // Make sure the cert status is expected.
678 SSLInfo ssl_info;
679 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
680 EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
681 SSLInfo ssl_server_info;
682 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
683 EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
684
685 // Make sure the second connection is NOT cached when using a new context.
686 ASSERT_NO_FATAL_FAILURE(CreateContext());
687 ASSERT_NO_FATAL_FAILURE(CreateSockets());
688
689 TestCompletionCallback handshake_callback2;
690 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
691
692 TestCompletionCallback connect_callback2;
693 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
694
695 client_ret2 = connect_callback2.GetResult(client_ret2);
696 server_ret2 = handshake_callback2.GetResult(server_ret2);
697
698 ASSERT_THAT(client_ret2, IsOk());
699 ASSERT_THAT(server_ret2, IsOk());
700
701 // Make sure the cert status is expected.
702 SSLInfo ssl_info2;
703 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
704 EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_FULL);
705 SSLInfo ssl_server_info2;
706 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
707 EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_FULL);
708 }
709
710 // Client certificates are disabled on iOS.
711 #if !BUILDFLAG(IS_IOS)
712 // This test executes Connect() on SSLClientSocket and Handshake() on
713 // SSLServerSocket to make sure handshaking between the two sockets is
714 // completed successfully, using client certificate.
TEST_F(SSLServerSocketTest,HandshakeWithClientCert)715 TEST_F(SSLServerSocketTest, HandshakeWithClientCert) {
716 scoped_refptr<X509Certificate> client_cert =
717 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
718 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
719 kClientCertFileName, kClientPrivateKeyFileName));
720 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
721 ASSERT_NO_FATAL_FAILURE(CreateContext());
722 ASSERT_NO_FATAL_FAILURE(CreateSockets());
723
724 TestCompletionCallback handshake_callback;
725 int server_ret = server_socket_->Handshake(handshake_callback.callback());
726
727 TestCompletionCallback connect_callback;
728 int client_ret = client_socket_->Connect(connect_callback.callback());
729
730 client_ret = connect_callback.GetResult(client_ret);
731 server_ret = handshake_callback.GetResult(server_ret);
732
733 ASSERT_THAT(client_ret, IsOk());
734 ASSERT_THAT(server_ret, IsOk());
735
736 // Make sure the cert status is expected.
737 SSLInfo ssl_info;
738 client_socket_->GetSSLInfo(&ssl_info);
739 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
740 server_socket_->GetSSLInfo(&ssl_info);
741 ASSERT_TRUE(ssl_info.cert.get());
742 EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_info.cert.get()));
743 }
744
745 // This test executes Connect() on SSLClientSocket and Handshake() twice on
746 // SSLServerSocket to make sure handshaking between the two sockets is
747 // completed successfully, using client certificate. The second connection is
748 // expected to succeed through the session cache.
TEST_F(SSLServerSocketTest,HandshakeWithClientCertCached)749 TEST_F(SSLServerSocketTest, HandshakeWithClientCertCached) {
750 scoped_refptr<X509Certificate> client_cert =
751 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
752 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
753 kClientCertFileName, kClientPrivateKeyFileName));
754 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
755 ASSERT_NO_FATAL_FAILURE(CreateContext());
756 ASSERT_NO_FATAL_FAILURE(CreateSockets());
757
758 TestCompletionCallback handshake_callback;
759 int server_ret = server_socket_->Handshake(handshake_callback.callback());
760
761 TestCompletionCallback connect_callback;
762 int client_ret = client_socket_->Connect(connect_callback.callback());
763
764 client_ret = connect_callback.GetResult(client_ret);
765 server_ret = handshake_callback.GetResult(server_ret);
766
767 ASSERT_THAT(client_ret, IsOk());
768 ASSERT_THAT(server_ret, IsOk());
769
770 // Make sure the cert status is expected.
771 SSLInfo ssl_info;
772 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
773 EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
774 SSLInfo ssl_server_info;
775 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
776 ASSERT_TRUE(ssl_server_info.cert.get());
777 EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_server_info.cert.get()));
778 EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
779 // Pump client read to get new session tickets.
780 PumpServerToClient();
781 server_socket_->Disconnect();
782 client_socket_->Disconnect();
783
784 // Create the connection again.
785 ASSERT_NO_FATAL_FAILURE(CreateSockets());
786 TestCompletionCallback handshake_callback2;
787 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
788
789 TestCompletionCallback connect_callback2;
790 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
791
792 client_ret2 = connect_callback2.GetResult(client_ret2);
793 server_ret2 = handshake_callback2.GetResult(server_ret2);
794
795 ASSERT_THAT(client_ret2, IsOk());
796 ASSERT_THAT(server_ret2, IsOk());
797
798 // Make sure the cert status is expected.
799 SSLInfo ssl_info2;
800 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
801 EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
802 SSLInfo ssl_server_info2;
803 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
804 ASSERT_TRUE(ssl_server_info2.cert.get());
805 EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_server_info2.cert.get()));
806 EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
807 }
808
TEST_F(SSLServerSocketTest,HandshakeWithClientCertRequiredNotSupplied)809 TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSupplied) {
810 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
811 ASSERT_NO_FATAL_FAILURE(CreateContext());
812 ASSERT_NO_FATAL_FAILURE(CreateSockets());
813 // Use the default setting for the client socket, which is to not send
814 // a client certificate. This will cause the client to receive an
815 // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
816 // requested cert_authorities from the CertificateRequest sent by the
817 // server.
818
819 TestCompletionCallback handshake_callback;
820 int server_ret = server_socket_->Handshake(handshake_callback.callback());
821
822 TestCompletionCallback connect_callback;
823 EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
824 connect_callback.GetResult(
825 client_socket_->Connect(connect_callback.callback())));
826
827 auto request_info = base::MakeRefCounted<SSLCertRequestInfo>();
828 client_socket_->GetSSLCertRequestInfo(request_info.get());
829
830 // Check that the authority name that arrived in the CertificateRequest
831 // handshake message is as expected.
832 scoped_refptr<X509Certificate> client_cert =
833 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
834 ASSERT_TRUE(client_cert);
835 EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
836
837 client_socket_->Disconnect();
838
839 EXPECT_THAT(handshake_callback.GetResult(server_ret),
840 IsError(ERR_CONNECTION_CLOSED));
841 }
842
TEST_F(SSLServerSocketTest,HandshakeWithClientCertRequiredNotSuppliedCached)843 TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSuppliedCached) {
844 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
845 ASSERT_NO_FATAL_FAILURE(CreateContext());
846 ASSERT_NO_FATAL_FAILURE(CreateSockets());
847 // Use the default setting for the client socket, which is to not send
848 // a client certificate. This will cause the client to receive an
849 // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
850 // requested cert_authorities from the CertificateRequest sent by the
851 // server.
852
853 TestCompletionCallback handshake_callback;
854 int server_ret = server_socket_->Handshake(handshake_callback.callback());
855
856 TestCompletionCallback connect_callback;
857 EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
858 connect_callback.GetResult(
859 client_socket_->Connect(connect_callback.callback())));
860
861 auto request_info = base::MakeRefCounted<SSLCertRequestInfo>();
862 client_socket_->GetSSLCertRequestInfo(request_info.get());
863
864 // Check that the authority name that arrived in the CertificateRequest
865 // handshake message is as expected.
866 scoped_refptr<X509Certificate> client_cert =
867 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
868 ASSERT_TRUE(client_cert);
869 EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
870
871 client_socket_->Disconnect();
872
873 EXPECT_THAT(handshake_callback.GetResult(server_ret),
874 IsError(ERR_CONNECTION_CLOSED));
875 server_socket_->Disconnect();
876
877 // Below, check that the cache didn't store the result of a failed handshake.
878 ASSERT_NO_FATAL_FAILURE(CreateSockets());
879 TestCompletionCallback handshake_callback2;
880 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
881
882 TestCompletionCallback connect_callback2;
883 EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
884 connect_callback2.GetResult(
885 client_socket_->Connect(connect_callback2.callback())));
886
887 auto request_info2 = base::MakeRefCounted<SSLCertRequestInfo>();
888 client_socket_->GetSSLCertRequestInfo(request_info2.get());
889
890 // Check that the authority name that arrived in the CertificateRequest
891 // handshake message is as expected.
892 EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info2->cert_authorities));
893
894 client_socket_->Disconnect();
895
896 EXPECT_THAT(handshake_callback2.GetResult(server_ret2),
897 IsError(ERR_CONNECTION_CLOSED));
898 }
899
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSupplied)900 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSupplied) {
901 scoped_refptr<X509Certificate> client_cert =
902 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
903 ASSERT_TRUE(client_cert);
904
905 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
906 kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
907 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
908 ASSERT_NO_FATAL_FAILURE(CreateContext());
909 ASSERT_NO_FATAL_FAILURE(CreateSockets());
910
911 TestCompletionCallback handshake_callback;
912 int server_ret = server_socket_->Handshake(handshake_callback.callback());
913
914 TestCompletionCallback connect_callback;
915 int client_ret = client_socket_->Connect(connect_callback.callback());
916
917 // In TLS 1.3, the client cert error isn't exposed until Read is called.
918 EXPECT_EQ(OK, connect_callback.GetResult(client_ret));
919 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
920 handshake_callback.GetResult(server_ret));
921
922 // Pump client read to get client cert error.
923 const int kReadBufSize = 1024;
924 scoped_refptr<DrainableIOBuffer> read_buf =
925 base::MakeRefCounted<DrainableIOBuffer>(
926 base::MakeRefCounted<IOBuffer>(kReadBufSize), kReadBufSize);
927 TestCompletionCallback read_callback;
928 client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
929 read_callback.callback());
930 client_ret = read_callback.GetResult(client_ret);
931 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
932 }
933
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSuppliedTLS12)934 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedTLS12) {
935 scoped_refptr<X509Certificate> client_cert =
936 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
937 ASSERT_TRUE(client_cert);
938
939 client_ssl_config_.version_max_override = SSL_PROTOCOL_VERSION_TLS1_2;
940 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
941 kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
942 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
943 ASSERT_NO_FATAL_FAILURE(CreateContext());
944 ASSERT_NO_FATAL_FAILURE(CreateSockets());
945
946 TestCompletionCallback handshake_callback;
947 int server_ret = server_socket_->Handshake(handshake_callback.callback());
948
949 TestCompletionCallback connect_callback;
950 int client_ret = client_socket_->Connect(connect_callback.callback());
951
952 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
953 connect_callback.GetResult(client_ret));
954 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
955 handshake_callback.GetResult(server_ret));
956 }
957
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSuppliedCached)958 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedCached) {
959 scoped_refptr<X509Certificate> client_cert =
960 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
961 ASSERT_TRUE(client_cert);
962
963 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
964 kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
965 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
966 ASSERT_NO_FATAL_FAILURE(CreateContext());
967 ASSERT_NO_FATAL_FAILURE(CreateSockets());
968
969 TestCompletionCallback handshake_callback;
970 int server_ret = server_socket_->Handshake(handshake_callback.callback());
971
972 TestCompletionCallback connect_callback;
973 int client_ret = client_socket_->Connect(connect_callback.callback());
974
975 // In TLS 1.3, the client cert error isn't exposed until Read is called.
976 EXPECT_EQ(OK, connect_callback.GetResult(client_ret));
977 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
978 handshake_callback.GetResult(server_ret));
979
980 // Pump client read to get client cert error.
981 const int kReadBufSize = 1024;
982 scoped_refptr<DrainableIOBuffer> read_buf =
983 base::MakeRefCounted<DrainableIOBuffer>(
984 base::MakeRefCounted<IOBuffer>(kReadBufSize), kReadBufSize);
985 TestCompletionCallback read_callback;
986 client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
987 read_callback.callback());
988 client_ret = read_callback.GetResult(client_ret);
989 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
990
991 client_socket_->Disconnect();
992 server_socket_->Disconnect();
993
994 // Below, check that the cache didn't store the result of a failed handshake.
995 ASSERT_NO_FATAL_FAILURE(CreateSockets());
996 TestCompletionCallback handshake_callback2;
997 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
998
999 TestCompletionCallback connect_callback2;
1000 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
1001
1002 // In TLS 1.3, the client cert error isn't exposed until Read is called.
1003 EXPECT_EQ(OK, connect_callback2.GetResult(client_ret2));
1004 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
1005 handshake_callback2.GetResult(server_ret2));
1006
1007 // Pump client read to get client cert error.
1008 client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
1009 read_callback.callback());
1010 client_ret = read_callback.GetResult(client_ret);
1011 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
1012 }
1013 #endif // !IS_IOS
1014
TEST_P(SSLServerSocketReadTest,DataTransfer)1015 TEST_P(SSLServerSocketReadTest, DataTransfer) {
1016 ASSERT_NO_FATAL_FAILURE(CreateContext());
1017 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1018
1019 // Establish connection.
1020 TestCompletionCallback connect_callback;
1021 int client_ret = client_socket_->Connect(connect_callback.callback());
1022 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1023
1024 TestCompletionCallback handshake_callback;
1025 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1026 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1027
1028 client_ret = connect_callback.GetResult(client_ret);
1029 ASSERT_THAT(client_ret, IsOk());
1030 server_ret = handshake_callback.GetResult(server_ret);
1031 ASSERT_THAT(server_ret, IsOk());
1032
1033 const int kReadBufSize = 1024;
1034 scoped_refptr<StringIOBuffer> write_buf =
1035 base::MakeRefCounted<StringIOBuffer>("testing123");
1036 scoped_refptr<DrainableIOBuffer> read_buf =
1037 base::MakeRefCounted<DrainableIOBuffer>(
1038 base::MakeRefCounted<IOBuffer>(kReadBufSize), kReadBufSize);
1039
1040 // Write then read.
1041 TestCompletionCallback write_callback;
1042 TestCompletionCallback read_callback;
1043 server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
1044 write_callback.callback(),
1045 TRAFFIC_ANNOTATION_FOR_TESTS);
1046 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
1047 client_ret = client_socket_->Read(
1048 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
1049 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1050
1051 server_ret = write_callback.GetResult(server_ret);
1052 EXPECT_GT(server_ret, 0);
1053 client_ret = read_callback.GetResult(client_ret);
1054 ASSERT_GT(client_ret, 0);
1055
1056 read_buf->DidConsume(client_ret);
1057 while (read_buf->BytesConsumed() < write_buf->size()) {
1058 client_ret = client_socket_->Read(
1059 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
1060 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1061 client_ret = read_callback.GetResult(client_ret);
1062 ASSERT_GT(client_ret, 0);
1063 read_buf->DidConsume(client_ret);
1064 }
1065 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
1066 read_buf->SetOffset(0);
1067 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
1068
1069 // Read then write.
1070 write_buf = base::MakeRefCounted<StringIOBuffer>("hello123");
1071 server_ret = Read(server_socket_.get(), read_buf.get(),
1072 read_buf->BytesRemaining(), read_callback.callback());
1073 EXPECT_EQ(server_ret, ERR_IO_PENDING);
1074 client_ret = client_socket_->Write(write_buf.get(), write_buf->size(),
1075 write_callback.callback(),
1076 TRAFFIC_ANNOTATION_FOR_TESTS);
1077 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1078
1079 server_ret = read_callback.GetResult(server_ret);
1080 if (read_if_ready_enabled()) {
1081 // ReadIfReady signals the data is available but does not consume it.
1082 // The data is consumed later below.
1083 ASSERT_EQ(server_ret, OK);
1084 } else {
1085 ASSERT_GT(server_ret, 0);
1086 read_buf->DidConsume(server_ret);
1087 }
1088 client_ret = write_callback.GetResult(client_ret);
1089 EXPECT_GT(client_ret, 0);
1090
1091 while (read_buf->BytesConsumed() < write_buf->size()) {
1092 server_ret = Read(server_socket_.get(), read_buf.get(),
1093 read_buf->BytesRemaining(), read_callback.callback());
1094 // All the data was written above, so the data should be synchronously
1095 // available out of both Read() and ReadIfReady().
1096 ASSERT_GT(server_ret, 0);
1097 read_buf->DidConsume(server_ret);
1098 }
1099 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
1100 read_buf->SetOffset(0);
1101 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
1102 }
1103
1104 // A regression test for bug 127822 (http://crbug.com/127822).
1105 // If the server closes the connection after the handshake is finished,
1106 // the client's Write() call should not cause an infinite loop.
1107 // NOTE: this is a test for SSLClientSocket rather than SSLServerSocket.
TEST_F(SSLServerSocketTest,ClientWriteAfterServerClose)1108 TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
1109 ASSERT_NO_FATAL_FAILURE(CreateContext());
1110 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1111
1112 // Establish connection.
1113 TestCompletionCallback connect_callback;
1114 int client_ret = client_socket_->Connect(connect_callback.callback());
1115 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1116
1117 TestCompletionCallback handshake_callback;
1118 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1119 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1120
1121 client_ret = connect_callback.GetResult(client_ret);
1122 ASSERT_THAT(client_ret, IsOk());
1123 server_ret = handshake_callback.GetResult(server_ret);
1124 ASSERT_THAT(server_ret, IsOk());
1125
1126 scoped_refptr<StringIOBuffer> write_buf =
1127 base::MakeRefCounted<StringIOBuffer>("testing123");
1128
1129 // The server closes the connection. The server needs to write some
1130 // data first so that the client's Read() calls from the transport
1131 // socket won't return ERR_IO_PENDING. This ensures that the client
1132 // will call Read() on the transport socket again.
1133 TestCompletionCallback write_callback;
1134 server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
1135 write_callback.callback(),
1136 TRAFFIC_ANNOTATION_FOR_TESTS);
1137 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
1138
1139 server_ret = write_callback.GetResult(server_ret);
1140 EXPECT_GT(server_ret, 0);
1141
1142 server_socket_->Disconnect();
1143
1144 // The client writes some data. This should not cause an infinite loop.
1145 client_ret = client_socket_->Write(write_buf.get(), write_buf->size(),
1146 write_callback.callback(),
1147 TRAFFIC_ANNOTATION_FOR_TESTS);
1148 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1149
1150 client_ret = write_callback.GetResult(client_ret);
1151 EXPECT_GT(client_ret, 0);
1152
1153 base::RunLoop run_loop;
1154 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
1155 FROM_HERE, run_loop.QuitClosure(), base::Milliseconds(10));
1156 run_loop.Run();
1157 }
1158
1159 // This test executes ExportKeyingMaterial() on the client and server sockets,
1160 // after connecting them, and verifies that the results match.
1161 // This test will fail if False Start is enabled (see crbug.com/90208).
TEST_F(SSLServerSocketTest,ExportKeyingMaterial)1162 TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
1163 ASSERT_NO_FATAL_FAILURE(CreateContext());
1164 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1165
1166 TestCompletionCallback connect_callback;
1167 int client_ret = client_socket_->Connect(connect_callback.callback());
1168 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1169
1170 TestCompletionCallback handshake_callback;
1171 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1172 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1173
1174 if (client_ret == ERR_IO_PENDING) {
1175 ASSERT_THAT(connect_callback.WaitForResult(), IsOk());
1176 }
1177 if (server_ret == ERR_IO_PENDING) {
1178 ASSERT_THAT(handshake_callback.WaitForResult(), IsOk());
1179 }
1180
1181 const int kKeyingMaterialSize = 32;
1182 const char kKeyingLabel[] = "EXPERIMENTAL-server-socket-test";
1183 const char kKeyingContext[] = "";
1184 unsigned char server_out[kKeyingMaterialSize];
1185 int rv = server_socket_->ExportKeyingMaterial(
1186 kKeyingLabel, false, kKeyingContext, server_out, sizeof(server_out));
1187 ASSERT_THAT(rv, IsOk());
1188
1189 unsigned char client_out[kKeyingMaterialSize];
1190 rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, false, kKeyingContext,
1191 client_out, sizeof(client_out));
1192 ASSERT_THAT(rv, IsOk());
1193 EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out)));
1194
1195 const char kKeyingLabelBad[] = "EXPERIMENTAL-server-socket-test-bad";
1196 unsigned char client_bad[kKeyingMaterialSize];
1197 rv = client_socket_->ExportKeyingMaterial(
1198 kKeyingLabelBad, false, kKeyingContext, client_bad, sizeof(client_bad));
1199 ASSERT_EQ(rv, OK);
1200 EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out)));
1201 }
1202
1203 // Verifies that SSLConfig::require_ecdhe flags works properly.
TEST_F(SSLServerSocketTest,RequireEcdheFlag)1204 TEST_F(SSLServerSocketTest, RequireEcdheFlag) {
1205 // Disable all ECDHE suites on the client side.
1206 SSLContextConfig config;
1207 config.disabled_cipher_suites.assign(
1208 kEcdheCiphers, kEcdheCiphers + std::size(kEcdheCiphers));
1209
1210 // Legacy RSA key exchange ciphers only exist in TLS 1.2 and below.
1211 config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
1212 ssl_config_service_->UpdateSSLConfigAndNotify(config);
1213
1214 // Require ECDHE on the server.
1215 server_ssl_config_.require_ecdhe = true;
1216
1217 ASSERT_NO_FATAL_FAILURE(CreateContext());
1218 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1219
1220 TestCompletionCallback connect_callback;
1221 int client_ret = client_socket_->Connect(connect_callback.callback());
1222
1223 TestCompletionCallback handshake_callback;
1224 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1225
1226 client_ret = connect_callback.GetResult(client_ret);
1227 server_ret = handshake_callback.GetResult(server_ret);
1228
1229 ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1230 ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1231 }
1232
1233 // This test executes Connect() on SSLClientSocket and Handshake() on
1234 // SSLServerSocket to make sure handshaking between the two sockets is
1235 // completed successfully. The server key is represented by SSLPrivateKey.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKey)1236 TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKey) {
1237 ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1238 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1239
1240 TestCompletionCallback handshake_callback;
1241 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1242
1243 TestCompletionCallback connect_callback;
1244 int client_ret = client_socket_->Connect(connect_callback.callback());
1245
1246 client_ret = connect_callback.GetResult(client_ret);
1247 server_ret = handshake_callback.GetResult(server_ret);
1248
1249 ASSERT_THAT(client_ret, IsOk());
1250 ASSERT_THAT(server_ret, IsOk());
1251
1252 // Make sure the cert status is expected.
1253 SSLInfo ssl_info;
1254 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
1255 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
1256
1257 // The default cipher suite should be ECDHE and an AEAD.
1258 uint16_t cipher_suite =
1259 SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
1260 const char* key_exchange;
1261 const char* cipher;
1262 const char* mac;
1263 bool is_aead;
1264 bool is_tls13;
1265 SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
1266 cipher_suite);
1267 EXPECT_TRUE(is_aead);
1268 }
1269
1270 // Verifies that non-ECDHE ciphers are disabled when using SSLPrivateKey as the
1271 // server key.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKeyRequireEcdhe)1272 TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKeyRequireEcdhe) {
1273 // Disable all ECDHE suites on the client side.
1274 SSLContextConfig config;
1275 config.disabled_cipher_suites.assign(
1276 kEcdheCiphers, kEcdheCiphers + std::size(kEcdheCiphers));
1277 // TLS 1.3 always works with SSLPrivateKey.
1278 config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
1279 ssl_config_service_->UpdateSSLConfigAndNotify(config);
1280
1281 ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1282 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1283
1284 TestCompletionCallback connect_callback;
1285 int client_ret = client_socket_->Connect(connect_callback.callback());
1286
1287 TestCompletionCallback handshake_callback;
1288 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1289
1290 client_ret = connect_callback.GetResult(client_ret);
1291 server_ret = handshake_callback.GetResult(server_ret);
1292
1293 ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1294 ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1295 }
1296
1297 class SSLServerSocketAlpsTest
1298 : public SSLServerSocketTest,
1299 public ::testing::WithParamInterface<std::tuple<bool, bool>> {
1300 public:
SSLServerSocketAlpsTest()1301 SSLServerSocketAlpsTest()
1302 : client_alps_enabled_(std::get<0>(GetParam())),
1303 server_alps_enabled_(std::get<1>(GetParam())) {}
1304 ~SSLServerSocketAlpsTest() override = default;
1305 const bool client_alps_enabled_;
1306 const bool server_alps_enabled_;
1307 };
1308
1309 INSTANTIATE_TEST_SUITE_P(All,
1310 SSLServerSocketAlpsTest,
1311 ::testing::Combine(::testing::Bool(),
1312 ::testing::Bool()));
1313
TEST_P(SSLServerSocketAlpsTest,Alps)1314 TEST_P(SSLServerSocketAlpsTest, Alps) {
1315 const std::string server_data = "server sends some test data";
1316 const std::string client_data = "client also sends some data";
1317
1318 server_ssl_config_.alpn_protos = {kProtoHTTP2};
1319 if (server_alps_enabled_) {
1320 server_ssl_config_.application_settings[kProtoHTTP2] =
1321 std::vector<uint8_t>(server_data.begin(), server_data.end());
1322 }
1323
1324 client_ssl_config_.alpn_protos = {kProtoHTTP2};
1325 if (client_alps_enabled_) {
1326 client_ssl_config_.application_settings[kProtoHTTP2] =
1327 std::vector<uint8_t>(client_data.begin(), client_data.end());
1328 }
1329
1330 ASSERT_NO_FATAL_FAILURE(CreateContext());
1331 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1332
1333 TestCompletionCallback handshake_callback;
1334 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1335
1336 TestCompletionCallback connect_callback;
1337 int client_ret = client_socket_->Connect(connect_callback.callback());
1338
1339 client_ret = connect_callback.GetResult(client_ret);
1340 server_ret = handshake_callback.GetResult(server_ret);
1341
1342 ASSERT_THAT(client_ret, IsOk());
1343 ASSERT_THAT(server_ret, IsOk());
1344
1345 // ALPS is negotiated only if ALPS is enabled both on client and server.
1346 const auto alps_data_received_by_client =
1347 client_socket_->GetPeerApplicationSettings();
1348 const auto alps_data_received_by_server =
1349 server_socket_->GetPeerApplicationSettings();
1350
1351 if (client_alps_enabled_ && server_alps_enabled_) {
1352 ASSERT_TRUE(alps_data_received_by_client.has_value());
1353 EXPECT_EQ(server_data, alps_data_received_by_client.value());
1354 ASSERT_TRUE(alps_data_received_by_server.has_value());
1355 EXPECT_EQ(client_data, alps_data_received_by_server.value());
1356 } else {
1357 EXPECT_FALSE(alps_data_received_by_client.has_value());
1358 EXPECT_FALSE(alps_data_received_by_server.has_value());
1359 }
1360 }
1361
1362 // Test that CancelReadIfReady works.
TEST_F(SSLServerSocketTest,CancelReadIfReady)1363 TEST_F(SSLServerSocketTest, CancelReadIfReady) {
1364 ASSERT_NO_FATAL_FAILURE(CreateContext());
1365 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1366
1367 TestCompletionCallback connect_callback;
1368 int client_ret = client_socket_->Connect(connect_callback.callback());
1369 TestCompletionCallback handshake_callback;
1370 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1371 ASSERT_THAT(connect_callback.GetResult(client_ret), IsOk());
1372 ASSERT_THAT(handshake_callback.GetResult(server_ret), IsOk());
1373
1374 // Attempt to read from the server socket. There will not be anything to read.
1375 // Cancel the read immediately afterwards.
1376 TestCompletionCallback read_callback;
1377 auto read_buf = base::MakeRefCounted<IOBuffer>(1);
1378 int read_ret =
1379 server_socket_->ReadIfReady(read_buf.get(), 1, read_callback.callback());
1380 ASSERT_THAT(read_ret, IsError(ERR_IO_PENDING));
1381 ASSERT_THAT(server_socket_->CancelReadIfReady(), IsOk());
1382
1383 // After the client writes data, the server should still not pick up a result.
1384 auto write_buf = base::MakeRefCounted<StringIOBuffer>("a");
1385 TestCompletionCallback write_callback;
1386 ASSERT_EQ(write_callback.GetResult(client_socket_->Write(
1387 write_buf.get(), write_buf->size(), write_callback.callback(),
1388 TRAFFIC_ANNOTATION_FOR_TESTS)),
1389 write_buf->size());
1390 base::RunLoop().RunUntilIdle();
1391 EXPECT_FALSE(read_callback.have_result());
1392
1393 // After a canceled read, future reads are still possible.
1394 while (true) {
1395 TestCompletionCallback read_callback2;
1396 read_ret = server_socket_->ReadIfReady(read_buf.get(), 1,
1397 read_callback2.callback());
1398 if (read_ret != ERR_IO_PENDING) {
1399 break;
1400 }
1401 ASSERT_THAT(read_callback2.GetResult(read_ret), IsOk());
1402 }
1403 ASSERT_EQ(1, read_ret);
1404 EXPECT_EQ(read_buf->data()[0], 'a');
1405 }
1406
1407 } // namespace net
1408