• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <string>
6 #include <vector>
7 
8 #include "base/callback.h"
9 #include "net/base/completion_callback.h"
10 #include "net/base/io_buffer.h"
11 #include "net/base/mock_host_resolver.h"
12 #include "net/base/test_completion_callback.h"
13 #include "net/socket/socket_test_util.h"
14 #include "net/url_request/url_request_test_util.h"
15 #include "net/websockets/websocket.h"
16 #include "testing/gtest/include/gtest/gtest.h"
17 #include "testing/gmock/include/gmock/gmock.h"
18 #include "testing/platform_test.h"
19 
20 struct WebSocketEvent {
21   enum EventType {
22     EVENT_OPEN, EVENT_MESSAGE, EVENT_ERROR, EVENT_CLOSE,
23   };
24 
WebSocketEventWebSocketEvent25   WebSocketEvent(EventType type, net::WebSocket* websocket,
26                  const std::string& websocket_msg, bool websocket_flag)
27       : event_type(type), socket(websocket), msg(websocket_msg),
28         flag(websocket_flag) {}
29 
30   EventType event_type;
31   net::WebSocket* socket;
32   std::string msg;
33   bool flag;
34 };
35 
36 class WebSocketEventRecorder : public net::WebSocketDelegate {
37  public:
WebSocketEventRecorder(net::CompletionCallback * callback)38   explicit WebSocketEventRecorder(net::CompletionCallback* callback)
39       : onopen_(NULL),
40         onmessage_(NULL),
41         onerror_(NULL),
42         onclose_(NULL),
43         callback_(callback) {}
~WebSocketEventRecorder()44   virtual ~WebSocketEventRecorder() {
45     delete onopen_;
46     delete onmessage_;
47     delete onerror_;
48     delete onclose_;
49   }
50 
SetOnOpen(Callback1<WebSocketEvent * >::Type * callback)51   void SetOnOpen(Callback1<WebSocketEvent*>::Type* callback) {
52     onopen_ = callback;
53   }
SetOnMessage(Callback1<WebSocketEvent * >::Type * callback)54   void SetOnMessage(Callback1<WebSocketEvent*>::Type* callback) {
55     onmessage_ = callback;
56   }
SetOnClose(Callback1<WebSocketEvent * >::Type * callback)57   void SetOnClose(Callback1<WebSocketEvent*>::Type* callback) {
58     onclose_ = callback;
59   }
60 
OnOpen(net::WebSocket * socket)61   virtual void OnOpen(net::WebSocket* socket) {
62     events_.push_back(
63         WebSocketEvent(WebSocketEvent::EVENT_OPEN, socket,
64                        std::string(), false));
65     if (onopen_)
66       onopen_->Run(&events_.back());
67   }
68 
OnMessage(net::WebSocket * socket,const std::string & msg)69   virtual void OnMessage(net::WebSocket* socket, const std::string& msg) {
70     events_.push_back(
71         WebSocketEvent(WebSocketEvent::EVENT_MESSAGE, socket, msg, false));
72     if (onmessage_)
73       onmessage_->Run(&events_.back());
74   }
OnError(net::WebSocket * socket)75   virtual void OnError(net::WebSocket* socket) {
76     events_.push_back(
77         WebSocketEvent(WebSocketEvent::EVENT_ERROR, socket,
78                        std::string(), false));
79     if (onerror_)
80       onerror_->Run(&events_.back());
81   }
OnClose(net::WebSocket * socket,bool was_clean)82   virtual void OnClose(net::WebSocket* socket, bool was_clean) {
83     events_.push_back(
84         WebSocketEvent(WebSocketEvent::EVENT_CLOSE, socket,
85                        std::string(), was_clean));
86     if (onclose_)
87       onclose_->Run(&events_.back());
88     if (callback_)
89       callback_->Run(net::OK);
90   }
91 
DoClose(WebSocketEvent * event)92   void DoClose(WebSocketEvent* event) {
93     event->socket->Close();
94   }
95 
GetSeenEvents() const96   const std::vector<WebSocketEvent>& GetSeenEvents() const {
97     return events_;
98   }
99 
100  private:
101   std::vector<WebSocketEvent> events_;
102   Callback1<WebSocketEvent*>::Type* onopen_;
103   Callback1<WebSocketEvent*>::Type* onmessage_;
104   Callback1<WebSocketEvent*>::Type* onerror_;
105   Callback1<WebSocketEvent*>::Type* onclose_;
106   net::CompletionCallback* callback_;
107 
108   DISALLOW_COPY_AND_ASSIGN(WebSocketEventRecorder);
109 };
110 
111 namespace net {
112 
113 class WebSocketTest : public PlatformTest {
114  protected:
InitReadBuf(WebSocket * websocket)115   void InitReadBuf(WebSocket* websocket) {
116     // Set up |current_read_buf_|.
117     websocket->current_read_buf_ = new GrowableIOBuffer();
118   }
SetReadConsumed(WebSocket * websocket,int consumed)119   void SetReadConsumed(WebSocket* websocket, int consumed) {
120     websocket->read_consumed_len_ = consumed;
121   }
AddToReadBuf(WebSocket * websocket,const char * data,int len)122   void AddToReadBuf(WebSocket* websocket, const char* data, int len) {
123     websocket->AddToReadBuffer(data, len);
124   }
125 
TestProcessFrameData(WebSocket * websocket,const char * expected_remaining_data,int expected_remaining_len)126   void TestProcessFrameData(WebSocket* websocket,
127                             const char* expected_remaining_data,
128                             int expected_remaining_len) {
129     websocket->ProcessFrameData();
130 
131     const char* actual_remaining_data =
132         websocket->current_read_buf_->StartOfBuffer()
133         + websocket->read_consumed_len_;
134     int actual_remaining_len =
135         websocket->current_read_buf_->offset() - websocket->read_consumed_len_;
136 
137     EXPECT_EQ(expected_remaining_len, actual_remaining_len);
138     EXPECT_TRUE(!memcmp(expected_remaining_data, actual_remaining_data,
139                         expected_remaining_len));
140   }
141 };
142 
TEST_F(WebSocketTest,Connect)143 TEST_F(WebSocketTest, Connect) {
144   MockClientSocketFactory mock_socket_factory;
145   MockRead data_reads[] = {
146     MockRead("HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
147              "Upgrade: WebSocket\r\n"
148              "Connection: Upgrade\r\n"
149              "WebSocket-Origin: http://example.com\r\n"
150              "WebSocket-Location: ws://example.com/demo\r\n"
151              "WebSocket-Protocol: sample\r\n"
152              "\r\n"),
153     // Server doesn't close the connection after handshake.
154     MockRead(true, ERR_IO_PENDING),
155   };
156   MockWrite data_writes[] = {
157     MockWrite("GET /demo HTTP/1.1\r\n"
158               "Upgrade: WebSocket\r\n"
159               "Connection: Upgrade\r\n"
160               "Host: example.com\r\n"
161               "Origin: http://example.com\r\n"
162               "WebSocket-Protocol: sample\r\n"
163               "\r\n"),
164   };
165   StaticSocketDataProvider data(data_reads, arraysize(data_reads),
166                                 data_writes, arraysize(data_writes));
167   mock_socket_factory.AddSocketDataProvider(&data);
168   MockHostResolver host_resolver;
169 
170   WebSocket::Request* request(
171       new WebSocket::Request(GURL("ws://example.com/demo"),
172                              "sample",
173                              "http://example.com",
174                              "ws://example.com/demo",
175                              WebSocket::DRAFT75,
176                              new TestURLRequestContext()));
177   request->SetHostResolver(&host_resolver);
178   request->SetClientSocketFactory(&mock_socket_factory);
179 
180   TestCompletionCallback callback;
181 
182   scoped_ptr<WebSocketEventRecorder> delegate(
183       new WebSocketEventRecorder(&callback));
184   delegate->SetOnOpen(NewCallback(delegate.get(),
185                                   &WebSocketEventRecorder::DoClose));
186 
187   scoped_refptr<WebSocket> websocket(
188       new WebSocket(request, delegate.get()));
189 
190   EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state());
191   websocket->Connect();
192 
193   callback.WaitForResult();
194 
195   const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
196   EXPECT_EQ(2U, events.size());
197 
198   EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type);
199   EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[1].event_type);
200 }
201 
TEST_F(WebSocketTest,ServerSentData)202 TEST_F(WebSocketTest, ServerSentData) {
203   MockClientSocketFactory mock_socket_factory;
204   static const char kMessage[] = "Hello";
205   static const char kFrame[] = "\x00Hello\xff";
206   static const int kFrameLen = sizeof(kFrame) - 1;
207   MockRead data_reads[] = {
208     MockRead("HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
209              "Upgrade: WebSocket\r\n"
210              "Connection: Upgrade\r\n"
211              "WebSocket-Origin: http://example.com\r\n"
212              "WebSocket-Location: ws://example.com/demo\r\n"
213              "WebSocket-Protocol: sample\r\n"
214              "\r\n"),
215     MockRead(true, kFrame, kFrameLen),
216     // Server doesn't close the connection after handshake.
217     MockRead(true, ERR_IO_PENDING),
218   };
219   MockWrite data_writes[] = {
220     MockWrite("GET /demo HTTP/1.1\r\n"
221               "Upgrade: WebSocket\r\n"
222               "Connection: Upgrade\r\n"
223               "Host: example.com\r\n"
224               "Origin: http://example.com\r\n"
225               "WebSocket-Protocol: sample\r\n"
226               "\r\n"),
227   };
228   StaticSocketDataProvider data(data_reads, arraysize(data_reads),
229                                 data_writes, arraysize(data_writes));
230   mock_socket_factory.AddSocketDataProvider(&data);
231   MockHostResolver host_resolver;
232 
233   WebSocket::Request* request(
234       new WebSocket::Request(GURL("ws://example.com/demo"),
235                              "sample",
236                              "http://example.com",
237                              "ws://example.com/demo",
238                              WebSocket::DRAFT75,
239                              new TestURLRequestContext()));
240   request->SetHostResolver(&host_resolver);
241   request->SetClientSocketFactory(&mock_socket_factory);
242 
243   TestCompletionCallback callback;
244 
245   scoped_ptr<WebSocketEventRecorder> delegate(
246       new WebSocketEventRecorder(&callback));
247   delegate->SetOnMessage(NewCallback(delegate.get(),
248                                      &WebSocketEventRecorder::DoClose));
249 
250   scoped_refptr<WebSocket> websocket(
251       new WebSocket(request, delegate.get()));
252 
253   EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state());
254   websocket->Connect();
255 
256   callback.WaitForResult();
257 
258   const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
259   EXPECT_EQ(3U, events.size());
260 
261   EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type);
262   EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[1].event_type);
263   EXPECT_EQ(kMessage, events[1].msg);
264   EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[2].event_type);
265 }
266 
TEST_F(WebSocketTest,ProcessFrameDataForLengthCalculation)267 TEST_F(WebSocketTest, ProcessFrameDataForLengthCalculation) {
268   WebSocket::Request* request(
269       new WebSocket::Request(GURL("ws://example.com/demo"),
270                              "sample",
271                              "http://example.com",
272                              "ws://example.com/demo",
273                              WebSocket::DRAFT75,
274                              new TestURLRequestContext()));
275   TestCompletionCallback callback;
276   scoped_ptr<WebSocketEventRecorder> delegate(
277       new WebSocketEventRecorder(&callback));
278 
279   scoped_refptr<WebSocket> websocket(
280       new WebSocket(request, delegate.get()));
281 
282   // Frame data: skip length 1 ('x'), and try to skip length 129
283   // (1 * 128 + 1) bytes after \x81\x01, but buffer is too short to skip.
284   static const char kTestLengthFrame[] =
285       "\x80\x01x\x80\x81\x01\x01\x00unexpected data\xFF";
286   const int kTestLengthFrameLength = sizeof(kTestLengthFrame) - 1;
287   InitReadBuf(websocket.get());
288   AddToReadBuf(websocket.get(), kTestLengthFrame, kTestLengthFrameLength);
289   SetReadConsumed(websocket.get(), 0);
290 
291   static const char kExpectedRemainingFrame[] =
292       "\x80\x81\x01\x01\x00unexpected data\xFF";
293   const int kExpectedRemainingLength = sizeof(kExpectedRemainingFrame) - 1;
294   TestProcessFrameData(websocket.get(),
295                        kExpectedRemainingFrame, kExpectedRemainingLength);
296   // No onmessage event expected.
297   const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
298   EXPECT_EQ(1U, events.size());
299 
300   EXPECT_EQ(WebSocketEvent::EVENT_ERROR, events[0].event_type);
301 
302   websocket->DetachDelegate();
303 }
304 
TEST_F(WebSocketTest,ProcessFrameDataForUnterminatedString)305 TEST_F(WebSocketTest, ProcessFrameDataForUnterminatedString) {
306   WebSocket::Request* request(
307       new WebSocket::Request(GURL("ws://example.com/demo"),
308                              "sample",
309                              "http://example.com",
310                              "ws://example.com/demo",
311                              WebSocket::DRAFT75,
312                              new TestURLRequestContext()));
313   TestCompletionCallback callback;
314   scoped_ptr<WebSocketEventRecorder> delegate(
315       new WebSocketEventRecorder(&callback));
316 
317   scoped_refptr<WebSocket> websocket(
318       new WebSocket(request, delegate.get()));
319 
320   static const char kTestUnterminatedFrame[] =
321       "\x00unterminated frame";
322   const int kTestUnterminatedFrameLength = sizeof(kTestUnterminatedFrame) - 1;
323   InitReadBuf(websocket.get());
324   AddToReadBuf(websocket.get(), kTestUnterminatedFrame,
325                kTestUnterminatedFrameLength);
326   SetReadConsumed(websocket.get(), 0);
327   TestProcessFrameData(websocket.get(),
328                        kTestUnterminatedFrame, kTestUnterminatedFrameLength);
329   {
330     // No onmessage event expected.
331     const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
332     EXPECT_EQ(0U, events.size());
333   }
334 
335   static const char kTestTerminateFrame[] = " is terminated in next read\xff";
336   const int kTestTerminateFrameLength = sizeof(kTestTerminateFrame) - 1;
337   AddToReadBuf(websocket.get(), kTestTerminateFrame,
338                kTestTerminateFrameLength);
339   TestProcessFrameData(websocket.get(), "", 0);
340 
341   static const char kExpectedMsg[] =
342       "unterminated frame is terminated in next read";
343   {
344     const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
345     EXPECT_EQ(1U, events.size());
346 
347     EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[0].event_type);
348     EXPECT_EQ(kExpectedMsg, events[0].msg);
349   }
350 
351   websocket->DetachDelegate();
352 }
353 
354 }  // namespace net
355