• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_handshake.h"
6 
7 #include <algorithm>
8 #include <vector>
9 
10 #include "base/logging.h"
11 #include "base/md5.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/rand_util.h"
14 #include "base/string_number_conversions.h"
15 #include "base/string_util.h"
16 #include "base/stringprintf.h"
17 #include "net/http/http_response_headers.h"
18 #include "net/http/http_util.h"
19 
20 namespace net {
21 
22 const int WebSocketHandshake::kWebSocketPort = 80;
23 const int WebSocketHandshake::kSecureWebSocketPort = 443;
24 
WebSocketHandshake(const GURL & url,const std::string & origin,const std::string & location,const std::string & protocol)25 WebSocketHandshake::WebSocketHandshake(
26     const GURL& url,
27     const std::string& origin,
28     const std::string& location,
29     const std::string& protocol)
30     : url_(url),
31       origin_(origin),
32       location_(location),
33       protocol_(protocol),
34       mode_(MODE_INCOMPLETE) {
35 }
36 
~WebSocketHandshake()37 WebSocketHandshake::~WebSocketHandshake() {
38 }
39 
is_secure() const40 bool WebSocketHandshake::is_secure() const {
41   return url_.SchemeIs("wss");
42 }
43 
CreateClientHandshakeMessage()44 std::string WebSocketHandshake::CreateClientHandshakeMessage() {
45   if (!parameter_.get()) {
46     parameter_.reset(new Parameter);
47     parameter_->GenerateKeys();
48   }
49   std::string msg;
50 
51   // WebSocket protocol 4.1 Opening handshake.
52 
53   msg = "GET ";
54   msg += GetResourceName();
55   msg += " HTTP/1.1\r\n";
56 
57   std::vector<std::string> fields;
58 
59   fields.push_back("Upgrade: WebSocket");
60   fields.push_back("Connection: Upgrade");
61 
62   fields.push_back("Host: " + GetHostFieldValue());
63 
64   fields.push_back("Origin: " + GetOriginFieldValue());
65 
66   if (!protocol_.empty())
67     fields.push_back("Sec-WebSocket-Protocol: " + protocol_);
68 
69   // TODO(ukai): Add cookie if necessary.
70 
71   fields.push_back("Sec-WebSocket-Key1: " + parameter_->GetSecWebSocketKey1());
72   fields.push_back("Sec-WebSocket-Key2: " + parameter_->GetSecWebSocketKey2());
73 
74   std::random_shuffle(fields.begin(), fields.end(), base::RandGenerator);
75 
76   for (size_t i = 0; i < fields.size(); i++) {
77     msg += fields[i] + "\r\n";
78   }
79   msg += "\r\n";
80 
81   msg.append(parameter_->GetKey3());
82   return msg;
83 }
84 
ReadServerHandshake(const char * data,size_t len)85 int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) {
86   mode_ = MODE_INCOMPLETE;
87   int eoh = HttpUtil::LocateEndOfHeaders(data, len);
88   if (eoh < 0)
89     return -1;
90 
91   scoped_refptr<HttpResponseHeaders> headers(
92       new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh)));
93 
94   if (headers->response_code() != 101) {
95     mode_ = MODE_FAILED;
96     DVLOG(1) << "Bad response code: " << headers->response_code();
97     return eoh;
98   }
99   mode_ = MODE_NORMAL;
100   if (!ProcessHeaders(*headers) || !CheckResponseHeaders()) {
101     DVLOG(1) << "Process Headers failed: " << std::string(data, eoh);
102     mode_ = MODE_FAILED;
103     return eoh;
104   }
105   if (len < static_cast<size_t>(eoh + Parameter::kExpectedResponseSize)) {
106     mode_ = MODE_INCOMPLETE;
107     return -1;
108   }
109   uint8 expected[Parameter::kExpectedResponseSize];
110   parameter_->GetExpectedResponse(expected);
111   if (memcmp(&data[eoh], expected, Parameter::kExpectedResponseSize)) {
112     mode_ = MODE_FAILED;
113     return eoh + Parameter::kExpectedResponseSize;
114   }
115   mode_ = MODE_CONNECTED;
116   return eoh + Parameter::kExpectedResponseSize;
117 }
118 
GetResourceName() const119 std::string WebSocketHandshake::GetResourceName() const {
120   std::string resource_name = url_.path();
121   if (url_.has_query()) {
122     resource_name += "?";
123     resource_name += url_.query();
124   }
125   return resource_name;
126 }
127 
GetHostFieldValue() const128 std::string WebSocketHandshake::GetHostFieldValue() const {
129   // url_.host() is expected to be encoded in punnycode here.
130   std::string host = StringToLowerASCII(url_.host());
131   if (url_.has_port()) {
132     bool secure = is_secure();
133     int port = url_.EffectiveIntPort();
134     if ((!secure &&
135          port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) ||
136         (secure &&
137          port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) {
138       host += ":";
139       host += base::IntToString(port);
140     }
141   }
142   return host;
143 }
144 
GetOriginFieldValue() const145 std::string WebSocketHandshake::GetOriginFieldValue() const {
146   // It's OK to lowercase the origin as the Origin header does not contain
147   // the path or query portions, as per
148   // http://tools.ietf.org/html/draft-abarth-origin-00.
149   //
150   // TODO(satorux): Should we trim the port portion here if it's 80 for
151   // http:// or 443 for https:// ? Or can we assume it's done by the
152   // client of the library?
153   return StringToLowerASCII(origin_);
154 }
155 
156 /* static */
GetSingleHeader(const HttpResponseHeaders & headers,const std::string & name,std::string * value)157 bool WebSocketHandshake::GetSingleHeader(const HttpResponseHeaders& headers,
158                                          const std::string& name,
159                                          std::string* value) {
160   std::string first_value;
161   void* iter = NULL;
162   if (!headers.EnumerateHeader(&iter, name, &first_value))
163     return false;
164 
165   // Checks no more |name| found in |headers|.
166   // Second call of EnumerateHeader() must return false.
167   std::string second_value;
168   if (headers.EnumerateHeader(&iter, name, &second_value))
169     return false;
170   *value = first_value;
171   return true;
172 }
173 
ProcessHeaders(const HttpResponseHeaders & headers)174 bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) {
175   std::string value;
176   if (!GetSingleHeader(headers, "upgrade", &value) ||
177       value != "WebSocket")
178     return false;
179 
180   if (!GetSingleHeader(headers, "connection", &value) ||
181       !LowerCaseEqualsASCII(value, "upgrade"))
182     return false;
183 
184   if (!GetSingleHeader(headers, "sec-websocket-origin", &ws_origin_))
185     return false;
186 
187   if (!GetSingleHeader(headers, "sec-websocket-location", &ws_location_))
188     return false;
189 
190   // If |protocol_| is not specified by client, we don't care if there's
191   // protocol field or not as specified in the spec.
192   if (!protocol_.empty()
193       && !GetSingleHeader(headers, "sec-websocket-protocol", &ws_protocol_))
194     return false;
195   return true;
196 }
197 
CheckResponseHeaders() const198 bool WebSocketHandshake::CheckResponseHeaders() const {
199   DCHECK(mode_ == MODE_NORMAL);
200   if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str()))
201     return false;
202   if (location_ != ws_location_)
203     return false;
204   if (!protocol_.empty() && protocol_ != ws_protocol_)
205     return false;
206   return true;
207 }
208 
209 namespace {
210 
211 // unsigned int version of base::RandInt().
212 // we can't use base::RandInt(), because max would be negative if it is
213 // represented as int, so DCHECK(min <= max) fails.
RandUint32(uint32 min,uint32 max)214 uint32 RandUint32(uint32 min, uint32 max) {
215   DCHECK(min <= max);
216 
217   uint64 range = static_cast<int64>(max) - min + 1;
218   uint64 number = base::RandUint64();
219   // TODO(ukai): fix to be uniform.
220   // the distribution of the result of modulo will be biased.
221   uint32 result = min + static_cast<uint32>(number % range);
222   DCHECK(result >= min && result <= max);
223   return result;
224 }
225 
226 }
227 
228 uint32 (*WebSocketHandshake::Parameter::rand_)(uint32 min, uint32 max) =
229     RandUint32;
230 uint8 randomCharacterInSecWebSocketKey[0x2F - 0x20 + 0x7E - 0x39];
231 
Parameter()232 WebSocketHandshake::Parameter::Parameter()
233     : number_1_(0), number_2_(0) {
234   if (randomCharacterInSecWebSocketKey[0] == '\0') {
235     int i = 0;
236     for (int ch = 0x21; ch <= 0x2F; ch++, i++)
237       randomCharacterInSecWebSocketKey[i] = ch;
238     for (int ch = 0x3A; ch <= 0x7E; ch++, i++)
239       randomCharacterInSecWebSocketKey[i] = ch;
240   }
241 }
242 
~Parameter()243 WebSocketHandshake::Parameter::~Parameter() {}
244 
GenerateKeys()245 void WebSocketHandshake::Parameter::GenerateKeys() {
246   GenerateSecWebSocketKey(&number_1_, &key_1_);
247   GenerateSecWebSocketKey(&number_2_, &key_2_);
248   GenerateKey3();
249 }
250 
SetChallengeNumber(uint8 * buf,uint32 number)251 static void SetChallengeNumber(uint8* buf, uint32 number) {
252   uint8* p = buf + 3;
253   for (int i = 0; i < 4; i++) {
254     *p = (uint8)(number & 0xFF);
255     --p;
256     number >>= 8;
257   }
258 }
259 
GetExpectedResponse(uint8 * expected) const260 void WebSocketHandshake::Parameter::GetExpectedResponse(uint8 *expected) const {
261   uint8 challenge[kExpectedResponseSize];
262   SetChallengeNumber(&challenge[0], number_1_);
263   SetChallengeNumber(&challenge[4], number_2_);
264   memcpy(&challenge[8], key_3_.data(), kKey3Size);
265   MD5Digest digest;
266   MD5Sum(challenge, kExpectedResponseSize, &digest);
267   memcpy(expected, digest.a, kExpectedResponseSize);
268 }
269 
270 /* static */
SetRandomNumberGenerator(uint32 (* rand)(uint32 min,uint32 max))271 void WebSocketHandshake::Parameter::SetRandomNumberGenerator(
272     uint32 (*rand)(uint32 min, uint32 max)) {
273   rand_ = rand;
274 }
275 
GenerateSecWebSocketKey(uint32 * number,std::string * key)276 void WebSocketHandshake::Parameter::GenerateSecWebSocketKey(
277     uint32* number, std::string* key) {
278   uint32 space = rand_(1, 12);
279   uint32 max = 4294967295U / space;
280   *number = rand_(0, max);
281   uint32 product = *number * space;
282 
283   std::string s = base::StringPrintf("%u", product);
284   int n = rand_(1, 12);
285   for (int i = 0; i < n; i++) {
286     int pos = rand_(0, s.length());
287     int chpos = rand_(0, sizeof(randomCharacterInSecWebSocketKey) - 1);
288     s = s.substr(0, pos).append(1, randomCharacterInSecWebSocketKey[chpos]) +
289         s.substr(pos);
290   }
291   for (uint32 i = 0; i < space; i++) {
292     int pos = rand_(1, s.length() - 1);
293     s = s.substr(0, pos) + " " + s.substr(pos);
294   }
295   *key = s;
296 }
297 
GenerateKey3()298 void WebSocketHandshake::Parameter::GenerateKey3() {
299   key_3_.clear();
300   for (int i = 0; i < 8; i++) {
301     key_3_.append(1, rand_(0, 255));
302   }
303 }
304 
305 }  // namespace net
306