1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_handshake_handler.h"
6
7 #include <limits>
8
9 #include "base/base64.h"
10 #include "base/sha1.h"
11 #include "base/strings/string_number_conversions.h"
12 #include "base/strings/string_piece.h"
13 #include "base/strings/string_tokenizer.h"
14 #include "base/strings/string_util.h"
15 #include "base/strings/stringprintf.h"
16 #include "net/http/http_request_headers.h"
17 #include "net/http/http_response_headers.h"
18 #include "net/http/http_util.h"
19 #include "net/websockets/websocket_handshake_constants.h"
20 #include "url/gurl.h"
21
22 namespace net {
23 namespace {
24
25 const int kVersionHeaderValueForRFC6455 = 13;
26
27 // Splits |handshake_message| into Status-Line or Request-Line (including CRLF)
28 // and headers (excluding 2nd CRLF of double CRLFs at the end of a handshake
29 // response).
ParseHandshakeHeader(const char * handshake_message,int len,std::string * request_line,std::string * headers)30 void ParseHandshakeHeader(
31 const char* handshake_message, int len,
32 std::string* request_line,
33 std::string* headers) {
34 size_t i = base::StringPiece(handshake_message, len).find_first_of("\r\n");
35 if (i == base::StringPiece::npos) {
36 *request_line = std::string(handshake_message, len);
37 *headers = "";
38 return;
39 }
40 // |request_line| includes \r\n.
41 *request_line = std::string(handshake_message, i + 2);
42
43 int header_len = len - (i + 2) - 2;
44 if (header_len > 0) {
45 // |handshake_message| includes trailing \r\n\r\n.
46 // |headers| doesn't include 2nd \r\n.
47 *headers = std::string(handshake_message + i + 2, header_len);
48 } else {
49 *headers = "";
50 }
51 }
52
FetchHeaders(const std::string & headers,const char * const headers_to_get[],size_t headers_to_get_len,std::vector<std::string> * values)53 void FetchHeaders(const std::string& headers,
54 const char* const headers_to_get[],
55 size_t headers_to_get_len,
56 std::vector<std::string>* values) {
57 net::HttpUtil::HeadersIterator iter(headers.begin(), headers.end(), "\r\n");
58 while (iter.GetNext()) {
59 for (size_t i = 0; i < headers_to_get_len; i++) {
60 if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
61 headers_to_get[i])) {
62 values->push_back(iter.values());
63 }
64 }
65 }
66 }
67
GetHeaderName(std::string::const_iterator line_begin,std::string::const_iterator line_end,std::string::const_iterator * name_begin,std::string::const_iterator * name_end)68 bool GetHeaderName(std::string::const_iterator line_begin,
69 std::string::const_iterator line_end,
70 std::string::const_iterator* name_begin,
71 std::string::const_iterator* name_end) {
72 std::string::const_iterator colon = std::find(line_begin, line_end, ':');
73 if (colon == line_end) {
74 return false;
75 }
76 *name_begin = line_begin;
77 *name_end = colon;
78 if (*name_begin == *name_end || net::HttpUtil::IsLWS(**name_begin))
79 return false;
80 net::HttpUtil::TrimLWS(name_begin, name_end);
81 return true;
82 }
83
84 // Similar to HttpUtil::StripHeaders, but it preserves malformed headers, that
85 // is, lines that are not formatted as "<name>: <value>\r\n".
FilterHeaders(const std::string & headers,const char * const headers_to_remove[],size_t headers_to_remove_len)86 std::string FilterHeaders(
87 const std::string& headers,
88 const char* const headers_to_remove[],
89 size_t headers_to_remove_len) {
90 std::string filtered_headers;
91
92 base::StringTokenizer lines(headers.begin(), headers.end(), "\r\n");
93 while (lines.GetNext()) {
94 std::string::const_iterator line_begin = lines.token_begin();
95 std::string::const_iterator line_end = lines.token_end();
96 std::string::const_iterator name_begin;
97 std::string::const_iterator name_end;
98 bool should_remove = false;
99 if (GetHeaderName(line_begin, line_end, &name_begin, &name_end)) {
100 for (size_t i = 0; i < headers_to_remove_len; ++i) {
101 if (LowerCaseEqualsASCII(name_begin, name_end, headers_to_remove[i])) {
102 should_remove = true;
103 break;
104 }
105 }
106 }
107 if (!should_remove) {
108 filtered_headers.append(line_begin, line_end);
109 filtered_headers.append("\r\n");
110 }
111 }
112 return filtered_headers;
113 }
114
CheckVersionInRequest(const std::string & request_headers)115 bool CheckVersionInRequest(const std::string& request_headers) {
116 std::vector<std::string> values;
117 const char* const headers_to_get[1] = {
118 websockets::kSecWebSocketVersionLowercase};
119 FetchHeaders(request_headers, headers_to_get, 1, &values);
120 DCHECK_LE(values.size(), 1U);
121 if (values.empty())
122 return false;
123
124 int version;
125 bool conversion_success = base::StringToInt(values[0], &version);
126 if (!conversion_success)
127 return false;
128
129 return version == kVersionHeaderValueForRFC6455;
130 }
131
132 // Append a header to a string. Equivalent to
133 // response_message += header + ": " + value + "\r\n"
134 // but avoids unnecessary allocations and copies.
AppendHeader(const base::StringPiece & header,const base::StringPiece & value,std::string * response_message)135 void AppendHeader(const base::StringPiece& header,
136 const base::StringPiece& value,
137 std::string* response_message) {
138 static const char kColonSpace[] = ": ";
139 const size_t kColonSpaceSize = sizeof(kColonSpace) - 1;
140 static const char kCrNl[] = "\r\n";
141 const size_t kCrNlSize = sizeof(kCrNl) - 1;
142
143 size_t extra_size =
144 header.size() + kColonSpaceSize + value.size() + kCrNlSize;
145 response_message->reserve(response_message->size() + extra_size);
146 response_message->append(header.begin(), header.end());
147 response_message->append(kColonSpace, kColonSpace + kColonSpaceSize);
148 response_message->append(value.begin(), value.end());
149 response_message->append(kCrNl, kCrNl + kCrNlSize);
150 }
151
152 } // namespace
153
WebSocketHandshakeRequestHandler()154 WebSocketHandshakeRequestHandler::WebSocketHandshakeRequestHandler()
155 : original_length_(0),
156 raw_length_(0) {}
157
ParseRequest(const char * data,int length)158 bool WebSocketHandshakeRequestHandler::ParseRequest(
159 const char* data, int length) {
160 DCHECK_GT(length, 0);
161 std::string input(data, length);
162 int input_header_length =
163 HttpUtil::LocateEndOfHeaders(input.data(), input.size(), 0);
164 if (input_header_length <= 0)
165 return false;
166
167 ParseHandshakeHeader(input.data(),
168 input_header_length,
169 &request_line_,
170 &headers_);
171
172 if (!CheckVersionInRequest(headers_)) {
173 NOTREACHED();
174 return false;
175 }
176
177 original_length_ = input_header_length;
178 return true;
179 }
180
original_length() const181 size_t WebSocketHandshakeRequestHandler::original_length() const {
182 return original_length_;
183 }
184
AppendHeaderIfMissing(const std::string & name,const std::string & value)185 void WebSocketHandshakeRequestHandler::AppendHeaderIfMissing(
186 const std::string& name, const std::string& value) {
187 DCHECK(!headers_.empty());
188 HttpUtil::AppendHeaderIfMissing(name.c_str(), value, &headers_);
189 }
190
RemoveHeaders(const char * const headers_to_remove[],size_t headers_to_remove_len)191 void WebSocketHandshakeRequestHandler::RemoveHeaders(
192 const char* const headers_to_remove[],
193 size_t headers_to_remove_len) {
194 DCHECK(!headers_.empty());
195 headers_ = FilterHeaders(
196 headers_, headers_to_remove, headers_to_remove_len);
197 }
198
GetRequestInfo(const GURL & url,std::string * challenge)199 HttpRequestInfo WebSocketHandshakeRequestHandler::GetRequestInfo(
200 const GURL& url, std::string* challenge) {
201 HttpRequestInfo request_info;
202 request_info.url = url;
203 size_t method_end = base::StringPiece(request_line_).find_first_of(" ");
204 if (method_end != base::StringPiece::npos)
205 request_info.method = std::string(request_line_.data(), method_end);
206
207 request_info.extra_headers.Clear();
208 request_info.extra_headers.AddHeadersFromString(headers_);
209
210 request_info.extra_headers.RemoveHeader(websockets::kUpgrade);
211 request_info.extra_headers.RemoveHeader(HttpRequestHeaders::kConnection);
212
213 std::string key;
214 bool header_present = request_info.extra_headers.GetHeader(
215 websockets::kSecWebSocketKey, &key);
216 DCHECK(header_present);
217 request_info.extra_headers.RemoveHeader(websockets::kSecWebSocketKey);
218 *challenge = key;
219 return request_info;
220 }
221
GetRequestHeaderBlock(const GURL & url,SpdyHeaderBlock * headers,std::string * challenge,int spdy_protocol_version)222 bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock(
223 const GURL& url,
224 SpdyHeaderBlock* headers,
225 std::string* challenge,
226 int spdy_protocol_version) {
227 // Construct opening handshake request headers as a SPDY header block.
228 // For details, see WebSocket Layering over SPDY/3 Draft 8.
229 if (spdy_protocol_version <= 2) {
230 (*headers)["path"] = url.path();
231 (*headers)["version"] = "WebSocket/13";
232 (*headers)["scheme"] = url.scheme();
233 } else {
234 (*headers)[":path"] = url.path();
235 (*headers)[":version"] = "WebSocket/13";
236 (*headers)[":scheme"] = url.scheme();
237 }
238
239 HttpUtil::HeadersIterator iter(headers_.begin(), headers_.end(), "\r\n");
240 while (iter.GetNext()) {
241 if (LowerCaseEqualsASCII(iter.name_begin(),
242 iter.name_end(),
243 websockets::kUpgradeLowercase) ||
244 LowerCaseEqualsASCII(
245 iter.name_begin(), iter.name_end(), "connection") ||
246 LowerCaseEqualsASCII(iter.name_begin(),
247 iter.name_end(),
248 websockets::kSecWebSocketVersionLowercase)) {
249 // These headers must be ignored.
250 continue;
251 } else if (LowerCaseEqualsASCII(iter.name_begin(),
252 iter.name_end(),
253 websockets::kSecWebSocketKeyLowercase)) {
254 *challenge = iter.values();
255 // Sec-WebSocket-Key is not sent to the server.
256 continue;
257 } else if (LowerCaseEqualsASCII(
258 iter.name_begin(), iter.name_end(), "host") ||
259 LowerCaseEqualsASCII(
260 iter.name_begin(), iter.name_end(), "origin") ||
261 LowerCaseEqualsASCII(
262 iter.name_begin(),
263 iter.name_end(),
264 websockets::kSecWebSocketProtocolLowercase) ||
265 LowerCaseEqualsASCII(
266 iter.name_begin(),
267 iter.name_end(),
268 websockets::kSecWebSocketExtensionsLowercase)) {
269 // TODO(toyoshim): Some WebSocket extensions may not be compatible with
270 // SPDY. We should omit them from a Sec-WebSocket-Extension header.
271 std::string name;
272 if (spdy_protocol_version <= 2)
273 name = StringToLowerASCII(iter.name());
274 else
275 name = ":" + StringToLowerASCII(iter.name());
276 (*headers)[name] = iter.values();
277 continue;
278 }
279 // Others should be sent out to |headers|.
280 std::string name = StringToLowerASCII(iter.name());
281 SpdyHeaderBlock::iterator found = headers->find(name);
282 if (found == headers->end()) {
283 (*headers)[name] = iter.values();
284 } else {
285 // For now, websocket doesn't use multiple headers, but follows to http.
286 found->second.append(1, '\0'); // +=() doesn't append 0's
287 found->second.append(iter.values());
288 }
289 }
290
291 return true;
292 }
293
GetRawRequest()294 std::string WebSocketHandshakeRequestHandler::GetRawRequest() {
295 DCHECK(!request_line_.empty());
296 DCHECK(!headers_.empty());
297
298 std::string raw_request = request_line_ + headers_ + "\r\n";
299 raw_length_ = raw_request.size();
300 return raw_request;
301 }
302
raw_length() const303 size_t WebSocketHandshakeRequestHandler::raw_length() const {
304 DCHECK_GT(raw_length_, 0);
305 return raw_length_;
306 }
307
WebSocketHandshakeResponseHandler()308 WebSocketHandshakeResponseHandler::WebSocketHandshakeResponseHandler()
309 : original_header_length_(0) {}
310
~WebSocketHandshakeResponseHandler()311 WebSocketHandshakeResponseHandler::~WebSocketHandshakeResponseHandler() {}
312
ParseRawResponse(const char * data,int length)313 size_t WebSocketHandshakeResponseHandler::ParseRawResponse(
314 const char* data, int length) {
315 DCHECK_GT(length, 0);
316 if (HasResponse()) {
317 DCHECK(!status_line_.empty());
318 // headers_ might be empty for wrong response from server.
319
320 return 0;
321 }
322
323 size_t old_original_length = original_.size();
324
325 original_.append(data, length);
326 // TODO(ukai): fail fast when response gives wrong status code.
327 original_header_length_ = HttpUtil::LocateEndOfHeaders(
328 original_.data(), original_.size(), 0);
329 if (!HasResponse())
330 return length;
331
332 ParseHandshakeHeader(original_.data(),
333 original_header_length_,
334 &status_line_,
335 &headers_);
336 int header_size = status_line_.size() + headers_.size();
337 DCHECK_GE(original_header_length_, header_size);
338 header_separator_ = std::string(original_.data() + header_size,
339 original_header_length_ - header_size);
340 return original_header_length_ - old_original_length;
341 }
342
HasResponse() const343 bool WebSocketHandshakeResponseHandler::HasResponse() const {
344 return original_header_length_ > 0 &&
345 static_cast<size_t>(original_header_length_) <= original_.size();
346 }
347
ComputeSecWebSocketAccept(const std::string & key,std::string * accept)348 void ComputeSecWebSocketAccept(const std::string& key,
349 std::string* accept) {
350 DCHECK(accept);
351
352 std::string hash =
353 base::SHA1HashString(key + websockets::kWebSocketGuid);
354 base::Base64Encode(hash, accept);
355 }
356
ParseResponseInfo(const HttpResponseInfo & response_info,const std::string & challenge)357 bool WebSocketHandshakeResponseHandler::ParseResponseInfo(
358 const HttpResponseInfo& response_info,
359 const std::string& challenge) {
360 if (!response_info.headers.get())
361 return false;
362
363 // TODO(ricea): Eliminate all the reallocations and string copies.
364 std::string response_message;
365 response_message = response_info.headers->GetStatusLine();
366 response_message += "\r\n";
367
368 AppendHeader(websockets::kUpgrade,
369 websockets::kWebSocketLowercase,
370 &response_message);
371
372 AppendHeader(
373 HttpRequestHeaders::kConnection, websockets::kUpgrade, &response_message);
374
375 std::string websocket_accept;
376 ComputeSecWebSocketAccept(challenge, &websocket_accept);
377 AppendHeader(
378 websockets::kSecWebSocketAccept, websocket_accept, &response_message);
379
380 void* iter = NULL;
381 std::string name;
382 std::string value;
383 while (response_info.headers->EnumerateHeaderLines(&iter, &name, &value)) {
384 AppendHeader(name, value, &response_message);
385 }
386 response_message += "\r\n";
387
388 return ParseRawResponse(response_message.data(),
389 response_message.size()) == response_message.size();
390 }
391
ParseResponseHeaderBlock(const SpdyHeaderBlock & headers,const std::string & challenge,int spdy_protocol_version)392 bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock(
393 const SpdyHeaderBlock& headers,
394 const std::string& challenge,
395 int spdy_protocol_version) {
396 SpdyHeaderBlock::const_iterator status;
397 if (spdy_protocol_version <= 2)
398 status = headers.find("status");
399 else
400 status = headers.find(":status");
401 if (status == headers.end())
402 return false;
403
404 std::string hash =
405 base::SHA1HashString(challenge + websockets::kWebSocketGuid);
406 std::string websocket_accept;
407 base::Base64Encode(hash, &websocket_accept);
408
409 std::string response_message = base::StringPrintf(
410 "%s %s\r\n", websockets::kHttpProtocolVersion, status->second.c_str());
411
412 AppendHeader(
413 websockets::kUpgrade, websockets::kWebSocketLowercase, &response_message);
414 AppendHeader(
415 HttpRequestHeaders::kConnection, websockets::kUpgrade, &response_message);
416 AppendHeader(
417 websockets::kSecWebSocketAccept, websocket_accept, &response_message);
418
419 for (SpdyHeaderBlock::const_iterator iter = headers.begin();
420 iter != headers.end();
421 ++iter) {
422 // For each value, if the server sends a NUL-separated list of values,
423 // we separate that back out into individual headers for each value
424 // in the list.
425 if ((spdy_protocol_version <= 2 &&
426 LowerCaseEqualsASCII(iter->first, "status")) ||
427 (spdy_protocol_version >= 3 &&
428 LowerCaseEqualsASCII(iter->first, ":status"))) {
429 // The status value is already handled as the first line of
430 // |response_message|. Just skip here.
431 continue;
432 }
433 const std::string& value = iter->second;
434 size_t start = 0;
435 size_t end = 0;
436 do {
437 end = value.find('\0', start);
438 std::string tval;
439 if (end != std::string::npos)
440 tval = value.substr(start, (end - start));
441 else
442 tval = value.substr(start);
443 if (spdy_protocol_version >= 3 &&
444 (LowerCaseEqualsASCII(iter->first,
445 websockets::kSecWebSocketProtocolSpdy3) ||
446 LowerCaseEqualsASCII(iter->first,
447 websockets::kSecWebSocketExtensionsSpdy3)))
448 AppendHeader(iter->first.substr(1), tval, &response_message);
449 else
450 AppendHeader(iter->first, tval, &response_message);
451 start = end + 1;
452 } while (end != std::string::npos);
453 }
454 response_message += "\r\n";
455
456 return ParseRawResponse(response_message.data(),
457 response_message.size()) == response_message.size();
458 }
459
GetHeaders(const char * const headers_to_get[],size_t headers_to_get_len,std::vector<std::string> * values)460 void WebSocketHandshakeResponseHandler::GetHeaders(
461 const char* const headers_to_get[],
462 size_t headers_to_get_len,
463 std::vector<std::string>* values) {
464 DCHECK(HasResponse());
465 DCHECK(!status_line_.empty());
466 // headers_ might be empty for wrong response from server.
467 if (headers_.empty())
468 return;
469
470 FetchHeaders(headers_, headers_to_get, headers_to_get_len, values);
471 }
472
RemoveHeaders(const char * const headers_to_remove[],size_t headers_to_remove_len)473 void WebSocketHandshakeResponseHandler::RemoveHeaders(
474 const char* const headers_to_remove[],
475 size_t headers_to_remove_len) {
476 DCHECK(HasResponse());
477 DCHECK(!status_line_.empty());
478 // headers_ might be empty for wrong response from server.
479 if (headers_.empty())
480 return;
481
482 headers_ = FilterHeaders(headers_, headers_to_remove, headers_to_remove_len);
483 }
484
GetRawResponse() const485 std::string WebSocketHandshakeResponseHandler::GetRawResponse() const {
486 DCHECK(HasResponse());
487 return original_.substr(0, original_header_length_);
488 }
489
GetResponse()490 std::string WebSocketHandshakeResponseHandler::GetResponse() {
491 DCHECK(HasResponse());
492 DCHECK(!status_line_.empty());
493 // headers_ might be empty for wrong response from server.
494
495 return status_line_ + headers_ + header_separator_;
496 }
497
498 } // namespace net
499