1 // Copyright 2018 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_http2_handshake_stream.h"
6
7 #include <cstddef>
8 #include <set>
9 #include <utility>
10
11 #include "base/check_op.h"
12 #include "base/functional/bind.h"
13 #include "base/notreached.h"
14 #include "base/strings/stringprintf.h"
15 #include "base/time/time.h"
16 #include "net/base/ip_endpoint.h"
17 #include "net/http/http_request_headers.h"
18 #include "net/http/http_request_info.h"
19 #include "net/http/http_response_headers.h"
20 #include "net/http/http_status_code.h"
21 #include "net/spdy/spdy_http_utils.h"
22 #include "net/spdy/spdy_session.h"
23 #include "net/traffic_annotation/network_traffic_annotation.h"
24 #include "net/websockets/websocket_basic_stream.h"
25 #include "net/websockets/websocket_deflate_parameters.h"
26 #include "net/websockets/websocket_deflate_predictor_impl.h"
27 #include "net/websockets/websocket_deflate_stream.h"
28 #include "net/websockets/websocket_deflater.h"
29 #include "net/websockets/websocket_handshake_constants.h"
30 #include "net/websockets/websocket_handshake_request_info.h"
31
32 namespace net {
33
34 namespace {
35
ValidateStatus(const HttpResponseHeaders * headers)36 bool ValidateStatus(const HttpResponseHeaders* headers) {
37 return headers->GetStatusLine() == "HTTP/1.1 200";
38 }
39
40 } // namespace
41
WebSocketHttp2HandshakeStream(base::WeakPtr<SpdySession> session,WebSocketStream::ConnectDelegate * connect_delegate,std::vector<std::string> requested_sub_protocols,std::vector<std::string> requested_extensions,WebSocketStreamRequestAPI * request,std::set<std::string> dns_aliases)42 WebSocketHttp2HandshakeStream::WebSocketHttp2HandshakeStream(
43 base::WeakPtr<SpdySession> session,
44 WebSocketStream::ConnectDelegate* connect_delegate,
45 std::vector<std::string> requested_sub_protocols,
46 std::vector<std::string> requested_extensions,
47 WebSocketStreamRequestAPI* request,
48 std::set<std::string> dns_aliases)
49 : session_(session),
50 connect_delegate_(connect_delegate),
51 requested_sub_protocols_(requested_sub_protocols),
52 requested_extensions_(requested_extensions),
53 stream_request_(request),
54 dns_aliases_(std::move(dns_aliases)) {
55 DCHECK(connect_delegate);
56 DCHECK(request);
57 }
58
~WebSocketHttp2HandshakeStream()59 WebSocketHttp2HandshakeStream::~WebSocketHttp2HandshakeStream() {
60 spdy_stream_request_.reset();
61 RecordHandshakeResult(result_);
62 }
63
RegisterRequest(const HttpRequestInfo * request_info)64 void WebSocketHttp2HandshakeStream::RegisterRequest(
65 const HttpRequestInfo* request_info) {
66 DCHECK(request_info);
67 DCHECK(request_info->traffic_annotation.is_valid());
68 request_info_ = request_info;
69 }
70
InitializeStream(bool can_send_early,RequestPriority priority,const NetLogWithSource & net_log,CompletionOnceCallback callback)71 int WebSocketHttp2HandshakeStream::InitializeStream(
72 bool can_send_early,
73 RequestPriority priority,
74 const NetLogWithSource& net_log,
75 CompletionOnceCallback callback) {
76 priority_ = priority;
77 net_log_ = net_log;
78 return OK;
79 }
80
SendRequest(const HttpRequestHeaders & headers,HttpResponseInfo * response,CompletionOnceCallback callback)81 int WebSocketHttp2HandshakeStream::SendRequest(
82 const HttpRequestHeaders& headers,
83 HttpResponseInfo* response,
84 CompletionOnceCallback callback) {
85 DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey));
86 DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol));
87 DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions));
88 DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin));
89 DCHECK(headers.HasHeader(websockets::kUpgrade));
90 DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection));
91 DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion));
92
93 if (!session_) {
94 const int rv = ERR_CONNECTION_CLOSED;
95 OnFailure("Connection closed before sending request.", rv, absl::nullopt);
96 return rv;
97 }
98
99 http_response_info_ = response;
100
101 IPEndPoint address;
102 int result = session_->GetPeerAddress(&address);
103 if (result != OK) {
104 OnFailure("Error getting IP address.", result, absl::nullopt);
105 return result;
106 }
107 http_response_info_->remote_endpoint = address;
108
109 auto request = std::make_unique<WebSocketHandshakeRequestInfo>(
110 request_info_->url, base::Time::Now());
111 request->headers.CopyFrom(headers);
112
113 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
114 requested_extensions_, &request->headers);
115 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
116 requested_sub_protocols_, &request->headers);
117
118 CreateSpdyHeadersFromHttpRequestForWebSocket(
119 request_info_->url, request->headers, &http2_request_headers_);
120
121 connect_delegate_->OnStartOpeningHandshake(std::move(request));
122
123 callback_ = std::move(callback);
124 spdy_stream_request_ = std::make_unique<SpdyStreamRequest>();
125 // The initial request for the WebSocket is a CONNECT, so there is no need to
126 // call ConfirmHandshake().
127 int rv = spdy_stream_request_->StartRequest(
128 SPDY_BIDIRECTIONAL_STREAM, session_, request_info_->url, true, priority_,
129 request_info_->socket_tag, net_log_,
130 base::BindOnce(&WebSocketHttp2HandshakeStream::StartRequestCallback,
131 base::Unretained(this)),
132 NetworkTrafficAnnotationTag(request_info_->traffic_annotation));
133 if (rv == OK) {
134 StartRequestCallback(rv);
135 return ERR_IO_PENDING;
136 }
137 return rv;
138 }
139
ReadResponseHeaders(CompletionOnceCallback callback)140 int WebSocketHttp2HandshakeStream::ReadResponseHeaders(
141 CompletionOnceCallback callback) {
142 if (stream_closed_)
143 return stream_error_;
144
145 if (response_headers_complete_)
146 return ValidateResponse();
147
148 callback_ = std::move(callback);
149 return ERR_IO_PENDING;
150 }
151
ReadResponseBody(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)152 int WebSocketHttp2HandshakeStream::ReadResponseBody(
153 IOBuffer* buf,
154 int buf_len,
155 CompletionOnceCallback callback) {
156 // Callers should instead call Upgrade() to get a WebSocketStream
157 // and call ReadFrames() on that.
158 NOTREACHED();
159 return OK;
160 }
161
Close(bool not_reusable)162 void WebSocketHttp2HandshakeStream::Close(bool not_reusable) {
163 spdy_stream_request_.reset();
164 if (stream_) {
165 stream_ = nullptr;
166 stream_closed_ = true;
167 stream_error_ = ERR_CONNECTION_CLOSED;
168 }
169 stream_adapter_.reset();
170 }
171
IsResponseBodyComplete() const172 bool WebSocketHttp2HandshakeStream::IsResponseBodyComplete() const {
173 return false;
174 }
175
IsConnectionReused() const176 bool WebSocketHttp2HandshakeStream::IsConnectionReused() const {
177 return true;
178 }
179
SetConnectionReused()180 void WebSocketHttp2HandshakeStream::SetConnectionReused() {}
181
CanReuseConnection() const182 bool WebSocketHttp2HandshakeStream::CanReuseConnection() const {
183 return false;
184 }
185
GetTotalReceivedBytes() const186 int64_t WebSocketHttp2HandshakeStream::GetTotalReceivedBytes() const {
187 return stream_ ? stream_->raw_received_bytes() : 0;
188 }
189
GetTotalSentBytes() const190 int64_t WebSocketHttp2HandshakeStream::GetTotalSentBytes() const {
191 return stream_ ? stream_->raw_sent_bytes() : 0;
192 }
193
GetAlternativeService(AlternativeService * alternative_service) const194 bool WebSocketHttp2HandshakeStream::GetAlternativeService(
195 AlternativeService* alternative_service) const {
196 return false;
197 }
198
GetLoadTimingInfo(LoadTimingInfo * load_timing_info) const199 bool WebSocketHttp2HandshakeStream::GetLoadTimingInfo(
200 LoadTimingInfo* load_timing_info) const {
201 return stream_ && stream_->GetLoadTimingInfo(load_timing_info);
202 }
203
GetSSLInfo(SSLInfo * ssl_info)204 void WebSocketHttp2HandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {
205 if (stream_)
206 stream_->GetSSLInfo(ssl_info);
207 }
208
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info)209 void WebSocketHttp2HandshakeStream::GetSSLCertRequestInfo(
210 SSLCertRequestInfo* cert_request_info) {
211 // A multiplexed stream cannot request client certificates. Client
212 // authentication may only occur during the initial SSL handshake.
213 NOTREACHED();
214 }
215
GetRemoteEndpoint(IPEndPoint * endpoint)216 int WebSocketHttp2HandshakeStream::GetRemoteEndpoint(IPEndPoint* endpoint) {
217 if (!session_)
218 return ERR_SOCKET_NOT_CONNECTED;
219
220 return session_->GetRemoteEndpoint(endpoint);
221 }
222
PopulateNetErrorDetails(NetErrorDetails *)223 void WebSocketHttp2HandshakeStream::PopulateNetErrorDetails(
224 NetErrorDetails* /*details*/) {
225 return;
226 }
227
Drain(HttpNetworkSession * session)228 void WebSocketHttp2HandshakeStream::Drain(HttpNetworkSession* session) {
229 Close(true /* not_reusable */);
230 }
231
SetPriority(RequestPriority priority)232 void WebSocketHttp2HandshakeStream::SetPriority(RequestPriority priority) {
233 priority_ = priority;
234 if (stream_)
235 stream_->SetPriority(priority_);
236 }
237
238 std::unique_ptr<HttpStream>
RenewStreamForAuth()239 WebSocketHttp2HandshakeStream::RenewStreamForAuth() {
240 // Renewing the stream is not supported.
241 return nullptr;
242 }
243
GetDnsAliases() const244 const std::set<std::string>& WebSocketHttp2HandshakeStream::GetDnsAliases()
245 const {
246 return dns_aliases_;
247 }
248
GetAcceptChViaAlps() const249 base::StringPiece WebSocketHttp2HandshakeStream::GetAcceptChViaAlps() const {
250 return {};
251 }
252
Upgrade()253 std::unique_ptr<WebSocketStream> WebSocketHttp2HandshakeStream::Upgrade() {
254 DCHECK(extension_params_.get());
255
256 stream_adapter_->DetachDelegate();
257 std::unique_ptr<WebSocketStream> basic_stream =
258 std::make_unique<WebSocketBasicStream>(std::move(stream_adapter_),
259 nullptr, sub_protocol_,
260 extensions_, net_log_);
261
262 if (!extension_params_->deflate_enabled)
263 return basic_stream;
264
265 return std::make_unique<WebSocketDeflateStream>(
266 std::move(basic_stream), extension_params_->deflate_parameters,
267 std::make_unique<WebSocketDeflatePredictorImpl>());
268 }
269
270 base::WeakPtr<WebSocketHandshakeStreamBase>
GetWeakPtr()271 WebSocketHttp2HandshakeStream::GetWeakPtr() {
272 return weak_ptr_factory_.GetWeakPtr();
273 }
274
OnHeadersSent()275 void WebSocketHttp2HandshakeStream::OnHeadersSent() {
276 std::move(callback_).Run(OK);
277 }
278
OnHeadersReceived(const spdy::Http2HeaderBlock & response_headers)279 void WebSocketHttp2HandshakeStream::OnHeadersReceived(
280 const spdy::Http2HeaderBlock& response_headers) {
281 DCHECK(!response_headers_complete_);
282 DCHECK(http_response_info_);
283
284 response_headers_complete_ = true;
285
286 const int rv =
287 SpdyHeadersToHttpResponse(response_headers, http_response_info_);
288 DCHECK_NE(rv, ERR_INCOMPLETE_HTTP2_HEADERS);
289
290 http_response_info_->response_time = stream_->response_time();
291 // Do not store SSLInfo in the response here, HttpNetworkTransaction will take
292 // care of that part.
293 http_response_info_->was_alpn_negotiated = true;
294 http_response_info_->request_time = stream_->GetRequestTime();
295 http_response_info_->connection_info =
296 HttpResponseInfo::CONNECTION_INFO_HTTP2;
297 http_response_info_->alpn_negotiated_protocol =
298 HttpResponseInfo::ConnectionInfoToString(
299 http_response_info_->connection_info);
300
301 if (callback_)
302 std::move(callback_).Run(ValidateResponse());
303 }
304
OnClose(int status)305 void WebSocketHttp2HandshakeStream::OnClose(int status) {
306 DCHECK(stream_adapter_);
307 DCHECK_GT(ERR_IO_PENDING, status);
308
309 stream_closed_ = true;
310 stream_error_ = status;
311 stream_ = nullptr;
312
313 stream_adapter_.reset();
314
315 // If response headers have already been received,
316 // then ValidateResponse() sets |result_|.
317 if (!response_headers_complete_)
318 result_ = HandshakeResult::HTTP2_FAILED;
319
320 OnFailure(std::string("Stream closed with error: ") + ErrorToString(status),
321 status, absl::nullopt);
322
323 if (callback_)
324 std::move(callback_).Run(status);
325 }
326
StartRequestCallback(int rv)327 void WebSocketHttp2HandshakeStream::StartRequestCallback(int rv) {
328 DCHECK(callback_);
329 if (rv != OK) {
330 spdy_stream_request_.reset();
331 std::move(callback_).Run(rv);
332 return;
333 }
334 stream_ = spdy_stream_request_->ReleaseStream();
335 spdy_stream_request_.reset();
336 stream_adapter_ =
337 std::make_unique<WebSocketSpdyStreamAdapter>(stream_, this, net_log_);
338 rv = stream_->SendRequestHeaders(std::move(http2_request_headers_),
339 MORE_DATA_TO_SEND);
340 // SendRequestHeaders() always returns asynchronously,
341 // and instead of taking a callback, it calls OnHeadersSent().
342 DCHECK_EQ(ERR_IO_PENDING, rv);
343 }
344
ValidateResponse()345 int WebSocketHttp2HandshakeStream::ValidateResponse() {
346 DCHECK(http_response_info_);
347 const HttpResponseHeaders* headers = http_response_info_->headers.get();
348 const int response_code = headers->response_code();
349 switch (response_code) {
350 case HTTP_OK:
351 return ValidateUpgradeResponse(headers);
352
353 // We need to pass these through for authentication to work.
354 case HTTP_UNAUTHORIZED:
355 case HTTP_PROXY_AUTHENTICATION_REQUIRED:
356 return OK;
357
358 // Other status codes are potentially risky (see the warnings in the
359 // WHATWG WebSocket API spec) and so are dropped by default.
360 default:
361 OnFailure(
362 base::StringPrintf(
363 "Error during WebSocket handshake: Unexpected response code: %d",
364 headers->response_code()),
365 ERR_FAILED, headers->response_code());
366 result_ = HandshakeResult::HTTP2_INVALID_STATUS;
367 return ERR_INVALID_RESPONSE;
368 }
369 }
370
ValidateUpgradeResponse(const HttpResponseHeaders * headers)371 int WebSocketHttp2HandshakeStream::ValidateUpgradeResponse(
372 const HttpResponseHeaders* headers) {
373 extension_params_ = std::make_unique<WebSocketExtensionParams>();
374 std::string failure_message;
375 if (!ValidateStatus(headers)) {
376 result_ = HandshakeResult::HTTP2_INVALID_STATUS;
377 } else if (!ValidateSubProtocol(headers, requested_sub_protocols_,
378 &sub_protocol_, &failure_message)) {
379 result_ = HandshakeResult::HTTP2_FAILED_SUBPROTO;
380 } else if (!ValidateExtensions(headers, &extensions_, &failure_message,
381 extension_params_.get())) {
382 result_ = HandshakeResult::HTTP2_FAILED_EXTENSIONS;
383 } else {
384 result_ = HandshakeResult::HTTP2_CONNECTED;
385 return OK;
386 }
387
388 const int rv = ERR_INVALID_RESPONSE;
389 OnFailure("Error during WebSocket handshake: " + failure_message, rv,
390 absl::nullopt);
391 return rv;
392 }
393
OnFailure(const std::string & message,int net_error,absl::optional<int> response_code)394 void WebSocketHttp2HandshakeStream::OnFailure(
395 const std::string& message,
396 int net_error,
397 absl::optional<int> response_code) {
398 stream_request_->OnFailure(message, net_error, response_code);
399 }
400
401 } // namespace net
402