1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <string>
6 #include <vector>
7
8 #include "base/memory/scoped_ptr.h"
9 #include "base/string_split.h"
10 #include "base/string_util.h"
11 #include "base/stringprintf.h"
12 #include "net/websockets/websocket_handshake.h"
13 #include "testing/gmock/include/gmock/gmock.h"
14 #include "testing/gtest/include/gtest/gtest.h"
15 #include "testing/platform_test.h"
16
17 namespace net {
18
19 class WebSocketHandshakeTest : public testing::Test {
20 public:
SetUpParameter(WebSocketHandshake * handshake,uint32 number_1,uint32 number_2,const std::string & key_1,const std::string & key_2,const std::string & key_3)21 static void SetUpParameter(WebSocketHandshake* handshake,
22 uint32 number_1, uint32 number_2,
23 const std::string& key_1, const std::string& key_2,
24 const std::string& key_3) {
25 WebSocketHandshake::Parameter* parameter =
26 new WebSocketHandshake::Parameter;
27 parameter->number_1_ = number_1;
28 parameter->number_2_ = number_2;
29 parameter->key_1_ = key_1;
30 parameter->key_2_ = key_2;
31 parameter->key_3_ = key_3;
32 handshake->parameter_.reset(parameter);
33 }
34
ExpectHeaderEquals(const std::string & expected,const std::string & actual)35 static void ExpectHeaderEquals(const std::string& expected,
36 const std::string& actual) {
37 std::vector<std::string> expected_lines;
38 Tokenize(expected, "\r\n", &expected_lines);
39 std::vector<std::string> actual_lines;
40 Tokenize(actual, "\r\n", &actual_lines);
41 // Request lines.
42 EXPECT_EQ(expected_lines[0], actual_lines[0]);
43
44 std::vector<std::string> expected_headers;
45 for (size_t i = 1; i < expected_lines.size(); i++) {
46 // Finish at first CRLF CRLF. Note that /key_3/ might include CRLF.
47 if (expected_lines[i] == "")
48 break;
49 expected_headers.push_back(expected_lines[i]);
50 }
51 sort(expected_headers.begin(), expected_headers.end());
52
53 std::vector<std::string> actual_headers;
54 for (size_t i = 1; i < actual_lines.size(); i++) {
55 // Finish at first CRLF CRLF. Note that /key_3/ might include CRLF.
56 if (actual_lines[i] == "")
57 break;
58 actual_headers.push_back(actual_lines[i]);
59 }
60 sort(actual_headers.begin(), actual_headers.end());
61
62 EXPECT_EQ(expected_headers.size(), actual_headers.size())
63 << "expected:" << expected
64 << "\nactual:" << actual;
65 for (size_t i = 0; i < expected_headers.size(); i++) {
66 EXPECT_EQ(expected_headers[i], actual_headers[i]);
67 }
68 }
69
ExpectHandshakeMessageEquals(const std::string & expected,const std::string & actual)70 static void ExpectHandshakeMessageEquals(const std::string& expected,
71 const std::string& actual) {
72 // Headers.
73 ExpectHeaderEquals(expected, actual);
74 // Compare tailing \r\n\r\n<key3> (4 + 8 bytes).
75 ASSERT_GT(expected.size(), 12U);
76 const char* expected_key3 = expected.data() + expected.size() - 12;
77 EXPECT_GT(actual.size(), 12U);
78 if (actual.size() <= 12U)
79 return;
80 const char* actual_key3 = actual.data() + actual.size() - 12;
81 EXPECT_TRUE(memcmp(expected_key3, actual_key3, 12) == 0)
82 << "expected_key3:" << DumpKey(expected_key3, 12)
83 << ", actual_key3:" << DumpKey(actual_key3, 12);
84 }
85
DumpKey(const char * buf,int len)86 static std::string DumpKey(const char* buf, int len) {
87 std::string s;
88 for (int i = 0; i < len; i++) {
89 if (isprint(buf[i]))
90 s += base::StringPrintf("%c", buf[i]);
91 else
92 s += base::StringPrintf("\\x%02x", buf[i]);
93 }
94 return s;
95 }
96
GetResourceName(WebSocketHandshake * handshake)97 static std::string GetResourceName(WebSocketHandshake* handshake) {
98 return handshake->GetResourceName();
99 }
GetHostFieldValue(WebSocketHandshake * handshake)100 static std::string GetHostFieldValue(WebSocketHandshake* handshake) {
101 return handshake->GetHostFieldValue();
102 }
GetOriginFieldValue(WebSocketHandshake * handshake)103 static std::string GetOriginFieldValue(WebSocketHandshake* handshake) {
104 return handshake->GetOriginFieldValue();
105 }
106 };
107
108
TEST_F(WebSocketHandshakeTest,Connect)109 TEST_F(WebSocketHandshakeTest, Connect) {
110 const std::string kExpectedClientHandshakeMessage =
111 "GET /demo HTTP/1.1\r\n"
112 "Upgrade: WebSocket\r\n"
113 "Connection: Upgrade\r\n"
114 "Host: example.com\r\n"
115 "Origin: http://example.com\r\n"
116 "Sec-WebSocket-Protocol: sample\r\n"
117 "Sec-WebSocket-Key1: 388P O503D&ul7 {K%gX( %7 15\r\n"
118 "Sec-WebSocket-Key2: 1 N ?|k UT0or 3o 4 I97N 5-S3O 31\r\n"
119 "\r\n"
120 "\x47\x30\x22\x2D\x5A\x3F\x47\x58";
121
122 scoped_ptr<WebSocketHandshake> handshake(
123 new WebSocketHandshake(GURL("ws://example.com/demo"),
124 "http://example.com",
125 "ws://example.com/demo",
126 "sample"));
127 SetUpParameter(handshake.get(), 777007543U, 114997259U,
128 "388P O503D&ul7 {K%gX( %7 15",
129 "1 N ?|k UT0or 3o 4 I97N 5-S3O 31",
130 std::string("\x47\x30\x22\x2D\x5A\x3F\x47\x58", 8));
131 EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
132 ExpectHandshakeMessageEquals(
133 kExpectedClientHandshakeMessage,
134 handshake->CreateClientHandshakeMessage());
135
136 const char kResponse[] = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
137 "Upgrade: WebSocket\r\n"
138 "Connection: Upgrade\r\n"
139 "Sec-WebSocket-Origin: http://example.com\r\n"
140 "Sec-WebSocket-Location: ws://example.com/demo\r\n"
141 "Sec-WebSocket-Protocol: sample\r\n"
142 "\r\n"
143 "\x30\x73\x74\x33\x52\x6C\x26\x71\x2D\x32\x5A\x55\x5E\x77\x65\x75";
144 std::vector<std::string> response_lines;
145 base::SplitStringDontTrim(kResponse, '\n', &response_lines);
146
147 EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
148 // too short
149 EXPECT_EQ(-1, handshake->ReadServerHandshake(kResponse, 16));
150 EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
151
152 // only status line
153 std::string response = response_lines[0];
154 EXPECT_EQ(-1, handshake->ReadServerHandshake(
155 response.data(), response.size()));
156 EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
157 // by upgrade header
158 response += response_lines[1];
159 EXPECT_EQ(-1, handshake->ReadServerHandshake(
160 response.data(), response.size()));
161 EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
162 // by connection header
163 response += response_lines[2];
164 EXPECT_EQ(-1, handshake->ReadServerHandshake(
165 response.data(), response.size()));
166 EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
167
168 response += response_lines[3]; // Sec-WebSocket-Origin
169 response += response_lines[4]; // Sec-WebSocket-Location
170 response += response_lines[5]; // Sec-WebSocket-Protocol
171 EXPECT_EQ(-1, handshake->ReadServerHandshake(
172 response.data(), response.size()));
173 EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
174
175 response += response_lines[6]; // \r\n
176 EXPECT_EQ(-1, handshake->ReadServerHandshake(
177 response.data(), response.size()));
178 EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
179
180 int handshake_length = sizeof(kResponse) - 1; // -1 for terminating \0
181 EXPECT_EQ(handshake_length, handshake->ReadServerHandshake(
182 kResponse, handshake_length)); // -1 for terminating \0
183 EXPECT_EQ(WebSocketHandshake::MODE_CONNECTED, handshake->mode());
184 }
185
TEST_F(WebSocketHandshakeTest,ServerSentData)186 TEST_F(WebSocketHandshakeTest, ServerSentData) {
187 const std::string kExpectedClientHandshakeMessage =
188 "GET /demo HTTP/1.1\r\n"
189 "Upgrade: WebSocket\r\n"
190 "Connection: Upgrade\r\n"
191 "Host: example.com\r\n"
192 "Origin: http://example.com\r\n"
193 "Sec-WebSocket-Protocol: sample\r\n"
194 "Sec-WebSocket-Key1: 388P O503D&ul7 {K%gX( %7 15\r\n"
195 "Sec-WebSocket-Key2: 1 N ?|k UT0or 3o 4 I97N 5-S3O 31\r\n"
196 "\r\n"
197 "\x47\x30\x22\x2D\x5A\x3F\x47\x58";
198 scoped_ptr<WebSocketHandshake> handshake(
199 new WebSocketHandshake(GURL("ws://example.com/demo"),
200 "http://example.com",
201 "ws://example.com/demo",
202 "sample"));
203 SetUpParameter(handshake.get(), 777007543U, 114997259U,
204 "388P O503D&ul7 {K%gX( %7 15",
205 "1 N ?|k UT0or 3o 4 I97N 5-S3O 31",
206 std::string("\x47\x30\x22\x2D\x5A\x3F\x47\x58", 8));
207 EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
208 ExpectHandshakeMessageEquals(
209 kExpectedClientHandshakeMessage,
210 handshake->CreateClientHandshakeMessage());
211
212 const char kResponse[] = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
213 "Upgrade: WebSocket\r\n"
214 "Connection: Upgrade\r\n"
215 "Sec-WebSocket-Origin: http://example.com\r\n"
216 "Sec-WebSocket-Location: ws://example.com/demo\r\n"
217 "Sec-WebSocket-Protocol: sample\r\n"
218 "\r\n"
219 "\x30\x73\x74\x33\x52\x6C\x26\x71\x2D\x32\x5A\x55\x5E\x77\x65\x75"
220 "\0Hello\xff";
221
222 int handshake_length = strlen(kResponse); // key3 doesn't contain \0.
223 EXPECT_EQ(handshake_length, handshake->ReadServerHandshake(
224 kResponse, sizeof(kResponse) - 1)); // -1 for terminating \0
225 EXPECT_EQ(WebSocketHandshake::MODE_CONNECTED, handshake->mode());
226 }
227
TEST_F(WebSocketHandshakeTest,is_secure_false)228 TEST_F(WebSocketHandshakeTest, is_secure_false) {
229 scoped_ptr<WebSocketHandshake> handshake(
230 new WebSocketHandshake(GURL("ws://example.com/demo"),
231 "http://example.com",
232 "ws://example.com/demo",
233 "sample"));
234 EXPECT_FALSE(handshake->is_secure());
235 }
236
TEST_F(WebSocketHandshakeTest,is_secure_true)237 TEST_F(WebSocketHandshakeTest, is_secure_true) {
238 // wss:// is secure.
239 scoped_ptr<WebSocketHandshake> handshake(
240 new WebSocketHandshake(GURL("wss://example.com/demo"),
241 "http://example.com",
242 "wss://example.com/demo",
243 "sample"));
244 EXPECT_TRUE(handshake->is_secure());
245 }
246
TEST_F(WebSocketHandshakeTest,CreateClientHandshakeMessage_ResourceName)247 TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_ResourceName) {
248 scoped_ptr<WebSocketHandshake> handshake(
249 new WebSocketHandshake(GURL("ws://example.com/Test?q=xxx&p=%20"),
250 "http://example.com",
251 "ws://example.com/demo",
252 "sample"));
253 // Path and query should be preserved as-is.
254 EXPECT_EQ("/Test?q=xxx&p=%20", GetResourceName(handshake.get()));
255 }
256
TEST_F(WebSocketHandshakeTest,CreateClientHandshakeMessage_Host)257 TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_Host) {
258 scoped_ptr<WebSocketHandshake> handshake(
259 new WebSocketHandshake(GURL("ws://Example.Com/demo"),
260 "http://Example.Com",
261 "ws://Example.Com/demo",
262 "sample"));
263 // Host should be lowercased
264 EXPECT_EQ("example.com", GetHostFieldValue(handshake.get()));
265 EXPECT_EQ("http://example.com", GetOriginFieldValue(handshake.get()));
266 }
267
TEST_F(WebSocketHandshakeTest,CreateClientHandshakeMessage_TrimPort80)268 TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_TrimPort80) {
269 scoped_ptr<WebSocketHandshake> handshake(
270 new WebSocketHandshake(GURL("ws://example.com:80/demo"),
271 "http://example.com",
272 "ws://example.com/demo",
273 "sample"));
274 // :80 should be trimmed as it's the default port for ws://.
275 EXPECT_EQ("example.com", GetHostFieldValue(handshake.get()));
276 }
277
TEST_F(WebSocketHandshakeTest,CreateClientHandshakeMessage_TrimPort443)278 TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_TrimPort443) {
279 scoped_ptr<WebSocketHandshake> handshake(
280 new WebSocketHandshake(GURL("wss://example.com:443/demo"),
281 "http://example.com",
282 "wss://example.com/demo",
283 "sample"));
284 // :443 should be trimmed as it's the default port for wss://.
285 EXPECT_EQ("example.com", GetHostFieldValue(handshake.get()));
286 }
287
TEST_F(WebSocketHandshakeTest,CreateClientHandshakeMessage_NonDefaultPortForWs)288 TEST_F(WebSocketHandshakeTest,
289 CreateClientHandshakeMessage_NonDefaultPortForWs) {
290 scoped_ptr<WebSocketHandshake> handshake(
291 new WebSocketHandshake(GURL("ws://example.com:8080/demo"),
292 "http://example.com",
293 "wss://example.com/demo",
294 "sample"));
295 // :8080 should be preserved as it's not the default port for ws://.
296 EXPECT_EQ("example.com:8080", GetHostFieldValue(handshake.get()));
297 }
298
TEST_F(WebSocketHandshakeTest,CreateClientHandshakeMessage_NonDefaultPortForWss)299 TEST_F(WebSocketHandshakeTest,
300 CreateClientHandshakeMessage_NonDefaultPortForWss) {
301 scoped_ptr<WebSocketHandshake> handshake(
302 new WebSocketHandshake(GURL("wss://example.com:4443/demo"),
303 "http://example.com",
304 "wss://example.com/demo",
305 "sample"));
306 // :4443 should be preserved as it's not the default port for wss://.
307 EXPECT_EQ("example.com:4443", GetHostFieldValue(handshake.get()));
308 }
309
TEST_F(WebSocketHandshakeTest,CreateClientHandshakeMessage_WsBut443)310 TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_WsBut443) {
311 scoped_ptr<WebSocketHandshake> handshake(
312 new WebSocketHandshake(GURL("ws://example.com:443/demo"),
313 "http://example.com",
314 "ws://example.com/demo",
315 "sample"));
316 // :443 should be preserved as it's not the default port for ws://.
317 EXPECT_EQ("example.com:443", GetHostFieldValue(handshake.get()));
318 }
319
TEST_F(WebSocketHandshakeTest,CreateClientHandshakeMessage_WssBut80)320 TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_WssBut80) {
321 scoped_ptr<WebSocketHandshake> handshake(
322 new WebSocketHandshake(GURL("wss://example.com:80/demo"),
323 "http://example.com",
324 "wss://example.com/demo",
325 "sample"));
326 // :80 should be preserved as it's not the default port for wss://.
327 EXPECT_EQ("example.com:80", GetHostFieldValue(handshake.get()));
328 }
329
330 } // namespace net
331