1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_channel.h"
6
7 #include <limits.h> // for INT_MAX
8 #include <stddef.h>
9 #include <string.h>
10
11 #include <algorithm>
12 #include <iterator>
13 #include <ostream>
14 #include <utility>
15 #include <vector>
16
17 #include "base/big_endian.h"
18 #include "base/check.h"
19 #include "base/check_op.h"
20 #include "base/functional/bind.h"
21 #include "base/location.h"
22 #include "base/logging.h"
23 #include "base/memory/raw_ptr.h"
24 #include "base/numerics/safe_conversions.h"
25 #include "base/ranges/algorithm.h"
26 #include "base/strings/string_piece.h"
27 #include "base/strings/stringprintf.h"
28 #include "base/time/time.h"
29 #include "base/values.h"
30 #include "net/base/io_buffer.h"
31 #include "net/base/net_errors.h"
32 #include "net/http/http_response_headers.h"
33 #include "net/log/net_log_event_type.h"
34 #include "net/log/net_log_with_source.h"
35 #include "net/traffic_annotation/network_traffic_annotation.h"
36 #include "net/websockets/websocket_errors.h"
37 #include "net/websockets/websocket_event_interface.h"
38 #include "net/websockets/websocket_frame.h"
39 #include "net/websockets/websocket_handshake_request_info.h"
40 #include "net/websockets/websocket_handshake_response_info.h"
41 #include "net/websockets/websocket_stream.h"
42
43 namespace net {
44 class AuthChallengeInfo;
45 class AuthCredentials;
46 class SSLInfo;
47
48 namespace {
49
50 using base::StreamingUtf8Validator;
51
52 constexpr size_t kWebSocketCloseCodeLength = 2;
53 // Timeout for waiting for the server to acknowledge a closing handshake.
54 constexpr int kClosingHandshakeTimeoutSeconds = 60;
55 // We wait for the server to close the underlying connection as recommended in
56 // https://tools.ietf.org/html/rfc6455#section-7.1.1
57 // We don't use 2MSL since there're server implementations that don't follow
58 // the recommendation and wait for the client to close the underlying
59 // connection. It leads to unnecessarily long time before CloseEvent
60 // invocation. We want to avoid this rather than strictly following the spec
61 // recommendation.
62 constexpr int kUnderlyingConnectionCloseTimeoutSeconds = 2;
63
64 using ChannelState = WebSocketChannel::ChannelState;
65
66 // Maximum close reason length = max control frame payload -
67 // status code length
68 // = 125 - 2
69 constexpr size_t kMaximumCloseReasonLength = 125 - kWebSocketCloseCodeLength;
70
71 // Check a close status code for strict compliance with RFC6455. This is only
72 // used for close codes received from a renderer that we are intending to send
73 // out over the network. See ParseClose() for the restrictions on incoming close
74 // codes. The |code| parameter is type int for convenience of implementation;
75 // the real type is uint16_t. Code 1005 is treated specially; it cannot be set
76 // explicitly by Javascript but the renderer uses it to indicate we should send
77 // a Close frame with no payload.
IsStrictlyValidCloseStatusCode(int code)78 bool IsStrictlyValidCloseStatusCode(int code) {
79 static const int kInvalidRanges[] = {
80 // [BAD, OK)
81 0, 1000, // 1000 is the first valid code
82 1006, 1007, // 1006 MUST NOT be set.
83 1014, 3000, // 1014 unassigned; 1015 up to 2999 are reserved.
84 5000, 65536, // Codes above 5000 are invalid.
85 };
86 const int* const kInvalidRangesEnd =
87 kInvalidRanges + std::size(kInvalidRanges);
88
89 DCHECK_GE(code, 0);
90 DCHECK_LT(code, 65536);
91 const int* upper = std::upper_bound(kInvalidRanges, kInvalidRangesEnd, code);
92 DCHECK_NE(kInvalidRangesEnd, upper);
93 DCHECK_GT(upper, kInvalidRanges);
94 DCHECK_GT(*upper, code);
95 DCHECK_LE(*(upper - 1), code);
96 return ((upper - kInvalidRanges) % 2) == 0;
97 }
98
99 // Sets |name| to the name of the frame type for the given |opcode|. Note that
100 // for all of Text, Binary and Continuation opcode, this method returns
101 // "Data frame".
GetFrameTypeForOpcode(WebSocketFrameHeader::OpCode opcode,std::string * name)102 void GetFrameTypeForOpcode(WebSocketFrameHeader::OpCode opcode,
103 std::string* name) {
104 switch (opcode) {
105 case WebSocketFrameHeader::kOpCodeText: // fall-thru
106 case WebSocketFrameHeader::kOpCodeBinary: // fall-thru
107 case WebSocketFrameHeader::kOpCodeContinuation:
108 *name = "Data frame";
109 break;
110
111 case WebSocketFrameHeader::kOpCodePing:
112 *name = "Ping";
113 break;
114
115 case WebSocketFrameHeader::kOpCodePong:
116 *name = "Pong";
117 break;
118
119 case WebSocketFrameHeader::kOpCodeClose:
120 *name = "Close";
121 break;
122
123 default:
124 *name = "Unknown frame type";
125 break;
126 }
127
128 return;
129 }
130
NetLogFailParam(uint16_t code,base::StringPiece reason,base::StringPiece message)131 base::Value::Dict NetLogFailParam(uint16_t code,
132 base::StringPiece reason,
133 base::StringPiece message) {
134 base::Value::Dict dict;
135 dict.Set("code", code);
136 dict.Set("reason", reason);
137 dict.Set("internal_reason", message);
138 return dict;
139 }
140
141 class DependentIOBuffer : public WrappedIOBuffer {
142 public:
DependentIOBuffer(scoped_refptr<IOBufferWithSize> buffer,size_t offset)143 DependentIOBuffer(scoped_refptr<IOBufferWithSize> buffer, size_t offset)
144 : WrappedIOBuffer(buffer->data() + offset, buffer->size() - offset),
145 buffer_(std::move(buffer)) {}
146
147 private:
~DependentIOBuffer()148 ~DependentIOBuffer() override {
149 // Prevent `data_` from dangling should this destructor remove the
150 // last reference to `buffer_`.
151 data_ = nullptr;
152 }
153
154 scoped_refptr<IOBufferWithSize> buffer_;
155 };
156
157 } // namespace
158
159 // A class to encapsulate a set of frames and information about the size of
160 // those frames.
161 class WebSocketChannel::SendBuffer {
162 public:
163 SendBuffer() = default;
164
165 // Add a WebSocketFrame to the buffer and increase total_bytes_.
166 void AddFrame(std::unique_ptr<WebSocketFrame> chunk,
167 scoped_refptr<IOBuffer> buffer);
168
169 // Return a pointer to the frames_ for write purposes.
frames()170 std::vector<std::unique_ptr<WebSocketFrame>>* frames() { return &frames_; }
171
172 private:
173 // The frames_ that will be sent in the next call to WriteFrames().
174 std::vector<std::unique_ptr<WebSocketFrame>> frames_;
175 // References of each WebSocketFrame.data;
176 std::vector<scoped_refptr<IOBuffer>> buffers_;
177
178 // The total size of the payload data in |frames_|. This will be used to
179 // measure the throughput of the link.
180 // TODO(ricea): Measure the throughput of the link.
181 uint64_t total_bytes_ = 0;
182 };
183
AddFrame(std::unique_ptr<WebSocketFrame> frame,scoped_refptr<IOBuffer> buffer)184 void WebSocketChannel::SendBuffer::AddFrame(
185 std::unique_ptr<WebSocketFrame> frame,
186 scoped_refptr<IOBuffer> buffer) {
187 total_bytes_ += frame->header.payload_length;
188 frames_.push_back(std::move(frame));
189 buffers_.push_back(std::move(buffer));
190 }
191
192 // Implementation of WebSocketStream::ConnectDelegate that simply forwards the
193 // calls on to the WebSocketChannel that created it.
194 class WebSocketChannel::ConnectDelegate
195 : public WebSocketStream::ConnectDelegate {
196 public:
ConnectDelegate(WebSocketChannel * creator)197 explicit ConnectDelegate(WebSocketChannel* creator) : creator_(creator) {}
198
199 ConnectDelegate(const ConnectDelegate&) = delete;
200 ConnectDelegate& operator=(const ConnectDelegate&) = delete;
201
OnCreateRequest(URLRequest * request)202 void OnCreateRequest(URLRequest* request) override {
203 creator_->OnCreateURLRequest(request);
204 }
205
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)206 void OnSuccess(
207 std::unique_ptr<WebSocketStream> stream,
208 std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {
209 creator_->OnConnectSuccess(std::move(stream), std::move(response));
210 // |this| may have been deleted.
211 }
212
OnFailure(const std::string & message,int net_error,absl::optional<int> response_code)213 void OnFailure(const std::string& message,
214 int net_error,
215 absl::optional<int> response_code) override {
216 creator_->OnConnectFailure(message, net_error, response_code);
217 // |this| has been deleted.
218 }
219
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)220 void OnStartOpeningHandshake(
221 std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {
222 creator_->OnStartOpeningHandshake(std::move(request));
223 }
224
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)225 void OnSSLCertificateError(
226 std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
227 ssl_error_callbacks,
228 int net_error,
229 const SSLInfo& ssl_info,
230 bool fatal) override {
231 creator_->OnSSLCertificateError(std::move(ssl_error_callbacks), net_error,
232 ssl_info, fatal);
233 }
234
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,absl::optional<AuthCredentials> * credentials)235 int OnAuthRequired(const AuthChallengeInfo& auth_info,
236 scoped_refptr<HttpResponseHeaders> headers,
237 const IPEndPoint& remote_endpoint,
238 base::OnceCallback<void(const AuthCredentials*)> callback,
239 absl::optional<AuthCredentials>* credentials) override {
240 return creator_->OnAuthRequired(auth_info, std::move(headers),
241 remote_endpoint, std::move(callback),
242 credentials);
243 }
244
245 private:
246 // A pointer to the WebSocketChannel that created this object. There is no
247 // danger of this pointer being stale, because deleting the WebSocketChannel
248 // cancels the connect process, deleting this object and preventing its
249 // callbacks from being called.
250 const raw_ptr<WebSocketChannel, DanglingUntriaged> creator_;
251 };
252
WebSocketChannel(std::unique_ptr<WebSocketEventInterface> event_interface,URLRequestContext * url_request_context)253 WebSocketChannel::WebSocketChannel(
254 std::unique_ptr<WebSocketEventInterface> event_interface,
255 URLRequestContext* url_request_context)
256 : event_interface_(std::move(event_interface)),
257 url_request_context_(url_request_context),
258 closing_handshake_timeout_(
259 base::Seconds(kClosingHandshakeTimeoutSeconds)),
260 underlying_connection_close_timeout_(
261 base::Seconds(kUnderlyingConnectionCloseTimeoutSeconds)) {}
262
~WebSocketChannel()263 WebSocketChannel::~WebSocketChannel() {
264 // The stream may hold a pointer to read_frames_, and so it needs to be
265 // destroyed first.
266 stream_.reset();
267 // The timer may have a callback pointing back to us, so stop it just in case
268 // someone decides to run the event loop from their destructor.
269 close_timer_.Stop();
270 }
271
SendAddChannelRequest(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation)272 void WebSocketChannel::SendAddChannelRequest(
273 const GURL& socket_url,
274 const std::vector<std::string>& requested_subprotocols,
275 const url::Origin& origin,
276 const SiteForCookies& site_for_cookies,
277 bool has_storage_access,
278 const IsolationInfo& isolation_info,
279 const HttpRequestHeaders& additional_headers,
280 NetworkTrafficAnnotationTag traffic_annotation) {
281 SendAddChannelRequestWithSuppliedCallback(
282 socket_url, requested_subprotocols, origin, site_for_cookies,
283 has_storage_access, isolation_info, additional_headers,
284 traffic_annotation,
285 base::BindOnce(&WebSocketStream::CreateAndConnectStream));
286 }
287
SetState(State new_state)288 void WebSocketChannel::SetState(State new_state) {
289 DCHECK_NE(state_, new_state);
290
291 state_ = new_state;
292 }
293
InClosingState() const294 bool WebSocketChannel::InClosingState() const {
295 // The state RECV_CLOSED is not supported here, because it is only used in one
296 // code path and should not leak into the code in general.
297 DCHECK_NE(RECV_CLOSED, state_)
298 << "InClosingState called with state_ == RECV_CLOSED";
299 return state_ == SEND_CLOSED || state_ == CLOSE_WAIT || state_ == CLOSED;
300 }
301
SendFrame(bool fin,WebSocketFrameHeader::OpCode op_code,scoped_refptr<IOBuffer> buffer,size_t buffer_size)302 WebSocketChannel::ChannelState WebSocketChannel::SendFrame(
303 bool fin,
304 WebSocketFrameHeader::OpCode op_code,
305 scoped_refptr<IOBuffer> buffer,
306 size_t buffer_size) {
307 DCHECK_LE(buffer_size, static_cast<size_t>(INT_MAX));
308 DCHECK(stream_) << "Got SendFrame without a connection established; fin="
309 << fin << " op_code=" << op_code
310 << " buffer_size=" << buffer_size;
311
312 if (InClosingState()) {
313 DVLOG(1) << "SendFrame called in state " << state_
314 << ". This may be a bug, or a harmless race.";
315 return CHANNEL_ALIVE;
316 }
317
318 DCHECK_EQ(state_, CONNECTED);
319
320 DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(op_code))
321 << "Got SendFrame with bogus op_code " << op_code << " fin=" << fin
322 << " buffer_size=" << buffer_size;
323
324 if (op_code == WebSocketFrameHeader::kOpCodeText ||
325 (op_code == WebSocketFrameHeader::kOpCodeContinuation &&
326 sending_text_message_)) {
327 StreamingUtf8Validator::State state = outgoing_utf8_validator_.AddBytes(
328 base::make_span(buffer->bytes(), buffer_size));
329 if (state == StreamingUtf8Validator::INVALID ||
330 (state == StreamingUtf8Validator::VALID_MIDPOINT && fin)) {
331 // TODO(ricea): Kill renderer.
332 FailChannel("Browser sent a text frame containing invalid UTF-8",
333 kWebSocketErrorGoingAway, "");
334 return CHANNEL_DELETED;
335 // |this| has been deleted.
336 }
337 sending_text_message_ = !fin;
338 DCHECK(!fin || state == StreamingUtf8Validator::VALID_ENDPOINT);
339 }
340
341 return SendFrameInternal(fin, op_code, std::move(buffer), buffer_size);
342 // |this| may have been deleted.
343 }
344
StartClosingHandshake(uint16_t code,const std::string & reason)345 ChannelState WebSocketChannel::StartClosingHandshake(
346 uint16_t code,
347 const std::string& reason) {
348 if (InClosingState()) {
349 // When the associated renderer process is killed while the channel is in
350 // CLOSING state we reach here.
351 DVLOG(1) << "StartClosingHandshake called in state " << state_
352 << ". This may be a bug, or a harmless race.";
353 return CHANNEL_ALIVE;
354 }
355 if (has_received_close_frame_) {
356 // We reach here if the client wants to start a closing handshake while
357 // the browser is waiting for the client to consume incoming data frames
358 // before responding to a closing handshake initiated by the server.
359 // As the client doesn't want the data frames any more, we can respond to
360 // the closing handshake initiated by the server.
361 return RespondToClosingHandshake();
362 }
363 if (state_ == CONNECTING) {
364 // Abort the in-progress handshake and drop the connection immediately.
365 stream_request_.reset();
366 SetState(CLOSED);
367 DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
368 return CHANNEL_DELETED;
369 }
370 DCHECK_EQ(state_, CONNECTED);
371
372 DCHECK(!close_timer_.IsRunning());
373 // This use of base::Unretained() is safe because we stop the timer in the
374 // destructor.
375 close_timer_.Start(
376 FROM_HERE, closing_handshake_timeout_,
377 base::BindOnce(&WebSocketChannel::CloseTimeout, base::Unretained(this)));
378
379 // Javascript actually only permits 1000 and 3000-4999, but the implementation
380 // itself may produce different codes. The length of |reason| is also checked
381 // by Javascript.
382 if (!IsStrictlyValidCloseStatusCode(code) ||
383 reason.size() > kMaximumCloseReasonLength) {
384 // "InternalServerError" is actually used for errors from any endpoint, per
385 // errata 3227 to RFC6455. If the renderer is sending us an invalid code or
386 // reason it must be malfunctioning in some way, and based on that we
387 // interpret this as an internal error.
388 if (SendClose(kWebSocketErrorInternalServerError, "") == CHANNEL_DELETED)
389 return CHANNEL_DELETED;
390 DCHECK_EQ(CONNECTED, state_);
391 SetState(SEND_CLOSED);
392 return CHANNEL_ALIVE;
393 }
394 if (SendClose(code, StreamingUtf8Validator::Validate(reason)
395 ? reason
396 : std::string()) == CHANNEL_DELETED)
397 return CHANNEL_DELETED;
398 DCHECK_EQ(CONNECTED, state_);
399 SetState(SEND_CLOSED);
400 return CHANNEL_ALIVE;
401 }
402
SendAddChannelRequestForTesting(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,WebSocketStreamRequestCreationCallback callback)403 void WebSocketChannel::SendAddChannelRequestForTesting(
404 const GURL& socket_url,
405 const std::vector<std::string>& requested_subprotocols,
406 const url::Origin& origin,
407 const SiteForCookies& site_for_cookies,
408 bool has_storage_access,
409 const IsolationInfo& isolation_info,
410 const HttpRequestHeaders& additional_headers,
411 NetworkTrafficAnnotationTag traffic_annotation,
412 WebSocketStreamRequestCreationCallback callback) {
413 SendAddChannelRequestWithSuppliedCallback(
414 socket_url, requested_subprotocols, origin, site_for_cookies,
415 has_storage_access, isolation_info, additional_headers,
416 traffic_annotation, std::move(callback));
417 }
418
SetClosingHandshakeTimeoutForTesting(base::TimeDelta delay)419 void WebSocketChannel::SetClosingHandshakeTimeoutForTesting(
420 base::TimeDelta delay) {
421 closing_handshake_timeout_ = delay;
422 }
423
SetUnderlyingConnectionCloseTimeoutForTesting(base::TimeDelta delay)424 void WebSocketChannel::SetUnderlyingConnectionCloseTimeoutForTesting(
425 base::TimeDelta delay) {
426 underlying_connection_close_timeout_ = delay;
427 }
428
SendAddChannelRequestWithSuppliedCallback(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,WebSocketStreamRequestCreationCallback callback)429 void WebSocketChannel::SendAddChannelRequestWithSuppliedCallback(
430 const GURL& socket_url,
431 const std::vector<std::string>& requested_subprotocols,
432 const url::Origin& origin,
433 const SiteForCookies& site_for_cookies,
434 bool has_storage_access,
435 const IsolationInfo& isolation_info,
436 const HttpRequestHeaders& additional_headers,
437 NetworkTrafficAnnotationTag traffic_annotation,
438 WebSocketStreamRequestCreationCallback callback) {
439 DCHECK_EQ(FRESHLY_CONSTRUCTED, state_);
440 if (!socket_url.SchemeIsWSOrWSS()) {
441 // TODO(ricea): Kill the renderer (this error should have been caught by
442 // Javascript).
443 event_interface_->OnFailChannel("Invalid scheme", ERR_FAILED,
444 absl::nullopt);
445 // |this| is deleted here.
446 return;
447 }
448 socket_url_ = socket_url;
449 auto connect_delegate = std::make_unique<ConnectDelegate>(this);
450 stream_request_ = std::move(callback).Run(
451 socket_url_, requested_subprotocols, origin, site_for_cookies,
452 has_storage_access, isolation_info, additional_headers,
453 url_request_context_.get(), NetLogWithSource(), traffic_annotation,
454 std::move(connect_delegate));
455 SetState(CONNECTING);
456 }
457
OnCreateURLRequest(URLRequest * request)458 void WebSocketChannel::OnCreateURLRequest(URLRequest* request) {
459 event_interface_->OnCreateURLRequest(request);
460 }
461
OnConnectSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)462 void WebSocketChannel::OnConnectSuccess(
463 std::unique_ptr<WebSocketStream> stream,
464 std::unique_ptr<WebSocketHandshakeResponseInfo> response) {
465 DCHECK(stream);
466 DCHECK_EQ(CONNECTING, state_);
467
468 stream_ = std::move(stream);
469
470 SetState(CONNECTED);
471
472 // |stream_request_| is not used once the connection has succeeded.
473 stream_request_.reset();
474
475 event_interface_->OnAddChannelResponse(
476 std::move(response), stream_->GetSubProtocol(), stream_->GetExtensions());
477 // |this| may have been deleted after OnAddChannelResponse.
478 }
479
OnConnectFailure(const std::string & message,int net_error,absl::optional<int> response_code)480 void WebSocketChannel::OnConnectFailure(const std::string& message,
481 int net_error,
482 absl::optional<int> response_code) {
483 DCHECK_EQ(CONNECTING, state_);
484
485 // Copy the message before we delete its owner.
486 std::string message_copy = message;
487
488 SetState(CLOSED);
489 stream_request_.reset();
490
491 event_interface_->OnFailChannel(message_copy, net_error, response_code);
492 // |this| has been deleted.
493 }
494
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)495 void WebSocketChannel::OnSSLCertificateError(
496 std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
497 ssl_error_callbacks,
498 int net_error,
499 const SSLInfo& ssl_info,
500 bool fatal) {
501 event_interface_->OnSSLCertificateError(
502 std::move(ssl_error_callbacks), socket_url_, net_error, ssl_info, fatal);
503 }
504
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,absl::optional<AuthCredentials> * credentials)505 int WebSocketChannel::OnAuthRequired(
506 const AuthChallengeInfo& auth_info,
507 scoped_refptr<HttpResponseHeaders> response_headers,
508 const IPEndPoint& remote_endpoint,
509 base::OnceCallback<void(const AuthCredentials*)> callback,
510 absl::optional<AuthCredentials>* credentials) {
511 return event_interface_->OnAuthRequired(
512 auth_info, std::move(response_headers), remote_endpoint,
513 std::move(callback), credentials);
514 }
515
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)516 void WebSocketChannel::OnStartOpeningHandshake(
517 std::unique_ptr<WebSocketHandshakeRequestInfo> request) {
518 event_interface_->OnStartOpeningHandshake(std::move(request));
519 }
520
WriteFrames()521 ChannelState WebSocketChannel::WriteFrames() {
522 int result = OK;
523 do {
524 // This use of base::Unretained is safe because this object owns the
525 // WebSocketStream and destroying it cancels all callbacks.
526 result = stream_->WriteFrames(
527 data_being_sent_->frames(),
528 base::BindOnce(base::IgnoreResult(&WebSocketChannel::OnWriteDone),
529 base::Unretained(this), false));
530 if (result != ERR_IO_PENDING) {
531 if (OnWriteDone(true, result) == CHANNEL_DELETED)
532 return CHANNEL_DELETED;
533 // OnWriteDone() returns CHANNEL_DELETED on error. Here |state_| is
534 // guaranteed to be the same as before OnWriteDone() call.
535 }
536 } while (result == OK && data_being_sent_);
537 return CHANNEL_ALIVE;
538 }
539
OnWriteDone(bool synchronous,int result)540 ChannelState WebSocketChannel::OnWriteDone(bool synchronous, int result) {
541 DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
542 DCHECK_NE(CONNECTING, state_);
543 DCHECK_NE(ERR_IO_PENDING, result);
544 DCHECK(data_being_sent_);
545 switch (result) {
546 case OK:
547 if (data_to_send_next_) {
548 data_being_sent_ = std::move(data_to_send_next_);
549 if (!synchronous)
550 return WriteFrames();
551 } else {
552 data_being_sent_.reset();
553 event_interface_->OnSendDataFrameDone();
554 }
555 return CHANNEL_ALIVE;
556
557 // If a recoverable error condition existed, it would go here.
558
559 default:
560 DCHECK_LT(result, 0)
561 << "WriteFrames() should only return OK or ERR_ codes";
562
563 stream_->Close();
564 SetState(CLOSED);
565 DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
566 return CHANNEL_DELETED;
567 }
568 }
569
ReadFrames()570 ChannelState WebSocketChannel::ReadFrames() {
571 DCHECK(stream_);
572 DCHECK(state_ == CONNECTED || state_ == SEND_CLOSED || state_ == CLOSE_WAIT);
573 DCHECK(read_frames_.empty());
574 if (is_reading_) {
575 return CHANNEL_ALIVE;
576 }
577
578 if (!InClosingState() && has_received_close_frame_) {
579 DCHECK(!event_interface_->HasPendingDataFrames());
580 // We've been waiting for the client to consume the frames before
581 // responding to the closing handshake initiated by the server.
582 if (RespondToClosingHandshake() == CHANNEL_DELETED) {
583 return CHANNEL_DELETED;
584 }
585 }
586
587 // TODO(crbug.com/999235): Remove this CHECK.
588 CHECK(event_interface_);
589 while (!event_interface_->HasPendingDataFrames()) {
590 DCHECK(stream_);
591 // This use of base::Unretained is safe because this object owns the
592 // WebSocketStream, and any pending reads will be cancelled when it is
593 // destroyed.
594 const int result = stream_->ReadFrames(
595 &read_frames_,
596 base::BindOnce(base::IgnoreResult(&WebSocketChannel::OnReadDone),
597 base::Unretained(this), false));
598 if (result == ERR_IO_PENDING) {
599 is_reading_ = true;
600 return CHANNEL_ALIVE;
601 }
602 if (OnReadDone(true, result) == CHANNEL_DELETED) {
603 return CHANNEL_DELETED;
604 }
605 DCHECK_NE(CLOSED, state_);
606 // TODO(crbug.com/999235): Remove this CHECK.
607 CHECK(event_interface_);
608 }
609 return CHANNEL_ALIVE;
610 }
611
OnReadDone(bool synchronous,int result)612 ChannelState WebSocketChannel::OnReadDone(bool synchronous, int result) {
613 DVLOG(3) << "WebSocketChannel::OnReadDone synchronous?" << synchronous
614 << ", result=" << result
615 << ", read_frames_.size=" << read_frames_.size();
616 DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
617 DCHECK_NE(CONNECTING, state_);
618 DCHECK_NE(ERR_IO_PENDING, result);
619 switch (result) {
620 case OK:
621 // ReadFrames() must use ERR_CONNECTION_CLOSED for a closed connection
622 // with no data read, not an empty response.
623 DCHECK(!read_frames_.empty())
624 << "ReadFrames() returned OK, but nothing was read.";
625 for (auto& read_frame : read_frames_) {
626 if (HandleFrame(std::move(read_frame)) == CHANNEL_DELETED)
627 return CHANNEL_DELETED;
628 }
629 read_frames_.clear();
630 DCHECK_NE(CLOSED, state_);
631 if (!synchronous) {
632 is_reading_ = false;
633 if (!event_interface_->HasPendingDataFrames()) {
634 return ReadFrames();
635 }
636 }
637 return CHANNEL_ALIVE;
638
639 case ERR_WS_PROTOCOL_ERROR:
640 // This could be kWebSocketErrorProtocolError (specifically, non-minimal
641 // encoding of payload length) or kWebSocketErrorMessageTooBig, or an
642 // extension-specific error.
643 FailChannel("Invalid frame header", kWebSocketErrorProtocolError,
644 "WebSocket Protocol Error");
645 return CHANNEL_DELETED;
646
647 default:
648 DCHECK_LT(result, 0)
649 << "ReadFrames() should only return OK or ERR_ codes";
650
651 stream_->Close();
652 SetState(CLOSED);
653
654 uint16_t code = kWebSocketErrorAbnormalClosure;
655 std::string reason = "";
656 bool was_clean = false;
657 if (has_received_close_frame_) {
658 code = received_close_code_;
659 reason = received_close_reason_;
660 was_clean = (result == ERR_CONNECTION_CLOSED);
661 }
662
663 DoDropChannel(was_clean, code, reason);
664 return CHANNEL_DELETED;
665 }
666 }
667
HandleFrame(std::unique_ptr<WebSocketFrame> frame)668 ChannelState WebSocketChannel::HandleFrame(
669 std::unique_ptr<WebSocketFrame> frame) {
670 if (frame->header.masked) {
671 // RFC6455 Section 5.1 "A client MUST close a connection if it detects a
672 // masked frame."
673 FailChannel(
674 "A server must not mask any frames that it sends to the "
675 "client.",
676 kWebSocketErrorProtocolError, "Masked frame from server");
677 return CHANNEL_DELETED;
678 }
679 const WebSocketFrameHeader::OpCode opcode = frame->header.opcode;
680 DCHECK(!WebSocketFrameHeader::IsKnownControlOpCode(opcode) ||
681 frame->header.final);
682 if (frame->header.reserved1 || frame->header.reserved2 ||
683 frame->header.reserved3) {
684 FailChannel(
685 base::StringPrintf("One or more reserved bits are on: reserved1 = %d, "
686 "reserved2 = %d, reserved3 = %d",
687 static_cast<int>(frame->header.reserved1),
688 static_cast<int>(frame->header.reserved2),
689 static_cast<int>(frame->header.reserved3)),
690 kWebSocketErrorProtocolError, "Invalid reserved bit");
691 return CHANNEL_DELETED;
692 }
693
694 // Respond to the frame appropriately to its type.
695 return HandleFrameByState(
696 opcode, frame->header.final,
697 base::make_span(frame->payload, base::checked_cast<size_t>(
698 frame->header.payload_length)));
699 }
700
HandleFrameByState(const WebSocketFrameHeader::OpCode opcode,bool final,base::span<const char> payload)701 ChannelState WebSocketChannel::HandleFrameByState(
702 const WebSocketFrameHeader::OpCode opcode,
703 bool final,
704 base::span<const char> payload) {
705 DCHECK_NE(RECV_CLOSED, state_)
706 << "HandleFrame() does not support being called re-entrantly from within "
707 "SendClose()";
708 DCHECK_NE(CLOSED, state_);
709 if (state_ == CLOSE_WAIT) {
710 std::string frame_name;
711 GetFrameTypeForOpcode(opcode, &frame_name);
712
713 // FailChannel() won't send another Close frame.
714 FailChannel(frame_name + " received after close",
715 kWebSocketErrorProtocolError, "");
716 return CHANNEL_DELETED;
717 }
718 switch (opcode) {
719 case WebSocketFrameHeader::kOpCodeText: // fall-thru
720 case WebSocketFrameHeader::kOpCodeBinary:
721 case WebSocketFrameHeader::kOpCodeContinuation:
722 return HandleDataFrame(opcode, final, std::move(payload));
723
724 case WebSocketFrameHeader::kOpCodePing:
725 DVLOG(1) << "Got Ping of size " << payload.size();
726 if (state_ == CONNECTED) {
727 auto buffer = base::MakeRefCounted<IOBufferWithSize>(payload.size());
728 base::ranges::copy(payload, buffer->data());
729 return SendFrameInternal(true, WebSocketFrameHeader::kOpCodePong,
730 std::move(buffer), payload.size());
731 }
732 DVLOG(3) << "Ignored ping in state " << state_;
733 return CHANNEL_ALIVE;
734
735 case WebSocketFrameHeader::kOpCodePong:
736 DVLOG(1) << "Got Pong of size " << payload.size();
737 // There is no need to do anything with pong messages.
738 return CHANNEL_ALIVE;
739
740 case WebSocketFrameHeader::kOpCodeClose: {
741 uint16_t code = kWebSocketNormalClosure;
742 std::string reason;
743 std::string message;
744 if (!ParseClose(payload, &code, &reason, &message)) {
745 FailChannel(message, code, reason);
746 return CHANNEL_DELETED;
747 }
748 // TODO(ricea): Find a way to safely log the message from the close
749 // message (escape control codes and so on).
750 return HandleCloseFrame(code, reason);
751 }
752
753 default:
754 FailChannel(base::StringPrintf("Unrecognized frame opcode: %d", opcode),
755 kWebSocketErrorProtocolError, "Unknown opcode");
756 return CHANNEL_DELETED;
757 }
758 }
759
HandleDataFrame(WebSocketFrameHeader::OpCode opcode,bool final,base::span<const char> payload)760 ChannelState WebSocketChannel::HandleDataFrame(
761 WebSocketFrameHeader::OpCode opcode,
762 bool final,
763 base::span<const char> payload) {
764 DVLOG(3) << "WebSocketChannel::HandleDataFrame opcode=" << opcode
765 << ", final?" << final << ", data=" << (void*)payload.data()
766 << ", size=" << payload.size();
767 if (state_ != CONNECTED) {
768 DVLOG(3) << "Ignored data packet received in state " << state_;
769 return CHANNEL_ALIVE;
770 }
771 if (has_received_close_frame_) {
772 DVLOG(3) << "Ignored data packet as we've received a close frame.";
773 return CHANNEL_ALIVE;
774 }
775 DCHECK(opcode == WebSocketFrameHeader::kOpCodeContinuation ||
776 opcode == WebSocketFrameHeader::kOpCodeText ||
777 opcode == WebSocketFrameHeader::kOpCodeBinary);
778 const bool got_continuation =
779 (opcode == WebSocketFrameHeader::kOpCodeContinuation);
780 if (got_continuation != expecting_to_handle_continuation_) {
781 const std::string console_log = got_continuation
782 ? "Received unexpected continuation frame."
783 : "Received start of new message but previous message is unfinished.";
784 const std::string reason = got_continuation
785 ? "Unexpected continuation"
786 : "Previous data frame unfinished";
787 FailChannel(console_log, kWebSocketErrorProtocolError, reason);
788 return CHANNEL_DELETED;
789 }
790 expecting_to_handle_continuation_ = !final;
791 WebSocketFrameHeader::OpCode opcode_to_send = opcode;
792 if (!initial_frame_forwarded_ &&
793 opcode == WebSocketFrameHeader::kOpCodeContinuation) {
794 opcode_to_send = receiving_text_message_
795 ? WebSocketFrameHeader::kOpCodeText
796 : WebSocketFrameHeader::kOpCodeBinary;
797 }
798 if (opcode == WebSocketFrameHeader::kOpCodeText ||
799 (opcode == WebSocketFrameHeader::kOpCodeContinuation &&
800 receiving_text_message_)) {
801 // This call is not redundant when size == 0 because it tells us what
802 // the current state is.
803 StreamingUtf8Validator::State state =
804 incoming_utf8_validator_.AddBytes(base::as_byte_span(payload));
805 if (state == StreamingUtf8Validator::INVALID ||
806 (state == StreamingUtf8Validator::VALID_MIDPOINT && final)) {
807 FailChannel("Could not decode a text frame as UTF-8.",
808 kWebSocketErrorProtocolError, "Invalid UTF-8 in text frame");
809 return CHANNEL_DELETED;
810 }
811 receiving_text_message_ = !final;
812 DCHECK(!final || state == StreamingUtf8Validator::VALID_ENDPOINT);
813 }
814 if (payload.size() == 0U && !final)
815 return CHANNEL_ALIVE;
816
817 initial_frame_forwarded_ = !final;
818 // Sends the received frame to the renderer process.
819 event_interface_->OnDataFrame(final, opcode_to_send, payload);
820 return CHANNEL_ALIVE;
821 }
822
HandleCloseFrame(uint16_t code,const std::string & reason)823 ChannelState WebSocketChannel::HandleCloseFrame(uint16_t code,
824 const std::string& reason) {
825 DVLOG(1) << "Got Close with code " << code;
826 switch (state_) {
827 case CONNECTED:
828 has_received_close_frame_ = true;
829 received_close_code_ = code;
830 received_close_reason_ = reason;
831 if (event_interface_->HasPendingDataFrames()) {
832 // We have some data to be sent to the renderer before sending this
833 // frame.
834 return CHANNEL_ALIVE;
835 }
836 return RespondToClosingHandshake();
837
838 case SEND_CLOSED:
839 SetState(CLOSE_WAIT);
840 DCHECK(close_timer_.IsRunning());
841 close_timer_.Stop();
842 // This use of base::Unretained() is safe because we stop the timer
843 // in the destructor.
844 close_timer_.Start(FROM_HERE, underlying_connection_close_timeout_,
845 base::BindOnce(&WebSocketChannel::CloseTimeout,
846 base::Unretained(this)));
847
848 // From RFC6455 section 7.1.5: "Each endpoint
849 // will see the status code sent by the other end as _The WebSocket
850 // Connection Close Code_."
851 has_received_close_frame_ = true;
852 received_close_code_ = code;
853 received_close_reason_ = reason;
854 break;
855
856 default:
857 LOG(DFATAL) << "Got Close in unexpected state " << state_;
858 break;
859 }
860 return CHANNEL_ALIVE;
861 }
862
RespondToClosingHandshake()863 ChannelState WebSocketChannel::RespondToClosingHandshake() {
864 DCHECK(has_received_close_frame_);
865 DCHECK_EQ(CONNECTED, state_);
866 SetState(RECV_CLOSED);
867 if (SendClose(received_close_code_, received_close_reason_) ==
868 CHANNEL_DELETED)
869 return CHANNEL_DELETED;
870 DCHECK_EQ(RECV_CLOSED, state_);
871
872 SetState(CLOSE_WAIT);
873 DCHECK(!close_timer_.IsRunning());
874 // This use of base::Unretained() is safe because we stop the timer
875 // in the destructor.
876 close_timer_.Start(
877 FROM_HERE, underlying_connection_close_timeout_,
878 base::BindOnce(&WebSocketChannel::CloseTimeout, base::Unretained(this)));
879
880 event_interface_->OnClosingHandshake();
881 return CHANNEL_ALIVE;
882 }
883
SendFrameInternal(bool fin,WebSocketFrameHeader::OpCode op_code,scoped_refptr<IOBuffer> buffer,uint64_t buffer_size)884 ChannelState WebSocketChannel::SendFrameInternal(
885 bool fin,
886 WebSocketFrameHeader::OpCode op_code,
887 scoped_refptr<IOBuffer> buffer,
888 uint64_t buffer_size) {
889 DCHECK(state_ == CONNECTED || state_ == RECV_CLOSED);
890 DCHECK(stream_);
891
892 auto frame = std::make_unique<WebSocketFrame>(op_code);
893 WebSocketFrameHeader& header = frame->header;
894 header.final = fin;
895 header.masked = true;
896 header.payload_length = buffer_size;
897 frame->payload = buffer->data();
898
899 if (data_being_sent_) {
900 // Either the link to the WebSocket server is saturated, or several messages
901 // are being sent in a batch.
902 if (!data_to_send_next_)
903 data_to_send_next_ = std::make_unique<SendBuffer>();
904 data_to_send_next_->AddFrame(std::move(frame), std::move(buffer));
905 return CHANNEL_ALIVE;
906 }
907
908 data_being_sent_ = std::make_unique<SendBuffer>();
909 data_being_sent_->AddFrame(std::move(frame), std::move(buffer));
910 return WriteFrames();
911 }
912
FailChannel(const std::string & message,uint16_t code,const std::string & reason)913 void WebSocketChannel::FailChannel(const std::string& message,
914 uint16_t code,
915 const std::string& reason) {
916 DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
917 DCHECK_NE(CONNECTING, state_);
918 DCHECK_NE(CLOSED, state_);
919
920 stream_->GetNetLogWithSource().AddEvent(
921 net::NetLogEventType::WEBSOCKET_INVALID_FRAME,
922 [&] { return NetLogFailParam(code, reason, message); });
923
924 if (state_ == CONNECTED) {
925 if (SendClose(code, reason) == CHANNEL_DELETED)
926 return;
927 }
928
929 // Careful study of RFC6455 section 7.1.7 and 7.1.1 indicates the browser
930 // should close the connection itself without waiting for the closing
931 // handshake.
932 stream_->Close();
933 SetState(CLOSED);
934 event_interface_->OnFailChannel(message, ERR_FAILED, absl::nullopt);
935 }
936
SendClose(uint16_t code,const std::string & reason)937 ChannelState WebSocketChannel::SendClose(uint16_t code,
938 const std::string& reason) {
939 DCHECK(state_ == CONNECTED || state_ == RECV_CLOSED);
940 DCHECK_LE(reason.size(), kMaximumCloseReasonLength);
941 scoped_refptr<IOBuffer> body;
942 uint64_t size = 0;
943 if (code == kWebSocketErrorNoStatusReceived) {
944 // Special case: translate kWebSocketErrorNoStatusReceived into a Close
945 // frame with no payload.
946 DCHECK(reason.empty());
947 body = base::MakeRefCounted<IOBufferWithSize>();
948 } else {
949 const size_t payload_length = kWebSocketCloseCodeLength + reason.length();
950 body = base::MakeRefCounted<IOBufferWithSize>(payload_length);
951 size = payload_length;
952 base::WriteBigEndian(body->data(), code);
953 static_assert(sizeof(code) == kWebSocketCloseCodeLength,
954 "they should both be two");
955 base::ranges::copy(reason, body->data() + kWebSocketCloseCodeLength);
956 }
957
958 return SendFrameInternal(true, WebSocketFrameHeader::kOpCodeClose,
959 std::move(body), size);
960 }
961
ParseClose(base::span<const char> payload,uint16_t * code,std::string * reason,std::string * message)962 bool WebSocketChannel::ParseClose(base::span<const char> payload,
963 uint16_t* code,
964 std::string* reason,
965 std::string* message) {
966 const uint64_t size = static_cast<uint64_t>(payload.size());
967 reason->clear();
968 if (size < kWebSocketCloseCodeLength) {
969 if (size == 0U) {
970 *code = kWebSocketErrorNoStatusReceived;
971 return true;
972 }
973
974 DVLOG(1) << "Close frame with payload size " << size << " received "
975 << "(the first byte is " << std::hex
976 << static_cast<int>(payload[0]) << ")";
977 *code = kWebSocketErrorProtocolError;
978 *message =
979 "Received a broken close frame containing an invalid size body.";
980 return false;
981 }
982
983 const char* data = payload.data();
984 uint16_t unchecked_code = 0;
985 base::ReadBigEndian(reinterpret_cast<const uint8_t*>(data), &unchecked_code);
986 static_assert(sizeof(unchecked_code) == kWebSocketCloseCodeLength,
987 "they should both be two bytes");
988
989 switch (unchecked_code) {
990 case kWebSocketErrorNoStatusReceived:
991 case kWebSocketErrorAbnormalClosure:
992 case kWebSocketErrorTlsHandshake:
993 *code = kWebSocketErrorProtocolError;
994 *message =
995 "Received a broken close frame containing a reserved status code.";
996 return false;
997
998 default:
999 *code = unchecked_code;
1000 break;
1001 }
1002
1003 std::string text(data + kWebSocketCloseCodeLength, data + size);
1004 if (StreamingUtf8Validator::Validate(text)) {
1005 reason->swap(text);
1006 return true;
1007 }
1008
1009 *code = kWebSocketErrorProtocolError;
1010 *reason = "Invalid UTF-8 in Close frame";
1011 *message = "Received a broken close frame containing invalid UTF-8.";
1012 return false;
1013 }
1014
DoDropChannel(bool was_clean,uint16_t code,const std::string & reason)1015 void WebSocketChannel::DoDropChannel(bool was_clean,
1016 uint16_t code,
1017 const std::string& reason) {
1018 event_interface_->OnDropChannel(was_clean, code, reason);
1019 }
1020
CloseTimeout()1021 void WebSocketChannel::CloseTimeout() {
1022 stream_->GetNetLogWithSource().AddEvent(
1023 net::NetLogEventType::WEBSOCKET_CLOSE_TIMEOUT);
1024 stream_->Close();
1025 SetState(CLOSED);
1026 if (has_received_close_frame_) {
1027 DoDropChannel(true, received_close_code_, received_close_reason_);
1028 } else {
1029 DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
1030 }
1031 // |this| has been deleted.
1032 }
1033
1034 } // namespace net
1035