1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <string>
6 #include <vector>
7
8 #include "base/callback.h"
9 #include "net/base/completion_callback.h"
10 #include "net/base/io_buffer.h"
11 #include "net/base/mock_host_resolver.h"
12 #include "net/base/test_completion_callback.h"
13 #include "net/socket/socket_test_util.h"
14 #include "net/url_request/url_request_test_util.h"
15 #include "net/websockets/websocket.h"
16 #include "testing/gtest/include/gtest/gtest.h"
17 #include "testing/gmock/include/gmock/gmock.h"
18 #include "testing/platform_test.h"
19
20 struct WebSocketEvent {
21 enum EventType {
22 EVENT_OPEN, EVENT_MESSAGE, EVENT_ERROR, EVENT_CLOSE,
23 };
24
WebSocketEventWebSocketEvent25 WebSocketEvent(EventType type, net::WebSocket* websocket,
26 const std::string& websocket_msg, bool websocket_flag)
27 : event_type(type), socket(websocket), msg(websocket_msg),
28 flag(websocket_flag) {}
29
30 EventType event_type;
31 net::WebSocket* socket;
32 std::string msg;
33 bool flag;
34 };
35
36 class WebSocketEventRecorder : public net::WebSocketDelegate {
37 public:
WebSocketEventRecorder(net::CompletionCallback * callback)38 explicit WebSocketEventRecorder(net::CompletionCallback* callback)
39 : onopen_(NULL),
40 onmessage_(NULL),
41 onerror_(NULL),
42 onclose_(NULL),
43 callback_(callback) {}
~WebSocketEventRecorder()44 virtual ~WebSocketEventRecorder() {
45 delete onopen_;
46 delete onmessage_;
47 delete onerror_;
48 delete onclose_;
49 }
50
SetOnOpen(Callback1<WebSocketEvent * >::Type * callback)51 void SetOnOpen(Callback1<WebSocketEvent*>::Type* callback) {
52 onopen_ = callback;
53 }
SetOnMessage(Callback1<WebSocketEvent * >::Type * callback)54 void SetOnMessage(Callback1<WebSocketEvent*>::Type* callback) {
55 onmessage_ = callback;
56 }
SetOnClose(Callback1<WebSocketEvent * >::Type * callback)57 void SetOnClose(Callback1<WebSocketEvent*>::Type* callback) {
58 onclose_ = callback;
59 }
60
OnOpen(net::WebSocket * socket)61 virtual void OnOpen(net::WebSocket* socket) {
62 events_.push_back(
63 WebSocketEvent(WebSocketEvent::EVENT_OPEN, socket,
64 std::string(), false));
65 if (onopen_)
66 onopen_->Run(&events_.back());
67 }
68
OnMessage(net::WebSocket * socket,const std::string & msg)69 virtual void OnMessage(net::WebSocket* socket, const std::string& msg) {
70 events_.push_back(
71 WebSocketEvent(WebSocketEvent::EVENT_MESSAGE, socket, msg, false));
72 if (onmessage_)
73 onmessage_->Run(&events_.back());
74 }
OnError(net::WebSocket * socket)75 virtual void OnError(net::WebSocket* socket) {
76 events_.push_back(
77 WebSocketEvent(WebSocketEvent::EVENT_ERROR, socket,
78 std::string(), false));
79 if (onerror_)
80 onerror_->Run(&events_.back());
81 }
OnClose(net::WebSocket * socket,bool was_clean)82 virtual void OnClose(net::WebSocket* socket, bool was_clean) {
83 events_.push_back(
84 WebSocketEvent(WebSocketEvent::EVENT_CLOSE, socket,
85 std::string(), was_clean));
86 if (onclose_)
87 onclose_->Run(&events_.back());
88 if (callback_)
89 callback_->Run(net::OK);
90 }
91
DoClose(WebSocketEvent * event)92 void DoClose(WebSocketEvent* event) {
93 event->socket->Close();
94 }
95
GetSeenEvents() const96 const std::vector<WebSocketEvent>& GetSeenEvents() const {
97 return events_;
98 }
99
100 private:
101 std::vector<WebSocketEvent> events_;
102 Callback1<WebSocketEvent*>::Type* onopen_;
103 Callback1<WebSocketEvent*>::Type* onmessage_;
104 Callback1<WebSocketEvent*>::Type* onerror_;
105 Callback1<WebSocketEvent*>::Type* onclose_;
106 net::CompletionCallback* callback_;
107
108 DISALLOW_COPY_AND_ASSIGN(WebSocketEventRecorder);
109 };
110
111 namespace net {
112
113 class WebSocketTest : public PlatformTest {
114 protected:
InitReadBuf(WebSocket * websocket)115 void InitReadBuf(WebSocket* websocket) {
116 // Set up |current_read_buf_|.
117 websocket->current_read_buf_ = new GrowableIOBuffer();
118 }
SetReadConsumed(WebSocket * websocket,int consumed)119 void SetReadConsumed(WebSocket* websocket, int consumed) {
120 websocket->read_consumed_len_ = consumed;
121 }
AddToReadBuf(WebSocket * websocket,const char * data,int len)122 void AddToReadBuf(WebSocket* websocket, const char* data, int len) {
123 websocket->AddToReadBuffer(data, len);
124 }
125
TestProcessFrameData(WebSocket * websocket,const char * expected_remaining_data,int expected_remaining_len)126 void TestProcessFrameData(WebSocket* websocket,
127 const char* expected_remaining_data,
128 int expected_remaining_len) {
129 websocket->ProcessFrameData();
130
131 const char* actual_remaining_data =
132 websocket->current_read_buf_->StartOfBuffer()
133 + websocket->read_consumed_len_;
134 int actual_remaining_len =
135 websocket->current_read_buf_->offset() - websocket->read_consumed_len_;
136
137 EXPECT_EQ(expected_remaining_len, actual_remaining_len);
138 EXPECT_TRUE(!memcmp(expected_remaining_data, actual_remaining_data,
139 expected_remaining_len));
140 }
141 };
142
TEST_F(WebSocketTest,Connect)143 TEST_F(WebSocketTest, Connect) {
144 MockClientSocketFactory mock_socket_factory;
145 MockRead data_reads[] = {
146 MockRead("HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
147 "Upgrade: WebSocket\r\n"
148 "Connection: Upgrade\r\n"
149 "WebSocket-Origin: http://example.com\r\n"
150 "WebSocket-Location: ws://example.com/demo\r\n"
151 "WebSocket-Protocol: sample\r\n"
152 "\r\n"),
153 // Server doesn't close the connection after handshake.
154 MockRead(true, ERR_IO_PENDING),
155 };
156 MockWrite data_writes[] = {
157 MockWrite("GET /demo HTTP/1.1\r\n"
158 "Upgrade: WebSocket\r\n"
159 "Connection: Upgrade\r\n"
160 "Host: example.com\r\n"
161 "Origin: http://example.com\r\n"
162 "WebSocket-Protocol: sample\r\n"
163 "\r\n"),
164 };
165 StaticSocketDataProvider data(data_reads, arraysize(data_reads),
166 data_writes, arraysize(data_writes));
167 mock_socket_factory.AddSocketDataProvider(&data);
168 MockHostResolver host_resolver;
169
170 WebSocket::Request* request(
171 new WebSocket::Request(GURL("ws://example.com/demo"),
172 "sample",
173 "http://example.com",
174 "ws://example.com/demo",
175 WebSocket::DRAFT75,
176 new TestURLRequestContext()));
177 request->SetHostResolver(&host_resolver);
178 request->SetClientSocketFactory(&mock_socket_factory);
179
180 TestCompletionCallback callback;
181
182 scoped_ptr<WebSocketEventRecorder> delegate(
183 new WebSocketEventRecorder(&callback));
184 delegate->SetOnOpen(NewCallback(delegate.get(),
185 &WebSocketEventRecorder::DoClose));
186
187 scoped_refptr<WebSocket> websocket(
188 new WebSocket(request, delegate.get()));
189
190 EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state());
191 websocket->Connect();
192
193 callback.WaitForResult();
194
195 const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
196 EXPECT_EQ(2U, events.size());
197
198 EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type);
199 EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[1].event_type);
200 }
201
TEST_F(WebSocketTest,ServerSentData)202 TEST_F(WebSocketTest, ServerSentData) {
203 MockClientSocketFactory mock_socket_factory;
204 static const char kMessage[] = "Hello";
205 static const char kFrame[] = "\x00Hello\xff";
206 static const int kFrameLen = sizeof(kFrame) - 1;
207 MockRead data_reads[] = {
208 MockRead("HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
209 "Upgrade: WebSocket\r\n"
210 "Connection: Upgrade\r\n"
211 "WebSocket-Origin: http://example.com\r\n"
212 "WebSocket-Location: ws://example.com/demo\r\n"
213 "WebSocket-Protocol: sample\r\n"
214 "\r\n"),
215 MockRead(true, kFrame, kFrameLen),
216 // Server doesn't close the connection after handshake.
217 MockRead(true, ERR_IO_PENDING),
218 };
219 MockWrite data_writes[] = {
220 MockWrite("GET /demo HTTP/1.1\r\n"
221 "Upgrade: WebSocket\r\n"
222 "Connection: Upgrade\r\n"
223 "Host: example.com\r\n"
224 "Origin: http://example.com\r\n"
225 "WebSocket-Protocol: sample\r\n"
226 "\r\n"),
227 };
228 StaticSocketDataProvider data(data_reads, arraysize(data_reads),
229 data_writes, arraysize(data_writes));
230 mock_socket_factory.AddSocketDataProvider(&data);
231 MockHostResolver host_resolver;
232
233 WebSocket::Request* request(
234 new WebSocket::Request(GURL("ws://example.com/demo"),
235 "sample",
236 "http://example.com",
237 "ws://example.com/demo",
238 WebSocket::DRAFT75,
239 new TestURLRequestContext()));
240 request->SetHostResolver(&host_resolver);
241 request->SetClientSocketFactory(&mock_socket_factory);
242
243 TestCompletionCallback callback;
244
245 scoped_ptr<WebSocketEventRecorder> delegate(
246 new WebSocketEventRecorder(&callback));
247 delegate->SetOnMessage(NewCallback(delegate.get(),
248 &WebSocketEventRecorder::DoClose));
249
250 scoped_refptr<WebSocket> websocket(
251 new WebSocket(request, delegate.get()));
252
253 EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state());
254 websocket->Connect();
255
256 callback.WaitForResult();
257
258 const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
259 EXPECT_EQ(3U, events.size());
260
261 EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type);
262 EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[1].event_type);
263 EXPECT_EQ(kMessage, events[1].msg);
264 EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[2].event_type);
265 }
266
TEST_F(WebSocketTest,ProcessFrameDataForLengthCalculation)267 TEST_F(WebSocketTest, ProcessFrameDataForLengthCalculation) {
268 WebSocket::Request* request(
269 new WebSocket::Request(GURL("ws://example.com/demo"),
270 "sample",
271 "http://example.com",
272 "ws://example.com/demo",
273 WebSocket::DRAFT75,
274 new TestURLRequestContext()));
275 TestCompletionCallback callback;
276 scoped_ptr<WebSocketEventRecorder> delegate(
277 new WebSocketEventRecorder(&callback));
278
279 scoped_refptr<WebSocket> websocket(
280 new WebSocket(request, delegate.get()));
281
282 // Frame data: skip length 1 ('x'), and try to skip length 129
283 // (1 * 128 + 1) bytes after \x81\x01, but buffer is too short to skip.
284 static const char kTestLengthFrame[] =
285 "\x80\x01x\x80\x81\x01\x01\x00unexpected data\xFF";
286 const int kTestLengthFrameLength = sizeof(kTestLengthFrame) - 1;
287 InitReadBuf(websocket.get());
288 AddToReadBuf(websocket.get(), kTestLengthFrame, kTestLengthFrameLength);
289 SetReadConsumed(websocket.get(), 0);
290
291 static const char kExpectedRemainingFrame[] =
292 "\x80\x81\x01\x01\x00unexpected data\xFF";
293 const int kExpectedRemainingLength = sizeof(kExpectedRemainingFrame) - 1;
294 TestProcessFrameData(websocket.get(),
295 kExpectedRemainingFrame, kExpectedRemainingLength);
296 // No onmessage event expected.
297 const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
298 EXPECT_EQ(1U, events.size());
299
300 EXPECT_EQ(WebSocketEvent::EVENT_ERROR, events[0].event_type);
301
302 websocket->DetachDelegate();
303 }
304
TEST_F(WebSocketTest,ProcessFrameDataForUnterminatedString)305 TEST_F(WebSocketTest, ProcessFrameDataForUnterminatedString) {
306 WebSocket::Request* request(
307 new WebSocket::Request(GURL("ws://example.com/demo"),
308 "sample",
309 "http://example.com",
310 "ws://example.com/demo",
311 WebSocket::DRAFT75,
312 new TestURLRequestContext()));
313 TestCompletionCallback callback;
314 scoped_ptr<WebSocketEventRecorder> delegate(
315 new WebSocketEventRecorder(&callback));
316
317 scoped_refptr<WebSocket> websocket(
318 new WebSocket(request, delegate.get()));
319
320 static const char kTestUnterminatedFrame[] =
321 "\x00unterminated frame";
322 const int kTestUnterminatedFrameLength = sizeof(kTestUnterminatedFrame) - 1;
323 InitReadBuf(websocket.get());
324 AddToReadBuf(websocket.get(), kTestUnterminatedFrame,
325 kTestUnterminatedFrameLength);
326 SetReadConsumed(websocket.get(), 0);
327 TestProcessFrameData(websocket.get(),
328 kTestUnterminatedFrame, kTestUnterminatedFrameLength);
329 {
330 // No onmessage event expected.
331 const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
332 EXPECT_EQ(0U, events.size());
333 }
334
335 static const char kTestTerminateFrame[] = " is terminated in next read\xff";
336 const int kTestTerminateFrameLength = sizeof(kTestTerminateFrame) - 1;
337 AddToReadBuf(websocket.get(), kTestTerminateFrame,
338 kTestTerminateFrameLength);
339 TestProcessFrameData(websocket.get(), "", 0);
340
341 static const char kExpectedMsg[] =
342 "unterminated frame is terminated in next read";
343 {
344 const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
345 EXPECT_EQ(1U, events.size());
346
347 EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[0].event_type);
348 EXPECT_EQ(kExpectedMsg, events[0].msg);
349 }
350
351 websocket->DetachDelegate();
352 }
353
354 } // namespace net
355