• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_channel.h"
6 
7 #include <limits.h>  // for INT_MAX
8 #include <stddef.h>
9 #include <string.h>
10 
11 #include <algorithm>
12 #include <iterator>
13 #include <ostream>
14 #include <utility>
15 #include <vector>
16 
17 #include "base/big_endian.h"
18 #include "base/check.h"
19 #include "base/check_op.h"
20 #include "base/functional/bind.h"
21 #include "base/location.h"
22 #include "base/logging.h"
23 #include "base/memory/raw_ptr.h"
24 #include "base/numerics/safe_conversions.h"
25 #include "base/ranges/algorithm.h"
26 #include "base/strings/string_piece.h"
27 #include "base/strings/stringprintf.h"
28 #include "base/time/time.h"
29 #include "base/values.h"
30 #include "net/base/io_buffer.h"
31 #include "net/base/net_errors.h"
32 #include "net/http/http_response_headers.h"
33 #include "net/log/net_log_event_type.h"
34 #include "net/log/net_log_with_source.h"
35 #include "net/traffic_annotation/network_traffic_annotation.h"
36 #include "net/websockets/websocket_errors.h"
37 #include "net/websockets/websocket_event_interface.h"
38 #include "net/websockets/websocket_frame.h"
39 #include "net/websockets/websocket_handshake_request_info.h"
40 #include "net/websockets/websocket_handshake_response_info.h"
41 #include "net/websockets/websocket_stream.h"
42 
43 namespace net {
44 class AuthChallengeInfo;
45 class AuthCredentials;
46 class SSLInfo;
47 
48 namespace {
49 
50 using base::StreamingUtf8Validator;
51 
52 constexpr size_t kWebSocketCloseCodeLength = 2;
53 // Timeout for waiting for the server to acknowledge a closing handshake.
54 constexpr int kClosingHandshakeTimeoutSeconds = 60;
55 // We wait for the server to close the underlying connection as recommended in
56 // https://tools.ietf.org/html/rfc6455#section-7.1.1
57 // We don't use 2MSL since there're server implementations that don't follow
58 // the recommendation and wait for the client to close the underlying
59 // connection. It leads to unnecessarily long time before CloseEvent
60 // invocation. We want to avoid this rather than strictly following the spec
61 // recommendation.
62 constexpr int kUnderlyingConnectionCloseTimeoutSeconds = 2;
63 
64 using ChannelState = WebSocketChannel::ChannelState;
65 
66 // Maximum close reason length = max control frame payload -
67 //                               status code length
68 //                             = 125 - 2
69 constexpr size_t kMaximumCloseReasonLength = 125 - kWebSocketCloseCodeLength;
70 
71 // Check a close status code for strict compliance with RFC6455. This is only
72 // used for close codes received from a renderer that we are intending to send
73 // out over the network. See ParseClose() for the restrictions on incoming close
74 // codes. The |code| parameter is type int for convenience of implementation;
75 // the real type is uint16_t. Code 1005 is treated specially; it cannot be set
76 // explicitly by Javascript but the renderer uses it to indicate we should send
77 // a Close frame with no payload.
IsStrictlyValidCloseStatusCode(int code)78 bool IsStrictlyValidCloseStatusCode(int code) {
79   static const int kInvalidRanges[] = {
80       // [BAD, OK)
81       0,    1000,   // 1000 is the first valid code
82       1006, 1007,   // 1006 MUST NOT be set.
83       1014, 3000,   // 1014 unassigned; 1015 up to 2999 are reserved.
84       5000, 65536,  // Codes above 5000 are invalid.
85   };
86   const int* const kInvalidRangesEnd =
87       kInvalidRanges + std::size(kInvalidRanges);
88 
89   DCHECK_GE(code, 0);
90   DCHECK_LT(code, 65536);
91   const int* upper = std::upper_bound(kInvalidRanges, kInvalidRangesEnd, code);
92   DCHECK_NE(kInvalidRangesEnd, upper);
93   DCHECK_GT(upper, kInvalidRanges);
94   DCHECK_GT(*upper, code);
95   DCHECK_LE(*(upper - 1), code);
96   return ((upper - kInvalidRanges) % 2) == 0;
97 }
98 
99 // Sets |name| to the name of the frame type for the given |opcode|. Note that
100 // for all of Text, Binary and Continuation opcode, this method returns
101 // "Data frame".
GetFrameTypeForOpcode(WebSocketFrameHeader::OpCode opcode,std::string * name)102 void GetFrameTypeForOpcode(WebSocketFrameHeader::OpCode opcode,
103                            std::string* name) {
104   switch (opcode) {
105     case WebSocketFrameHeader::kOpCodeText:    // fall-thru
106     case WebSocketFrameHeader::kOpCodeBinary:  // fall-thru
107     case WebSocketFrameHeader::kOpCodeContinuation:
108       *name = "Data frame";
109       break;
110 
111     case WebSocketFrameHeader::kOpCodePing:
112       *name = "Ping";
113       break;
114 
115     case WebSocketFrameHeader::kOpCodePong:
116       *name = "Pong";
117       break;
118 
119     case WebSocketFrameHeader::kOpCodeClose:
120       *name = "Close";
121       break;
122 
123     default:
124       *name = "Unknown frame type";
125       break;
126   }
127 
128   return;
129 }
130 
NetLogFailParam(uint16_t code,base::StringPiece reason,base::StringPiece message)131 base::Value::Dict NetLogFailParam(uint16_t code,
132                                   base::StringPiece reason,
133                                   base::StringPiece message) {
134   base::Value::Dict dict;
135   dict.Set("code", code);
136   dict.Set("reason", reason);
137   dict.Set("internal_reason", message);
138   return dict;
139 }
140 
141 class DependentIOBuffer : public WrappedIOBuffer {
142  public:
DependentIOBuffer(scoped_refptr<IOBufferWithSize> buffer,size_t offset)143   DependentIOBuffer(scoped_refptr<IOBufferWithSize> buffer, size_t offset)
144       : WrappedIOBuffer(buffer->data() + offset, buffer->size() - offset),
145         buffer_(std::move(buffer)) {}
146 
147  private:
~DependentIOBuffer()148   ~DependentIOBuffer() override {
149     // Prevent `data_` from dangling should this destructor remove the
150     // last reference to `buffer_`.
151     data_ = nullptr;
152   }
153 
154   scoped_refptr<IOBufferWithSize> buffer_;
155 };
156 
157 }  // namespace
158 
159 // A class to encapsulate a set of frames and information about the size of
160 // those frames.
161 class WebSocketChannel::SendBuffer {
162  public:
163   SendBuffer() = default;
164 
165   // Add a WebSocketFrame to the buffer and increase total_bytes_.
166   void AddFrame(std::unique_ptr<WebSocketFrame> chunk,
167                 scoped_refptr<IOBuffer> buffer);
168 
169   // Return a pointer to the frames_ for write purposes.
frames()170   std::vector<std::unique_ptr<WebSocketFrame>>* frames() { return &frames_; }
171 
172  private:
173   // The frames_ that will be sent in the next call to WriteFrames().
174   std::vector<std::unique_ptr<WebSocketFrame>> frames_;
175   // References of each WebSocketFrame.data;
176   std::vector<scoped_refptr<IOBuffer>> buffers_;
177 
178   // The total size of the payload data in |frames_|. This will be used to
179   // measure the throughput of the link.
180   // TODO(ricea): Measure the throughput of the link.
181   uint64_t total_bytes_ = 0;
182 };
183 
AddFrame(std::unique_ptr<WebSocketFrame> frame,scoped_refptr<IOBuffer> buffer)184 void WebSocketChannel::SendBuffer::AddFrame(
185     std::unique_ptr<WebSocketFrame> frame,
186     scoped_refptr<IOBuffer> buffer) {
187   total_bytes_ += frame->header.payload_length;
188   frames_.push_back(std::move(frame));
189   buffers_.push_back(std::move(buffer));
190 }
191 
192 // Implementation of WebSocketStream::ConnectDelegate that simply forwards the
193 // calls on to the WebSocketChannel that created it.
194 class WebSocketChannel::ConnectDelegate
195     : public WebSocketStream::ConnectDelegate {
196  public:
ConnectDelegate(WebSocketChannel * creator)197   explicit ConnectDelegate(WebSocketChannel* creator) : creator_(creator) {}
198 
199   ConnectDelegate(const ConnectDelegate&) = delete;
200   ConnectDelegate& operator=(const ConnectDelegate&) = delete;
201 
OnCreateRequest(URLRequest * request)202   void OnCreateRequest(URLRequest* request) override {
203     creator_->OnCreateURLRequest(request);
204   }
205 
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)206   void OnSuccess(
207       std::unique_ptr<WebSocketStream> stream,
208       std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {
209     creator_->OnConnectSuccess(std::move(stream), std::move(response));
210     // |this| may have been deleted.
211   }
212 
OnFailure(const std::string & message,int net_error,absl::optional<int> response_code)213   void OnFailure(const std::string& message,
214                  int net_error,
215                  absl::optional<int> response_code) override {
216     creator_->OnConnectFailure(message, net_error, response_code);
217     // |this| has been deleted.
218   }
219 
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)220   void OnStartOpeningHandshake(
221       std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {
222     creator_->OnStartOpeningHandshake(std::move(request));
223   }
224 
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)225   void OnSSLCertificateError(
226       std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
227           ssl_error_callbacks,
228       int net_error,
229       const SSLInfo& ssl_info,
230       bool fatal) override {
231     creator_->OnSSLCertificateError(std::move(ssl_error_callbacks), net_error,
232                                     ssl_info, fatal);
233   }
234 
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,absl::optional<AuthCredentials> * credentials)235   int OnAuthRequired(const AuthChallengeInfo& auth_info,
236                      scoped_refptr<HttpResponseHeaders> headers,
237                      const IPEndPoint& remote_endpoint,
238                      base::OnceCallback<void(const AuthCredentials*)> callback,
239                      absl::optional<AuthCredentials>* credentials) override {
240     return creator_->OnAuthRequired(auth_info, std::move(headers),
241                                     remote_endpoint, std::move(callback),
242                                     credentials);
243   }
244 
245  private:
246   // A pointer to the WebSocketChannel that created this object. There is no
247   // danger of this pointer being stale, because deleting the WebSocketChannel
248   // cancels the connect process, deleting this object and preventing its
249   // callbacks from being called.
250   const raw_ptr<WebSocketChannel, DanglingUntriaged> creator_;
251 };
252 
WebSocketChannel(std::unique_ptr<WebSocketEventInterface> event_interface,URLRequestContext * url_request_context)253 WebSocketChannel::WebSocketChannel(
254     std::unique_ptr<WebSocketEventInterface> event_interface,
255     URLRequestContext* url_request_context)
256     : event_interface_(std::move(event_interface)),
257       url_request_context_(url_request_context),
258       closing_handshake_timeout_(
259           base::Seconds(kClosingHandshakeTimeoutSeconds)),
260       underlying_connection_close_timeout_(
261           base::Seconds(kUnderlyingConnectionCloseTimeoutSeconds)) {}
262 
~WebSocketChannel()263 WebSocketChannel::~WebSocketChannel() {
264   // The stream may hold a pointer to read_frames_, and so it needs to be
265   // destroyed first.
266   stream_.reset();
267   // The timer may have a callback pointing back to us, so stop it just in case
268   // someone decides to run the event loop from their destructor.
269   close_timer_.Stop();
270 }
271 
SendAddChannelRequest(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation)272 void WebSocketChannel::SendAddChannelRequest(
273     const GURL& socket_url,
274     const std::vector<std::string>& requested_subprotocols,
275     const url::Origin& origin,
276     const SiteForCookies& site_for_cookies,
277     bool has_storage_access,
278     const IsolationInfo& isolation_info,
279     const HttpRequestHeaders& additional_headers,
280     NetworkTrafficAnnotationTag traffic_annotation) {
281   SendAddChannelRequestWithSuppliedCallback(
282       socket_url, requested_subprotocols, origin, site_for_cookies,
283       has_storage_access, isolation_info, additional_headers,
284       traffic_annotation,
285       base::BindOnce(&WebSocketStream::CreateAndConnectStream));
286 }
287 
SetState(State new_state)288 void WebSocketChannel::SetState(State new_state) {
289   DCHECK_NE(state_, new_state);
290 
291   state_ = new_state;
292 }
293 
InClosingState() const294 bool WebSocketChannel::InClosingState() const {
295   // The state RECV_CLOSED is not supported here, because it is only used in one
296   // code path and should not leak into the code in general.
297   DCHECK_NE(RECV_CLOSED, state_)
298       << "InClosingState called with state_ == RECV_CLOSED";
299   return state_ == SEND_CLOSED || state_ == CLOSE_WAIT || state_ == CLOSED;
300 }
301 
SendFrame(bool fin,WebSocketFrameHeader::OpCode op_code,scoped_refptr<IOBuffer> buffer,size_t buffer_size)302 WebSocketChannel::ChannelState WebSocketChannel::SendFrame(
303     bool fin,
304     WebSocketFrameHeader::OpCode op_code,
305     scoped_refptr<IOBuffer> buffer,
306     size_t buffer_size) {
307   DCHECK_LE(buffer_size, static_cast<size_t>(INT_MAX));
308   DCHECK(stream_) << "Got SendFrame without a connection established; fin="
309                   << fin << " op_code=" << op_code
310                   << " buffer_size=" << buffer_size;
311 
312   if (InClosingState()) {
313     DVLOG(1) << "SendFrame called in state " << state_
314              << ". This may be a bug, or a harmless race.";
315     return CHANNEL_ALIVE;
316   }
317 
318   DCHECK_EQ(state_, CONNECTED);
319 
320   DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(op_code))
321       << "Got SendFrame with bogus op_code " << op_code << " fin=" << fin
322       << " buffer_size=" << buffer_size;
323 
324   if (op_code == WebSocketFrameHeader::kOpCodeText ||
325       (op_code == WebSocketFrameHeader::kOpCodeContinuation &&
326        sending_text_message_)) {
327     StreamingUtf8Validator::State state = outgoing_utf8_validator_.AddBytes(
328         base::make_span(buffer->bytes(), buffer_size));
329     if (state == StreamingUtf8Validator::INVALID ||
330         (state == StreamingUtf8Validator::VALID_MIDPOINT && fin)) {
331       // TODO(ricea): Kill renderer.
332       FailChannel("Browser sent a text frame containing invalid UTF-8",
333                   kWebSocketErrorGoingAway, "");
334       return CHANNEL_DELETED;
335       // |this| has been deleted.
336     }
337     sending_text_message_ = !fin;
338     DCHECK(!fin || state == StreamingUtf8Validator::VALID_ENDPOINT);
339   }
340 
341   return SendFrameInternal(fin, op_code, std::move(buffer), buffer_size);
342   // |this| may have been deleted.
343 }
344 
StartClosingHandshake(uint16_t code,const std::string & reason)345 ChannelState WebSocketChannel::StartClosingHandshake(
346     uint16_t code,
347     const std::string& reason) {
348   if (InClosingState()) {
349     // When the associated renderer process is killed while the channel is in
350     // CLOSING state we reach here.
351     DVLOG(1) << "StartClosingHandshake called in state " << state_
352              << ". This may be a bug, or a harmless race.";
353     return CHANNEL_ALIVE;
354   }
355   if (has_received_close_frame_) {
356     // We reach here if the client wants to start a closing handshake while
357     // the browser is waiting for the client to consume incoming data frames
358     // before responding to a closing handshake initiated by the server.
359     // As the client doesn't want the data frames any more, we can respond to
360     // the closing handshake initiated by the server.
361     return RespondToClosingHandshake();
362   }
363   if (state_ == CONNECTING) {
364     // Abort the in-progress handshake and drop the connection immediately.
365     stream_request_.reset();
366     SetState(CLOSED);
367     DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
368     return CHANNEL_DELETED;
369   }
370   DCHECK_EQ(state_, CONNECTED);
371 
372   DCHECK(!close_timer_.IsRunning());
373   // This use of base::Unretained() is safe because we stop the timer in the
374   // destructor.
375   close_timer_.Start(
376       FROM_HERE, closing_handshake_timeout_,
377       base::BindOnce(&WebSocketChannel::CloseTimeout, base::Unretained(this)));
378 
379   // Javascript actually only permits 1000 and 3000-4999, but the implementation
380   // itself may produce different codes. The length of |reason| is also checked
381   // by Javascript.
382   if (!IsStrictlyValidCloseStatusCode(code) ||
383       reason.size() > kMaximumCloseReasonLength) {
384     // "InternalServerError" is actually used for errors from any endpoint, per
385     // errata 3227 to RFC6455. If the renderer is sending us an invalid code or
386     // reason it must be malfunctioning in some way, and based on that we
387     // interpret this as an internal error.
388     if (SendClose(kWebSocketErrorInternalServerError, "") == CHANNEL_DELETED)
389       return CHANNEL_DELETED;
390     DCHECK_EQ(CONNECTED, state_);
391     SetState(SEND_CLOSED);
392     return CHANNEL_ALIVE;
393   }
394   if (SendClose(code, StreamingUtf8Validator::Validate(reason)
395                           ? reason
396                           : std::string()) == CHANNEL_DELETED)
397     return CHANNEL_DELETED;
398   DCHECK_EQ(CONNECTED, state_);
399   SetState(SEND_CLOSED);
400   return CHANNEL_ALIVE;
401 }
402 
SendAddChannelRequestForTesting(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,WebSocketStreamRequestCreationCallback callback)403 void WebSocketChannel::SendAddChannelRequestForTesting(
404     const GURL& socket_url,
405     const std::vector<std::string>& requested_subprotocols,
406     const url::Origin& origin,
407     const SiteForCookies& site_for_cookies,
408     bool has_storage_access,
409     const IsolationInfo& isolation_info,
410     const HttpRequestHeaders& additional_headers,
411     NetworkTrafficAnnotationTag traffic_annotation,
412     WebSocketStreamRequestCreationCallback callback) {
413   SendAddChannelRequestWithSuppliedCallback(
414       socket_url, requested_subprotocols, origin, site_for_cookies,
415       has_storage_access, isolation_info, additional_headers,
416       traffic_annotation, std::move(callback));
417 }
418 
SetClosingHandshakeTimeoutForTesting(base::TimeDelta delay)419 void WebSocketChannel::SetClosingHandshakeTimeoutForTesting(
420     base::TimeDelta delay) {
421   closing_handshake_timeout_ = delay;
422 }
423 
SetUnderlyingConnectionCloseTimeoutForTesting(base::TimeDelta delay)424 void WebSocketChannel::SetUnderlyingConnectionCloseTimeoutForTesting(
425     base::TimeDelta delay) {
426   underlying_connection_close_timeout_ = delay;
427 }
428 
SendAddChannelRequestWithSuppliedCallback(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,WebSocketStreamRequestCreationCallback callback)429 void WebSocketChannel::SendAddChannelRequestWithSuppliedCallback(
430     const GURL& socket_url,
431     const std::vector<std::string>& requested_subprotocols,
432     const url::Origin& origin,
433     const SiteForCookies& site_for_cookies,
434     bool has_storage_access,
435     const IsolationInfo& isolation_info,
436     const HttpRequestHeaders& additional_headers,
437     NetworkTrafficAnnotationTag traffic_annotation,
438     WebSocketStreamRequestCreationCallback callback) {
439   DCHECK_EQ(FRESHLY_CONSTRUCTED, state_);
440   if (!socket_url.SchemeIsWSOrWSS()) {
441     // TODO(ricea): Kill the renderer (this error should have been caught by
442     // Javascript).
443     event_interface_->OnFailChannel("Invalid scheme", ERR_FAILED,
444                                     absl::nullopt);
445     // |this| is deleted here.
446     return;
447   }
448   socket_url_ = socket_url;
449   auto connect_delegate = std::make_unique<ConnectDelegate>(this);
450   stream_request_ = std::move(callback).Run(
451       socket_url_, requested_subprotocols, origin, site_for_cookies,
452       has_storage_access, isolation_info, additional_headers,
453       url_request_context_.get(), NetLogWithSource(), traffic_annotation,
454       std::move(connect_delegate));
455   SetState(CONNECTING);
456 }
457 
OnCreateURLRequest(URLRequest * request)458 void WebSocketChannel::OnCreateURLRequest(URLRequest* request) {
459   event_interface_->OnCreateURLRequest(request);
460 }
461 
OnConnectSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)462 void WebSocketChannel::OnConnectSuccess(
463     std::unique_ptr<WebSocketStream> stream,
464     std::unique_ptr<WebSocketHandshakeResponseInfo> response) {
465   DCHECK(stream);
466   DCHECK_EQ(CONNECTING, state_);
467 
468   stream_ = std::move(stream);
469 
470   SetState(CONNECTED);
471 
472   // |stream_request_| is not used once the connection has succeeded.
473   stream_request_.reset();
474 
475   event_interface_->OnAddChannelResponse(
476       std::move(response), stream_->GetSubProtocol(), stream_->GetExtensions());
477   // |this| may have been deleted after OnAddChannelResponse.
478 }
479 
OnConnectFailure(const std::string & message,int net_error,absl::optional<int> response_code)480 void WebSocketChannel::OnConnectFailure(const std::string& message,
481                                         int net_error,
482                                         absl::optional<int> response_code) {
483   DCHECK_EQ(CONNECTING, state_);
484 
485   // Copy the message before we delete its owner.
486   std::string message_copy = message;
487 
488   SetState(CLOSED);
489   stream_request_.reset();
490 
491   event_interface_->OnFailChannel(message_copy, net_error, response_code);
492   // |this| has been deleted.
493 }
494 
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)495 void WebSocketChannel::OnSSLCertificateError(
496     std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
497         ssl_error_callbacks,
498     int net_error,
499     const SSLInfo& ssl_info,
500     bool fatal) {
501   event_interface_->OnSSLCertificateError(
502       std::move(ssl_error_callbacks), socket_url_, net_error, ssl_info, fatal);
503 }
504 
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,absl::optional<AuthCredentials> * credentials)505 int WebSocketChannel::OnAuthRequired(
506     const AuthChallengeInfo& auth_info,
507     scoped_refptr<HttpResponseHeaders> response_headers,
508     const IPEndPoint& remote_endpoint,
509     base::OnceCallback<void(const AuthCredentials*)> callback,
510     absl::optional<AuthCredentials>* credentials) {
511   return event_interface_->OnAuthRequired(
512       auth_info, std::move(response_headers), remote_endpoint,
513       std::move(callback), credentials);
514 }
515 
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)516 void WebSocketChannel::OnStartOpeningHandshake(
517     std::unique_ptr<WebSocketHandshakeRequestInfo> request) {
518   event_interface_->OnStartOpeningHandshake(std::move(request));
519 }
520 
WriteFrames()521 ChannelState WebSocketChannel::WriteFrames() {
522   int result = OK;
523   do {
524     // This use of base::Unretained is safe because this object owns the
525     // WebSocketStream and destroying it cancels all callbacks.
526     result = stream_->WriteFrames(
527         data_being_sent_->frames(),
528         base::BindOnce(base::IgnoreResult(&WebSocketChannel::OnWriteDone),
529                        base::Unretained(this), false));
530     if (result != ERR_IO_PENDING) {
531       if (OnWriteDone(true, result) == CHANNEL_DELETED)
532         return CHANNEL_DELETED;
533       // OnWriteDone() returns CHANNEL_DELETED on error. Here |state_| is
534       // guaranteed to be the same as before OnWriteDone() call.
535     }
536   } while (result == OK && data_being_sent_);
537   return CHANNEL_ALIVE;
538 }
539 
OnWriteDone(bool synchronous,int result)540 ChannelState WebSocketChannel::OnWriteDone(bool synchronous, int result) {
541   DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
542   DCHECK_NE(CONNECTING, state_);
543   DCHECK_NE(ERR_IO_PENDING, result);
544   DCHECK(data_being_sent_);
545   switch (result) {
546     case OK:
547       if (data_to_send_next_) {
548         data_being_sent_ = std::move(data_to_send_next_);
549         if (!synchronous)
550           return WriteFrames();
551       } else {
552         data_being_sent_.reset();
553         event_interface_->OnSendDataFrameDone();
554       }
555       return CHANNEL_ALIVE;
556 
557     // If a recoverable error condition existed, it would go here.
558 
559     default:
560       DCHECK_LT(result, 0)
561           << "WriteFrames() should only return OK or ERR_ codes";
562 
563       stream_->Close();
564       SetState(CLOSED);
565       DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
566       return CHANNEL_DELETED;
567   }
568 }
569 
ReadFrames()570 ChannelState WebSocketChannel::ReadFrames() {
571   DCHECK(stream_);
572   DCHECK(state_ == CONNECTED || state_ == SEND_CLOSED || state_ == CLOSE_WAIT);
573   DCHECK(read_frames_.empty());
574   if (is_reading_) {
575     return CHANNEL_ALIVE;
576   }
577 
578   if (!InClosingState() && has_received_close_frame_) {
579     DCHECK(!event_interface_->HasPendingDataFrames());
580     // We've been waiting for the client to consume the frames before
581     // responding to the closing handshake initiated by the server.
582     if (RespondToClosingHandshake() == CHANNEL_DELETED) {
583       return CHANNEL_DELETED;
584     }
585   }
586 
587   // TODO(crbug.com/999235): Remove this CHECK.
588   CHECK(event_interface_);
589   while (!event_interface_->HasPendingDataFrames()) {
590     DCHECK(stream_);
591     // This use of base::Unretained is safe because this object owns the
592     // WebSocketStream, and any pending reads will be cancelled when it is
593     // destroyed.
594     const int result = stream_->ReadFrames(
595         &read_frames_,
596         base::BindOnce(base::IgnoreResult(&WebSocketChannel::OnReadDone),
597                        base::Unretained(this), false));
598     if (result == ERR_IO_PENDING) {
599       is_reading_ = true;
600       return CHANNEL_ALIVE;
601     }
602     if (OnReadDone(true, result) == CHANNEL_DELETED) {
603       return CHANNEL_DELETED;
604     }
605     DCHECK_NE(CLOSED, state_);
606     // TODO(crbug.com/999235): Remove this CHECK.
607     CHECK(event_interface_);
608   }
609   return CHANNEL_ALIVE;
610 }
611 
OnReadDone(bool synchronous,int result)612 ChannelState WebSocketChannel::OnReadDone(bool synchronous, int result) {
613   DVLOG(3) << "WebSocketChannel::OnReadDone synchronous?" << synchronous
614            << ", result=" << result
615            << ", read_frames_.size=" << read_frames_.size();
616   DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
617   DCHECK_NE(CONNECTING, state_);
618   DCHECK_NE(ERR_IO_PENDING, result);
619   switch (result) {
620     case OK:
621       // ReadFrames() must use ERR_CONNECTION_CLOSED for a closed connection
622       // with no data read, not an empty response.
623       DCHECK(!read_frames_.empty())
624           << "ReadFrames() returned OK, but nothing was read.";
625       for (auto& read_frame : read_frames_) {
626         if (HandleFrame(std::move(read_frame)) == CHANNEL_DELETED)
627           return CHANNEL_DELETED;
628       }
629       read_frames_.clear();
630       DCHECK_NE(CLOSED, state_);
631       if (!synchronous) {
632         is_reading_ = false;
633         if (!event_interface_->HasPendingDataFrames()) {
634           return ReadFrames();
635         }
636       }
637       return CHANNEL_ALIVE;
638 
639     case ERR_WS_PROTOCOL_ERROR:
640       // This could be kWebSocketErrorProtocolError (specifically, non-minimal
641       // encoding of payload length) or kWebSocketErrorMessageTooBig, or an
642       // extension-specific error.
643       FailChannel("Invalid frame header", kWebSocketErrorProtocolError,
644                   "WebSocket Protocol Error");
645       return CHANNEL_DELETED;
646 
647     default:
648       DCHECK_LT(result, 0)
649           << "ReadFrames() should only return OK or ERR_ codes";
650 
651       stream_->Close();
652       SetState(CLOSED);
653 
654       uint16_t code = kWebSocketErrorAbnormalClosure;
655       std::string reason = "";
656       bool was_clean = false;
657       if (has_received_close_frame_) {
658         code = received_close_code_;
659         reason = received_close_reason_;
660         was_clean = (result == ERR_CONNECTION_CLOSED);
661       }
662 
663       DoDropChannel(was_clean, code, reason);
664       return CHANNEL_DELETED;
665   }
666 }
667 
HandleFrame(std::unique_ptr<WebSocketFrame> frame)668 ChannelState WebSocketChannel::HandleFrame(
669     std::unique_ptr<WebSocketFrame> frame) {
670   if (frame->header.masked) {
671     // RFC6455 Section 5.1 "A client MUST close a connection if it detects a
672     // masked frame."
673     FailChannel(
674         "A server must not mask any frames that it sends to the "
675         "client.",
676         kWebSocketErrorProtocolError, "Masked frame from server");
677     return CHANNEL_DELETED;
678   }
679   const WebSocketFrameHeader::OpCode opcode = frame->header.opcode;
680   DCHECK(!WebSocketFrameHeader::IsKnownControlOpCode(opcode) ||
681          frame->header.final);
682   if (frame->header.reserved1 || frame->header.reserved2 ||
683       frame->header.reserved3) {
684     FailChannel(
685         base::StringPrintf("One or more reserved bits are on: reserved1 = %d, "
686                            "reserved2 = %d, reserved3 = %d",
687                            static_cast<int>(frame->header.reserved1),
688                            static_cast<int>(frame->header.reserved2),
689                            static_cast<int>(frame->header.reserved3)),
690         kWebSocketErrorProtocolError, "Invalid reserved bit");
691     return CHANNEL_DELETED;
692   }
693 
694   // Respond to the frame appropriately to its type.
695   return HandleFrameByState(
696       opcode, frame->header.final,
697       base::make_span(frame->payload, base::checked_cast<size_t>(
698                                           frame->header.payload_length)));
699 }
700 
HandleFrameByState(const WebSocketFrameHeader::OpCode opcode,bool final,base::span<const char> payload)701 ChannelState WebSocketChannel::HandleFrameByState(
702     const WebSocketFrameHeader::OpCode opcode,
703     bool final,
704     base::span<const char> payload) {
705   DCHECK_NE(RECV_CLOSED, state_)
706       << "HandleFrame() does not support being called re-entrantly from within "
707          "SendClose()";
708   DCHECK_NE(CLOSED, state_);
709   if (state_ == CLOSE_WAIT) {
710     std::string frame_name;
711     GetFrameTypeForOpcode(opcode, &frame_name);
712 
713     // FailChannel() won't send another Close frame.
714     FailChannel(frame_name + " received after close",
715                 kWebSocketErrorProtocolError, "");
716     return CHANNEL_DELETED;
717   }
718   switch (opcode) {
719     case WebSocketFrameHeader::kOpCodeText:  // fall-thru
720     case WebSocketFrameHeader::kOpCodeBinary:
721     case WebSocketFrameHeader::kOpCodeContinuation:
722       return HandleDataFrame(opcode, final, std::move(payload));
723 
724     case WebSocketFrameHeader::kOpCodePing:
725       DVLOG(1) << "Got Ping of size " << payload.size();
726       if (state_ == CONNECTED) {
727         auto buffer = base::MakeRefCounted<IOBufferWithSize>(payload.size());
728         base::ranges::copy(payload, buffer->data());
729         return SendFrameInternal(true, WebSocketFrameHeader::kOpCodePong,
730                                  std::move(buffer), payload.size());
731       }
732       DVLOG(3) << "Ignored ping in state " << state_;
733       return CHANNEL_ALIVE;
734 
735     case WebSocketFrameHeader::kOpCodePong:
736       DVLOG(1) << "Got Pong of size " << payload.size();
737       // There is no need to do anything with pong messages.
738       return CHANNEL_ALIVE;
739 
740     case WebSocketFrameHeader::kOpCodeClose: {
741       uint16_t code = kWebSocketNormalClosure;
742       std::string reason;
743       std::string message;
744       if (!ParseClose(payload, &code, &reason, &message)) {
745         FailChannel(message, code, reason);
746         return CHANNEL_DELETED;
747       }
748       // TODO(ricea): Find a way to safely log the message from the close
749       // message (escape control codes and so on).
750       return HandleCloseFrame(code, reason);
751     }
752 
753     default:
754       FailChannel(base::StringPrintf("Unrecognized frame opcode: %d", opcode),
755                   kWebSocketErrorProtocolError, "Unknown opcode");
756       return CHANNEL_DELETED;
757   }
758 }
759 
HandleDataFrame(WebSocketFrameHeader::OpCode opcode,bool final,base::span<const char> payload)760 ChannelState WebSocketChannel::HandleDataFrame(
761     WebSocketFrameHeader::OpCode opcode,
762     bool final,
763     base::span<const char> payload) {
764   DVLOG(3) << "WebSocketChannel::HandleDataFrame opcode=" << opcode
765            << ", final?" << final << ", data=" << (void*)payload.data()
766            << ", size=" << payload.size();
767   if (state_ != CONNECTED) {
768     DVLOG(3) << "Ignored data packet received in state " << state_;
769     return CHANNEL_ALIVE;
770   }
771   if (has_received_close_frame_) {
772     DVLOG(3) << "Ignored data packet as we've received a close frame.";
773     return CHANNEL_ALIVE;
774   }
775   DCHECK(opcode == WebSocketFrameHeader::kOpCodeContinuation ||
776          opcode == WebSocketFrameHeader::kOpCodeText ||
777          opcode == WebSocketFrameHeader::kOpCodeBinary);
778   const bool got_continuation =
779       (opcode == WebSocketFrameHeader::kOpCodeContinuation);
780   if (got_continuation != expecting_to_handle_continuation_) {
781     const std::string console_log = got_continuation
782         ? "Received unexpected continuation frame."
783         : "Received start of new message but previous message is unfinished.";
784     const std::string reason = got_continuation
785         ? "Unexpected continuation"
786         : "Previous data frame unfinished";
787     FailChannel(console_log, kWebSocketErrorProtocolError, reason);
788     return CHANNEL_DELETED;
789   }
790   expecting_to_handle_continuation_ = !final;
791   WebSocketFrameHeader::OpCode opcode_to_send = opcode;
792   if (!initial_frame_forwarded_ &&
793       opcode == WebSocketFrameHeader::kOpCodeContinuation) {
794     opcode_to_send = receiving_text_message_
795                          ? WebSocketFrameHeader::kOpCodeText
796                          : WebSocketFrameHeader::kOpCodeBinary;
797   }
798   if (opcode == WebSocketFrameHeader::kOpCodeText ||
799       (opcode == WebSocketFrameHeader::kOpCodeContinuation &&
800        receiving_text_message_)) {
801     // This call is not redundant when size == 0 because it tells us what
802     // the current state is.
803     StreamingUtf8Validator::State state =
804         incoming_utf8_validator_.AddBytes(base::as_byte_span(payload));
805     if (state == StreamingUtf8Validator::INVALID ||
806         (state == StreamingUtf8Validator::VALID_MIDPOINT && final)) {
807       FailChannel("Could not decode a text frame as UTF-8.",
808                   kWebSocketErrorProtocolError, "Invalid UTF-8 in text frame");
809       return CHANNEL_DELETED;
810     }
811     receiving_text_message_ = !final;
812     DCHECK(!final || state == StreamingUtf8Validator::VALID_ENDPOINT);
813   }
814   if (payload.size() == 0U && !final)
815     return CHANNEL_ALIVE;
816 
817   initial_frame_forwarded_ = !final;
818   // Sends the received frame to the renderer process.
819   event_interface_->OnDataFrame(final, opcode_to_send, payload);
820   return CHANNEL_ALIVE;
821 }
822 
HandleCloseFrame(uint16_t code,const std::string & reason)823 ChannelState WebSocketChannel::HandleCloseFrame(uint16_t code,
824                                                 const std::string& reason) {
825   DVLOG(1) << "Got Close with code " << code;
826   switch (state_) {
827     case CONNECTED:
828       has_received_close_frame_ = true;
829       received_close_code_ = code;
830       received_close_reason_ = reason;
831       if (event_interface_->HasPendingDataFrames()) {
832         // We have some data to be sent to the renderer before sending this
833         // frame.
834         return CHANNEL_ALIVE;
835       }
836       return RespondToClosingHandshake();
837 
838     case SEND_CLOSED:
839       SetState(CLOSE_WAIT);
840       DCHECK(close_timer_.IsRunning());
841       close_timer_.Stop();
842       // This use of base::Unretained() is safe because we stop the timer
843       // in the destructor.
844       close_timer_.Start(FROM_HERE, underlying_connection_close_timeout_,
845                          base::BindOnce(&WebSocketChannel::CloseTimeout,
846                                         base::Unretained(this)));
847 
848       // From RFC6455 section 7.1.5: "Each endpoint
849       // will see the status code sent by the other end as _The WebSocket
850       // Connection Close Code_."
851       has_received_close_frame_ = true;
852       received_close_code_ = code;
853       received_close_reason_ = reason;
854       break;
855 
856     default:
857       LOG(DFATAL) << "Got Close in unexpected state " << state_;
858       break;
859   }
860   return CHANNEL_ALIVE;
861 }
862 
RespondToClosingHandshake()863 ChannelState WebSocketChannel::RespondToClosingHandshake() {
864   DCHECK(has_received_close_frame_);
865   DCHECK_EQ(CONNECTED, state_);
866   SetState(RECV_CLOSED);
867   if (SendClose(received_close_code_, received_close_reason_) ==
868       CHANNEL_DELETED)
869     return CHANNEL_DELETED;
870   DCHECK_EQ(RECV_CLOSED, state_);
871 
872   SetState(CLOSE_WAIT);
873   DCHECK(!close_timer_.IsRunning());
874   // This use of base::Unretained() is safe because we stop the timer
875   // in the destructor.
876   close_timer_.Start(
877       FROM_HERE, underlying_connection_close_timeout_,
878       base::BindOnce(&WebSocketChannel::CloseTimeout, base::Unretained(this)));
879 
880   event_interface_->OnClosingHandshake();
881   return CHANNEL_ALIVE;
882 }
883 
SendFrameInternal(bool fin,WebSocketFrameHeader::OpCode op_code,scoped_refptr<IOBuffer> buffer,uint64_t buffer_size)884 ChannelState WebSocketChannel::SendFrameInternal(
885     bool fin,
886     WebSocketFrameHeader::OpCode op_code,
887     scoped_refptr<IOBuffer> buffer,
888     uint64_t buffer_size) {
889   DCHECK(state_ == CONNECTED || state_ == RECV_CLOSED);
890   DCHECK(stream_);
891 
892   auto frame = std::make_unique<WebSocketFrame>(op_code);
893   WebSocketFrameHeader& header = frame->header;
894   header.final = fin;
895   header.masked = true;
896   header.payload_length = buffer_size;
897   frame->payload = buffer->data();
898 
899   if (data_being_sent_) {
900     // Either the link to the WebSocket server is saturated, or several messages
901     // are being sent in a batch.
902     if (!data_to_send_next_)
903       data_to_send_next_ = std::make_unique<SendBuffer>();
904     data_to_send_next_->AddFrame(std::move(frame), std::move(buffer));
905     return CHANNEL_ALIVE;
906   }
907 
908   data_being_sent_ = std::make_unique<SendBuffer>();
909   data_being_sent_->AddFrame(std::move(frame), std::move(buffer));
910   return WriteFrames();
911 }
912 
FailChannel(const std::string & message,uint16_t code,const std::string & reason)913 void WebSocketChannel::FailChannel(const std::string& message,
914                                    uint16_t code,
915                                    const std::string& reason) {
916   DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
917   DCHECK_NE(CONNECTING, state_);
918   DCHECK_NE(CLOSED, state_);
919 
920   stream_->GetNetLogWithSource().AddEvent(
921       net::NetLogEventType::WEBSOCKET_INVALID_FRAME,
922       [&] { return NetLogFailParam(code, reason, message); });
923 
924   if (state_ == CONNECTED) {
925     if (SendClose(code, reason) == CHANNEL_DELETED)
926       return;
927   }
928 
929   // Careful study of RFC6455 section 7.1.7 and 7.1.1 indicates the browser
930   // should close the connection itself without waiting for the closing
931   // handshake.
932   stream_->Close();
933   SetState(CLOSED);
934   event_interface_->OnFailChannel(message, ERR_FAILED, absl::nullopt);
935 }
936 
SendClose(uint16_t code,const std::string & reason)937 ChannelState WebSocketChannel::SendClose(uint16_t code,
938                                          const std::string& reason) {
939   DCHECK(state_ == CONNECTED || state_ == RECV_CLOSED);
940   DCHECK_LE(reason.size(), kMaximumCloseReasonLength);
941   scoped_refptr<IOBuffer> body;
942   uint64_t size = 0;
943   if (code == kWebSocketErrorNoStatusReceived) {
944     // Special case: translate kWebSocketErrorNoStatusReceived into a Close
945     // frame with no payload.
946     DCHECK(reason.empty());
947     body = base::MakeRefCounted<IOBufferWithSize>();
948   } else {
949     const size_t payload_length = kWebSocketCloseCodeLength + reason.length();
950     body = base::MakeRefCounted<IOBufferWithSize>(payload_length);
951     size = payload_length;
952     base::WriteBigEndian(body->data(), code);
953     static_assert(sizeof(code) == kWebSocketCloseCodeLength,
954                   "they should both be two");
955     base::ranges::copy(reason, body->data() + kWebSocketCloseCodeLength);
956   }
957 
958   return SendFrameInternal(true, WebSocketFrameHeader::kOpCodeClose,
959                            std::move(body), size);
960 }
961 
ParseClose(base::span<const char> payload,uint16_t * code,std::string * reason,std::string * message)962 bool WebSocketChannel::ParseClose(base::span<const char> payload,
963                                   uint16_t* code,
964                                   std::string* reason,
965                                   std::string* message) {
966   const uint64_t size = static_cast<uint64_t>(payload.size());
967   reason->clear();
968   if (size < kWebSocketCloseCodeLength) {
969     if (size == 0U) {
970       *code = kWebSocketErrorNoStatusReceived;
971       return true;
972     }
973 
974     DVLOG(1) << "Close frame with payload size " << size << " received "
975              << "(the first byte is " << std::hex
976              << static_cast<int>(payload[0]) << ")";
977     *code = kWebSocketErrorProtocolError;
978     *message =
979         "Received a broken close frame containing an invalid size body.";
980     return false;
981   }
982 
983   const char* data = payload.data();
984   uint16_t unchecked_code = 0;
985   base::ReadBigEndian(reinterpret_cast<const uint8_t*>(data), &unchecked_code);
986   static_assert(sizeof(unchecked_code) == kWebSocketCloseCodeLength,
987                 "they should both be two bytes");
988 
989   switch (unchecked_code) {
990     case kWebSocketErrorNoStatusReceived:
991     case kWebSocketErrorAbnormalClosure:
992     case kWebSocketErrorTlsHandshake:
993       *code = kWebSocketErrorProtocolError;
994       *message =
995           "Received a broken close frame containing a reserved status code.";
996       return false;
997 
998     default:
999       *code = unchecked_code;
1000       break;
1001   }
1002 
1003   std::string text(data + kWebSocketCloseCodeLength, data + size);
1004   if (StreamingUtf8Validator::Validate(text)) {
1005     reason->swap(text);
1006     return true;
1007   }
1008 
1009   *code = kWebSocketErrorProtocolError;
1010   *reason = "Invalid UTF-8 in Close frame";
1011   *message = "Received a broken close frame containing invalid UTF-8.";
1012   return false;
1013 }
1014 
DoDropChannel(bool was_clean,uint16_t code,const std::string & reason)1015 void WebSocketChannel::DoDropChannel(bool was_clean,
1016                                      uint16_t code,
1017                                      const std::string& reason) {
1018   event_interface_->OnDropChannel(was_clean, code, reason);
1019 }
1020 
CloseTimeout()1021 void WebSocketChannel::CloseTimeout() {
1022   stream_->GetNetLogWithSource().AddEvent(
1023       net::NetLogEventType::WEBSOCKET_CLOSE_TIMEOUT);
1024   stream_->Close();
1025   SetState(CLOSED);
1026   if (has_received_close_frame_) {
1027     DoDropChannel(true, received_close_code_, received_close_reason_);
1028   } else {
1029     DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
1030   }
1031   // |this| has been deleted.
1032 }
1033 
1034 }  // namespace net
1035