// Copyright (c) 2011 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/websockets/websocket_handshake.h" #include #include #include "base/logging.h" #include "base/md5.h" #include "base/memory/ref_counted.h" #include "base/rand_util.h" #include "base/string_number_conversions.h" #include "base/string_util.h" #include "base/stringprintf.h" #include "net/http/http_response_headers.h" #include "net/http/http_util.h" namespace net { const int WebSocketHandshake::kWebSocketPort = 80; const int WebSocketHandshake::kSecureWebSocketPort = 443; WebSocketHandshake::WebSocketHandshake( const GURL& url, const std::string& origin, const std::string& location, const std::string& protocol) : url_(url), origin_(origin), location_(location), protocol_(protocol), mode_(MODE_INCOMPLETE) { } WebSocketHandshake::~WebSocketHandshake() { } bool WebSocketHandshake::is_secure() const { return url_.SchemeIs("wss"); } std::string WebSocketHandshake::CreateClientHandshakeMessage() { if (!parameter_.get()) { parameter_.reset(new Parameter); parameter_->GenerateKeys(); } std::string msg; // WebSocket protocol 4.1 Opening handshake. msg = "GET "; msg += GetResourceName(); msg += " HTTP/1.1\r\n"; std::vector fields; fields.push_back("Upgrade: WebSocket"); fields.push_back("Connection: Upgrade"); fields.push_back("Host: " + GetHostFieldValue()); fields.push_back("Origin: " + GetOriginFieldValue()); if (!protocol_.empty()) fields.push_back("Sec-WebSocket-Protocol: " + protocol_); // TODO(ukai): Add cookie if necessary. fields.push_back("Sec-WebSocket-Key1: " + parameter_->GetSecWebSocketKey1()); fields.push_back("Sec-WebSocket-Key2: " + parameter_->GetSecWebSocketKey2()); std::random_shuffle(fields.begin(), fields.end(), base::RandGenerator); for (size_t i = 0; i < fields.size(); i++) { msg += fields[i] + "\r\n"; } msg += "\r\n"; msg.append(parameter_->GetKey3()); return msg; } int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) { mode_ = MODE_INCOMPLETE; int eoh = HttpUtil::LocateEndOfHeaders(data, len); if (eoh < 0) return -1; scoped_refptr headers( new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh))); if (headers->response_code() != 101) { mode_ = MODE_FAILED; DVLOG(1) << "Bad response code: " << headers->response_code(); return eoh; } mode_ = MODE_NORMAL; if (!ProcessHeaders(*headers) || !CheckResponseHeaders()) { DVLOG(1) << "Process Headers failed: " << std::string(data, eoh); mode_ = MODE_FAILED; return eoh; } if (len < static_cast(eoh + Parameter::kExpectedResponseSize)) { mode_ = MODE_INCOMPLETE; return -1; } uint8 expected[Parameter::kExpectedResponseSize]; parameter_->GetExpectedResponse(expected); if (memcmp(&data[eoh], expected, Parameter::kExpectedResponseSize)) { mode_ = MODE_FAILED; return eoh + Parameter::kExpectedResponseSize; } mode_ = MODE_CONNECTED; return eoh + Parameter::kExpectedResponseSize; } std::string WebSocketHandshake::GetResourceName() const { std::string resource_name = url_.path(); if (url_.has_query()) { resource_name += "?"; resource_name += url_.query(); } return resource_name; } std::string WebSocketHandshake::GetHostFieldValue() const { // url_.host() is expected to be encoded in punnycode here. std::string host = StringToLowerASCII(url_.host()); if (url_.has_port()) { bool secure = is_secure(); int port = url_.EffectiveIntPort(); if ((!secure && port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) || (secure && port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) { host += ":"; host += base::IntToString(port); } } return host; } std::string WebSocketHandshake::GetOriginFieldValue() const { // It's OK to lowercase the origin as the Origin header does not contain // the path or query portions, as per // http://tools.ietf.org/html/draft-abarth-origin-00. // // TODO(satorux): Should we trim the port portion here if it's 80 for // http:// or 443 for https:// ? Or can we assume it's done by the // client of the library? return StringToLowerASCII(origin_); } /* static */ bool WebSocketHandshake::GetSingleHeader(const HttpResponseHeaders& headers, const std::string& name, std::string* value) { std::string first_value; void* iter = NULL; if (!headers.EnumerateHeader(&iter, name, &first_value)) return false; // Checks no more |name| found in |headers|. // Second call of EnumerateHeader() must return false. std::string second_value; if (headers.EnumerateHeader(&iter, name, &second_value)) return false; *value = first_value; return true; } bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) { std::string value; if (!GetSingleHeader(headers, "upgrade", &value) || value != "WebSocket") return false; if (!GetSingleHeader(headers, "connection", &value) || !LowerCaseEqualsASCII(value, "upgrade")) return false; if (!GetSingleHeader(headers, "sec-websocket-origin", &ws_origin_)) return false; if (!GetSingleHeader(headers, "sec-websocket-location", &ws_location_)) return false; // If |protocol_| is not specified by client, we don't care if there's // protocol field or not as specified in the spec. if (!protocol_.empty() && !GetSingleHeader(headers, "sec-websocket-protocol", &ws_protocol_)) return false; return true; } bool WebSocketHandshake::CheckResponseHeaders() const { DCHECK(mode_ == MODE_NORMAL); if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str())) return false; if (location_ != ws_location_) return false; if (!protocol_.empty() && protocol_ != ws_protocol_) return false; return true; } namespace { // unsigned int version of base::RandInt(). // we can't use base::RandInt(), because max would be negative if it is // represented as int, so DCHECK(min <= max) fails. uint32 RandUint32(uint32 min, uint32 max) { DCHECK(min <= max); uint64 range = static_cast(max) - min + 1; uint64 number = base::RandUint64(); // TODO(ukai): fix to be uniform. // the distribution of the result of modulo will be biased. uint32 result = min + static_cast(number % range); DCHECK(result >= min && result <= max); return result; } } uint32 (*WebSocketHandshake::Parameter::rand_)(uint32 min, uint32 max) = RandUint32; uint8 randomCharacterInSecWebSocketKey[0x2F - 0x20 + 0x7E - 0x39]; WebSocketHandshake::Parameter::Parameter() : number_1_(0), number_2_(0) { if (randomCharacterInSecWebSocketKey[0] == '\0') { int i = 0; for (int ch = 0x21; ch <= 0x2F; ch++, i++) randomCharacterInSecWebSocketKey[i] = ch; for (int ch = 0x3A; ch <= 0x7E; ch++, i++) randomCharacterInSecWebSocketKey[i] = ch; } } WebSocketHandshake::Parameter::~Parameter() {} void WebSocketHandshake::Parameter::GenerateKeys() { GenerateSecWebSocketKey(&number_1_, &key_1_); GenerateSecWebSocketKey(&number_2_, &key_2_); GenerateKey3(); } static void SetChallengeNumber(uint8* buf, uint32 number) { uint8* p = buf + 3; for (int i = 0; i < 4; i++) { *p = (uint8)(number & 0xFF); --p; number >>= 8; } } void WebSocketHandshake::Parameter::GetExpectedResponse(uint8 *expected) const { uint8 challenge[kExpectedResponseSize]; SetChallengeNumber(&challenge[0], number_1_); SetChallengeNumber(&challenge[4], number_2_); memcpy(&challenge[8], key_3_.data(), kKey3Size); MD5Digest digest; MD5Sum(challenge, kExpectedResponseSize, &digest); memcpy(expected, digest.a, kExpectedResponseSize); } /* static */ void WebSocketHandshake::Parameter::SetRandomNumberGenerator( uint32 (*rand)(uint32 min, uint32 max)) { rand_ = rand; } void WebSocketHandshake::Parameter::GenerateSecWebSocketKey( uint32* number, std::string* key) { uint32 space = rand_(1, 12); uint32 max = 4294967295U / space; *number = rand_(0, max); uint32 product = *number * space; std::string s = base::StringPrintf("%u", product); int n = rand_(1, 12); for (int i = 0; i < n; i++) { int pos = rand_(0, s.length()); int chpos = rand_(0, sizeof(randomCharacterInSecWebSocketKey) - 1); s = s.substr(0, pos).append(1, randomCharacterInSecWebSocketKey[chpos]) + s.substr(pos); } for (uint32 i = 0; i < space; i++) { int pos = rand_(1, s.length() - 1); s = s.substr(0, pos) + " " + s.substr(pos); } *key = s; } void WebSocketHandshake::Parameter::GenerateKey3() { key_3_.clear(); for (int i = 0; i < 8; i++) { key_3_.append(1, rand_(0, 255)); } } } // namespace net