• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/socket_test_util.h"
6 
7 #include <algorithm>
8 
9 #include "base/basictypes.h"
10 #include "base/compiler_specific.h"
11 #include "base/message_loop.h"
12 #include "net/base/ssl_info.h"
13 #include "net/socket/socket.h"
14 #include "testing/gtest/include/gtest/gtest.h"
15 
16 namespace net {
17 
MockClientSocket()18 MockClientSocket::MockClientSocket()
19     : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)),
20       connected_(false) {
21 }
22 
GetSSLInfo(net::SSLInfo * ssl_info)23 void MockClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
24   NOTREACHED();
25 }
26 
GetSSLCertRequestInfo(net::SSLCertRequestInfo * cert_request_info)27 void MockClientSocket::GetSSLCertRequestInfo(
28     net::SSLCertRequestInfo* cert_request_info) {
29   NOTREACHED();
30 }
31 
32 SSLClientSocket::NextProtoStatus
GetNextProto(std::string * proto)33 MockClientSocket::GetNextProto(std::string* proto) {
34   proto->clear();
35   return SSLClientSocket::kNextProtoUnsupported;
36 }
37 
Disconnect()38 void MockClientSocket::Disconnect() {
39   connected_ = false;
40 }
41 
IsConnected() const42 bool MockClientSocket::IsConnected() const {
43   return connected_;
44 }
45 
IsConnectedAndIdle() const46 bool MockClientSocket::IsConnectedAndIdle() const {
47   return connected_;
48 }
49 
GetPeerName(struct sockaddr * name,socklen_t * namelen)50 int MockClientSocket::GetPeerName(struct sockaddr* name, socklen_t* namelen) {
51   memset(reinterpret_cast<char *>(name), 0, *namelen);
52   return net::OK;
53 }
54 
RunCallbackAsync(net::CompletionCallback * callback,int result)55 void MockClientSocket::RunCallbackAsync(net::CompletionCallback* callback,
56                                         int result) {
57   MessageLoop::current()->PostTask(FROM_HERE,
58       method_factory_.NewRunnableMethod(
59           &MockClientSocket::RunCallback, callback, result));
60 }
61 
RunCallback(net::CompletionCallback * callback,int result)62 void MockClientSocket::RunCallback(net::CompletionCallback* callback,
63                                    int result) {
64   if (callback)
65     callback->Run(result);
66 }
67 
MockTCPClientSocket(const net::AddressList & addresses,net::SocketDataProvider * data)68 MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses,
69                                          net::SocketDataProvider* data)
70     : addresses_(addresses),
71       data_(data),
72       read_offset_(0),
73       read_data_(false, net::ERR_UNEXPECTED),
74       need_read_data_(true),
75       peer_closed_connection_(false),
76       pending_buf_(NULL),
77       pending_buf_len_(0),
78       pending_callback_(NULL) {
79   DCHECK(data_);
80   data_->Reset();
81 }
82 
Connect(net::CompletionCallback * callback,LoadLog * load_log)83 int MockTCPClientSocket::Connect(net::CompletionCallback* callback,
84                                  LoadLog* load_log) {
85   if (connected_)
86     return net::OK;
87   connected_ = true;
88   if (data_->connect_data().async) {
89     RunCallbackAsync(callback, data_->connect_data().result);
90     return net::ERR_IO_PENDING;
91   }
92   return data_->connect_data().result;
93 }
94 
IsConnected() const95 bool MockTCPClientSocket::IsConnected() const {
96   return connected_ && !peer_closed_connection_;
97 }
98 
Read(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)99 int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len,
100                               net::CompletionCallback* callback) {
101   if (!connected_)
102     return net::ERR_UNEXPECTED;
103 
104   // If the buffer is already in use, a read is already in progress!
105   DCHECK(pending_buf_ == NULL);
106 
107   // Store our async IO data.
108   pending_buf_ = buf;
109   pending_buf_len_ = buf_len;
110   pending_callback_ = callback;
111 
112   if (need_read_data_) {
113     read_data_ = data_->GetNextRead();
114     if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
115       // This MockRead is just a marker to instruct us to set
116       // peer_closed_connection_.  Skip it and get the next one.
117       read_data_ = data_->GetNextRead();
118       peer_closed_connection_ = true;
119     }
120     // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
121     // to complete the async IO manually later (via OnReadComplete).
122     if (read_data_.result == ERR_IO_PENDING) {
123       DCHECK(callback);  // We need to be using async IO in this case.
124       return ERR_IO_PENDING;
125     }
126     need_read_data_ = false;
127   }
128 
129   return CompleteRead();
130 }
131 
Write(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)132 int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len,
133                                net::CompletionCallback* callback) {
134   DCHECK(buf);
135   DCHECK_GT(buf_len, 0);
136 
137   if (!connected_)
138     return net::ERR_UNEXPECTED;
139 
140   std::string data(buf->data(), buf_len);
141   net::MockWriteResult write_result = data_->OnWrite(data);
142 
143   if (write_result.async) {
144     RunCallbackAsync(callback, write_result.result);
145     return net::ERR_IO_PENDING;
146   }
147   return write_result.result;
148 }
149 
OnReadComplete(const MockRead & data)150 void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
151   // There must be a read pending.
152   DCHECK(pending_buf_);
153   // You can't complete a read with another ERR_IO_PENDING status code.
154   DCHECK_NE(ERR_IO_PENDING, data.result);
155   // Since we've been waiting for data, need_read_data_ should be true.
156   DCHECK(need_read_data_);
157 
158   read_data_ = data;
159   need_read_data_ = false;
160 
161   // The caller is simulating that this IO completes right now.  Don't
162   // let CompleteRead() schedule a callback.
163   read_data_.async = false;
164 
165   net::CompletionCallback* callback = pending_callback_;
166   int rv = CompleteRead();
167   RunCallback(callback, rv);
168 }
169 
CompleteRead()170 int MockTCPClientSocket::CompleteRead() {
171   DCHECK(pending_buf_);
172   DCHECK(pending_buf_len_ > 0);
173 
174   // Save the pending async IO data and reset our |pending_| state.
175   net::IOBuffer* buf = pending_buf_;
176   int buf_len = pending_buf_len_;
177   net::CompletionCallback* callback = pending_callback_;
178   pending_buf_ = NULL;
179   pending_buf_len_ = 0;
180   pending_callback_ = NULL;
181 
182   int result = read_data_.result;
183   DCHECK(result != ERR_IO_PENDING);
184 
185   if (read_data_.data) {
186     if (read_data_.data_len - read_offset_ > 0) {
187       result = std::min(buf_len, read_data_.data_len - read_offset_);
188       memcpy(buf->data(), read_data_.data + read_offset_, result);
189       read_offset_ += result;
190       if (read_offset_ == read_data_.data_len) {
191         need_read_data_ = true;
192         read_offset_ = 0;
193       }
194     } else {
195       result = 0;  // EOF
196     }
197   }
198 
199   if (read_data_.async) {
200     DCHECK(callback);
201     RunCallbackAsync(callback, result);
202     return net::ERR_IO_PENDING;
203   }
204   return result;
205 }
206 
207 class MockSSLClientSocket::ConnectCallback :
208     public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> {
209  public:
ConnectCallback(MockSSLClientSocket * ssl_client_socket,net::CompletionCallback * user_callback,int rv)210   ConnectCallback(MockSSLClientSocket *ssl_client_socket,
211                   net::CompletionCallback* user_callback,
212                   int rv)
213       : ALLOW_THIS_IN_INITIALIZER_LIST(
214           net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>(
215                 this, &ConnectCallback::Wrapper)),
216         ssl_client_socket_(ssl_client_socket),
217         user_callback_(user_callback),
218         rv_(rv) {
219   }
220 
221  private:
Wrapper(int rv)222   void Wrapper(int rv) {
223     if (rv_ == net::OK)
224       ssl_client_socket_->connected_ = true;
225     user_callback_->Run(rv_);
226     delete this;
227   }
228 
229   MockSSLClientSocket* ssl_client_socket_;
230   net::CompletionCallback* user_callback_;
231   int rv_;
232 };
233 
MockSSLClientSocket(net::ClientSocket * transport_socket,const std::string & hostname,const net::SSLConfig & ssl_config,net::SSLSocketDataProvider * data)234 MockSSLClientSocket::MockSSLClientSocket(
235     net::ClientSocket* transport_socket,
236     const std::string& hostname,
237     const net::SSLConfig& ssl_config,
238     net::SSLSocketDataProvider* data)
239     : transport_(transport_socket),
240       data_(data) {
241   DCHECK(data_);
242 }
243 
~MockSSLClientSocket()244 MockSSLClientSocket::~MockSSLClientSocket() {
245   Disconnect();
246 }
247 
GetSSLInfo(net::SSLInfo * ssl_info)248 void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
249   ssl_info->Reset();
250 }
251 
Connect(net::CompletionCallback * callback,LoadLog * load_log)252 int MockSSLClientSocket::Connect(net::CompletionCallback* callback,
253                                  LoadLog* load_log) {
254   ConnectCallback* connect_callback = new ConnectCallback(
255       this, callback, data_->connect.result);
256   int rv = transport_->Connect(connect_callback, load_log);
257   if (rv == net::OK) {
258     delete connect_callback;
259     if (data_->connect.async) {
260       RunCallbackAsync(callback, data_->connect.result);
261       return net::ERR_IO_PENDING;
262     }
263     if (data_->connect.result == net::OK)
264       connected_ = true;
265     return data_->connect.result;
266   }
267   return rv;
268 }
269 
Disconnect()270 void MockSSLClientSocket::Disconnect() {
271   MockClientSocket::Disconnect();
272   if (transport_ != NULL)
273     transport_->Disconnect();
274 }
275 
Read(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)276 int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len,
277                               net::CompletionCallback* callback) {
278   return transport_->Read(buf, buf_len, callback);
279 }
280 
Write(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)281 int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len,
282                                net::CompletionCallback* callback) {
283   return transport_->Write(buf, buf_len, callback);
284 }
285 
GetNextRead()286 MockRead StaticSocketDataProvider::GetNextRead() {
287   MockRead rv = reads_[read_index_];
288   if (reads_[read_index_].result != OK ||
289       reads_[read_index_].data_len != 0)
290     read_index_++;  // Don't advance past an EOF.
291   return rv;
292 }
293 
OnWrite(const std::string & data)294 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
295   if (!writes_) {
296     // Not using mock writes; succeed synchronously.
297     return MockWriteResult(false, data.length());
298   }
299 
300   // Check that what we are writing matches the expectation.
301   // Then give the mocked return value.
302   net::MockWrite* w = &writes_[write_index_++];
303   int result = w->result;
304   if (w->data) {
305     // Note - we can simulate a partial write here.  If the expected data
306     // is a match, but shorter than the write actually written, that is legal.
307     // Example:
308     //   Application writes "foobarbaz" (9 bytes)
309     //   Expected write was "foo" (3 bytes)
310     //   This is a success, and we return 3 to the application.
311     std::string expected_data(w->data, w->data_len);
312     EXPECT_GE(data.length(), expected_data.length());
313     std::string actual_data(data.substr(0, w->data_len));
314     EXPECT_EQ(expected_data, actual_data);
315     if (expected_data != actual_data)
316       return MockWriteResult(false, net::ERR_UNEXPECTED);
317     if (result == net::OK)
318       result = w->data_len;
319   }
320   return MockWriteResult(w->async, result);
321 }
322 
Reset()323 void StaticSocketDataProvider::Reset() {
324   read_index_ = 0;
325   write_index_ = 0;
326 }
327 
DynamicSocketDataProvider()328 DynamicSocketDataProvider::DynamicSocketDataProvider()
329     : short_read_limit_(0),
330       allow_unconsumed_reads_(false) {
331 }
332 
GetNextRead()333 MockRead DynamicSocketDataProvider::GetNextRead() {
334   if (reads_.empty())
335     return MockRead(false, ERR_UNEXPECTED);
336   MockRead result = reads_.front();
337   if (short_read_limit_ == 0 || result.data_len <= short_read_limit_) {
338     reads_.pop_front();
339   } else {
340     result.data_len = short_read_limit_;
341     reads_.front().data += result.data_len;
342     reads_.front().data_len -= result.data_len;
343   }
344   return result;
345 }
346 
Reset()347 void DynamicSocketDataProvider::Reset() {
348   reads_.clear();
349 }
350 
SimulateRead(const char * data)351 void DynamicSocketDataProvider::SimulateRead(const char* data) {
352   if (!allow_unconsumed_reads_) {
353     EXPECT_TRUE(reads_.empty()) << "Unconsumed read: " << reads_.front().data;
354   }
355   reads_.push_back(MockRead(data));
356 }
357 
AddSocketDataProvider(SocketDataProvider * data)358 void MockClientSocketFactory::AddSocketDataProvider(
359     SocketDataProvider* data) {
360   mock_data_.Add(data);
361 }
362 
AddSSLSocketDataProvider(SSLSocketDataProvider * data)363 void MockClientSocketFactory::AddSSLSocketDataProvider(
364     SSLSocketDataProvider* data) {
365   mock_ssl_data_.Add(data);
366 }
367 
ResetNextMockIndexes()368 void MockClientSocketFactory::ResetNextMockIndexes() {
369   mock_data_.ResetNextIndex();
370   mock_ssl_data_.ResetNextIndex();
371 }
372 
GetMockTCPClientSocket(int index) const373 MockTCPClientSocket* MockClientSocketFactory::GetMockTCPClientSocket(
374     int index) const {
375   return tcp_client_sockets_[index];
376 }
377 
GetMockSSLClientSocket(int index) const378 MockSSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket(
379     int index) const {
380   return ssl_client_sockets_[index];
381 }
382 
CreateTCPClientSocket(const AddressList & addresses)383 ClientSocket* MockClientSocketFactory::CreateTCPClientSocket(
384     const AddressList& addresses) {
385   SocketDataProvider* data_provider = mock_data_.GetNext();
386   MockTCPClientSocket* socket =
387       new MockTCPClientSocket(addresses, data_provider);
388   data_provider->set_socket(socket);
389   tcp_client_sockets_.push_back(socket);
390   return socket;
391 }
392 
CreateSSLClientSocket(ClientSocket * transport_socket,const std::string & hostname,const SSLConfig & ssl_config)393 SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket(
394     ClientSocket* transport_socket,
395     const std::string& hostname,
396     const SSLConfig& ssl_config) {
397   MockSSLClientSocket* socket =
398       new MockSSLClientSocket(transport_socket, hostname, ssl_config,
399                               mock_ssl_data_.GetNext());
400   ssl_client_sockets_.push_back(socket);
401   return socket;
402 }
403 
WaitForResult()404 int TestSocketRequest::WaitForResult() {
405   return callback_.WaitForResult();
406 }
407 
RunWithParams(const Tuple1<int> & params)408 void TestSocketRequest::RunWithParams(const Tuple1<int>& params) {
409   callback_.RunWithParams(params);
410   (*completion_count_)++;
411   request_order_->push_back(this);
412 }
413 
414 // static
415 const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
416 
417 // static
418 const int ClientSocketPoolTest::kRequestNotFound = -2;
419 
SetUp()420 void ClientSocketPoolTest::SetUp() {
421   completion_count_ = 0;
422 }
423 
TearDown()424 void ClientSocketPoolTest::TearDown() {
425   // The tests often call Reset() on handles at the end which may post
426   // DoReleaseSocket() tasks.
427   // Pending tasks created by client_socket_pool_base_unittest.cc are
428   // posted two milliseconds into the future and thus won't become
429   // scheduled until that time.
430   // We wait a few milliseconds to make sure that all such future tasks
431   // are ready to run, before calling RunAllPending(). This will work
432   // correctly even if Sleep() finishes late (and it should never finish
433   // early), as all we have to ensure is that actual wall-time has progressed
434   // past the scheduled starting time of the pending task.
435   PlatformThread::Sleep(10);
436   MessageLoop::current()->RunAllPending();
437 }
438 
GetOrderOfRequest(size_t index)439 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) {
440   index--;
441   if (index >= requests_.size())
442     return kIndexOutOfBounds;
443 
444   for (size_t i = 0; i < request_order_.size(); i++)
445     if (requests_[index] == request_order_[i])
446       return i + 1;
447 
448   return kRequestNotFound;
449 }
450 
ReleaseOneConnection(KeepAlive keep_alive)451 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
452   ScopedVector<TestSocketRequest>::iterator i;
453   for (i = requests_.begin(); i != requests_.end(); ++i) {
454     if ((*i)->handle()->is_initialized()) {
455       if (keep_alive == NO_KEEP_ALIVE)
456         (*i)->handle()->socket()->Disconnect();
457       (*i)->handle()->Reset();
458       MessageLoop::current()->RunAllPending();
459       return true;
460     }
461   }
462   return false;
463 }
464 
ReleaseAllConnections(KeepAlive keep_alive)465 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
466   bool released_one;
467   do {
468     released_one = ReleaseOneConnection(keep_alive);
469   } while (released_one);
470 }
471 
472 }  // namespace net
473