• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
6 #define NET_SOCKET_SOCKET_TEST_UTIL_H_
7 #pragma once
8 
9 #include <cstring>
10 #include <deque>
11 #include <string>
12 #include <vector>
13 
14 #include "base/basictypes.h"
15 #include "base/callback.h"
16 #include "base/logging.h"
17 #include "base/memory/scoped_ptr.h"
18 #include "base/memory/scoped_vector.h"
19 #include "base/memory/weak_ptr.h"
20 #include "base/string16.h"
21 #include "net/base/address_list.h"
22 #include "net/base/io_buffer.h"
23 #include "net/base/net_errors.h"
24 #include "net/base/net_log.h"
25 #include "net/base/ssl_config_service.h"
26 #include "net/base/test_completion_callback.h"
27 #include "net/http/http_auth_controller.h"
28 #include "net/http/http_proxy_client_socket_pool.h"
29 #include "net/socket/client_socket_factory.h"
30 #include "net/socket/client_socket_handle.h"
31 #include "net/socket/socks_client_socket_pool.h"
32 #include "net/socket/ssl_client_socket.h"
33 #include "net/socket/ssl_client_socket_pool.h"
34 #include "net/socket/transport_client_socket_pool.h"
35 #include "testing/gtest/include/gtest/gtest.h"
36 
37 namespace net {
38 
39 enum {
40   // A private network error code used by the socket test utility classes.
41   // If the |result| member of a MockRead is
42   // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a
43   // marker that indicates the peer will close the connection after the next
44   // MockRead.  The other members of that MockRead are ignored.
45   ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000,
46 };
47 
48 class ClientSocket;
49 class MockClientSocket;
50 class SSLClientSocket;
51 class SSLHostInfo;
52 
53 struct MockConnect {
54   // Asynchronous connection success.
MockConnectMockConnect55   MockConnect() : async(true), result(OK) { }
MockConnectMockConnect56   MockConnect(bool a, int r) : async(a), result(r) { }
57 
58   bool async;
59   int result;
60 };
61 
62 struct MockRead {
63   // Flag to indicate that the message loop should be terminated.
64   enum {
65     STOPLOOP = 1 << 31
66   };
67 
68   // Default
MockReadMockRead69   MockRead() : async(false), result(0), data(NULL), data_len(0),
70       sequence_number(0), time_stamp(base::Time::Now()) {}
71 
72   // Read failure (no data).
MockReadMockRead73   MockRead(bool async, int result) : async(async) , result(result), data(NULL),
74       data_len(0), sequence_number(0), time_stamp(base::Time::Now()) { }
75 
76   // Read failure (no data), with sequence information.
MockReadMockRead77   MockRead(bool async, int result, int seq) : async(async) , result(result),
78       data(NULL), data_len(0), sequence_number(seq),
79       time_stamp(base::Time::Now()) { }
80 
81   // Asynchronous read success (inferred data length).
MockReadMockRead82   explicit MockRead(const char* data) : async(true),  result(0), data(data),
83       data_len(strlen(data)), sequence_number(0),
84       time_stamp(base::Time::Now()) { }
85 
86   // Read success (inferred data length).
MockReadMockRead87   MockRead(bool async, const char* data) : async(async), result(0), data(data),
88       data_len(strlen(data)), sequence_number(0),
89       time_stamp(base::Time::Now()) { }
90 
91   // Read success.
MockReadMockRead92   MockRead(bool async, const char* data, int data_len) : async(async),
93       result(0), data(data), data_len(data_len), sequence_number(0),
94       time_stamp(base::Time::Now()) { }
95 
96   // Read success (inferred data length) with sequence information.
MockReadMockRead97   MockRead(bool async, int seq, const char* data) : async(async),
98       result(0), data(data), data_len(strlen(data)), sequence_number(seq),
99       time_stamp(base::Time::Now()) { }
100 
101   // Read success with sequence information.
MockReadMockRead102   MockRead(bool async, const char* data, int data_len, int seq) : async(async),
103       result(0), data(data), data_len(data_len), sequence_number(seq),
104       time_stamp(base::Time::Now()) { }
105 
106   bool async;
107   int result;
108   const char* data;
109   int data_len;
110 
111   // For OrderedSocketData, which only allows reads to occur in a particular
112   // sequence.  If a read occurs before the given |sequence_number| is reached,
113   // an ERR_IO_PENDING is returned.
114   int sequence_number;      // The sequence number at which a read is allowed
115                             // to occur.
116   base::Time time_stamp;    // The time stamp at which the operation occurred.
117 };
118 
119 // MockWrite uses the same member fields as MockRead, but with different
120 // meanings. The expected input to MockTCPClientSocket::Write() is given
121 // by {data, data_len}, and the return value of Write() is controlled by
122 // {async, result}.
123 typedef MockRead MockWrite;
124 
125 struct MockWriteResult {
MockWriteResultMockWriteResult126   MockWriteResult(bool async, int result) : async(async), result(result) {}
127 
128   bool async;
129   int result;
130 };
131 
132 // The SocketDataProvider is an interface used by the MockClientSocket
133 // for getting data about individual reads and writes on the socket.
134 class SocketDataProvider {
135  public:
SocketDataProvider()136   SocketDataProvider() : socket_(NULL) {}
137 
~SocketDataProvider()138   virtual ~SocketDataProvider() {}
139 
140   // Returns the buffer and result code for the next simulated read.
141   // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller
142   // that it will be called via the MockClientSocket::OnReadComplete()
143   // function at a later time.
144   virtual MockRead GetNextRead() = 0;
145   virtual MockWriteResult OnWrite(const std::string& data) = 0;
146   virtual void Reset() = 0;
147 
148   // Accessor for the socket which is using the SocketDataProvider.
socket()149   MockClientSocket* socket() { return socket_; }
set_socket(MockClientSocket * socket)150   void set_socket(MockClientSocket* socket) { socket_ = socket; }
151 
connect_data()152   MockConnect connect_data() const { return connect_; }
set_connect_data(const MockConnect & connect)153   void set_connect_data(const MockConnect& connect) { connect_ = connect; }
154 
155  private:
156   MockConnect connect_;
157   MockClientSocket* socket_;
158 
159   DISALLOW_COPY_AND_ASSIGN(SocketDataProvider);
160 };
161 
162 // SocketDataProvider which responds based on static tables of mock reads and
163 // writes.
164 class StaticSocketDataProvider : public SocketDataProvider {
165  public:
166   StaticSocketDataProvider();
167   StaticSocketDataProvider(MockRead* reads, size_t reads_count,
168                            MockWrite* writes, size_t writes_count);
169   virtual ~StaticSocketDataProvider();
170 
171   // These functions get access to the next available read and write data.
172   const MockRead& PeekRead() const;
173   const MockWrite& PeekWrite() const;
174   // These functions get random access to the read and write data, for timing.
175   const MockRead& PeekRead(size_t index) const;
176   const MockWrite& PeekWrite(size_t index) const;
read_index()177   size_t read_index() const { return read_index_; }
write_index()178   size_t write_index() const { return write_index_; }
read_count()179   size_t read_count() const { return read_count_; }
write_count()180   size_t write_count() const { return write_count_; }
181 
at_read_eof()182   bool at_read_eof() const { return read_index_ >= read_count_; }
at_write_eof()183   bool at_write_eof() const { return write_index_ >= write_count_; }
184 
CompleteRead()185   virtual void CompleteRead() {}
186 
187   // SocketDataProvider methods:
188   virtual MockRead GetNextRead();
189   virtual MockWriteResult OnWrite(const std::string& data);
190   virtual void Reset();
191 
192  private:
193   MockRead* reads_;
194   size_t read_index_;
195   size_t read_count_;
196   MockWrite* writes_;
197   size_t write_index_;
198   size_t write_count_;
199 
200   DISALLOW_COPY_AND_ASSIGN(StaticSocketDataProvider);
201 };
202 
203 // SocketDataProvider which can make decisions about next mock reads based on
204 // received writes. It can also be used to enforce order of operations, for
205 // example that tested code must send the "Hello!" message before receiving
206 // response. This is useful for testing conversation-like protocols like FTP.
207 class DynamicSocketDataProvider : public SocketDataProvider {
208  public:
209   DynamicSocketDataProvider();
210   virtual ~DynamicSocketDataProvider();
211 
short_read_limit()212   int short_read_limit() const { return short_read_limit_; }
set_short_read_limit(int limit)213   void set_short_read_limit(int limit) { short_read_limit_ = limit; }
214 
allow_unconsumed_reads(bool allow)215   void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; }
216 
217   // SocketDataProvider methods:
218   virtual MockRead GetNextRead();
219   virtual MockWriteResult OnWrite(const std::string& data) = 0;
220   virtual void Reset();
221 
222  protected:
223   // The next time there is a read from this socket, it will return |data|.
224   // Before calling SimulateRead next time, the previous data must be consumed.
225   void SimulateRead(const char* data, size_t length);
SimulateRead(const char * data)226   void SimulateRead(const char* data) {
227     SimulateRead(data, std::strlen(data));
228   }
229 
230  private:
231   std::deque<MockRead> reads_;
232 
233   // Max number of bytes we will read at a time. 0 means no limit.
234   int short_read_limit_;
235 
236   // If true, we'll not require the client to consume all data before we
237   // mock the next read.
238   bool allow_unconsumed_reads_;
239 
240   DISALLOW_COPY_AND_ASSIGN(DynamicSocketDataProvider);
241 };
242 
243 // SSLSocketDataProviders only need to keep track of the return code from calls
244 // to Connect().
245 struct SSLSocketDataProvider {
246   SSLSocketDataProvider(bool async, int result);
247   ~SSLSocketDataProvider();
248 
249   MockConnect connect;
250   SSLClientSocket::NextProtoStatus next_proto_status;
251   std::string next_proto;
252   bool was_npn_negotiated;
253   net::SSLCertRequestInfo* cert_request_info;
254   scoped_refptr<X509Certificate> cert_;
255 };
256 
257 // A DataProvider where the client must write a request before the reads (e.g.
258 // the response) will complete.
259 class DelayedSocketData : public StaticSocketDataProvider,
260                           public base::RefCounted<DelayedSocketData> {
261  public:
262   // |write_delay| the number of MockWrites to complete before allowing
263   //               a MockRead to complete.
264   // |reads| the list of MockRead completions.
265   // |writes| the list of MockWrite completions.
266   // Note: All MockReads and MockWrites must be async.
267   // Note: The MockRead and MockWrite lists musts end with a EOF
268   //       e.g. a MockRead(true, 0, 0);
269   DelayedSocketData(int write_delay,
270                     MockRead* reads, size_t reads_count,
271                     MockWrite* writes, size_t writes_count);
272 
273   // |connect| the result for the connect phase.
274   // |reads| the list of MockRead completions.
275   // |write_delay| the number of MockWrites to complete before allowing
276   //               a MockRead to complete.
277   // |writes| the list of MockWrite completions.
278   // Note: All MockReads and MockWrites must be async.
279   // Note: The MockRead and MockWrite lists musts end with a EOF
280   //       e.g. a MockRead(true, 0, 0);
281   DelayedSocketData(const MockConnect& connect, int write_delay,
282                     MockRead* reads, size_t reads_count,
283                     MockWrite* writes, size_t writes_count);
284   ~DelayedSocketData();
285 
286   void ForceNextRead();
287 
288   // StaticSocketDataProvider:
289   virtual MockRead GetNextRead();
290   virtual MockWriteResult OnWrite(const std::string& data);
291   virtual void Reset();
292   virtual void CompleteRead();
293 
294  private:
295   int write_delay_;
296   ScopedRunnableMethodFactory<DelayedSocketData> factory_;
297 };
298 
299 // A DataProvider where the reads are ordered.
300 // If a read is requested before its sequence number is reached, we return an
301 // ERR_IO_PENDING (that way we don't have to explicitly add a MockRead just to
302 // wait).
303 // The sequence number is incremented on every read and write operation.
304 // The message loop may be interrupted by setting the high bit of the sequence
305 // number in the MockRead's sequence number.  When that MockRead is reached,
306 // we post a Quit message to the loop.  This allows us to interrupt the reading
307 // of data before a complete message has arrived, and provides support for
308 // testing server push when the request is issued while the response is in the
309 // middle of being received.
310 class OrderedSocketData : public StaticSocketDataProvider,
311                           public base::RefCounted<OrderedSocketData> {
312  public:
313   // |reads| the list of MockRead completions.
314   // |writes| the list of MockWrite completions.
315   // Note: All MockReads and MockWrites must be async.
316   // Note: The MockRead and MockWrite lists musts end with a EOF
317   //       e.g. a MockRead(true, 0, 0);
318   OrderedSocketData(MockRead* reads, size_t reads_count,
319                     MockWrite* writes, size_t writes_count);
320 
321   // |connect| the result for the connect phase.
322   // |reads| the list of MockRead completions.
323   // |writes| the list of MockWrite completions.
324   // Note: All MockReads and MockWrites must be async.
325   // Note: The MockRead and MockWrite lists musts end with a EOF
326   //       e.g. a MockRead(true, 0, 0);
327   OrderedSocketData(const MockConnect& connect,
328                     MockRead* reads, size_t reads_count,
329                     MockWrite* writes, size_t writes_count);
330 
SetCompletionCallback(CompletionCallback * callback)331   void SetCompletionCallback(CompletionCallback* callback) {
332     callback_ = callback;
333   }
334 
335   // Posts a quit message to the current message loop, if one is running.
336   void EndLoop();
337 
338   // StaticSocketDataProvider:
339   virtual MockRead GetNextRead();
340   virtual MockWriteResult OnWrite(const std::string& data);
341   virtual void Reset();
342   virtual void CompleteRead();
343 
344  private:
345   friend class base::RefCounted<OrderedSocketData>;
346   virtual ~OrderedSocketData();
347 
348   int sequence_number_;
349   int loop_stop_stage_;
350   CompletionCallback* callback_;
351   bool blocked_;
352   ScopedRunnableMethodFactory<OrderedSocketData> factory_;
353 };
354 
355 class DeterministicMockTCPClientSocket;
356 
357 // This class gives the user full control over the network activity,
358 // specifically the timing of the COMPLETION of I/O operations.  Regardless of
359 // the order in which I/O operations are initiated, this class ensures that they
360 // complete in the correct order.
361 //
362 // Network activity is modeled as a sequence of numbered steps which is
363 // incremented whenever an I/O operation completes.  This can happen under two
364 // different circumstances:
365 //
366 // 1) Performing a synchronous I/O operation.  (Invoking Read() or Write()
367 //    when the corresponding MockRead or MockWrite is marked !async).
368 // 2) Running the Run() method of this class.  The run method will invoke
369 //    the current MessageLoop, running all pending events, and will then
370 //    invoke any pending IO callbacks.
371 //
372 // In addition, this class allows for I/O processing to "stop" at a specified
373 // step, by calling SetStop(int) or StopAfter(int).  Initiating an I/O operation
374 // by calling Read() or Write() while stopped is permitted if the operation is
375 // asynchronous.  It is an error to perform synchronous I/O while stopped.
376 //
377 // When creating the MockReads and MockWrites, note that the sequence number
378 // refers to the number of the step in which the I/O will complete.  In the
379 // case of synchronous I/O, this will be the same step as the I/O is initiated.
380 // However, in the case of asynchronous I/O, this I/O may be initiated in
381 // a much earlier step. Furthermore, when the a Read() or Write() is separated
382 // from its completion by other Read() or Writes()'s, it can not be marked
383 // synchronous.  If it is, ERR_UNUEXPECTED will be returned indicating that a
384 // synchronous Read() or Write() could not be completed synchronously because of
385 // the specific ordering constraints.
386 //
387 // Sequence numbers are preserved across both reads and writes. There should be
388 // no gaps in sequence numbers, and no repeated sequence numbers. i.e.
389 //  MockRead reads[] = {
390 //    MockRead(false, "first read", length, 0)   // sync
391 //    MockRead(true, "second read", length, 2)   // async
392 //  };
393 //  MockWrite writes[] = {
394 //    MockWrite(true, "first write", length, 1),    // async
395 //    MockWrite(false, "second write", length, 3),  // sync
396 //  };
397 //
398 // Example control flow:
399 // Read() is called.  The current step is 0.  The first available read is
400 // synchronous, so the call to Read() returns length.  The current step is
401 // now 1.  Next, Read() is called again.  The next available read can
402 // not be completed until step 2, so Read() returns ERR_IO_PENDING.  The current
403 // step is still 1.  Write is called().  The first available write is able to
404 // complete in this step, but is marked asynchronous.  Write() returns
405 // ERR_IO_PENDING.  The current step is still 1.  At this point RunFor(1) is
406 // called which will cause the write callback to be invoked, and will then
407 // stop.  The current state is now 2.  RunFor(1) is called again, which
408 // causes the read callback to be invoked, and will then stop.  Then current
409 // step is 2.  Write() is called again.  Then next available write is
410 // synchronous so the call to Write() returns length.
411 //
412 // For examples of how to use this class, see:
413 //   deterministic_socket_data_unittests.cc
414 class DeterministicSocketData : public StaticSocketDataProvider,
415     public base::RefCounted<DeterministicSocketData> {
416  public:
417   // |reads| the list of MockRead completions.
418   // |writes| the list of MockWrite completions.
419   DeterministicSocketData(MockRead* reads, size_t reads_count,
420                           MockWrite* writes, size_t writes_count);
421   virtual ~DeterministicSocketData();
422 
423   // Consume all the data up to the give stop point (via SetStop()).
424   void Run();
425 
426   // Set the stop point to be |steps| from now, and then invoke Run().
427   void RunFor(int steps);
428 
429   // Stop at step |seq|, which must be in the future.
430   virtual void SetStop(int seq);
431 
432   // Stop |seq| steps after the current step.
433   virtual void StopAfter(int seq);
stopped()434   bool stopped() const { return stopped_; }
SetStopped(bool val)435   void SetStopped(bool val) { stopped_ = val; }
current_read()436   MockRead& current_read() { return current_read_; }
current_write()437   MockRead& current_write() { return current_write_; }
sequence_number()438   int sequence_number() const { return sequence_number_; }
set_socket(base::WeakPtr<DeterministicMockTCPClientSocket> socket)439   void set_socket(base::WeakPtr<DeterministicMockTCPClientSocket> socket) {
440     socket_ = socket;
441   }
442 
443   // StaticSocketDataProvider:
444 
445   // When the socket calls Read(), that calls GetNextRead(), and expects either
446   // ERR_IO_PENDING or data.
447   virtual MockRead GetNextRead();
448 
449   // When the socket calls Write(), it always completes synchronously. OnWrite()
450   // checks to make sure the written data matches the expected data. The
451   // callback will not be invoked until its sequence number is reached.
452   virtual MockWriteResult OnWrite(const std::string& data);
453   virtual void Reset();
CompleteRead()454   virtual void CompleteRead() {}
455 
456  private:
457   // Invoke the read and write callbacks, if the timing is appropriate.
458   void InvokeCallbacks();
459 
460   void NextStep();
461 
462   int sequence_number_;
463   MockRead current_read_;
464   MockWrite current_write_;
465   int stopping_sequence_number_;
466   bool stopped_;
467   base::WeakPtr<DeterministicMockTCPClientSocket> socket_;
468   bool print_debug_;
469 };
470 
471 // Holds an array of SocketDataProvider elements.  As Mock{TCP,SSL}ClientSocket
472 // objects get instantiated, they take their data from the i'th element of this
473 // array.
474 template<typename T>
475 class SocketDataProviderArray {
476  public:
SocketDataProviderArray()477   SocketDataProviderArray() : next_index_(0) {
478   }
479 
GetNext()480   T* GetNext() {
481     DCHECK_LT(next_index_, data_providers_.size());
482     return data_providers_[next_index_++];
483   }
484 
Add(T * data_provider)485   void Add(T* data_provider) {
486     DCHECK(data_provider);
487     data_providers_.push_back(data_provider);
488   }
489 
ResetNextIndex()490   void ResetNextIndex() {
491     next_index_ = 0;
492   }
493 
494  private:
495   // Index of the next |data_providers_| element to use. Not an iterator
496   // because those are invalidated on vector reallocation.
497   size_t next_index_;
498 
499   // SocketDataProviders to be returned.
500   std::vector<T*> data_providers_;
501 };
502 
503 class MockTCPClientSocket;
504 class MockSSLClientSocket;
505 
506 // ClientSocketFactory which contains arrays of sockets of each type.
507 // You should first fill the arrays using AddMock{SSL,}Socket. When the factory
508 // is asked to create a socket, it takes next entry from appropriate array.
509 // You can use ResetNextMockIndexes to reset that next entry index for all mock
510 // socket types.
511 class MockClientSocketFactory : public ClientSocketFactory {
512  public:
513   MockClientSocketFactory();
514   virtual ~MockClientSocketFactory();
515 
516   void AddSocketDataProvider(SocketDataProvider* socket);
517   void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
518   void ResetNextMockIndexes();
519 
520   // Return |index|-th MockTCPClientSocket (starting from 0) that the factory
521   // created.
522   MockTCPClientSocket* GetMockTCPClientSocket(size_t index) const;
523 
524   // Return |index|-th MockSSLClientSocket (starting from 0) that the factory
525   // created.
526   MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const;
527 
mock_data()528   SocketDataProviderArray<SocketDataProvider>& mock_data() {
529     return mock_data_;
530   }
tcp_client_sockets()531   std::vector<MockTCPClientSocket*>& tcp_client_sockets() {
532     return tcp_client_sockets_;
533   }
534 
535   // ClientSocketFactory
536   virtual ClientSocket* CreateTransportClientSocket(
537       const AddressList& addresses,
538       NetLog* net_log,
539       const NetLog::Source& source);
540   virtual SSLClientSocket* CreateSSLClientSocket(
541       ClientSocketHandle* transport_socket,
542       const HostPortPair& host_and_port,
543       const SSLConfig& ssl_config,
544       SSLHostInfo* ssl_host_info,
545       CertVerifier* cert_verifier,
546       DnsCertProvenanceChecker* dns_cert_checker);
547   virtual void ClearSSLSessionCache();
548 
549  private:
550   SocketDataProviderArray<SocketDataProvider> mock_data_;
551   SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
552 
553   // Store pointers to handed out sockets in case the test wants to get them.
554   std::vector<MockTCPClientSocket*> tcp_client_sockets_;
555   std::vector<MockSSLClientSocket*> ssl_client_sockets_;
556 };
557 
558 class MockClientSocket : public net::SSLClientSocket {
559  public:
560   explicit MockClientSocket(net::NetLog* net_log);
561 
562   // If an async IO is pending because the SocketDataProvider returned
563   // ERR_IO_PENDING, then the MockClientSocket waits until this OnReadComplete
564   // is called to complete the asynchronous read operation.
565   // data.async is ignored, and this read is completed synchronously as
566   // part of this call.
567   virtual void OnReadComplete(const MockRead& data) = 0;
568 
569   // Socket methods:
570   virtual int Read(net::IOBuffer* buf, int buf_len,
571                    net::CompletionCallback* callback) = 0;
572   virtual int Write(net::IOBuffer* buf, int buf_len,
573                     net::CompletionCallback* callback) = 0;
574   virtual bool SetReceiveBufferSize(int32 size);
575   virtual bool SetSendBufferSize(int32 size);
576 
577   // ClientSocket methods:
578   virtual int Connect(net::CompletionCallback* callback) = 0;
579   virtual void Disconnect();
580   virtual bool IsConnected() const;
581   virtual bool IsConnectedAndIdle() const;
582   virtual int GetPeerAddress(AddressList* address) const;
583   virtual int GetLocalAddress(IPEndPoint* address) const;
584   virtual const BoundNetLog& NetLog() const;
SetSubresourceSpeculation()585   virtual void SetSubresourceSpeculation() {}
SetOmniboxSpeculation()586   virtual void SetOmniboxSpeculation() {}
587 
588   // SSLClientSocket methods:
589   virtual void GetSSLInfo(net::SSLInfo* ssl_info);
590   virtual void GetSSLCertRequestInfo(
591       net::SSLCertRequestInfo* cert_request_info);
592   virtual NextProtoStatus GetNextProto(std::string* proto);
593 
594  protected:
595   virtual ~MockClientSocket();
596   void RunCallbackAsync(net::CompletionCallback* callback, int result);
597   void RunCallback(net::CompletionCallback*, int result);
598 
599   ScopedRunnableMethodFactory<MockClientSocket> method_factory_;
600 
601   // True if Connect completed successfully and Disconnect hasn't been called.
602   bool connected_;
603 
604   net::BoundNetLog net_log_;
605 };
606 
607 class MockTCPClientSocket : public MockClientSocket {
608  public:
609   MockTCPClientSocket(const net::AddressList& addresses, net::NetLog* net_log,
610                       net::SocketDataProvider* socket);
611 
addresses()612   net::AddressList addresses() const { return addresses_; }
613 
614   // Socket methods:
615   virtual int Read(net::IOBuffer* buf, int buf_len,
616                    net::CompletionCallback* callback);
617   virtual int Write(net::IOBuffer* buf, int buf_len,
618                     net::CompletionCallback* callback);
619 
620   // ClientSocket methods:
621   virtual int Connect(net::CompletionCallback* callback);
622   virtual void Disconnect();
623   virtual bool IsConnected() const;
624   virtual bool IsConnectedAndIdle() const;
625   virtual int GetPeerAddress(AddressList* address) const;
626   virtual bool WasEverUsed() const;
627   virtual bool UsingTCPFastOpen() const;
628 
629   // MockClientSocket:
630   virtual void OnReadComplete(const MockRead& data);
631 
632  private:
633   int CompleteRead();
634 
635   net::AddressList addresses_;
636 
637   net::SocketDataProvider* data_;
638   int read_offset_;
639   net::MockRead read_data_;
640   bool need_read_data_;
641 
642   // True if the peer has closed the connection.  This allows us to simulate
643   // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real
644   // TCPClientSocket.
645   bool peer_closed_connection_;
646 
647   // While an asynchronous IO is pending, we save our user-buffer state.
648   net::IOBuffer* pending_buf_;
649   int pending_buf_len_;
650   net::CompletionCallback* pending_callback_;
651   bool was_used_to_convey_data_;
652 };
653 
654 class DeterministicMockTCPClientSocket : public MockClientSocket,
655     public base::SupportsWeakPtr<DeterministicMockTCPClientSocket> {
656  public:
657   DeterministicMockTCPClientSocket(net::NetLog* net_log,
658       net::DeterministicSocketData* data);
659   virtual ~DeterministicMockTCPClientSocket();
660 
write_pending()661   bool write_pending() const { return write_pending_; }
read_pending()662   bool read_pending() const { return read_pending_; }
663 
664   void CompleteWrite();
665   int CompleteRead();
666 
667   // Socket:
668   virtual int Write(net::IOBuffer* buf, int buf_len,
669                     net::CompletionCallback* callback);
670   virtual int Read(net::IOBuffer* buf, int buf_len,
671                    net::CompletionCallback* callback);
672 
673   // ClientSocket:
674   virtual int Connect(net::CompletionCallback* callback);
675   virtual void Disconnect();
676   virtual bool IsConnected() const;
677   virtual bool IsConnectedAndIdle() const;
678   virtual bool WasEverUsed() const;
679   virtual bool UsingTCPFastOpen() const;
680 
681   // MockClientSocket:
682   virtual void OnReadComplete(const MockRead& data);
683 
684  private:
685   bool write_pending_;
686   net::CompletionCallback* write_callback_;
687   int write_result_;
688 
689   net::MockRead read_data_;
690 
691   net::IOBuffer* read_buf_;
692   int read_buf_len_;
693   bool read_pending_;
694   net::CompletionCallback* read_callback_;
695   net::DeterministicSocketData* data_;
696   bool was_used_to_convey_data_;
697 };
698 
699 class MockSSLClientSocket : public MockClientSocket {
700  public:
701   MockSSLClientSocket(
702       net::ClientSocketHandle* transport_socket,
703       const HostPortPair& host_and_port,
704       const net::SSLConfig& ssl_config,
705       SSLHostInfo* ssl_host_info,
706       net::SSLSocketDataProvider* socket);
707   virtual ~MockSSLClientSocket();
708 
709   // Socket methods:
710   virtual int Read(net::IOBuffer* buf, int buf_len,
711                    net::CompletionCallback* callback);
712   virtual int Write(net::IOBuffer* buf, int buf_len,
713                     net::CompletionCallback* callback);
714 
715   // ClientSocket methods:
716   virtual int Connect(net::CompletionCallback* callback);
717   virtual void Disconnect();
718   virtual bool IsConnected() const;
719   virtual bool WasEverUsed() const;
720   virtual bool UsingTCPFastOpen() const;
721 
722   // SSLClientSocket methods:
723   virtual void GetSSLInfo(net::SSLInfo* ssl_info);
724   virtual void GetSSLCertRequestInfo(
725       net::SSLCertRequestInfo* cert_request_info);
726   virtual NextProtoStatus GetNextProto(std::string* proto);
727   virtual bool was_npn_negotiated() const;
728   virtual bool set_was_npn_negotiated(bool negotiated);
729 
730   // This MockSocket does not implement the manual async IO feature.
731   virtual void OnReadComplete(const MockRead& data);
732 
733  private:
734   class ConnectCallback;
735 
736   scoped_ptr<ClientSocketHandle> transport_;
737   net::SSLSocketDataProvider* data_;
738   bool is_npn_state_set_;
739   bool new_npn_value_;
740   bool was_used_to_convey_data_;
741 };
742 
743 class TestSocketRequest : public CallbackRunner< Tuple1<int> > {
744  public:
745   TestSocketRequest(
746       std::vector<TestSocketRequest*>* request_order,
747       size_t* completion_count);
748   virtual ~TestSocketRequest();
749 
handle()750   ClientSocketHandle* handle() { return &handle_; }
751 
752   int WaitForResult();
753   virtual void RunWithParams(const Tuple1<int>& params);
754 
755  private:
756   ClientSocketHandle handle_;
757   std::vector<TestSocketRequest*>* request_order_;
758   size_t* completion_count_;
759   TestCompletionCallback callback_;
760 };
761 
762 class ClientSocketPoolTest {
763  public:
764   enum KeepAlive {
765     KEEP_ALIVE,
766 
767     // A socket will be disconnected in addition to handle being reset.
768     NO_KEEP_ALIVE,
769   };
770 
771   static const int kIndexOutOfBounds;
772   static const int kRequestNotFound;
773 
774   ClientSocketPoolTest();
775   ~ClientSocketPoolTest();
776 
777   template <typename PoolType, typename SocketParams>
StartRequestUsingPool(PoolType * socket_pool,const std::string & group_name,RequestPriority priority,const scoped_refptr<SocketParams> & socket_params)778   int StartRequestUsingPool(PoolType* socket_pool,
779                             const std::string& group_name,
780                             RequestPriority priority,
781                             const scoped_refptr<SocketParams>& socket_params) {
782     DCHECK(socket_pool);
783     TestSocketRequest* request = new TestSocketRequest(&request_order_,
784                                                        &completion_count_);
785     requests_.push_back(request);
786     int rv = request->handle()->Init(
787         group_name, socket_params, priority, request,
788         socket_pool, BoundNetLog());
789     if (rv != ERR_IO_PENDING)
790       request_order_.push_back(request);
791     return rv;
792   }
793 
794   // Provided there were n requests started, takes |index| in range 1..n
795   // and returns order in which that request completed, in range 1..n,
796   // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound
797   // if that request did not complete (for example was canceled).
798   int GetOrderOfRequest(size_t index) const;
799 
800   // Resets first initialized socket handle from |requests_|. If found such
801   // a handle, returns true.
802   bool ReleaseOneConnection(KeepAlive keep_alive);
803 
804   // Releases connections until there is nothing to release.
805   void ReleaseAllConnections(KeepAlive keep_alive);
806 
request(int i)807   TestSocketRequest* request(int i) { return requests_[i]; }
requests_size()808   size_t requests_size() const { return requests_.size(); }
requests()809   ScopedVector<TestSocketRequest>* requests() { return &requests_; }
completion_count()810   size_t completion_count() const { return completion_count_; }
811 
812  private:
813   ScopedVector<TestSocketRequest> requests_;
814   std::vector<TestSocketRequest*> request_order_;
815   size_t completion_count_;
816 };
817 
818 class MockTransportClientSocketPool : public TransportClientSocketPool {
819  public:
820   class MockConnectJob {
821    public:
822     MockConnectJob(ClientSocket* socket, ClientSocketHandle* handle,
823                    CompletionCallback* callback);
824     ~MockConnectJob();
825 
826     int Connect();
827     bool CancelHandle(const ClientSocketHandle* handle);
828 
829    private:
830     void OnConnect(int rv);
831 
832     scoped_ptr<ClientSocket> socket_;
833     ClientSocketHandle* handle_;
834     CompletionCallback* user_callback_;
835     CompletionCallbackImpl<MockConnectJob> connect_callback_;
836 
837     DISALLOW_COPY_AND_ASSIGN(MockConnectJob);
838   };
839 
840   MockTransportClientSocketPool(
841       int max_sockets,
842       int max_sockets_per_group,
843       ClientSocketPoolHistograms* histograms,
844       ClientSocketFactory* socket_factory);
845 
846   virtual ~MockTransportClientSocketPool();
847 
release_count()848   int release_count() const { return release_count_; }
cancel_count()849   int cancel_count() const { return cancel_count_; }
850 
851   // TransportClientSocketPool methods.
852   virtual int RequestSocket(const std::string& group_name,
853                             const void* socket_params,
854                             RequestPriority priority,
855                             ClientSocketHandle* handle,
856                             CompletionCallback* callback,
857                             const BoundNetLog& net_log);
858 
859   virtual void CancelRequest(const std::string& group_name,
860                              ClientSocketHandle* handle);
861   virtual void ReleaseSocket(const std::string& group_name,
862                              ClientSocket* socket, int id);
863 
864  private:
865   ClientSocketFactory* client_socket_factory_;
866   ScopedVector<MockConnectJob> job_list_;
867   int release_count_;
868   int cancel_count_;
869 
870   DISALLOW_COPY_AND_ASSIGN(MockTransportClientSocketPool);
871 };
872 
873 class DeterministicMockClientSocketFactory : public ClientSocketFactory {
874  public:
875   DeterministicMockClientSocketFactory();
876   virtual ~DeterministicMockClientSocketFactory();
877 
878   void AddSocketDataProvider(DeterministicSocketData* socket);
879   void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
880   void ResetNextMockIndexes();
881 
882   // Return |index|-th MockSSLClientSocket (starting from 0) that the factory
883   // created.
884   MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const;
885 
mock_data()886   SocketDataProviderArray<DeterministicSocketData>& mock_data() {
887     return mock_data_;
888   }
tcp_client_sockets()889   std::vector<DeterministicMockTCPClientSocket*>& tcp_client_sockets() {
890     return tcp_client_sockets_;
891   }
892 
893   // ClientSocketFactory
894   virtual ClientSocket* CreateTransportClientSocket(
895       const AddressList& addresses,
896       NetLog* net_log,
897       const NetLog::Source& source);
898   virtual SSLClientSocket* CreateSSLClientSocket(
899       ClientSocketHandle* transport_socket,
900       const HostPortPair& host_and_port,
901       const SSLConfig& ssl_config,
902       SSLHostInfo* ssl_host_info,
903       CertVerifier* cert_verifier,
904       DnsCertProvenanceChecker* dns_cert_checker);
905   virtual void ClearSSLSessionCache();
906 
907  private:
908   SocketDataProviderArray<DeterministicSocketData> mock_data_;
909   SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
910 
911   // Store pointers to handed out sockets in case the test wants to get them.
912   std::vector<DeterministicMockTCPClientSocket*> tcp_client_sockets_;
913   std::vector<MockSSLClientSocket*> ssl_client_sockets_;
914 };
915 
916 class MockSOCKSClientSocketPool : public SOCKSClientSocketPool {
917  public:
918   MockSOCKSClientSocketPool(
919       int max_sockets,
920       int max_sockets_per_group,
921       ClientSocketPoolHistograms* histograms,
922       TransportClientSocketPool* transport_pool);
923 
924   virtual ~MockSOCKSClientSocketPool();
925 
926   // SOCKSClientSocketPool methods.
927   virtual int RequestSocket(const std::string& group_name,
928                             const void* socket_params,
929                             RequestPriority priority,
930                             ClientSocketHandle* handle,
931                             CompletionCallback* callback,
932                             const BoundNetLog& net_log);
933 
934   virtual void CancelRequest(const std::string& group_name,
935                              ClientSocketHandle* handle);
936   virtual void ReleaseSocket(const std::string& group_name,
937                              ClientSocket* socket, int id);
938 
939  private:
940   TransportClientSocketPool* const transport_pool_;
941 
942   DISALLOW_COPY_AND_ASSIGN(MockSOCKSClientSocketPool);
943 };
944 
945 // Constants for a successful SOCKS v5 handshake.
946 extern const char kSOCKS5GreetRequest[];
947 extern const int kSOCKS5GreetRequestLength;
948 
949 extern const char kSOCKS5GreetResponse[];
950 extern const int kSOCKS5GreetResponseLength;
951 
952 extern const char kSOCKS5OkRequest[];
953 extern const int kSOCKS5OkRequestLength;
954 
955 extern const char kSOCKS5OkResponse[];
956 extern const int kSOCKS5OkResponseLength;
957 
958 }  // namespace net
959 
960 #endif  // NET_SOCKET_SOCKET_TEST_UTIL_H_
961