1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 // This test suite uses SSLClientSocket to test the implementation of
6 // SSLServerSocket. In order to establish connections between the sockets
7 // we need two additional classes:
8 // 1. FakeSocket
9 // Connects SSL socket to FakeDataChannel. This class is just a stub.
10 //
11 // 2. FakeDataChannel
12 // Implements the actual exchange of data between two FakeSockets.
13 //
14 // Implementations of these two classes are included in this file.
15
16 #include "net/socket/ssl_server_socket.h"
17
18 #include <stdint.h>
19 #include <stdlib.h>
20
21 #include <memory>
22 #include <utility>
23
24 #include "base/check.h"
25 #include "base/compiler_specific.h"
26 #include "base/containers/queue.h"
27 #include "base/files/file_path.h"
28 #include "base/files/file_util.h"
29 #include "base/functional/bind.h"
30 #include "base/functional/callback_helpers.h"
31 #include "base/location.h"
32 #include "base/memory/raw_ptr.h"
33 #include "base/memory/scoped_refptr.h"
34 #include "base/notreached.h"
35 #include "base/run_loop.h"
36 #include "base/task/single_thread_task_runner.h"
37 #include "base/test/bind.h"
38 #include "base/test/task_environment.h"
39 #include "build/build_config.h"
40 #include "crypto/rsa_private_key.h"
41 #include "net/base/address_list.h"
42 #include "net/base/completion_once_callback.h"
43 #include "net/base/host_port_pair.h"
44 #include "net/base/io_buffer.h"
45 #include "net/base/ip_address.h"
46 #include "net/base/ip_endpoint.h"
47 #include "net/base/net_errors.h"
48 #include "net/cert/cert_status_flags.h"
49 #include "net/cert/ct_policy_enforcer.h"
50 #include "net/cert/ct_policy_status.h"
51 #include "net/cert/mock_cert_verifier.h"
52 #include "net/cert/mock_client_cert_verifier.h"
53 #include "net/cert/signed_certificate_timestamp_and_status.h"
54 #include "net/cert/x509_certificate.h"
55 #include "net/http/transport_security_state.h"
56 #include "net/log/net_log_with_source.h"
57 #include "net/socket/client_socket_factory.h"
58 #include "net/socket/socket_test_util.h"
59 #include "net/socket/ssl_client_socket.h"
60 #include "net/socket/stream_socket.h"
61 #include "net/ssl/ssl_cert_request_info.h"
62 #include "net/ssl/ssl_cipher_suite_names.h"
63 #include "net/ssl/ssl_client_session_cache.h"
64 #include "net/ssl/ssl_connection_status_flags.h"
65 #include "net/ssl/ssl_info.h"
66 #include "net/ssl/ssl_private_key.h"
67 #include "net/ssl/ssl_server_config.h"
68 #include "net/ssl/test_ssl_config_service.h"
69 #include "net/ssl/test_ssl_private_key.h"
70 #include "net/test/cert_test_util.h"
71 #include "net/test/gtest_util.h"
72 #include "net/test/test_data_directory.h"
73 #include "net/test/test_with_task_environment.h"
74 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
75 #include "testing/gmock/include/gmock/gmock.h"
76 #include "testing/gtest/include/gtest/gtest.h"
77 #include "testing/platform_test.h"
78 #include "third_party/boringssl/src/include/openssl/evp.h"
79 #include "third_party/boringssl/src/include/openssl/ssl.h"
80
81 using net::test::IsError;
82 using net::test::IsOk;
83
84 namespace net {
85
86 namespace {
87
88 // Client certificates are disabled on iOS.
89 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
90 const char kClientCertFileName[] = "client_1.pem";
91 const char kClientPrivateKeyFileName[] = "client_1.pk8";
92 const char kWrongClientCertFileName[] = "client_2.pem";
93 const char kWrongClientPrivateKeyFileName[] = "client_2.pk8";
94 #endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
95
96 const uint16_t kEcdheCiphers[] = {
97 0xc007, // ECDHE_ECDSA_WITH_RC4_128_SHA
98 0xc009, // ECDHE_ECDSA_WITH_AES_128_CBC_SHA
99 0xc00a, // ECDHE_ECDSA_WITH_AES_256_CBC_SHA
100 0xc011, // ECDHE_RSA_WITH_RC4_128_SHA
101 0xc013, // ECDHE_RSA_WITH_AES_128_CBC_SHA
102 0xc014, // ECDHE_RSA_WITH_AES_256_CBC_SHA
103 0xc02b, // ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
104 0xc02c, // ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
105 0xc02f, // ECDHE_RSA_WITH_AES_128_GCM_SHA256
106 0xc030, // ECDHE_RSA_WITH_AES_256_GCM_SHA384
107 0xcca8, // ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
108 0xcca9, // ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
109 };
110
111 class MockCTPolicyEnforcer : public CTPolicyEnforcer {
112 public:
113 MockCTPolicyEnforcer() = default;
114 ~MockCTPolicyEnforcer() override = default;
CheckCompliance(X509Certificate * cert,const ct::SCTList & verified_scts,const NetLogWithSource & net_log)115 ct::CTPolicyCompliance CheckCompliance(
116 X509Certificate* cert,
117 const ct::SCTList& verified_scts,
118 const NetLogWithSource& net_log) override {
119 return ct::CTPolicyCompliance::CT_POLICY_COMPLIES_VIA_SCTS;
120 }
121 };
122
123 class FakeDataChannel {
124 public:
125 FakeDataChannel() = default;
126
127 FakeDataChannel(const FakeDataChannel&) = delete;
128 FakeDataChannel& operator=(const FakeDataChannel&) = delete;
129
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)130 int Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) {
131 DCHECK(read_callback_.is_null());
132 DCHECK(!read_buf_.get());
133 if (closed_)
134 return 0;
135 if (data_.empty()) {
136 read_callback_ = std::move(callback);
137 read_buf_ = buf;
138 read_buf_len_ = buf_len;
139 return ERR_IO_PENDING;
140 }
141 return PropagateData(buf, buf_len);
142 }
143
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)144 int Write(IOBuffer* buf,
145 int buf_len,
146 CompletionOnceCallback callback,
147 const NetworkTrafficAnnotationTag& traffic_annotation) {
148 DCHECK(write_callback_.is_null());
149 if (closed_) {
150 if (write_called_after_close_)
151 return ERR_CONNECTION_RESET;
152 write_called_after_close_ = true;
153 write_callback_ = std::move(callback);
154 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
155 FROM_HERE, base::BindOnce(&FakeDataChannel::DoWriteCallback,
156 weak_factory_.GetWeakPtr()));
157 return ERR_IO_PENDING;
158 }
159 // This function returns synchronously, so make a copy of the buffer.
160 data_.push(base::MakeRefCounted<DrainableIOBuffer>(
161 base::MakeRefCounted<StringIOBuffer>(std::string(buf->data(), buf_len)),
162 buf_len));
163 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
164 FROM_HERE, base::BindOnce(&FakeDataChannel::DoReadCallback,
165 weak_factory_.GetWeakPtr()));
166 return buf_len;
167 }
168
169 // Closes the FakeDataChannel. After Close() is called, Read() returns 0,
170 // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that
171 // after the FakeDataChannel is closed, the first Write() call completes
172 // asynchronously, which is necessary to reproduce bug 127822.
Close()173 void Close() {
174 closed_ = true;
175 if (!read_callback_.is_null()) {
176 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
177 FROM_HERE, base::BindOnce(&FakeDataChannel::DoReadCallback,
178 weak_factory_.GetWeakPtr()));
179 }
180 }
181
182 private:
DoReadCallback()183 void DoReadCallback() {
184 if (read_callback_.is_null())
185 return;
186
187 if (closed_) {
188 std::move(read_callback_).Run(ERR_CONNECTION_CLOSED);
189 return;
190 }
191
192 if (data_.empty())
193 return;
194
195 int copied = PropagateData(read_buf_, read_buf_len_);
196 read_buf_ = nullptr;
197 read_buf_len_ = 0;
198 std::move(read_callback_).Run(copied);
199 }
200
DoWriteCallback()201 void DoWriteCallback() {
202 if (write_callback_.is_null())
203 return;
204
205 std::move(write_callback_).Run(ERR_CONNECTION_RESET);
206 }
207
PropagateData(scoped_refptr<IOBuffer> read_buf,int read_buf_len)208 int PropagateData(scoped_refptr<IOBuffer> read_buf, int read_buf_len) {
209 scoped_refptr<DrainableIOBuffer> buf = data_.front();
210 int copied = std::min(buf->BytesRemaining(), read_buf_len);
211 memcpy(read_buf->data(), buf->data(), copied);
212 buf->DidConsume(copied);
213
214 if (!buf->BytesRemaining())
215 data_.pop();
216 return copied;
217 }
218
219 CompletionOnceCallback read_callback_;
220 scoped_refptr<IOBuffer> read_buf_;
221 int read_buf_len_ = 0;
222
223 CompletionOnceCallback write_callback_;
224
225 base::queue<scoped_refptr<DrainableIOBuffer>> data_;
226
227 // True if Close() has been called.
228 bool closed_ = false;
229
230 // Controls the completion of Write() after the FakeDataChannel is closed.
231 // After the FakeDataChannel is closed, the first Write() call completes
232 // asynchronously.
233 bool write_called_after_close_ = false;
234
235 base::WeakPtrFactory<FakeDataChannel> weak_factory_{this};
236 };
237
238 class FakeSocket : public StreamSocket {
239 public:
FakeSocket(FakeDataChannel * incoming_channel,FakeDataChannel * outgoing_channel)240 FakeSocket(FakeDataChannel* incoming_channel,
241 FakeDataChannel* outgoing_channel)
242 : incoming_(incoming_channel), outgoing_(outgoing_channel) {}
243
244 FakeSocket(const FakeSocket&) = delete;
245 FakeSocket& operator=(const FakeSocket&) = delete;
246
247 ~FakeSocket() override = default;
248
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)249 int Read(IOBuffer* buf,
250 int buf_len,
251 CompletionOnceCallback callback) override {
252 // Read random number of bytes.
253 buf_len = rand() % buf_len + 1;
254 return incoming_->Read(buf, buf_len, std::move(callback));
255 }
256
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)257 int Write(IOBuffer* buf,
258 int buf_len,
259 CompletionOnceCallback callback,
260 const NetworkTrafficAnnotationTag& traffic_annotation) override {
261 // Write random number of bytes.
262 buf_len = rand() % buf_len + 1;
263 return outgoing_->Write(buf, buf_len, std::move(callback),
264 TRAFFIC_ANNOTATION_FOR_TESTS);
265 }
266
SetReceiveBufferSize(int32_t size)267 int SetReceiveBufferSize(int32_t size) override { return OK; }
268
SetSendBufferSize(int32_t size)269 int SetSendBufferSize(int32_t size) override { return OK; }
270
Connect(CompletionOnceCallback callback)271 int Connect(CompletionOnceCallback callback) override { return OK; }
272
Disconnect()273 void Disconnect() override {
274 incoming_->Close();
275 outgoing_->Close();
276 }
277
IsConnected() const278 bool IsConnected() const override { return true; }
279
IsConnectedAndIdle() const280 bool IsConnectedAndIdle() const override { return true; }
281
GetPeerAddress(IPEndPoint * address) const282 int GetPeerAddress(IPEndPoint* address) const override {
283 *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/);
284 return OK;
285 }
286
GetLocalAddress(IPEndPoint * address) const287 int GetLocalAddress(IPEndPoint* address) const override {
288 *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/);
289 return OK;
290 }
291
NetLog() const292 const NetLogWithSource& NetLog() const override { return net_log_; }
293
WasEverUsed() const294 bool WasEverUsed() const override { return true; }
295
GetNegotiatedProtocol() const296 NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
297
GetSSLInfo(SSLInfo * ssl_info)298 bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
299
GetTotalReceivedBytes() const300 int64_t GetTotalReceivedBytes() const override {
301 NOTIMPLEMENTED();
302 return 0;
303 }
304
ApplySocketTag(const SocketTag & tag)305 void ApplySocketTag(const SocketTag& tag) override {}
306
307 private:
308 NetLogWithSource net_log_;
309 raw_ptr<FakeDataChannel> incoming_;
310 raw_ptr<FakeDataChannel> outgoing_;
311 };
312
313 } // namespace
314
315 // Verify the correctness of the test helper classes first.
TEST(FakeSocketTest,DataTransfer)316 TEST(FakeSocketTest, DataTransfer) {
317 base::test::TaskEnvironment task_environment;
318
319 // Establish channels between two sockets.
320 FakeDataChannel channel_1;
321 FakeDataChannel channel_2;
322 FakeSocket client(&channel_1, &channel_2);
323 FakeSocket server(&channel_2, &channel_1);
324
325 const char kTestData[] = "testing123";
326 const int kTestDataSize = strlen(kTestData);
327 const int kReadBufSize = 1024;
328 auto write_buf = base::MakeRefCounted<StringIOBuffer>(kTestData);
329 auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
330
331 // Write then read.
332 int written =
333 server.Write(write_buf.get(), kTestDataSize, CompletionOnceCallback(),
334 TRAFFIC_ANNOTATION_FOR_TESTS);
335 EXPECT_GT(written, 0);
336 EXPECT_LE(written, kTestDataSize);
337
338 int read =
339 client.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
340 EXPECT_GT(read, 0);
341 EXPECT_LE(read, written);
342 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
343
344 // Read then write.
345 TestCompletionCallback callback;
346 EXPECT_EQ(ERR_IO_PENDING,
347 server.Read(read_buf.get(), kReadBufSize, callback.callback()));
348
349 written =
350 client.Write(write_buf.get(), kTestDataSize, CompletionOnceCallback(),
351 TRAFFIC_ANNOTATION_FOR_TESTS);
352 EXPECT_GT(written, 0);
353 EXPECT_LE(written, kTestDataSize);
354
355 read = callback.WaitForResult();
356 EXPECT_GT(read, 0);
357 EXPECT_LE(read, written);
358 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
359 }
360
361 class SSLServerSocketTest : public PlatformTest, public WithTaskEnvironment {
362 public:
SSLServerSocketTest()363 SSLServerSocketTest()
364 : ssl_config_service_(
365 std::make_unique<TestSSLConfigService>(SSLContextConfig())),
366 cert_verifier_(std::make_unique<MockCertVerifier>()),
367 client_cert_verifier_(std::make_unique<MockClientCertVerifier>()),
368 transport_security_state_(std::make_unique<TransportSecurityState>()),
369 ct_policy_enforcer_(std::make_unique<MockCTPolicyEnforcer>()),
370 ssl_client_session_cache_(std::make_unique<SSLClientSessionCache>(
371 SSLClientSessionCache::Config())) {}
372
SetUp()373 void SetUp() override {
374 PlatformTest::SetUp();
375
376 cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
377 client_cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
378
379 server_cert_ =
380 ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der");
381 ASSERT_TRUE(server_cert_);
382 server_private_key_ = ReadTestKey("unittest.key.bin");
383 ASSERT_TRUE(server_private_key_);
384
385 std::unique_ptr<crypto::RSAPrivateKey> key =
386 ReadTestKey("unittest.key.bin");
387 ASSERT_TRUE(key);
388 server_ssl_private_key_ = WrapOpenSSLPrivateKey(bssl::UpRef(key->key()));
389
390 // Certificate provided by the host doesn't need authority.
391 client_ssl_config_.allowed_bad_certs.emplace_back(
392 server_cert_, CERT_STATUS_AUTHORITY_INVALID);
393
394 client_context_ = std::make_unique<SSLClientContext>(
395 ssl_config_service_.get(), cert_verifier_.get(),
396 transport_security_state_.get(), ct_policy_enforcer_.get(),
397 ssl_client_session_cache_.get(), nullptr);
398 }
399
400 protected:
CreateContext()401 void CreateContext() {
402 client_socket_.reset();
403 server_socket_.reset();
404 channel_1_.reset();
405 channel_2_.reset();
406 server_context_ = CreateSSLServerContext(
407 server_cert_.get(), *server_private_key_, server_ssl_config_);
408 }
409
CreateContextSSLPrivateKey()410 void CreateContextSSLPrivateKey() {
411 client_socket_.reset();
412 server_socket_.reset();
413 channel_1_.reset();
414 channel_2_.reset();
415 server_context_.reset();
416 server_context_ = CreateSSLServerContext(
417 server_cert_.get(), server_ssl_private_key_, server_ssl_config_);
418 }
419
GetHostAndPort()420 static HostPortPair GetHostAndPort() { return HostPortPair("unittest", 0); }
421
CreateSockets()422 void CreateSockets() {
423 client_socket_.reset();
424 server_socket_.reset();
425 channel_1_ = std::make_unique<FakeDataChannel>();
426 channel_2_ = std::make_unique<FakeDataChannel>();
427 std::unique_ptr<StreamSocket> client_connection =
428 std::make_unique<FakeSocket>(channel_1_.get(), channel_2_.get());
429 std::unique_ptr<StreamSocket> server_socket =
430 std::make_unique<FakeSocket>(channel_2_.get(), channel_1_.get());
431
432 client_socket_ = client_context_->CreateSSLClientSocket(
433 std::move(client_connection), GetHostAndPort(), client_ssl_config_);
434 ASSERT_TRUE(client_socket_);
435
436 server_socket_ =
437 server_context_->CreateSSLServerSocket(std::move(server_socket));
438 ASSERT_TRUE(server_socket_);
439 }
440
441 // Client certificates are disabled on iOS.
442 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
ConfigureClientCertsForClient(const char * cert_file_name,const char * private_key_file_name)443 void ConfigureClientCertsForClient(const char* cert_file_name,
444 const char* private_key_file_name) {
445 scoped_refptr<X509Certificate> client_cert =
446 ImportCertFromFile(GetTestCertsDirectory(), cert_file_name);
447 ASSERT_TRUE(client_cert);
448
449 std::unique_ptr<crypto::RSAPrivateKey> key =
450 ReadTestKey(private_key_file_name);
451 ASSERT_TRUE(key);
452
453 client_context_->SetClientCertificate(
454 GetHostAndPort(), std::move(client_cert),
455 WrapOpenSSLPrivateKey(bssl::UpRef(key->key())));
456 }
457
ConfigureClientCertsForServer()458 void ConfigureClientCertsForServer() {
459 server_ssl_config_.client_cert_type =
460 SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT;
461
462 // "CN=B CA" - DER encoded DN of the issuer of client_1.pem
463 static const uint8_t kClientCertCAName[] = {
464 0x30, 0x0f, 0x31, 0x0d, 0x30, 0x0b, 0x06, 0x03, 0x55,
465 0x04, 0x03, 0x0c, 0x04, 0x42, 0x20, 0x43, 0x41};
466 server_ssl_config_.cert_authorities.emplace_back(
467 std::begin(kClientCertCAName), std::end(kClientCertCAName));
468
469 scoped_refptr<X509Certificate> expected_client_cert(
470 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName));
471 ASSERT_TRUE(expected_client_cert);
472
473 client_cert_verifier_->AddResultForCert(expected_client_cert.get(), OK);
474
475 server_ssl_config_.client_cert_verifier = client_cert_verifier_.get();
476 }
477 #endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
478
ReadTestKey(base::StringPiece name)479 std::unique_ptr<crypto::RSAPrivateKey> ReadTestKey(base::StringPiece name) {
480 base::FilePath certs_dir(GetTestCertsDirectory());
481 base::FilePath key_path = certs_dir.AppendASCII(name);
482 std::string key_string;
483 if (!base::ReadFileToString(key_path, &key_string))
484 return nullptr;
485 std::vector<uint8_t> key_vector(
486 reinterpret_cast<const uint8_t*>(key_string.data()),
487 reinterpret_cast<const uint8_t*>(key_string.data() +
488 key_string.length()));
489 std::unique_ptr<crypto::RSAPrivateKey> key(
490 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
491 return key;
492 }
493
PumpServerToClient()494 void PumpServerToClient() {
495 const int kReadBufSize = 1024;
496 scoped_refptr<StringIOBuffer> write_buf =
497 base::MakeRefCounted<StringIOBuffer>("testing123");
498 scoped_refptr<DrainableIOBuffer> read_buf =
499 base::MakeRefCounted<DrainableIOBuffer>(
500 base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
501 TestCompletionCallback write_callback;
502 TestCompletionCallback read_callback;
503 int server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
504 write_callback.callback(),
505 TRAFFIC_ANNOTATION_FOR_TESTS);
506 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
507 int client_ret = client_socket_->Read(
508 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
509 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
510
511 server_ret = write_callback.GetResult(server_ret);
512 EXPECT_GT(server_ret, 0);
513 client_ret = read_callback.GetResult(client_ret);
514 ASSERT_GT(client_ret, 0);
515 }
516
517 std::unique_ptr<FakeDataChannel> channel_1_;
518 std::unique_ptr<FakeDataChannel> channel_2_;
519 SSLConfig client_ssl_config_;
520 SSLServerConfig server_ssl_config_;
521 std::unique_ptr<TestSSLConfigService> ssl_config_service_;
522 std::unique_ptr<MockCertVerifier> cert_verifier_;
523 std::unique_ptr<MockClientCertVerifier> client_cert_verifier_;
524 std::unique_ptr<TransportSecurityState> transport_security_state_;
525 std::unique_ptr<MockCTPolicyEnforcer> ct_policy_enforcer_;
526 std::unique_ptr<SSLClientSessionCache> ssl_client_session_cache_;
527 std::unique_ptr<SSLClientContext> client_context_;
528 std::unique_ptr<SSLServerContext> server_context_;
529 std::unique_ptr<SSLClientSocket> client_socket_;
530 std::unique_ptr<SSLServerSocket> server_socket_;
531 std::unique_ptr<crypto::RSAPrivateKey> server_private_key_;
532 scoped_refptr<SSLPrivateKey> server_ssl_private_key_;
533 scoped_refptr<X509Certificate> server_cert_;
534 };
535
536 class SSLServerSocketReadTest : public SSLServerSocketTest,
537 public ::testing::WithParamInterface<bool> {
538 protected:
SSLServerSocketReadTest()539 SSLServerSocketReadTest() : read_if_ready_enabled_(GetParam()) {}
540
Read(StreamSocket * socket,IOBuffer * buf,int buf_len,CompletionOnceCallback callback)541 int Read(StreamSocket* socket,
542 IOBuffer* buf,
543 int buf_len,
544 CompletionOnceCallback callback) {
545 if (read_if_ready_enabled()) {
546 return socket->ReadIfReady(buf, buf_len, std::move(callback));
547 }
548 return socket->Read(buf, buf_len, std::move(callback));
549 }
550
read_if_ready_enabled() const551 bool read_if_ready_enabled() const { return read_if_ready_enabled_; }
552
553 private:
554 const bool read_if_ready_enabled_;
555 };
556
557 INSTANTIATE_TEST_SUITE_P(/* no prefix */,
558 SSLServerSocketReadTest,
559 ::testing::Bool());
560
561 // This test only executes creation of client and server sockets. This is to
562 // test that creation of sockets doesn't crash and have minimal code to run
563 // with memory leak/corruption checking tools.
TEST_F(SSLServerSocketTest,Initialize)564 TEST_F(SSLServerSocketTest, Initialize) {
565 ASSERT_NO_FATAL_FAILURE(CreateContext());
566 ASSERT_NO_FATAL_FAILURE(CreateSockets());
567 }
568
569 // This test executes Connect() on SSLClientSocket and Handshake() on
570 // SSLServerSocket to make sure handshaking between the two sockets is
571 // completed successfully.
TEST_F(SSLServerSocketTest,Handshake)572 TEST_F(SSLServerSocketTest, Handshake) {
573 ASSERT_NO_FATAL_FAILURE(CreateContext());
574 ASSERT_NO_FATAL_FAILURE(CreateSockets());
575
576 TestCompletionCallback handshake_callback;
577 int server_ret = server_socket_->Handshake(handshake_callback.callback());
578
579 TestCompletionCallback connect_callback;
580 int client_ret = client_socket_->Connect(connect_callback.callback());
581
582 client_ret = connect_callback.GetResult(client_ret);
583 server_ret = handshake_callback.GetResult(server_ret);
584
585 ASSERT_THAT(client_ret, IsOk());
586 ASSERT_THAT(server_ret, IsOk());
587
588 // Make sure the cert status is expected.
589 SSLInfo ssl_info;
590 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
591 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
592
593 // The default cipher suite should be ECDHE and an AEAD.
594 uint16_t cipher_suite =
595 SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
596 const char* key_exchange;
597 const char* cipher;
598 const char* mac;
599 bool is_aead;
600 bool is_tls13;
601 SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
602 cipher_suite);
603 EXPECT_TRUE(is_aead);
604 }
605
606 // This test makes sure the session cache is working.
TEST_F(SSLServerSocketTest,HandshakeCached)607 TEST_F(SSLServerSocketTest, HandshakeCached) {
608 ASSERT_NO_FATAL_FAILURE(CreateContext());
609 ASSERT_NO_FATAL_FAILURE(CreateSockets());
610
611 TestCompletionCallback handshake_callback;
612 int server_ret = server_socket_->Handshake(handshake_callback.callback());
613
614 TestCompletionCallback connect_callback;
615 int client_ret = client_socket_->Connect(connect_callback.callback());
616
617 client_ret = connect_callback.GetResult(client_ret);
618 server_ret = handshake_callback.GetResult(server_ret);
619
620 ASSERT_THAT(client_ret, IsOk());
621 ASSERT_THAT(server_ret, IsOk());
622
623 // Make sure the cert status is expected.
624 SSLInfo ssl_info;
625 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
626 EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
627 SSLInfo ssl_server_info;
628 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
629 EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
630
631 // Pump client read to get new session tickets.
632 PumpServerToClient();
633
634 // Make sure the second connection is cached.
635 ASSERT_NO_FATAL_FAILURE(CreateSockets());
636 TestCompletionCallback handshake_callback2;
637 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
638
639 TestCompletionCallback connect_callback2;
640 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
641
642 client_ret2 = connect_callback2.GetResult(client_ret2);
643 server_ret2 = handshake_callback2.GetResult(server_ret2);
644
645 ASSERT_THAT(client_ret2, IsOk());
646 ASSERT_THAT(server_ret2, IsOk());
647
648 // Make sure the cert status is expected.
649 SSLInfo ssl_info2;
650 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
651 EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
652 SSLInfo ssl_server_info2;
653 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
654 EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
655 }
656
657 // This test makes sure the session cache separates out by server context.
TEST_F(SSLServerSocketTest,HandshakeCachedContextSwitch)658 TEST_F(SSLServerSocketTest, HandshakeCachedContextSwitch) {
659 ASSERT_NO_FATAL_FAILURE(CreateContext());
660 ASSERT_NO_FATAL_FAILURE(CreateSockets());
661
662 TestCompletionCallback handshake_callback;
663 int server_ret = server_socket_->Handshake(handshake_callback.callback());
664
665 TestCompletionCallback connect_callback;
666 int client_ret = client_socket_->Connect(connect_callback.callback());
667
668 client_ret = connect_callback.GetResult(client_ret);
669 server_ret = handshake_callback.GetResult(server_ret);
670
671 ASSERT_THAT(client_ret, IsOk());
672 ASSERT_THAT(server_ret, IsOk());
673
674 // Make sure the cert status is expected.
675 SSLInfo ssl_info;
676 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
677 EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
678 SSLInfo ssl_server_info;
679 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
680 EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
681
682 // Make sure the second connection is NOT cached when using a new context.
683 ASSERT_NO_FATAL_FAILURE(CreateContext());
684 ASSERT_NO_FATAL_FAILURE(CreateSockets());
685
686 TestCompletionCallback handshake_callback2;
687 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
688
689 TestCompletionCallback connect_callback2;
690 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
691
692 client_ret2 = connect_callback2.GetResult(client_ret2);
693 server_ret2 = handshake_callback2.GetResult(server_ret2);
694
695 ASSERT_THAT(client_ret2, IsOk());
696 ASSERT_THAT(server_ret2, IsOk());
697
698 // Make sure the cert status is expected.
699 SSLInfo ssl_info2;
700 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
701 EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_FULL);
702 SSLInfo ssl_server_info2;
703 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
704 EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_FULL);
705 }
706
707 // Client certificates are disabled on iOS.
708 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
709 // This test executes Connect() on SSLClientSocket and Handshake() on
710 // SSLServerSocket to make sure handshaking between the two sockets is
711 // completed successfully, using client certificate.
TEST_F(SSLServerSocketTest,HandshakeWithClientCert)712 TEST_F(SSLServerSocketTest, HandshakeWithClientCert) {
713 scoped_refptr<X509Certificate> client_cert =
714 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
715 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
716 kClientCertFileName, kClientPrivateKeyFileName));
717 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
718 ASSERT_NO_FATAL_FAILURE(CreateContext());
719 ASSERT_NO_FATAL_FAILURE(CreateSockets());
720
721 TestCompletionCallback handshake_callback;
722 int server_ret = server_socket_->Handshake(handshake_callback.callback());
723
724 TestCompletionCallback connect_callback;
725 int client_ret = client_socket_->Connect(connect_callback.callback());
726
727 client_ret = connect_callback.GetResult(client_ret);
728 server_ret = handshake_callback.GetResult(server_ret);
729
730 ASSERT_THAT(client_ret, IsOk());
731 ASSERT_THAT(server_ret, IsOk());
732
733 // Make sure the cert status is expected.
734 SSLInfo ssl_info;
735 client_socket_->GetSSLInfo(&ssl_info);
736 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
737 server_socket_->GetSSLInfo(&ssl_info);
738 ASSERT_TRUE(ssl_info.cert.get());
739 EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_info.cert.get()));
740 }
741
742 // This test executes Connect() on SSLClientSocket and Handshake() twice on
743 // SSLServerSocket to make sure handshaking between the two sockets is
744 // completed successfully, using client certificate. The second connection is
745 // expected to succeed through the session cache.
TEST_F(SSLServerSocketTest,HandshakeWithClientCertCached)746 TEST_F(SSLServerSocketTest, HandshakeWithClientCertCached) {
747 scoped_refptr<X509Certificate> client_cert =
748 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
749 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
750 kClientCertFileName, kClientPrivateKeyFileName));
751 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
752 ASSERT_NO_FATAL_FAILURE(CreateContext());
753 ASSERT_NO_FATAL_FAILURE(CreateSockets());
754
755 TestCompletionCallback handshake_callback;
756 int server_ret = server_socket_->Handshake(handshake_callback.callback());
757
758 TestCompletionCallback connect_callback;
759 int client_ret = client_socket_->Connect(connect_callback.callback());
760
761 client_ret = connect_callback.GetResult(client_ret);
762 server_ret = handshake_callback.GetResult(server_ret);
763
764 ASSERT_THAT(client_ret, IsOk());
765 ASSERT_THAT(server_ret, IsOk());
766
767 // Make sure the cert status is expected.
768 SSLInfo ssl_info;
769 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
770 EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
771 SSLInfo ssl_server_info;
772 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
773 ASSERT_TRUE(ssl_server_info.cert.get());
774 EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_server_info.cert.get()));
775 EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
776 // Pump client read to get new session tickets.
777 PumpServerToClient();
778 server_socket_->Disconnect();
779 client_socket_->Disconnect();
780
781 // Create the connection again.
782 ASSERT_NO_FATAL_FAILURE(CreateSockets());
783 TestCompletionCallback handshake_callback2;
784 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
785
786 TestCompletionCallback connect_callback2;
787 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
788
789 client_ret2 = connect_callback2.GetResult(client_ret2);
790 server_ret2 = handshake_callback2.GetResult(server_ret2);
791
792 ASSERT_THAT(client_ret2, IsOk());
793 ASSERT_THAT(server_ret2, IsOk());
794
795 // Make sure the cert status is expected.
796 SSLInfo ssl_info2;
797 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
798 EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
799 SSLInfo ssl_server_info2;
800 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
801 ASSERT_TRUE(ssl_server_info2.cert.get());
802 EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_server_info2.cert.get()));
803 EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
804 }
805
TEST_F(SSLServerSocketTest,HandshakeWithClientCertRequiredNotSupplied)806 TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSupplied) {
807 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
808 ASSERT_NO_FATAL_FAILURE(CreateContext());
809 ASSERT_NO_FATAL_FAILURE(CreateSockets());
810 // Use the default setting for the client socket, which is to not send
811 // a client certificate. This will cause the client to receive an
812 // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
813 // requested cert_authorities from the CertificateRequest sent by the
814 // server.
815
816 TestCompletionCallback handshake_callback;
817 int server_ret = server_socket_->Handshake(handshake_callback.callback());
818
819 TestCompletionCallback connect_callback;
820 EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
821 connect_callback.GetResult(
822 client_socket_->Connect(connect_callback.callback())));
823
824 auto request_info = base::MakeRefCounted<SSLCertRequestInfo>();
825 client_socket_->GetSSLCertRequestInfo(request_info.get());
826
827 // Check that the authority name that arrived in the CertificateRequest
828 // handshake message is as expected.
829 scoped_refptr<X509Certificate> client_cert =
830 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
831 ASSERT_TRUE(client_cert);
832 EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
833
834 client_socket_->Disconnect();
835
836 EXPECT_THAT(handshake_callback.GetResult(server_ret),
837 IsError(ERR_CONNECTION_CLOSED));
838 }
839
TEST_F(SSLServerSocketTest,HandshakeWithClientCertRequiredNotSuppliedCached)840 TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSuppliedCached) {
841 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
842 ASSERT_NO_FATAL_FAILURE(CreateContext());
843 ASSERT_NO_FATAL_FAILURE(CreateSockets());
844 // Use the default setting for the client socket, which is to not send
845 // a client certificate. This will cause the client to receive an
846 // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
847 // requested cert_authorities from the CertificateRequest sent by the
848 // server.
849
850 TestCompletionCallback handshake_callback;
851 int server_ret = server_socket_->Handshake(handshake_callback.callback());
852
853 TestCompletionCallback connect_callback;
854 EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
855 connect_callback.GetResult(
856 client_socket_->Connect(connect_callback.callback())));
857
858 auto request_info = base::MakeRefCounted<SSLCertRequestInfo>();
859 client_socket_->GetSSLCertRequestInfo(request_info.get());
860
861 // Check that the authority name that arrived in the CertificateRequest
862 // handshake message is as expected.
863 scoped_refptr<X509Certificate> client_cert =
864 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
865 ASSERT_TRUE(client_cert);
866 EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
867
868 client_socket_->Disconnect();
869
870 EXPECT_THAT(handshake_callback.GetResult(server_ret),
871 IsError(ERR_CONNECTION_CLOSED));
872 server_socket_->Disconnect();
873
874 // Below, check that the cache didn't store the result of a failed handshake.
875 ASSERT_NO_FATAL_FAILURE(CreateSockets());
876 TestCompletionCallback handshake_callback2;
877 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
878
879 TestCompletionCallback connect_callback2;
880 EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
881 connect_callback2.GetResult(
882 client_socket_->Connect(connect_callback2.callback())));
883
884 auto request_info2 = base::MakeRefCounted<SSLCertRequestInfo>();
885 client_socket_->GetSSLCertRequestInfo(request_info2.get());
886
887 // Check that the authority name that arrived in the CertificateRequest
888 // handshake message is as expected.
889 EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info2->cert_authorities));
890
891 client_socket_->Disconnect();
892
893 EXPECT_THAT(handshake_callback2.GetResult(server_ret2),
894 IsError(ERR_CONNECTION_CLOSED));
895 }
896
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSupplied)897 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSupplied) {
898 scoped_refptr<X509Certificate> client_cert =
899 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
900 ASSERT_TRUE(client_cert);
901
902 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
903 kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
904 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
905 ASSERT_NO_FATAL_FAILURE(CreateContext());
906 ASSERT_NO_FATAL_FAILURE(CreateSockets());
907
908 TestCompletionCallback handshake_callback;
909 int server_ret = server_socket_->Handshake(handshake_callback.callback());
910
911 TestCompletionCallback connect_callback;
912 int client_ret = client_socket_->Connect(connect_callback.callback());
913
914 // In TLS 1.3, the client cert error isn't exposed until Read is called.
915 EXPECT_EQ(OK, connect_callback.GetResult(client_ret));
916 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
917 handshake_callback.GetResult(server_ret));
918
919 // Pump client read to get client cert error.
920 const int kReadBufSize = 1024;
921 scoped_refptr<DrainableIOBuffer> read_buf =
922 base::MakeRefCounted<DrainableIOBuffer>(
923 base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
924 TestCompletionCallback read_callback;
925 client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
926 read_callback.callback());
927 client_ret = read_callback.GetResult(client_ret);
928 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
929 }
930
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSuppliedTLS12)931 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedTLS12) {
932 scoped_refptr<X509Certificate> client_cert =
933 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
934 ASSERT_TRUE(client_cert);
935
936 client_ssl_config_.version_max_override = SSL_PROTOCOL_VERSION_TLS1_2;
937 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
938 kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
939 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
940 ASSERT_NO_FATAL_FAILURE(CreateContext());
941 ASSERT_NO_FATAL_FAILURE(CreateSockets());
942
943 TestCompletionCallback handshake_callback;
944 int server_ret = server_socket_->Handshake(handshake_callback.callback());
945
946 TestCompletionCallback connect_callback;
947 int client_ret = client_socket_->Connect(connect_callback.callback());
948
949 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
950 connect_callback.GetResult(client_ret));
951 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
952 handshake_callback.GetResult(server_ret));
953 }
954
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSuppliedCached)955 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedCached) {
956 scoped_refptr<X509Certificate> client_cert =
957 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
958 ASSERT_TRUE(client_cert);
959
960 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
961 kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
962 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
963 ASSERT_NO_FATAL_FAILURE(CreateContext());
964 ASSERT_NO_FATAL_FAILURE(CreateSockets());
965
966 TestCompletionCallback handshake_callback;
967 int server_ret = server_socket_->Handshake(handshake_callback.callback());
968
969 TestCompletionCallback connect_callback;
970 int client_ret = client_socket_->Connect(connect_callback.callback());
971
972 // In TLS 1.3, the client cert error isn't exposed until Read is called.
973 EXPECT_EQ(OK, connect_callback.GetResult(client_ret));
974 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
975 handshake_callback.GetResult(server_ret));
976
977 // Pump client read to get client cert error.
978 const int kReadBufSize = 1024;
979 scoped_refptr<DrainableIOBuffer> read_buf =
980 base::MakeRefCounted<DrainableIOBuffer>(
981 base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
982 TestCompletionCallback read_callback;
983 client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
984 read_callback.callback());
985 client_ret = read_callback.GetResult(client_ret);
986 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
987
988 client_socket_->Disconnect();
989 server_socket_->Disconnect();
990
991 // Below, check that the cache didn't store the result of a failed handshake.
992 ASSERT_NO_FATAL_FAILURE(CreateSockets());
993 TestCompletionCallback handshake_callback2;
994 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
995
996 TestCompletionCallback connect_callback2;
997 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
998
999 // In TLS 1.3, the client cert error isn't exposed until Read is called.
1000 EXPECT_EQ(OK, connect_callback2.GetResult(client_ret2));
1001 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
1002 handshake_callback2.GetResult(server_ret2));
1003
1004 // Pump client read to get client cert error.
1005 client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
1006 read_callback.callback());
1007 client_ret = read_callback.GetResult(client_ret);
1008 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
1009 }
1010 #endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
1011
TEST_P(SSLServerSocketReadTest,DataTransfer)1012 TEST_P(SSLServerSocketReadTest, DataTransfer) {
1013 ASSERT_NO_FATAL_FAILURE(CreateContext());
1014 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1015
1016 // Establish connection.
1017 TestCompletionCallback connect_callback;
1018 int client_ret = client_socket_->Connect(connect_callback.callback());
1019 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1020
1021 TestCompletionCallback handshake_callback;
1022 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1023 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1024
1025 client_ret = connect_callback.GetResult(client_ret);
1026 ASSERT_THAT(client_ret, IsOk());
1027 server_ret = handshake_callback.GetResult(server_ret);
1028 ASSERT_THAT(server_ret, IsOk());
1029
1030 const int kReadBufSize = 1024;
1031 scoped_refptr<StringIOBuffer> write_buf =
1032 base::MakeRefCounted<StringIOBuffer>("testing123");
1033 scoped_refptr<DrainableIOBuffer> read_buf =
1034 base::MakeRefCounted<DrainableIOBuffer>(
1035 base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
1036
1037 // Write then read.
1038 TestCompletionCallback write_callback;
1039 TestCompletionCallback read_callback;
1040 server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
1041 write_callback.callback(),
1042 TRAFFIC_ANNOTATION_FOR_TESTS);
1043 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
1044 client_ret = client_socket_->Read(
1045 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
1046 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1047
1048 server_ret = write_callback.GetResult(server_ret);
1049 EXPECT_GT(server_ret, 0);
1050 client_ret = read_callback.GetResult(client_ret);
1051 ASSERT_GT(client_ret, 0);
1052
1053 read_buf->DidConsume(client_ret);
1054 while (read_buf->BytesConsumed() < write_buf->size()) {
1055 client_ret = client_socket_->Read(
1056 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
1057 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1058 client_ret = read_callback.GetResult(client_ret);
1059 ASSERT_GT(client_ret, 0);
1060 read_buf->DidConsume(client_ret);
1061 }
1062 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
1063 read_buf->SetOffset(0);
1064 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
1065
1066 // Read then write.
1067 write_buf = base::MakeRefCounted<StringIOBuffer>("hello123");
1068 server_ret = Read(server_socket_.get(), read_buf.get(),
1069 read_buf->BytesRemaining(), read_callback.callback());
1070 EXPECT_EQ(server_ret, ERR_IO_PENDING);
1071 client_ret = client_socket_->Write(write_buf.get(), write_buf->size(),
1072 write_callback.callback(),
1073 TRAFFIC_ANNOTATION_FOR_TESTS);
1074 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1075
1076 server_ret = read_callback.GetResult(server_ret);
1077 if (read_if_ready_enabled()) {
1078 // ReadIfReady signals the data is available but does not consume it.
1079 // The data is consumed later below.
1080 ASSERT_EQ(server_ret, OK);
1081 } else {
1082 ASSERT_GT(server_ret, 0);
1083 read_buf->DidConsume(server_ret);
1084 }
1085 client_ret = write_callback.GetResult(client_ret);
1086 EXPECT_GT(client_ret, 0);
1087
1088 while (read_buf->BytesConsumed() < write_buf->size()) {
1089 server_ret = Read(server_socket_.get(), read_buf.get(),
1090 read_buf->BytesRemaining(), read_callback.callback());
1091 // All the data was written above, so the data should be synchronously
1092 // available out of both Read() and ReadIfReady().
1093 ASSERT_GT(server_ret, 0);
1094 read_buf->DidConsume(server_ret);
1095 }
1096 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
1097 read_buf->SetOffset(0);
1098 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
1099 }
1100
1101 // A regression test for bug 127822 (http://crbug.com/127822).
1102 // If the server closes the connection after the handshake is finished,
1103 // the client's Write() call should not cause an infinite loop.
1104 // NOTE: this is a test for SSLClientSocket rather than SSLServerSocket.
TEST_F(SSLServerSocketTest,ClientWriteAfterServerClose)1105 TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
1106 ASSERT_NO_FATAL_FAILURE(CreateContext());
1107 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1108
1109 // Establish connection.
1110 TestCompletionCallback connect_callback;
1111 int client_ret = client_socket_->Connect(connect_callback.callback());
1112 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1113
1114 TestCompletionCallback handshake_callback;
1115 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1116 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1117
1118 client_ret = connect_callback.GetResult(client_ret);
1119 ASSERT_THAT(client_ret, IsOk());
1120 server_ret = handshake_callback.GetResult(server_ret);
1121 ASSERT_THAT(server_ret, IsOk());
1122
1123 scoped_refptr<StringIOBuffer> write_buf =
1124 base::MakeRefCounted<StringIOBuffer>("testing123");
1125
1126 // The server closes the connection. The server needs to write some
1127 // data first so that the client's Read() calls from the transport
1128 // socket won't return ERR_IO_PENDING. This ensures that the client
1129 // will call Read() on the transport socket again.
1130 TestCompletionCallback write_callback;
1131 server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
1132 write_callback.callback(),
1133 TRAFFIC_ANNOTATION_FOR_TESTS);
1134 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
1135
1136 server_ret = write_callback.GetResult(server_ret);
1137 EXPECT_GT(server_ret, 0);
1138
1139 server_socket_->Disconnect();
1140
1141 // The client writes some data. This should not cause an infinite loop.
1142 client_ret = client_socket_->Write(write_buf.get(), write_buf->size(),
1143 write_callback.callback(),
1144 TRAFFIC_ANNOTATION_FOR_TESTS);
1145 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1146
1147 client_ret = write_callback.GetResult(client_ret);
1148 EXPECT_GT(client_ret, 0);
1149
1150 base::RunLoop run_loop;
1151 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
1152 FROM_HERE, run_loop.QuitClosure(), base::Milliseconds(10));
1153 run_loop.Run();
1154 }
1155
1156 // This test executes ExportKeyingMaterial() on the client and server sockets,
1157 // after connecting them, and verifies that the results match.
1158 // This test will fail if False Start is enabled (see crbug.com/90208).
TEST_F(SSLServerSocketTest,ExportKeyingMaterial)1159 TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
1160 ASSERT_NO_FATAL_FAILURE(CreateContext());
1161 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1162
1163 TestCompletionCallback connect_callback;
1164 int client_ret = client_socket_->Connect(connect_callback.callback());
1165 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1166
1167 TestCompletionCallback handshake_callback;
1168 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1169 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1170
1171 if (client_ret == ERR_IO_PENDING) {
1172 ASSERT_THAT(connect_callback.WaitForResult(), IsOk());
1173 }
1174 if (server_ret == ERR_IO_PENDING) {
1175 ASSERT_THAT(handshake_callback.WaitForResult(), IsOk());
1176 }
1177
1178 const int kKeyingMaterialSize = 32;
1179 const char kKeyingLabel[] = "EXPERIMENTAL-server-socket-test";
1180 const char kKeyingContext[] = "";
1181 unsigned char server_out[kKeyingMaterialSize];
1182 int rv = server_socket_->ExportKeyingMaterial(
1183 kKeyingLabel, false, kKeyingContext, server_out, sizeof(server_out));
1184 ASSERT_THAT(rv, IsOk());
1185
1186 unsigned char client_out[kKeyingMaterialSize];
1187 rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, false, kKeyingContext,
1188 client_out, sizeof(client_out));
1189 ASSERT_THAT(rv, IsOk());
1190 EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out)));
1191
1192 const char kKeyingLabelBad[] = "EXPERIMENTAL-server-socket-test-bad";
1193 unsigned char client_bad[kKeyingMaterialSize];
1194 rv = client_socket_->ExportKeyingMaterial(
1195 kKeyingLabelBad, false, kKeyingContext, client_bad, sizeof(client_bad));
1196 ASSERT_EQ(rv, OK);
1197 EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out)));
1198 }
1199
1200 // Verifies that SSLConfig::require_ecdhe flags works properly.
TEST_F(SSLServerSocketTest,RequireEcdheFlag)1201 TEST_F(SSLServerSocketTest, RequireEcdheFlag) {
1202 // Disable all ECDHE suites on the client side.
1203 SSLContextConfig config;
1204 config.disabled_cipher_suites.assign(
1205 kEcdheCiphers, kEcdheCiphers + std::size(kEcdheCiphers));
1206
1207 // Legacy RSA key exchange ciphers only exist in TLS 1.2 and below.
1208 config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
1209 ssl_config_service_->UpdateSSLConfigAndNotify(config);
1210
1211 // Require ECDHE on the server.
1212 server_ssl_config_.require_ecdhe = true;
1213
1214 ASSERT_NO_FATAL_FAILURE(CreateContext());
1215 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1216
1217 TestCompletionCallback connect_callback;
1218 int client_ret = client_socket_->Connect(connect_callback.callback());
1219
1220 TestCompletionCallback handshake_callback;
1221 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1222
1223 client_ret = connect_callback.GetResult(client_ret);
1224 server_ret = handshake_callback.GetResult(server_ret);
1225
1226 ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1227 ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1228 }
1229
1230 // This test executes Connect() on SSLClientSocket and Handshake() on
1231 // SSLServerSocket to make sure handshaking between the two sockets is
1232 // completed successfully. The server key is represented by SSLPrivateKey.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKey)1233 TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKey) {
1234 ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1235 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1236
1237 TestCompletionCallback handshake_callback;
1238 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1239
1240 TestCompletionCallback connect_callback;
1241 int client_ret = client_socket_->Connect(connect_callback.callback());
1242
1243 client_ret = connect_callback.GetResult(client_ret);
1244 server_ret = handshake_callback.GetResult(server_ret);
1245
1246 ASSERT_THAT(client_ret, IsOk());
1247 ASSERT_THAT(server_ret, IsOk());
1248
1249 // Make sure the cert status is expected.
1250 SSLInfo ssl_info;
1251 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
1252 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
1253
1254 // The default cipher suite should be ECDHE and an AEAD.
1255 uint16_t cipher_suite =
1256 SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
1257 const char* key_exchange;
1258 const char* cipher;
1259 const char* mac;
1260 bool is_aead;
1261 bool is_tls13;
1262 SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
1263 cipher_suite);
1264 EXPECT_TRUE(is_aead);
1265 }
1266
1267 namespace {
1268
1269 // Helper that wraps an underlying SSLPrivateKey to allow the test to
1270 // do some work immediately before a `Sign()` operation is performed.
1271 class SSLPrivateKeyHook : public SSLPrivateKey {
1272 public:
SSLPrivateKeyHook(scoped_refptr<SSLPrivateKey> private_key,base::RepeatingClosure on_sign)1273 SSLPrivateKeyHook(scoped_refptr<SSLPrivateKey> private_key,
1274 base::RepeatingClosure on_sign)
1275 : private_key_(std::move(private_key)), on_sign_(std::move(on_sign)) {}
1276
1277 // SSLPrivateKey implementation.
GetProviderName()1278 std::string GetProviderName() override {
1279 return private_key_->GetProviderName();
1280 }
GetAlgorithmPreferences()1281 std::vector<uint16_t> GetAlgorithmPreferences() override {
1282 return private_key_->GetAlgorithmPreferences();
1283 }
Sign(uint16_t algorithm,base::span<const uint8_t> input,SignCallback callback)1284 void Sign(uint16_t algorithm,
1285 base::span<const uint8_t> input,
1286 SignCallback callback) override {
1287 on_sign_.Run();
1288 private_key_->Sign(algorithm, input, std::move(callback));
1289 }
1290
1291 private:
1292 ~SSLPrivateKeyHook() override = default;
1293
1294 const scoped_refptr<SSLPrivateKey> private_key_;
1295 const base::RepeatingClosure on_sign_;
1296 };
1297
1298 } // namespace
1299
1300 // Verifies that if the client disconnects while during private key signing then
1301 // the disconnection is correctly reported to the `Handshake()` completion
1302 // callback, with `ERR_CONNECTION_CLOSED`.
1303 // This is a regression test for crbug.com/1449461.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKeyDisconnectDuringSigning_ReturnsError)1304 TEST_F(SSLServerSocketTest,
1305 HandshakeServerSSLPrivateKeyDisconnectDuringSigning_ReturnsError) {
1306 auto on_sign = base::BindLambdaForTesting([&]() {
1307 client_socket_->Disconnect();
1308 ASSERT_FALSE(client_socket_->IsConnected());
1309 });
1310 server_ssl_private_key_ = base::MakeRefCounted<SSLPrivateKeyHook>(
1311 std::move(server_ssl_private_key_), on_sign);
1312 ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1313 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1314
1315 TestCompletionCallback handshake_callback;
1316 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1317 ASSERT_EQ(server_ret, net::ERR_IO_PENDING);
1318
1319 TestCompletionCallback connect_callback;
1320 client_socket_->Connect(connect_callback.callback());
1321
1322 // If resuming the handshake after private-key signing is not handled
1323 // correctly as per crbug.com/1449461 then the test will hang and timeout
1324 // at this point, due to the server-side completion callback not being
1325 // correctly invoked.
1326 server_ret = handshake_callback.GetResult(server_ret);
1327 EXPECT_EQ(server_ret, net::ERR_CONNECTION_CLOSED);
1328 }
1329
1330 // Verifies that non-ECDHE ciphers are disabled when using SSLPrivateKey as the
1331 // server key.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKeyRequireEcdhe)1332 TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKeyRequireEcdhe) {
1333 // Disable all ECDHE suites on the client side.
1334 SSLContextConfig config;
1335 config.disabled_cipher_suites.assign(
1336 kEcdheCiphers, kEcdheCiphers + std::size(kEcdheCiphers));
1337 // TLS 1.3 always works with SSLPrivateKey.
1338 config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
1339 ssl_config_service_->UpdateSSLConfigAndNotify(config);
1340
1341 ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1342 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1343
1344 TestCompletionCallback connect_callback;
1345 int client_ret = client_socket_->Connect(connect_callback.callback());
1346
1347 TestCompletionCallback handshake_callback;
1348 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1349
1350 client_ret = connect_callback.GetResult(client_ret);
1351 server_ret = handshake_callback.GetResult(server_ret);
1352
1353 ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1354 ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1355 }
1356
1357 class SSLServerSocketAlpsTest
1358 : public SSLServerSocketTest,
1359 public ::testing::WithParamInterface<std::tuple<bool, bool>> {
1360 public:
SSLServerSocketAlpsTest()1361 SSLServerSocketAlpsTest()
1362 : client_alps_enabled_(std::get<0>(GetParam())),
1363 server_alps_enabled_(std::get<1>(GetParam())) {}
1364 ~SSLServerSocketAlpsTest() override = default;
1365 const bool client_alps_enabled_;
1366 const bool server_alps_enabled_;
1367 };
1368
1369 INSTANTIATE_TEST_SUITE_P(All,
1370 SSLServerSocketAlpsTest,
1371 ::testing::Combine(::testing::Bool(),
1372 ::testing::Bool()));
1373
TEST_P(SSLServerSocketAlpsTest,Alps)1374 TEST_P(SSLServerSocketAlpsTest, Alps) {
1375 const std::string server_data = "server sends some test data";
1376 const std::string client_data = "client also sends some data";
1377
1378 server_ssl_config_.alpn_protos = {kProtoHTTP2};
1379 if (server_alps_enabled_) {
1380 server_ssl_config_.application_settings[kProtoHTTP2] =
1381 std::vector<uint8_t>(server_data.begin(), server_data.end());
1382 }
1383
1384 client_ssl_config_.alpn_protos = {kProtoHTTP2};
1385 if (client_alps_enabled_) {
1386 client_ssl_config_.application_settings[kProtoHTTP2] =
1387 std::vector<uint8_t>(client_data.begin(), client_data.end());
1388 }
1389
1390 ASSERT_NO_FATAL_FAILURE(CreateContext());
1391 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1392
1393 TestCompletionCallback handshake_callback;
1394 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1395
1396 TestCompletionCallback connect_callback;
1397 int client_ret = client_socket_->Connect(connect_callback.callback());
1398
1399 client_ret = connect_callback.GetResult(client_ret);
1400 server_ret = handshake_callback.GetResult(server_ret);
1401
1402 ASSERT_THAT(client_ret, IsOk());
1403 ASSERT_THAT(server_ret, IsOk());
1404
1405 // ALPS is negotiated only if ALPS is enabled both on client and server.
1406 const auto alps_data_received_by_client =
1407 client_socket_->GetPeerApplicationSettings();
1408 const auto alps_data_received_by_server =
1409 server_socket_->GetPeerApplicationSettings();
1410
1411 if (client_alps_enabled_ && server_alps_enabled_) {
1412 ASSERT_TRUE(alps_data_received_by_client.has_value());
1413 EXPECT_EQ(server_data, alps_data_received_by_client.value());
1414 ASSERT_TRUE(alps_data_received_by_server.has_value());
1415 EXPECT_EQ(client_data, alps_data_received_by_server.value());
1416 } else {
1417 EXPECT_FALSE(alps_data_received_by_client.has_value());
1418 EXPECT_FALSE(alps_data_received_by_server.has_value());
1419 }
1420 }
1421
1422 // Test that CancelReadIfReady works.
TEST_F(SSLServerSocketTest,CancelReadIfReady)1423 TEST_F(SSLServerSocketTest, CancelReadIfReady) {
1424 ASSERT_NO_FATAL_FAILURE(CreateContext());
1425 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1426
1427 TestCompletionCallback connect_callback;
1428 int client_ret = client_socket_->Connect(connect_callback.callback());
1429 TestCompletionCallback handshake_callback;
1430 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1431 ASSERT_THAT(connect_callback.GetResult(client_ret), IsOk());
1432 ASSERT_THAT(handshake_callback.GetResult(server_ret), IsOk());
1433
1434 // Attempt to read from the server socket. There will not be anything to read.
1435 // Cancel the read immediately afterwards.
1436 TestCompletionCallback read_callback;
1437 auto read_buf = base::MakeRefCounted<IOBufferWithSize>(1);
1438 int read_ret =
1439 server_socket_->ReadIfReady(read_buf.get(), 1, read_callback.callback());
1440 ASSERT_THAT(read_ret, IsError(ERR_IO_PENDING));
1441 ASSERT_THAT(server_socket_->CancelReadIfReady(), IsOk());
1442
1443 // After the client writes data, the server should still not pick up a result.
1444 auto write_buf = base::MakeRefCounted<StringIOBuffer>("a");
1445 TestCompletionCallback write_callback;
1446 ASSERT_EQ(write_callback.GetResult(client_socket_->Write(
1447 write_buf.get(), write_buf->size(), write_callback.callback(),
1448 TRAFFIC_ANNOTATION_FOR_TESTS)),
1449 write_buf->size());
1450 base::RunLoop().RunUntilIdle();
1451 EXPECT_FALSE(read_callback.have_result());
1452
1453 // After a canceled read, future reads are still possible.
1454 while (true) {
1455 TestCompletionCallback read_callback2;
1456 read_ret = server_socket_->ReadIfReady(read_buf.get(), 1,
1457 read_callback2.callback());
1458 if (read_ret != ERR_IO_PENDING) {
1459 break;
1460 }
1461 ASSERT_THAT(read_callback2.GetResult(read_ret), IsOk());
1462 }
1463 ASSERT_EQ(1, read_ret);
1464 EXPECT_EQ(read_buf->data()[0], 'a');
1465 }
1466
1467 } // namespace net
1468