1 // Copyright 2018 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_handshake_stream_base.h"
6
7 #include <stddef.h>
8
9 #include "base/metrics/histogram_functions.h"
10 #include "base/metrics/histogram_macros.h"
11 #include "base/strings/strcat.h"
12 #include "base/strings/string_util.h"
13 #include "net/http/http_request_headers.h"
14 #include "net/http/http_response_headers.h"
15 #include "net/websockets/websocket_extension.h"
16 #include "net/websockets/websocket_extension_parser.h"
17 #include "net/websockets/websocket_handshake_constants.h"
18
19 namespace net {
20
21 namespace {
22
AddVectorHeaderIfNonEmpty(const char * name,const std::vector<std::string> & value,HttpRequestHeaders * headers)23 size_t AddVectorHeaderIfNonEmpty(const char* name,
24 const std::vector<std::string>& value,
25 HttpRequestHeaders* headers) {
26 if (value.empty()) {
27 return 0u;
28 }
29 std::string joined = base::JoinString(value, ", ");
30 const size_t size = joined.size();
31 headers->SetHeader(name, std::move(joined));
32 return size;
33 }
34
35 } // namespace
36
37 // static
MultipleHeaderValuesMessage(const std::string & header_name)38 std::string WebSocketHandshakeStreamBase::MultipleHeaderValuesMessage(
39 const std::string& header_name) {
40 return base::StrCat(
41 {"'", header_name,
42 "' header must not appear more than once in a response"});
43 }
44
45 // static
AddVectorHeaders(const std::vector<std::string> & extensions,const std::vector<std::string> & protocols,HttpRequestHeaders * headers)46 void WebSocketHandshakeStreamBase::AddVectorHeaders(
47 const std::vector<std::string>& extensions,
48 const std::vector<std::string>& protocols,
49 HttpRequestHeaders* headers) {
50 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions, extensions,
51 headers);
52 const size_t protocol_header_size = AddVectorHeaderIfNonEmpty(
53 websockets::kSecWebSocketProtocol, protocols, headers);
54 base::UmaHistogramCounts10000("Net.WebSocket.ProtocolHeaderSize",
55 protocol_header_size);
56 }
57
58 // static
ValidateSubProtocol(const HttpResponseHeaders * headers,const std::vector<std::string> & requested_sub_protocols,std::string * sub_protocol,std::string * failure_message)59 bool WebSocketHandshakeStreamBase::ValidateSubProtocol(
60 const HttpResponseHeaders* headers,
61 const std::vector<std::string>& requested_sub_protocols,
62 std::string* sub_protocol,
63 std::string* failure_message) {
64 size_t iter = 0;
65 std::optional<std::string> value;
66 while (std::optional<std::string_view> maybe_value = headers->EnumerateHeader(
67 &iter, websockets::kSecWebSocketProtocol)) {
68 if (value) {
69 *failure_message =
70 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
71 return false;
72 }
73 if (requested_sub_protocols.empty()) {
74 *failure_message =
75 base::StrCat({"Response must not include 'Sec-WebSocket-Protocol' "
76 "header if not present in request: ",
77 *maybe_value});
78 return false;
79 }
80 auto it = std::ranges::find(requested_sub_protocols, *maybe_value);
81 if (it == requested_sub_protocols.end()) {
82 *failure_message =
83 base::StrCat({"'Sec-WebSocket-Protocol' header value '", *maybe_value,
84 "' in response does not match any of sent values"});
85 return false;
86 }
87 value = *maybe_value;
88 }
89
90 if (!requested_sub_protocols.empty() && !value.has_value()) {
91 *failure_message =
92 "Sent non-empty 'Sec-WebSocket-Protocol' header "
93 "but no response was received";
94 return false;
95 }
96 if (value) {
97 *sub_protocol = *value;
98 } else {
99 sub_protocol->clear();
100 }
101 return true;
102 }
103
104 // static
ValidateExtensions(const HttpResponseHeaders * headers,std::string * accepted_extensions_descriptor,std::string * failure_message,WebSocketExtensionParams * params)105 bool WebSocketHandshakeStreamBase::ValidateExtensions(
106 const HttpResponseHeaders* headers,
107 std::string* accepted_extensions_descriptor,
108 std::string* failure_message,
109 WebSocketExtensionParams* params) {
110 size_t iter = 0;
111 std::vector<std::string> header_values;
112 // TODO(ricea): If adding support for additional extensions, generalise this
113 // code.
114 bool seen_permessage_deflate = false;
115 while (std::optional<std::string_view> header_value =
116 headers->EnumerateHeader(&iter,
117 websockets::kSecWebSocketExtensions)) {
118 WebSocketExtensionParser parser;
119 if (!parser.Parse(*header_value)) {
120 // TODO(yhirano) Set appropriate failure message.
121 *failure_message =
122 base::StrCat({"'Sec-WebSocket-Extensions' header value is "
123 "rejected by the parser: ",
124 *header_value});
125 return false;
126 }
127
128 const std::vector<WebSocketExtension>& extensions = parser.extensions();
129 for (const auto& extension : extensions) {
130 if (extension.name() == "permessage-deflate") {
131 if (seen_permessage_deflate) {
132 *failure_message = "Received duplicate permessage-deflate response";
133 return false;
134 }
135 seen_permessage_deflate = true;
136 auto& deflate_parameters = params->deflate_parameters;
137 if (!deflate_parameters.Initialize(extension, failure_message) ||
138 !deflate_parameters.IsValidAsResponse(failure_message)) {
139 *failure_message = "Error in permessage-deflate: " + *failure_message;
140 return false;
141 }
142 // Note that we don't have to check the request-response compatibility
143 // here because we send a request compatible with any valid responses.
144 // TODO(yhirano): Place a DCHECK here.
145
146 header_values.emplace_back(*header_value);
147 } else {
148 *failure_message = "Found an unsupported extension '" +
149 extension.name() +
150 "' in 'Sec-WebSocket-Extensions' header";
151 return false;
152 }
153 }
154 }
155 *accepted_extensions_descriptor = base::JoinString(header_values, ", ");
156 params->deflate_enabled = seen_permessage_deflate;
157 return true;
158 }
159
RecordHandshakeResult(HandshakeResult result)160 void WebSocketHandshakeStreamBase::RecordHandshakeResult(
161 HandshakeResult result) {
162 UMA_HISTOGRAM_ENUMERATION("Net.WebSocket.HandshakeResult2", result,
163 HandshakeResult::NUM_HANDSHAKE_RESULT_TYPES);
164 }
165
166 } // namespace net
167