1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_basic_handshake_stream.h"
6
7 #include <algorithm>
8 #include <iterator>
9 #include <set>
10 #include <string>
11 #include <vector>
12
13 #include "base/base64.h"
14 #include "base/basictypes.h"
15 #include "base/bind.h"
16 #include "base/containers/hash_tables.h"
17 #include "base/logging.h"
18 #include "base/metrics/histogram.h"
19 #include "base/metrics/sparse_histogram.h"
20 #include "base/stl_util.h"
21 #include "base/strings/string_number_conversions.h"
22 #include "base/strings/string_piece.h"
23 #include "base/strings/string_util.h"
24 #include "base/strings/stringprintf.h"
25 #include "base/time/time.h"
26 #include "crypto/random.h"
27 #include "net/http/http_request_headers.h"
28 #include "net/http/http_request_info.h"
29 #include "net/http/http_response_body_drainer.h"
30 #include "net/http/http_response_headers.h"
31 #include "net/http/http_status_code.h"
32 #include "net/http/http_stream_parser.h"
33 #include "net/socket/client_socket_handle.h"
34 #include "net/socket/websocket_transport_client_socket_pool.h"
35 #include "net/websockets/websocket_basic_stream.h"
36 #include "net/websockets/websocket_deflate_predictor.h"
37 #include "net/websockets/websocket_deflate_predictor_impl.h"
38 #include "net/websockets/websocket_deflate_stream.h"
39 #include "net/websockets/websocket_deflater.h"
40 #include "net/websockets/websocket_extension_parser.h"
41 #include "net/websockets/websocket_handshake_constants.h"
42 #include "net/websockets/websocket_handshake_handler.h"
43 #include "net/websockets/websocket_handshake_request_info.h"
44 #include "net/websockets/websocket_handshake_response_info.h"
45 #include "net/websockets/websocket_stream.h"
46
47 namespace net {
48
49 // TODO(ricea): If more extensions are added, replace this with a more general
50 // mechanism.
51 struct WebSocketExtensionParams {
WebSocketExtensionParamsnet::WebSocketExtensionParams52 WebSocketExtensionParams()
53 : deflate_enabled(false),
54 client_window_bits(15),
55 deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {}
56
57 bool deflate_enabled;
58 int client_window_bits;
59 WebSocketDeflater::ContextTakeOverMode deflate_mode;
60 };
61
62 namespace {
63
64 enum GetHeaderResult {
65 GET_HEADER_OK,
66 GET_HEADER_MISSING,
67 GET_HEADER_MULTIPLE,
68 };
69
MissingHeaderMessage(const std::string & header_name)70 std::string MissingHeaderMessage(const std::string& header_name) {
71 return std::string("'") + header_name + "' header is missing";
72 }
73
MultipleHeaderValuesMessage(const std::string & header_name)74 std::string MultipleHeaderValuesMessage(const std::string& header_name) {
75 return
76 std::string("'") +
77 header_name +
78 "' header must not appear more than once in a response";
79 }
80
GenerateHandshakeChallenge()81 std::string GenerateHandshakeChallenge() {
82 std::string raw_challenge(websockets::kRawChallengeLength, '\0');
83 crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length());
84 std::string encoded_challenge;
85 base::Base64Encode(raw_challenge, &encoded_challenge);
86 return encoded_challenge;
87 }
88
AddVectorHeaderIfNonEmpty(const char * name,const std::vector<std::string> & value,HttpRequestHeaders * headers)89 void AddVectorHeaderIfNonEmpty(const char* name,
90 const std::vector<std::string>& value,
91 HttpRequestHeaders* headers) {
92 if (value.empty())
93 return;
94 headers->SetHeader(name, JoinString(value, ", "));
95 }
96
GetSingleHeaderValue(const HttpResponseHeaders * headers,const base::StringPiece & name,std::string * value)97 GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers,
98 const base::StringPiece& name,
99 std::string* value) {
100 void* state = NULL;
101 size_t num_values = 0;
102 std::string temp_value;
103 while (headers->EnumerateHeader(&state, name, &temp_value)) {
104 if (++num_values > 1)
105 return GET_HEADER_MULTIPLE;
106 *value = temp_value;
107 }
108 return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING;
109 }
110
ValidateHeaderHasSingleValue(GetHeaderResult result,const std::string & header_name,std::string * failure_message)111 bool ValidateHeaderHasSingleValue(GetHeaderResult result,
112 const std::string& header_name,
113 std::string* failure_message) {
114 if (result == GET_HEADER_MISSING) {
115 *failure_message = MissingHeaderMessage(header_name);
116 return false;
117 }
118 if (result == GET_HEADER_MULTIPLE) {
119 *failure_message = MultipleHeaderValuesMessage(header_name);
120 return false;
121 }
122 DCHECK_EQ(result, GET_HEADER_OK);
123 return true;
124 }
125
ValidateUpgrade(const HttpResponseHeaders * headers,std::string * failure_message)126 bool ValidateUpgrade(const HttpResponseHeaders* headers,
127 std::string* failure_message) {
128 std::string value;
129 GetHeaderResult result =
130 GetSingleHeaderValue(headers, websockets::kUpgrade, &value);
131 if (!ValidateHeaderHasSingleValue(result,
132 websockets::kUpgrade,
133 failure_message)) {
134 return false;
135 }
136
137 if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) {
138 *failure_message =
139 "'Upgrade' header value is not 'WebSocket': " + value;
140 return false;
141 }
142 return true;
143 }
144
ValidateSecWebSocketAccept(const HttpResponseHeaders * headers,const std::string & expected,std::string * failure_message)145 bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers,
146 const std::string& expected,
147 std::string* failure_message) {
148 std::string actual;
149 GetHeaderResult result =
150 GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual);
151 if (!ValidateHeaderHasSingleValue(result,
152 websockets::kSecWebSocketAccept,
153 failure_message)) {
154 return false;
155 }
156
157 if (expected != actual) {
158 *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value";
159 return false;
160 }
161 return true;
162 }
163
ValidateConnection(const HttpResponseHeaders * headers,std::string * failure_message)164 bool ValidateConnection(const HttpResponseHeaders* headers,
165 std::string* failure_message) {
166 // Connection header is permitted to contain other tokens.
167 if (!headers->HasHeader(HttpRequestHeaders::kConnection)) {
168 *failure_message = MissingHeaderMessage(HttpRequestHeaders::kConnection);
169 return false;
170 }
171 if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection,
172 websockets::kUpgrade)) {
173 *failure_message = "'Connection' header value must contain 'Upgrade'";
174 return false;
175 }
176 return true;
177 }
178
ValidateSubProtocol(const HttpResponseHeaders * headers,const std::vector<std::string> & requested_sub_protocols,std::string * sub_protocol,std::string * failure_message)179 bool ValidateSubProtocol(
180 const HttpResponseHeaders* headers,
181 const std::vector<std::string>& requested_sub_protocols,
182 std::string* sub_protocol,
183 std::string* failure_message) {
184 void* state = NULL;
185 std::string value;
186 base::hash_set<std::string> requested_set(requested_sub_protocols.begin(),
187 requested_sub_protocols.end());
188 int count = 0;
189 bool has_multiple_protocols = false;
190 bool has_invalid_protocol = false;
191
192 while (!has_invalid_protocol || !has_multiple_protocols) {
193 std::string temp_value;
194 if (!headers->EnumerateHeader(
195 &state, websockets::kSecWebSocketProtocol, &temp_value))
196 break;
197 value = temp_value;
198 if (requested_set.count(value) == 0)
199 has_invalid_protocol = true;
200 if (++count > 1)
201 has_multiple_protocols = true;
202 }
203
204 if (has_multiple_protocols) {
205 *failure_message =
206 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
207 return false;
208 } else if (count > 0 && requested_sub_protocols.size() == 0) {
209 *failure_message =
210 std::string("Response must not include 'Sec-WebSocket-Protocol' "
211 "header if not present in request: ")
212 + value;
213 return false;
214 } else if (has_invalid_protocol) {
215 *failure_message =
216 "'Sec-WebSocket-Protocol' header value '" +
217 value +
218 "' in response does not match any of sent values";
219 return false;
220 } else if (requested_sub_protocols.size() > 0 && count == 0) {
221 *failure_message =
222 "Sent non-empty 'Sec-WebSocket-Protocol' header "
223 "but no response was received";
224 return false;
225 }
226 *sub_protocol = value;
227 return true;
228 }
229
DeflateError(std::string * message,const base::StringPiece & piece)230 bool DeflateError(std::string* message, const base::StringPiece& piece) {
231 *message = "Error in permessage-deflate: ";
232 piece.AppendToString(message);
233 return false;
234 }
235
ValidatePerMessageDeflateExtension(const WebSocketExtension & extension,std::string * failure_message,WebSocketExtensionParams * params)236 bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension,
237 std::string* failure_message,
238 WebSocketExtensionParams* params) {
239 static const char kClientPrefix[] = "client_";
240 static const char kServerPrefix[] = "server_";
241 static const char kNoContextTakeover[] = "no_context_takeover";
242 static const char kMaxWindowBits[] = "max_window_bits";
243 const size_t kPrefixLen = arraysize(kClientPrefix) - 1;
244 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1,
245 the_strings_server_and_client_must_be_the_same_length);
246 typedef std::vector<WebSocketExtension::Parameter> ParameterVector;
247
248 DCHECK_EQ("permessage-deflate", extension.name());
249 const ParameterVector& parameters = extension.parameters();
250 std::set<std::string> seen_names;
251 for (ParameterVector::const_iterator it = parameters.begin();
252 it != parameters.end(); ++it) {
253 const std::string& name = it->name();
254 if (seen_names.count(name) != 0) {
255 return DeflateError(
256 failure_message,
257 "Received duplicate permessage-deflate extension parameter " + name);
258 }
259 seen_names.insert(name);
260 const std::string client_or_server(name, 0, kPrefixLen);
261 const bool is_client = (client_or_server == kClientPrefix);
262 if (!is_client && client_or_server != kServerPrefix) {
263 return DeflateError(
264 failure_message,
265 "Received an unexpected permessage-deflate extension parameter");
266 }
267 const std::string rest(name, kPrefixLen);
268 if (rest == kNoContextTakeover) {
269 if (it->HasValue()) {
270 return DeflateError(failure_message,
271 "Received invalid " + name + " parameter");
272 }
273 if (is_client)
274 params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT;
275 } else if (rest == kMaxWindowBits) {
276 if (!it->HasValue())
277 return DeflateError(failure_message, name + " must have value");
278 int bits = 0;
279 if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 ||
280 it->value()[0] == '0' ||
281 it->value().find_first_not_of("0123456789") != std::string::npos) {
282 return DeflateError(failure_message,
283 "Received invalid " + name + " parameter");
284 }
285 if (is_client)
286 params->client_window_bits = bits;
287 } else {
288 return DeflateError(
289 failure_message,
290 "Received an unexpected permessage-deflate extension parameter");
291 }
292 }
293 params->deflate_enabled = true;
294 return true;
295 }
296
ValidateExtensions(const HttpResponseHeaders * headers,const std::vector<std::string> & requested_extensions,std::string * extensions,std::string * failure_message,WebSocketExtensionParams * params)297 bool ValidateExtensions(const HttpResponseHeaders* headers,
298 const std::vector<std::string>& requested_extensions,
299 std::string* extensions,
300 std::string* failure_message,
301 WebSocketExtensionParams* params) {
302 void* state = NULL;
303 std::string value;
304 std::vector<std::string> accepted_extensions;
305 // TODO(ricea): If adding support for additional extensions, generalise this
306 // code.
307 bool seen_permessage_deflate = false;
308 while (headers->EnumerateHeader(
309 &state, websockets::kSecWebSocketExtensions, &value)) {
310 WebSocketExtensionParser parser;
311 parser.Parse(value);
312 if (parser.has_error()) {
313 // TODO(yhirano) Set appropriate failure message.
314 *failure_message =
315 "'Sec-WebSocket-Extensions' header value is "
316 "rejected by the parser: " +
317 value;
318 return false;
319 }
320 if (parser.extension().name() == "permessage-deflate") {
321 if (seen_permessage_deflate) {
322 *failure_message = "Received duplicate permessage-deflate response";
323 return false;
324 }
325 seen_permessage_deflate = true;
326 if (!ValidatePerMessageDeflateExtension(
327 parser.extension(), failure_message, params))
328 return false;
329 } else {
330 *failure_message =
331 "Found an unsupported extension '" +
332 parser.extension().name() +
333 "' in 'Sec-WebSocket-Extensions' header";
334 return false;
335 }
336 accepted_extensions.push_back(value);
337 }
338 *extensions = JoinString(accepted_extensions, ", ");
339 return true;
340 }
341
342 } // namespace
343
WebSocketBasicHandshakeStream(scoped_ptr<ClientSocketHandle> connection,WebSocketStream::ConnectDelegate * connect_delegate,bool using_proxy,std::vector<std::string> requested_sub_protocols,std::vector<std::string> requested_extensions,std::string * failure_message)344 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream(
345 scoped_ptr<ClientSocketHandle> connection,
346 WebSocketStream::ConnectDelegate* connect_delegate,
347 bool using_proxy,
348 std::vector<std::string> requested_sub_protocols,
349 std::vector<std::string> requested_extensions,
350 std::string* failure_message)
351 : state_(connection.release(), using_proxy),
352 connect_delegate_(connect_delegate),
353 http_response_info_(NULL),
354 requested_sub_protocols_(requested_sub_protocols),
355 requested_extensions_(requested_extensions),
356 failure_message_(failure_message) {
357 DCHECK(connect_delegate);
358 DCHECK(failure_message);
359 }
360
~WebSocketBasicHandshakeStream()361 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {}
362
InitializeStream(const HttpRequestInfo * request_info,RequestPriority priority,const BoundNetLog & net_log,const CompletionCallback & callback)363 int WebSocketBasicHandshakeStream::InitializeStream(
364 const HttpRequestInfo* request_info,
365 RequestPriority priority,
366 const BoundNetLog& net_log,
367 const CompletionCallback& callback) {
368 url_ = request_info->url;
369 state_.Initialize(request_info, priority, net_log, callback);
370 return OK;
371 }
372
SendRequest(const HttpRequestHeaders & headers,HttpResponseInfo * response,const CompletionCallback & callback)373 int WebSocketBasicHandshakeStream::SendRequest(
374 const HttpRequestHeaders& headers,
375 HttpResponseInfo* response,
376 const CompletionCallback& callback) {
377 DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey));
378 DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol));
379 DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions));
380 DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin));
381 DCHECK(headers.HasHeader(websockets::kUpgrade));
382 DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection));
383 DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion));
384 DCHECK(parser());
385
386 http_response_info_ = response;
387
388 // Create a copy of the headers object, so that we can add the
389 // Sec-WebSockey-Key header.
390 HttpRequestHeaders enriched_headers;
391 enriched_headers.CopyFrom(headers);
392 std::string handshake_challenge;
393 if (handshake_challenge_for_testing_) {
394 handshake_challenge = *handshake_challenge_for_testing_;
395 handshake_challenge_for_testing_.reset();
396 } else {
397 handshake_challenge = GenerateHandshakeChallenge();
398 }
399 enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge);
400
401 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
402 requested_extensions_,
403 &enriched_headers);
404 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
405 requested_sub_protocols_,
406 &enriched_headers);
407
408 ComputeSecWebSocketAccept(handshake_challenge,
409 &handshake_challenge_response_);
410
411 DCHECK(connect_delegate_);
412 scoped_ptr<WebSocketHandshakeRequestInfo> request(
413 new WebSocketHandshakeRequestInfo(url_, base::Time::Now()));
414 request->headers.CopyFrom(enriched_headers);
415 connect_delegate_->OnStartOpeningHandshake(request.Pass());
416
417 return parser()->SendRequest(
418 state_.GenerateRequestLine(), enriched_headers, response, callback);
419 }
420
ReadResponseHeaders(const CompletionCallback & callback)421 int WebSocketBasicHandshakeStream::ReadResponseHeaders(
422 const CompletionCallback& callback) {
423 // HttpStreamParser uses a weak pointer when reading from the
424 // socket, so it won't be called back after being destroyed. The
425 // HttpStreamParser is owned by HttpBasicState which is owned by this object,
426 // so this use of base::Unretained() is safe.
427 int rv = parser()->ReadResponseHeaders(
428 base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback,
429 base::Unretained(this),
430 callback));
431 if (rv == ERR_IO_PENDING)
432 return rv;
433 return ValidateResponse(rv);
434 }
435
ReadResponseBody(IOBuffer * buf,int buf_len,const CompletionCallback & callback)436 int WebSocketBasicHandshakeStream::ReadResponseBody(
437 IOBuffer* buf,
438 int buf_len,
439 const CompletionCallback& callback) {
440 return parser()->ReadResponseBody(buf, buf_len, callback);
441 }
442
Close(bool not_reusable)443 void WebSocketBasicHandshakeStream::Close(bool not_reusable) {
444 // This class ignores the value of |not_reusable| and never lets the socket be
445 // re-used.
446 if (parser())
447 parser()->Close(true);
448 }
449
IsResponseBodyComplete() const450 bool WebSocketBasicHandshakeStream::IsResponseBodyComplete() const {
451 return parser()->IsResponseBodyComplete();
452 }
453
CanFindEndOfResponse() const454 bool WebSocketBasicHandshakeStream::CanFindEndOfResponse() const {
455 return parser() && parser()->CanFindEndOfResponse();
456 }
457
IsConnectionReused() const458 bool WebSocketBasicHandshakeStream::IsConnectionReused() const {
459 return parser()->IsConnectionReused();
460 }
461
SetConnectionReused()462 void WebSocketBasicHandshakeStream::SetConnectionReused() {
463 parser()->SetConnectionReused();
464 }
465
IsConnectionReusable() const466 bool WebSocketBasicHandshakeStream::IsConnectionReusable() const {
467 return false;
468 }
469
GetTotalReceivedBytes() const470 int64 WebSocketBasicHandshakeStream::GetTotalReceivedBytes() const {
471 return 0;
472 }
473
GetLoadTimingInfo(LoadTimingInfo * load_timing_info) const474 bool WebSocketBasicHandshakeStream::GetLoadTimingInfo(
475 LoadTimingInfo* load_timing_info) const {
476 return state_.connection()->GetLoadTimingInfo(IsConnectionReused(),
477 load_timing_info);
478 }
479
GetSSLInfo(SSLInfo * ssl_info)480 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {
481 parser()->GetSSLInfo(ssl_info);
482 }
483
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info)484 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo(
485 SSLCertRequestInfo* cert_request_info) {
486 parser()->GetSSLCertRequestInfo(cert_request_info);
487 }
488
IsSpdyHttpStream() const489 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; }
490
Drain(HttpNetworkSession * session)491 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) {
492 HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this);
493 drainer->Start(session);
494 // |drainer| will delete itself.
495 }
496
SetPriority(RequestPriority priority)497 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) {
498 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is
499 // gone, then copy whatever has happened there over here.
500 }
501
Upgrade()502 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() {
503 // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
504 // sure it does not touch it again before it is destroyed.
505 state_.DeleteParser();
506 WebSocketTransportClientSocketPool::UnlockEndpoint(state_.connection());
507 scoped_ptr<WebSocketStream> basic_stream(
508 new WebSocketBasicStream(state_.ReleaseConnection(),
509 state_.read_buf(),
510 sub_protocol_,
511 extensions_));
512 DCHECK(extension_params_.get());
513 if (extension_params_->deflate_enabled) {
514 UMA_HISTOGRAM_ENUMERATION(
515 "Net.WebSocket.DeflateMode",
516 extension_params_->deflate_mode,
517 WebSocketDeflater::NUM_CONTEXT_TAKEOVER_MODE_TYPES);
518
519 return scoped_ptr<WebSocketStream>(
520 new WebSocketDeflateStream(basic_stream.Pass(),
521 extension_params_->deflate_mode,
522 extension_params_->client_window_bits,
523 scoped_ptr<WebSocketDeflatePredictor>(
524 new WebSocketDeflatePredictorImpl)));
525 } else {
526 return basic_stream.Pass();
527 }
528 }
529
SetWebSocketKeyForTesting(const std::string & key)530 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
531 const std::string& key) {
532 handshake_challenge_for_testing_.reset(new std::string(key));
533 }
534
ReadResponseHeadersCallback(const CompletionCallback & callback,int result)535 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
536 const CompletionCallback& callback,
537 int result) {
538 callback.Run(ValidateResponse(result));
539 }
540
OnFinishOpeningHandshake()541 void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() {
542 DCHECK(http_response_info_);
543 WebSocketDispatchOnFinishOpeningHandshake(connect_delegate_,
544 url_,
545 http_response_info_->headers,
546 http_response_info_->response_time);
547 }
548
ValidateResponse(int rv)549 int WebSocketBasicHandshakeStream::ValidateResponse(int rv) {
550 DCHECK(http_response_info_);
551 // Most net errors happen during connection, so they are not seen by this
552 // method. The histogram for error codes is created in
553 // Delegate::OnResponseStarted in websocket_stream.cc instead.
554 if (rv >= 0) {
555 const HttpResponseHeaders* headers = http_response_info_->headers.get();
556 const int response_code = headers->response_code();
557 UMA_HISTOGRAM_SPARSE_SLOWLY("Net.WebSocket.ResponseCode", response_code);
558 switch (response_code) {
559 case HTTP_SWITCHING_PROTOCOLS:
560 OnFinishOpeningHandshake();
561 return ValidateUpgradeResponse(headers);
562
563 // We need to pass these through for authentication to work.
564 case HTTP_UNAUTHORIZED:
565 case HTTP_PROXY_AUTHENTICATION_REQUIRED:
566 return OK;
567
568 // Other status codes are potentially risky (see the warnings in the
569 // WHATWG WebSocket API spec) and so are dropped by default.
570 default:
571 // A WebSocket server cannot be using HTTP/0.9, so if we see version
572 // 0.9, it means the response was garbage.
573 // Reporting "Unexpected response code: 200" in this case is not
574 // helpful, so use a different error message.
575 if (headers->GetHttpVersion() == HttpVersion(0, 9)) {
576 set_failure_message(
577 "Error during WebSocket handshake: Invalid status line");
578 } else {
579 set_failure_message(base::StringPrintf(
580 "Error during WebSocket handshake: Unexpected response code: %d",
581 headers->response_code()));
582 }
583 OnFinishOpeningHandshake();
584 return ERR_INVALID_RESPONSE;
585 }
586 } else {
587 if (rv == ERR_EMPTY_RESPONSE) {
588 set_failure_message(
589 "Connection closed before receiving a handshake response");
590 return rv;
591 }
592 set_failure_message(std::string("Error during WebSocket handshake: ") +
593 ErrorToString(rv));
594 OnFinishOpeningHandshake();
595 return rv;
596 }
597 }
598
ValidateUpgradeResponse(const HttpResponseHeaders * headers)599 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
600 const HttpResponseHeaders* headers) {
601 extension_params_.reset(new WebSocketExtensionParams);
602 std::string failure_message;
603 if (ValidateUpgrade(headers, &failure_message) &&
604 ValidateSecWebSocketAccept(
605 headers, handshake_challenge_response_, &failure_message) &&
606 ValidateConnection(headers, &failure_message) &&
607 ValidateSubProtocol(headers,
608 requested_sub_protocols_,
609 &sub_protocol_,
610 &failure_message) &&
611 ValidateExtensions(headers,
612 requested_extensions_,
613 &extensions_,
614 &failure_message,
615 extension_params_.get())) {
616 return OK;
617 }
618 set_failure_message("Error during WebSocket handshake: " + failure_message);
619 return ERR_INVALID_RESPONSE;
620 }
621
set_failure_message(const std::string & failure_message)622 void WebSocketBasicHandshakeStream::set_failure_message(
623 const std::string& failure_message) {
624 *failure_message_ = failure_message;
625 }
626
627 } // namespace net
628