• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_channel.h"
6 
7 #include <limits.h>  // for INT_MAX
8 #include <stddef.h>
9 
10 #include <utility>
11 #include <vector>
12 
13 #include "base/big_endian.h"
14 #include "base/containers/circular_deque.h"
15 #include "base/functional/bind.h"
16 #include "base/location.h"
17 #include "base/memory/raw_ptr.h"
18 #include "base/memory/weak_ptr.h"
19 #include "base/numerics/safe_conversions.h"
20 #include "base/ranges/algorithm.h"
21 #include "base/strings/string_piece.h"
22 #include "base/strings/stringprintf.h"
23 #include "base/task/single_thread_task_runner.h"
24 #include "base/time/time.h"
25 #include "net/base/auth.h"
26 #include "net/base/io_buffer.h"
27 #include "net/base/ip_endpoint.h"
28 #include "net/http/http_request_headers.h"
29 #include "net/http/http_response_headers.h"
30 #include "net/http/http_util.h"
31 #include "net/log/net_log_with_source.h"
32 #include "net/traffic_annotation/network_traffic_annotation.h"
33 #include "net/websockets/websocket_errors.h"
34 #include "net/websockets/websocket_event_interface.h"
35 #include "net/websockets/websocket_frame.h"
36 #include "net/websockets/websocket_handshake_request_info.h"
37 #include "net/websockets/websocket_handshake_response_info.h"
38 #include "net/websockets/websocket_stream.h"
39 #include "url/origin.h"
40 
41 namespace net {
42 
43 namespace {
44 
45 using base::StreamingUtf8Validator;
46 
47 constexpr size_t kWebSocketCloseCodeLength = 2;
48 // Timeout for waiting for the server to acknowledge a closing handshake.
49 constexpr int kClosingHandshakeTimeoutSeconds = 60;
50 // We wait for the server to close the underlying connection as recommended in
51 // https://tools.ietf.org/html/rfc6455#section-7.1.1
52 // We don't use 2MSL since there're server implementations that don't follow
53 // the recommendation and wait for the client to close the underlying
54 // connection. It leads to unnecessarily long time before CloseEvent
55 // invocation. We want to avoid this rather than strictly following the spec
56 // recommendation.
57 constexpr int kUnderlyingConnectionCloseTimeoutSeconds = 2;
58 
59 using ChannelState = WebSocketChannel::ChannelState;
60 
61 // Maximum close reason length = max control frame payload -
62 //                               status code length
63 //                             = 125 - 2
64 constexpr size_t kMaximumCloseReasonLength = 125 - kWebSocketCloseCodeLength;
65 
66 // Check a close status code for strict compliance with RFC6455. This is only
67 // used for close codes received from a renderer that we are intending to send
68 // out over the network. See ParseClose() for the restrictions on incoming close
69 // codes. The |code| parameter is type int for convenience of implementation;
70 // the real type is uint16_t. Code 1005 is treated specially; it cannot be set
71 // explicitly by Javascript but the renderer uses it to indicate we should send
72 // a Close frame with no payload.
IsStrictlyValidCloseStatusCode(int code)73 bool IsStrictlyValidCloseStatusCode(int code) {
74   static const int kInvalidRanges[] = {
75       // [BAD, OK)
76       0,    1000,   // 1000 is the first valid code
77       1006, 1007,   // 1006 MUST NOT be set.
78       1014, 3000,   // 1014 unassigned; 1015 up to 2999 are reserved.
79       5000, 65536,  // Codes above 5000 are invalid.
80   };
81   const int* const kInvalidRangesEnd =
82       kInvalidRanges + std::size(kInvalidRanges);
83 
84   DCHECK_GE(code, 0);
85   DCHECK_LT(code, 65536);
86   const int* upper = std::upper_bound(kInvalidRanges, kInvalidRangesEnd, code);
87   DCHECK_NE(kInvalidRangesEnd, upper);
88   DCHECK_GT(upper, kInvalidRanges);
89   DCHECK_GT(*upper, code);
90   DCHECK_LE(*(upper - 1), code);
91   return ((upper - kInvalidRanges) % 2) == 0;
92 }
93 
94 // Sets |name| to the name of the frame type for the given |opcode|. Note that
95 // for all of Text, Binary and Continuation opcode, this method returns
96 // "Data frame".
GetFrameTypeForOpcode(WebSocketFrameHeader::OpCode opcode,std::string * name)97 void GetFrameTypeForOpcode(WebSocketFrameHeader::OpCode opcode,
98                            std::string* name) {
99   switch (opcode) {
100     case WebSocketFrameHeader::kOpCodeText:    // fall-thru
101     case WebSocketFrameHeader::kOpCodeBinary:  // fall-thru
102     case WebSocketFrameHeader::kOpCodeContinuation:
103       *name = "Data frame";
104       break;
105 
106     case WebSocketFrameHeader::kOpCodePing:
107       *name = "Ping";
108       break;
109 
110     case WebSocketFrameHeader::kOpCodePong:
111       *name = "Pong";
112       break;
113 
114     case WebSocketFrameHeader::kOpCodeClose:
115       *name = "Close";
116       break;
117 
118     default:
119       *name = "Unknown frame type";
120       break;
121   }
122 
123   return;
124 }
125 
NetLogFailParam(uint16_t code,base::StringPiece reason,base::StringPiece message)126 base::Value::Dict NetLogFailParam(uint16_t code,
127                                   base::StringPiece reason,
128                                   base::StringPiece message) {
129   base::Value::Dict dict;
130   dict.Set("code", code);
131   dict.Set("reason", reason);
132   dict.Set("internal_reason", message);
133   return dict;
134 }
135 
136 class DependentIOBuffer : public WrappedIOBuffer {
137  public:
DependentIOBuffer(scoped_refptr<IOBuffer> buffer,size_t offset)138   DependentIOBuffer(scoped_refptr<IOBuffer> buffer, size_t offset)
139       : WrappedIOBuffer(buffer->data() + offset), buffer_(std::move(buffer)) {}
140 
141  private:
142   ~DependentIOBuffer() override = default;
143   scoped_refptr<IOBuffer> buffer_;
144 };
145 
146 }  // namespace
147 
148 // A class to encapsulate a set of frames and information about the size of
149 // those frames.
150 class WebSocketChannel::SendBuffer {
151  public:
152   SendBuffer() = default;
153 
154   // Add a WebSocketFrame to the buffer and increase total_bytes_.
155   void AddFrame(std::unique_ptr<WebSocketFrame> chunk,
156                 scoped_refptr<IOBuffer> buffer);
157 
158   // Return a pointer to the frames_ for write purposes.
frames()159   std::vector<std::unique_ptr<WebSocketFrame>>* frames() { return &frames_; }
160 
161  private:
162   // The frames_ that will be sent in the next call to WriteFrames().
163   std::vector<std::unique_ptr<WebSocketFrame>> frames_;
164   // References of each WebSocketFrame.data;
165   std::vector<scoped_refptr<IOBuffer>> buffers_;
166 
167   // The total size of the payload data in |frames_|. This will be used to
168   // measure the throughput of the link.
169   // TODO(ricea): Measure the throughput of the link.
170   uint64_t total_bytes_ = 0;
171 };
172 
AddFrame(std::unique_ptr<WebSocketFrame> frame,scoped_refptr<IOBuffer> buffer)173 void WebSocketChannel::SendBuffer::AddFrame(
174     std::unique_ptr<WebSocketFrame> frame,
175     scoped_refptr<IOBuffer> buffer) {
176   total_bytes_ += frame->header.payload_length;
177   frames_.push_back(std::move(frame));
178   buffers_.push_back(std::move(buffer));
179 }
180 
181 // Implementation of WebSocketStream::ConnectDelegate that simply forwards the
182 // calls on to the WebSocketChannel that created it.
183 class WebSocketChannel::ConnectDelegate
184     : public WebSocketStream::ConnectDelegate {
185  public:
ConnectDelegate(WebSocketChannel * creator)186   explicit ConnectDelegate(WebSocketChannel* creator) : creator_(creator) {}
187 
188   ConnectDelegate(const ConnectDelegate&) = delete;
189   ConnectDelegate& operator=(const ConnectDelegate&) = delete;
190 
OnCreateRequest(URLRequest * request)191   void OnCreateRequest(URLRequest* request) override {
192     creator_->OnCreateURLRequest(request);
193   }
194 
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)195   void OnSuccess(
196       std::unique_ptr<WebSocketStream> stream,
197       std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {
198     creator_->OnConnectSuccess(std::move(stream), std::move(response));
199     // |this| may have been deleted.
200   }
201 
OnFailure(const std::string & message,int net_error,absl::optional<int> response_code)202   void OnFailure(const std::string& message,
203                  int net_error,
204                  absl::optional<int> response_code) override {
205     creator_->OnConnectFailure(message, net_error, response_code);
206     // |this| has been deleted.
207   }
208 
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)209   void OnStartOpeningHandshake(
210       std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {
211     creator_->OnStartOpeningHandshake(std::move(request));
212   }
213 
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)214   void OnSSLCertificateError(
215       std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
216           ssl_error_callbacks,
217       int net_error,
218       const SSLInfo& ssl_info,
219       bool fatal) override {
220     creator_->OnSSLCertificateError(std::move(ssl_error_callbacks), net_error,
221                                     ssl_info, fatal);
222   }
223 
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,absl::optional<AuthCredentials> * credentials)224   int OnAuthRequired(const AuthChallengeInfo& auth_info,
225                      scoped_refptr<HttpResponseHeaders> headers,
226                      const IPEndPoint& remote_endpoint,
227                      base::OnceCallback<void(const AuthCredentials*)> callback,
228                      absl::optional<AuthCredentials>* credentials) override {
229     return creator_->OnAuthRequired(auth_info, std::move(headers),
230                                     remote_endpoint, std::move(callback),
231                                     credentials);
232   }
233 
234  private:
235   // A pointer to the WebSocketChannel that created this object. There is no
236   // danger of this pointer being stale, because deleting the WebSocketChannel
237   // cancels the connect process, deleting this object and preventing its
238   // callbacks from being called.
239   const raw_ptr<WebSocketChannel> creator_;
240 };
241 
WebSocketChannel(std::unique_ptr<WebSocketEventInterface> event_interface,URLRequestContext * url_request_context)242 WebSocketChannel::WebSocketChannel(
243     std::unique_ptr<WebSocketEventInterface> event_interface,
244     URLRequestContext* url_request_context)
245     : event_interface_(std::move(event_interface)),
246       url_request_context_(url_request_context),
247       closing_handshake_timeout_(
248           base::Seconds(kClosingHandshakeTimeoutSeconds)),
249       underlying_connection_close_timeout_(
250           base::Seconds(kUnderlyingConnectionCloseTimeoutSeconds)) {}
251 
~WebSocketChannel()252 WebSocketChannel::~WebSocketChannel() {
253   // The stream may hold a pointer to read_frames_, and so it needs to be
254   // destroyed first.
255   stream_.reset();
256   // The timer may have a callback pointing back to us, so stop it just in case
257   // someone decides to run the event loop from their destructor.
258   close_timer_.Stop();
259 }
260 
SendAddChannelRequest(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation)261 void WebSocketChannel::SendAddChannelRequest(
262     const GURL& socket_url,
263     const std::vector<std::string>& requested_subprotocols,
264     const url::Origin& origin,
265     const SiteForCookies& site_for_cookies,
266     const IsolationInfo& isolation_info,
267     const HttpRequestHeaders& additional_headers,
268     NetworkTrafficAnnotationTag traffic_annotation) {
269   SendAddChannelRequestWithSuppliedCallback(
270       socket_url, requested_subprotocols, origin, site_for_cookies,
271       isolation_info, additional_headers, traffic_annotation,
272       base::BindOnce(&WebSocketStream::CreateAndConnectStream));
273 }
274 
SetState(State new_state)275 void WebSocketChannel::SetState(State new_state) {
276   DCHECK_NE(state_, new_state);
277 
278   state_ = new_state;
279 }
280 
InClosingState() const281 bool WebSocketChannel::InClosingState() const {
282   // The state RECV_CLOSED is not supported here, because it is only used in one
283   // code path and should not leak into the code in general.
284   DCHECK_NE(RECV_CLOSED, state_)
285       << "InClosingState called with state_ == RECV_CLOSED";
286   return state_ == SEND_CLOSED || state_ == CLOSE_WAIT || state_ == CLOSED;
287 }
288 
SendFrame(bool fin,WebSocketFrameHeader::OpCode op_code,scoped_refptr<IOBuffer> buffer,size_t buffer_size)289 WebSocketChannel::ChannelState WebSocketChannel::SendFrame(
290     bool fin,
291     WebSocketFrameHeader::OpCode op_code,
292     scoped_refptr<IOBuffer> buffer,
293     size_t buffer_size) {
294   DCHECK_LE(buffer_size, static_cast<size_t>(INT_MAX));
295   DCHECK(stream_) << "Got SendFrame without a connection established; fin="
296                   << fin << " op_code=" << op_code
297                   << " buffer_size=" << buffer_size;
298 
299   if (InClosingState()) {
300     DVLOG(1) << "SendFrame called in state " << state_
301              << ". This may be a bug, or a harmless race.";
302     return CHANNEL_ALIVE;
303   }
304 
305   DCHECK_EQ(state_, CONNECTED);
306 
307   DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(op_code))
308       << "Got SendFrame with bogus op_code " << op_code << " fin=" << fin
309       << " buffer_size=" << buffer_size;
310 
311   if (op_code == WebSocketFrameHeader::kOpCodeText ||
312       (op_code == WebSocketFrameHeader::kOpCodeContinuation &&
313        sending_text_message_)) {
314     StreamingUtf8Validator::State state =
315         outgoing_utf8_validator_.AddBytes(buffer->data(), buffer_size);
316     if (state == StreamingUtf8Validator::INVALID ||
317         (state == StreamingUtf8Validator::VALID_MIDPOINT && fin)) {
318       // TODO(ricea): Kill renderer.
319       FailChannel("Browser sent a text frame containing invalid UTF-8",
320                   kWebSocketErrorGoingAway, "");
321       return CHANNEL_DELETED;
322       // |this| has been deleted.
323     }
324     sending_text_message_ = !fin;
325     DCHECK(!fin || state == StreamingUtf8Validator::VALID_ENDPOINT);
326   }
327 
328   return SendFrameInternal(fin, op_code, std::move(buffer), buffer_size);
329   // |this| may have been deleted.
330 }
331 
StartClosingHandshake(uint16_t code,const std::string & reason)332 ChannelState WebSocketChannel::StartClosingHandshake(
333     uint16_t code,
334     const std::string& reason) {
335   if (InClosingState()) {
336     // When the associated renderer process is killed while the channel is in
337     // CLOSING state we reach here.
338     DVLOG(1) << "StartClosingHandshake called in state " << state_
339              << ". This may be a bug, or a harmless race.";
340     return CHANNEL_ALIVE;
341   }
342   if (has_received_close_frame_) {
343     // We reach here if the client wants to start a closing handshake while
344     // the browser is waiting for the client to consume incoming data frames
345     // before responding to a closing handshake initiated by the server.
346     // As the client doesn't want the data frames any more, we can respond to
347     // the closing handshake initiated by the server.
348     return RespondToClosingHandshake();
349   }
350   if (state_ == CONNECTING) {
351     // Abort the in-progress handshake and drop the connection immediately.
352     stream_request_.reset();
353     SetState(CLOSED);
354     DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
355     return CHANNEL_DELETED;
356   }
357   DCHECK_EQ(state_, CONNECTED);
358 
359   DCHECK(!close_timer_.IsRunning());
360   // This use of base::Unretained() is safe because we stop the timer in the
361   // destructor.
362   close_timer_.Start(
363       FROM_HERE, closing_handshake_timeout_,
364       base::BindOnce(&WebSocketChannel::CloseTimeout, base::Unretained(this)));
365 
366   // Javascript actually only permits 1000 and 3000-4999, but the implementation
367   // itself may produce different codes. The length of |reason| is also checked
368   // by Javascript.
369   if (!IsStrictlyValidCloseStatusCode(code) ||
370       reason.size() > kMaximumCloseReasonLength) {
371     // "InternalServerError" is actually used for errors from any endpoint, per
372     // errata 3227 to RFC6455. If the renderer is sending us an invalid code or
373     // reason it must be malfunctioning in some way, and based on that we
374     // interpret this as an internal error.
375     if (SendClose(kWebSocketErrorInternalServerError, "") == CHANNEL_DELETED)
376       return CHANNEL_DELETED;
377     DCHECK_EQ(CONNECTED, state_);
378     SetState(SEND_CLOSED);
379     return CHANNEL_ALIVE;
380   }
381   if (SendClose(code, StreamingUtf8Validator::Validate(reason)
382                           ? reason
383                           : std::string()) == CHANNEL_DELETED)
384     return CHANNEL_DELETED;
385   DCHECK_EQ(CONNECTED, state_);
386   SetState(SEND_CLOSED);
387   return CHANNEL_ALIVE;
388 }
389 
SendAddChannelRequestForTesting(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,WebSocketStreamRequestCreationCallback callback)390 void WebSocketChannel::SendAddChannelRequestForTesting(
391     const GURL& socket_url,
392     const std::vector<std::string>& requested_subprotocols,
393     const url::Origin& origin,
394     const SiteForCookies& site_for_cookies,
395     const IsolationInfo& isolation_info,
396     const HttpRequestHeaders& additional_headers,
397     NetworkTrafficAnnotationTag traffic_annotation,
398     WebSocketStreamRequestCreationCallback callback) {
399   SendAddChannelRequestWithSuppliedCallback(
400       socket_url, requested_subprotocols, origin, site_for_cookies,
401       isolation_info, additional_headers, traffic_annotation,
402       std::move(callback));
403 }
404 
SetClosingHandshakeTimeoutForTesting(base::TimeDelta delay)405 void WebSocketChannel::SetClosingHandshakeTimeoutForTesting(
406     base::TimeDelta delay) {
407   closing_handshake_timeout_ = delay;
408 }
409 
SetUnderlyingConnectionCloseTimeoutForTesting(base::TimeDelta delay)410 void WebSocketChannel::SetUnderlyingConnectionCloseTimeoutForTesting(
411     base::TimeDelta delay) {
412   underlying_connection_close_timeout_ = delay;
413 }
414 
SendAddChannelRequestWithSuppliedCallback(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,WebSocketStreamRequestCreationCallback callback)415 void WebSocketChannel::SendAddChannelRequestWithSuppliedCallback(
416     const GURL& socket_url,
417     const std::vector<std::string>& requested_subprotocols,
418     const url::Origin& origin,
419     const SiteForCookies& site_for_cookies,
420     const IsolationInfo& isolation_info,
421     const HttpRequestHeaders& additional_headers,
422     NetworkTrafficAnnotationTag traffic_annotation,
423     WebSocketStreamRequestCreationCallback callback) {
424   DCHECK_EQ(FRESHLY_CONSTRUCTED, state_);
425   if (!socket_url.SchemeIsWSOrWSS()) {
426     // TODO(ricea): Kill the renderer (this error should have been caught by
427     // Javascript).
428     event_interface_->OnFailChannel("Invalid scheme", ERR_FAILED,
429                                     absl::nullopt);
430     // |this| is deleted here.
431     return;
432   }
433   socket_url_ = socket_url;
434   auto connect_delegate = std::make_unique<ConnectDelegate>(this);
435   stream_request_ = std::move(callback).Run(
436       socket_url_, requested_subprotocols, origin, site_for_cookies,
437       isolation_info, additional_headers, url_request_context_.get(),
438       NetLogWithSource(), traffic_annotation, std::move(connect_delegate));
439   SetState(CONNECTING);
440 }
441 
OnCreateURLRequest(URLRequest * request)442 void WebSocketChannel::OnCreateURLRequest(URLRequest* request) {
443   event_interface_->OnCreateURLRequest(request);
444 }
445 
OnConnectSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)446 void WebSocketChannel::OnConnectSuccess(
447     std::unique_ptr<WebSocketStream> stream,
448     std::unique_ptr<WebSocketHandshakeResponseInfo> response) {
449   DCHECK(stream);
450   DCHECK_EQ(CONNECTING, state_);
451 
452   stream_ = std::move(stream);
453 
454   SetState(CONNECTED);
455 
456   // |stream_request_| is not used once the connection has succeeded.
457   stream_request_.reset();
458 
459   event_interface_->OnAddChannelResponse(
460       std::move(response), stream_->GetSubProtocol(), stream_->GetExtensions());
461   // |this| may have been deleted after OnAddChannelResponse.
462 }
463 
OnConnectFailure(const std::string & message,int net_error,absl::optional<int> response_code)464 void WebSocketChannel::OnConnectFailure(const std::string& message,
465                                         int net_error,
466                                         absl::optional<int> response_code) {
467   DCHECK_EQ(CONNECTING, state_);
468 
469   // Copy the message before we delete its owner.
470   std::string message_copy = message;
471 
472   SetState(CLOSED);
473   stream_request_.reset();
474 
475   event_interface_->OnFailChannel(message_copy, net_error, response_code);
476   // |this| has been deleted.
477 }
478 
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)479 void WebSocketChannel::OnSSLCertificateError(
480     std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
481         ssl_error_callbacks,
482     int net_error,
483     const SSLInfo& ssl_info,
484     bool fatal) {
485   event_interface_->OnSSLCertificateError(
486       std::move(ssl_error_callbacks), socket_url_, net_error, ssl_info, fatal);
487 }
488 
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,absl::optional<AuthCredentials> * credentials)489 int WebSocketChannel::OnAuthRequired(
490     const AuthChallengeInfo& auth_info,
491     scoped_refptr<HttpResponseHeaders> response_headers,
492     const IPEndPoint& remote_endpoint,
493     base::OnceCallback<void(const AuthCredentials*)> callback,
494     absl::optional<AuthCredentials>* credentials) {
495   return event_interface_->OnAuthRequired(
496       auth_info, std::move(response_headers), remote_endpoint,
497       std::move(callback), credentials);
498 }
499 
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)500 void WebSocketChannel::OnStartOpeningHandshake(
501     std::unique_ptr<WebSocketHandshakeRequestInfo> request) {
502   event_interface_->OnStartOpeningHandshake(std::move(request));
503 }
504 
WriteFrames()505 ChannelState WebSocketChannel::WriteFrames() {
506   int result = OK;
507   do {
508     // This use of base::Unretained is safe because this object owns the
509     // WebSocketStream and destroying it cancels all callbacks.
510     result = stream_->WriteFrames(
511         data_being_sent_->frames(),
512         base::BindOnce(base::IgnoreResult(&WebSocketChannel::OnWriteDone),
513                        base::Unretained(this), false));
514     if (result != ERR_IO_PENDING) {
515       if (OnWriteDone(true, result) == CHANNEL_DELETED)
516         return CHANNEL_DELETED;
517       // OnWriteDone() returns CHANNEL_DELETED on error. Here |state_| is
518       // guaranteed to be the same as before OnWriteDone() call.
519     }
520   } while (result == OK && data_being_sent_);
521   return CHANNEL_ALIVE;
522 }
523 
OnWriteDone(bool synchronous,int result)524 ChannelState WebSocketChannel::OnWriteDone(bool synchronous, int result) {
525   DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
526   DCHECK_NE(CONNECTING, state_);
527   DCHECK_NE(ERR_IO_PENDING, result);
528   DCHECK(data_being_sent_);
529   switch (result) {
530     case OK:
531       if (data_to_send_next_) {
532         data_being_sent_ = std::move(data_to_send_next_);
533         if (!synchronous)
534           return WriteFrames();
535       } else {
536         data_being_sent_.reset();
537         event_interface_->OnSendDataFrameDone();
538       }
539       return CHANNEL_ALIVE;
540 
541     // If a recoverable error condition existed, it would go here.
542 
543     default:
544       DCHECK_LT(result, 0)
545           << "WriteFrames() should only return OK or ERR_ codes";
546 
547       stream_->Close();
548       SetState(CLOSED);
549       DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
550       return CHANNEL_DELETED;
551   }
552 }
553 
ReadFrames()554 ChannelState WebSocketChannel::ReadFrames() {
555   DCHECK(stream_);
556   DCHECK(state_ == CONNECTED || state_ == SEND_CLOSED || state_ == CLOSE_WAIT);
557   DCHECK(read_frames_.empty());
558   if (is_reading_) {
559     return CHANNEL_ALIVE;
560   }
561 
562   if (!InClosingState() && has_received_close_frame_) {
563     DCHECK(!event_interface_->HasPendingDataFrames());
564     // We've been waiting for the client to consume the frames before
565     // responding to the closing handshake initiated by the server.
566     if (RespondToClosingHandshake() == CHANNEL_DELETED) {
567       return CHANNEL_DELETED;
568     }
569   }
570 
571   // TODO(crbug.com/999235): Remove this CHECK.
572   CHECK(event_interface_);
573   while (!event_interface_->HasPendingDataFrames()) {
574     DCHECK(stream_);
575     // This use of base::Unretained is safe because this object owns the
576     // WebSocketStream, and any pending reads will be cancelled when it is
577     // destroyed.
578     const int result = stream_->ReadFrames(
579         &read_frames_,
580         base::BindOnce(base::IgnoreResult(&WebSocketChannel::OnReadDone),
581                        base::Unretained(this), false));
582     if (result == ERR_IO_PENDING) {
583       is_reading_ = true;
584       return CHANNEL_ALIVE;
585     }
586     if (OnReadDone(true, result) == CHANNEL_DELETED) {
587       return CHANNEL_DELETED;
588     }
589     DCHECK_NE(CLOSED, state_);
590     // TODO(crbug.com/999235): Remove this CHECK.
591     CHECK(event_interface_);
592   }
593   return CHANNEL_ALIVE;
594 }
595 
OnReadDone(bool synchronous,int result)596 ChannelState WebSocketChannel::OnReadDone(bool synchronous, int result) {
597   DVLOG(3) << "WebSocketChannel::OnReadDone synchronous?" << synchronous
598            << ", result=" << result
599            << ", read_frames_.size=" << read_frames_.size();
600   DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
601   DCHECK_NE(CONNECTING, state_);
602   DCHECK_NE(ERR_IO_PENDING, result);
603   switch (result) {
604     case OK:
605       // ReadFrames() must use ERR_CONNECTION_CLOSED for a closed connection
606       // with no data read, not an empty response.
607       DCHECK(!read_frames_.empty())
608           << "ReadFrames() returned OK, but nothing was read.";
609       for (auto& read_frame : read_frames_) {
610         if (HandleFrame(std::move(read_frame)) == CHANNEL_DELETED)
611           return CHANNEL_DELETED;
612       }
613       read_frames_.clear();
614       DCHECK_NE(CLOSED, state_);
615       if (!synchronous) {
616         is_reading_ = false;
617         if (!event_interface_->HasPendingDataFrames()) {
618           return ReadFrames();
619         }
620       }
621       return CHANNEL_ALIVE;
622 
623     case ERR_WS_PROTOCOL_ERROR:
624       // This could be kWebSocketErrorProtocolError (specifically, non-minimal
625       // encoding of payload length) or kWebSocketErrorMessageTooBig, or an
626       // extension-specific error.
627       FailChannel("Invalid frame header", kWebSocketErrorProtocolError,
628                   "WebSocket Protocol Error");
629       return CHANNEL_DELETED;
630 
631     default:
632       DCHECK_LT(result, 0)
633           << "ReadFrames() should only return OK or ERR_ codes";
634 
635       stream_->Close();
636       SetState(CLOSED);
637 
638       uint16_t code = kWebSocketErrorAbnormalClosure;
639       std::string reason = "";
640       bool was_clean = false;
641       if (has_received_close_frame_) {
642         code = received_close_code_;
643         reason = received_close_reason_;
644         was_clean = (result == ERR_CONNECTION_CLOSED);
645       }
646 
647       DoDropChannel(was_clean, code, reason);
648       return CHANNEL_DELETED;
649   }
650 }
651 
HandleFrame(std::unique_ptr<WebSocketFrame> frame)652 ChannelState WebSocketChannel::HandleFrame(
653     std::unique_ptr<WebSocketFrame> frame) {
654   if (frame->header.masked) {
655     // RFC6455 Section 5.1 "A client MUST close a connection if it detects a
656     // masked frame."
657     FailChannel(
658         "A server must not mask any frames that it sends to the "
659         "client.",
660         kWebSocketErrorProtocolError, "Masked frame from server");
661     return CHANNEL_DELETED;
662   }
663   const WebSocketFrameHeader::OpCode opcode = frame->header.opcode;
664   DCHECK(!WebSocketFrameHeader::IsKnownControlOpCode(opcode) ||
665          frame->header.final);
666   if (frame->header.reserved1 || frame->header.reserved2 ||
667       frame->header.reserved3) {
668     FailChannel(
669         base::StringPrintf("One or more reserved bits are on: reserved1 = %d, "
670                            "reserved2 = %d, reserved3 = %d",
671                            static_cast<int>(frame->header.reserved1),
672                            static_cast<int>(frame->header.reserved2),
673                            static_cast<int>(frame->header.reserved3)),
674         kWebSocketErrorProtocolError, "Invalid reserved bit");
675     return CHANNEL_DELETED;
676   }
677 
678   // Respond to the frame appropriately to its type.
679   return HandleFrameByState(
680       opcode, frame->header.final,
681       base::make_span(frame->payload, base::checked_cast<size_t>(
682                                           frame->header.payload_length)));
683 }
684 
HandleFrameByState(const WebSocketFrameHeader::OpCode opcode,bool final,base::span<const char> payload)685 ChannelState WebSocketChannel::HandleFrameByState(
686     const WebSocketFrameHeader::OpCode opcode,
687     bool final,
688     base::span<const char> payload) {
689   DCHECK_NE(RECV_CLOSED, state_)
690       << "HandleFrame() does not support being called re-entrantly from within "
691          "SendClose()";
692   DCHECK_NE(CLOSED, state_);
693   if (state_ == CLOSE_WAIT) {
694     std::string frame_name;
695     GetFrameTypeForOpcode(opcode, &frame_name);
696 
697     // FailChannel() won't send another Close frame.
698     FailChannel(frame_name + " received after close",
699                 kWebSocketErrorProtocolError, "");
700     return CHANNEL_DELETED;
701   }
702   switch (opcode) {
703     case WebSocketFrameHeader::kOpCodeText:  // fall-thru
704     case WebSocketFrameHeader::kOpCodeBinary:
705     case WebSocketFrameHeader::kOpCodeContinuation:
706       return HandleDataFrame(opcode, final, std::move(payload));
707 
708     case WebSocketFrameHeader::kOpCodePing:
709       DVLOG(1) << "Got Ping of size " << payload.size();
710       if (state_ == CONNECTED) {
711         auto buffer = base::MakeRefCounted<IOBuffer>(payload.size());
712         memcpy(buffer->data(), payload.data(), payload.size());
713         return SendFrameInternal(true, WebSocketFrameHeader::kOpCodePong,
714                                  std::move(buffer), payload.size());
715       }
716       DVLOG(3) << "Ignored ping in state " << state_;
717       return CHANNEL_ALIVE;
718 
719     case WebSocketFrameHeader::kOpCodePong:
720       DVLOG(1) << "Got Pong of size " << payload.size();
721       // There is no need to do anything with pong messages.
722       return CHANNEL_ALIVE;
723 
724     case WebSocketFrameHeader::kOpCodeClose: {
725       uint16_t code = kWebSocketNormalClosure;
726       std::string reason;
727       std::string message;
728       if (!ParseClose(payload, &code, &reason, &message)) {
729         FailChannel(message, code, reason);
730         return CHANNEL_DELETED;
731       }
732       // TODO(ricea): Find a way to safely log the message from the close
733       // message (escape control codes and so on).
734       return HandleCloseFrame(code, reason);
735     }
736 
737     default:
738       FailChannel(base::StringPrintf("Unrecognized frame opcode: %d", opcode),
739                   kWebSocketErrorProtocolError, "Unknown opcode");
740       return CHANNEL_DELETED;
741   }
742 }
743 
HandleDataFrame(WebSocketFrameHeader::OpCode opcode,bool final,base::span<const char> payload)744 ChannelState WebSocketChannel::HandleDataFrame(
745     WebSocketFrameHeader::OpCode opcode,
746     bool final,
747     base::span<const char> payload) {
748   DVLOG(3) << "WebSocketChannel::HandleDataFrame opcode=" << opcode
749            << ", final?" << final << ", data=" << (void*)payload.data()
750            << ", size=" << payload.size();
751   if (state_ != CONNECTED) {
752     DVLOG(3) << "Ignored data packet received in state " << state_;
753     return CHANNEL_ALIVE;
754   }
755   if (has_received_close_frame_) {
756     DVLOG(3) << "Ignored data packet as we've received a close frame.";
757     return CHANNEL_ALIVE;
758   }
759   DCHECK(opcode == WebSocketFrameHeader::kOpCodeContinuation ||
760          opcode == WebSocketFrameHeader::kOpCodeText ||
761          opcode == WebSocketFrameHeader::kOpCodeBinary);
762   const bool got_continuation =
763       (opcode == WebSocketFrameHeader::kOpCodeContinuation);
764   if (got_continuation != expecting_to_handle_continuation_) {
765     const std::string console_log = got_continuation
766         ? "Received unexpected continuation frame."
767         : "Received start of new message but previous message is unfinished.";
768     const std::string reason = got_continuation
769         ? "Unexpected continuation"
770         : "Previous data frame unfinished";
771     FailChannel(console_log, kWebSocketErrorProtocolError, reason);
772     return CHANNEL_DELETED;
773   }
774   expecting_to_handle_continuation_ = !final;
775   WebSocketFrameHeader::OpCode opcode_to_send = opcode;
776   if (!initial_frame_forwarded_ &&
777       opcode == WebSocketFrameHeader::kOpCodeContinuation) {
778     opcode_to_send = receiving_text_message_
779                          ? WebSocketFrameHeader::kOpCodeText
780                          : WebSocketFrameHeader::kOpCodeBinary;
781   }
782   if (opcode == WebSocketFrameHeader::kOpCodeText ||
783       (opcode == WebSocketFrameHeader::kOpCodeContinuation &&
784        receiving_text_message_)) {
785     // This call is not redundant when size == 0 because it tells us what
786     // the current state is.
787     StreamingUtf8Validator::State state = incoming_utf8_validator_.AddBytes(
788         payload.data(), static_cast<size_t>(payload.size()));
789     if (state == StreamingUtf8Validator::INVALID ||
790         (state == StreamingUtf8Validator::VALID_MIDPOINT && final)) {
791       FailChannel("Could not decode a text frame as UTF-8.",
792                   kWebSocketErrorProtocolError, "Invalid UTF-8 in text frame");
793       return CHANNEL_DELETED;
794     }
795     receiving_text_message_ = !final;
796     DCHECK(!final || state == StreamingUtf8Validator::VALID_ENDPOINT);
797   }
798   if (payload.size() == 0U && !final)
799     return CHANNEL_ALIVE;
800 
801   initial_frame_forwarded_ = !final;
802   // Sends the received frame to the renderer process.
803   event_interface_->OnDataFrame(final, opcode_to_send, payload);
804   return CHANNEL_ALIVE;
805 }
806 
HandleCloseFrame(uint16_t code,const std::string & reason)807 ChannelState WebSocketChannel::HandleCloseFrame(uint16_t code,
808                                                 const std::string& reason) {
809   DVLOG(1) << "Got Close with code " << code;
810   switch (state_) {
811     case CONNECTED:
812       has_received_close_frame_ = true;
813       received_close_code_ = code;
814       received_close_reason_ = reason;
815       if (event_interface_->HasPendingDataFrames()) {
816         // We have some data to be sent to the renderer before sending this
817         // frame.
818         return CHANNEL_ALIVE;
819       }
820       return RespondToClosingHandshake();
821 
822     case SEND_CLOSED:
823       SetState(CLOSE_WAIT);
824       DCHECK(close_timer_.IsRunning());
825       close_timer_.Stop();
826       // This use of base::Unretained() is safe because we stop the timer
827       // in the destructor.
828       close_timer_.Start(FROM_HERE, underlying_connection_close_timeout_,
829                          base::BindOnce(&WebSocketChannel::CloseTimeout,
830                                         base::Unretained(this)));
831 
832       // From RFC6455 section 7.1.5: "Each endpoint
833       // will see the status code sent by the other end as _The WebSocket
834       // Connection Close Code_."
835       has_received_close_frame_ = true;
836       received_close_code_ = code;
837       received_close_reason_ = reason;
838       break;
839 
840     default:
841       LOG(DFATAL) << "Got Close in unexpected state " << state_;
842       break;
843   }
844   return CHANNEL_ALIVE;
845 }
846 
RespondToClosingHandshake()847 ChannelState WebSocketChannel::RespondToClosingHandshake() {
848   DCHECK(has_received_close_frame_);
849   DCHECK_EQ(CONNECTED, state_);
850   SetState(RECV_CLOSED);
851   if (SendClose(received_close_code_, received_close_reason_) ==
852       CHANNEL_DELETED)
853     return CHANNEL_DELETED;
854   DCHECK_EQ(RECV_CLOSED, state_);
855 
856   SetState(CLOSE_WAIT);
857   DCHECK(!close_timer_.IsRunning());
858   // This use of base::Unretained() is safe because we stop the timer
859   // in the destructor.
860   close_timer_.Start(
861       FROM_HERE, underlying_connection_close_timeout_,
862       base::BindOnce(&WebSocketChannel::CloseTimeout, base::Unretained(this)));
863 
864   event_interface_->OnClosingHandshake();
865   return CHANNEL_ALIVE;
866 }
867 
SendFrameInternal(bool fin,WebSocketFrameHeader::OpCode op_code,scoped_refptr<IOBuffer> buffer,uint64_t buffer_size)868 ChannelState WebSocketChannel::SendFrameInternal(
869     bool fin,
870     WebSocketFrameHeader::OpCode op_code,
871     scoped_refptr<IOBuffer> buffer,
872     uint64_t buffer_size) {
873   DCHECK(state_ == CONNECTED || state_ == RECV_CLOSED);
874   DCHECK(stream_);
875 
876   auto frame = std::make_unique<WebSocketFrame>(op_code);
877   WebSocketFrameHeader& header = frame->header;
878   header.final = fin;
879   header.masked = true;
880   header.payload_length = buffer_size;
881   frame->payload = buffer->data();
882 
883   if (data_being_sent_) {
884     // Either the link to the WebSocket server is saturated, or several messages
885     // are being sent in a batch.
886     if (!data_to_send_next_)
887       data_to_send_next_ = std::make_unique<SendBuffer>();
888     data_to_send_next_->AddFrame(std::move(frame), std::move(buffer));
889     return CHANNEL_ALIVE;
890   }
891 
892   data_being_sent_ = std::make_unique<SendBuffer>();
893   data_being_sent_->AddFrame(std::move(frame), std::move(buffer));
894   return WriteFrames();
895 }
896 
FailChannel(const std::string & message,uint16_t code,const std::string & reason)897 void WebSocketChannel::FailChannel(const std::string& message,
898                                    uint16_t code,
899                                    const std::string& reason) {
900   DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
901   DCHECK_NE(CONNECTING, state_);
902   DCHECK_NE(CLOSED, state_);
903 
904   stream_->GetNetLogWithSource().AddEvent(
905       net::NetLogEventType::WEBSOCKET_INVALID_FRAME,
906       [&] { return NetLogFailParam(code, reason, message); });
907 
908   if (state_ == CONNECTED) {
909     if (SendClose(code, reason) == CHANNEL_DELETED)
910       return;
911   }
912 
913   // Careful study of RFC6455 section 7.1.7 and 7.1.1 indicates the browser
914   // should close the connection itself without waiting for the closing
915   // handshake.
916   stream_->Close();
917   SetState(CLOSED);
918   event_interface_->OnFailChannel(message, ERR_FAILED, absl::nullopt);
919 }
920 
SendClose(uint16_t code,const std::string & reason)921 ChannelState WebSocketChannel::SendClose(uint16_t code,
922                                          const std::string& reason) {
923   DCHECK(state_ == CONNECTED || state_ == RECV_CLOSED);
924   DCHECK_LE(reason.size(), kMaximumCloseReasonLength);
925   scoped_refptr<IOBuffer> body;
926   uint64_t size = 0;
927   if (code == kWebSocketErrorNoStatusReceived) {
928     // Special case: translate kWebSocketErrorNoStatusReceived into a Close
929     // frame with no payload.
930     DCHECK(reason.empty());
931     body = base::MakeRefCounted<IOBuffer>(0);
932   } else {
933     const size_t payload_length = kWebSocketCloseCodeLength + reason.length();
934     body = base::MakeRefCounted<IOBuffer>(payload_length);
935     size = payload_length;
936     base::WriteBigEndian(body->data(), code);
937     static_assert(sizeof(code) == kWebSocketCloseCodeLength,
938                   "they should both be two");
939     base::ranges::copy(reason, body->data() + kWebSocketCloseCodeLength);
940   }
941 
942   return SendFrameInternal(true, WebSocketFrameHeader::kOpCodeClose,
943                            std::move(body), size);
944 }
945 
ParseClose(base::span<const char> payload,uint16_t * code,std::string * reason,std::string * message)946 bool WebSocketChannel::ParseClose(base::span<const char> payload,
947                                   uint16_t* code,
948                                   std::string* reason,
949                                   std::string* message) {
950   const uint64_t size = static_cast<uint64_t>(payload.size());
951   reason->clear();
952   if (size < kWebSocketCloseCodeLength) {
953     if (size == 0U) {
954       *code = kWebSocketErrorNoStatusReceived;
955       return true;
956     }
957 
958     DVLOG(1) << "Close frame with payload size " << size << " received "
959              << "(the first byte is " << std::hex
960              << static_cast<int>(payload[0]) << ")";
961     *code = kWebSocketErrorProtocolError;
962     *message =
963         "Received a broken close frame containing an invalid size body.";
964     return false;
965   }
966 
967   const char* data = payload.data();
968   uint16_t unchecked_code = 0;
969   base::ReadBigEndian(reinterpret_cast<const uint8_t*>(data), &unchecked_code);
970   static_assert(sizeof(unchecked_code) == kWebSocketCloseCodeLength,
971                 "they should both be two bytes");
972 
973   switch (unchecked_code) {
974     case kWebSocketErrorNoStatusReceived:
975     case kWebSocketErrorAbnormalClosure:
976     case kWebSocketErrorTlsHandshake:
977       *code = kWebSocketErrorProtocolError;
978       *message =
979           "Received a broken close frame containing a reserved status code.";
980       return false;
981 
982     default:
983       *code = unchecked_code;
984       break;
985   }
986 
987   std::string text(data + kWebSocketCloseCodeLength, data + size);
988   if (StreamingUtf8Validator::Validate(text)) {
989     reason->swap(text);
990     return true;
991   }
992 
993   *code = kWebSocketErrorProtocolError;
994   *reason = "Invalid UTF-8 in Close frame";
995   *message = "Received a broken close frame containing invalid UTF-8.";
996   return false;
997 }
998 
DoDropChannel(bool was_clean,uint16_t code,const std::string & reason)999 void WebSocketChannel::DoDropChannel(bool was_clean,
1000                                      uint16_t code,
1001                                      const std::string& reason) {
1002   event_interface_->OnDropChannel(was_clean, code, reason);
1003 }
1004 
CloseTimeout()1005 void WebSocketChannel::CloseTimeout() {
1006   stream_->GetNetLogWithSource().AddEvent(
1007       net::NetLogEventType::WEBSOCKET_CLOSE_TIMEOUT);
1008   stream_->Close();
1009   SetState(CLOSED);
1010   if (has_received_close_frame_) {
1011     DoDropChannel(true, received_close_code_, received_close_reason_);
1012   } else {
1013     DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
1014   }
1015   // |this| has been deleted.
1016 }
1017 
1018 }  // namespace net
1019