• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2010 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <string>
6 #include <vector>
7 
8 #include "base/callback.h"
9 #include "base/utf_string_conversions.h"
10 #include "net/base/auth.h"
11 #include "net/base/mock_host_resolver.h"
12 #include "net/base/net_log.h"
13 #include "net/base/net_log_unittest.h"
14 #include "net/base/test_completion_callback.h"
15 #include "net/socket/socket_test_util.h"
16 #include "net/socket_stream/socket_stream.h"
17 #include "net/url_request/url_request_test_util.h"
18 #include "testing/gtest/include/gtest/gtest.h"
19 #include "testing/platform_test.h"
20 
21 struct SocketStreamEvent {
22   enum EventType {
23     EVENT_CONNECTED, EVENT_SENT_DATA, EVENT_RECEIVED_DATA, EVENT_CLOSE,
24     EVENT_AUTH_REQUIRED,
25   };
26 
SocketStreamEventSocketStreamEvent27   SocketStreamEvent(EventType type, net::SocketStream* socket_stream,
28                     int num, const std::string& str,
29                     net::AuthChallengeInfo* auth_challenge_info)
30       : event_type(type), socket(socket_stream), number(num), data(str),
31         auth_info(auth_challenge_info) {}
32 
33   EventType event_type;
34   net::SocketStream* socket;
35   int number;
36   std::string data;
37   scoped_refptr<net::AuthChallengeInfo> auth_info;
38 };
39 
40 class SocketStreamEventRecorder : public net::SocketStream::Delegate {
41  public:
SocketStreamEventRecorder(net::CompletionCallback * callback)42   explicit SocketStreamEventRecorder(net::CompletionCallback* callback)
43       : on_connected_(NULL),
44         on_sent_data_(NULL),
45         on_received_data_(NULL),
46         on_close_(NULL),
47         on_auth_required_(NULL),
48         callback_(callback) {}
~SocketStreamEventRecorder()49   virtual ~SocketStreamEventRecorder() {
50     delete on_connected_;
51     delete on_sent_data_;
52     delete on_received_data_;
53     delete on_close_;
54     delete on_auth_required_;
55   }
56 
SetOnConnected(Callback1<SocketStreamEvent * >::Type * callback)57   void SetOnConnected(Callback1<SocketStreamEvent*>::Type* callback) {
58     on_connected_ = callback;
59   }
SetOnSentData(Callback1<SocketStreamEvent * >::Type * callback)60   void SetOnSentData(Callback1<SocketStreamEvent*>::Type* callback) {
61     on_sent_data_ = callback;
62   }
SetOnReceivedData(Callback1<SocketStreamEvent * >::Type * callback)63   void SetOnReceivedData(Callback1<SocketStreamEvent*>::Type* callback) {
64     on_received_data_ = callback;
65   }
SetOnClose(Callback1<SocketStreamEvent * >::Type * callback)66   void SetOnClose(Callback1<SocketStreamEvent*>::Type* callback) {
67     on_close_ = callback;
68   }
SetOnAuthRequired(Callback1<SocketStreamEvent * >::Type * callback)69   void SetOnAuthRequired(Callback1<SocketStreamEvent*>::Type* callback) {
70     on_auth_required_ = callback;
71   }
72 
OnConnected(net::SocketStream * socket,int num_pending_send_allowed)73   virtual void OnConnected(net::SocketStream* socket,
74                            int num_pending_send_allowed) {
75     events_.push_back(
76         SocketStreamEvent(SocketStreamEvent::EVENT_CONNECTED,
77                           socket, num_pending_send_allowed, std::string(),
78                           NULL));
79     if (on_connected_)
80       on_connected_->Run(&events_.back());
81   }
OnSentData(net::SocketStream * socket,int amount_sent)82   virtual void OnSentData(net::SocketStream* socket,
83                           int amount_sent) {
84     events_.push_back(
85         SocketStreamEvent(SocketStreamEvent::EVENT_SENT_DATA,
86                           socket, amount_sent, std::string(), NULL));
87     if (on_sent_data_)
88       on_sent_data_->Run(&events_.back());
89   }
OnReceivedData(net::SocketStream * socket,const char * data,int len)90   virtual void OnReceivedData(net::SocketStream* socket,
91                               const char* data, int len) {
92     events_.push_back(
93         SocketStreamEvent(SocketStreamEvent::EVENT_RECEIVED_DATA,
94                           socket, len, std::string(data, len), NULL));
95     if (on_received_data_)
96       on_received_data_->Run(&events_.back());
97   }
OnClose(net::SocketStream * socket)98   virtual void OnClose(net::SocketStream* socket) {
99     events_.push_back(
100         SocketStreamEvent(SocketStreamEvent::EVENT_CLOSE,
101                           socket, 0, std::string(), NULL));
102     if (on_close_)
103       on_close_->Run(&events_.back());
104     if (callback_)
105       callback_->Run(net::OK);
106   }
OnAuthRequired(net::SocketStream * socket,net::AuthChallengeInfo * auth_info)107   virtual void OnAuthRequired(net::SocketStream* socket,
108                               net::AuthChallengeInfo* auth_info) {
109     events_.push_back(
110         SocketStreamEvent(SocketStreamEvent::EVENT_AUTH_REQUIRED,
111                           socket, 0, std::string(), auth_info));
112     if (on_auth_required_)
113       on_auth_required_->Run(&events_.back());
114   }
115 
DoClose(SocketStreamEvent * event)116   void DoClose(SocketStreamEvent* event) {
117     event->socket->Close();
118   }
DoRestartWithAuth(SocketStreamEvent * event)119   void DoRestartWithAuth(SocketStreamEvent* event) {
120     VLOG(1) << "RestartWithAuth username=" << username_
121             << " password=" << password_;
122     event->socket->RestartWithAuth(username_, password_);
123   }
SetAuthInfo(const string16 & username,const string16 & password)124   void SetAuthInfo(const string16& username,
125                    const string16& password) {
126     username_ = username;
127     password_ = password;
128   }
129 
GetSeenEvents() const130   const std::vector<SocketStreamEvent>& GetSeenEvents() const {
131     return events_;
132   }
133 
134  private:
135   std::vector<SocketStreamEvent> events_;
136   Callback1<SocketStreamEvent*>::Type* on_connected_;
137   Callback1<SocketStreamEvent*>::Type* on_sent_data_;
138   Callback1<SocketStreamEvent*>::Type* on_received_data_;
139   Callback1<SocketStreamEvent*>::Type* on_close_;
140   Callback1<SocketStreamEvent*>::Type* on_auth_required_;
141   net::CompletionCallback* callback_;
142 
143   string16 username_;
144   string16 password_;
145 
146   DISALLOW_COPY_AND_ASSIGN(SocketStreamEventRecorder);
147 };
148 
149 namespace net {
150 
151 class SocketStreamTest : public PlatformTest {
152  public:
~SocketStreamTest()153   virtual ~SocketStreamTest() {}
SetUp()154   virtual void SetUp() {
155     mock_socket_factory_.reset();
156     handshake_request_ = kWebSocketHandshakeRequest;
157     handshake_response_ = kWebSocketHandshakeResponse;
158   }
TearDown()159   virtual void TearDown() {
160     mock_socket_factory_.reset();
161   }
162 
SetWebSocketHandshakeMessage(const char * request,const char * response)163   virtual void SetWebSocketHandshakeMessage(
164       const char* request, const char* response) {
165     handshake_request_ = request;
166     handshake_response_ = response;
167   }
AddWebSocketMessage(const std::string & message)168   virtual void AddWebSocketMessage(const std::string& message) {
169     messages_.push_back(message);
170   }
171 
GetMockClientSocketFactory()172   virtual MockClientSocketFactory* GetMockClientSocketFactory() {
173     mock_socket_factory_.reset(new MockClientSocketFactory);
174     return mock_socket_factory_.get();
175   }
176 
DoSendWebSocketHandshake(SocketStreamEvent * event)177   virtual void DoSendWebSocketHandshake(SocketStreamEvent* event) {
178     event->socket->SendData(
179         handshake_request_.data(), handshake_request_.size());
180   }
181 
DoCloseFlushPendingWriteTest(SocketStreamEvent * event)182   virtual void DoCloseFlushPendingWriteTest(SocketStreamEvent* event) {
183     // handshake response received.
184     for (size_t i = 0; i < messages_.size(); i++) {
185       std::vector<char> frame;
186       frame.push_back('\0');
187       frame.insert(frame.end(), messages_[i].begin(), messages_[i].end());
188       frame.push_back('\xff');
189       EXPECT_TRUE(event->socket->SendData(&frame[0], frame.size()));
190     }
191     // Actual ClientSocket close must happen after all frames queued by
192     // SendData above are sent out.
193     event->socket->Close();
194   }
195 
196   static const char* kWebSocketHandshakeRequest;
197   static const char* kWebSocketHandshakeResponse;
198 
199  private:
200   std::string handshake_request_;
201   std::string handshake_response_;
202   std::vector<std::string> messages_;
203 
204   scoped_ptr<MockClientSocketFactory> mock_socket_factory_;
205 };
206 
207 const char* SocketStreamTest::kWebSocketHandshakeRequest =
208     "GET /demo HTTP/1.1\r\n"
209     "Host: example.com\r\n"
210     "Connection: Upgrade\r\n"
211     "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
212     "Sec-WebSocket-Protocol: sample\r\n"
213     "Upgrade: WebSocket\r\n"
214     "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
215     "Origin: http://example.com\r\n"
216     "\r\n"
217     "^n:ds[4U";
218 
219 const char* SocketStreamTest::kWebSocketHandshakeResponse =
220     "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
221     "Upgrade: WebSocket\r\n"
222     "Connection: Upgrade\r\n"
223     "Sec-WebSocket-Origin: http://example.com\r\n"
224     "Sec-WebSocket-Location: ws://example.com/demo\r\n"
225     "Sec-WebSocket-Protocol: sample\r\n"
226     "\r\n"
227     "8jKS'y:G*Co,Wxa-";
228 
TEST_F(SocketStreamTest,CloseFlushPendingWrite)229 TEST_F(SocketStreamTest, CloseFlushPendingWrite) {
230   TestCompletionCallback callback;
231 
232   scoped_ptr<SocketStreamEventRecorder> delegate(
233       new SocketStreamEventRecorder(&callback));
234   // Necessary for NewCallback.
235   SocketStreamTest* test = this;
236   delegate->SetOnConnected(NewCallback(
237       test, &SocketStreamTest::DoSendWebSocketHandshake));
238   delegate->SetOnReceivedData(NewCallback(
239       test, &SocketStreamTest::DoCloseFlushPendingWriteTest));
240 
241   MockHostResolver host_resolver;
242 
243   scoped_refptr<SocketStream> socket_stream(
244       new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
245 
246   socket_stream->set_context(new TestURLRequestContext());
247   socket_stream->SetHostResolver(&host_resolver);
248 
249   MockWrite data_writes[] = {
250     MockWrite(SocketStreamTest::kWebSocketHandshakeRequest),
251     MockWrite(true, "\0message1\xff", 10),
252     MockWrite(true, "\0message2\xff", 10)
253   };
254   MockRead data_reads[] = {
255     MockRead(SocketStreamTest::kWebSocketHandshakeResponse),
256     // Server doesn't close the connection after handshake.
257     MockRead(true, ERR_IO_PENDING)
258   };
259   AddWebSocketMessage("message1");
260   AddWebSocketMessage("message2");
261 
262   scoped_refptr<DelayedSocketData> data_provider(
263       new DelayedSocketData(1,
264                             data_reads, arraysize(data_reads),
265                             data_writes, arraysize(data_writes)));
266 
267   MockClientSocketFactory* mock_socket_factory =
268       GetMockClientSocketFactory();
269   mock_socket_factory->AddSocketDataProvider(data_provider.get());
270 
271   socket_stream->SetClientSocketFactory(mock_socket_factory);
272 
273   socket_stream->Connect();
274 
275   callback.WaitForResult();
276 
277   const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
278   EXPECT_EQ(6U, events.size());
279 
280   EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[0].event_type);
281   EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[1].event_type);
282   EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[2].event_type);
283   EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[3].event_type);
284   EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[4].event_type);
285   EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[5].event_type);
286 }
287 
TEST_F(SocketStreamTest,BasicAuthProxy)288 TEST_F(SocketStreamTest, BasicAuthProxy) {
289   MockClientSocketFactory mock_socket_factory;
290   MockWrite data_writes1[] = {
291     MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
292               "Host: example.com\r\n"
293               "Proxy-Connection: keep-alive\r\n\r\n"),
294   };
295   MockRead data_reads1[] = {
296     MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"),
297     MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"),
298     MockRead("\r\n"),
299   };
300   StaticSocketDataProvider data1(data_reads1, arraysize(data_reads1),
301                                  data_writes1, arraysize(data_writes1));
302   mock_socket_factory.AddSocketDataProvider(&data1);
303 
304   MockWrite data_writes2[] = {
305     MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
306               "Host: example.com\r\n"
307               "Proxy-Connection: keep-alive\r\n"
308               "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"),
309   };
310   MockRead data_reads2[] = {
311     MockRead("HTTP/1.1 200 Connection Established\r\n"),
312     MockRead("Proxy-agent: Apache/2.2.8\r\n"),
313     MockRead("\r\n"),
314     // SocketStream::DoClose is run asynchronously.  Socket can be read after
315     // "\r\n".  We have to give ERR_IO_PENDING to SocketStream then to indicate
316     // server doesn't close the connection.
317     MockRead(true, ERR_IO_PENDING)
318   };
319   StaticSocketDataProvider data2(data_reads2, arraysize(data_reads2),
320                                  data_writes2, arraysize(data_writes2));
321   mock_socket_factory.AddSocketDataProvider(&data2);
322 
323   TestCompletionCallback callback;
324 
325   scoped_ptr<SocketStreamEventRecorder> delegate(
326       new SocketStreamEventRecorder(&callback));
327   delegate->SetOnConnected(NewCallback(delegate.get(),
328                                        &SocketStreamEventRecorder::DoClose));
329   delegate->SetAuthInfo(ASCIIToUTF16("foo"), ASCIIToUTF16("bar"));
330   delegate->SetOnAuthRequired(
331       NewCallback(delegate.get(),
332                   &SocketStreamEventRecorder::DoRestartWithAuth));
333 
334   scoped_refptr<SocketStream> socket_stream(
335       new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
336 
337   socket_stream->set_context(new TestURLRequestContext("myproxy:70"));
338   MockHostResolver host_resolver;
339   socket_stream->SetHostResolver(&host_resolver);
340   socket_stream->SetClientSocketFactory(&mock_socket_factory);
341 
342   socket_stream->Connect();
343 
344   callback.WaitForResult();
345 
346   const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
347   EXPECT_EQ(3U, events.size());
348 
349   EXPECT_EQ(SocketStreamEvent::EVENT_AUTH_REQUIRED, events[0].event_type);
350   EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type);
351   EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[2].event_type);
352 
353   // TODO(eroman): Add back NetLogTest here...
354 }
355 
356 }  // namespace net
357