1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_channel.h"
6
7 #include <limits.h> // for INT_MAX
8 #include <stddef.h>
9
10 #include <utility>
11 #include <vector>
12
13 #include "base/big_endian.h"
14 #include "base/containers/circular_deque.h"
15 #include "base/functional/bind.h"
16 #include "base/location.h"
17 #include "base/memory/raw_ptr.h"
18 #include "base/memory/weak_ptr.h"
19 #include "base/numerics/safe_conversions.h"
20 #include "base/ranges/algorithm.h"
21 #include "base/strings/string_piece.h"
22 #include "base/strings/stringprintf.h"
23 #include "base/task/single_thread_task_runner.h"
24 #include "base/time/time.h"
25 #include "net/base/auth.h"
26 #include "net/base/io_buffer.h"
27 #include "net/base/ip_endpoint.h"
28 #include "net/http/http_request_headers.h"
29 #include "net/http/http_response_headers.h"
30 #include "net/http/http_util.h"
31 #include "net/log/net_log_with_source.h"
32 #include "net/traffic_annotation/network_traffic_annotation.h"
33 #include "net/websockets/websocket_errors.h"
34 #include "net/websockets/websocket_event_interface.h"
35 #include "net/websockets/websocket_frame.h"
36 #include "net/websockets/websocket_handshake_request_info.h"
37 #include "net/websockets/websocket_handshake_response_info.h"
38 #include "net/websockets/websocket_stream.h"
39 #include "url/origin.h"
40
41 namespace net {
42
43 namespace {
44
45 using base::StreamingUtf8Validator;
46
47 constexpr size_t kWebSocketCloseCodeLength = 2;
48 // Timeout for waiting for the server to acknowledge a closing handshake.
49 constexpr int kClosingHandshakeTimeoutSeconds = 60;
50 // We wait for the server to close the underlying connection as recommended in
51 // https://tools.ietf.org/html/rfc6455#section-7.1.1
52 // We don't use 2MSL since there're server implementations that don't follow
53 // the recommendation and wait for the client to close the underlying
54 // connection. It leads to unnecessarily long time before CloseEvent
55 // invocation. We want to avoid this rather than strictly following the spec
56 // recommendation.
57 constexpr int kUnderlyingConnectionCloseTimeoutSeconds = 2;
58
59 using ChannelState = WebSocketChannel::ChannelState;
60
61 // Maximum close reason length = max control frame payload -
62 // status code length
63 // = 125 - 2
64 constexpr size_t kMaximumCloseReasonLength = 125 - kWebSocketCloseCodeLength;
65
66 // Check a close status code for strict compliance with RFC6455. This is only
67 // used for close codes received from a renderer that we are intending to send
68 // out over the network. See ParseClose() for the restrictions on incoming close
69 // codes. The |code| parameter is type int for convenience of implementation;
70 // the real type is uint16_t. Code 1005 is treated specially; it cannot be set
71 // explicitly by Javascript but the renderer uses it to indicate we should send
72 // a Close frame with no payload.
IsStrictlyValidCloseStatusCode(int code)73 bool IsStrictlyValidCloseStatusCode(int code) {
74 static const int kInvalidRanges[] = {
75 // [BAD, OK)
76 0, 1000, // 1000 is the first valid code
77 1006, 1007, // 1006 MUST NOT be set.
78 1014, 3000, // 1014 unassigned; 1015 up to 2999 are reserved.
79 5000, 65536, // Codes above 5000 are invalid.
80 };
81 const int* const kInvalidRangesEnd =
82 kInvalidRanges + std::size(kInvalidRanges);
83
84 DCHECK_GE(code, 0);
85 DCHECK_LT(code, 65536);
86 const int* upper = std::upper_bound(kInvalidRanges, kInvalidRangesEnd, code);
87 DCHECK_NE(kInvalidRangesEnd, upper);
88 DCHECK_GT(upper, kInvalidRanges);
89 DCHECK_GT(*upper, code);
90 DCHECK_LE(*(upper - 1), code);
91 return ((upper - kInvalidRanges) % 2) == 0;
92 }
93
94 // Sets |name| to the name of the frame type for the given |opcode|. Note that
95 // for all of Text, Binary and Continuation opcode, this method returns
96 // "Data frame".
GetFrameTypeForOpcode(WebSocketFrameHeader::OpCode opcode,std::string * name)97 void GetFrameTypeForOpcode(WebSocketFrameHeader::OpCode opcode,
98 std::string* name) {
99 switch (opcode) {
100 case WebSocketFrameHeader::kOpCodeText: // fall-thru
101 case WebSocketFrameHeader::kOpCodeBinary: // fall-thru
102 case WebSocketFrameHeader::kOpCodeContinuation:
103 *name = "Data frame";
104 break;
105
106 case WebSocketFrameHeader::kOpCodePing:
107 *name = "Ping";
108 break;
109
110 case WebSocketFrameHeader::kOpCodePong:
111 *name = "Pong";
112 break;
113
114 case WebSocketFrameHeader::kOpCodeClose:
115 *name = "Close";
116 break;
117
118 default:
119 *name = "Unknown frame type";
120 break;
121 }
122
123 return;
124 }
125
NetLogFailParam(uint16_t code,base::StringPiece reason,base::StringPiece message)126 base::Value::Dict NetLogFailParam(uint16_t code,
127 base::StringPiece reason,
128 base::StringPiece message) {
129 base::Value::Dict dict;
130 dict.Set("code", code);
131 dict.Set("reason", reason);
132 dict.Set("internal_reason", message);
133 return dict;
134 }
135
136 class DependentIOBuffer : public WrappedIOBuffer {
137 public:
DependentIOBuffer(scoped_refptr<IOBuffer> buffer,size_t offset)138 DependentIOBuffer(scoped_refptr<IOBuffer> buffer, size_t offset)
139 : WrappedIOBuffer(buffer->data() + offset), buffer_(std::move(buffer)) {}
140
141 private:
142 ~DependentIOBuffer() override = default;
143 scoped_refptr<IOBuffer> buffer_;
144 };
145
146 } // namespace
147
148 // A class to encapsulate a set of frames and information about the size of
149 // those frames.
150 class WebSocketChannel::SendBuffer {
151 public:
152 SendBuffer() = default;
153
154 // Add a WebSocketFrame to the buffer and increase total_bytes_.
155 void AddFrame(std::unique_ptr<WebSocketFrame> chunk,
156 scoped_refptr<IOBuffer> buffer);
157
158 // Return a pointer to the frames_ for write purposes.
frames()159 std::vector<std::unique_ptr<WebSocketFrame>>* frames() { return &frames_; }
160
161 private:
162 // The frames_ that will be sent in the next call to WriteFrames().
163 std::vector<std::unique_ptr<WebSocketFrame>> frames_;
164 // References of each WebSocketFrame.data;
165 std::vector<scoped_refptr<IOBuffer>> buffers_;
166
167 // The total size of the payload data in |frames_|. This will be used to
168 // measure the throughput of the link.
169 // TODO(ricea): Measure the throughput of the link.
170 uint64_t total_bytes_ = 0;
171 };
172
AddFrame(std::unique_ptr<WebSocketFrame> frame,scoped_refptr<IOBuffer> buffer)173 void WebSocketChannel::SendBuffer::AddFrame(
174 std::unique_ptr<WebSocketFrame> frame,
175 scoped_refptr<IOBuffer> buffer) {
176 total_bytes_ += frame->header.payload_length;
177 frames_.push_back(std::move(frame));
178 buffers_.push_back(std::move(buffer));
179 }
180
181 // Implementation of WebSocketStream::ConnectDelegate that simply forwards the
182 // calls on to the WebSocketChannel that created it.
183 class WebSocketChannel::ConnectDelegate
184 : public WebSocketStream::ConnectDelegate {
185 public:
ConnectDelegate(WebSocketChannel * creator)186 explicit ConnectDelegate(WebSocketChannel* creator) : creator_(creator) {}
187
188 ConnectDelegate(const ConnectDelegate&) = delete;
189 ConnectDelegate& operator=(const ConnectDelegate&) = delete;
190
OnCreateRequest(URLRequest * request)191 void OnCreateRequest(URLRequest* request) override {
192 creator_->OnCreateURLRequest(request);
193 }
194
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)195 void OnSuccess(
196 std::unique_ptr<WebSocketStream> stream,
197 std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {
198 creator_->OnConnectSuccess(std::move(stream), std::move(response));
199 // |this| may have been deleted.
200 }
201
OnFailure(const std::string & message,int net_error,absl::optional<int> response_code)202 void OnFailure(const std::string& message,
203 int net_error,
204 absl::optional<int> response_code) override {
205 creator_->OnConnectFailure(message, net_error, response_code);
206 // |this| has been deleted.
207 }
208
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)209 void OnStartOpeningHandshake(
210 std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {
211 creator_->OnStartOpeningHandshake(std::move(request));
212 }
213
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)214 void OnSSLCertificateError(
215 std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
216 ssl_error_callbacks,
217 int net_error,
218 const SSLInfo& ssl_info,
219 bool fatal) override {
220 creator_->OnSSLCertificateError(std::move(ssl_error_callbacks), net_error,
221 ssl_info, fatal);
222 }
223
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,absl::optional<AuthCredentials> * credentials)224 int OnAuthRequired(const AuthChallengeInfo& auth_info,
225 scoped_refptr<HttpResponseHeaders> headers,
226 const IPEndPoint& remote_endpoint,
227 base::OnceCallback<void(const AuthCredentials*)> callback,
228 absl::optional<AuthCredentials>* credentials) override {
229 return creator_->OnAuthRequired(auth_info, std::move(headers),
230 remote_endpoint, std::move(callback),
231 credentials);
232 }
233
234 private:
235 // A pointer to the WebSocketChannel that created this object. There is no
236 // danger of this pointer being stale, because deleting the WebSocketChannel
237 // cancels the connect process, deleting this object and preventing its
238 // callbacks from being called.
239 const raw_ptr<WebSocketChannel> creator_;
240 };
241
WebSocketChannel(std::unique_ptr<WebSocketEventInterface> event_interface,URLRequestContext * url_request_context)242 WebSocketChannel::WebSocketChannel(
243 std::unique_ptr<WebSocketEventInterface> event_interface,
244 URLRequestContext* url_request_context)
245 : event_interface_(std::move(event_interface)),
246 url_request_context_(url_request_context),
247 closing_handshake_timeout_(
248 base::Seconds(kClosingHandshakeTimeoutSeconds)),
249 underlying_connection_close_timeout_(
250 base::Seconds(kUnderlyingConnectionCloseTimeoutSeconds)) {}
251
~WebSocketChannel()252 WebSocketChannel::~WebSocketChannel() {
253 // The stream may hold a pointer to read_frames_, and so it needs to be
254 // destroyed first.
255 stream_.reset();
256 // The timer may have a callback pointing back to us, so stop it just in case
257 // someone decides to run the event loop from their destructor.
258 close_timer_.Stop();
259 }
260
SendAddChannelRequest(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation)261 void WebSocketChannel::SendAddChannelRequest(
262 const GURL& socket_url,
263 const std::vector<std::string>& requested_subprotocols,
264 const url::Origin& origin,
265 const SiteForCookies& site_for_cookies,
266 const IsolationInfo& isolation_info,
267 const HttpRequestHeaders& additional_headers,
268 NetworkTrafficAnnotationTag traffic_annotation) {
269 SendAddChannelRequestWithSuppliedCallback(
270 socket_url, requested_subprotocols, origin, site_for_cookies,
271 isolation_info, additional_headers, traffic_annotation,
272 base::BindOnce(&WebSocketStream::CreateAndConnectStream));
273 }
274
SetState(State new_state)275 void WebSocketChannel::SetState(State new_state) {
276 DCHECK_NE(state_, new_state);
277
278 state_ = new_state;
279 }
280
InClosingState() const281 bool WebSocketChannel::InClosingState() const {
282 // The state RECV_CLOSED is not supported here, because it is only used in one
283 // code path and should not leak into the code in general.
284 DCHECK_NE(RECV_CLOSED, state_)
285 << "InClosingState called with state_ == RECV_CLOSED";
286 return state_ == SEND_CLOSED || state_ == CLOSE_WAIT || state_ == CLOSED;
287 }
288
SendFrame(bool fin,WebSocketFrameHeader::OpCode op_code,scoped_refptr<IOBuffer> buffer,size_t buffer_size)289 WebSocketChannel::ChannelState WebSocketChannel::SendFrame(
290 bool fin,
291 WebSocketFrameHeader::OpCode op_code,
292 scoped_refptr<IOBuffer> buffer,
293 size_t buffer_size) {
294 DCHECK_LE(buffer_size, static_cast<size_t>(INT_MAX));
295 DCHECK(stream_) << "Got SendFrame without a connection established; fin="
296 << fin << " op_code=" << op_code
297 << " buffer_size=" << buffer_size;
298
299 if (InClosingState()) {
300 DVLOG(1) << "SendFrame called in state " << state_
301 << ". This may be a bug, or a harmless race.";
302 return CHANNEL_ALIVE;
303 }
304
305 DCHECK_EQ(state_, CONNECTED);
306
307 DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(op_code))
308 << "Got SendFrame with bogus op_code " << op_code << " fin=" << fin
309 << " buffer_size=" << buffer_size;
310
311 if (op_code == WebSocketFrameHeader::kOpCodeText ||
312 (op_code == WebSocketFrameHeader::kOpCodeContinuation &&
313 sending_text_message_)) {
314 StreamingUtf8Validator::State state =
315 outgoing_utf8_validator_.AddBytes(buffer->data(), buffer_size);
316 if (state == StreamingUtf8Validator::INVALID ||
317 (state == StreamingUtf8Validator::VALID_MIDPOINT && fin)) {
318 // TODO(ricea): Kill renderer.
319 FailChannel("Browser sent a text frame containing invalid UTF-8",
320 kWebSocketErrorGoingAway, "");
321 return CHANNEL_DELETED;
322 // |this| has been deleted.
323 }
324 sending_text_message_ = !fin;
325 DCHECK(!fin || state == StreamingUtf8Validator::VALID_ENDPOINT);
326 }
327
328 return SendFrameInternal(fin, op_code, std::move(buffer), buffer_size);
329 // |this| may have been deleted.
330 }
331
StartClosingHandshake(uint16_t code,const std::string & reason)332 ChannelState WebSocketChannel::StartClosingHandshake(
333 uint16_t code,
334 const std::string& reason) {
335 if (InClosingState()) {
336 // When the associated renderer process is killed while the channel is in
337 // CLOSING state we reach here.
338 DVLOG(1) << "StartClosingHandshake called in state " << state_
339 << ". This may be a bug, or a harmless race.";
340 return CHANNEL_ALIVE;
341 }
342 if (has_received_close_frame_) {
343 // We reach here if the client wants to start a closing handshake while
344 // the browser is waiting for the client to consume incoming data frames
345 // before responding to a closing handshake initiated by the server.
346 // As the client doesn't want the data frames any more, we can respond to
347 // the closing handshake initiated by the server.
348 return RespondToClosingHandshake();
349 }
350 if (state_ == CONNECTING) {
351 // Abort the in-progress handshake and drop the connection immediately.
352 stream_request_.reset();
353 SetState(CLOSED);
354 DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
355 return CHANNEL_DELETED;
356 }
357 DCHECK_EQ(state_, CONNECTED);
358
359 DCHECK(!close_timer_.IsRunning());
360 // This use of base::Unretained() is safe because we stop the timer in the
361 // destructor.
362 close_timer_.Start(
363 FROM_HERE, closing_handshake_timeout_,
364 base::BindOnce(&WebSocketChannel::CloseTimeout, base::Unretained(this)));
365
366 // Javascript actually only permits 1000 and 3000-4999, but the implementation
367 // itself may produce different codes. The length of |reason| is also checked
368 // by Javascript.
369 if (!IsStrictlyValidCloseStatusCode(code) ||
370 reason.size() > kMaximumCloseReasonLength) {
371 // "InternalServerError" is actually used for errors from any endpoint, per
372 // errata 3227 to RFC6455. If the renderer is sending us an invalid code or
373 // reason it must be malfunctioning in some way, and based on that we
374 // interpret this as an internal error.
375 if (SendClose(kWebSocketErrorInternalServerError, "") == CHANNEL_DELETED)
376 return CHANNEL_DELETED;
377 DCHECK_EQ(CONNECTED, state_);
378 SetState(SEND_CLOSED);
379 return CHANNEL_ALIVE;
380 }
381 if (SendClose(code, StreamingUtf8Validator::Validate(reason)
382 ? reason
383 : std::string()) == CHANNEL_DELETED)
384 return CHANNEL_DELETED;
385 DCHECK_EQ(CONNECTED, state_);
386 SetState(SEND_CLOSED);
387 return CHANNEL_ALIVE;
388 }
389
SendAddChannelRequestForTesting(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,WebSocketStreamRequestCreationCallback callback)390 void WebSocketChannel::SendAddChannelRequestForTesting(
391 const GURL& socket_url,
392 const std::vector<std::string>& requested_subprotocols,
393 const url::Origin& origin,
394 const SiteForCookies& site_for_cookies,
395 const IsolationInfo& isolation_info,
396 const HttpRequestHeaders& additional_headers,
397 NetworkTrafficAnnotationTag traffic_annotation,
398 WebSocketStreamRequestCreationCallback callback) {
399 SendAddChannelRequestWithSuppliedCallback(
400 socket_url, requested_subprotocols, origin, site_for_cookies,
401 isolation_info, additional_headers, traffic_annotation,
402 std::move(callback));
403 }
404
SetClosingHandshakeTimeoutForTesting(base::TimeDelta delay)405 void WebSocketChannel::SetClosingHandshakeTimeoutForTesting(
406 base::TimeDelta delay) {
407 closing_handshake_timeout_ = delay;
408 }
409
SetUnderlyingConnectionCloseTimeoutForTesting(base::TimeDelta delay)410 void WebSocketChannel::SetUnderlyingConnectionCloseTimeoutForTesting(
411 base::TimeDelta delay) {
412 underlying_connection_close_timeout_ = delay;
413 }
414
SendAddChannelRequestWithSuppliedCallback(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,WebSocketStreamRequestCreationCallback callback)415 void WebSocketChannel::SendAddChannelRequestWithSuppliedCallback(
416 const GURL& socket_url,
417 const std::vector<std::string>& requested_subprotocols,
418 const url::Origin& origin,
419 const SiteForCookies& site_for_cookies,
420 const IsolationInfo& isolation_info,
421 const HttpRequestHeaders& additional_headers,
422 NetworkTrafficAnnotationTag traffic_annotation,
423 WebSocketStreamRequestCreationCallback callback) {
424 DCHECK_EQ(FRESHLY_CONSTRUCTED, state_);
425 if (!socket_url.SchemeIsWSOrWSS()) {
426 // TODO(ricea): Kill the renderer (this error should have been caught by
427 // Javascript).
428 event_interface_->OnFailChannel("Invalid scheme", ERR_FAILED,
429 absl::nullopt);
430 // |this| is deleted here.
431 return;
432 }
433 socket_url_ = socket_url;
434 auto connect_delegate = std::make_unique<ConnectDelegate>(this);
435 stream_request_ = std::move(callback).Run(
436 socket_url_, requested_subprotocols, origin, site_for_cookies,
437 isolation_info, additional_headers, url_request_context_.get(),
438 NetLogWithSource(), traffic_annotation, std::move(connect_delegate));
439 SetState(CONNECTING);
440 }
441
OnCreateURLRequest(URLRequest * request)442 void WebSocketChannel::OnCreateURLRequest(URLRequest* request) {
443 event_interface_->OnCreateURLRequest(request);
444 }
445
OnConnectSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)446 void WebSocketChannel::OnConnectSuccess(
447 std::unique_ptr<WebSocketStream> stream,
448 std::unique_ptr<WebSocketHandshakeResponseInfo> response) {
449 DCHECK(stream);
450 DCHECK_EQ(CONNECTING, state_);
451
452 stream_ = std::move(stream);
453
454 SetState(CONNECTED);
455
456 // |stream_request_| is not used once the connection has succeeded.
457 stream_request_.reset();
458
459 event_interface_->OnAddChannelResponse(
460 std::move(response), stream_->GetSubProtocol(), stream_->GetExtensions());
461 // |this| may have been deleted after OnAddChannelResponse.
462 }
463
OnConnectFailure(const std::string & message,int net_error,absl::optional<int> response_code)464 void WebSocketChannel::OnConnectFailure(const std::string& message,
465 int net_error,
466 absl::optional<int> response_code) {
467 DCHECK_EQ(CONNECTING, state_);
468
469 // Copy the message before we delete its owner.
470 std::string message_copy = message;
471
472 SetState(CLOSED);
473 stream_request_.reset();
474
475 event_interface_->OnFailChannel(message_copy, net_error, response_code);
476 // |this| has been deleted.
477 }
478
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)479 void WebSocketChannel::OnSSLCertificateError(
480 std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
481 ssl_error_callbacks,
482 int net_error,
483 const SSLInfo& ssl_info,
484 bool fatal) {
485 event_interface_->OnSSLCertificateError(
486 std::move(ssl_error_callbacks), socket_url_, net_error, ssl_info, fatal);
487 }
488
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,absl::optional<AuthCredentials> * credentials)489 int WebSocketChannel::OnAuthRequired(
490 const AuthChallengeInfo& auth_info,
491 scoped_refptr<HttpResponseHeaders> response_headers,
492 const IPEndPoint& remote_endpoint,
493 base::OnceCallback<void(const AuthCredentials*)> callback,
494 absl::optional<AuthCredentials>* credentials) {
495 return event_interface_->OnAuthRequired(
496 auth_info, std::move(response_headers), remote_endpoint,
497 std::move(callback), credentials);
498 }
499
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)500 void WebSocketChannel::OnStartOpeningHandshake(
501 std::unique_ptr<WebSocketHandshakeRequestInfo> request) {
502 event_interface_->OnStartOpeningHandshake(std::move(request));
503 }
504
WriteFrames()505 ChannelState WebSocketChannel::WriteFrames() {
506 int result = OK;
507 do {
508 // This use of base::Unretained is safe because this object owns the
509 // WebSocketStream and destroying it cancels all callbacks.
510 result = stream_->WriteFrames(
511 data_being_sent_->frames(),
512 base::BindOnce(base::IgnoreResult(&WebSocketChannel::OnWriteDone),
513 base::Unretained(this), false));
514 if (result != ERR_IO_PENDING) {
515 if (OnWriteDone(true, result) == CHANNEL_DELETED)
516 return CHANNEL_DELETED;
517 // OnWriteDone() returns CHANNEL_DELETED on error. Here |state_| is
518 // guaranteed to be the same as before OnWriteDone() call.
519 }
520 } while (result == OK && data_being_sent_);
521 return CHANNEL_ALIVE;
522 }
523
OnWriteDone(bool synchronous,int result)524 ChannelState WebSocketChannel::OnWriteDone(bool synchronous, int result) {
525 DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
526 DCHECK_NE(CONNECTING, state_);
527 DCHECK_NE(ERR_IO_PENDING, result);
528 DCHECK(data_being_sent_);
529 switch (result) {
530 case OK:
531 if (data_to_send_next_) {
532 data_being_sent_ = std::move(data_to_send_next_);
533 if (!synchronous)
534 return WriteFrames();
535 } else {
536 data_being_sent_.reset();
537 event_interface_->OnSendDataFrameDone();
538 }
539 return CHANNEL_ALIVE;
540
541 // If a recoverable error condition existed, it would go here.
542
543 default:
544 DCHECK_LT(result, 0)
545 << "WriteFrames() should only return OK or ERR_ codes";
546
547 stream_->Close();
548 SetState(CLOSED);
549 DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
550 return CHANNEL_DELETED;
551 }
552 }
553
ReadFrames()554 ChannelState WebSocketChannel::ReadFrames() {
555 DCHECK(stream_);
556 DCHECK(state_ == CONNECTED || state_ == SEND_CLOSED || state_ == CLOSE_WAIT);
557 DCHECK(read_frames_.empty());
558 if (is_reading_) {
559 return CHANNEL_ALIVE;
560 }
561
562 if (!InClosingState() && has_received_close_frame_) {
563 DCHECK(!event_interface_->HasPendingDataFrames());
564 // We've been waiting for the client to consume the frames before
565 // responding to the closing handshake initiated by the server.
566 if (RespondToClosingHandshake() == CHANNEL_DELETED) {
567 return CHANNEL_DELETED;
568 }
569 }
570
571 // TODO(crbug.com/999235): Remove this CHECK.
572 CHECK(event_interface_);
573 while (!event_interface_->HasPendingDataFrames()) {
574 DCHECK(stream_);
575 // This use of base::Unretained is safe because this object owns the
576 // WebSocketStream, and any pending reads will be cancelled when it is
577 // destroyed.
578 const int result = stream_->ReadFrames(
579 &read_frames_,
580 base::BindOnce(base::IgnoreResult(&WebSocketChannel::OnReadDone),
581 base::Unretained(this), false));
582 if (result == ERR_IO_PENDING) {
583 is_reading_ = true;
584 return CHANNEL_ALIVE;
585 }
586 if (OnReadDone(true, result) == CHANNEL_DELETED) {
587 return CHANNEL_DELETED;
588 }
589 DCHECK_NE(CLOSED, state_);
590 // TODO(crbug.com/999235): Remove this CHECK.
591 CHECK(event_interface_);
592 }
593 return CHANNEL_ALIVE;
594 }
595
OnReadDone(bool synchronous,int result)596 ChannelState WebSocketChannel::OnReadDone(bool synchronous, int result) {
597 DVLOG(3) << "WebSocketChannel::OnReadDone synchronous?" << synchronous
598 << ", result=" << result
599 << ", read_frames_.size=" << read_frames_.size();
600 DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
601 DCHECK_NE(CONNECTING, state_);
602 DCHECK_NE(ERR_IO_PENDING, result);
603 switch (result) {
604 case OK:
605 // ReadFrames() must use ERR_CONNECTION_CLOSED for a closed connection
606 // with no data read, not an empty response.
607 DCHECK(!read_frames_.empty())
608 << "ReadFrames() returned OK, but nothing was read.";
609 for (auto& read_frame : read_frames_) {
610 if (HandleFrame(std::move(read_frame)) == CHANNEL_DELETED)
611 return CHANNEL_DELETED;
612 }
613 read_frames_.clear();
614 DCHECK_NE(CLOSED, state_);
615 if (!synchronous) {
616 is_reading_ = false;
617 if (!event_interface_->HasPendingDataFrames()) {
618 return ReadFrames();
619 }
620 }
621 return CHANNEL_ALIVE;
622
623 case ERR_WS_PROTOCOL_ERROR:
624 // This could be kWebSocketErrorProtocolError (specifically, non-minimal
625 // encoding of payload length) or kWebSocketErrorMessageTooBig, or an
626 // extension-specific error.
627 FailChannel("Invalid frame header", kWebSocketErrorProtocolError,
628 "WebSocket Protocol Error");
629 return CHANNEL_DELETED;
630
631 default:
632 DCHECK_LT(result, 0)
633 << "ReadFrames() should only return OK or ERR_ codes";
634
635 stream_->Close();
636 SetState(CLOSED);
637
638 uint16_t code = kWebSocketErrorAbnormalClosure;
639 std::string reason = "";
640 bool was_clean = false;
641 if (has_received_close_frame_) {
642 code = received_close_code_;
643 reason = received_close_reason_;
644 was_clean = (result == ERR_CONNECTION_CLOSED);
645 }
646
647 DoDropChannel(was_clean, code, reason);
648 return CHANNEL_DELETED;
649 }
650 }
651
HandleFrame(std::unique_ptr<WebSocketFrame> frame)652 ChannelState WebSocketChannel::HandleFrame(
653 std::unique_ptr<WebSocketFrame> frame) {
654 if (frame->header.masked) {
655 // RFC6455 Section 5.1 "A client MUST close a connection if it detects a
656 // masked frame."
657 FailChannel(
658 "A server must not mask any frames that it sends to the "
659 "client.",
660 kWebSocketErrorProtocolError, "Masked frame from server");
661 return CHANNEL_DELETED;
662 }
663 const WebSocketFrameHeader::OpCode opcode = frame->header.opcode;
664 DCHECK(!WebSocketFrameHeader::IsKnownControlOpCode(opcode) ||
665 frame->header.final);
666 if (frame->header.reserved1 || frame->header.reserved2 ||
667 frame->header.reserved3) {
668 FailChannel(
669 base::StringPrintf("One or more reserved bits are on: reserved1 = %d, "
670 "reserved2 = %d, reserved3 = %d",
671 static_cast<int>(frame->header.reserved1),
672 static_cast<int>(frame->header.reserved2),
673 static_cast<int>(frame->header.reserved3)),
674 kWebSocketErrorProtocolError, "Invalid reserved bit");
675 return CHANNEL_DELETED;
676 }
677
678 // Respond to the frame appropriately to its type.
679 return HandleFrameByState(
680 opcode, frame->header.final,
681 base::make_span(frame->payload, base::checked_cast<size_t>(
682 frame->header.payload_length)));
683 }
684
HandleFrameByState(const WebSocketFrameHeader::OpCode opcode,bool final,base::span<const char> payload)685 ChannelState WebSocketChannel::HandleFrameByState(
686 const WebSocketFrameHeader::OpCode opcode,
687 bool final,
688 base::span<const char> payload) {
689 DCHECK_NE(RECV_CLOSED, state_)
690 << "HandleFrame() does not support being called re-entrantly from within "
691 "SendClose()";
692 DCHECK_NE(CLOSED, state_);
693 if (state_ == CLOSE_WAIT) {
694 std::string frame_name;
695 GetFrameTypeForOpcode(opcode, &frame_name);
696
697 // FailChannel() won't send another Close frame.
698 FailChannel(frame_name + " received after close",
699 kWebSocketErrorProtocolError, "");
700 return CHANNEL_DELETED;
701 }
702 switch (opcode) {
703 case WebSocketFrameHeader::kOpCodeText: // fall-thru
704 case WebSocketFrameHeader::kOpCodeBinary:
705 case WebSocketFrameHeader::kOpCodeContinuation:
706 return HandleDataFrame(opcode, final, std::move(payload));
707
708 case WebSocketFrameHeader::kOpCodePing:
709 DVLOG(1) << "Got Ping of size " << payload.size();
710 if (state_ == CONNECTED) {
711 auto buffer = base::MakeRefCounted<IOBuffer>(payload.size());
712 memcpy(buffer->data(), payload.data(), payload.size());
713 return SendFrameInternal(true, WebSocketFrameHeader::kOpCodePong,
714 std::move(buffer), payload.size());
715 }
716 DVLOG(3) << "Ignored ping in state " << state_;
717 return CHANNEL_ALIVE;
718
719 case WebSocketFrameHeader::kOpCodePong:
720 DVLOG(1) << "Got Pong of size " << payload.size();
721 // There is no need to do anything with pong messages.
722 return CHANNEL_ALIVE;
723
724 case WebSocketFrameHeader::kOpCodeClose: {
725 uint16_t code = kWebSocketNormalClosure;
726 std::string reason;
727 std::string message;
728 if (!ParseClose(payload, &code, &reason, &message)) {
729 FailChannel(message, code, reason);
730 return CHANNEL_DELETED;
731 }
732 // TODO(ricea): Find a way to safely log the message from the close
733 // message (escape control codes and so on).
734 return HandleCloseFrame(code, reason);
735 }
736
737 default:
738 FailChannel(base::StringPrintf("Unrecognized frame opcode: %d", opcode),
739 kWebSocketErrorProtocolError, "Unknown opcode");
740 return CHANNEL_DELETED;
741 }
742 }
743
HandleDataFrame(WebSocketFrameHeader::OpCode opcode,bool final,base::span<const char> payload)744 ChannelState WebSocketChannel::HandleDataFrame(
745 WebSocketFrameHeader::OpCode opcode,
746 bool final,
747 base::span<const char> payload) {
748 DVLOG(3) << "WebSocketChannel::HandleDataFrame opcode=" << opcode
749 << ", final?" << final << ", data=" << (void*)payload.data()
750 << ", size=" << payload.size();
751 if (state_ != CONNECTED) {
752 DVLOG(3) << "Ignored data packet received in state " << state_;
753 return CHANNEL_ALIVE;
754 }
755 if (has_received_close_frame_) {
756 DVLOG(3) << "Ignored data packet as we've received a close frame.";
757 return CHANNEL_ALIVE;
758 }
759 DCHECK(opcode == WebSocketFrameHeader::kOpCodeContinuation ||
760 opcode == WebSocketFrameHeader::kOpCodeText ||
761 opcode == WebSocketFrameHeader::kOpCodeBinary);
762 const bool got_continuation =
763 (opcode == WebSocketFrameHeader::kOpCodeContinuation);
764 if (got_continuation != expecting_to_handle_continuation_) {
765 const std::string console_log = got_continuation
766 ? "Received unexpected continuation frame."
767 : "Received start of new message but previous message is unfinished.";
768 const std::string reason = got_continuation
769 ? "Unexpected continuation"
770 : "Previous data frame unfinished";
771 FailChannel(console_log, kWebSocketErrorProtocolError, reason);
772 return CHANNEL_DELETED;
773 }
774 expecting_to_handle_continuation_ = !final;
775 WebSocketFrameHeader::OpCode opcode_to_send = opcode;
776 if (!initial_frame_forwarded_ &&
777 opcode == WebSocketFrameHeader::kOpCodeContinuation) {
778 opcode_to_send = receiving_text_message_
779 ? WebSocketFrameHeader::kOpCodeText
780 : WebSocketFrameHeader::kOpCodeBinary;
781 }
782 if (opcode == WebSocketFrameHeader::kOpCodeText ||
783 (opcode == WebSocketFrameHeader::kOpCodeContinuation &&
784 receiving_text_message_)) {
785 // This call is not redundant when size == 0 because it tells us what
786 // the current state is.
787 StreamingUtf8Validator::State state = incoming_utf8_validator_.AddBytes(
788 payload.data(), static_cast<size_t>(payload.size()));
789 if (state == StreamingUtf8Validator::INVALID ||
790 (state == StreamingUtf8Validator::VALID_MIDPOINT && final)) {
791 FailChannel("Could not decode a text frame as UTF-8.",
792 kWebSocketErrorProtocolError, "Invalid UTF-8 in text frame");
793 return CHANNEL_DELETED;
794 }
795 receiving_text_message_ = !final;
796 DCHECK(!final || state == StreamingUtf8Validator::VALID_ENDPOINT);
797 }
798 if (payload.size() == 0U && !final)
799 return CHANNEL_ALIVE;
800
801 initial_frame_forwarded_ = !final;
802 // Sends the received frame to the renderer process.
803 event_interface_->OnDataFrame(final, opcode_to_send, payload);
804 return CHANNEL_ALIVE;
805 }
806
HandleCloseFrame(uint16_t code,const std::string & reason)807 ChannelState WebSocketChannel::HandleCloseFrame(uint16_t code,
808 const std::string& reason) {
809 DVLOG(1) << "Got Close with code " << code;
810 switch (state_) {
811 case CONNECTED:
812 has_received_close_frame_ = true;
813 received_close_code_ = code;
814 received_close_reason_ = reason;
815 if (event_interface_->HasPendingDataFrames()) {
816 // We have some data to be sent to the renderer before sending this
817 // frame.
818 return CHANNEL_ALIVE;
819 }
820 return RespondToClosingHandshake();
821
822 case SEND_CLOSED:
823 SetState(CLOSE_WAIT);
824 DCHECK(close_timer_.IsRunning());
825 close_timer_.Stop();
826 // This use of base::Unretained() is safe because we stop the timer
827 // in the destructor.
828 close_timer_.Start(FROM_HERE, underlying_connection_close_timeout_,
829 base::BindOnce(&WebSocketChannel::CloseTimeout,
830 base::Unretained(this)));
831
832 // From RFC6455 section 7.1.5: "Each endpoint
833 // will see the status code sent by the other end as _The WebSocket
834 // Connection Close Code_."
835 has_received_close_frame_ = true;
836 received_close_code_ = code;
837 received_close_reason_ = reason;
838 break;
839
840 default:
841 LOG(DFATAL) << "Got Close in unexpected state " << state_;
842 break;
843 }
844 return CHANNEL_ALIVE;
845 }
846
RespondToClosingHandshake()847 ChannelState WebSocketChannel::RespondToClosingHandshake() {
848 DCHECK(has_received_close_frame_);
849 DCHECK_EQ(CONNECTED, state_);
850 SetState(RECV_CLOSED);
851 if (SendClose(received_close_code_, received_close_reason_) ==
852 CHANNEL_DELETED)
853 return CHANNEL_DELETED;
854 DCHECK_EQ(RECV_CLOSED, state_);
855
856 SetState(CLOSE_WAIT);
857 DCHECK(!close_timer_.IsRunning());
858 // This use of base::Unretained() is safe because we stop the timer
859 // in the destructor.
860 close_timer_.Start(
861 FROM_HERE, underlying_connection_close_timeout_,
862 base::BindOnce(&WebSocketChannel::CloseTimeout, base::Unretained(this)));
863
864 event_interface_->OnClosingHandshake();
865 return CHANNEL_ALIVE;
866 }
867
SendFrameInternal(bool fin,WebSocketFrameHeader::OpCode op_code,scoped_refptr<IOBuffer> buffer,uint64_t buffer_size)868 ChannelState WebSocketChannel::SendFrameInternal(
869 bool fin,
870 WebSocketFrameHeader::OpCode op_code,
871 scoped_refptr<IOBuffer> buffer,
872 uint64_t buffer_size) {
873 DCHECK(state_ == CONNECTED || state_ == RECV_CLOSED);
874 DCHECK(stream_);
875
876 auto frame = std::make_unique<WebSocketFrame>(op_code);
877 WebSocketFrameHeader& header = frame->header;
878 header.final = fin;
879 header.masked = true;
880 header.payload_length = buffer_size;
881 frame->payload = buffer->data();
882
883 if (data_being_sent_) {
884 // Either the link to the WebSocket server is saturated, or several messages
885 // are being sent in a batch.
886 if (!data_to_send_next_)
887 data_to_send_next_ = std::make_unique<SendBuffer>();
888 data_to_send_next_->AddFrame(std::move(frame), std::move(buffer));
889 return CHANNEL_ALIVE;
890 }
891
892 data_being_sent_ = std::make_unique<SendBuffer>();
893 data_being_sent_->AddFrame(std::move(frame), std::move(buffer));
894 return WriteFrames();
895 }
896
FailChannel(const std::string & message,uint16_t code,const std::string & reason)897 void WebSocketChannel::FailChannel(const std::string& message,
898 uint16_t code,
899 const std::string& reason) {
900 DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
901 DCHECK_NE(CONNECTING, state_);
902 DCHECK_NE(CLOSED, state_);
903
904 stream_->GetNetLogWithSource().AddEvent(
905 net::NetLogEventType::WEBSOCKET_INVALID_FRAME,
906 [&] { return NetLogFailParam(code, reason, message); });
907
908 if (state_ == CONNECTED) {
909 if (SendClose(code, reason) == CHANNEL_DELETED)
910 return;
911 }
912
913 // Careful study of RFC6455 section 7.1.7 and 7.1.1 indicates the browser
914 // should close the connection itself without waiting for the closing
915 // handshake.
916 stream_->Close();
917 SetState(CLOSED);
918 event_interface_->OnFailChannel(message, ERR_FAILED, absl::nullopt);
919 }
920
SendClose(uint16_t code,const std::string & reason)921 ChannelState WebSocketChannel::SendClose(uint16_t code,
922 const std::string& reason) {
923 DCHECK(state_ == CONNECTED || state_ == RECV_CLOSED);
924 DCHECK_LE(reason.size(), kMaximumCloseReasonLength);
925 scoped_refptr<IOBuffer> body;
926 uint64_t size = 0;
927 if (code == kWebSocketErrorNoStatusReceived) {
928 // Special case: translate kWebSocketErrorNoStatusReceived into a Close
929 // frame with no payload.
930 DCHECK(reason.empty());
931 body = base::MakeRefCounted<IOBuffer>(0);
932 } else {
933 const size_t payload_length = kWebSocketCloseCodeLength + reason.length();
934 body = base::MakeRefCounted<IOBuffer>(payload_length);
935 size = payload_length;
936 base::WriteBigEndian(body->data(), code);
937 static_assert(sizeof(code) == kWebSocketCloseCodeLength,
938 "they should both be two");
939 base::ranges::copy(reason, body->data() + kWebSocketCloseCodeLength);
940 }
941
942 return SendFrameInternal(true, WebSocketFrameHeader::kOpCodeClose,
943 std::move(body), size);
944 }
945
ParseClose(base::span<const char> payload,uint16_t * code,std::string * reason,std::string * message)946 bool WebSocketChannel::ParseClose(base::span<const char> payload,
947 uint16_t* code,
948 std::string* reason,
949 std::string* message) {
950 const uint64_t size = static_cast<uint64_t>(payload.size());
951 reason->clear();
952 if (size < kWebSocketCloseCodeLength) {
953 if (size == 0U) {
954 *code = kWebSocketErrorNoStatusReceived;
955 return true;
956 }
957
958 DVLOG(1) << "Close frame with payload size " << size << " received "
959 << "(the first byte is " << std::hex
960 << static_cast<int>(payload[0]) << ")";
961 *code = kWebSocketErrorProtocolError;
962 *message =
963 "Received a broken close frame containing an invalid size body.";
964 return false;
965 }
966
967 const char* data = payload.data();
968 uint16_t unchecked_code = 0;
969 base::ReadBigEndian(reinterpret_cast<const uint8_t*>(data), &unchecked_code);
970 static_assert(sizeof(unchecked_code) == kWebSocketCloseCodeLength,
971 "they should both be two bytes");
972
973 switch (unchecked_code) {
974 case kWebSocketErrorNoStatusReceived:
975 case kWebSocketErrorAbnormalClosure:
976 case kWebSocketErrorTlsHandshake:
977 *code = kWebSocketErrorProtocolError;
978 *message =
979 "Received a broken close frame containing a reserved status code.";
980 return false;
981
982 default:
983 *code = unchecked_code;
984 break;
985 }
986
987 std::string text(data + kWebSocketCloseCodeLength, data + size);
988 if (StreamingUtf8Validator::Validate(text)) {
989 reason->swap(text);
990 return true;
991 }
992
993 *code = kWebSocketErrorProtocolError;
994 *reason = "Invalid UTF-8 in Close frame";
995 *message = "Received a broken close frame containing invalid UTF-8.";
996 return false;
997 }
998
DoDropChannel(bool was_clean,uint16_t code,const std::string & reason)999 void WebSocketChannel::DoDropChannel(bool was_clean,
1000 uint16_t code,
1001 const std::string& reason) {
1002 event_interface_->OnDropChannel(was_clean, code, reason);
1003 }
1004
CloseTimeout()1005 void WebSocketChannel::CloseTimeout() {
1006 stream_->GetNetLogWithSource().AddEvent(
1007 net::NetLogEventType::WEBSOCKET_CLOSE_TIMEOUT);
1008 stream_->Close();
1009 SetState(CLOSED);
1010 if (has_received_close_frame_) {
1011 DoDropChannel(true, received_close_code_, received_close_reason_);
1012 } else {
1013 DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
1014 }
1015 // |this| has been deleted.
1016 }
1017
1018 } // namespace net
1019