1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/socket_test_util.h"
6
7 #include <algorithm>
8
9 #include "base/basictypes.h"
10 #include "base/compiler_specific.h"
11 #include "base/message_loop.h"
12 #include "net/base/ssl_info.h"
13 #include "net/socket/socket.h"
14 #include "testing/gtest/include/gtest/gtest.h"
15
16 namespace net {
17
MockClientSocket()18 MockClientSocket::MockClientSocket()
19 : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)),
20 connected_(false) {
21 }
22
GetSSLInfo(net::SSLInfo * ssl_info)23 void MockClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
24 NOTREACHED();
25 }
26
GetSSLCertRequestInfo(net::SSLCertRequestInfo * cert_request_info)27 void MockClientSocket::GetSSLCertRequestInfo(
28 net::SSLCertRequestInfo* cert_request_info) {
29 NOTREACHED();
30 }
31
32 SSLClientSocket::NextProtoStatus
GetNextProto(std::string * proto)33 MockClientSocket::GetNextProto(std::string* proto) {
34 proto->clear();
35 return SSLClientSocket::kNextProtoUnsupported;
36 }
37
Disconnect()38 void MockClientSocket::Disconnect() {
39 connected_ = false;
40 }
41
IsConnected() const42 bool MockClientSocket::IsConnected() const {
43 return connected_;
44 }
45
IsConnectedAndIdle() const46 bool MockClientSocket::IsConnectedAndIdle() const {
47 return connected_;
48 }
49
GetPeerName(struct sockaddr * name,socklen_t * namelen)50 int MockClientSocket::GetPeerName(struct sockaddr* name, socklen_t* namelen) {
51 memset(reinterpret_cast<char *>(name), 0, *namelen);
52 return net::OK;
53 }
54
RunCallbackAsync(net::CompletionCallback * callback,int result)55 void MockClientSocket::RunCallbackAsync(net::CompletionCallback* callback,
56 int result) {
57 MessageLoop::current()->PostTask(FROM_HERE,
58 method_factory_.NewRunnableMethod(
59 &MockClientSocket::RunCallback, callback, result));
60 }
61
RunCallback(net::CompletionCallback * callback,int result)62 void MockClientSocket::RunCallback(net::CompletionCallback* callback,
63 int result) {
64 if (callback)
65 callback->Run(result);
66 }
67
MockTCPClientSocket(const net::AddressList & addresses,net::SocketDataProvider * data)68 MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses,
69 net::SocketDataProvider* data)
70 : addresses_(addresses),
71 data_(data),
72 read_offset_(0),
73 read_data_(false, net::ERR_UNEXPECTED),
74 need_read_data_(true),
75 peer_closed_connection_(false),
76 pending_buf_(NULL),
77 pending_buf_len_(0),
78 pending_callback_(NULL) {
79 DCHECK(data_);
80 data_->Reset();
81 }
82
Connect(net::CompletionCallback * callback,LoadLog * load_log)83 int MockTCPClientSocket::Connect(net::CompletionCallback* callback,
84 LoadLog* load_log) {
85 if (connected_)
86 return net::OK;
87 connected_ = true;
88 if (data_->connect_data().async) {
89 RunCallbackAsync(callback, data_->connect_data().result);
90 return net::ERR_IO_PENDING;
91 }
92 return data_->connect_data().result;
93 }
94
IsConnected() const95 bool MockTCPClientSocket::IsConnected() const {
96 return connected_ && !peer_closed_connection_;
97 }
98
Read(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)99 int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len,
100 net::CompletionCallback* callback) {
101 if (!connected_)
102 return net::ERR_UNEXPECTED;
103
104 // If the buffer is already in use, a read is already in progress!
105 DCHECK(pending_buf_ == NULL);
106
107 // Store our async IO data.
108 pending_buf_ = buf;
109 pending_buf_len_ = buf_len;
110 pending_callback_ = callback;
111
112 if (need_read_data_) {
113 read_data_ = data_->GetNextRead();
114 if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
115 // This MockRead is just a marker to instruct us to set
116 // peer_closed_connection_. Skip it and get the next one.
117 read_data_ = data_->GetNextRead();
118 peer_closed_connection_ = true;
119 }
120 // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
121 // to complete the async IO manually later (via OnReadComplete).
122 if (read_data_.result == ERR_IO_PENDING) {
123 DCHECK(callback); // We need to be using async IO in this case.
124 return ERR_IO_PENDING;
125 }
126 need_read_data_ = false;
127 }
128
129 return CompleteRead();
130 }
131
Write(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)132 int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len,
133 net::CompletionCallback* callback) {
134 DCHECK(buf);
135 DCHECK_GT(buf_len, 0);
136
137 if (!connected_)
138 return net::ERR_UNEXPECTED;
139
140 std::string data(buf->data(), buf_len);
141 net::MockWriteResult write_result = data_->OnWrite(data);
142
143 if (write_result.async) {
144 RunCallbackAsync(callback, write_result.result);
145 return net::ERR_IO_PENDING;
146 }
147 return write_result.result;
148 }
149
OnReadComplete(const MockRead & data)150 void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
151 // There must be a read pending.
152 DCHECK(pending_buf_);
153 // You can't complete a read with another ERR_IO_PENDING status code.
154 DCHECK_NE(ERR_IO_PENDING, data.result);
155 // Since we've been waiting for data, need_read_data_ should be true.
156 DCHECK(need_read_data_);
157
158 read_data_ = data;
159 need_read_data_ = false;
160
161 // The caller is simulating that this IO completes right now. Don't
162 // let CompleteRead() schedule a callback.
163 read_data_.async = false;
164
165 net::CompletionCallback* callback = pending_callback_;
166 int rv = CompleteRead();
167 RunCallback(callback, rv);
168 }
169
CompleteRead()170 int MockTCPClientSocket::CompleteRead() {
171 DCHECK(pending_buf_);
172 DCHECK(pending_buf_len_ > 0);
173
174 // Save the pending async IO data and reset our |pending_| state.
175 net::IOBuffer* buf = pending_buf_;
176 int buf_len = pending_buf_len_;
177 net::CompletionCallback* callback = pending_callback_;
178 pending_buf_ = NULL;
179 pending_buf_len_ = 0;
180 pending_callback_ = NULL;
181
182 int result = read_data_.result;
183 DCHECK(result != ERR_IO_PENDING);
184
185 if (read_data_.data) {
186 if (read_data_.data_len - read_offset_ > 0) {
187 result = std::min(buf_len, read_data_.data_len - read_offset_);
188 memcpy(buf->data(), read_data_.data + read_offset_, result);
189 read_offset_ += result;
190 if (read_offset_ == read_data_.data_len) {
191 need_read_data_ = true;
192 read_offset_ = 0;
193 }
194 } else {
195 result = 0; // EOF
196 }
197 }
198
199 if (read_data_.async) {
200 DCHECK(callback);
201 RunCallbackAsync(callback, result);
202 return net::ERR_IO_PENDING;
203 }
204 return result;
205 }
206
207 class MockSSLClientSocket::ConnectCallback :
208 public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> {
209 public:
ConnectCallback(MockSSLClientSocket * ssl_client_socket,net::CompletionCallback * user_callback,int rv)210 ConnectCallback(MockSSLClientSocket *ssl_client_socket,
211 net::CompletionCallback* user_callback,
212 int rv)
213 : ALLOW_THIS_IN_INITIALIZER_LIST(
214 net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>(
215 this, &ConnectCallback::Wrapper)),
216 ssl_client_socket_(ssl_client_socket),
217 user_callback_(user_callback),
218 rv_(rv) {
219 }
220
221 private:
Wrapper(int rv)222 void Wrapper(int rv) {
223 if (rv_ == net::OK)
224 ssl_client_socket_->connected_ = true;
225 user_callback_->Run(rv_);
226 delete this;
227 }
228
229 MockSSLClientSocket* ssl_client_socket_;
230 net::CompletionCallback* user_callback_;
231 int rv_;
232 };
233
MockSSLClientSocket(net::ClientSocket * transport_socket,const std::string & hostname,const net::SSLConfig & ssl_config,net::SSLSocketDataProvider * data)234 MockSSLClientSocket::MockSSLClientSocket(
235 net::ClientSocket* transport_socket,
236 const std::string& hostname,
237 const net::SSLConfig& ssl_config,
238 net::SSLSocketDataProvider* data)
239 : transport_(transport_socket),
240 data_(data) {
241 DCHECK(data_);
242 }
243
~MockSSLClientSocket()244 MockSSLClientSocket::~MockSSLClientSocket() {
245 Disconnect();
246 }
247
GetSSLInfo(net::SSLInfo * ssl_info)248 void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
249 ssl_info->Reset();
250 }
251
Connect(net::CompletionCallback * callback,LoadLog * load_log)252 int MockSSLClientSocket::Connect(net::CompletionCallback* callback,
253 LoadLog* load_log) {
254 ConnectCallback* connect_callback = new ConnectCallback(
255 this, callback, data_->connect.result);
256 int rv = transport_->Connect(connect_callback, load_log);
257 if (rv == net::OK) {
258 delete connect_callback;
259 if (data_->connect.async) {
260 RunCallbackAsync(callback, data_->connect.result);
261 return net::ERR_IO_PENDING;
262 }
263 if (data_->connect.result == net::OK)
264 connected_ = true;
265 return data_->connect.result;
266 }
267 return rv;
268 }
269
Disconnect()270 void MockSSLClientSocket::Disconnect() {
271 MockClientSocket::Disconnect();
272 if (transport_ != NULL)
273 transport_->Disconnect();
274 }
275
Read(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)276 int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len,
277 net::CompletionCallback* callback) {
278 return transport_->Read(buf, buf_len, callback);
279 }
280
Write(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)281 int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len,
282 net::CompletionCallback* callback) {
283 return transport_->Write(buf, buf_len, callback);
284 }
285
GetNextRead()286 MockRead StaticSocketDataProvider::GetNextRead() {
287 MockRead rv = reads_[read_index_];
288 if (reads_[read_index_].result != OK ||
289 reads_[read_index_].data_len != 0)
290 read_index_++; // Don't advance past an EOF.
291 return rv;
292 }
293
OnWrite(const std::string & data)294 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
295 if (!writes_) {
296 // Not using mock writes; succeed synchronously.
297 return MockWriteResult(false, data.length());
298 }
299
300 // Check that what we are writing matches the expectation.
301 // Then give the mocked return value.
302 net::MockWrite* w = &writes_[write_index_++];
303 int result = w->result;
304 if (w->data) {
305 // Note - we can simulate a partial write here. If the expected data
306 // is a match, but shorter than the write actually written, that is legal.
307 // Example:
308 // Application writes "foobarbaz" (9 bytes)
309 // Expected write was "foo" (3 bytes)
310 // This is a success, and we return 3 to the application.
311 std::string expected_data(w->data, w->data_len);
312 EXPECT_GE(data.length(), expected_data.length());
313 std::string actual_data(data.substr(0, w->data_len));
314 EXPECT_EQ(expected_data, actual_data);
315 if (expected_data != actual_data)
316 return MockWriteResult(false, net::ERR_UNEXPECTED);
317 if (result == net::OK)
318 result = w->data_len;
319 }
320 return MockWriteResult(w->async, result);
321 }
322
Reset()323 void StaticSocketDataProvider::Reset() {
324 read_index_ = 0;
325 write_index_ = 0;
326 }
327
DynamicSocketDataProvider()328 DynamicSocketDataProvider::DynamicSocketDataProvider()
329 : short_read_limit_(0),
330 allow_unconsumed_reads_(false) {
331 }
332
GetNextRead()333 MockRead DynamicSocketDataProvider::GetNextRead() {
334 if (reads_.empty())
335 return MockRead(false, ERR_UNEXPECTED);
336 MockRead result = reads_.front();
337 if (short_read_limit_ == 0 || result.data_len <= short_read_limit_) {
338 reads_.pop_front();
339 } else {
340 result.data_len = short_read_limit_;
341 reads_.front().data += result.data_len;
342 reads_.front().data_len -= result.data_len;
343 }
344 return result;
345 }
346
Reset()347 void DynamicSocketDataProvider::Reset() {
348 reads_.clear();
349 }
350
SimulateRead(const char * data)351 void DynamicSocketDataProvider::SimulateRead(const char* data) {
352 if (!allow_unconsumed_reads_) {
353 EXPECT_TRUE(reads_.empty()) << "Unconsumed read: " << reads_.front().data;
354 }
355 reads_.push_back(MockRead(data));
356 }
357
AddSocketDataProvider(SocketDataProvider * data)358 void MockClientSocketFactory::AddSocketDataProvider(
359 SocketDataProvider* data) {
360 mock_data_.Add(data);
361 }
362
AddSSLSocketDataProvider(SSLSocketDataProvider * data)363 void MockClientSocketFactory::AddSSLSocketDataProvider(
364 SSLSocketDataProvider* data) {
365 mock_ssl_data_.Add(data);
366 }
367
ResetNextMockIndexes()368 void MockClientSocketFactory::ResetNextMockIndexes() {
369 mock_data_.ResetNextIndex();
370 mock_ssl_data_.ResetNextIndex();
371 }
372
GetMockTCPClientSocket(int index) const373 MockTCPClientSocket* MockClientSocketFactory::GetMockTCPClientSocket(
374 int index) const {
375 return tcp_client_sockets_[index];
376 }
377
GetMockSSLClientSocket(int index) const378 MockSSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket(
379 int index) const {
380 return ssl_client_sockets_[index];
381 }
382
CreateTCPClientSocket(const AddressList & addresses)383 ClientSocket* MockClientSocketFactory::CreateTCPClientSocket(
384 const AddressList& addresses) {
385 SocketDataProvider* data_provider = mock_data_.GetNext();
386 MockTCPClientSocket* socket =
387 new MockTCPClientSocket(addresses, data_provider);
388 data_provider->set_socket(socket);
389 tcp_client_sockets_.push_back(socket);
390 return socket;
391 }
392
CreateSSLClientSocket(ClientSocket * transport_socket,const std::string & hostname,const SSLConfig & ssl_config)393 SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket(
394 ClientSocket* transport_socket,
395 const std::string& hostname,
396 const SSLConfig& ssl_config) {
397 MockSSLClientSocket* socket =
398 new MockSSLClientSocket(transport_socket, hostname, ssl_config,
399 mock_ssl_data_.GetNext());
400 ssl_client_sockets_.push_back(socket);
401 return socket;
402 }
403
WaitForResult()404 int TestSocketRequest::WaitForResult() {
405 return callback_.WaitForResult();
406 }
407
RunWithParams(const Tuple1<int> & params)408 void TestSocketRequest::RunWithParams(const Tuple1<int>& params) {
409 callback_.RunWithParams(params);
410 (*completion_count_)++;
411 request_order_->push_back(this);
412 }
413
414 // static
415 const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
416
417 // static
418 const int ClientSocketPoolTest::kRequestNotFound = -2;
419
SetUp()420 void ClientSocketPoolTest::SetUp() {
421 completion_count_ = 0;
422 }
423
TearDown()424 void ClientSocketPoolTest::TearDown() {
425 // The tests often call Reset() on handles at the end which may post
426 // DoReleaseSocket() tasks.
427 // Pending tasks created by client_socket_pool_base_unittest.cc are
428 // posted two milliseconds into the future and thus won't become
429 // scheduled until that time.
430 // We wait a few milliseconds to make sure that all such future tasks
431 // are ready to run, before calling RunAllPending(). This will work
432 // correctly even if Sleep() finishes late (and it should never finish
433 // early), as all we have to ensure is that actual wall-time has progressed
434 // past the scheduled starting time of the pending task.
435 PlatformThread::Sleep(10);
436 MessageLoop::current()->RunAllPending();
437 }
438
GetOrderOfRequest(size_t index)439 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) {
440 index--;
441 if (index >= requests_.size())
442 return kIndexOutOfBounds;
443
444 for (size_t i = 0; i < request_order_.size(); i++)
445 if (requests_[index] == request_order_[i])
446 return i + 1;
447
448 return kRequestNotFound;
449 }
450
ReleaseOneConnection(KeepAlive keep_alive)451 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
452 ScopedVector<TestSocketRequest>::iterator i;
453 for (i = requests_.begin(); i != requests_.end(); ++i) {
454 if ((*i)->handle()->is_initialized()) {
455 if (keep_alive == NO_KEEP_ALIVE)
456 (*i)->handle()->socket()->Disconnect();
457 (*i)->handle()->Reset();
458 MessageLoop::current()->RunAllPending();
459 return true;
460 }
461 }
462 return false;
463 }
464
ReleaseAllConnections(KeepAlive keep_alive)465 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
466 bool released_one;
467 do {
468 released_one = ReleaseOneConnection(keep_alive);
469 } while (released_one);
470 }
471
472 } // namespace net
473