• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_handshake_handler.h"
6 
7 #include <limits>
8 
9 #include "base/base64.h"
10 #include "base/sha1.h"
11 #include "base/strings/string_number_conversions.h"
12 #include "base/strings/string_piece.h"
13 #include "base/strings/string_tokenizer.h"
14 #include "base/strings/string_util.h"
15 #include "base/strings/stringprintf.h"
16 #include "net/http/http_request_headers.h"
17 #include "net/http/http_response_headers.h"
18 #include "net/http/http_util.h"
19 #include "net/websockets/websocket_handshake_constants.h"
20 #include "url/gurl.h"
21 
22 namespace net {
23 namespace {
24 
25 const int kVersionHeaderValueForRFC6455 = 13;
26 
27 // Splits |handshake_message| into Status-Line or Request-Line (including CRLF)
28 // and headers (excluding 2nd CRLF of double CRLFs at the end of a handshake
29 // response).
ParseHandshakeHeader(const char * handshake_message,int len,std::string * request_line,std::string * headers)30 void ParseHandshakeHeader(
31     const char* handshake_message, int len,
32     std::string* request_line,
33     std::string* headers) {
34   size_t i = base::StringPiece(handshake_message, len).find_first_of("\r\n");
35   if (i == base::StringPiece::npos) {
36     *request_line = std::string(handshake_message, len);
37     *headers = "";
38     return;
39   }
40   // |request_line| includes \r\n.
41   *request_line = std::string(handshake_message, i + 2);
42 
43   int header_len = len - (i + 2) - 2;
44   if (header_len > 0) {
45     // |handshake_message| includes trailing \r\n\r\n.
46     // |headers| doesn't include 2nd \r\n.
47     *headers = std::string(handshake_message + i + 2, header_len);
48   } else {
49     *headers = "";
50   }
51 }
52 
FetchHeaders(const std::string & headers,const char * const headers_to_get[],size_t headers_to_get_len,std::vector<std::string> * values)53 void FetchHeaders(const std::string& headers,
54                   const char* const headers_to_get[],
55                   size_t headers_to_get_len,
56                   std::vector<std::string>* values) {
57   net::HttpUtil::HeadersIterator iter(headers.begin(), headers.end(), "\r\n");
58   while (iter.GetNext()) {
59     for (size_t i = 0; i < headers_to_get_len; i++) {
60       if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
61                                headers_to_get[i])) {
62         values->push_back(iter.values());
63       }
64     }
65   }
66 }
67 
GetHeaderName(std::string::const_iterator line_begin,std::string::const_iterator line_end,std::string::const_iterator * name_begin,std::string::const_iterator * name_end)68 bool GetHeaderName(std::string::const_iterator line_begin,
69                    std::string::const_iterator line_end,
70                    std::string::const_iterator* name_begin,
71                    std::string::const_iterator* name_end) {
72   std::string::const_iterator colon = std::find(line_begin, line_end, ':');
73   if (colon == line_end) {
74     return false;
75   }
76   *name_begin = line_begin;
77   *name_end = colon;
78   if (*name_begin == *name_end || net::HttpUtil::IsLWS(**name_begin))
79     return false;
80   net::HttpUtil::TrimLWS(name_begin, name_end);
81   return true;
82 }
83 
84 // Similar to HttpUtil::StripHeaders, but it preserves malformed headers, that
85 // is, lines that are not formatted as "<name>: <value>\r\n".
FilterHeaders(const std::string & headers,const char * const headers_to_remove[],size_t headers_to_remove_len)86 std::string FilterHeaders(
87     const std::string& headers,
88     const char* const headers_to_remove[],
89     size_t headers_to_remove_len) {
90   std::string filtered_headers;
91 
92   base::StringTokenizer lines(headers.begin(), headers.end(), "\r\n");
93   while (lines.GetNext()) {
94     std::string::const_iterator line_begin = lines.token_begin();
95     std::string::const_iterator line_end = lines.token_end();
96     std::string::const_iterator name_begin;
97     std::string::const_iterator name_end;
98     bool should_remove = false;
99     if (GetHeaderName(line_begin, line_end, &name_begin, &name_end)) {
100       for (size_t i = 0; i < headers_to_remove_len; ++i) {
101         if (LowerCaseEqualsASCII(name_begin, name_end, headers_to_remove[i])) {
102           should_remove = true;
103           break;
104         }
105       }
106     }
107     if (!should_remove) {
108       filtered_headers.append(line_begin, line_end);
109       filtered_headers.append("\r\n");
110     }
111   }
112   return filtered_headers;
113 }
114 
CheckVersionInRequest(const std::string & request_headers)115 bool CheckVersionInRequest(const std::string& request_headers) {
116   std::vector<std::string> values;
117   const char* const headers_to_get[1] = {
118     websockets::kSecWebSocketVersionLowercase};
119   FetchHeaders(request_headers, headers_to_get, 1, &values);
120   DCHECK_LE(values.size(), 1U);
121   if (values.empty())
122     return false;
123 
124   int version;
125   bool conversion_success = base::StringToInt(values[0], &version);
126   if (!conversion_success)
127     return false;
128 
129   return version == kVersionHeaderValueForRFC6455;
130 }
131 
132 // Append a header to a string. Equivalent to
133 //   response_message += header + ": " + value + "\r\n"
134 // but avoids unnecessary allocations and copies.
AppendHeader(const base::StringPiece & header,const base::StringPiece & value,std::string * response_message)135 void AppendHeader(const base::StringPiece& header,
136                   const base::StringPiece& value,
137                   std::string* response_message) {
138   static const char kColonSpace[] = ": ";
139   const size_t kColonSpaceSize = sizeof(kColonSpace) - 1;
140   static const char kCrNl[] = "\r\n";
141   const size_t kCrNlSize = sizeof(kCrNl) - 1;
142 
143   size_t extra_size =
144       header.size() + kColonSpaceSize + value.size() + kCrNlSize;
145   response_message->reserve(response_message->size() + extra_size);
146   response_message->append(header.begin(), header.end());
147   response_message->append(kColonSpace, kColonSpace + kColonSpaceSize);
148   response_message->append(value.begin(), value.end());
149   response_message->append(kCrNl, kCrNl + kCrNlSize);
150 }
151 
152 }  // namespace
153 
WebSocketHandshakeRequestHandler()154 WebSocketHandshakeRequestHandler::WebSocketHandshakeRequestHandler()
155     : original_length_(0),
156       raw_length_(0) {}
157 
ParseRequest(const char * data,int length)158 bool WebSocketHandshakeRequestHandler::ParseRequest(
159     const char* data, int length) {
160   DCHECK_GT(length, 0);
161   std::string input(data, length);
162   int input_header_length =
163       HttpUtil::LocateEndOfHeaders(input.data(), input.size(), 0);
164   if (input_header_length <= 0)
165     return false;
166 
167   ParseHandshakeHeader(input.data(),
168                        input_header_length,
169                        &request_line_,
170                        &headers_);
171 
172   if (!CheckVersionInRequest(headers_)) {
173     NOTREACHED();
174     return false;
175   }
176 
177   original_length_ = input_header_length;
178   return true;
179 }
180 
original_length() const181 size_t WebSocketHandshakeRequestHandler::original_length() const {
182   return original_length_;
183 }
184 
AppendHeaderIfMissing(const std::string & name,const std::string & value)185 void WebSocketHandshakeRequestHandler::AppendHeaderIfMissing(
186     const std::string& name, const std::string& value) {
187   DCHECK(!headers_.empty());
188   HttpUtil::AppendHeaderIfMissing(name.c_str(), value, &headers_);
189 }
190 
RemoveHeaders(const char * const headers_to_remove[],size_t headers_to_remove_len)191 void WebSocketHandshakeRequestHandler::RemoveHeaders(
192     const char* const headers_to_remove[],
193     size_t headers_to_remove_len) {
194   DCHECK(!headers_.empty());
195   headers_ = FilterHeaders(
196       headers_, headers_to_remove, headers_to_remove_len);
197 }
198 
GetRequestInfo(const GURL & url,std::string * challenge)199 HttpRequestInfo WebSocketHandshakeRequestHandler::GetRequestInfo(
200     const GURL& url, std::string* challenge) {
201   HttpRequestInfo request_info;
202   request_info.url = url;
203   size_t method_end = base::StringPiece(request_line_).find_first_of(" ");
204   if (method_end != base::StringPiece::npos)
205     request_info.method = std::string(request_line_.data(), method_end);
206 
207   request_info.extra_headers.Clear();
208   request_info.extra_headers.AddHeadersFromString(headers_);
209 
210   request_info.extra_headers.RemoveHeader(websockets::kUpgrade);
211   request_info.extra_headers.RemoveHeader(HttpRequestHeaders::kConnection);
212 
213   std::string key;
214   bool header_present = request_info.extra_headers.GetHeader(
215       websockets::kSecWebSocketKey, &key);
216   DCHECK(header_present);
217   request_info.extra_headers.RemoveHeader(websockets::kSecWebSocketKey);
218   *challenge = key;
219   return request_info;
220 }
221 
GetRequestHeaderBlock(const GURL & url,SpdyHeaderBlock * headers,std::string * challenge,int spdy_protocol_version)222 bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock(
223     const GURL& url,
224     SpdyHeaderBlock* headers,
225     std::string* challenge,
226     int spdy_protocol_version) {
227   // Construct opening handshake request headers as a SPDY header block.
228   // For details, see WebSocket Layering over SPDY/3 Draft 8.
229   if (spdy_protocol_version <= 2) {
230     (*headers)["path"] = url.path();
231     (*headers)["version"] = "WebSocket/13";
232     (*headers)["scheme"] = url.scheme();
233   } else {
234     (*headers)[":path"] = url.path();
235     (*headers)[":version"] = "WebSocket/13";
236     (*headers)[":scheme"] = url.scheme();
237   }
238 
239   HttpUtil::HeadersIterator iter(headers_.begin(), headers_.end(), "\r\n");
240   while (iter.GetNext()) {
241     if (LowerCaseEqualsASCII(iter.name_begin(),
242                              iter.name_end(),
243                              websockets::kUpgradeLowercase) ||
244         LowerCaseEqualsASCII(
245             iter.name_begin(), iter.name_end(), "connection") ||
246         LowerCaseEqualsASCII(iter.name_begin(),
247                              iter.name_end(),
248                              websockets::kSecWebSocketVersionLowercase)) {
249       // These headers must be ignored.
250       continue;
251     } else if (LowerCaseEqualsASCII(iter.name_begin(),
252                                     iter.name_end(),
253                                     websockets::kSecWebSocketKeyLowercase)) {
254       *challenge = iter.values();
255       // Sec-WebSocket-Key is not sent to the server.
256       continue;
257     } else if (LowerCaseEqualsASCII(
258                    iter.name_begin(), iter.name_end(), "host") ||
259                LowerCaseEqualsASCII(
260                    iter.name_begin(), iter.name_end(), "origin") ||
261                LowerCaseEqualsASCII(
262                    iter.name_begin(),
263                    iter.name_end(),
264                    websockets::kSecWebSocketProtocolLowercase) ||
265                LowerCaseEqualsASCII(
266                    iter.name_begin(),
267                    iter.name_end(),
268                    websockets::kSecWebSocketExtensionsLowercase)) {
269       // TODO(toyoshim): Some WebSocket extensions may not be compatible with
270       // SPDY. We should omit them from a Sec-WebSocket-Extension header.
271       std::string name;
272       if (spdy_protocol_version <= 2)
273         name = base::StringToLowerASCII(iter.name());
274       else
275         name = ":" + base::StringToLowerASCII(iter.name());
276       (*headers)[name] = iter.values();
277       continue;
278     }
279     // Others should be sent out to |headers|.
280     std::string name = base::StringToLowerASCII(iter.name());
281     SpdyHeaderBlock::iterator found = headers->find(name);
282     if (found == headers->end()) {
283       (*headers)[name] = iter.values();
284     } else {
285       // For now, websocket doesn't use multiple headers, but follows to http.
286       found->second.append(1, '\0');  // +=() doesn't append 0's
287       found->second.append(iter.values());
288     }
289   }
290 
291   return true;
292 }
293 
GetRawRequest()294 std::string WebSocketHandshakeRequestHandler::GetRawRequest() {
295   DCHECK(!request_line_.empty());
296   DCHECK(!headers_.empty());
297 
298   std::string raw_request = request_line_ + headers_ + "\r\n";
299   raw_length_ = raw_request.size();
300   return raw_request;
301 }
302 
raw_length() const303 size_t WebSocketHandshakeRequestHandler::raw_length() const {
304   DCHECK_GT(raw_length_, 0);
305   return raw_length_;
306 }
307 
WebSocketHandshakeResponseHandler()308 WebSocketHandshakeResponseHandler::WebSocketHandshakeResponseHandler()
309     : original_header_length_(0) {}
310 
~WebSocketHandshakeResponseHandler()311 WebSocketHandshakeResponseHandler::~WebSocketHandshakeResponseHandler() {}
312 
ParseRawResponse(const char * data,int length)313 size_t WebSocketHandshakeResponseHandler::ParseRawResponse(
314     const char* data, int length) {
315   DCHECK_GT(length, 0);
316   if (HasResponse()) {
317     DCHECK(!status_line_.empty());
318     // headers_ might be empty for wrong response from server.
319 
320     return 0;
321   }
322 
323   size_t old_original_length = original_.size();
324 
325   original_.append(data, length);
326   // TODO(ukai): fail fast when response gives wrong status code.
327   original_header_length_ = HttpUtil::LocateEndOfHeaders(
328       original_.data(), original_.size(), 0);
329   if (!HasResponse())
330     return length;
331 
332   ParseHandshakeHeader(original_.data(),
333                        original_header_length_,
334                        &status_line_,
335                        &headers_);
336   int header_size = status_line_.size() + headers_.size();
337   DCHECK_GE(original_header_length_, header_size);
338   header_separator_ = std::string(original_.data() + header_size,
339                                   original_header_length_ - header_size);
340   return original_header_length_ - old_original_length;
341 }
342 
HasResponse() const343 bool WebSocketHandshakeResponseHandler::HasResponse() const {
344   return original_header_length_ > 0 &&
345       static_cast<size_t>(original_header_length_) <= original_.size();
346 }
347 
ComputeSecWebSocketAccept(const std::string & key,std::string * accept)348 void ComputeSecWebSocketAccept(const std::string& key,
349                                std::string* accept) {
350   DCHECK(accept);
351 
352   std::string hash =
353       base::SHA1HashString(key + websockets::kWebSocketGuid);
354   base::Base64Encode(hash, accept);
355 }
356 
ParseResponseInfo(const HttpResponseInfo & response_info,const std::string & challenge)357 bool WebSocketHandshakeResponseHandler::ParseResponseInfo(
358     const HttpResponseInfo& response_info,
359     const std::string& challenge) {
360   if (!response_info.headers.get())
361     return false;
362 
363   // TODO(ricea): Eliminate all the reallocations and string copies.
364   std::string response_message;
365   response_message = response_info.headers->GetStatusLine();
366   response_message += "\r\n";
367 
368   AppendHeader(websockets::kUpgrade,
369                websockets::kWebSocketLowercase,
370                &response_message);
371 
372   AppendHeader(
373       HttpRequestHeaders::kConnection, websockets::kUpgrade, &response_message);
374 
375   std::string websocket_accept;
376   ComputeSecWebSocketAccept(challenge, &websocket_accept);
377   AppendHeader(
378       websockets::kSecWebSocketAccept, websocket_accept, &response_message);
379 
380   void* iter = NULL;
381   std::string name;
382   std::string value;
383   while (response_info.headers->EnumerateHeaderLines(&iter, &name, &value)) {
384     AppendHeader(name, value, &response_message);
385   }
386   response_message += "\r\n";
387 
388   return ParseRawResponse(response_message.data(),
389                           response_message.size()) == response_message.size();
390 }
391 
ParseResponseHeaderBlock(const SpdyHeaderBlock & headers,const std::string & challenge,int spdy_protocol_version)392 bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock(
393     const SpdyHeaderBlock& headers,
394     const std::string& challenge,
395     int spdy_protocol_version) {
396   SpdyHeaderBlock::const_iterator status;
397   if (spdy_protocol_version <= 2)
398     status = headers.find("status");
399   else
400     status = headers.find(":status");
401   if (status == headers.end())
402     return false;
403 
404   std::string hash =
405       base::SHA1HashString(challenge + websockets::kWebSocketGuid);
406   std::string websocket_accept;
407   base::Base64Encode(hash, &websocket_accept);
408 
409   std::string response_message = base::StringPrintf(
410       "%s %s\r\n", websockets::kHttpProtocolVersion, status->second.c_str());
411 
412   AppendHeader(
413       websockets::kUpgrade, websockets::kWebSocketLowercase, &response_message);
414   AppendHeader(
415       HttpRequestHeaders::kConnection, websockets::kUpgrade, &response_message);
416   AppendHeader(
417       websockets::kSecWebSocketAccept, websocket_accept, &response_message);
418 
419   for (SpdyHeaderBlock::const_iterator iter = headers.begin();
420        iter != headers.end();
421        ++iter) {
422     // For each value, if the server sends a NUL-separated list of values,
423     // we separate that back out into individual headers for each value
424     // in the list.
425     if ((spdy_protocol_version <= 2 &&
426          LowerCaseEqualsASCII(iter->first, "status")) ||
427         (spdy_protocol_version >= 3 &&
428          LowerCaseEqualsASCII(iter->first, ":status"))) {
429       // The status value is already handled as the first line of
430       // |response_message|. Just skip here.
431       continue;
432     }
433     const std::string& value = iter->second;
434     size_t start = 0;
435     size_t end = 0;
436     do {
437       end = value.find('\0', start);
438       std::string tval;
439       if (end != std::string::npos)
440         tval = value.substr(start, (end - start));
441       else
442         tval = value.substr(start);
443       if (spdy_protocol_version >= 3 &&
444           (LowerCaseEqualsASCII(iter->first,
445                                 websockets::kSecWebSocketProtocolSpdy3) ||
446            LowerCaseEqualsASCII(iter->first,
447                                 websockets::kSecWebSocketExtensionsSpdy3)))
448         AppendHeader(iter->first.substr(1), tval, &response_message);
449       else
450         AppendHeader(iter->first, tval, &response_message);
451       start = end + 1;
452     } while (end != std::string::npos);
453   }
454   response_message += "\r\n";
455 
456   return ParseRawResponse(response_message.data(),
457                           response_message.size()) == response_message.size();
458 }
459 
GetHeaders(const char * const headers_to_get[],size_t headers_to_get_len,std::vector<std::string> * values)460 void WebSocketHandshakeResponseHandler::GetHeaders(
461     const char* const headers_to_get[],
462     size_t headers_to_get_len,
463     std::vector<std::string>* values) {
464   DCHECK(HasResponse());
465   DCHECK(!status_line_.empty());
466   // headers_ might be empty for wrong response from server.
467   if (headers_.empty())
468     return;
469 
470   FetchHeaders(headers_, headers_to_get, headers_to_get_len, values);
471 }
472 
RemoveHeaders(const char * const headers_to_remove[],size_t headers_to_remove_len)473 void WebSocketHandshakeResponseHandler::RemoveHeaders(
474     const char* const headers_to_remove[],
475     size_t headers_to_remove_len) {
476   DCHECK(HasResponse());
477   DCHECK(!status_line_.empty());
478   // headers_ might be empty for wrong response from server.
479   if (headers_.empty())
480     return;
481 
482   headers_ = FilterHeaders(headers_, headers_to_remove, headers_to_remove_len);
483 }
484 
GetRawResponse() const485 std::string WebSocketHandshakeResponseHandler::GetRawResponse() const {
486   DCHECK(HasResponse());
487   return original_.substr(0, original_header_length_);
488 }
489 
GetResponse()490 std::string WebSocketHandshakeResponseHandler::GetResponse() {
491   DCHECK(HasResponse());
492   DCHECK(!status_line_.empty());
493   // headers_ might be empty for wrong response from server.
494 
495   return status_line_ + headers_ + header_separator_;
496 }
497 
498 }  // namespace net
499