• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2010 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_handshake_handler.h"
6 
7 #include "base/md5.h"
8 #include "base/string_piece.h"
9 #include "base/string_util.h"
10 #include "googleurl/src/gurl.h"
11 #include "net/http/http_response_headers.h"
12 #include "net/http/http_util.h"
13 
14 namespace {
15 
16 const size_t kRequestKey3Size = 8U;
17 const size_t kResponseKeySize = 16U;
18 
ParseHandshakeHeader(const char * handshake_message,int len,std::string * status_line,std::string * headers)19 void ParseHandshakeHeader(
20     const char* handshake_message, int len,
21     std::string* status_line,
22     std::string* headers) {
23   size_t i = base::StringPiece(handshake_message, len).find_first_of("\r\n");
24   if (i == base::StringPiece::npos) {
25     *status_line = std::string(handshake_message, len);
26     *headers = "";
27     return;
28   }
29   // |status_line| includes \r\n.
30   *status_line = std::string(handshake_message, i + 2);
31 
32   int header_len = len - (i + 2) - 2;
33   if (header_len > 0) {
34     // |handshake_message| includes tailing \r\n\r\n.
35     // |headers| doesn't include 2nd \r\n.
36     *headers = std::string(handshake_message + i + 2, header_len);
37   } else {
38     *headers = "";
39   }
40 }
41 
FetchHeaders(const std::string & headers,const char * const headers_to_get[],size_t headers_to_get_len,std::vector<std::string> * values)42 void FetchHeaders(const std::string& headers,
43                   const char* const headers_to_get[],
44                   size_t headers_to_get_len,
45                   std::vector<std::string>* values) {
46   net::HttpUtil::HeadersIterator iter(headers.begin(), headers.end(), "\r\n");
47   while (iter.GetNext()) {
48     for (size_t i = 0; i < headers_to_get_len; i++) {
49       if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
50                                headers_to_get[i])) {
51         values->push_back(iter.values());
52       }
53     }
54   }
55 }
56 
GetHeaderName(std::string::const_iterator line_begin,std::string::const_iterator line_end,std::string::const_iterator * name_begin,std::string::const_iterator * name_end)57 bool GetHeaderName(std::string::const_iterator line_begin,
58                    std::string::const_iterator line_end,
59                    std::string::const_iterator* name_begin,
60                    std::string::const_iterator* name_end) {
61   std::string::const_iterator colon = std::find(line_begin, line_end, ':');
62   if (colon == line_end) {
63     return false;
64   }
65   *name_begin = line_begin;
66   *name_end = colon;
67   if (*name_begin == *name_end || net::HttpUtil::IsLWS(**name_begin))
68     return false;
69   net::HttpUtil::TrimLWS(name_begin, name_end);
70   return true;
71 }
72 
73 // Similar to HttpUtil::StripHeaders, but it preserves malformed headers, that
74 // is, lines that are not formatted as "<name>: <value>\r\n".
FilterHeaders(const std::string & headers,const char * const headers_to_remove[],size_t headers_to_remove_len)75 std::string FilterHeaders(
76     const std::string& headers,
77     const char* const headers_to_remove[],
78     size_t headers_to_remove_len) {
79   std::string filtered_headers;
80 
81   StringTokenizer lines(headers.begin(), headers.end(), "\r\n");
82   while (lines.GetNext()) {
83     std::string::const_iterator line_begin = lines.token_begin();
84     std::string::const_iterator line_end = lines.token_end();
85     std::string::const_iterator name_begin;
86     std::string::const_iterator name_end;
87     bool should_remove = false;
88     if (GetHeaderName(line_begin, line_end, &name_begin, &name_end)) {
89       for (size_t i = 0; i < headers_to_remove_len; ++i) {
90         if (LowerCaseEqualsASCII(name_begin, name_end, headers_to_remove[i])) {
91           should_remove = true;
92           break;
93         }
94       }
95     }
96     if (!should_remove) {
97       filtered_headers.append(line_begin, line_end);
98       filtered_headers.append("\r\n");
99     }
100   }
101   return filtered_headers;
102 }
103 
104 // Gets a key number from |key| and appends the number to |challenge|.
105 // The key number (/part_N/) is extracted as step 4.-8. in
106 // 5.2. Sending the server's opening handshake of
107 // http://www.ietf.org/id/draft-ietf-hybi-thewebsocketprotocol-00.txt
GetKeyNumber(const std::string & key,std::string * challenge)108 void GetKeyNumber(const std::string& key, std::string* challenge) {
109   uint32 key_number = 0;
110   uint32 spaces = 0;
111   for (size_t i = 0; i < key.size(); ++i) {
112     if (isdigit(key[i])) {
113       // key_number should not overflow. (it comes from
114       // WebCore/websockets/WebSocketHandshake.cpp).
115       key_number = key_number * 10 + key[i] - '0';
116     } else if (key[i] == ' ') {
117       ++spaces;
118     }
119   }
120   // spaces should not be zero in valid handshake request.
121   if (spaces == 0)
122     return;
123   key_number /= spaces;
124 
125   char part[4];
126   for (int i = 0; i < 4; i++) {
127     part[3 - i] = key_number & 0xFF;
128     key_number >>= 8;
129   }
130   challenge->append(part, 4);
131 }
132 
133 }  // anonymous namespace
134 
135 namespace net {
136 
WebSocketHandshakeRequestHandler()137 WebSocketHandshakeRequestHandler::WebSocketHandshakeRequestHandler()
138     : original_length_(0),
139       raw_length_(0) {}
140 
ParseRequest(const char * data,int length)141 bool WebSocketHandshakeRequestHandler::ParseRequest(
142     const char* data, int length) {
143   DCHECK_GT(length, 0);
144   std::string input(data, length);
145   int input_header_length =
146       HttpUtil::LocateEndOfHeaders(input.data(), input.size(), 0);
147   if (input_header_length <= 0 ||
148       input_header_length + kRequestKey3Size > input.size())
149     return false;
150 
151   ParseHandshakeHeader(input.data(),
152                        input_header_length,
153                        &status_line_,
154                        &headers_);
155 
156   // draft-hixie-thewebsocketprotocol-76 or later will send /key3/
157   // after handshake request header.
158   // Assumes WebKit doesn't send any data after handshake request message
159   // until handshake is finished.
160   // Thus, |key3_| is part of handshake message, and not in part
161   // of WebSocket frame stream.
162   DCHECK_EQ(kRequestKey3Size,
163             input.size() -
164             input_header_length);
165   key3_ = std::string(input.data() + input_header_length,
166                       input.size() - input_header_length);
167   original_length_ = input.size();
168   return true;
169 }
170 
original_length() const171 size_t WebSocketHandshakeRequestHandler::original_length() const {
172   return original_length_;
173 }
174 
AppendHeaderIfMissing(const std::string & name,const std::string & value)175 void WebSocketHandshakeRequestHandler::AppendHeaderIfMissing(
176     const std::string& name, const std::string& value) {
177   DCHECK(!headers_.empty());
178   HttpUtil::AppendHeaderIfMissing(name.c_str(), value, &headers_);
179 }
180 
RemoveHeaders(const char * const headers_to_remove[],size_t headers_to_remove_len)181 void WebSocketHandshakeRequestHandler::RemoveHeaders(
182     const char* const headers_to_remove[],
183     size_t headers_to_remove_len) {
184   DCHECK(!headers_.empty());
185   headers_ = FilterHeaders(
186       headers_, headers_to_remove, headers_to_remove_len);
187 }
188 
GetRequestInfo(const GURL & url,std::string * challenge)189 HttpRequestInfo WebSocketHandshakeRequestHandler::GetRequestInfo(
190     const GURL& url, std::string* challenge) {
191   HttpRequestInfo request_info;
192   request_info.url = url;
193   base::StringPiece method = status_line_.data();
194   size_t method_end = base::StringPiece(
195       status_line_.data(), status_line_.size()).find_first_of(" ");
196   if (method_end != base::StringPiece::npos)
197     request_info.method = std::string(status_line_.data(), method_end);
198 
199   request_info.extra_headers.Clear();
200   request_info.extra_headers.AddHeadersFromString(headers_);
201 
202   request_info.extra_headers.RemoveHeader("Upgrade");
203   request_info.extra_headers.RemoveHeader("Connection");
204 
205   challenge->clear();
206   std::string key;
207   request_info.extra_headers.GetHeader("Sec-WebSocket-Key1", &key);
208   request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key1");
209   GetKeyNumber(key, challenge);
210 
211   request_info.extra_headers.GetHeader("Sec-WebSocket-Key2", &key);
212   request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key2");
213   GetKeyNumber(key, challenge);
214 
215   challenge->append(key3_);
216 
217   return request_info;
218 }
219 
GetRequestHeaderBlock(const GURL & url,spdy::SpdyHeaderBlock * headers,std::string * challenge)220 bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock(
221     const GURL& url, spdy::SpdyHeaderBlock* headers, std::string* challenge) {
222   // We don't set "method" and "version".  These are fixed value in WebSocket
223   // protocol.
224   (*headers)["url"] = url.spec();
225 
226   std::string key1;
227   std::string key2;
228   HttpUtil::HeadersIterator iter(headers_.begin(), headers_.end(), "\r\n");
229   while (iter.GetNext()) {
230     if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
231                              "connection")) {
232       // Ignore "Connection" header.
233       continue;
234     } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
235                                     "upgrade")) {
236       // Ignore "Upgrade" header.
237       continue;
238     } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
239                                     "sec-websocket-key1")) {
240       // Use only for generating challenge.
241       key1 = iter.values();
242       continue;
243     } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
244                                     "sec-websocket-key2")) {
245       // Use only for generating challenge.
246       key2 = iter.values();
247       continue;
248     }
249     // Others should be sent out to |headers|.
250     std::string name = StringToLowerASCII(iter.name());
251     spdy::SpdyHeaderBlock::iterator found = headers->find(name);
252     if (found == headers->end()) {
253       (*headers)[name] = iter.values();
254     } else {
255       // For now, websocket doesn't use multiple headers, but follows to http.
256       found->second.append(1, '\0');  // +=() doesn't append 0's
257       found->second.append(iter.values());
258     }
259   }
260 
261   challenge->clear();
262   GetKeyNumber(key1, challenge);
263   GetKeyNumber(key2, challenge);
264   challenge->append(key3_);
265 
266   return true;
267 }
268 
GetRawRequest()269 std::string WebSocketHandshakeRequestHandler::GetRawRequest() {
270   DCHECK(!status_line_.empty());
271   DCHECK(!headers_.empty());
272   DCHECK_EQ(kRequestKey3Size, key3_.size());
273   std::string raw_request = status_line_ + headers_ + "\r\n" + key3_;
274   raw_length_ = raw_request.size();
275   return raw_request;
276 }
277 
raw_length() const278 size_t WebSocketHandshakeRequestHandler::raw_length() const {
279   DCHECK_GT(raw_length_, 0);
280   return raw_length_;
281 }
282 
WebSocketHandshakeResponseHandler()283 WebSocketHandshakeResponseHandler::WebSocketHandshakeResponseHandler()
284     : original_header_length_(0) {
285 }
286 
~WebSocketHandshakeResponseHandler()287 WebSocketHandshakeResponseHandler::~WebSocketHandshakeResponseHandler() {}
288 
ParseRawResponse(const char * data,int length)289 size_t WebSocketHandshakeResponseHandler::ParseRawResponse(
290     const char* data, int length) {
291   DCHECK_GT(length, 0);
292   if (HasResponse()) {
293     DCHECK(!status_line_.empty());
294     DCHECK(!headers_.empty());
295     DCHECK_EQ(kResponseKeySize, key_.size());
296     return 0;
297   }
298 
299   size_t old_original_length = original_.size();
300 
301   original_.append(data, length);
302   // TODO(ukai): fail fast when response gives wrong status code.
303   original_header_length_ = HttpUtil::LocateEndOfHeaders(
304       original_.data(), original_.size(), 0);
305   if (!HasResponse())
306     return length;
307 
308   ParseHandshakeHeader(original_.data(),
309                        original_header_length_,
310                        &status_line_,
311                        &headers_);
312   int header_size = status_line_.size() + headers_.size();
313   DCHECK_GE(original_header_length_, header_size);
314   header_separator_ = std::string(original_.data() + header_size,
315                                   original_header_length_ - header_size);
316   key_ = std::string(original_.data() + original_header_length_,
317                      kResponseKeySize);
318 
319   return original_header_length_ + kResponseKeySize - old_original_length;
320 }
321 
HasResponse() const322 bool WebSocketHandshakeResponseHandler::HasResponse() const {
323   return original_header_length_ > 0 &&
324       original_header_length_ + kResponseKeySize <= original_.size();
325 }
326 
ParseResponseInfo(const HttpResponseInfo & response_info,const std::string & challenge)327 bool WebSocketHandshakeResponseHandler::ParseResponseInfo(
328     const HttpResponseInfo& response_info,
329     const std::string& challenge) {
330   if (!response_info.headers.get())
331     return false;
332 
333   std::string response_message;
334   response_message = response_info.headers->GetStatusLine();
335   response_message += "\r\n";
336   response_message += "Upgrade: WebSocket\r\n";
337   response_message += "Connection: Upgrade\r\n";
338   void* iter = NULL;
339   std::string name;
340   std::string value;
341   while (response_info.headers->EnumerateHeaderLines(&iter, &name, &value)) {
342     response_message += name + ": " + value + "\r\n";
343   }
344   response_message += "\r\n";
345 
346   MD5Digest digest;
347   MD5Sum(challenge.data(), challenge.size(), &digest);
348 
349   const char* digest_data = reinterpret_cast<char*>(digest.a);
350   response_message.append(digest_data, sizeof(digest.a));
351 
352   return ParseRawResponse(response_message.data(),
353                           response_message.size()) == response_message.size();
354 }
355 
ParseResponseHeaderBlock(const spdy::SpdyHeaderBlock & headers,const std::string & challenge)356 bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock(
357     const spdy::SpdyHeaderBlock& headers,
358     const std::string& challenge) {
359   std::string response_message;
360   response_message = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n";
361   response_message += "Upgrade: WebSocket\r\n";
362   response_message += "Connection: Upgrade\r\n";
363   for (spdy::SpdyHeaderBlock::const_iterator iter = headers.begin();
364        iter != headers.end();
365        ++iter) {
366     // For each value, if the server sends a NUL-separated list of values,
367     // we separate that back out into individual headers for each value
368     // in the list.
369     const std::string& value = iter->second;
370     size_t start = 0;
371     size_t end = 0;
372     do {
373       end = value.find('\0', start);
374       std::string tval;
375       if (end != std::string::npos)
376         tval = value.substr(start, (end - start));
377       else
378         tval = value.substr(start);
379       response_message += iter->first + ": " + tval + "\r\n";
380       start = end + 1;
381     } while (end != std::string::npos);
382   }
383   response_message += "\r\n";
384 
385   MD5Digest digest;
386   MD5Sum(challenge.data(), challenge.size(), &digest);
387 
388   const char* digest_data = reinterpret_cast<char*>(digest.a);
389   response_message.append(digest_data, sizeof(digest.a));
390 
391   return ParseRawResponse(response_message.data(),
392                           response_message.size()) == response_message.size();
393 }
394 
GetHeaders(const char * const headers_to_get[],size_t headers_to_get_len,std::vector<std::string> * values)395 void WebSocketHandshakeResponseHandler::GetHeaders(
396     const char* const headers_to_get[],
397     size_t headers_to_get_len,
398     std::vector<std::string>* values) {
399   DCHECK(HasResponse());
400   DCHECK(!status_line_.empty());
401   DCHECK(!headers_.empty());
402   DCHECK_EQ(kResponseKeySize, key_.size());
403 
404   FetchHeaders(headers_, headers_to_get, headers_to_get_len, values);
405 }
406 
RemoveHeaders(const char * const headers_to_remove[],size_t headers_to_remove_len)407 void WebSocketHandshakeResponseHandler::RemoveHeaders(
408     const char* const headers_to_remove[],
409     size_t headers_to_remove_len) {
410   DCHECK(HasResponse());
411   DCHECK(!status_line_.empty());
412   DCHECK(!headers_.empty());
413   DCHECK_EQ(kResponseKeySize, key_.size());
414 
415   headers_ = FilterHeaders(headers_, headers_to_remove, headers_to_remove_len);
416 }
417 
GetRawResponse() const418 std::string WebSocketHandshakeResponseHandler::GetRawResponse() const {
419   DCHECK(HasResponse());
420   return std::string(original_.data(),
421                      original_header_length_ + kResponseKeySize);
422 }
423 
GetResponse()424 std::string WebSocketHandshakeResponseHandler::GetResponse() {
425   DCHECK(HasResponse());
426   DCHECK(!status_line_.empty());
427   // headers_ might be empty for wrong response from server.
428   DCHECK_EQ(kResponseKeySize, key_.size());
429 
430   return status_line_ + headers_ + header_separator_ + key_;
431 }
432 
433 }  // namespace net
434