• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 // This test suite uses SSLClientSocket to test the implementation of
6 // SSLServerSocket. In order to establish connections between the sockets
7 // we need two additional classes:
8 // 1. FakeSocket
9 //    Connects SSL socket to FakeDataChannel. This class is just a stub.
10 //
11 // 2. FakeDataChannel
12 //    Implements the actual exchange of data between two FakeSockets.
13 //
14 // Implementations of these two classes are included in this file.
15 
16 #include "net/socket/ssl_server_socket.h"
17 
18 #include <queue>
19 
20 #include "base/file_path.h"
21 #include "base/file_util.h"
22 #include "base/path_service.h"
23 #include "crypto/nss_util.h"
24 #include "crypto/rsa_private_key.h"
25 #include "net/base/address_list.h"
26 #include "net/base/cert_status_flags.h"
27 #include "net/base/cert_verifier.h"
28 #include "net/base/host_port_pair.h"
29 #include "net/base/io_buffer.h"
30 #include "net/base/ip_endpoint.h"
31 #include "net/base/net_errors.h"
32 #include "net/base/net_log.h"
33 #include "net/base/ssl_config_service.h"
34 #include "net/base/x509_certificate.h"
35 #include "net/socket/client_socket.h"
36 #include "net/socket/client_socket_factory.h"
37 #include "net/socket/socket_test_util.h"
38 #include "net/socket/ssl_client_socket.h"
39 #include "testing/gtest/include/gtest/gtest.h"
40 #include "testing/platform_test.h"
41 
42 namespace net {
43 
44 namespace {
45 
46 class FakeDataChannel {
47  public:
FakeDataChannel()48   FakeDataChannel() : read_callback_(NULL), read_buf_len_(0) {
49   }
50 
Read(IOBuffer * buf,int buf_len,CompletionCallback * callback)51   virtual int Read(IOBuffer* buf, int buf_len,
52                    CompletionCallback* callback) {
53     if (data_.empty()) {
54       read_callback_ = callback;
55       read_buf_ = buf;
56       read_buf_len_ = buf_len;
57       return net::ERR_IO_PENDING;
58     }
59     return PropogateData(buf, buf_len);
60   }
61 
Write(IOBuffer * buf,int buf_len,CompletionCallback * callback)62   virtual int Write(IOBuffer* buf, int buf_len,
63                     CompletionCallback* callback) {
64     data_.push(new net::DrainableIOBuffer(buf, buf_len));
65     DoReadCallback();
66     return buf_len;
67   }
68 
69  private:
DoReadCallback()70   void DoReadCallback() {
71     if (!read_callback_)
72       return;
73 
74     int copied = PropogateData(read_buf_, read_buf_len_);
75     net::CompletionCallback* callback = read_callback_;
76     read_callback_ = NULL;
77     read_buf_ = NULL;
78     read_buf_len_ = 0;
79     callback->Run(copied);
80   }
81 
PropogateData(scoped_refptr<net::IOBuffer> read_buf,int read_buf_len)82   int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) {
83     scoped_refptr<net::DrainableIOBuffer> buf = data_.front();
84     int copied = std::min(buf->BytesRemaining(), read_buf_len);
85     memcpy(read_buf->data(), buf->data(), copied);
86     buf->DidConsume(copied);
87 
88     if (!buf->BytesRemaining())
89       data_.pop();
90     return copied;
91   }
92 
93   net::CompletionCallback* read_callback_;
94   scoped_refptr<net::IOBuffer> read_buf_;
95   int read_buf_len_;
96 
97   std::queue<scoped_refptr<net::DrainableIOBuffer> > data_;
98 
99   DISALLOW_COPY_AND_ASSIGN(FakeDataChannel);
100 };
101 
102 class FakeSocket : public ClientSocket {
103  public:
FakeSocket(FakeDataChannel * incoming_channel,FakeDataChannel * outgoing_channel)104   FakeSocket(FakeDataChannel* incoming_channel,
105              FakeDataChannel* outgoing_channel)
106       : incoming_(incoming_channel),
107         outgoing_(outgoing_channel) {
108   }
109 
~FakeSocket()110   virtual ~FakeSocket() {
111 
112   }
113 
Read(IOBuffer * buf,int buf_len,CompletionCallback * callback)114   virtual int Read(IOBuffer* buf, int buf_len,
115                    CompletionCallback* callback) {
116     return incoming_->Read(buf, buf_len, callback);
117   }
118 
Write(IOBuffer * buf,int buf_len,CompletionCallback * callback)119   virtual int Write(IOBuffer* buf, int buf_len,
120                     CompletionCallback* callback) {
121     return outgoing_->Write(buf, buf_len, callback);
122   }
123 
SetReceiveBufferSize(int32 size)124   virtual bool SetReceiveBufferSize(int32 size) {
125     return true;
126   }
127 
SetSendBufferSize(int32 size)128   virtual bool SetSendBufferSize(int32 size) {
129     return true;
130   }
131 
Connect(CompletionCallback * callback)132   virtual int Connect(CompletionCallback* callback) {
133     return net::OK;
134   }
135 
Disconnect()136   virtual void Disconnect() {}
137 
IsConnected() const138   virtual bool IsConnected() const {
139     return true;
140   }
141 
IsConnectedAndIdle() const142   virtual bool IsConnectedAndIdle() const {
143     return true;
144   }
145 
GetPeerAddress(AddressList * address) const146   virtual int GetPeerAddress(AddressList* address) const {
147     net::IPAddressNumber ip_address(4);
148     *address = net::AddressList(ip_address, 0, false);
149     return net::OK;
150   }
151 
GetLocalAddress(IPEndPoint * address) const152   virtual int GetLocalAddress(IPEndPoint* address) const {
153     net::IPAddressNumber ip_address(4);
154     *address = net::IPEndPoint(ip_address, 0);
155     return net::OK;
156   }
157 
NetLog() const158   virtual const BoundNetLog& NetLog() const {
159     return net_log_;
160   }
161 
SetSubresourceSpeculation()162   virtual void SetSubresourceSpeculation() {}
SetOmniboxSpeculation()163   virtual void SetOmniboxSpeculation() {}
164 
WasEverUsed() const165   virtual bool WasEverUsed() const {
166     return true;
167   }
168 
UsingTCPFastOpen() const169   virtual bool UsingTCPFastOpen() const {
170     return false;
171   }
172 
173  private:
174   net::BoundNetLog net_log_;
175   FakeDataChannel* incoming_;
176   FakeDataChannel* outgoing_;
177 
178   DISALLOW_COPY_AND_ASSIGN(FakeSocket);
179 };
180 
181 }  // namespace
182 
183 // Verify the correctness of the test helper classes first.
TEST(FakeSocketTest,DataTransfer)184 TEST(FakeSocketTest, DataTransfer) {
185   // Establish channels between two sockets.
186   FakeDataChannel channel_1;
187   FakeDataChannel channel_2;
188   FakeSocket client(&channel_1, &channel_2);
189   FakeSocket server(&channel_2, &channel_1);
190 
191   const char kTestData[] = "testing123";
192   const int kTestDataSize = strlen(kTestData);
193   const int kReadBufSize = 1024;
194   scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData);
195   scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize);
196 
197   // Write then read.
198   EXPECT_EQ(kTestDataSize, server.Write(write_buf, kTestDataSize, NULL));
199   EXPECT_EQ(kTestDataSize, client.Read(read_buf, kReadBufSize, NULL));
200   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), kTestDataSize));
201 
202   // Read then write.
203   TestCompletionCallback callback;
204   EXPECT_EQ(net::ERR_IO_PENDING,
205             server.Read(read_buf, kReadBufSize, &callback));
206   EXPECT_EQ(kTestDataSize, client.Write(write_buf, kTestDataSize, NULL));
207   EXPECT_EQ(kTestDataSize, callback.WaitForResult());
208   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), kTestDataSize));
209 }
210 
211 class SSLServerSocketTest : public PlatformTest {
212  public:
SSLServerSocketTest()213   SSLServerSocketTest()
214       : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()) {
215   }
216 
217  protected:
Initialize()218   void Initialize() {
219     FakeSocket* fake_client_socket = new FakeSocket(&channel_1_, &channel_2_);
220     FakeSocket* fake_server_socket = new FakeSocket(&channel_2_, &channel_1_);
221 
222     FilePath certs_dir;
223     PathService::Get(base::DIR_SOURCE_ROOT, &certs_dir);
224     certs_dir = certs_dir.AppendASCII("net");
225     certs_dir = certs_dir.AppendASCII("data");
226     certs_dir = certs_dir.AppendASCII("ssl");
227     certs_dir = certs_dir.AppendASCII("certificates");
228 
229     FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
230     std::string cert_der;
231     ASSERT_TRUE(file_util::ReadFileToString(cert_path, &cert_der));
232 
233     scoped_refptr<net::X509Certificate> cert =
234         X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size());
235 
236     FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
237     std::string key_string;
238     ASSERT_TRUE(file_util::ReadFileToString(key_path, &key_string));
239     std::vector<uint8> key_vector(
240         reinterpret_cast<const uint8*>(key_string.data()),
241         reinterpret_cast<const uint8*>(key_string.data() +
242                                        key_string.length()));
243 
244     scoped_ptr<crypto::RSAPrivateKey> private_key(
245         crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
246 
247     net::SSLConfig ssl_config;
248     ssl_config.false_start_enabled = false;
249     ssl_config.ssl3_enabled = true;
250     ssl_config.tls1_enabled = true;
251 
252     // Certificate provided by the host doesn't need authority.
253     net::SSLConfig::CertAndStatus cert_and_status;
254     cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID;
255     cert_and_status.cert = cert;
256     ssl_config.allowed_bad_certs.push_back(cert_and_status);
257 
258     net::HostPortPair host_and_pair("unittest", 0);
259     client_socket_.reset(
260         socket_factory_->CreateSSLClientSocket(
261             fake_client_socket, host_and_pair, ssl_config, NULL,
262             &cert_verifier_));
263     server_socket_.reset(net::CreateSSLServerSocket(fake_server_socket,
264                                                     cert, private_key.get(),
265                                                     net::SSLConfig()));
266   }
267 
268   FakeDataChannel channel_1_;
269   FakeDataChannel channel_2_;
270   scoped_ptr<net::SSLClientSocket> client_socket_;
271   scoped_ptr<net::SSLServerSocket> server_socket_;
272   net::ClientSocketFactory* socket_factory_;
273   net::CertVerifier cert_verifier_;
274 };
275 
276 // SSLServerSocket is only implemented using NSS.
277 #if defined(USE_NSS) || defined(OS_WIN) || defined(OS_MACOSX)
278 
279 // This test only executes creation of client and server sockets. This is to
280 // test that creation of sockets doesn't crash and have minimal code to run
281 // under valgrind in order to help debugging memory problems.
TEST_F(SSLServerSocketTest,Initialize)282 TEST_F(SSLServerSocketTest, Initialize) {
283   Initialize();
284 }
285 
286 // This test executes Connect() of SSLClientSocket and Accept() of
287 // SSLServerSocket to make sure handshaking between the two sockets are
288 // completed successfully.
TEST_F(SSLServerSocketTest,Handshake)289 TEST_F(SSLServerSocketTest, Handshake) {
290   Initialize();
291 
292   TestCompletionCallback connect_callback;
293   TestCompletionCallback accept_callback;
294 
295   int server_ret = server_socket_->Accept(&accept_callback);
296   EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
297 
298   int client_ret = client_socket_->Connect(&connect_callback);
299   EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
300 
301   if (client_ret == net::ERR_IO_PENDING) {
302     EXPECT_EQ(net::OK, connect_callback.WaitForResult());
303   }
304   if (server_ret == net::ERR_IO_PENDING) {
305     EXPECT_EQ(net::OK, accept_callback.WaitForResult());
306   }
307 }
308 
TEST_F(SSLServerSocketTest,DataTransfer)309 TEST_F(SSLServerSocketTest, DataTransfer) {
310   Initialize();
311 
312   TestCompletionCallback connect_callback;
313   TestCompletionCallback accept_callback;
314 
315   // Establish connection.
316   int client_ret = client_socket_->Connect(&connect_callback);
317   ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
318 
319   int server_ret = server_socket_->Accept(&accept_callback);
320   ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
321 
322   if (client_ret == net::ERR_IO_PENDING) {
323     ASSERT_EQ(net::OK, connect_callback.WaitForResult());
324   }
325   if (server_ret == net::ERR_IO_PENDING) {
326     ASSERT_EQ(net::OK, accept_callback.WaitForResult());
327   }
328 
329   const int kReadBufSize = 1024;
330   scoped_refptr<net::StringIOBuffer> write_buf =
331       new net::StringIOBuffer("testing123");
332   scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize);
333 
334   // Write then read.
335   TestCompletionCallback write_callback;
336   TestCompletionCallback read_callback;
337   server_ret = server_socket_->Write(write_buf, write_buf->size(),
338                                      &write_callback);
339   EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
340   client_ret = client_socket_->Read(read_buf, kReadBufSize, &read_callback);
341   EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
342 
343   if (server_ret == net::ERR_IO_PENDING) {
344     EXPECT_GT(write_callback.WaitForResult(), 0);
345   }
346   if (client_ret == net::ERR_IO_PENDING) {
347     EXPECT_GT(read_callback.WaitForResult(), 0);
348   }
349   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
350 
351   // Read then write.
352   write_buf = new net::StringIOBuffer("hello123");
353   server_ret = server_socket_->Read(read_buf, kReadBufSize, &read_callback);
354   EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
355   client_ret = client_socket_->Write(write_buf, write_buf->size(),
356                                      &write_callback);
357   EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
358 
359   if (server_ret == net::ERR_IO_PENDING) {
360     EXPECT_GT(read_callback.WaitForResult(), 0);
361   }
362   if (client_ret == net::ERR_IO_PENDING) {
363     EXPECT_GT(write_callback.WaitForResult(), 0);
364   }
365   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
366 }
367 #endif
368 
369 }  // namespace net
370