• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/socket_test_util.h"
6 
7 #include <algorithm>
8 #include <vector>
9 
10 
11 #include "base/basictypes.h"
12 #include "base/compiler_specific.h"
13 #include "base/message_loop.h"
14 #include "base/time.h"
15 #include "net/base/address_family.h"
16 #include "net/base/auth.h"
17 #include "net/base/host_resolver_proc.h"
18 #include "net/base/ssl_cert_request_info.h"
19 #include "net/base/ssl_info.h"
20 #include "net/http/http_network_session.h"
21 #include "net/http/http_request_headers.h"
22 #include "net/http/http_response_headers.h"
23 #include "net/socket/client_socket_pool_histograms.h"
24 #include "net/socket/socket.h"
25 #include "net/socket/ssl_host_info.h"
26 #include "testing/gtest/include/gtest/gtest.h"
27 
28 #define NET_TRACE(level, s)   DLOG(level) << s << __FUNCTION__ << "() "
29 
30 namespace net {
31 
32 namespace {
33 
AsciifyHigh(char x)34 inline char AsciifyHigh(char x) {
35   char nybble = static_cast<char>((x >> 4) & 0x0F);
36   return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
37 }
38 
AsciifyLow(char x)39 inline char AsciifyLow(char x) {
40   char nybble = static_cast<char>((x >> 0) & 0x0F);
41   return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
42 }
43 
Asciify(char x)44 inline char Asciify(char x) {
45   if ((x < 0) || !isprint(x))
46     return '.';
47   return x;
48 }
49 
DumpData(const char * data,int data_len)50 void DumpData(const char* data, int data_len) {
51   if (logging::LOG_INFO < logging::GetMinLogLevel())
52     return;
53   DVLOG(1) << "Length:  " << data_len;
54   const char* pfx = "Data:    ";
55   if (!data || (data_len <= 0)) {
56     DVLOG(1) << pfx << "<None>";
57   } else {
58     int i;
59     for (i = 0; i <= (data_len - 4); i += 4) {
60       DVLOG(1) << pfx
61                << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
62                << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
63                << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
64                << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3])
65                << "  '"
66                << Asciify(data[i + 0])
67                << Asciify(data[i + 1])
68                << Asciify(data[i + 2])
69                << Asciify(data[i + 3])
70                << "'";
71       pfx = "         ";
72     }
73     // Take care of any 'trailing' bytes, if data_len was not a multiple of 4.
74     switch (data_len - i) {
75       case 3:
76         DVLOG(1) << pfx
77                  << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
78                  << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
79                  << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
80                  << "    '"
81                  << Asciify(data[i + 0])
82                  << Asciify(data[i + 1])
83                  << Asciify(data[i + 2])
84                  << " '";
85         break;
86       case 2:
87         DVLOG(1) << pfx
88                  << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
89                  << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
90                  << "      '"
91                  << Asciify(data[i + 0])
92                  << Asciify(data[i + 1])
93                  << "  '";
94         break;
95       case 1:
96         DVLOG(1) << pfx
97                  << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
98                  << "        '"
99                  << Asciify(data[i + 0])
100                  << "   '";
101         break;
102     }
103   }
104 }
105 
DumpMockRead(const MockRead & r)106 void DumpMockRead(const MockRead& r) {
107   if (logging::LOG_INFO < logging::GetMinLogLevel())
108     return;
109   DVLOG(1) << "Async:   " << r.async
110            << "\nResult:  " << r.result;
111   DumpData(r.data, r.data_len);
112   const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : "";
113   DVLOG(1) << "Stage:   " << (r.sequence_number & ~MockRead::STOPLOOP) << stop
114            << "\nTime:    " << r.time_stamp.ToInternalValue();
115 }
116 
117 }  // namespace
118 
StaticSocketDataProvider()119 StaticSocketDataProvider::StaticSocketDataProvider()
120     : reads_(NULL),
121       read_index_(0),
122       read_count_(0),
123       writes_(NULL),
124       write_index_(0),
125       write_count_(0) {
126 }
127 
StaticSocketDataProvider(MockRead * reads,size_t reads_count,MockWrite * writes,size_t writes_count)128 StaticSocketDataProvider::StaticSocketDataProvider(MockRead* reads,
129                                                    size_t reads_count,
130                                                    MockWrite* writes,
131                                                    size_t writes_count)
132     : reads_(reads),
133       read_index_(0),
134       read_count_(reads_count),
135       writes_(writes),
136       write_index_(0),
137       write_count_(writes_count) {
138 }
139 
~StaticSocketDataProvider()140 StaticSocketDataProvider::~StaticSocketDataProvider() {}
141 
PeekRead() const142 const MockRead& StaticSocketDataProvider::PeekRead() const {
143   DCHECK(!at_read_eof());
144   return reads_[read_index_];
145 }
146 
PeekWrite() const147 const MockWrite& StaticSocketDataProvider::PeekWrite() const {
148   DCHECK(!at_write_eof());
149   return writes_[write_index_];
150 }
151 
PeekRead(size_t index) const152 const MockRead& StaticSocketDataProvider::PeekRead(size_t index) const {
153   DCHECK_LT(index, read_count_);
154   return reads_[index];
155 }
156 
PeekWrite(size_t index) const157 const MockWrite& StaticSocketDataProvider::PeekWrite(size_t index) const {
158   DCHECK_LT(index, write_count_);
159   return writes_[index];
160 }
161 
GetNextRead()162 MockRead StaticSocketDataProvider::GetNextRead() {
163   DCHECK(!at_read_eof());
164   reads_[read_index_].time_stamp = base::Time::Now();
165   return reads_[read_index_++];
166 }
167 
OnWrite(const std::string & data)168 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
169   if (!writes_) {
170     // Not using mock writes; succeed synchronously.
171     return MockWriteResult(false, data.length());
172   }
173 
174   DCHECK(!at_write_eof());
175 
176   // Check that what we are writing matches the expectation.
177   // Then give the mocked return value.
178   net::MockWrite* w = &writes_[write_index_++];
179   w->time_stamp = base::Time::Now();
180   int result = w->result;
181   if (w->data) {
182     // Note - we can simulate a partial write here.  If the expected data
183     // is a match, but shorter than the write actually written, that is legal.
184     // Example:
185     //   Application writes "foobarbaz" (9 bytes)
186     //   Expected write was "foo" (3 bytes)
187     //   This is a success, and we return 3 to the application.
188     std::string expected_data(w->data, w->data_len);
189     EXPECT_GE(data.length(), expected_data.length());
190     std::string actual_data(data.substr(0, w->data_len));
191     EXPECT_EQ(expected_data, actual_data);
192     if (expected_data != actual_data)
193       return MockWriteResult(false, net::ERR_UNEXPECTED);
194     if (result == net::OK)
195       result = w->data_len;
196   }
197   return MockWriteResult(w->async, result);
198 }
199 
Reset()200 void StaticSocketDataProvider::Reset() {
201   read_index_ = 0;
202   write_index_ = 0;
203 }
204 
DynamicSocketDataProvider()205 DynamicSocketDataProvider::DynamicSocketDataProvider()
206     : short_read_limit_(0),
207       allow_unconsumed_reads_(false) {
208 }
209 
~DynamicSocketDataProvider()210 DynamicSocketDataProvider::~DynamicSocketDataProvider() {}
211 
GetNextRead()212 MockRead DynamicSocketDataProvider::GetNextRead() {
213   if (reads_.empty())
214     return MockRead(false, ERR_UNEXPECTED);
215   MockRead result = reads_.front();
216   if (short_read_limit_ == 0 || result.data_len <= short_read_limit_) {
217     reads_.pop_front();
218   } else {
219     result.data_len = short_read_limit_;
220     reads_.front().data += result.data_len;
221     reads_.front().data_len -= result.data_len;
222   }
223   return result;
224 }
225 
Reset()226 void DynamicSocketDataProvider::Reset() {
227   reads_.clear();
228 }
229 
SimulateRead(const char * data,const size_t length)230 void DynamicSocketDataProvider::SimulateRead(const char* data,
231                                              const size_t length) {
232   if (!allow_unconsumed_reads_) {
233     EXPECT_TRUE(reads_.empty()) << "Unconsumed read: " << reads_.front().data;
234   }
235   reads_.push_back(MockRead(true, data, length));
236 }
237 
SSLSocketDataProvider(bool async,int result)238 SSLSocketDataProvider::SSLSocketDataProvider(bool async, int result)
239     : connect(async, result),
240       next_proto_status(SSLClientSocket::kNextProtoUnsupported),
241       was_npn_negotiated(false),
242       cert_request_info(NULL) {
243 }
244 
~SSLSocketDataProvider()245 SSLSocketDataProvider::~SSLSocketDataProvider() {
246 }
247 
DelayedSocketData(int write_delay,MockRead * reads,size_t reads_count,MockWrite * writes,size_t writes_count)248 DelayedSocketData::DelayedSocketData(
249     int write_delay, MockRead* reads, size_t reads_count,
250     MockWrite* writes, size_t writes_count)
251     : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
252       write_delay_(write_delay),
253       ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) {
254   DCHECK_GE(write_delay_, 0);
255 }
256 
DelayedSocketData(const MockConnect & connect,int write_delay,MockRead * reads,size_t reads_count,MockWrite * writes,size_t writes_count)257 DelayedSocketData::DelayedSocketData(
258     const MockConnect& connect, int write_delay, MockRead* reads,
259     size_t reads_count, MockWrite* writes, size_t writes_count)
260     : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
261       write_delay_(write_delay),
262       ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) {
263   DCHECK_GE(write_delay_, 0);
264   set_connect_data(connect);
265 }
266 
~DelayedSocketData()267 DelayedSocketData::~DelayedSocketData() {
268 }
269 
ForceNextRead()270 void DelayedSocketData::ForceNextRead() {
271   write_delay_ = 0;
272   CompleteRead();
273 }
274 
GetNextRead()275 MockRead DelayedSocketData::GetNextRead() {
276   if (write_delay_ > 0)
277     return MockRead(true, ERR_IO_PENDING);
278   return StaticSocketDataProvider::GetNextRead();
279 }
280 
OnWrite(const std::string & data)281 MockWriteResult DelayedSocketData::OnWrite(const std::string& data) {
282   MockWriteResult rv = StaticSocketDataProvider::OnWrite(data);
283   // Now that our write has completed, we can allow reads to continue.
284   if (!--write_delay_)
285     MessageLoop::current()->PostDelayedTask(FROM_HERE,
286       factory_.NewRunnableMethod(&DelayedSocketData::CompleteRead), 100);
287   return rv;
288 }
289 
Reset()290 void DelayedSocketData::Reset() {
291   set_socket(NULL);
292   factory_.RevokeAll();
293   StaticSocketDataProvider::Reset();
294 }
295 
CompleteRead()296 void DelayedSocketData::CompleteRead() {
297   if (socket())
298     socket()->OnReadComplete(GetNextRead());
299 }
300 
OrderedSocketData(MockRead * reads,size_t reads_count,MockWrite * writes,size_t writes_count)301 OrderedSocketData::OrderedSocketData(
302     MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count)
303     : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
304       sequence_number_(0), loop_stop_stage_(0), callback_(NULL),
305       blocked_(false), ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) {
306 }
307 
OrderedSocketData(const MockConnect & connect,MockRead * reads,size_t reads_count,MockWrite * writes,size_t writes_count)308 OrderedSocketData::OrderedSocketData(
309     const MockConnect& connect,
310     MockRead* reads, size_t reads_count,
311     MockWrite* writes, size_t writes_count)
312     : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
313       sequence_number_(0), loop_stop_stage_(0), callback_(NULL),
314       blocked_(false), ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) {
315   set_connect_data(connect);
316 }
317 
EndLoop()318 void OrderedSocketData::EndLoop() {
319   // If we've already stopped the loop, don't do it again until we've advanced
320   // to the next sequence_number.
321   NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_ << ": EndLoop()";
322   if (loop_stop_stage_ > 0) {
323     const MockRead& next_read = StaticSocketDataProvider::PeekRead();
324     if ((next_read.sequence_number & ~MockRead::STOPLOOP) >
325         loop_stop_stage_) {
326       NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_
327                                 << ": Clearing stop index";
328       loop_stop_stage_ = 0;
329     } else {
330       return;
331     }
332   }
333   // Record the sequence_number at which we stopped the loop.
334   NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_
335                             << ": Posting Quit at read " << read_index();
336   loop_stop_stage_ = sequence_number_;
337   if (callback_)
338     callback_->RunWithParams(Tuple1<int>(ERR_IO_PENDING));
339 }
340 
GetNextRead()341 MockRead OrderedSocketData::GetNextRead() {
342   factory_.RevokeAll();
343   blocked_ = false;
344   const MockRead& next_read = StaticSocketDataProvider::PeekRead();
345   if (next_read.sequence_number & MockRead::STOPLOOP)
346     EndLoop();
347   if ((next_read.sequence_number & ~MockRead::STOPLOOP) <=
348       sequence_number_++) {
349     NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_ - 1
350                               << ": Read " << read_index();
351     DumpMockRead(next_read);
352     return StaticSocketDataProvider::GetNextRead();
353   }
354   NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_ - 1
355                             << ": I/O Pending";
356   MockRead result = MockRead(true, ERR_IO_PENDING);
357   DumpMockRead(result);
358   blocked_ = true;
359   return result;
360 }
361 
OnWrite(const std::string & data)362 MockWriteResult OrderedSocketData::OnWrite(const std::string& data) {
363   NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_
364                             << ": Write " << write_index();
365   DumpMockRead(PeekWrite());
366   ++sequence_number_;
367   if (blocked_) {
368     // TODO(willchan): This 100ms delay seems to work around some weirdness.  We
369     // should probably fix the weirdness.  One example is in SpdyStream,
370     // DoSendRequest() will return ERR_IO_PENDING, and there's a race.  If the
371     // SYN_REPLY causes OnResponseReceived() to get called before
372     // SpdyStream::ReadResponseHeaders() is called, we hit a NOTREACHED().
373     MessageLoop::current()->PostDelayedTask(
374         FROM_HERE,
375         factory_.NewRunnableMethod(&OrderedSocketData::CompleteRead), 100);
376   }
377   return StaticSocketDataProvider::OnWrite(data);
378 }
379 
Reset()380 void OrderedSocketData::Reset() {
381   NET_TRACE(INFO, "  *** ") << "Stage "
382                             << sequence_number_ << ": Reset()";
383   sequence_number_ = 0;
384   loop_stop_stage_ = 0;
385   set_socket(NULL);
386   factory_.RevokeAll();
387   StaticSocketDataProvider::Reset();
388 }
389 
CompleteRead()390 void OrderedSocketData::CompleteRead() {
391   if (socket()) {
392     NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_;
393     socket()->OnReadComplete(GetNextRead());
394   }
395 }
396 
~OrderedSocketData()397 OrderedSocketData::~OrderedSocketData() {}
398 
DeterministicSocketData(MockRead * reads,size_t reads_count,MockWrite * writes,size_t writes_count)399 DeterministicSocketData::DeterministicSocketData(MockRead* reads,
400     size_t reads_count, MockWrite* writes, size_t writes_count)
401     : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
402       sequence_number_(0),
403       current_read_(),
404       current_write_(),
405       stopping_sequence_number_(0),
406       stopped_(false),
407       print_debug_(false) {}
408 
~DeterministicSocketData()409 DeterministicSocketData::~DeterministicSocketData() {}
410 
Run()411 void DeterministicSocketData::Run() {
412   SetStopped(false);
413   int counter = 0;
414   // Continue to consume data until all data has run out, or the stopped_ flag
415   // has been set. Consuming data requires two separate operations -- running
416   // the tasks in the message loop, and explicitly invoking the read/write
417   // callbacks (simulating network I/O). We check our conditions between each,
418   // since they can change in either.
419   while ((!at_write_eof() || !at_read_eof()) && !stopped()) {
420     if (counter % 2 == 0)
421       MessageLoop::current()->RunAllPending();
422     if (counter % 2 == 1) {
423       InvokeCallbacks();
424     }
425     counter++;
426   }
427   // We're done consuming new data, but it is possible there are still some
428   // pending callbacks which we expect to complete before returning.
429   while (socket_ && (socket_->write_pending() || socket_->read_pending()) &&
430          !stopped()) {
431     InvokeCallbacks();
432     MessageLoop::current()->RunAllPending();
433   }
434   SetStopped(false);
435 }
436 
RunFor(int steps)437 void DeterministicSocketData::RunFor(int steps) {
438   StopAfter(steps);
439   Run();
440 }
441 
SetStop(int seq)442 void DeterministicSocketData::SetStop(int seq) {
443   DCHECK_LT(sequence_number_, seq);
444   stopping_sequence_number_ = seq;
445   stopped_ = false;
446 }
447 
StopAfter(int seq)448 void DeterministicSocketData::StopAfter(int seq) {
449   SetStop(sequence_number_ + seq);
450 }
451 
GetNextRead()452 MockRead DeterministicSocketData::GetNextRead() {
453   current_read_ = StaticSocketDataProvider::PeekRead();
454   EXPECT_LE(sequence_number_, current_read_.sequence_number);
455 
456   // Synchronous read while stopped is an error
457   if (stopped() && !current_read_.async) {
458     LOG(ERROR) << "Unable to perform synchronous IO while stopped";
459     return MockRead(false, ERR_UNEXPECTED);
460   }
461 
462   // Async read which will be called back in a future step.
463   if (sequence_number_ < current_read_.sequence_number) {
464     NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_
465                               << ": I/O Pending";
466     MockRead result = MockRead(false, ERR_IO_PENDING);
467     if (!current_read_.async) {
468       LOG(ERROR) << "Unable to perform synchronous read: "
469           << current_read_.sequence_number
470           << " at stage: " << sequence_number_;
471       result = MockRead(false, ERR_UNEXPECTED);
472     }
473     if (print_debug_)
474       DumpMockRead(result);
475     return result;
476   }
477 
478   NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_
479                             << ": Read " << read_index();
480   if (print_debug_)
481     DumpMockRead(current_read_);
482 
483   // Increment the sequence number if IO is complete
484   if (!current_read_.async)
485     NextStep();
486 
487   DCHECK_NE(ERR_IO_PENDING, current_read_.result);
488   StaticSocketDataProvider::GetNextRead();
489 
490   return current_read_;
491 }
492 
OnWrite(const std::string & data)493 MockWriteResult DeterministicSocketData::OnWrite(const std::string& data) {
494   const MockWrite& next_write = StaticSocketDataProvider::PeekWrite();
495   current_write_ = next_write;
496 
497   // Synchronous write while stopped is an error
498   if (stopped() && !next_write.async) {
499     LOG(ERROR) << "Unable to perform synchronous IO while stopped";
500     return MockWriteResult(false, ERR_UNEXPECTED);
501   }
502 
503   // Async write which will be called back in a future step.
504   if (sequence_number_ < next_write.sequence_number) {
505     NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_
506                               << ": I/O Pending";
507     if (!next_write.async) {
508       LOG(ERROR) << "Unable to perform synchronous write: "
509           << next_write.sequence_number << " at stage: " << sequence_number_;
510       return MockWriteResult(false, ERR_UNEXPECTED);
511     }
512   } else {
513     NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_
514                               << ": Write " << write_index();
515   }
516 
517   if (print_debug_)
518     DumpMockRead(next_write);
519 
520   // Move to the next step if I/O is synchronous, since the operation will
521   // complete when this method returns.
522   if (!next_write.async)
523     NextStep();
524 
525   // This is either a sync write for this step, or an async write.
526   return StaticSocketDataProvider::OnWrite(data);
527 }
528 
Reset()529 void DeterministicSocketData::Reset() {
530   NET_TRACE(INFO, "  *** ") << "Stage "
531                             << sequence_number_ << ": Reset()";
532   sequence_number_ = 0;
533   StaticSocketDataProvider::Reset();
534   NOTREACHED();
535 }
536 
InvokeCallbacks()537 void DeterministicSocketData::InvokeCallbacks() {
538   if (socket_ && socket_->write_pending() &&
539       (current_write().sequence_number == sequence_number())) {
540     socket_->CompleteWrite();
541     NextStep();
542     return;
543   }
544   if (socket_ && socket_->read_pending() &&
545       (current_read().sequence_number == sequence_number())) {
546     socket_->CompleteRead();
547     NextStep();
548     return;
549   }
550 }
551 
NextStep()552 void DeterministicSocketData::NextStep() {
553   // Invariant: Can never move *past* the stopping step.
554   DCHECK_LT(sequence_number_, stopping_sequence_number_);
555   sequence_number_++;
556   if (sequence_number_ == stopping_sequence_number_)
557     SetStopped(true);
558 }
559 
MockClientSocketFactory()560 MockClientSocketFactory::MockClientSocketFactory() {}
561 
~MockClientSocketFactory()562 MockClientSocketFactory::~MockClientSocketFactory() {}
563 
AddSocketDataProvider(SocketDataProvider * data)564 void MockClientSocketFactory::AddSocketDataProvider(
565     SocketDataProvider* data) {
566   mock_data_.Add(data);
567 }
568 
AddSSLSocketDataProvider(SSLSocketDataProvider * data)569 void MockClientSocketFactory::AddSSLSocketDataProvider(
570     SSLSocketDataProvider* data) {
571   mock_ssl_data_.Add(data);
572 }
573 
ResetNextMockIndexes()574 void MockClientSocketFactory::ResetNextMockIndexes() {
575   mock_data_.ResetNextIndex();
576   mock_ssl_data_.ResetNextIndex();
577 }
578 
GetMockTCPClientSocket(size_t index) const579 MockTCPClientSocket* MockClientSocketFactory::GetMockTCPClientSocket(
580     size_t index) const {
581   DCHECK_LT(index, tcp_client_sockets_.size());
582   return tcp_client_sockets_[index];
583 }
584 
GetMockSSLClientSocket(size_t index) const585 MockSSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket(
586     size_t index) const {
587   DCHECK_LT(index, ssl_client_sockets_.size());
588   return ssl_client_sockets_[index];
589 }
590 
CreateTransportClientSocket(const AddressList & addresses,net::NetLog * net_log,const NetLog::Source & source)591 ClientSocket* MockClientSocketFactory::CreateTransportClientSocket(
592     const AddressList& addresses,
593     net::NetLog* net_log,
594     const NetLog::Source& source) {
595   SocketDataProvider* data_provider = mock_data_.GetNext();
596   MockTCPClientSocket* socket =
597       new MockTCPClientSocket(addresses, net_log, data_provider);
598   data_provider->set_socket(socket);
599   tcp_client_sockets_.push_back(socket);
600   return socket;
601 }
602 
CreateSSLClientSocket(ClientSocketHandle * transport_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config,SSLHostInfo * ssl_host_info,CertVerifier * cert_verifier,DnsCertProvenanceChecker * dns_cert_checker)603 SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket(
604     ClientSocketHandle* transport_socket,
605     const HostPortPair& host_and_port,
606     const SSLConfig& ssl_config,
607     SSLHostInfo* ssl_host_info,
608     CertVerifier* cert_verifier,
609     DnsCertProvenanceChecker* dns_cert_checker) {
610   MockSSLClientSocket* socket =
611       new MockSSLClientSocket(transport_socket, host_and_port, ssl_config,
612                               ssl_host_info, mock_ssl_data_.GetNext());
613   ssl_client_sockets_.push_back(socket);
614   return socket;
615 }
616 
ClearSSLSessionCache()617 void MockClientSocketFactory::ClearSSLSessionCache() {
618 }
619 
MockClientSocket(net::NetLog * net_log)620 MockClientSocket::MockClientSocket(net::NetLog* net_log)
621     : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)),
622       connected_(false),
623       net_log_(NetLog::Source(), net_log) {
624 }
625 
SetReceiveBufferSize(int32 size)626 bool MockClientSocket::SetReceiveBufferSize(int32 size) {
627   return true;
628 }
629 
SetSendBufferSize(int32 size)630 bool MockClientSocket::SetSendBufferSize(int32 size) {
631   return true;
632 }
633 
Disconnect()634 void MockClientSocket::Disconnect() {
635   connected_ = false;
636 }
637 
IsConnected() const638 bool MockClientSocket::IsConnected() const {
639   return connected_;
640 }
641 
IsConnectedAndIdle() const642 bool MockClientSocket::IsConnectedAndIdle() const {
643   return connected_;
644 }
645 
GetPeerAddress(AddressList * address) const646 int MockClientSocket::GetPeerAddress(AddressList* address) const {
647   return net::SystemHostResolverProc("192.0.2.33", ADDRESS_FAMILY_UNSPECIFIED,
648                                      0, address, NULL);
649 }
650 
GetLocalAddress(IPEndPoint * address) const651 int MockClientSocket::GetLocalAddress(IPEndPoint* address) const {
652   IPAddressNumber ip;
653   if (!ParseIPLiteralToNumber("192.0.2.33", &ip))
654     return ERR_FAILED;
655   *address = IPEndPoint(ip, 123);
656       return OK;
657 }
658 
NetLog() const659 const BoundNetLog& MockClientSocket::NetLog() const {
660   return net_log_;
661 }
662 
GetSSLInfo(net::SSLInfo * ssl_info)663 void MockClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
664   NOTREACHED();
665 }
666 
GetSSLCertRequestInfo(net::SSLCertRequestInfo * cert_request_info)667 void MockClientSocket::GetSSLCertRequestInfo(
668     net::SSLCertRequestInfo* cert_request_info) {
669 }
670 
671 SSLClientSocket::NextProtoStatus
GetNextProto(std::string * proto)672 MockClientSocket::GetNextProto(std::string* proto) {
673   proto->clear();
674   return SSLClientSocket::kNextProtoUnsupported;
675 }
676 
~MockClientSocket()677 MockClientSocket::~MockClientSocket() {}
678 
RunCallbackAsync(net::CompletionCallback * callback,int result)679 void MockClientSocket::RunCallbackAsync(net::CompletionCallback* callback,
680                                         int result) {
681   MessageLoop::current()->PostTask(FROM_HERE,
682       method_factory_.NewRunnableMethod(
683           &MockClientSocket::RunCallback, callback, result));
684 }
685 
RunCallback(net::CompletionCallback * callback,int result)686 void MockClientSocket::RunCallback(net::CompletionCallback* callback,
687                                    int result) {
688   if (callback)
689     callback->Run(result);
690 }
691 
MockTCPClientSocket(const net::AddressList & addresses,net::NetLog * net_log,net::SocketDataProvider * data)692 MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses,
693                                          net::NetLog* net_log,
694                                          net::SocketDataProvider* data)
695     : MockClientSocket(net_log),
696       addresses_(addresses),
697       data_(data),
698       read_offset_(0),
699       read_data_(false, net::ERR_UNEXPECTED),
700       need_read_data_(true),
701       peer_closed_connection_(false),
702       pending_buf_(NULL),
703       pending_buf_len_(0),
704       pending_callback_(NULL),
705       was_used_to_convey_data_(false) {
706   DCHECK(data_);
707   data_->Reset();
708 }
709 
Read(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)710 int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len,
711                               net::CompletionCallback* callback) {
712   if (!connected_)
713     return net::ERR_UNEXPECTED;
714 
715   // If the buffer is already in use, a read is already in progress!
716   DCHECK(pending_buf_ == NULL);
717 
718   // Store our async IO data.
719   pending_buf_ = buf;
720   pending_buf_len_ = buf_len;
721   pending_callback_ = callback;
722 
723   if (need_read_data_) {
724     read_data_ = data_->GetNextRead();
725     if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
726       // This MockRead is just a marker to instruct us to set
727       // peer_closed_connection_.  Skip it and get the next one.
728       read_data_ = data_->GetNextRead();
729       peer_closed_connection_ = true;
730     }
731     // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
732     // to complete the async IO manually later (via OnReadComplete).
733     if (read_data_.result == ERR_IO_PENDING) {
734       DCHECK(callback);  // We need to be using async IO in this case.
735       return ERR_IO_PENDING;
736     }
737     need_read_data_ = false;
738   }
739 
740   return CompleteRead();
741 }
742 
Write(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)743 int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len,
744                                net::CompletionCallback* callback) {
745   DCHECK(buf);
746   DCHECK_GT(buf_len, 0);
747 
748   if (!connected_)
749     return net::ERR_UNEXPECTED;
750 
751   std::string data(buf->data(), buf_len);
752   net::MockWriteResult write_result = data_->OnWrite(data);
753 
754   was_used_to_convey_data_ = true;
755 
756   if (write_result.async) {
757     RunCallbackAsync(callback, write_result.result);
758     return net::ERR_IO_PENDING;
759   }
760 
761   return write_result.result;
762 }
763 
Connect(net::CompletionCallback * callback)764 int MockTCPClientSocket::Connect(net::CompletionCallback* callback) {
765   if (connected_)
766     return net::OK;
767   connected_ = true;
768   peer_closed_connection_ = false;
769   if (data_->connect_data().async) {
770     RunCallbackAsync(callback, data_->connect_data().result);
771     return net::ERR_IO_PENDING;
772   }
773   return data_->connect_data().result;
774 }
775 
Disconnect()776 void MockTCPClientSocket::Disconnect() {
777   MockClientSocket::Disconnect();
778   pending_callback_ = NULL;
779 }
780 
IsConnected() const781 bool MockTCPClientSocket::IsConnected() const {
782   return connected_ && !peer_closed_connection_;
783 }
784 
IsConnectedAndIdle() const785 bool MockTCPClientSocket::IsConnectedAndIdle() const {
786   return IsConnected();
787 }
788 
GetPeerAddress(AddressList * address) const789 int MockTCPClientSocket::GetPeerAddress(AddressList* address) const {
790   if (!IsConnected())
791     return ERR_SOCKET_NOT_CONNECTED;
792   return MockClientSocket::GetPeerAddress(address);
793 }
794 
WasEverUsed() const795 bool MockTCPClientSocket::WasEverUsed() const {
796   return was_used_to_convey_data_;
797 }
798 
UsingTCPFastOpen() const799 bool MockTCPClientSocket::UsingTCPFastOpen() const {
800   return false;
801 }
802 
OnReadComplete(const MockRead & data)803 void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
804   // There must be a read pending.
805   DCHECK(pending_buf_);
806   // You can't complete a read with another ERR_IO_PENDING status code.
807   DCHECK_NE(ERR_IO_PENDING, data.result);
808   // Since we've been waiting for data, need_read_data_ should be true.
809   DCHECK(need_read_data_);
810 
811   read_data_ = data;
812   need_read_data_ = false;
813 
814   // The caller is simulating that this IO completes right now.  Don't
815   // let CompleteRead() schedule a callback.
816   read_data_.async = false;
817 
818   net::CompletionCallback* callback = pending_callback_;
819   int rv = CompleteRead();
820   RunCallback(callback, rv);
821 }
822 
CompleteRead()823 int MockTCPClientSocket::CompleteRead() {
824   DCHECK(pending_buf_);
825   DCHECK(pending_buf_len_ > 0);
826 
827   was_used_to_convey_data_ = true;
828 
829   // Save the pending async IO data and reset our |pending_| state.
830   net::IOBuffer* buf = pending_buf_;
831   int buf_len = pending_buf_len_;
832   net::CompletionCallback* callback = pending_callback_;
833   pending_buf_ = NULL;
834   pending_buf_len_ = 0;
835   pending_callback_ = NULL;
836 
837   int result = read_data_.result;
838   DCHECK(result != ERR_IO_PENDING);
839 
840   if (read_data_.data) {
841     if (read_data_.data_len - read_offset_ > 0) {
842       result = std::min(buf_len, read_data_.data_len - read_offset_);
843       memcpy(buf->data(), read_data_.data + read_offset_, result);
844       read_offset_ += result;
845       if (read_offset_ == read_data_.data_len) {
846         need_read_data_ = true;
847         read_offset_ = 0;
848       }
849     } else {
850       result = 0;  // EOF
851     }
852   }
853 
854   if (read_data_.async) {
855     DCHECK(callback);
856     RunCallbackAsync(callback, result);
857     return net::ERR_IO_PENDING;
858   }
859   return result;
860 }
861 
DeterministicMockTCPClientSocket(net::NetLog * net_log,net::DeterministicSocketData * data)862 DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket(
863     net::NetLog* net_log, net::DeterministicSocketData* data)
864     : MockClientSocket(net_log),
865       write_pending_(false),
866       write_callback_(NULL),
867       write_result_(0),
868       read_data_(),
869       read_buf_(NULL),
870       read_buf_len_(0),
871       read_pending_(false),
872       read_callback_(NULL),
873       data_(data),
874       was_used_to_convey_data_(false) {}
875 
~DeterministicMockTCPClientSocket()876 DeterministicMockTCPClientSocket::~DeterministicMockTCPClientSocket() {}
877 
CompleteWrite()878 void DeterministicMockTCPClientSocket::CompleteWrite() {
879   was_used_to_convey_data_ = true;
880   write_pending_ = false;
881   write_callback_->Run(write_result_);
882 }
883 
CompleteRead()884 int DeterministicMockTCPClientSocket::CompleteRead() {
885   DCHECK_GT(read_buf_len_, 0);
886   DCHECK_LE(read_data_.data_len, read_buf_len_);
887   DCHECK(read_buf_);
888 
889   was_used_to_convey_data_ = true;
890 
891   if (read_data_.result == ERR_IO_PENDING)
892     read_data_ = data_->GetNextRead();
893   DCHECK_NE(ERR_IO_PENDING, read_data_.result);
894   // If read_data_.async is true, we do not need to wait, since this is already
895   // the callback. Therefore we don't even bother to check it.
896   int result = read_data_.result;
897 
898   if (read_data_.data_len > 0) {
899     DCHECK(read_data_.data);
900     result = std::min(read_buf_len_, read_data_.data_len);
901     memcpy(read_buf_->data(), read_data_.data, result);
902   }
903 
904   if (read_pending_) {
905     read_pending_ = false;
906     read_callback_->Run(result);
907   }
908 
909   return result;
910 }
911 
Write(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)912 int DeterministicMockTCPClientSocket::Write(
913     net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) {
914   DCHECK(buf);
915   DCHECK_GT(buf_len, 0);
916 
917   if (!connected_)
918     return net::ERR_UNEXPECTED;
919 
920   std::string data(buf->data(), buf_len);
921   net::MockWriteResult write_result = data_->OnWrite(data);
922 
923   if (write_result.async) {
924     write_callback_ = callback;
925     write_result_ = write_result.result;
926     DCHECK(write_callback_ != NULL);
927     write_pending_ = true;
928     return net::ERR_IO_PENDING;
929   }
930 
931   was_used_to_convey_data_ = true;
932   write_pending_ = false;
933   return write_result.result;
934 }
935 
Read(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)936 int DeterministicMockTCPClientSocket::Read(
937     net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) {
938   if (!connected_)
939     return net::ERR_UNEXPECTED;
940 
941   read_data_ = data_->GetNextRead();
942   // The buffer should always be big enough to contain all the MockRead data. To
943   // use small buffers, split the data into multiple MockReads.
944   DCHECK_LE(read_data_.data_len, buf_len);
945 
946   read_buf_ = buf;
947   read_buf_len_ = buf_len;
948   read_callback_ = callback;
949 
950   if (read_data_.async || (read_data_.result == ERR_IO_PENDING)) {
951     read_pending_ = true;
952     DCHECK(read_callback_);
953     return ERR_IO_PENDING;
954   }
955 
956   was_used_to_convey_data_ = true;
957   return CompleteRead();
958 }
959 
960 // TODO(erikchen): Support connect sequencing.
Connect(net::CompletionCallback * callback)961 int DeterministicMockTCPClientSocket::Connect(
962     net::CompletionCallback* callback) {
963   if (connected_)
964     return net::OK;
965   connected_ = true;
966   if (data_->connect_data().async) {
967     RunCallbackAsync(callback, data_->connect_data().result);
968     return net::ERR_IO_PENDING;
969   }
970   return data_->connect_data().result;
971 }
972 
Disconnect()973 void DeterministicMockTCPClientSocket::Disconnect() {
974   MockClientSocket::Disconnect();
975 }
976 
IsConnected() const977 bool DeterministicMockTCPClientSocket::IsConnected() const {
978   return connected_;
979 }
980 
IsConnectedAndIdle() const981 bool DeterministicMockTCPClientSocket::IsConnectedAndIdle() const {
982   return IsConnected();
983 }
984 
WasEverUsed() const985 bool DeterministicMockTCPClientSocket::WasEverUsed() const {
986   return was_used_to_convey_data_;
987 }
988 
UsingTCPFastOpen() const989 bool DeterministicMockTCPClientSocket::UsingTCPFastOpen() const {
990   return false;
991 }
992 
OnReadComplete(const MockRead & data)993 void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {}
994 
995 class MockSSLClientSocket::ConnectCallback
996     : public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> {
997  public:
ConnectCallback(MockSSLClientSocket * ssl_client_socket,net::CompletionCallback * user_callback,int rv)998   ConnectCallback(MockSSLClientSocket *ssl_client_socket,
999                   net::CompletionCallback* user_callback,
1000                   int rv)
1001       : ALLOW_THIS_IN_INITIALIZER_LIST(
1002           net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>(
1003                 this, &ConnectCallback::Wrapper)),
1004         ssl_client_socket_(ssl_client_socket),
1005         user_callback_(user_callback),
1006         rv_(rv) {
1007   }
1008 
1009  private:
Wrapper(int rv)1010   void Wrapper(int rv) {
1011     if (rv_ == net::OK)
1012       ssl_client_socket_->connected_ = true;
1013     user_callback_->Run(rv_);
1014     delete this;
1015   }
1016 
1017   MockSSLClientSocket* ssl_client_socket_;
1018   net::CompletionCallback* user_callback_;
1019   int rv_;
1020 };
1021 
MockSSLClientSocket(net::ClientSocketHandle * transport_socket,const HostPortPair & host_port_pair,const net::SSLConfig & ssl_config,SSLHostInfo * ssl_host_info,net::SSLSocketDataProvider * data)1022 MockSSLClientSocket::MockSSLClientSocket(
1023     net::ClientSocketHandle* transport_socket,
1024     const HostPortPair& host_port_pair,
1025     const net::SSLConfig& ssl_config,
1026     SSLHostInfo* ssl_host_info,
1027     net::SSLSocketDataProvider* data)
1028     : MockClientSocket(transport_socket->socket()->NetLog().net_log()),
1029       transport_(transport_socket),
1030       data_(data),
1031       is_npn_state_set_(false),
1032       new_npn_value_(false) {
1033   DCHECK(data_);
1034   delete ssl_host_info;  // we take ownership but don't use it.
1035 }
1036 
~MockSSLClientSocket()1037 MockSSLClientSocket::~MockSSLClientSocket() {
1038   Disconnect();
1039 }
1040 
Read(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)1041 int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len,
1042                               net::CompletionCallback* callback) {
1043   return transport_->socket()->Read(buf, buf_len, callback);
1044 }
1045 
Write(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)1046 int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len,
1047                                net::CompletionCallback* callback) {
1048   return transport_->socket()->Write(buf, buf_len, callback);
1049 }
1050 
Connect(net::CompletionCallback * callback)1051 int MockSSLClientSocket::Connect(net::CompletionCallback* callback) {
1052   ConnectCallback* connect_callback = new ConnectCallback(
1053       this, callback, data_->connect.result);
1054   int rv = transport_->socket()->Connect(connect_callback);
1055   if (rv == net::OK) {
1056     delete connect_callback;
1057     if (data_->connect.result == net::OK)
1058       connected_ = true;
1059     if (data_->connect.async) {
1060       RunCallbackAsync(callback, data_->connect.result);
1061       return net::ERR_IO_PENDING;
1062     }
1063     return data_->connect.result;
1064   }
1065   return rv;
1066 }
1067 
Disconnect()1068 void MockSSLClientSocket::Disconnect() {
1069   MockClientSocket::Disconnect();
1070   if (transport_->socket() != NULL)
1071     transport_->socket()->Disconnect();
1072 }
1073 
IsConnected() const1074 bool MockSSLClientSocket::IsConnected() const {
1075   return transport_->socket()->IsConnected();
1076 }
1077 
WasEverUsed() const1078 bool MockSSLClientSocket::WasEverUsed() const {
1079   return transport_->socket()->WasEverUsed();
1080 }
1081 
UsingTCPFastOpen() const1082 bool MockSSLClientSocket::UsingTCPFastOpen() const {
1083   return transport_->socket()->UsingTCPFastOpen();
1084 }
1085 
GetSSLInfo(net::SSLInfo * ssl_info)1086 void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
1087   ssl_info->Reset();
1088   ssl_info->cert = data_->cert_;
1089 }
1090 
GetSSLCertRequestInfo(net::SSLCertRequestInfo * cert_request_info)1091 void MockSSLClientSocket::GetSSLCertRequestInfo(
1092     net::SSLCertRequestInfo* cert_request_info) {
1093   DCHECK(cert_request_info);
1094   if (data_->cert_request_info) {
1095     cert_request_info->host_and_port =
1096         data_->cert_request_info->host_and_port;
1097     cert_request_info->client_certs = data_->cert_request_info->client_certs;
1098   } else {
1099     cert_request_info->Reset();
1100   }
1101 }
1102 
GetNextProto(std::string * proto)1103 SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto(
1104     std::string* proto) {
1105   *proto = data_->next_proto;
1106   return data_->next_proto_status;
1107 }
1108 
was_npn_negotiated() const1109 bool MockSSLClientSocket::was_npn_negotiated() const {
1110   if (is_npn_state_set_)
1111     return new_npn_value_;
1112   return data_->was_npn_negotiated;
1113 }
1114 
set_was_npn_negotiated(bool negotiated)1115 bool MockSSLClientSocket::set_was_npn_negotiated(bool negotiated) {
1116   is_npn_state_set_ = true;
1117   return new_npn_value_ = negotiated;
1118 }
1119 
OnReadComplete(const MockRead & data)1120 void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
1121   NOTIMPLEMENTED();
1122 }
1123 
TestSocketRequest(std::vector<TestSocketRequest * > * request_order,size_t * completion_count)1124 TestSocketRequest::TestSocketRequest(
1125     std::vector<TestSocketRequest*>* request_order,
1126     size_t* completion_count)
1127     : request_order_(request_order),
1128       completion_count_(completion_count) {
1129   DCHECK(request_order);
1130   DCHECK(completion_count);
1131 }
1132 
~TestSocketRequest()1133 TestSocketRequest::~TestSocketRequest() {
1134 }
1135 
WaitForResult()1136 int TestSocketRequest::WaitForResult() {
1137   return callback_.WaitForResult();
1138 }
1139 
RunWithParams(const Tuple1<int> & params)1140 void TestSocketRequest::RunWithParams(const Tuple1<int>& params) {
1141   callback_.RunWithParams(params);
1142   (*completion_count_)++;
1143   request_order_->push_back(this);
1144 }
1145 
1146 // static
1147 const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
1148 
1149 // static
1150 const int ClientSocketPoolTest::kRequestNotFound = -2;
1151 
ClientSocketPoolTest()1152 ClientSocketPoolTest::ClientSocketPoolTest() : completion_count_(0) {}
~ClientSocketPoolTest()1153 ClientSocketPoolTest::~ClientSocketPoolTest() {}
1154 
GetOrderOfRequest(size_t index) const1155 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const {
1156   index--;
1157   if (index >= requests_.size())
1158     return kIndexOutOfBounds;
1159 
1160   for (size_t i = 0; i < request_order_.size(); i++)
1161     if (requests_[index] == request_order_[i])
1162       return i + 1;
1163 
1164   return kRequestNotFound;
1165 }
1166 
ReleaseOneConnection(KeepAlive keep_alive)1167 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
1168   ScopedVector<TestSocketRequest>::iterator i;
1169   for (i = requests_.begin(); i != requests_.end(); ++i) {
1170     if ((*i)->handle()->is_initialized()) {
1171       if (keep_alive == NO_KEEP_ALIVE)
1172         (*i)->handle()->socket()->Disconnect();
1173       (*i)->handle()->Reset();
1174       MessageLoop::current()->RunAllPending();
1175       return true;
1176     }
1177   }
1178   return false;
1179 }
1180 
ReleaseAllConnections(KeepAlive keep_alive)1181 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
1182   bool released_one;
1183   do {
1184     released_one = ReleaseOneConnection(keep_alive);
1185   } while (released_one);
1186 }
1187 
MockConnectJob(ClientSocket * socket,ClientSocketHandle * handle,CompletionCallback * callback)1188 MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
1189     ClientSocket* socket,
1190     ClientSocketHandle* handle,
1191     CompletionCallback* callback)
1192     : socket_(socket),
1193       handle_(handle),
1194       user_callback_(callback),
1195       ALLOW_THIS_IN_INITIALIZER_LIST(
1196           connect_callback_(this, &MockConnectJob::OnConnect)) {
1197 }
1198 
~MockConnectJob()1199 MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() {}
1200 
Connect()1201 int MockTransportClientSocketPool::MockConnectJob::Connect() {
1202   int rv = socket_->Connect(&connect_callback_);
1203   if (rv == OK) {
1204     user_callback_ = NULL;
1205     OnConnect(OK);
1206   }
1207   return rv;
1208 }
1209 
CancelHandle(const ClientSocketHandle * handle)1210 bool MockTransportClientSocketPool::MockConnectJob::CancelHandle(
1211     const ClientSocketHandle* handle) {
1212   if (handle != handle_)
1213     return false;
1214   socket_.reset();
1215   handle_ = NULL;
1216   user_callback_ = NULL;
1217   return true;
1218 }
1219 
OnConnect(int rv)1220 void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
1221   if (!socket_.get())
1222     return;
1223   if (rv == OK) {
1224     handle_->set_socket(socket_.release());
1225   } else {
1226     socket_.reset();
1227   }
1228 
1229   handle_ = NULL;
1230 
1231   if (user_callback_) {
1232     CompletionCallback* callback = user_callback_;
1233     user_callback_ = NULL;
1234     callback->Run(rv);
1235   }
1236 }
1237 
MockTransportClientSocketPool(int max_sockets,int max_sockets_per_group,ClientSocketPoolHistograms * histograms,ClientSocketFactory * socket_factory)1238 MockTransportClientSocketPool::MockTransportClientSocketPool(
1239     int max_sockets,
1240     int max_sockets_per_group,
1241     ClientSocketPoolHistograms* histograms,
1242     ClientSocketFactory* socket_factory)
1243     : TransportClientSocketPool(max_sockets, max_sockets_per_group, histograms,
1244                                 NULL, NULL, NULL),
1245       client_socket_factory_(socket_factory),
1246       release_count_(0),
1247       cancel_count_(0) {
1248 }
1249 
~MockTransportClientSocketPool()1250 MockTransportClientSocketPool::~MockTransportClientSocketPool() {}
1251 
RequestSocket(const std::string & group_name,const void * socket_params,RequestPriority priority,ClientSocketHandle * handle,CompletionCallback * callback,const BoundNetLog & net_log)1252 int MockTransportClientSocketPool::RequestSocket(const std::string& group_name,
1253                                            const void* socket_params,
1254                                            RequestPriority priority,
1255                                            ClientSocketHandle* handle,
1256                                            CompletionCallback* callback,
1257                                            const BoundNetLog& net_log) {
1258   ClientSocket* socket = client_socket_factory_->CreateTransportClientSocket(
1259       AddressList(), net_log.net_log(), net::NetLog::Source());
1260   MockConnectJob* job = new MockConnectJob(socket, handle, callback);
1261   job_list_.push_back(job);
1262   handle->set_pool_id(1);
1263   return job->Connect();
1264 }
1265 
CancelRequest(const std::string & group_name,ClientSocketHandle * handle)1266 void MockTransportClientSocketPool::CancelRequest(const std::string& group_name,
1267                                             ClientSocketHandle* handle) {
1268   std::vector<MockConnectJob*>::iterator i;
1269   for (i = job_list_.begin(); i != job_list_.end(); ++i) {
1270     if ((*i)->CancelHandle(handle)) {
1271       cancel_count_++;
1272       break;
1273     }
1274   }
1275 }
1276 
ReleaseSocket(const std::string & group_name,ClientSocket * socket,int id)1277 void MockTransportClientSocketPool::ReleaseSocket(const std::string& group_name,
1278                                             ClientSocket* socket, int id) {
1279   EXPECT_EQ(1, id);
1280   release_count_++;
1281   delete socket;
1282 }
1283 
DeterministicMockClientSocketFactory()1284 DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {}
1285 
~DeterministicMockClientSocketFactory()1286 DeterministicMockClientSocketFactory::~DeterministicMockClientSocketFactory() {}
1287 
AddSocketDataProvider(DeterministicSocketData * data)1288 void DeterministicMockClientSocketFactory::AddSocketDataProvider(
1289     DeterministicSocketData* data) {
1290   mock_data_.Add(data);
1291 }
1292 
AddSSLSocketDataProvider(SSLSocketDataProvider * data)1293 void DeterministicMockClientSocketFactory::AddSSLSocketDataProvider(
1294     SSLSocketDataProvider* data) {
1295   mock_ssl_data_.Add(data);
1296 }
1297 
ResetNextMockIndexes()1298 void DeterministicMockClientSocketFactory::ResetNextMockIndexes() {
1299   mock_data_.ResetNextIndex();
1300   mock_ssl_data_.ResetNextIndex();
1301 }
1302 
1303 MockSSLClientSocket* DeterministicMockClientSocketFactory::
GetMockSSLClientSocket(size_t index) const1304     GetMockSSLClientSocket(size_t index) const {
1305   DCHECK_LT(index, ssl_client_sockets_.size());
1306   return ssl_client_sockets_[index];
1307 }
1308 
CreateTransportClientSocket(const AddressList & addresses,net::NetLog * net_log,const net::NetLog::Source & source)1309 ClientSocket* DeterministicMockClientSocketFactory::CreateTransportClientSocket(
1310     const AddressList& addresses,
1311     net::NetLog* net_log,
1312     const net::NetLog::Source& source) {
1313   DeterministicSocketData* data_provider = mock_data().GetNext();
1314   DeterministicMockTCPClientSocket* socket =
1315       new DeterministicMockTCPClientSocket(net_log, data_provider);
1316   data_provider->set_socket(socket->AsWeakPtr());
1317   tcp_client_sockets().push_back(socket);
1318   return socket;
1319 }
1320 
CreateSSLClientSocket(ClientSocketHandle * transport_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config,SSLHostInfo * ssl_host_info,CertVerifier * cert_verifier,DnsCertProvenanceChecker * dns_cert_checker)1321 SSLClientSocket* DeterministicMockClientSocketFactory::CreateSSLClientSocket(
1322     ClientSocketHandle* transport_socket,
1323     const HostPortPair& host_and_port,
1324     const SSLConfig& ssl_config,
1325     SSLHostInfo* ssl_host_info,
1326     CertVerifier* cert_verifier,
1327     DnsCertProvenanceChecker* dns_cert_checker) {
1328   MockSSLClientSocket* socket =
1329       new MockSSLClientSocket(transport_socket, host_and_port, ssl_config,
1330                               ssl_host_info, mock_ssl_data_.GetNext());
1331   ssl_client_sockets_.push_back(socket);
1332   return socket;
1333 }
1334 
ClearSSLSessionCache()1335 void DeterministicMockClientSocketFactory::ClearSSLSessionCache() {
1336 }
1337 
MockSOCKSClientSocketPool(int max_sockets,int max_sockets_per_group,ClientSocketPoolHistograms * histograms,TransportClientSocketPool * transport_pool)1338 MockSOCKSClientSocketPool::MockSOCKSClientSocketPool(
1339     int max_sockets,
1340     int max_sockets_per_group,
1341     ClientSocketPoolHistograms* histograms,
1342     TransportClientSocketPool* transport_pool)
1343     : SOCKSClientSocketPool(max_sockets, max_sockets_per_group, histograms,
1344                             NULL, transport_pool, NULL),
1345       transport_pool_(transport_pool) {
1346 }
1347 
~MockSOCKSClientSocketPool()1348 MockSOCKSClientSocketPool::~MockSOCKSClientSocketPool() {}
1349 
RequestSocket(const std::string & group_name,const void * socket_params,RequestPriority priority,ClientSocketHandle * handle,CompletionCallback * callback,const BoundNetLog & net_log)1350 int MockSOCKSClientSocketPool::RequestSocket(const std::string& group_name,
1351                                              const void* socket_params,
1352                                              RequestPriority priority,
1353                                              ClientSocketHandle* handle,
1354                                              CompletionCallback* callback,
1355                                              const BoundNetLog& net_log) {
1356   return transport_pool_->RequestSocket(group_name,
1357                                         socket_params,
1358                                         priority,
1359                                         handle,
1360                                         callback,
1361                                         net_log);
1362 }
1363 
CancelRequest(const std::string & group_name,ClientSocketHandle * handle)1364 void MockSOCKSClientSocketPool::CancelRequest(
1365     const std::string& group_name,
1366     ClientSocketHandle* handle) {
1367   return transport_pool_->CancelRequest(group_name, handle);
1368 }
1369 
ReleaseSocket(const std::string & group_name,ClientSocket * socket,int id)1370 void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
1371                                               ClientSocket* socket, int id) {
1372   return transport_pool_->ReleaseSocket(group_name, socket, id);
1373 }
1374 
1375 const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 };
1376 const int kSOCKS5GreetRequestLength = arraysize(kSOCKS5GreetRequest);
1377 
1378 const char kSOCKS5GreetResponse[] = { 0x05, 0x00 };
1379 const int kSOCKS5GreetResponseLength = arraysize(kSOCKS5GreetResponse);
1380 
1381 const char kSOCKS5OkRequest[] =
1382     { 0x05, 0x01, 0x00, 0x03, 0x04, 'h', 'o', 's', 't', 0x00, 0x50 };
1383 const int kSOCKS5OkRequestLength = arraysize(kSOCKS5OkRequest);
1384 
1385 const char kSOCKS5OkResponse[] =
1386     { 0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x00, 0x50 };
1387 const int kSOCKS5OkResponseLength = arraysize(kSOCKS5OkResponse);
1388 
1389 }  // namespace net
1390