• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
6 #define NET_SOCKET_SOCKET_TEST_UTIL_H_
7 
8 #include <deque>
9 #include <string>
10 #include <vector>
11 
12 #include "base/basictypes.h"
13 #include "base/logging.h"
14 #include "base/scoped_ptr.h"
15 #include "base/scoped_vector.h"
16 #include "net/base/address_list.h"
17 #include "net/base/io_buffer.h"
18 #include "net/base/net_errors.h"
19 #include "net/base/ssl_config_service.h"
20 #include "net/base/test_completion_callback.h"
21 #include "net/socket/client_socket_factory.h"
22 #include "net/socket/client_socket_handle.h"
23 #include "net/socket/ssl_client_socket.h"
24 #include "testing/gtest/include/gtest/gtest.h"
25 
26 namespace net {
27 
28 enum {
29   // A private network error code used by the socket test utility classes.
30   // If the |result| member of a MockRead is
31   // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a
32   // marker that indicates the peer will close the connection after the next
33   // MockRead.  The other members of that MockRead are ignored.
34   ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000,
35 };
36 
37 class ClientSocket;
38 class LoadLog;
39 class MockClientSocket;
40 class SSLClientSocket;
41 
42 struct MockConnect {
43   // Asynchronous connection success.
MockConnectMockConnect44   MockConnect() : async(true), result(OK) { }
MockConnectMockConnect45   MockConnect(bool a, int r) : async(a), result(r) { }
46 
47   bool async;
48   int result;
49 };
50 
51 struct MockRead {
52   // Default
MockReadMockRead53   MockRead() : async(false), result(0), data(NULL), data_len(0) {}
54 
55   // Read failure (no data).
MockReadMockRead56   MockRead(bool async, int result) : async(async) , result(result), data(NULL),
57       data_len(0) { }
58 
59   // Asynchronous read success (inferred data length).
MockReadMockRead60   explicit MockRead(const char* data) : async(true),  result(0), data(data),
61       data_len(strlen(data)) { }
62 
63   // Read success (inferred data length).
MockReadMockRead64   MockRead(bool async, const char* data) : async(async), result(0), data(data),
65       data_len(strlen(data)) { }
66 
67   // Read success.
MockReadMockRead68   MockRead(bool async, const char* data, int data_len) : async(async),
69       result(0), data(data), data_len(data_len) { }
70 
71   bool async;
72   int result;
73   const char* data;
74   int data_len;
75 };
76 
77 // MockWrite uses the same member fields as MockRead, but with different
78 // meanings. The expected input to MockTCPClientSocket::Write() is given
79 // by {data, data_len}, and the return value of Write() is controlled by
80 // {async, result}.
81 typedef MockRead MockWrite;
82 
83 struct MockWriteResult {
MockWriteResultMockWriteResult84   MockWriteResult(bool async, int result) : async(async), result(result) {}
85 
86   bool async;
87   int result;
88 };
89 
90 // The SocketDataProvider is an interface used by the MockClientSocket
91 // for getting data about individual reads and writes on the socket.
92 class SocketDataProvider {
93  public:
SocketDataProvider()94   SocketDataProvider() : socket_(NULL) {}
95 
~SocketDataProvider()96   virtual ~SocketDataProvider() {}
97 
98   // Returns the buffer and result code for the next simulated read.
99   // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller
100   // that it will be called via the MockClientSocket::OnReadComplete()
101   // function at a later time.
102   virtual MockRead GetNextRead() = 0;
103   virtual MockWriteResult OnWrite(const std::string& data) = 0;
104   virtual void Reset() = 0;
105 
106   // Accessor for the socket which is using the SocketDataProvider.
socket()107   MockClientSocket* socket() { return socket_; }
set_socket(MockClientSocket * socket)108   void set_socket(MockClientSocket* socket) { socket_ = socket; }
109 
connect_data()110   MockConnect connect_data() const { return connect_; }
set_connect_data(const MockConnect & connect)111   void set_connect_data(const MockConnect& connect) { connect_ = connect; }
112 
113  private:
114   MockConnect connect_;
115   MockClientSocket* socket_;
116 
117   DISALLOW_COPY_AND_ASSIGN(SocketDataProvider);
118 };
119 
120 // SocketDataProvider which responds based on static tables of mock reads and
121 // writes.
122 class StaticSocketDataProvider : public SocketDataProvider {
123  public:
StaticSocketDataProvider()124   StaticSocketDataProvider() : reads_(NULL), read_index_(0),
125       writes_(NULL), write_index_(0) {}
StaticSocketDataProvider(MockRead * r,MockWrite * w)126   StaticSocketDataProvider(MockRead* r, MockWrite* w) : reads_(r),
127       read_index_(0), writes_(w), write_index_(0) {}
128 
129   // SocketDataProvider methods:
130   virtual MockRead GetNextRead();
131   virtual MockWriteResult OnWrite(const std::string& data);
132   virtual void Reset();
133 
134   // If the test wishes to verify that all data is consumed, it can include
135   // a EOF MockRead or MockWrite, which is a zero-length Read or Write.
136   // The test can then call at_read_eof() or at_write_eof() to verify that
137   // all data has been consumed.
at_read_eof()138   bool at_read_eof() const { return reads_[read_index_].data_len == 0; }
at_write_eof()139   bool at_write_eof() const { return writes_[write_index_].data_len == 0; }
140 
141  private:
142   MockRead* reads_;
143   int read_index_;
144   MockWrite* writes_;
145   int write_index_;
146 
147   DISALLOW_COPY_AND_ASSIGN(StaticSocketDataProvider);
148 };
149 
150 // SocketDataProvider which can make decisions about next mock reads based on
151 // received writes. It can also be used to enforce order of operations, for
152 // example that tested code must send the "Hello!" message before receiving
153 // response. This is useful for testing conversation-like protocols like FTP.
154 class DynamicSocketDataProvider : public SocketDataProvider {
155  public:
156   DynamicSocketDataProvider();
157 
158   // SocketDataProvider methods:
159   virtual MockRead GetNextRead();
160   virtual MockWriteResult OnWrite(const std::string& data) = 0;
161   virtual void Reset();
162 
short_read_limit()163   int short_read_limit() const { return short_read_limit_; }
set_short_read_limit(int limit)164   void set_short_read_limit(int limit) { short_read_limit_ = limit; }
165 
allow_unconsumed_reads(bool allow)166   void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; }
167 
168  protected:
169   // The next time there is a read from this socket, it will return |data|.
170   // Before calling SimulateRead next time, the previous data must be consumed.
171   void SimulateRead(const char* data);
172 
173  private:
174   std::deque<MockRead> reads_;
175 
176   // Max number of bytes we will read at a time. 0 means no limit.
177   int short_read_limit_;
178 
179   // If true, we'll not require the client to consume all data before we
180   // mock the next read.
181   bool allow_unconsumed_reads_;
182 
183   DISALLOW_COPY_AND_ASSIGN(DynamicSocketDataProvider);
184 };
185 
186 // SSLSocketDataProviders only need to keep track of the return code from calls
187 // to Connect().
188 struct SSLSocketDataProvider {
SSLSocketDataProviderSSLSocketDataProvider189   SSLSocketDataProvider(bool async, int result) : connect(async, result) { }
190 
191   MockConnect connect;
192 };
193 
194 // Holds an array of SocketDataProvider elements.  As Mock{TCP,SSL}ClientSocket
195 // objects get instantiated, they take their data from the i'th element of this
196 // array.
197 template<typename T>
198 class SocketDataProviderArray {
199  public:
SocketDataProviderArray()200   SocketDataProviderArray() : next_index_(0) {
201   }
202 
GetNext()203   T* GetNext() {
204     DCHECK(next_index_ < data_providers_.size());
205     return data_providers_[next_index_++];
206   }
207 
Add(T * data_provider)208   void Add(T* data_provider) {
209     DCHECK(data_provider);
210     data_providers_.push_back(data_provider);
211   }
212 
ResetNextIndex()213   void ResetNextIndex() {
214     next_index_ = 0;
215   }
216 
217  private:
218   // Index of the next |data_providers_| element to use. Not an iterator
219   // because those are invalidated on vector reallocation.
220   size_t next_index_;
221 
222   // SocketDataProviders to be returned.
223   std::vector<T*> data_providers_;
224 };
225 
226 class MockTCPClientSocket;
227 class MockSSLClientSocket;
228 
229 // ClientSocketFactory which contains arrays of sockets of each type.
230 // You should first fill the arrays using AddMock{SSL,}Socket. When the factory
231 // is asked to create a socket, it takes next entry from appropriate array.
232 // You can use ResetNextMockIndexes to reset that next entry index for all mock
233 // socket types.
234 class MockClientSocketFactory : public ClientSocketFactory {
235  public:
236   void AddSocketDataProvider(SocketDataProvider* socket);
237   void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
238   void ResetNextMockIndexes();
239 
240   // Return |index|-th MockTCPClientSocket (starting from 0) that the factory
241   // created.
242   MockTCPClientSocket* GetMockTCPClientSocket(int index) const;
243 
244   // Return |index|-th MockSSLClientSocket (starting from 0) that the factory
245   // created.
246   MockSSLClientSocket* GetMockSSLClientSocket(int index) const;
247 
248   // ClientSocketFactory
249   virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses);
250   virtual SSLClientSocket* CreateSSLClientSocket(
251       ClientSocket* transport_socket,
252       const std::string& hostname,
253       const SSLConfig& ssl_config);
254 
255  private:
256   SocketDataProviderArray<SocketDataProvider> mock_data_;
257   SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
258 
259   // Store pointers to handed out sockets in case the test wants to get them.
260   std::vector<MockTCPClientSocket*> tcp_client_sockets_;
261   std::vector<MockSSLClientSocket*> ssl_client_sockets_;
262 };
263 
264 class MockClientSocket : public net::SSLClientSocket {
265  public:
266   MockClientSocket();
267 
268   // ClientSocket methods:
269   virtual int Connect(net::CompletionCallback* callback, LoadLog* load_log) = 0;
270   virtual void Disconnect();
271   virtual bool IsConnected() const;
272   virtual bool IsConnectedAndIdle() const;
273   virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen);
274 
275   // SSLClientSocket methods:
276   virtual void GetSSLInfo(net::SSLInfo* ssl_info);
277   virtual void GetSSLCertRequestInfo(
278       net::SSLCertRequestInfo* cert_request_info);
279   virtual NextProtoStatus GetNextProto(std::string* proto);
280 
281   // Socket methods:
282   virtual int Read(net::IOBuffer* buf, int buf_len,
283                    net::CompletionCallback* callback) = 0;
284   virtual int Write(net::IOBuffer* buf, int buf_len,
285                     net::CompletionCallback* callback) = 0;
SetReceiveBufferSize(int32 size)286   virtual bool SetReceiveBufferSize(int32 size) { return true; }
SetSendBufferSize(int32 size)287   virtual bool SetSendBufferSize(int32 size) { return true; }
288 
289   // If an async IO is pending because the SocketDataProvider returned
290   // ERR_IO_PENDING, then the MockClientSocket waits until this OnReadComplete
291   // is called to complete the asynchronous read operation.
292   // data.async is ignored, and this read is completed synchronously as
293   // part of this call.
294   virtual void OnReadComplete(const MockRead& data) = 0;
295 
296  protected:
297   void RunCallbackAsync(net::CompletionCallback* callback, int result);
298   void RunCallback(net::CompletionCallback*, int result);
299 
300   ScopedRunnableMethodFactory<MockClientSocket> method_factory_;
301 
302   // True if Connect completed successfully and Disconnect hasn't been called.
303   bool connected_;
304 };
305 
306 class MockTCPClientSocket : public MockClientSocket {
307  public:
308   MockTCPClientSocket(const net::AddressList& addresses,
309                       net::SocketDataProvider* socket);
310 
311   // ClientSocket methods:
312   virtual int Connect(net::CompletionCallback* callback,
313                       LoadLog* load_log);
314   virtual bool IsConnected() const;
IsConnectedAndIdle()315   virtual bool IsConnectedAndIdle() const { return IsConnected(); }
316 
317   // Socket methods:
318   virtual int Read(net::IOBuffer* buf, int buf_len,
319                    net::CompletionCallback* callback);
320   virtual int Write(net::IOBuffer* buf, int buf_len,
321                     net::CompletionCallback* callback);
322 
323   virtual void OnReadComplete(const MockRead& data);
324 
addresses()325   net::AddressList addresses() const { return addresses_; }
326 
327  private:
328   int CompleteRead();
329 
330   net::AddressList addresses_;
331 
332   net::SocketDataProvider* data_;
333   int read_offset_;
334   net::MockRead read_data_;
335   bool need_read_data_;
336 
337   // True if the peer has closed the connection.  This allows us to simulate
338   // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real
339   // TCPClientSocket.
340   bool peer_closed_connection_;
341 
342   // While an asynchronous IO is pending, we save our user-buffer state.
343   net::IOBuffer* pending_buf_;
344   int pending_buf_len_;
345   net::CompletionCallback* pending_callback_;
346 };
347 
348 class MockSSLClientSocket : public MockClientSocket {
349  public:
350   MockSSLClientSocket(
351       net::ClientSocket* transport_socket,
352       const std::string& hostname,
353       const net::SSLConfig& ssl_config,
354       net::SSLSocketDataProvider* socket);
355   ~MockSSLClientSocket();
356 
357   virtual void GetSSLInfo(net::SSLInfo* ssl_info);
358 
359   virtual int Connect(net::CompletionCallback* callback, LoadLog* load_log);
360   virtual void Disconnect();
361 
362   // Socket methods:
363   virtual int Read(net::IOBuffer* buf, int buf_len,
364                    net::CompletionCallback* callback);
365   virtual int Write(net::IOBuffer* buf, int buf_len,
366                     net::CompletionCallback* callback);
367 
368   // This MockSocket does not implement the manual async IO feature.
OnReadComplete(const MockRead & data)369   virtual void OnReadComplete(const MockRead& data) { NOTIMPLEMENTED(); }
370 
371  private:
372   class ConnectCallback;
373 
374   scoped_ptr<ClientSocket> transport_;
375   net::SSLSocketDataProvider* data_;
376 };
377 
378 class TestSocketRequest : public CallbackRunner< Tuple1<int> > {
379  public:
TestSocketRequest(std::vector<TestSocketRequest * > * request_order,size_t * completion_count)380   TestSocketRequest(
381       std::vector<TestSocketRequest*>* request_order,
382       size_t* completion_count)
383       : request_order_(request_order),
384         completion_count_(completion_count) {
385     DCHECK(request_order);
386     DCHECK(completion_count);
387   }
388 
handle()389   ClientSocketHandle* handle() { return &handle_; }
390 
391   int WaitForResult();
392   virtual void RunWithParams(const Tuple1<int>& params);
393 
394  private:
395   ClientSocketHandle handle_;
396   std::vector<TestSocketRequest*>* request_order_;
397   size_t* completion_count_;
398   TestCompletionCallback callback_;
399 };
400 
401 class ClientSocketPoolTest : public testing::Test {
402  protected:
403   enum KeepAlive {
404     KEEP_ALIVE,
405 
406     // A socket will be disconnected in addition to handle being reset.
407     NO_KEEP_ALIVE,
408   };
409 
410   static const int kIndexOutOfBounds;
411   static const int kRequestNotFound;
412 
413   virtual void SetUp();
414   virtual void TearDown();
415 
416   template <typename PoolType, typename SocketParams>
StartRequestUsingPool(PoolType * socket_pool,const std::string & group_name,RequestPriority priority,const SocketParams & socket_params)417   int StartRequestUsingPool(PoolType* socket_pool,
418                             const std::string& group_name,
419                             RequestPriority priority,
420                             const SocketParams& socket_params) {
421     DCHECK(socket_pool);
422     TestSocketRequest* request = new TestSocketRequest(&request_order_,
423                                                        &completion_count_);
424     requests_.push_back(request);
425     int rv = request->handle()->Init(
426         group_name, socket_params, priority, request,
427         socket_pool, NULL);
428     if (rv != ERR_IO_PENDING)
429       request_order_.push_back(request);
430     return rv;
431   }
432 
433   // Provided there were n requests started, takes |index| in range 1..n
434   // and returns order in which that request completed, in range 1..n,
435   // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound
436   // if that request did not complete (for example was canceled).
437   int GetOrderOfRequest(size_t index);
438 
439   // Resets first initialized socket handle from |requests_|. If found such
440   // a handle, returns true.
441   bool ReleaseOneConnection(KeepAlive keep_alive);
442 
443   // Releases connections until there is nothing to release.
444   void ReleaseAllConnections(KeepAlive keep_alive);
445 
446   ScopedVector<TestSocketRequest> requests_;
447   std::vector<TestSocketRequest*> request_order_;
448   size_t completion_count_;
449 };
450 
451 }  // namespace net
452 
453 #endif  // NET_SOCKET_SOCKET_TEST_UTIL_H_
454