• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_stream.h"
6 
7 #include <utility>
8 
9 #include "base/functional/bind.h"
10 #include "base/logging.h"
11 #include "base/memory/raw_ptr.h"
12 #include "base/memory/weak_ptr.h"
13 #include "base/metrics/histogram_functions.h"
14 #include "base/time/time.h"
15 #include "base/timer/timer.h"
16 #include "net/base/ip_endpoint.h"
17 #include "net/base/isolation_info.h"
18 #include "net/base/load_flags.h"
19 #include "net/base/url_util.h"
20 #include "net/http/http_request_headers.h"
21 #include "net/http/http_response_headers.h"
22 #include "net/http/http_response_info.h"
23 #include "net/http/http_status_code.h"
24 #include "net/traffic_annotation/network_traffic_annotation.h"
25 #include "net/url_request/redirect_info.h"
26 #include "net/url_request/url_request.h"
27 #include "net/url_request/url_request_context.h"
28 #include "net/url_request/websocket_handshake_userdata_key.h"
29 #include "net/websockets/websocket_basic_handshake_stream.h"
30 #include "net/websockets/websocket_errors.h"
31 #include "net/websockets/websocket_event_interface.h"
32 #include "net/websockets/websocket_handshake_constants.h"
33 #include "net/websockets/websocket_handshake_stream_base.h"
34 #include "net/websockets/websocket_handshake_stream_create_helper.h"
35 #include "net/websockets/websocket_http2_handshake_stream.h"
36 #include "net/websockets/websocket_http3_handshake_stream.h"
37 #include "third_party/abseil-cpp/absl/types/optional.h"
38 #include "url/gurl.h"
39 #include "url/origin.h"
40 
41 namespace net {
42 namespace {
43 
44 // The timeout duration of WebSocket handshake.
45 // It is defined as the same value as the TCP connection timeout value in
46 // net/socket/websocket_transport_client_socket_pool.cc to make it hard for
47 // JavaScript programs to recognize the timeout cause.
48 const int kHandshakeTimeoutIntervalInSeconds = 240;
49 
50 class WebSocketStreamRequestImpl;
51 
52 class Delegate : public URLRequest::Delegate {
53  public:
Delegate(WebSocketStreamRequestImpl * owner)54   explicit Delegate(WebSocketStreamRequestImpl* owner) : owner_(owner) {}
55   ~Delegate() override = default;
56 
57   // Implementation of URLRequest::Delegate methods.
58   void OnReceivedRedirect(URLRequest* request,
59                           const RedirectInfo& redirect_info,
60                           bool* defer_redirect) override;
61 
62   void OnResponseStarted(URLRequest* request, int net_error) override;
63 
64   void OnAuthRequired(URLRequest* request,
65                       const AuthChallengeInfo& auth_info) override;
66 
67   void OnCertificateRequested(URLRequest* request,
68                               SSLCertRequestInfo* cert_request_info) override;
69 
70   void OnSSLCertificateError(URLRequest* request,
71                              int net_error,
72                              const SSLInfo& ssl_info,
73                              bool fatal) override;
74 
75   void OnReadCompleted(URLRequest* request, int bytes_read) override;
76 
77  private:
78   void OnAuthRequiredComplete(URLRequest* request,
79                               const AuthCredentials* auth_credentials);
80 
81   raw_ptr<WebSocketStreamRequestImpl> owner_;
82 };
83 
84 class WebSocketStreamRequestImpl : public WebSocketStreamRequestAPI {
85  public:
WebSocketStreamRequestImpl(const GURL & url,const std::vector<std::string> & requested_subprotocols,const URLRequestContext * context,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,std::unique_ptr<WebSocketStreamRequestAPI> api_delegate)86   WebSocketStreamRequestImpl(
87       const GURL& url,
88       const std::vector<std::string>& requested_subprotocols,
89       const URLRequestContext* context,
90       const url::Origin& origin,
91       const SiteForCookies& site_for_cookies,
92       const IsolationInfo& isolation_info,
93       const HttpRequestHeaders& additional_headers,
94       NetworkTrafficAnnotationTag traffic_annotation,
95       std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,
96       std::unique_ptr<WebSocketStreamRequestAPI> api_delegate)
97       : delegate_(this),
98         url_request_(context->CreateRequest(url,
99                                             DEFAULT_PRIORITY,
100                                             &delegate_,
101                                             traffic_annotation,
102                                             /*is_for_websockets=*/true)),
103         connect_delegate_(std::move(connect_delegate)),
104         api_delegate_(std::move(api_delegate)) {
105     DCHECK_EQ(IsolationInfo::RequestType::kOther,
106               isolation_info.request_type());
107 
108     HttpRequestHeaders headers = additional_headers;
109     headers.SetHeader(websockets::kUpgrade, websockets::kWebSocketLowercase);
110     headers.SetHeader(HttpRequestHeaders::kConnection, websockets::kUpgrade);
111     headers.SetHeader(HttpRequestHeaders::kOrigin, origin.Serialize());
112     headers.SetHeader(websockets::kSecWebSocketVersion,
113                       websockets::kSupportedVersion);
114 
115     // Remove HTTP headers that are important to websocket connections: they
116     // will be added later.
117     headers.RemoveHeader(websockets::kSecWebSocketExtensions);
118     headers.RemoveHeader(websockets::kSecWebSocketKey);
119     headers.RemoveHeader(websockets::kSecWebSocketProtocol);
120 
121     url_request_->SetExtraRequestHeaders(headers);
122     url_request_->set_initiator(origin);
123     url_request_->set_site_for_cookies(site_for_cookies);
124     url_request_->set_isolation_info(isolation_info);
125 
126     auto create_helper = std::make_unique<WebSocketHandshakeStreamCreateHelper>(
127         connect_delegate_.get(), requested_subprotocols, this);
128     url_request_->SetUserData(kWebSocketHandshakeUserDataKey,
129                               std::move(create_helper));
130     url_request_->SetLoadFlags(LOAD_DISABLE_CACHE | LOAD_BYPASS_CACHE);
131     connect_delegate_->OnCreateRequest(url_request_.get());
132   }
133 
134   // Destroying this object destroys the URLRequest, which cancels the request
135   // and so terminates the handshake if it is incomplete.
136   ~WebSocketStreamRequestImpl() override = default;
137 
OnBasicHandshakeStreamCreated(WebSocketBasicHandshakeStream * handshake_stream)138   void OnBasicHandshakeStreamCreated(
139       WebSocketBasicHandshakeStream* handshake_stream) override {
140     if (api_delegate_) {
141       api_delegate_->OnBasicHandshakeStreamCreated(handshake_stream);
142     }
143     OnHandshakeStreamCreated(handshake_stream);
144   }
145 
OnHttp2HandshakeStreamCreated(WebSocketHttp2HandshakeStream * handshake_stream)146   void OnHttp2HandshakeStreamCreated(
147       WebSocketHttp2HandshakeStream* handshake_stream) override {
148     if (api_delegate_) {
149       api_delegate_->OnHttp2HandshakeStreamCreated(handshake_stream);
150     }
151     OnHandshakeStreamCreated(handshake_stream);
152   }
153 
OnHttp3HandshakeStreamCreated(WebSocketHttp3HandshakeStream * handshake_stream)154   void OnHttp3HandshakeStreamCreated(
155       WebSocketHttp3HandshakeStream* handshake_stream) override {
156     if (api_delegate_) {
157       api_delegate_->OnHttp3HandshakeStreamCreated(handshake_stream);
158     }
159     OnHandshakeStreamCreated(handshake_stream);
160   }
161 
OnFailure(const std::string & message,int net_error,absl::optional<int> response_code)162   void OnFailure(const std::string& message,
163                  int net_error,
164                  absl::optional<int> response_code) override {
165     if (api_delegate_)
166       api_delegate_->OnFailure(message, net_error, response_code);
167     failure_message_ = message;
168     failure_net_error_ = net_error;
169     failure_response_code_ = response_code;
170   }
171 
Start(std::unique_ptr<base::OneShotTimer> timer)172   void Start(std::unique_ptr<base::OneShotTimer> timer) {
173     DCHECK(timer);
174     base::TimeDelta timeout(base::Seconds(kHandshakeTimeoutIntervalInSeconds));
175     timer_ = std::move(timer);
176     timer_->Start(FROM_HERE, timeout,
177                   base::BindOnce(&WebSocketStreamRequestImpl::OnTimeout,
178                                  base::Unretained(this)));
179     url_request_->Start();
180   }
181 
PerformUpgrade()182   void PerformUpgrade() {
183     DCHECK(timer_);
184     DCHECK(connect_delegate_);
185 
186     timer_->Stop();
187 
188     if (!handshake_stream_) {
189       ReportFailureWithMessage(
190           "No handshake stream has been created or handshake stream is already "
191           "destroyed.",
192           ERR_FAILED, absl::nullopt);
193       return;
194     }
195 
196     std::unique_ptr<URLRequest> url_request = std::move(url_request_);
197     WebSocketHandshakeStreamBase* handshake_stream = handshake_stream_.get();
198     handshake_stream_.reset();
199     auto handshake_response_info =
200         std::make_unique<WebSocketHandshakeResponseInfo>(
201             url_request->url(), url_request->response_headers(),
202             url_request->GetResponseRemoteEndpoint(),
203             url_request->response_time());
204     connect_delegate_->OnSuccess(handshake_stream->Upgrade(),
205                                  std::move(handshake_response_info));
206 
207     // This is safe even if |this| has already been deleted.
208     url_request->CancelWithError(ERR_WS_UPGRADE);
209   }
210 
FailureMessageFromNetError(int net_error)211   std::string FailureMessageFromNetError(int net_error) {
212     if (net_error == ERR_TUNNEL_CONNECTION_FAILED) {
213       // This error is common and confusing, so special-case it.
214       // TODO(ricea): Include the HostPortPair of the selected proxy server in
215       // the error message. This is not currently possible because it isn't set
216       // in HttpResponseInfo when a ERR_TUNNEL_CONNECTION_FAILED error happens.
217       return "Establishing a tunnel via proxy server failed.";
218     } else {
219       return std::string("Error in connection establishment: ") +
220              ErrorToString(net_error);
221     }
222   }
223 
ReportFailure(int net_error,absl::optional<int> response_code)224   void ReportFailure(int net_error, absl::optional<int> response_code) {
225     DCHECK(timer_);
226     timer_->Stop();
227     if (failure_message_.empty()) {
228       switch (net_error) {
229         case OK:
230         case ERR_IO_PENDING:
231           break;
232         case ERR_ABORTED:
233           failure_message_ = "WebSocket opening handshake was canceled";
234           break;
235         case ERR_TIMED_OUT:
236           failure_message_ = "WebSocket opening handshake timed out";
237           break;
238         default:
239           failure_message_ = FailureMessageFromNetError(net_error);
240           break;
241       }
242     }
243 
244     ReportFailureWithMessage(
245         failure_message_, failure_net_error_.value_or(net_error),
246         failure_response_code_ ? failure_response_code_ : response_code);
247   }
248 
ReportFailureWithMessage(const std::string & failure_message,int net_error,absl::optional<int> response_code)249   void ReportFailureWithMessage(const std::string& failure_message,
250                                 int net_error,
251                                 absl::optional<int> response_code) {
252     connect_delegate_->OnFailure(failure_message, net_error, response_code);
253   }
254 
connect_delegate() const255   WebSocketStream::ConnectDelegate* connect_delegate() const {
256     return connect_delegate_.get();
257   }
258 
OnTimeout()259   void OnTimeout() {
260     url_request_->CancelWithError(ERR_TIMED_OUT);
261   }
262 
263  private:
OnHandshakeStreamCreated(WebSocketHandshakeStreamBase * handshake_stream)264   void OnHandshakeStreamCreated(
265       WebSocketHandshakeStreamBase* handshake_stream) {
266     DCHECK(handshake_stream);
267 
268     handshake_stream_ = handshake_stream->GetWeakPtr();
269   }
270 
271   // |delegate_| needs to be declared before |url_request_| so that it gets
272   // initialised first.
273   Delegate delegate_;
274 
275   // Deleting the WebSocketStreamRequestImpl object deletes this URLRequest
276   // object, cancelling the whole connection.
277   std::unique_ptr<URLRequest> url_request_;
278 
279   std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate_;
280 
281   // This is owned by the caller of
282   // WebsocketHandshakeStreamCreateHelper::CreateBasicStream() or
283   // CreateHttp2Stream() or CreateHttp3Stream().  Both the stream and this
284   // object will be destroyed during the destruction of the URLRequest object
285   // associated with the handshake. This is only guaranteed to be a valid
286   // pointer if the handshake succeeded.
287   base::WeakPtr<WebSocketHandshakeStreamBase> handshake_stream_;
288 
289   // The failure information supplied by WebSocketBasicHandshakeStream, if any.
290   std::string failure_message_;
291   absl::optional<int> failure_net_error_;
292   absl::optional<int> failure_response_code_;
293 
294   // A timer for handshake timeout.
295   std::unique_ptr<base::OneShotTimer> timer_;
296 
297   // A delegate for On*HandshakeCreated and OnFailure calls.
298   std::unique_ptr<WebSocketStreamRequestAPI> api_delegate_;
299 };
300 
301 class SSLErrorCallbacks : public WebSocketEventInterface::SSLErrorCallbacks {
302  public:
SSLErrorCallbacks(URLRequest * url_request)303   explicit SSLErrorCallbacks(URLRequest* url_request)
304       : url_request_(url_request->GetWeakPtr()) {}
305 
CancelSSLRequest(int error,const SSLInfo * ssl_info)306   void CancelSSLRequest(int error, const SSLInfo* ssl_info) override {
307     if (!url_request_)
308       return;
309 
310     if (ssl_info) {
311       url_request_->CancelWithSSLError(error, *ssl_info);
312     } else {
313       url_request_->CancelWithError(error);
314     }
315   }
316 
ContinueSSLRequest()317   void ContinueSSLRequest() override {
318     if (url_request_)
319       url_request_->ContinueDespiteLastError();
320   }
321 
322  private:
323   base::WeakPtr<URLRequest> url_request_;
324 };
325 
OnReceivedRedirect(URLRequest * request,const RedirectInfo & redirect_info,bool * defer_redirect)326 void Delegate::OnReceivedRedirect(URLRequest* request,
327                                   const RedirectInfo& redirect_info,
328                                   bool* defer_redirect) {
329   // This code should never be reached for externally generated redirects,
330   // as WebSocketBasicHandshakeStream is responsible for filtering out
331   // all response codes besides 101, 401, and 407. As such, the URLRequest
332   // should never see a redirect sent over the network. However, internal
333   // redirects also result in this method being called, such as those
334   // caused by HSTS.
335   // Because it's security critical to prevent externally-generated
336   // redirects in WebSockets, perform additional checks to ensure this
337   // is only internal.
338   GURL::Replacements replacements;
339   replacements.SetSchemeStr("wss");
340   GURL expected_url = request->original_url().ReplaceComponents(replacements);
341   if (redirect_info.new_method != "GET" ||
342       redirect_info.new_url != expected_url) {
343     // This should not happen.
344     DLOG(FATAL) << "Unauthorized WebSocket redirect to "
345                 << redirect_info.new_method << " "
346                 << redirect_info.new_url.spec();
347     request->Cancel();
348   }
349 }
350 
OnResponseStarted(URLRequest * request,int net_error)351 void Delegate::OnResponseStarted(URLRequest* request, int net_error) {
352   DCHECK_NE(ERR_IO_PENDING, net_error);
353   // All error codes, including OK and ABORTED, as with
354   // Net.ErrorCodesForMainFrame4
355   base::UmaHistogramSparse("Net.WebSocket.ErrorCodes", -net_error);
356   if (net::IsLocalhost(request->url())) {
357     base::UmaHistogramSparse("Net.WebSocket.ErrorCodes_Localhost", -net_error);
358   } else {
359     base::UmaHistogramSparse("Net.WebSocket.ErrorCodes_NotLocalhost",
360                              -net_error);
361   }
362 
363   if (net_error != OK) {
364     DVLOG(3) << "OnResponseStarted (request failed)";
365     owner_->ReportFailure(net_error, absl::nullopt);
366     return;
367   }
368   const int response_code = request->GetResponseCode();
369   DVLOG(3) << "OnResponseStarted (response code " << response_code << ")";
370 
371   if (request->response_info().connection_info ==
372       HttpResponseInfo::CONNECTION_INFO_HTTP2) {
373     if (response_code == HTTP_OK) {
374       owner_->PerformUpgrade();
375       return;
376     }
377 
378     owner_->ReportFailure(net_error, absl::nullopt);
379     return;
380   }
381 
382   switch (response_code) {
383     case HTTP_SWITCHING_PROTOCOLS:
384       owner_->PerformUpgrade();
385       return;
386 
387     case HTTP_UNAUTHORIZED:
388       owner_->ReportFailureWithMessage(
389           "HTTP Authentication failed; no valid credentials available",
390           net_error, response_code);
391       return;
392 
393     case HTTP_PROXY_AUTHENTICATION_REQUIRED:
394       owner_->ReportFailureWithMessage("Proxy authentication failed", net_error,
395                                        response_code);
396       return;
397 
398     default:
399       owner_->ReportFailure(net_error, response_code);
400   }
401 }
402 
OnAuthRequired(URLRequest * request,const AuthChallengeInfo & auth_info)403 void Delegate::OnAuthRequired(URLRequest* request,
404                               const AuthChallengeInfo& auth_info) {
405   absl::optional<AuthCredentials> credentials;
406   // This base::Unretained(this) relies on an assumption that |callback| can
407   // be called called during the opening handshake.
408   int rv = owner_->connect_delegate()->OnAuthRequired(
409       auth_info, request->response_headers(),
410       request->GetResponseRemoteEndpoint(),
411       base::BindOnce(&Delegate::OnAuthRequiredComplete, base::Unretained(this),
412                      request),
413       &credentials);
414   request->LogBlockedBy("WebSocketStream::Delegate::OnAuthRequired");
415   if (rv == ERR_IO_PENDING)
416     return;
417   if (rv != OK) {
418     request->LogUnblocked();
419     owner_->ReportFailure(rv, absl::nullopt);
420     return;
421   }
422   OnAuthRequiredComplete(request, nullptr);
423 }
424 
OnAuthRequiredComplete(URLRequest * request,const AuthCredentials * credentials)425 void Delegate::OnAuthRequiredComplete(URLRequest* request,
426                                       const AuthCredentials* credentials) {
427   request->LogUnblocked();
428   if (!credentials) {
429     request->CancelAuth();
430     return;
431   }
432   request->SetAuth(*credentials);
433 }
434 
OnCertificateRequested(URLRequest * request,SSLCertRequestInfo * cert_request_info)435 void Delegate::OnCertificateRequested(URLRequest* request,
436                                       SSLCertRequestInfo* cert_request_info) {
437   // This method is called when a client certificate is requested, and the
438   // request context does not already contain a client certificate selection for
439   // the endpoint. In this case, a main frame resource request would pop-up UI
440   // to permit selection of a client certificate, but since WebSockets are
441   // sub-resources they should not pop-up UI and so there is nothing more we can
442   // do.
443   request->Cancel();
444 }
445 
OnSSLCertificateError(URLRequest * request,int net_error,const SSLInfo & ssl_info,bool fatal)446 void Delegate::OnSSLCertificateError(URLRequest* request,
447                                      int net_error,
448                                      const SSLInfo& ssl_info,
449                                      bool fatal) {
450   owner_->connect_delegate()->OnSSLCertificateError(
451       std::make_unique<SSLErrorCallbacks>(request), net_error, ssl_info, fatal);
452 }
453 
OnReadCompleted(URLRequest * request,int bytes_read)454 void Delegate::OnReadCompleted(URLRequest* request, int bytes_read) {
455   NOTREACHED();
456 }
457 
458 }  // namespace
459 
460 WebSocketStreamRequest::~WebSocketStreamRequest() = default;
461 
462 WebSocketStream::WebSocketStream() = default;
463 WebSocketStream::~WebSocketStream() = default;
464 
465 WebSocketStream::ConnectDelegate::~ConnectDelegate() = default;
466 
CreateAndConnectStream(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,URLRequestContext * url_request_context,const NetLogWithSource & net_log,NetworkTrafficAnnotationTag traffic_annotation,std::unique_ptr<ConnectDelegate> connect_delegate)467 std::unique_ptr<WebSocketStreamRequest> WebSocketStream::CreateAndConnectStream(
468     const GURL& socket_url,
469     const std::vector<std::string>& requested_subprotocols,
470     const url::Origin& origin,
471     const SiteForCookies& site_for_cookies,
472     const IsolationInfo& isolation_info,
473     const HttpRequestHeaders& additional_headers,
474     URLRequestContext* url_request_context,
475     const NetLogWithSource& net_log,
476     NetworkTrafficAnnotationTag traffic_annotation,
477     std::unique_ptr<ConnectDelegate> connect_delegate) {
478   auto request = std::make_unique<WebSocketStreamRequestImpl>(
479       socket_url, requested_subprotocols, url_request_context, origin,
480       site_for_cookies, isolation_info, additional_headers, traffic_annotation,
481       std::move(connect_delegate), nullptr);
482   request->Start(std::make_unique<base::OneShotTimer>());
483   return std::move(request);
484 }
485 
486 std::unique_ptr<WebSocketStreamRequest>
CreateAndConnectStreamForTesting(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,URLRequestContext * url_request_context,const NetLogWithSource & net_log,NetworkTrafficAnnotationTag traffic_annotation,std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,std::unique_ptr<base::OneShotTimer> timer,std::unique_ptr<WebSocketStreamRequestAPI> api_delegate)487 WebSocketStream::CreateAndConnectStreamForTesting(
488     const GURL& socket_url,
489     const std::vector<std::string>& requested_subprotocols,
490     const url::Origin& origin,
491     const SiteForCookies& site_for_cookies,
492     const IsolationInfo& isolation_info,
493     const HttpRequestHeaders& additional_headers,
494     URLRequestContext* url_request_context,
495     const NetLogWithSource& net_log,
496     NetworkTrafficAnnotationTag traffic_annotation,
497     std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,
498     std::unique_ptr<base::OneShotTimer> timer,
499     std::unique_ptr<WebSocketStreamRequestAPI> api_delegate) {
500   auto request = std::make_unique<WebSocketStreamRequestImpl>(
501       socket_url, requested_subprotocols, url_request_context, origin,
502       site_for_cookies, isolation_info, additional_headers, traffic_annotation,
503       std::move(connect_delegate), std::move(api_delegate));
504   request->Start(std::move(timer));
505   return std::move(request);
506 }
507 
508 }  // namespace net
509