• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_basic_handshake_stream.h"
6 
7 #include <algorithm>
8 #include <iterator>
9 #include <set>
10 #include <string>
11 #include <vector>
12 
13 #include "base/base64.h"
14 #include "base/basictypes.h"
15 #include "base/bind.h"
16 #include "base/containers/hash_tables.h"
17 #include "base/logging.h"
18 #include "base/metrics/histogram.h"
19 #include "base/metrics/sparse_histogram.h"
20 #include "base/stl_util.h"
21 #include "base/strings/string_number_conversions.h"
22 #include "base/strings/string_piece.h"
23 #include "base/strings/string_util.h"
24 #include "base/strings/stringprintf.h"
25 #include "base/time/time.h"
26 #include "crypto/random.h"
27 #include "net/http/http_request_headers.h"
28 #include "net/http/http_request_info.h"
29 #include "net/http/http_response_body_drainer.h"
30 #include "net/http/http_response_headers.h"
31 #include "net/http/http_status_code.h"
32 #include "net/http/http_stream_parser.h"
33 #include "net/socket/client_socket_handle.h"
34 #include "net/socket/websocket_transport_client_socket_pool.h"
35 #include "net/websockets/websocket_basic_stream.h"
36 #include "net/websockets/websocket_deflate_predictor.h"
37 #include "net/websockets/websocket_deflate_predictor_impl.h"
38 #include "net/websockets/websocket_deflate_stream.h"
39 #include "net/websockets/websocket_deflater.h"
40 #include "net/websockets/websocket_extension_parser.h"
41 #include "net/websockets/websocket_handshake_constants.h"
42 #include "net/websockets/websocket_handshake_handler.h"
43 #include "net/websockets/websocket_handshake_request_info.h"
44 #include "net/websockets/websocket_handshake_response_info.h"
45 #include "net/websockets/websocket_stream.h"
46 
47 namespace net {
48 
49 // TODO(ricea): If more extensions are added, replace this with a more general
50 // mechanism.
51 struct WebSocketExtensionParams {
WebSocketExtensionParamsnet::WebSocketExtensionParams52   WebSocketExtensionParams()
53       : deflate_enabled(false),
54         client_window_bits(15),
55         deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {}
56 
57   bool deflate_enabled;
58   int client_window_bits;
59   WebSocketDeflater::ContextTakeOverMode deflate_mode;
60 };
61 
62 namespace {
63 
64 enum GetHeaderResult {
65   GET_HEADER_OK,
66   GET_HEADER_MISSING,
67   GET_HEADER_MULTIPLE,
68 };
69 
MissingHeaderMessage(const std::string & header_name)70 std::string MissingHeaderMessage(const std::string& header_name) {
71   return std::string("'") + header_name + "' header is missing";
72 }
73 
MultipleHeaderValuesMessage(const std::string & header_name)74 std::string MultipleHeaderValuesMessage(const std::string& header_name) {
75   return
76       std::string("'") +
77       header_name +
78       "' header must not appear more than once in a response";
79 }
80 
GenerateHandshakeChallenge()81 std::string GenerateHandshakeChallenge() {
82   std::string raw_challenge(websockets::kRawChallengeLength, '\0');
83   crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length());
84   std::string encoded_challenge;
85   base::Base64Encode(raw_challenge, &encoded_challenge);
86   return encoded_challenge;
87 }
88 
AddVectorHeaderIfNonEmpty(const char * name,const std::vector<std::string> & value,HttpRequestHeaders * headers)89 void AddVectorHeaderIfNonEmpty(const char* name,
90                                const std::vector<std::string>& value,
91                                HttpRequestHeaders* headers) {
92   if (value.empty())
93     return;
94   headers->SetHeader(name, JoinString(value, ", "));
95 }
96 
GetSingleHeaderValue(const HttpResponseHeaders * headers,const base::StringPiece & name,std::string * value)97 GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers,
98                                      const base::StringPiece& name,
99                                      std::string* value) {
100   void* state = NULL;
101   size_t num_values = 0;
102   std::string temp_value;
103   while (headers->EnumerateHeader(&state, name, &temp_value)) {
104     if (++num_values > 1)
105       return GET_HEADER_MULTIPLE;
106     *value = temp_value;
107   }
108   return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING;
109 }
110 
ValidateHeaderHasSingleValue(GetHeaderResult result,const std::string & header_name,std::string * failure_message)111 bool ValidateHeaderHasSingleValue(GetHeaderResult result,
112                                   const std::string& header_name,
113                                   std::string* failure_message) {
114   if (result == GET_HEADER_MISSING) {
115     *failure_message = MissingHeaderMessage(header_name);
116     return false;
117   }
118   if (result == GET_HEADER_MULTIPLE) {
119     *failure_message = MultipleHeaderValuesMessage(header_name);
120     return false;
121   }
122   DCHECK_EQ(result, GET_HEADER_OK);
123   return true;
124 }
125 
ValidateUpgrade(const HttpResponseHeaders * headers,std::string * failure_message)126 bool ValidateUpgrade(const HttpResponseHeaders* headers,
127                      std::string* failure_message) {
128   std::string value;
129   GetHeaderResult result =
130       GetSingleHeaderValue(headers, websockets::kUpgrade, &value);
131   if (!ValidateHeaderHasSingleValue(result,
132                                     websockets::kUpgrade,
133                                     failure_message)) {
134     return false;
135   }
136 
137   if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) {
138     *failure_message =
139         "'Upgrade' header value is not 'WebSocket': " + value;
140     return false;
141   }
142   return true;
143 }
144 
ValidateSecWebSocketAccept(const HttpResponseHeaders * headers,const std::string & expected,std::string * failure_message)145 bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers,
146                                 const std::string& expected,
147                                 std::string* failure_message) {
148   std::string actual;
149   GetHeaderResult result =
150       GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual);
151   if (!ValidateHeaderHasSingleValue(result,
152                                     websockets::kSecWebSocketAccept,
153                                     failure_message)) {
154     return false;
155   }
156 
157   if (expected != actual) {
158     *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value";
159     return false;
160   }
161   return true;
162 }
163 
ValidateConnection(const HttpResponseHeaders * headers,std::string * failure_message)164 bool ValidateConnection(const HttpResponseHeaders* headers,
165                         std::string* failure_message) {
166   // Connection header is permitted to contain other tokens.
167   if (!headers->HasHeader(HttpRequestHeaders::kConnection)) {
168     *failure_message = MissingHeaderMessage(HttpRequestHeaders::kConnection);
169     return false;
170   }
171   if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection,
172                                websockets::kUpgrade)) {
173     *failure_message = "'Connection' header value must contain 'Upgrade'";
174     return false;
175   }
176   return true;
177 }
178 
ValidateSubProtocol(const HttpResponseHeaders * headers,const std::vector<std::string> & requested_sub_protocols,std::string * sub_protocol,std::string * failure_message)179 bool ValidateSubProtocol(
180     const HttpResponseHeaders* headers,
181     const std::vector<std::string>& requested_sub_protocols,
182     std::string* sub_protocol,
183     std::string* failure_message) {
184   void* state = NULL;
185   std::string value;
186   base::hash_set<std::string> requested_set(requested_sub_protocols.begin(),
187                                             requested_sub_protocols.end());
188   int count = 0;
189   bool has_multiple_protocols = false;
190   bool has_invalid_protocol = false;
191 
192   while (!has_invalid_protocol || !has_multiple_protocols) {
193     std::string temp_value;
194     if (!headers->EnumerateHeader(
195             &state, websockets::kSecWebSocketProtocol, &temp_value))
196       break;
197     value = temp_value;
198     if (requested_set.count(value) == 0)
199       has_invalid_protocol = true;
200     if (++count > 1)
201       has_multiple_protocols = true;
202   }
203 
204   if (has_multiple_protocols) {
205     *failure_message =
206         MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
207     return false;
208   } else if (count > 0 && requested_sub_protocols.size() == 0) {
209     *failure_message =
210         std::string("Response must not include 'Sec-WebSocket-Protocol' "
211                     "header if not present in request: ")
212         + value;
213     return false;
214   } else if (has_invalid_protocol) {
215     *failure_message =
216         "'Sec-WebSocket-Protocol' header value '" +
217         value +
218         "' in response does not match any of sent values";
219     return false;
220   } else if (requested_sub_protocols.size() > 0 && count == 0) {
221     *failure_message =
222         "Sent non-empty 'Sec-WebSocket-Protocol' header "
223         "but no response was received";
224     return false;
225   }
226   *sub_protocol = value;
227   return true;
228 }
229 
DeflateError(std::string * message,const base::StringPiece & piece)230 bool DeflateError(std::string* message, const base::StringPiece& piece) {
231   *message = "Error in permessage-deflate: ";
232   piece.AppendToString(message);
233   return false;
234 }
235 
ValidatePerMessageDeflateExtension(const WebSocketExtension & extension,std::string * failure_message,WebSocketExtensionParams * params)236 bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension,
237                                         std::string* failure_message,
238                                         WebSocketExtensionParams* params) {
239   static const char kClientPrefix[] = "client_";
240   static const char kServerPrefix[] = "server_";
241   static const char kNoContextTakeover[] = "no_context_takeover";
242   static const char kMaxWindowBits[] = "max_window_bits";
243   const size_t kPrefixLen = arraysize(kClientPrefix) - 1;
244   COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1,
245                  the_strings_server_and_client_must_be_the_same_length);
246   typedef std::vector<WebSocketExtension::Parameter> ParameterVector;
247 
248   DCHECK_EQ("permessage-deflate", extension.name());
249   const ParameterVector& parameters = extension.parameters();
250   std::set<std::string> seen_names;
251   for (ParameterVector::const_iterator it = parameters.begin();
252        it != parameters.end(); ++it) {
253     const std::string& name = it->name();
254     if (seen_names.count(name) != 0) {
255       return DeflateError(
256           failure_message,
257           "Received duplicate permessage-deflate extension parameter " + name);
258     }
259     seen_names.insert(name);
260     const std::string client_or_server(name, 0, kPrefixLen);
261     const bool is_client = (client_or_server == kClientPrefix);
262     if (!is_client && client_or_server != kServerPrefix) {
263       return DeflateError(
264           failure_message,
265           "Received an unexpected permessage-deflate extension parameter");
266     }
267     const std::string rest(name, kPrefixLen);
268     if (rest == kNoContextTakeover) {
269       if (it->HasValue()) {
270         return DeflateError(failure_message,
271                             "Received invalid " + name + " parameter");
272       }
273       if (is_client)
274         params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT;
275     } else if (rest == kMaxWindowBits) {
276       if (!it->HasValue())
277         return DeflateError(failure_message, name + " must have value");
278       int bits = 0;
279       if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 ||
280           it->value()[0] == '0' ||
281           it->value().find_first_not_of("0123456789") != std::string::npos) {
282         return DeflateError(failure_message,
283                             "Received invalid " + name + " parameter");
284       }
285       if (is_client)
286         params->client_window_bits = bits;
287     } else {
288       return DeflateError(
289           failure_message,
290           "Received an unexpected permessage-deflate extension parameter");
291     }
292   }
293   params->deflate_enabled = true;
294   return true;
295 }
296 
ValidateExtensions(const HttpResponseHeaders * headers,const std::vector<std::string> & requested_extensions,std::string * extensions,std::string * failure_message,WebSocketExtensionParams * params)297 bool ValidateExtensions(const HttpResponseHeaders* headers,
298                         const std::vector<std::string>& requested_extensions,
299                         std::string* extensions,
300                         std::string* failure_message,
301                         WebSocketExtensionParams* params) {
302   void* state = NULL;
303   std::string value;
304   std::vector<std::string> accepted_extensions;
305   // TODO(ricea): If adding support for additional extensions, generalise this
306   // code.
307   bool seen_permessage_deflate = false;
308   while (headers->EnumerateHeader(
309              &state, websockets::kSecWebSocketExtensions, &value)) {
310     WebSocketExtensionParser parser;
311     parser.Parse(value);
312     if (parser.has_error()) {
313       // TODO(yhirano) Set appropriate failure message.
314       *failure_message =
315           "'Sec-WebSocket-Extensions' header value is "
316           "rejected by the parser: " +
317           value;
318       return false;
319     }
320     if (parser.extension().name() == "permessage-deflate") {
321       if (seen_permessage_deflate) {
322         *failure_message = "Received duplicate permessage-deflate response";
323         return false;
324       }
325       seen_permessage_deflate = true;
326       if (!ValidatePerMessageDeflateExtension(
327               parser.extension(), failure_message, params))
328         return false;
329     } else {
330       *failure_message =
331           "Found an unsupported extension '" +
332           parser.extension().name() +
333           "' in 'Sec-WebSocket-Extensions' header";
334       return false;
335     }
336     accepted_extensions.push_back(value);
337   }
338   *extensions = JoinString(accepted_extensions, ", ");
339   return true;
340 }
341 
342 }  // namespace
343 
WebSocketBasicHandshakeStream(scoped_ptr<ClientSocketHandle> connection,WebSocketStream::ConnectDelegate * connect_delegate,bool using_proxy,std::vector<std::string> requested_sub_protocols,std::vector<std::string> requested_extensions,std::string * failure_message)344 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream(
345     scoped_ptr<ClientSocketHandle> connection,
346     WebSocketStream::ConnectDelegate* connect_delegate,
347     bool using_proxy,
348     std::vector<std::string> requested_sub_protocols,
349     std::vector<std::string> requested_extensions,
350     std::string* failure_message)
351     : state_(connection.release(), using_proxy),
352       connect_delegate_(connect_delegate),
353       http_response_info_(NULL),
354       requested_sub_protocols_(requested_sub_protocols),
355       requested_extensions_(requested_extensions),
356       failure_message_(failure_message) {
357   DCHECK(connect_delegate);
358   DCHECK(failure_message);
359 }
360 
~WebSocketBasicHandshakeStream()361 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {}
362 
InitializeStream(const HttpRequestInfo * request_info,RequestPriority priority,const BoundNetLog & net_log,const CompletionCallback & callback)363 int WebSocketBasicHandshakeStream::InitializeStream(
364     const HttpRequestInfo* request_info,
365     RequestPriority priority,
366     const BoundNetLog& net_log,
367     const CompletionCallback& callback) {
368   url_ = request_info->url;
369   state_.Initialize(request_info, priority, net_log, callback);
370   return OK;
371 }
372 
SendRequest(const HttpRequestHeaders & headers,HttpResponseInfo * response,const CompletionCallback & callback)373 int WebSocketBasicHandshakeStream::SendRequest(
374     const HttpRequestHeaders& headers,
375     HttpResponseInfo* response,
376     const CompletionCallback& callback) {
377   DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey));
378   DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol));
379   DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions));
380   DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin));
381   DCHECK(headers.HasHeader(websockets::kUpgrade));
382   DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection));
383   DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion));
384   DCHECK(parser());
385 
386   http_response_info_ = response;
387 
388   // Create a copy of the headers object, so that we can add the
389   // Sec-WebSockey-Key header.
390   HttpRequestHeaders enriched_headers;
391   enriched_headers.CopyFrom(headers);
392   std::string handshake_challenge;
393   if (handshake_challenge_for_testing_) {
394     handshake_challenge = *handshake_challenge_for_testing_;
395     handshake_challenge_for_testing_.reset();
396   } else {
397     handshake_challenge = GenerateHandshakeChallenge();
398   }
399   enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge);
400 
401   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
402                             requested_extensions_,
403                             &enriched_headers);
404   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
405                             requested_sub_protocols_,
406                             &enriched_headers);
407 
408   ComputeSecWebSocketAccept(handshake_challenge,
409                             &handshake_challenge_response_);
410 
411   DCHECK(connect_delegate_);
412   scoped_ptr<WebSocketHandshakeRequestInfo> request(
413       new WebSocketHandshakeRequestInfo(url_, base::Time::Now()));
414   request->headers.CopyFrom(enriched_headers);
415   connect_delegate_->OnStartOpeningHandshake(request.Pass());
416 
417   return parser()->SendRequest(
418       state_.GenerateRequestLine(), enriched_headers, response, callback);
419 }
420 
ReadResponseHeaders(const CompletionCallback & callback)421 int WebSocketBasicHandshakeStream::ReadResponseHeaders(
422     const CompletionCallback& callback) {
423   // HttpStreamParser uses a weak pointer when reading from the
424   // socket, so it won't be called back after being destroyed. The
425   // HttpStreamParser is owned by HttpBasicState which is owned by this object,
426   // so this use of base::Unretained() is safe.
427   int rv = parser()->ReadResponseHeaders(
428       base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback,
429                  base::Unretained(this),
430                  callback));
431   if (rv == ERR_IO_PENDING)
432     return rv;
433   return ValidateResponse(rv);
434 }
435 
ReadResponseBody(IOBuffer * buf,int buf_len,const CompletionCallback & callback)436 int WebSocketBasicHandshakeStream::ReadResponseBody(
437     IOBuffer* buf,
438     int buf_len,
439     const CompletionCallback& callback) {
440   return parser()->ReadResponseBody(buf, buf_len, callback);
441 }
442 
Close(bool not_reusable)443 void WebSocketBasicHandshakeStream::Close(bool not_reusable) {
444   // This class ignores the value of |not_reusable| and never lets the socket be
445   // re-used.
446   if (parser())
447     parser()->Close(true);
448 }
449 
IsResponseBodyComplete() const450 bool WebSocketBasicHandshakeStream::IsResponseBodyComplete() const {
451   return parser()->IsResponseBodyComplete();
452 }
453 
CanFindEndOfResponse() const454 bool WebSocketBasicHandshakeStream::CanFindEndOfResponse() const {
455   return parser() && parser()->CanFindEndOfResponse();
456 }
457 
IsConnectionReused() const458 bool WebSocketBasicHandshakeStream::IsConnectionReused() const {
459   return parser()->IsConnectionReused();
460 }
461 
SetConnectionReused()462 void WebSocketBasicHandshakeStream::SetConnectionReused() {
463   parser()->SetConnectionReused();
464 }
465 
IsConnectionReusable() const466 bool WebSocketBasicHandshakeStream::IsConnectionReusable() const {
467   return false;
468 }
469 
GetTotalReceivedBytes() const470 int64 WebSocketBasicHandshakeStream::GetTotalReceivedBytes() const {
471   return 0;
472 }
473 
GetLoadTimingInfo(LoadTimingInfo * load_timing_info) const474 bool WebSocketBasicHandshakeStream::GetLoadTimingInfo(
475     LoadTimingInfo* load_timing_info) const {
476   return state_.connection()->GetLoadTimingInfo(IsConnectionReused(),
477                                                 load_timing_info);
478 }
479 
GetSSLInfo(SSLInfo * ssl_info)480 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {
481   parser()->GetSSLInfo(ssl_info);
482 }
483 
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info)484 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo(
485     SSLCertRequestInfo* cert_request_info) {
486   parser()->GetSSLCertRequestInfo(cert_request_info);
487 }
488 
IsSpdyHttpStream() const489 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; }
490 
Drain(HttpNetworkSession * session)491 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) {
492   HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this);
493   drainer->Start(session);
494   // |drainer| will delete itself.
495 }
496 
SetPriority(RequestPriority priority)497 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) {
498   // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is
499   // gone, then copy whatever has happened there over here.
500 }
501 
Upgrade()502 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() {
503   // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
504   // sure it does not touch it again before it is destroyed.
505   state_.DeleteParser();
506   WebSocketTransportClientSocketPool::UnlockEndpoint(state_.connection());
507   scoped_ptr<WebSocketStream> basic_stream(
508       new WebSocketBasicStream(state_.ReleaseConnection(),
509                                state_.read_buf(),
510                                sub_protocol_,
511                                extensions_));
512   DCHECK(extension_params_.get());
513   if (extension_params_->deflate_enabled) {
514     UMA_HISTOGRAM_ENUMERATION(
515         "Net.WebSocket.DeflateMode",
516         extension_params_->deflate_mode,
517         WebSocketDeflater::NUM_CONTEXT_TAKEOVER_MODE_TYPES);
518 
519     return scoped_ptr<WebSocketStream>(
520         new WebSocketDeflateStream(basic_stream.Pass(),
521                                    extension_params_->deflate_mode,
522                                    extension_params_->client_window_bits,
523                                    scoped_ptr<WebSocketDeflatePredictor>(
524                                        new WebSocketDeflatePredictorImpl)));
525   } else {
526     return basic_stream.Pass();
527   }
528 }
529 
SetWebSocketKeyForTesting(const std::string & key)530 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
531     const std::string& key) {
532   handshake_challenge_for_testing_.reset(new std::string(key));
533 }
534 
ReadResponseHeadersCallback(const CompletionCallback & callback,int result)535 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
536     const CompletionCallback& callback,
537     int result) {
538   callback.Run(ValidateResponse(result));
539 }
540 
OnFinishOpeningHandshake()541 void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() {
542   DCHECK(http_response_info_);
543   WebSocketDispatchOnFinishOpeningHandshake(connect_delegate_,
544                                             url_,
545                                             http_response_info_->headers,
546                                             http_response_info_->response_time);
547 }
548 
ValidateResponse(int rv)549 int WebSocketBasicHandshakeStream::ValidateResponse(int rv) {
550   DCHECK(http_response_info_);
551   // Most net errors happen during connection, so they are not seen by this
552   // method. The histogram for error codes is created in
553   // Delegate::OnResponseStarted in websocket_stream.cc instead.
554   if (rv >= 0) {
555     const HttpResponseHeaders* headers = http_response_info_->headers.get();
556     const int response_code = headers->response_code();
557     UMA_HISTOGRAM_SPARSE_SLOWLY("Net.WebSocket.ResponseCode", response_code);
558     switch (response_code) {
559       case HTTP_SWITCHING_PROTOCOLS:
560         OnFinishOpeningHandshake();
561         return ValidateUpgradeResponse(headers);
562 
563       // We need to pass these through for authentication to work.
564       case HTTP_UNAUTHORIZED:
565       case HTTP_PROXY_AUTHENTICATION_REQUIRED:
566         return OK;
567 
568       // Other status codes are potentially risky (see the warnings in the
569       // WHATWG WebSocket API spec) and so are dropped by default.
570       default:
571         // A WebSocket server cannot be using HTTP/0.9, so if we see version
572         // 0.9, it means the response was garbage.
573         // Reporting "Unexpected response code: 200" in this case is not
574         // helpful, so use a different error message.
575         if (headers->GetHttpVersion() == HttpVersion(0, 9)) {
576           set_failure_message(
577               "Error during WebSocket handshake: Invalid status line");
578         } else {
579           set_failure_message(base::StringPrintf(
580               "Error during WebSocket handshake: Unexpected response code: %d",
581               headers->response_code()));
582         }
583         OnFinishOpeningHandshake();
584         return ERR_INVALID_RESPONSE;
585     }
586   } else {
587     if (rv == ERR_EMPTY_RESPONSE) {
588       set_failure_message(
589           "Connection closed before receiving a handshake response");
590       return rv;
591     }
592     set_failure_message(std::string("Error during WebSocket handshake: ") +
593                         ErrorToString(rv));
594     OnFinishOpeningHandshake();
595     return rv;
596   }
597 }
598 
ValidateUpgradeResponse(const HttpResponseHeaders * headers)599 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
600     const HttpResponseHeaders* headers) {
601   extension_params_.reset(new WebSocketExtensionParams);
602   std::string failure_message;
603   if (ValidateUpgrade(headers, &failure_message) &&
604       ValidateSecWebSocketAccept(
605           headers, handshake_challenge_response_, &failure_message) &&
606       ValidateConnection(headers, &failure_message) &&
607       ValidateSubProtocol(headers,
608                           requested_sub_protocols_,
609                           &sub_protocol_,
610                           &failure_message) &&
611       ValidateExtensions(headers,
612                          requested_extensions_,
613                          &extensions_,
614                          &failure_message,
615                          extension_params_.get())) {
616     return OK;
617   }
618   set_failure_message("Error during WebSocket handshake: " + failure_message);
619   return ERR_INVALID_RESPONSE;
620 }
621 
set_failure_message(const std::string & failure_message)622 void WebSocketBasicHandshakeStream::set_failure_message(
623     const std::string& failure_message) {
624   *failure_message_ = failure_message;
625 }
626 
627 }  // namespace net
628