1 // Copyright (c) 2010 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <string>
6 #include <vector>
7
8 #include "base/callback.h"
9 #include "base/utf_string_conversions.h"
10 #include "net/base/auth.h"
11 #include "net/base/mock_host_resolver.h"
12 #include "net/base/net_log.h"
13 #include "net/base/net_log_unittest.h"
14 #include "net/base/test_completion_callback.h"
15 #include "net/socket/socket_test_util.h"
16 #include "net/socket_stream/socket_stream.h"
17 #include "net/url_request/url_request_test_util.h"
18 #include "testing/gtest/include/gtest/gtest.h"
19 #include "testing/platform_test.h"
20
21 struct SocketStreamEvent {
22 enum EventType {
23 EVENT_CONNECTED, EVENT_SENT_DATA, EVENT_RECEIVED_DATA, EVENT_CLOSE,
24 EVENT_AUTH_REQUIRED,
25 };
26
SocketStreamEventSocketStreamEvent27 SocketStreamEvent(EventType type, net::SocketStream* socket_stream,
28 int num, const std::string& str,
29 net::AuthChallengeInfo* auth_challenge_info)
30 : event_type(type), socket(socket_stream), number(num), data(str),
31 auth_info(auth_challenge_info) {}
32
33 EventType event_type;
34 net::SocketStream* socket;
35 int number;
36 std::string data;
37 scoped_refptr<net::AuthChallengeInfo> auth_info;
38 };
39
40 class SocketStreamEventRecorder : public net::SocketStream::Delegate {
41 public:
SocketStreamEventRecorder(net::CompletionCallback * callback)42 explicit SocketStreamEventRecorder(net::CompletionCallback* callback)
43 : on_connected_(NULL),
44 on_sent_data_(NULL),
45 on_received_data_(NULL),
46 on_close_(NULL),
47 on_auth_required_(NULL),
48 callback_(callback) {}
~SocketStreamEventRecorder()49 virtual ~SocketStreamEventRecorder() {
50 delete on_connected_;
51 delete on_sent_data_;
52 delete on_received_data_;
53 delete on_close_;
54 delete on_auth_required_;
55 }
56
SetOnConnected(Callback1<SocketStreamEvent * >::Type * callback)57 void SetOnConnected(Callback1<SocketStreamEvent*>::Type* callback) {
58 on_connected_ = callback;
59 }
SetOnSentData(Callback1<SocketStreamEvent * >::Type * callback)60 void SetOnSentData(Callback1<SocketStreamEvent*>::Type* callback) {
61 on_sent_data_ = callback;
62 }
SetOnReceivedData(Callback1<SocketStreamEvent * >::Type * callback)63 void SetOnReceivedData(Callback1<SocketStreamEvent*>::Type* callback) {
64 on_received_data_ = callback;
65 }
SetOnClose(Callback1<SocketStreamEvent * >::Type * callback)66 void SetOnClose(Callback1<SocketStreamEvent*>::Type* callback) {
67 on_close_ = callback;
68 }
SetOnAuthRequired(Callback1<SocketStreamEvent * >::Type * callback)69 void SetOnAuthRequired(Callback1<SocketStreamEvent*>::Type* callback) {
70 on_auth_required_ = callback;
71 }
72
OnConnected(net::SocketStream * socket,int num_pending_send_allowed)73 virtual void OnConnected(net::SocketStream* socket,
74 int num_pending_send_allowed) {
75 events_.push_back(
76 SocketStreamEvent(SocketStreamEvent::EVENT_CONNECTED,
77 socket, num_pending_send_allowed, std::string(),
78 NULL));
79 if (on_connected_)
80 on_connected_->Run(&events_.back());
81 }
OnSentData(net::SocketStream * socket,int amount_sent)82 virtual void OnSentData(net::SocketStream* socket,
83 int amount_sent) {
84 events_.push_back(
85 SocketStreamEvent(SocketStreamEvent::EVENT_SENT_DATA,
86 socket, amount_sent, std::string(), NULL));
87 if (on_sent_data_)
88 on_sent_data_->Run(&events_.back());
89 }
OnReceivedData(net::SocketStream * socket,const char * data,int len)90 virtual void OnReceivedData(net::SocketStream* socket,
91 const char* data, int len) {
92 events_.push_back(
93 SocketStreamEvent(SocketStreamEvent::EVENT_RECEIVED_DATA,
94 socket, len, std::string(data, len), NULL));
95 if (on_received_data_)
96 on_received_data_->Run(&events_.back());
97 }
OnClose(net::SocketStream * socket)98 virtual void OnClose(net::SocketStream* socket) {
99 events_.push_back(
100 SocketStreamEvent(SocketStreamEvent::EVENT_CLOSE,
101 socket, 0, std::string(), NULL));
102 if (on_close_)
103 on_close_->Run(&events_.back());
104 if (callback_)
105 callback_->Run(net::OK);
106 }
OnAuthRequired(net::SocketStream * socket,net::AuthChallengeInfo * auth_info)107 virtual void OnAuthRequired(net::SocketStream* socket,
108 net::AuthChallengeInfo* auth_info) {
109 events_.push_back(
110 SocketStreamEvent(SocketStreamEvent::EVENT_AUTH_REQUIRED,
111 socket, 0, std::string(), auth_info));
112 if (on_auth_required_)
113 on_auth_required_->Run(&events_.back());
114 }
115
DoClose(SocketStreamEvent * event)116 void DoClose(SocketStreamEvent* event) {
117 event->socket->Close();
118 }
DoRestartWithAuth(SocketStreamEvent * event)119 void DoRestartWithAuth(SocketStreamEvent* event) {
120 VLOG(1) << "RestartWithAuth username=" << username_
121 << " password=" << password_;
122 event->socket->RestartWithAuth(username_, password_);
123 }
SetAuthInfo(const string16 & username,const string16 & password)124 void SetAuthInfo(const string16& username,
125 const string16& password) {
126 username_ = username;
127 password_ = password;
128 }
129
GetSeenEvents() const130 const std::vector<SocketStreamEvent>& GetSeenEvents() const {
131 return events_;
132 }
133
134 private:
135 std::vector<SocketStreamEvent> events_;
136 Callback1<SocketStreamEvent*>::Type* on_connected_;
137 Callback1<SocketStreamEvent*>::Type* on_sent_data_;
138 Callback1<SocketStreamEvent*>::Type* on_received_data_;
139 Callback1<SocketStreamEvent*>::Type* on_close_;
140 Callback1<SocketStreamEvent*>::Type* on_auth_required_;
141 net::CompletionCallback* callback_;
142
143 string16 username_;
144 string16 password_;
145
146 DISALLOW_COPY_AND_ASSIGN(SocketStreamEventRecorder);
147 };
148
149 namespace net {
150
151 class SocketStreamTest : public PlatformTest {
152 public:
~SocketStreamTest()153 virtual ~SocketStreamTest() {}
SetUp()154 virtual void SetUp() {
155 mock_socket_factory_.reset();
156 handshake_request_ = kWebSocketHandshakeRequest;
157 handshake_response_ = kWebSocketHandshakeResponse;
158 }
TearDown()159 virtual void TearDown() {
160 mock_socket_factory_.reset();
161 }
162
SetWebSocketHandshakeMessage(const char * request,const char * response)163 virtual void SetWebSocketHandshakeMessage(
164 const char* request, const char* response) {
165 handshake_request_ = request;
166 handshake_response_ = response;
167 }
AddWebSocketMessage(const std::string & message)168 virtual void AddWebSocketMessage(const std::string& message) {
169 messages_.push_back(message);
170 }
171
GetMockClientSocketFactory()172 virtual MockClientSocketFactory* GetMockClientSocketFactory() {
173 mock_socket_factory_.reset(new MockClientSocketFactory);
174 return mock_socket_factory_.get();
175 }
176
DoSendWebSocketHandshake(SocketStreamEvent * event)177 virtual void DoSendWebSocketHandshake(SocketStreamEvent* event) {
178 event->socket->SendData(
179 handshake_request_.data(), handshake_request_.size());
180 }
181
DoCloseFlushPendingWriteTest(SocketStreamEvent * event)182 virtual void DoCloseFlushPendingWriteTest(SocketStreamEvent* event) {
183 // handshake response received.
184 for (size_t i = 0; i < messages_.size(); i++) {
185 std::vector<char> frame;
186 frame.push_back('\0');
187 frame.insert(frame.end(), messages_[i].begin(), messages_[i].end());
188 frame.push_back('\xff');
189 EXPECT_TRUE(event->socket->SendData(&frame[0], frame.size()));
190 }
191 // Actual ClientSocket close must happen after all frames queued by
192 // SendData above are sent out.
193 event->socket->Close();
194 }
195
196 static const char* kWebSocketHandshakeRequest;
197 static const char* kWebSocketHandshakeResponse;
198
199 private:
200 std::string handshake_request_;
201 std::string handshake_response_;
202 std::vector<std::string> messages_;
203
204 scoped_ptr<MockClientSocketFactory> mock_socket_factory_;
205 };
206
207 const char* SocketStreamTest::kWebSocketHandshakeRequest =
208 "GET /demo HTTP/1.1\r\n"
209 "Host: example.com\r\n"
210 "Connection: Upgrade\r\n"
211 "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n"
212 "Sec-WebSocket-Protocol: sample\r\n"
213 "Upgrade: WebSocket\r\n"
214 "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n"
215 "Origin: http://example.com\r\n"
216 "\r\n"
217 "^n:ds[4U";
218
219 const char* SocketStreamTest::kWebSocketHandshakeResponse =
220 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
221 "Upgrade: WebSocket\r\n"
222 "Connection: Upgrade\r\n"
223 "Sec-WebSocket-Origin: http://example.com\r\n"
224 "Sec-WebSocket-Location: ws://example.com/demo\r\n"
225 "Sec-WebSocket-Protocol: sample\r\n"
226 "\r\n"
227 "8jKS'y:G*Co,Wxa-";
228
TEST_F(SocketStreamTest,CloseFlushPendingWrite)229 TEST_F(SocketStreamTest, CloseFlushPendingWrite) {
230 TestCompletionCallback callback;
231
232 scoped_ptr<SocketStreamEventRecorder> delegate(
233 new SocketStreamEventRecorder(&callback));
234 // Necessary for NewCallback.
235 SocketStreamTest* test = this;
236 delegate->SetOnConnected(NewCallback(
237 test, &SocketStreamTest::DoSendWebSocketHandshake));
238 delegate->SetOnReceivedData(NewCallback(
239 test, &SocketStreamTest::DoCloseFlushPendingWriteTest));
240
241 MockHostResolver host_resolver;
242
243 scoped_refptr<SocketStream> socket_stream(
244 new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
245
246 socket_stream->set_context(new TestURLRequestContext());
247 socket_stream->SetHostResolver(&host_resolver);
248
249 MockWrite data_writes[] = {
250 MockWrite(SocketStreamTest::kWebSocketHandshakeRequest),
251 MockWrite(true, "\0message1\xff", 10),
252 MockWrite(true, "\0message2\xff", 10)
253 };
254 MockRead data_reads[] = {
255 MockRead(SocketStreamTest::kWebSocketHandshakeResponse),
256 // Server doesn't close the connection after handshake.
257 MockRead(true, ERR_IO_PENDING)
258 };
259 AddWebSocketMessage("message1");
260 AddWebSocketMessage("message2");
261
262 scoped_refptr<DelayedSocketData> data_provider(
263 new DelayedSocketData(1,
264 data_reads, arraysize(data_reads),
265 data_writes, arraysize(data_writes)));
266
267 MockClientSocketFactory* mock_socket_factory =
268 GetMockClientSocketFactory();
269 mock_socket_factory->AddSocketDataProvider(data_provider.get());
270
271 socket_stream->SetClientSocketFactory(mock_socket_factory);
272
273 socket_stream->Connect();
274
275 callback.WaitForResult();
276
277 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
278 EXPECT_EQ(6U, events.size());
279
280 EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[0].event_type);
281 EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[1].event_type);
282 EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[2].event_type);
283 EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[3].event_type);
284 EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[4].event_type);
285 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[5].event_type);
286 }
287
TEST_F(SocketStreamTest,BasicAuthProxy)288 TEST_F(SocketStreamTest, BasicAuthProxy) {
289 MockClientSocketFactory mock_socket_factory;
290 MockWrite data_writes1[] = {
291 MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
292 "Host: example.com\r\n"
293 "Proxy-Connection: keep-alive\r\n\r\n"),
294 };
295 MockRead data_reads1[] = {
296 MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"),
297 MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"),
298 MockRead("\r\n"),
299 };
300 StaticSocketDataProvider data1(data_reads1, arraysize(data_reads1),
301 data_writes1, arraysize(data_writes1));
302 mock_socket_factory.AddSocketDataProvider(&data1);
303
304 MockWrite data_writes2[] = {
305 MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
306 "Host: example.com\r\n"
307 "Proxy-Connection: keep-alive\r\n"
308 "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"),
309 };
310 MockRead data_reads2[] = {
311 MockRead("HTTP/1.1 200 Connection Established\r\n"),
312 MockRead("Proxy-agent: Apache/2.2.8\r\n"),
313 MockRead("\r\n"),
314 // SocketStream::DoClose is run asynchronously. Socket can be read after
315 // "\r\n". We have to give ERR_IO_PENDING to SocketStream then to indicate
316 // server doesn't close the connection.
317 MockRead(true, ERR_IO_PENDING)
318 };
319 StaticSocketDataProvider data2(data_reads2, arraysize(data_reads2),
320 data_writes2, arraysize(data_writes2));
321 mock_socket_factory.AddSocketDataProvider(&data2);
322
323 TestCompletionCallback callback;
324
325 scoped_ptr<SocketStreamEventRecorder> delegate(
326 new SocketStreamEventRecorder(&callback));
327 delegate->SetOnConnected(NewCallback(delegate.get(),
328 &SocketStreamEventRecorder::DoClose));
329 delegate->SetAuthInfo(ASCIIToUTF16("foo"), ASCIIToUTF16("bar"));
330 delegate->SetOnAuthRequired(
331 NewCallback(delegate.get(),
332 &SocketStreamEventRecorder::DoRestartWithAuth));
333
334 scoped_refptr<SocketStream> socket_stream(
335 new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
336
337 socket_stream->set_context(new TestURLRequestContext("myproxy:70"));
338 MockHostResolver host_resolver;
339 socket_stream->SetHostResolver(&host_resolver);
340 socket_stream->SetClientSocketFactory(&mock_socket_factory);
341
342 socket_stream->Connect();
343
344 callback.WaitForResult();
345
346 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
347 EXPECT_EQ(3U, events.size());
348
349 EXPECT_EQ(SocketStreamEvent::EVENT_AUTH_REQUIRED, events[0].event_type);
350 EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type);
351 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[2].event_type);
352
353 // TODO(eroman): Add back NetLogTest here...
354 }
355
356 } // namespace net
357