1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/socket_test_util.h"
6
7 #include <algorithm>
8 #include <vector>
9
10
11 #include "base/basictypes.h"
12 #include "base/compiler_specific.h"
13 #include "base/message_loop.h"
14 #include "base/time.h"
15 #include "net/base/address_family.h"
16 #include "net/base/auth.h"
17 #include "net/base/host_resolver_proc.h"
18 #include "net/base/ssl_cert_request_info.h"
19 #include "net/base/ssl_info.h"
20 #include "net/http/http_network_session.h"
21 #include "net/http/http_request_headers.h"
22 #include "net/http/http_response_headers.h"
23 #include "net/socket/client_socket_pool_histograms.h"
24 #include "net/socket/socket.h"
25 #include "net/socket/ssl_host_info.h"
26 #include "testing/gtest/include/gtest/gtest.h"
27
28 #define NET_TRACE(level, s) DLOG(level) << s << __FUNCTION__ << "() "
29
30 namespace net {
31
32 namespace {
33
AsciifyHigh(char x)34 inline char AsciifyHigh(char x) {
35 char nybble = static_cast<char>((x >> 4) & 0x0F);
36 return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
37 }
38
AsciifyLow(char x)39 inline char AsciifyLow(char x) {
40 char nybble = static_cast<char>((x >> 0) & 0x0F);
41 return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
42 }
43
Asciify(char x)44 inline char Asciify(char x) {
45 if ((x < 0) || !isprint(x))
46 return '.';
47 return x;
48 }
49
DumpData(const char * data,int data_len)50 void DumpData(const char* data, int data_len) {
51 if (logging::LOG_INFO < logging::GetMinLogLevel())
52 return;
53 DVLOG(1) << "Length: " << data_len;
54 const char* pfx = "Data: ";
55 if (!data || (data_len <= 0)) {
56 DVLOG(1) << pfx << "<None>";
57 } else {
58 int i;
59 for (i = 0; i <= (data_len - 4); i += 4) {
60 DVLOG(1) << pfx
61 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
62 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
63 << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
64 << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3])
65 << " '"
66 << Asciify(data[i + 0])
67 << Asciify(data[i + 1])
68 << Asciify(data[i + 2])
69 << Asciify(data[i + 3])
70 << "'";
71 pfx = " ";
72 }
73 // Take care of any 'trailing' bytes, if data_len was not a multiple of 4.
74 switch (data_len - i) {
75 case 3:
76 DVLOG(1) << pfx
77 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
78 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
79 << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
80 << " '"
81 << Asciify(data[i + 0])
82 << Asciify(data[i + 1])
83 << Asciify(data[i + 2])
84 << " '";
85 break;
86 case 2:
87 DVLOG(1) << pfx
88 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
89 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
90 << " '"
91 << Asciify(data[i + 0])
92 << Asciify(data[i + 1])
93 << " '";
94 break;
95 case 1:
96 DVLOG(1) << pfx
97 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
98 << " '"
99 << Asciify(data[i + 0])
100 << " '";
101 break;
102 }
103 }
104 }
105
DumpMockRead(const MockRead & r)106 void DumpMockRead(const MockRead& r) {
107 if (logging::LOG_INFO < logging::GetMinLogLevel())
108 return;
109 DVLOG(1) << "Async: " << r.async
110 << "\nResult: " << r.result;
111 DumpData(r.data, r.data_len);
112 const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : "";
113 DVLOG(1) << "Stage: " << (r.sequence_number & ~MockRead::STOPLOOP) << stop
114 << "\nTime: " << r.time_stamp.ToInternalValue();
115 }
116
117 } // namespace
118
StaticSocketDataProvider()119 StaticSocketDataProvider::StaticSocketDataProvider()
120 : reads_(NULL),
121 read_index_(0),
122 read_count_(0),
123 writes_(NULL),
124 write_index_(0),
125 write_count_(0) {
126 }
127
StaticSocketDataProvider(MockRead * reads,size_t reads_count,MockWrite * writes,size_t writes_count)128 StaticSocketDataProvider::StaticSocketDataProvider(MockRead* reads,
129 size_t reads_count,
130 MockWrite* writes,
131 size_t writes_count)
132 : reads_(reads),
133 read_index_(0),
134 read_count_(reads_count),
135 writes_(writes),
136 write_index_(0),
137 write_count_(writes_count) {
138 }
139
~StaticSocketDataProvider()140 StaticSocketDataProvider::~StaticSocketDataProvider() {}
141
PeekRead() const142 const MockRead& StaticSocketDataProvider::PeekRead() const {
143 DCHECK(!at_read_eof());
144 return reads_[read_index_];
145 }
146
PeekWrite() const147 const MockWrite& StaticSocketDataProvider::PeekWrite() const {
148 DCHECK(!at_write_eof());
149 return writes_[write_index_];
150 }
151
PeekRead(size_t index) const152 const MockRead& StaticSocketDataProvider::PeekRead(size_t index) const {
153 DCHECK_LT(index, read_count_);
154 return reads_[index];
155 }
156
PeekWrite(size_t index) const157 const MockWrite& StaticSocketDataProvider::PeekWrite(size_t index) const {
158 DCHECK_LT(index, write_count_);
159 return writes_[index];
160 }
161
GetNextRead()162 MockRead StaticSocketDataProvider::GetNextRead() {
163 DCHECK(!at_read_eof());
164 reads_[read_index_].time_stamp = base::Time::Now();
165 return reads_[read_index_++];
166 }
167
OnWrite(const std::string & data)168 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
169 if (!writes_) {
170 // Not using mock writes; succeed synchronously.
171 return MockWriteResult(false, data.length());
172 }
173
174 DCHECK(!at_write_eof());
175
176 // Check that what we are writing matches the expectation.
177 // Then give the mocked return value.
178 net::MockWrite* w = &writes_[write_index_++];
179 w->time_stamp = base::Time::Now();
180 int result = w->result;
181 if (w->data) {
182 // Note - we can simulate a partial write here. If the expected data
183 // is a match, but shorter than the write actually written, that is legal.
184 // Example:
185 // Application writes "foobarbaz" (9 bytes)
186 // Expected write was "foo" (3 bytes)
187 // This is a success, and we return 3 to the application.
188 std::string expected_data(w->data, w->data_len);
189 EXPECT_GE(data.length(), expected_data.length());
190 std::string actual_data(data.substr(0, w->data_len));
191 EXPECT_EQ(expected_data, actual_data);
192 if (expected_data != actual_data)
193 return MockWriteResult(false, net::ERR_UNEXPECTED);
194 if (result == net::OK)
195 result = w->data_len;
196 }
197 return MockWriteResult(w->async, result);
198 }
199
Reset()200 void StaticSocketDataProvider::Reset() {
201 read_index_ = 0;
202 write_index_ = 0;
203 }
204
DynamicSocketDataProvider()205 DynamicSocketDataProvider::DynamicSocketDataProvider()
206 : short_read_limit_(0),
207 allow_unconsumed_reads_(false) {
208 }
209
~DynamicSocketDataProvider()210 DynamicSocketDataProvider::~DynamicSocketDataProvider() {}
211
GetNextRead()212 MockRead DynamicSocketDataProvider::GetNextRead() {
213 if (reads_.empty())
214 return MockRead(false, ERR_UNEXPECTED);
215 MockRead result = reads_.front();
216 if (short_read_limit_ == 0 || result.data_len <= short_read_limit_) {
217 reads_.pop_front();
218 } else {
219 result.data_len = short_read_limit_;
220 reads_.front().data += result.data_len;
221 reads_.front().data_len -= result.data_len;
222 }
223 return result;
224 }
225
Reset()226 void DynamicSocketDataProvider::Reset() {
227 reads_.clear();
228 }
229
SimulateRead(const char * data,const size_t length)230 void DynamicSocketDataProvider::SimulateRead(const char* data,
231 const size_t length) {
232 if (!allow_unconsumed_reads_) {
233 EXPECT_TRUE(reads_.empty()) << "Unconsumed read: " << reads_.front().data;
234 }
235 reads_.push_back(MockRead(true, data, length));
236 }
237
SSLSocketDataProvider(bool async,int result)238 SSLSocketDataProvider::SSLSocketDataProvider(bool async, int result)
239 : connect(async, result),
240 next_proto_status(SSLClientSocket::kNextProtoUnsupported),
241 was_npn_negotiated(false),
242 cert_request_info(NULL) {
243 }
244
~SSLSocketDataProvider()245 SSLSocketDataProvider::~SSLSocketDataProvider() {
246 }
247
DelayedSocketData(int write_delay,MockRead * reads,size_t reads_count,MockWrite * writes,size_t writes_count)248 DelayedSocketData::DelayedSocketData(
249 int write_delay, MockRead* reads, size_t reads_count,
250 MockWrite* writes, size_t writes_count)
251 : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
252 write_delay_(write_delay),
253 ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) {
254 DCHECK_GE(write_delay_, 0);
255 }
256
DelayedSocketData(const MockConnect & connect,int write_delay,MockRead * reads,size_t reads_count,MockWrite * writes,size_t writes_count)257 DelayedSocketData::DelayedSocketData(
258 const MockConnect& connect, int write_delay, MockRead* reads,
259 size_t reads_count, MockWrite* writes, size_t writes_count)
260 : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
261 write_delay_(write_delay),
262 ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) {
263 DCHECK_GE(write_delay_, 0);
264 set_connect_data(connect);
265 }
266
~DelayedSocketData()267 DelayedSocketData::~DelayedSocketData() {
268 }
269
ForceNextRead()270 void DelayedSocketData::ForceNextRead() {
271 write_delay_ = 0;
272 CompleteRead();
273 }
274
GetNextRead()275 MockRead DelayedSocketData::GetNextRead() {
276 if (write_delay_ > 0)
277 return MockRead(true, ERR_IO_PENDING);
278 return StaticSocketDataProvider::GetNextRead();
279 }
280
OnWrite(const std::string & data)281 MockWriteResult DelayedSocketData::OnWrite(const std::string& data) {
282 MockWriteResult rv = StaticSocketDataProvider::OnWrite(data);
283 // Now that our write has completed, we can allow reads to continue.
284 if (!--write_delay_)
285 MessageLoop::current()->PostDelayedTask(FROM_HERE,
286 factory_.NewRunnableMethod(&DelayedSocketData::CompleteRead), 100);
287 return rv;
288 }
289
Reset()290 void DelayedSocketData::Reset() {
291 set_socket(NULL);
292 factory_.RevokeAll();
293 StaticSocketDataProvider::Reset();
294 }
295
CompleteRead()296 void DelayedSocketData::CompleteRead() {
297 if (socket())
298 socket()->OnReadComplete(GetNextRead());
299 }
300
OrderedSocketData(MockRead * reads,size_t reads_count,MockWrite * writes,size_t writes_count)301 OrderedSocketData::OrderedSocketData(
302 MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count)
303 : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
304 sequence_number_(0), loop_stop_stage_(0), callback_(NULL),
305 blocked_(false), ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) {
306 }
307
OrderedSocketData(const MockConnect & connect,MockRead * reads,size_t reads_count,MockWrite * writes,size_t writes_count)308 OrderedSocketData::OrderedSocketData(
309 const MockConnect& connect,
310 MockRead* reads, size_t reads_count,
311 MockWrite* writes, size_t writes_count)
312 : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
313 sequence_number_(0), loop_stop_stage_(0), callback_(NULL),
314 blocked_(false), ALLOW_THIS_IN_INITIALIZER_LIST(factory_(this)) {
315 set_connect_data(connect);
316 }
317
EndLoop()318 void OrderedSocketData::EndLoop() {
319 // If we've already stopped the loop, don't do it again until we've advanced
320 // to the next sequence_number.
321 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ << ": EndLoop()";
322 if (loop_stop_stage_ > 0) {
323 const MockRead& next_read = StaticSocketDataProvider::PeekRead();
324 if ((next_read.sequence_number & ~MockRead::STOPLOOP) >
325 loop_stop_stage_) {
326 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
327 << ": Clearing stop index";
328 loop_stop_stage_ = 0;
329 } else {
330 return;
331 }
332 }
333 // Record the sequence_number at which we stopped the loop.
334 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
335 << ": Posting Quit at read " << read_index();
336 loop_stop_stage_ = sequence_number_;
337 if (callback_)
338 callback_->RunWithParams(Tuple1<int>(ERR_IO_PENDING));
339 }
340
GetNextRead()341 MockRead OrderedSocketData::GetNextRead() {
342 factory_.RevokeAll();
343 blocked_ = false;
344 const MockRead& next_read = StaticSocketDataProvider::PeekRead();
345 if (next_read.sequence_number & MockRead::STOPLOOP)
346 EndLoop();
347 if ((next_read.sequence_number & ~MockRead::STOPLOOP) <=
348 sequence_number_++) {
349 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1
350 << ": Read " << read_index();
351 DumpMockRead(next_read);
352 return StaticSocketDataProvider::GetNextRead();
353 }
354 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1
355 << ": I/O Pending";
356 MockRead result = MockRead(true, ERR_IO_PENDING);
357 DumpMockRead(result);
358 blocked_ = true;
359 return result;
360 }
361
OnWrite(const std::string & data)362 MockWriteResult OrderedSocketData::OnWrite(const std::string& data) {
363 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
364 << ": Write " << write_index();
365 DumpMockRead(PeekWrite());
366 ++sequence_number_;
367 if (blocked_) {
368 // TODO(willchan): This 100ms delay seems to work around some weirdness. We
369 // should probably fix the weirdness. One example is in SpdyStream,
370 // DoSendRequest() will return ERR_IO_PENDING, and there's a race. If the
371 // SYN_REPLY causes OnResponseReceived() to get called before
372 // SpdyStream::ReadResponseHeaders() is called, we hit a NOTREACHED().
373 MessageLoop::current()->PostDelayedTask(
374 FROM_HERE,
375 factory_.NewRunnableMethod(&OrderedSocketData::CompleteRead), 100);
376 }
377 return StaticSocketDataProvider::OnWrite(data);
378 }
379
Reset()380 void OrderedSocketData::Reset() {
381 NET_TRACE(INFO, " *** ") << "Stage "
382 << sequence_number_ << ": Reset()";
383 sequence_number_ = 0;
384 loop_stop_stage_ = 0;
385 set_socket(NULL);
386 factory_.RevokeAll();
387 StaticSocketDataProvider::Reset();
388 }
389
CompleteRead()390 void OrderedSocketData::CompleteRead() {
391 if (socket()) {
392 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_;
393 socket()->OnReadComplete(GetNextRead());
394 }
395 }
396
~OrderedSocketData()397 OrderedSocketData::~OrderedSocketData() {}
398
DeterministicSocketData(MockRead * reads,size_t reads_count,MockWrite * writes,size_t writes_count)399 DeterministicSocketData::DeterministicSocketData(MockRead* reads,
400 size_t reads_count, MockWrite* writes, size_t writes_count)
401 : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
402 sequence_number_(0),
403 current_read_(),
404 current_write_(),
405 stopping_sequence_number_(0),
406 stopped_(false),
407 print_debug_(false) {}
408
~DeterministicSocketData()409 DeterministicSocketData::~DeterministicSocketData() {}
410
Run()411 void DeterministicSocketData::Run() {
412 SetStopped(false);
413 int counter = 0;
414 // Continue to consume data until all data has run out, or the stopped_ flag
415 // has been set. Consuming data requires two separate operations -- running
416 // the tasks in the message loop, and explicitly invoking the read/write
417 // callbacks (simulating network I/O). We check our conditions between each,
418 // since they can change in either.
419 while ((!at_write_eof() || !at_read_eof()) && !stopped()) {
420 if (counter % 2 == 0)
421 MessageLoop::current()->RunAllPending();
422 if (counter % 2 == 1) {
423 InvokeCallbacks();
424 }
425 counter++;
426 }
427 // We're done consuming new data, but it is possible there are still some
428 // pending callbacks which we expect to complete before returning.
429 while (socket_ && (socket_->write_pending() || socket_->read_pending()) &&
430 !stopped()) {
431 InvokeCallbacks();
432 MessageLoop::current()->RunAllPending();
433 }
434 SetStopped(false);
435 }
436
RunFor(int steps)437 void DeterministicSocketData::RunFor(int steps) {
438 StopAfter(steps);
439 Run();
440 }
441
SetStop(int seq)442 void DeterministicSocketData::SetStop(int seq) {
443 DCHECK_LT(sequence_number_, seq);
444 stopping_sequence_number_ = seq;
445 stopped_ = false;
446 }
447
StopAfter(int seq)448 void DeterministicSocketData::StopAfter(int seq) {
449 SetStop(sequence_number_ + seq);
450 }
451
GetNextRead()452 MockRead DeterministicSocketData::GetNextRead() {
453 current_read_ = StaticSocketDataProvider::PeekRead();
454 EXPECT_LE(sequence_number_, current_read_.sequence_number);
455
456 // Synchronous read while stopped is an error
457 if (stopped() && !current_read_.async) {
458 LOG(ERROR) << "Unable to perform synchronous IO while stopped";
459 return MockRead(false, ERR_UNEXPECTED);
460 }
461
462 // Async read which will be called back in a future step.
463 if (sequence_number_ < current_read_.sequence_number) {
464 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
465 << ": I/O Pending";
466 MockRead result = MockRead(false, ERR_IO_PENDING);
467 if (!current_read_.async) {
468 LOG(ERROR) << "Unable to perform synchronous read: "
469 << current_read_.sequence_number
470 << " at stage: " << sequence_number_;
471 result = MockRead(false, ERR_UNEXPECTED);
472 }
473 if (print_debug_)
474 DumpMockRead(result);
475 return result;
476 }
477
478 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
479 << ": Read " << read_index();
480 if (print_debug_)
481 DumpMockRead(current_read_);
482
483 // Increment the sequence number if IO is complete
484 if (!current_read_.async)
485 NextStep();
486
487 DCHECK_NE(ERR_IO_PENDING, current_read_.result);
488 StaticSocketDataProvider::GetNextRead();
489
490 return current_read_;
491 }
492
OnWrite(const std::string & data)493 MockWriteResult DeterministicSocketData::OnWrite(const std::string& data) {
494 const MockWrite& next_write = StaticSocketDataProvider::PeekWrite();
495 current_write_ = next_write;
496
497 // Synchronous write while stopped is an error
498 if (stopped() && !next_write.async) {
499 LOG(ERROR) << "Unable to perform synchronous IO while stopped";
500 return MockWriteResult(false, ERR_UNEXPECTED);
501 }
502
503 // Async write which will be called back in a future step.
504 if (sequence_number_ < next_write.sequence_number) {
505 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
506 << ": I/O Pending";
507 if (!next_write.async) {
508 LOG(ERROR) << "Unable to perform synchronous write: "
509 << next_write.sequence_number << " at stage: " << sequence_number_;
510 return MockWriteResult(false, ERR_UNEXPECTED);
511 }
512 } else {
513 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
514 << ": Write " << write_index();
515 }
516
517 if (print_debug_)
518 DumpMockRead(next_write);
519
520 // Move to the next step if I/O is synchronous, since the operation will
521 // complete when this method returns.
522 if (!next_write.async)
523 NextStep();
524
525 // This is either a sync write for this step, or an async write.
526 return StaticSocketDataProvider::OnWrite(data);
527 }
528
Reset()529 void DeterministicSocketData::Reset() {
530 NET_TRACE(INFO, " *** ") << "Stage "
531 << sequence_number_ << ": Reset()";
532 sequence_number_ = 0;
533 StaticSocketDataProvider::Reset();
534 NOTREACHED();
535 }
536
InvokeCallbacks()537 void DeterministicSocketData::InvokeCallbacks() {
538 if (socket_ && socket_->write_pending() &&
539 (current_write().sequence_number == sequence_number())) {
540 socket_->CompleteWrite();
541 NextStep();
542 return;
543 }
544 if (socket_ && socket_->read_pending() &&
545 (current_read().sequence_number == sequence_number())) {
546 socket_->CompleteRead();
547 NextStep();
548 return;
549 }
550 }
551
NextStep()552 void DeterministicSocketData::NextStep() {
553 // Invariant: Can never move *past* the stopping step.
554 DCHECK_LT(sequence_number_, stopping_sequence_number_);
555 sequence_number_++;
556 if (sequence_number_ == stopping_sequence_number_)
557 SetStopped(true);
558 }
559
MockClientSocketFactory()560 MockClientSocketFactory::MockClientSocketFactory() {}
561
~MockClientSocketFactory()562 MockClientSocketFactory::~MockClientSocketFactory() {}
563
AddSocketDataProvider(SocketDataProvider * data)564 void MockClientSocketFactory::AddSocketDataProvider(
565 SocketDataProvider* data) {
566 mock_data_.Add(data);
567 }
568
AddSSLSocketDataProvider(SSLSocketDataProvider * data)569 void MockClientSocketFactory::AddSSLSocketDataProvider(
570 SSLSocketDataProvider* data) {
571 mock_ssl_data_.Add(data);
572 }
573
ResetNextMockIndexes()574 void MockClientSocketFactory::ResetNextMockIndexes() {
575 mock_data_.ResetNextIndex();
576 mock_ssl_data_.ResetNextIndex();
577 }
578
GetMockTCPClientSocket(size_t index) const579 MockTCPClientSocket* MockClientSocketFactory::GetMockTCPClientSocket(
580 size_t index) const {
581 DCHECK_LT(index, tcp_client_sockets_.size());
582 return tcp_client_sockets_[index];
583 }
584
GetMockSSLClientSocket(size_t index) const585 MockSSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket(
586 size_t index) const {
587 DCHECK_LT(index, ssl_client_sockets_.size());
588 return ssl_client_sockets_[index];
589 }
590
CreateTransportClientSocket(const AddressList & addresses,net::NetLog * net_log,const NetLog::Source & source)591 ClientSocket* MockClientSocketFactory::CreateTransportClientSocket(
592 const AddressList& addresses,
593 net::NetLog* net_log,
594 const NetLog::Source& source) {
595 SocketDataProvider* data_provider = mock_data_.GetNext();
596 MockTCPClientSocket* socket =
597 new MockTCPClientSocket(addresses, net_log, data_provider);
598 data_provider->set_socket(socket);
599 tcp_client_sockets_.push_back(socket);
600 return socket;
601 }
602
CreateSSLClientSocket(ClientSocketHandle * transport_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config,SSLHostInfo * ssl_host_info,CertVerifier * cert_verifier,DnsCertProvenanceChecker * dns_cert_checker)603 SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket(
604 ClientSocketHandle* transport_socket,
605 const HostPortPair& host_and_port,
606 const SSLConfig& ssl_config,
607 SSLHostInfo* ssl_host_info,
608 CertVerifier* cert_verifier,
609 DnsCertProvenanceChecker* dns_cert_checker) {
610 MockSSLClientSocket* socket =
611 new MockSSLClientSocket(transport_socket, host_and_port, ssl_config,
612 ssl_host_info, mock_ssl_data_.GetNext());
613 ssl_client_sockets_.push_back(socket);
614 return socket;
615 }
616
ClearSSLSessionCache()617 void MockClientSocketFactory::ClearSSLSessionCache() {
618 }
619
MockClientSocket(net::NetLog * net_log)620 MockClientSocket::MockClientSocket(net::NetLog* net_log)
621 : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)),
622 connected_(false),
623 net_log_(NetLog::Source(), net_log) {
624 }
625
SetReceiveBufferSize(int32 size)626 bool MockClientSocket::SetReceiveBufferSize(int32 size) {
627 return true;
628 }
629
SetSendBufferSize(int32 size)630 bool MockClientSocket::SetSendBufferSize(int32 size) {
631 return true;
632 }
633
Disconnect()634 void MockClientSocket::Disconnect() {
635 connected_ = false;
636 }
637
IsConnected() const638 bool MockClientSocket::IsConnected() const {
639 return connected_;
640 }
641
IsConnectedAndIdle() const642 bool MockClientSocket::IsConnectedAndIdle() const {
643 return connected_;
644 }
645
GetPeerAddress(AddressList * address) const646 int MockClientSocket::GetPeerAddress(AddressList* address) const {
647 return net::SystemHostResolverProc("192.0.2.33", ADDRESS_FAMILY_UNSPECIFIED,
648 0, address, NULL);
649 }
650
GetLocalAddress(IPEndPoint * address) const651 int MockClientSocket::GetLocalAddress(IPEndPoint* address) const {
652 IPAddressNumber ip;
653 if (!ParseIPLiteralToNumber("192.0.2.33", &ip))
654 return ERR_FAILED;
655 *address = IPEndPoint(ip, 123);
656 return OK;
657 }
658
NetLog() const659 const BoundNetLog& MockClientSocket::NetLog() const {
660 return net_log_;
661 }
662
GetSSLInfo(net::SSLInfo * ssl_info)663 void MockClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
664 NOTREACHED();
665 }
666
GetSSLCertRequestInfo(net::SSLCertRequestInfo * cert_request_info)667 void MockClientSocket::GetSSLCertRequestInfo(
668 net::SSLCertRequestInfo* cert_request_info) {
669 }
670
671 SSLClientSocket::NextProtoStatus
GetNextProto(std::string * proto)672 MockClientSocket::GetNextProto(std::string* proto) {
673 proto->clear();
674 return SSLClientSocket::kNextProtoUnsupported;
675 }
676
~MockClientSocket()677 MockClientSocket::~MockClientSocket() {}
678
RunCallbackAsync(net::CompletionCallback * callback,int result)679 void MockClientSocket::RunCallbackAsync(net::CompletionCallback* callback,
680 int result) {
681 MessageLoop::current()->PostTask(FROM_HERE,
682 method_factory_.NewRunnableMethod(
683 &MockClientSocket::RunCallback, callback, result));
684 }
685
RunCallback(net::CompletionCallback * callback,int result)686 void MockClientSocket::RunCallback(net::CompletionCallback* callback,
687 int result) {
688 if (callback)
689 callback->Run(result);
690 }
691
MockTCPClientSocket(const net::AddressList & addresses,net::NetLog * net_log,net::SocketDataProvider * data)692 MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses,
693 net::NetLog* net_log,
694 net::SocketDataProvider* data)
695 : MockClientSocket(net_log),
696 addresses_(addresses),
697 data_(data),
698 read_offset_(0),
699 read_data_(false, net::ERR_UNEXPECTED),
700 need_read_data_(true),
701 peer_closed_connection_(false),
702 pending_buf_(NULL),
703 pending_buf_len_(0),
704 pending_callback_(NULL),
705 was_used_to_convey_data_(false) {
706 DCHECK(data_);
707 data_->Reset();
708 }
709
Read(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)710 int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len,
711 net::CompletionCallback* callback) {
712 if (!connected_)
713 return net::ERR_UNEXPECTED;
714
715 // If the buffer is already in use, a read is already in progress!
716 DCHECK(pending_buf_ == NULL);
717
718 // Store our async IO data.
719 pending_buf_ = buf;
720 pending_buf_len_ = buf_len;
721 pending_callback_ = callback;
722
723 if (need_read_data_) {
724 read_data_ = data_->GetNextRead();
725 if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
726 // This MockRead is just a marker to instruct us to set
727 // peer_closed_connection_. Skip it and get the next one.
728 read_data_ = data_->GetNextRead();
729 peer_closed_connection_ = true;
730 }
731 // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
732 // to complete the async IO manually later (via OnReadComplete).
733 if (read_data_.result == ERR_IO_PENDING) {
734 DCHECK(callback); // We need to be using async IO in this case.
735 return ERR_IO_PENDING;
736 }
737 need_read_data_ = false;
738 }
739
740 return CompleteRead();
741 }
742
Write(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)743 int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len,
744 net::CompletionCallback* callback) {
745 DCHECK(buf);
746 DCHECK_GT(buf_len, 0);
747
748 if (!connected_)
749 return net::ERR_UNEXPECTED;
750
751 std::string data(buf->data(), buf_len);
752 net::MockWriteResult write_result = data_->OnWrite(data);
753
754 was_used_to_convey_data_ = true;
755
756 if (write_result.async) {
757 RunCallbackAsync(callback, write_result.result);
758 return net::ERR_IO_PENDING;
759 }
760
761 return write_result.result;
762 }
763
Connect(net::CompletionCallback * callback)764 int MockTCPClientSocket::Connect(net::CompletionCallback* callback) {
765 if (connected_)
766 return net::OK;
767 connected_ = true;
768 peer_closed_connection_ = false;
769 if (data_->connect_data().async) {
770 RunCallbackAsync(callback, data_->connect_data().result);
771 return net::ERR_IO_PENDING;
772 }
773 return data_->connect_data().result;
774 }
775
Disconnect()776 void MockTCPClientSocket::Disconnect() {
777 MockClientSocket::Disconnect();
778 pending_callback_ = NULL;
779 }
780
IsConnected() const781 bool MockTCPClientSocket::IsConnected() const {
782 return connected_ && !peer_closed_connection_;
783 }
784
IsConnectedAndIdle() const785 bool MockTCPClientSocket::IsConnectedAndIdle() const {
786 return IsConnected();
787 }
788
GetPeerAddress(AddressList * address) const789 int MockTCPClientSocket::GetPeerAddress(AddressList* address) const {
790 if (!IsConnected())
791 return ERR_SOCKET_NOT_CONNECTED;
792 return MockClientSocket::GetPeerAddress(address);
793 }
794
WasEverUsed() const795 bool MockTCPClientSocket::WasEverUsed() const {
796 return was_used_to_convey_data_;
797 }
798
UsingTCPFastOpen() const799 bool MockTCPClientSocket::UsingTCPFastOpen() const {
800 return false;
801 }
802
OnReadComplete(const MockRead & data)803 void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
804 // There must be a read pending.
805 DCHECK(pending_buf_);
806 // You can't complete a read with another ERR_IO_PENDING status code.
807 DCHECK_NE(ERR_IO_PENDING, data.result);
808 // Since we've been waiting for data, need_read_data_ should be true.
809 DCHECK(need_read_data_);
810
811 read_data_ = data;
812 need_read_data_ = false;
813
814 // The caller is simulating that this IO completes right now. Don't
815 // let CompleteRead() schedule a callback.
816 read_data_.async = false;
817
818 net::CompletionCallback* callback = pending_callback_;
819 int rv = CompleteRead();
820 RunCallback(callback, rv);
821 }
822
CompleteRead()823 int MockTCPClientSocket::CompleteRead() {
824 DCHECK(pending_buf_);
825 DCHECK(pending_buf_len_ > 0);
826
827 was_used_to_convey_data_ = true;
828
829 // Save the pending async IO data and reset our |pending_| state.
830 net::IOBuffer* buf = pending_buf_;
831 int buf_len = pending_buf_len_;
832 net::CompletionCallback* callback = pending_callback_;
833 pending_buf_ = NULL;
834 pending_buf_len_ = 0;
835 pending_callback_ = NULL;
836
837 int result = read_data_.result;
838 DCHECK(result != ERR_IO_PENDING);
839
840 if (read_data_.data) {
841 if (read_data_.data_len - read_offset_ > 0) {
842 result = std::min(buf_len, read_data_.data_len - read_offset_);
843 memcpy(buf->data(), read_data_.data + read_offset_, result);
844 read_offset_ += result;
845 if (read_offset_ == read_data_.data_len) {
846 need_read_data_ = true;
847 read_offset_ = 0;
848 }
849 } else {
850 result = 0; // EOF
851 }
852 }
853
854 if (read_data_.async) {
855 DCHECK(callback);
856 RunCallbackAsync(callback, result);
857 return net::ERR_IO_PENDING;
858 }
859 return result;
860 }
861
DeterministicMockTCPClientSocket(net::NetLog * net_log,net::DeterministicSocketData * data)862 DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket(
863 net::NetLog* net_log, net::DeterministicSocketData* data)
864 : MockClientSocket(net_log),
865 write_pending_(false),
866 write_callback_(NULL),
867 write_result_(0),
868 read_data_(),
869 read_buf_(NULL),
870 read_buf_len_(0),
871 read_pending_(false),
872 read_callback_(NULL),
873 data_(data),
874 was_used_to_convey_data_(false) {}
875
~DeterministicMockTCPClientSocket()876 DeterministicMockTCPClientSocket::~DeterministicMockTCPClientSocket() {}
877
CompleteWrite()878 void DeterministicMockTCPClientSocket::CompleteWrite() {
879 was_used_to_convey_data_ = true;
880 write_pending_ = false;
881 write_callback_->Run(write_result_);
882 }
883
CompleteRead()884 int DeterministicMockTCPClientSocket::CompleteRead() {
885 DCHECK_GT(read_buf_len_, 0);
886 DCHECK_LE(read_data_.data_len, read_buf_len_);
887 DCHECK(read_buf_);
888
889 was_used_to_convey_data_ = true;
890
891 if (read_data_.result == ERR_IO_PENDING)
892 read_data_ = data_->GetNextRead();
893 DCHECK_NE(ERR_IO_PENDING, read_data_.result);
894 // If read_data_.async is true, we do not need to wait, since this is already
895 // the callback. Therefore we don't even bother to check it.
896 int result = read_data_.result;
897
898 if (read_data_.data_len > 0) {
899 DCHECK(read_data_.data);
900 result = std::min(read_buf_len_, read_data_.data_len);
901 memcpy(read_buf_->data(), read_data_.data, result);
902 }
903
904 if (read_pending_) {
905 read_pending_ = false;
906 read_callback_->Run(result);
907 }
908
909 return result;
910 }
911
Write(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)912 int DeterministicMockTCPClientSocket::Write(
913 net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) {
914 DCHECK(buf);
915 DCHECK_GT(buf_len, 0);
916
917 if (!connected_)
918 return net::ERR_UNEXPECTED;
919
920 std::string data(buf->data(), buf_len);
921 net::MockWriteResult write_result = data_->OnWrite(data);
922
923 if (write_result.async) {
924 write_callback_ = callback;
925 write_result_ = write_result.result;
926 DCHECK(write_callback_ != NULL);
927 write_pending_ = true;
928 return net::ERR_IO_PENDING;
929 }
930
931 was_used_to_convey_data_ = true;
932 write_pending_ = false;
933 return write_result.result;
934 }
935
Read(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)936 int DeterministicMockTCPClientSocket::Read(
937 net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) {
938 if (!connected_)
939 return net::ERR_UNEXPECTED;
940
941 read_data_ = data_->GetNextRead();
942 // The buffer should always be big enough to contain all the MockRead data. To
943 // use small buffers, split the data into multiple MockReads.
944 DCHECK_LE(read_data_.data_len, buf_len);
945
946 read_buf_ = buf;
947 read_buf_len_ = buf_len;
948 read_callback_ = callback;
949
950 if (read_data_.async || (read_data_.result == ERR_IO_PENDING)) {
951 read_pending_ = true;
952 DCHECK(read_callback_);
953 return ERR_IO_PENDING;
954 }
955
956 was_used_to_convey_data_ = true;
957 return CompleteRead();
958 }
959
960 // TODO(erikchen): Support connect sequencing.
Connect(net::CompletionCallback * callback)961 int DeterministicMockTCPClientSocket::Connect(
962 net::CompletionCallback* callback) {
963 if (connected_)
964 return net::OK;
965 connected_ = true;
966 if (data_->connect_data().async) {
967 RunCallbackAsync(callback, data_->connect_data().result);
968 return net::ERR_IO_PENDING;
969 }
970 return data_->connect_data().result;
971 }
972
Disconnect()973 void DeterministicMockTCPClientSocket::Disconnect() {
974 MockClientSocket::Disconnect();
975 }
976
IsConnected() const977 bool DeterministicMockTCPClientSocket::IsConnected() const {
978 return connected_;
979 }
980
IsConnectedAndIdle() const981 bool DeterministicMockTCPClientSocket::IsConnectedAndIdle() const {
982 return IsConnected();
983 }
984
WasEverUsed() const985 bool DeterministicMockTCPClientSocket::WasEverUsed() const {
986 return was_used_to_convey_data_;
987 }
988
UsingTCPFastOpen() const989 bool DeterministicMockTCPClientSocket::UsingTCPFastOpen() const {
990 return false;
991 }
992
OnReadComplete(const MockRead & data)993 void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {}
994
995 class MockSSLClientSocket::ConnectCallback
996 : public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> {
997 public:
ConnectCallback(MockSSLClientSocket * ssl_client_socket,net::CompletionCallback * user_callback,int rv)998 ConnectCallback(MockSSLClientSocket *ssl_client_socket,
999 net::CompletionCallback* user_callback,
1000 int rv)
1001 : ALLOW_THIS_IN_INITIALIZER_LIST(
1002 net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>(
1003 this, &ConnectCallback::Wrapper)),
1004 ssl_client_socket_(ssl_client_socket),
1005 user_callback_(user_callback),
1006 rv_(rv) {
1007 }
1008
1009 private:
Wrapper(int rv)1010 void Wrapper(int rv) {
1011 if (rv_ == net::OK)
1012 ssl_client_socket_->connected_ = true;
1013 user_callback_->Run(rv_);
1014 delete this;
1015 }
1016
1017 MockSSLClientSocket* ssl_client_socket_;
1018 net::CompletionCallback* user_callback_;
1019 int rv_;
1020 };
1021
MockSSLClientSocket(net::ClientSocketHandle * transport_socket,const HostPortPair & host_port_pair,const net::SSLConfig & ssl_config,SSLHostInfo * ssl_host_info,net::SSLSocketDataProvider * data)1022 MockSSLClientSocket::MockSSLClientSocket(
1023 net::ClientSocketHandle* transport_socket,
1024 const HostPortPair& host_port_pair,
1025 const net::SSLConfig& ssl_config,
1026 SSLHostInfo* ssl_host_info,
1027 net::SSLSocketDataProvider* data)
1028 : MockClientSocket(transport_socket->socket()->NetLog().net_log()),
1029 transport_(transport_socket),
1030 data_(data),
1031 is_npn_state_set_(false),
1032 new_npn_value_(false) {
1033 DCHECK(data_);
1034 delete ssl_host_info; // we take ownership but don't use it.
1035 }
1036
~MockSSLClientSocket()1037 MockSSLClientSocket::~MockSSLClientSocket() {
1038 Disconnect();
1039 }
1040
Read(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)1041 int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len,
1042 net::CompletionCallback* callback) {
1043 return transport_->socket()->Read(buf, buf_len, callback);
1044 }
1045
Write(net::IOBuffer * buf,int buf_len,net::CompletionCallback * callback)1046 int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len,
1047 net::CompletionCallback* callback) {
1048 return transport_->socket()->Write(buf, buf_len, callback);
1049 }
1050
Connect(net::CompletionCallback * callback)1051 int MockSSLClientSocket::Connect(net::CompletionCallback* callback) {
1052 ConnectCallback* connect_callback = new ConnectCallback(
1053 this, callback, data_->connect.result);
1054 int rv = transport_->socket()->Connect(connect_callback);
1055 if (rv == net::OK) {
1056 delete connect_callback;
1057 if (data_->connect.result == net::OK)
1058 connected_ = true;
1059 if (data_->connect.async) {
1060 RunCallbackAsync(callback, data_->connect.result);
1061 return net::ERR_IO_PENDING;
1062 }
1063 return data_->connect.result;
1064 }
1065 return rv;
1066 }
1067
Disconnect()1068 void MockSSLClientSocket::Disconnect() {
1069 MockClientSocket::Disconnect();
1070 if (transport_->socket() != NULL)
1071 transport_->socket()->Disconnect();
1072 }
1073
IsConnected() const1074 bool MockSSLClientSocket::IsConnected() const {
1075 return transport_->socket()->IsConnected();
1076 }
1077
WasEverUsed() const1078 bool MockSSLClientSocket::WasEverUsed() const {
1079 return transport_->socket()->WasEverUsed();
1080 }
1081
UsingTCPFastOpen() const1082 bool MockSSLClientSocket::UsingTCPFastOpen() const {
1083 return transport_->socket()->UsingTCPFastOpen();
1084 }
1085
GetSSLInfo(net::SSLInfo * ssl_info)1086 void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
1087 ssl_info->Reset();
1088 ssl_info->cert = data_->cert_;
1089 }
1090
GetSSLCertRequestInfo(net::SSLCertRequestInfo * cert_request_info)1091 void MockSSLClientSocket::GetSSLCertRequestInfo(
1092 net::SSLCertRequestInfo* cert_request_info) {
1093 DCHECK(cert_request_info);
1094 if (data_->cert_request_info) {
1095 cert_request_info->host_and_port =
1096 data_->cert_request_info->host_and_port;
1097 cert_request_info->client_certs = data_->cert_request_info->client_certs;
1098 } else {
1099 cert_request_info->Reset();
1100 }
1101 }
1102
GetNextProto(std::string * proto)1103 SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto(
1104 std::string* proto) {
1105 *proto = data_->next_proto;
1106 return data_->next_proto_status;
1107 }
1108
was_npn_negotiated() const1109 bool MockSSLClientSocket::was_npn_negotiated() const {
1110 if (is_npn_state_set_)
1111 return new_npn_value_;
1112 return data_->was_npn_negotiated;
1113 }
1114
set_was_npn_negotiated(bool negotiated)1115 bool MockSSLClientSocket::set_was_npn_negotiated(bool negotiated) {
1116 is_npn_state_set_ = true;
1117 return new_npn_value_ = negotiated;
1118 }
1119
OnReadComplete(const MockRead & data)1120 void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
1121 NOTIMPLEMENTED();
1122 }
1123
TestSocketRequest(std::vector<TestSocketRequest * > * request_order,size_t * completion_count)1124 TestSocketRequest::TestSocketRequest(
1125 std::vector<TestSocketRequest*>* request_order,
1126 size_t* completion_count)
1127 : request_order_(request_order),
1128 completion_count_(completion_count) {
1129 DCHECK(request_order);
1130 DCHECK(completion_count);
1131 }
1132
~TestSocketRequest()1133 TestSocketRequest::~TestSocketRequest() {
1134 }
1135
WaitForResult()1136 int TestSocketRequest::WaitForResult() {
1137 return callback_.WaitForResult();
1138 }
1139
RunWithParams(const Tuple1<int> & params)1140 void TestSocketRequest::RunWithParams(const Tuple1<int>& params) {
1141 callback_.RunWithParams(params);
1142 (*completion_count_)++;
1143 request_order_->push_back(this);
1144 }
1145
1146 // static
1147 const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
1148
1149 // static
1150 const int ClientSocketPoolTest::kRequestNotFound = -2;
1151
ClientSocketPoolTest()1152 ClientSocketPoolTest::ClientSocketPoolTest() : completion_count_(0) {}
~ClientSocketPoolTest()1153 ClientSocketPoolTest::~ClientSocketPoolTest() {}
1154
GetOrderOfRequest(size_t index) const1155 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const {
1156 index--;
1157 if (index >= requests_.size())
1158 return kIndexOutOfBounds;
1159
1160 for (size_t i = 0; i < request_order_.size(); i++)
1161 if (requests_[index] == request_order_[i])
1162 return i + 1;
1163
1164 return kRequestNotFound;
1165 }
1166
ReleaseOneConnection(KeepAlive keep_alive)1167 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
1168 ScopedVector<TestSocketRequest>::iterator i;
1169 for (i = requests_.begin(); i != requests_.end(); ++i) {
1170 if ((*i)->handle()->is_initialized()) {
1171 if (keep_alive == NO_KEEP_ALIVE)
1172 (*i)->handle()->socket()->Disconnect();
1173 (*i)->handle()->Reset();
1174 MessageLoop::current()->RunAllPending();
1175 return true;
1176 }
1177 }
1178 return false;
1179 }
1180
ReleaseAllConnections(KeepAlive keep_alive)1181 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
1182 bool released_one;
1183 do {
1184 released_one = ReleaseOneConnection(keep_alive);
1185 } while (released_one);
1186 }
1187
MockConnectJob(ClientSocket * socket,ClientSocketHandle * handle,CompletionCallback * callback)1188 MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
1189 ClientSocket* socket,
1190 ClientSocketHandle* handle,
1191 CompletionCallback* callback)
1192 : socket_(socket),
1193 handle_(handle),
1194 user_callback_(callback),
1195 ALLOW_THIS_IN_INITIALIZER_LIST(
1196 connect_callback_(this, &MockConnectJob::OnConnect)) {
1197 }
1198
~MockConnectJob()1199 MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() {}
1200
Connect()1201 int MockTransportClientSocketPool::MockConnectJob::Connect() {
1202 int rv = socket_->Connect(&connect_callback_);
1203 if (rv == OK) {
1204 user_callback_ = NULL;
1205 OnConnect(OK);
1206 }
1207 return rv;
1208 }
1209
CancelHandle(const ClientSocketHandle * handle)1210 bool MockTransportClientSocketPool::MockConnectJob::CancelHandle(
1211 const ClientSocketHandle* handle) {
1212 if (handle != handle_)
1213 return false;
1214 socket_.reset();
1215 handle_ = NULL;
1216 user_callback_ = NULL;
1217 return true;
1218 }
1219
OnConnect(int rv)1220 void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
1221 if (!socket_.get())
1222 return;
1223 if (rv == OK) {
1224 handle_->set_socket(socket_.release());
1225 } else {
1226 socket_.reset();
1227 }
1228
1229 handle_ = NULL;
1230
1231 if (user_callback_) {
1232 CompletionCallback* callback = user_callback_;
1233 user_callback_ = NULL;
1234 callback->Run(rv);
1235 }
1236 }
1237
MockTransportClientSocketPool(int max_sockets,int max_sockets_per_group,ClientSocketPoolHistograms * histograms,ClientSocketFactory * socket_factory)1238 MockTransportClientSocketPool::MockTransportClientSocketPool(
1239 int max_sockets,
1240 int max_sockets_per_group,
1241 ClientSocketPoolHistograms* histograms,
1242 ClientSocketFactory* socket_factory)
1243 : TransportClientSocketPool(max_sockets, max_sockets_per_group, histograms,
1244 NULL, NULL, NULL),
1245 client_socket_factory_(socket_factory),
1246 release_count_(0),
1247 cancel_count_(0) {
1248 }
1249
~MockTransportClientSocketPool()1250 MockTransportClientSocketPool::~MockTransportClientSocketPool() {}
1251
RequestSocket(const std::string & group_name,const void * socket_params,RequestPriority priority,ClientSocketHandle * handle,CompletionCallback * callback,const BoundNetLog & net_log)1252 int MockTransportClientSocketPool::RequestSocket(const std::string& group_name,
1253 const void* socket_params,
1254 RequestPriority priority,
1255 ClientSocketHandle* handle,
1256 CompletionCallback* callback,
1257 const BoundNetLog& net_log) {
1258 ClientSocket* socket = client_socket_factory_->CreateTransportClientSocket(
1259 AddressList(), net_log.net_log(), net::NetLog::Source());
1260 MockConnectJob* job = new MockConnectJob(socket, handle, callback);
1261 job_list_.push_back(job);
1262 handle->set_pool_id(1);
1263 return job->Connect();
1264 }
1265
CancelRequest(const std::string & group_name,ClientSocketHandle * handle)1266 void MockTransportClientSocketPool::CancelRequest(const std::string& group_name,
1267 ClientSocketHandle* handle) {
1268 std::vector<MockConnectJob*>::iterator i;
1269 for (i = job_list_.begin(); i != job_list_.end(); ++i) {
1270 if ((*i)->CancelHandle(handle)) {
1271 cancel_count_++;
1272 break;
1273 }
1274 }
1275 }
1276
ReleaseSocket(const std::string & group_name,ClientSocket * socket,int id)1277 void MockTransportClientSocketPool::ReleaseSocket(const std::string& group_name,
1278 ClientSocket* socket, int id) {
1279 EXPECT_EQ(1, id);
1280 release_count_++;
1281 delete socket;
1282 }
1283
DeterministicMockClientSocketFactory()1284 DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {}
1285
~DeterministicMockClientSocketFactory()1286 DeterministicMockClientSocketFactory::~DeterministicMockClientSocketFactory() {}
1287
AddSocketDataProvider(DeterministicSocketData * data)1288 void DeterministicMockClientSocketFactory::AddSocketDataProvider(
1289 DeterministicSocketData* data) {
1290 mock_data_.Add(data);
1291 }
1292
AddSSLSocketDataProvider(SSLSocketDataProvider * data)1293 void DeterministicMockClientSocketFactory::AddSSLSocketDataProvider(
1294 SSLSocketDataProvider* data) {
1295 mock_ssl_data_.Add(data);
1296 }
1297
ResetNextMockIndexes()1298 void DeterministicMockClientSocketFactory::ResetNextMockIndexes() {
1299 mock_data_.ResetNextIndex();
1300 mock_ssl_data_.ResetNextIndex();
1301 }
1302
1303 MockSSLClientSocket* DeterministicMockClientSocketFactory::
GetMockSSLClientSocket(size_t index) const1304 GetMockSSLClientSocket(size_t index) const {
1305 DCHECK_LT(index, ssl_client_sockets_.size());
1306 return ssl_client_sockets_[index];
1307 }
1308
CreateTransportClientSocket(const AddressList & addresses,net::NetLog * net_log,const net::NetLog::Source & source)1309 ClientSocket* DeterministicMockClientSocketFactory::CreateTransportClientSocket(
1310 const AddressList& addresses,
1311 net::NetLog* net_log,
1312 const net::NetLog::Source& source) {
1313 DeterministicSocketData* data_provider = mock_data().GetNext();
1314 DeterministicMockTCPClientSocket* socket =
1315 new DeterministicMockTCPClientSocket(net_log, data_provider);
1316 data_provider->set_socket(socket->AsWeakPtr());
1317 tcp_client_sockets().push_back(socket);
1318 return socket;
1319 }
1320
CreateSSLClientSocket(ClientSocketHandle * transport_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config,SSLHostInfo * ssl_host_info,CertVerifier * cert_verifier,DnsCertProvenanceChecker * dns_cert_checker)1321 SSLClientSocket* DeterministicMockClientSocketFactory::CreateSSLClientSocket(
1322 ClientSocketHandle* transport_socket,
1323 const HostPortPair& host_and_port,
1324 const SSLConfig& ssl_config,
1325 SSLHostInfo* ssl_host_info,
1326 CertVerifier* cert_verifier,
1327 DnsCertProvenanceChecker* dns_cert_checker) {
1328 MockSSLClientSocket* socket =
1329 new MockSSLClientSocket(transport_socket, host_and_port, ssl_config,
1330 ssl_host_info, mock_ssl_data_.GetNext());
1331 ssl_client_sockets_.push_back(socket);
1332 return socket;
1333 }
1334
ClearSSLSessionCache()1335 void DeterministicMockClientSocketFactory::ClearSSLSessionCache() {
1336 }
1337
MockSOCKSClientSocketPool(int max_sockets,int max_sockets_per_group,ClientSocketPoolHistograms * histograms,TransportClientSocketPool * transport_pool)1338 MockSOCKSClientSocketPool::MockSOCKSClientSocketPool(
1339 int max_sockets,
1340 int max_sockets_per_group,
1341 ClientSocketPoolHistograms* histograms,
1342 TransportClientSocketPool* transport_pool)
1343 : SOCKSClientSocketPool(max_sockets, max_sockets_per_group, histograms,
1344 NULL, transport_pool, NULL),
1345 transport_pool_(transport_pool) {
1346 }
1347
~MockSOCKSClientSocketPool()1348 MockSOCKSClientSocketPool::~MockSOCKSClientSocketPool() {}
1349
RequestSocket(const std::string & group_name,const void * socket_params,RequestPriority priority,ClientSocketHandle * handle,CompletionCallback * callback,const BoundNetLog & net_log)1350 int MockSOCKSClientSocketPool::RequestSocket(const std::string& group_name,
1351 const void* socket_params,
1352 RequestPriority priority,
1353 ClientSocketHandle* handle,
1354 CompletionCallback* callback,
1355 const BoundNetLog& net_log) {
1356 return transport_pool_->RequestSocket(group_name,
1357 socket_params,
1358 priority,
1359 handle,
1360 callback,
1361 net_log);
1362 }
1363
CancelRequest(const std::string & group_name,ClientSocketHandle * handle)1364 void MockSOCKSClientSocketPool::CancelRequest(
1365 const std::string& group_name,
1366 ClientSocketHandle* handle) {
1367 return transport_pool_->CancelRequest(group_name, handle);
1368 }
1369
ReleaseSocket(const std::string & group_name,ClientSocket * socket,int id)1370 void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
1371 ClientSocket* socket, int id) {
1372 return transport_pool_->ReleaseSocket(group_name, socket, id);
1373 }
1374
1375 const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 };
1376 const int kSOCKS5GreetRequestLength = arraysize(kSOCKS5GreetRequest);
1377
1378 const char kSOCKS5GreetResponse[] = { 0x05, 0x00 };
1379 const int kSOCKS5GreetResponseLength = arraysize(kSOCKS5GreetResponse);
1380
1381 const char kSOCKS5OkRequest[] =
1382 { 0x05, 0x01, 0x00, 0x03, 0x04, 'h', 'o', 's', 't', 0x00, 0x50 };
1383 const int kSOCKS5OkRequestLength = arraysize(kSOCKS5OkRequest);
1384
1385 const char kSOCKS5OkResponse[] =
1386 { 0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x00, 0x50 };
1387 const int kSOCKS5OkResponseLength = arraysize(kSOCKS5OkResponse);
1388
1389 } // namespace net
1390