• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_handshake_stream_base.h"
6 
7 #include <stddef.h>
8 
9 #include "base/metrics/histogram_functions.h"
10 #include "base/metrics/histogram_macros.h"
11 #include "base/strings/strcat.h"
12 #include "base/strings/string_util.h"
13 #include "net/http/http_request_headers.h"
14 #include "net/http/http_response_headers.h"
15 #include "net/websockets/websocket_extension.h"
16 #include "net/websockets/websocket_extension_parser.h"
17 #include "net/websockets/websocket_handshake_constants.h"
18 
19 namespace net {
20 
21 namespace {
22 
AddVectorHeaderIfNonEmpty(const char * name,const std::vector<std::string> & value,HttpRequestHeaders * headers)23 size_t AddVectorHeaderIfNonEmpty(const char* name,
24                                  const std::vector<std::string>& value,
25                                  HttpRequestHeaders* headers) {
26   if (value.empty()) {
27     return 0u;
28   }
29   std::string joined = base::JoinString(value, ", ");
30   const size_t size = joined.size();
31   headers->SetHeader(name, std::move(joined));
32   return size;
33 }
34 
35 }  // namespace
36 
37 // static
MultipleHeaderValuesMessage(const std::string & header_name)38 std::string WebSocketHandshakeStreamBase::MultipleHeaderValuesMessage(
39     const std::string& header_name) {
40   return base::StrCat(
41       {"'", header_name,
42        "' header must not appear more than once in a response"});
43 }
44 
45 // static
AddVectorHeaders(const std::vector<std::string> & extensions,const std::vector<std::string> & protocols,HttpRequestHeaders * headers)46 void WebSocketHandshakeStreamBase::AddVectorHeaders(
47     const std::vector<std::string>& extensions,
48     const std::vector<std::string>& protocols,
49     HttpRequestHeaders* headers) {
50   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions, extensions,
51                             headers);
52   const size_t protocol_header_size = AddVectorHeaderIfNonEmpty(
53       websockets::kSecWebSocketProtocol, protocols, headers);
54   base::UmaHistogramCounts10000("Net.WebSocket.ProtocolHeaderSize",
55                                 protocol_header_size);
56 }
57 
58 // static
ValidateSubProtocol(const HttpResponseHeaders * headers,const std::vector<std::string> & requested_sub_protocols,std::string * sub_protocol,std::string * failure_message)59 bool WebSocketHandshakeStreamBase::ValidateSubProtocol(
60     const HttpResponseHeaders* headers,
61     const std::vector<std::string>& requested_sub_protocols,
62     std::string* sub_protocol,
63     std::string* failure_message) {
64   size_t iter = 0;
65   std::optional<std::string> value;
66   while (std::optional<std::string_view> maybe_value = headers->EnumerateHeader(
67              &iter, websockets::kSecWebSocketProtocol)) {
68     if (value) {
69       *failure_message =
70           MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
71       return false;
72     }
73     if (requested_sub_protocols.empty()) {
74       *failure_message =
75           base::StrCat({"Response must not include 'Sec-WebSocket-Protocol' "
76                         "header if not present in request: ",
77                         *maybe_value});
78       return false;
79     }
80     auto it = std::ranges::find(requested_sub_protocols, *maybe_value);
81     if (it == requested_sub_protocols.end()) {
82       *failure_message =
83           base::StrCat({"'Sec-WebSocket-Protocol' header value '", *maybe_value,
84                         "' in response does not match any of sent values"});
85       return false;
86     }
87     value = *maybe_value;
88   }
89 
90   if (!requested_sub_protocols.empty() && !value.has_value()) {
91     *failure_message =
92         "Sent non-empty 'Sec-WebSocket-Protocol' header "
93         "but no response was received";
94     return false;
95   }
96   if (value) {
97     *sub_protocol = *value;
98   } else {
99     sub_protocol->clear();
100   }
101   return true;
102 }
103 
104 // static
ValidateExtensions(const HttpResponseHeaders * headers,std::string * accepted_extensions_descriptor,std::string * failure_message,WebSocketExtensionParams * params)105 bool WebSocketHandshakeStreamBase::ValidateExtensions(
106     const HttpResponseHeaders* headers,
107     std::string* accepted_extensions_descriptor,
108     std::string* failure_message,
109     WebSocketExtensionParams* params) {
110   size_t iter = 0;
111   std::vector<std::string> header_values;
112   // TODO(ricea): If adding support for additional extensions, generalise this
113   // code.
114   bool seen_permessage_deflate = false;
115   while (std::optional<std::string_view> header_value =
116              headers->EnumerateHeader(&iter,
117                                       websockets::kSecWebSocketExtensions)) {
118     WebSocketExtensionParser parser;
119     if (!parser.Parse(*header_value)) {
120       // TODO(yhirano) Set appropriate failure message.
121       *failure_message =
122           base::StrCat({"'Sec-WebSocket-Extensions' header value is "
123                         "rejected by the parser: ",
124                         *header_value});
125       return false;
126     }
127 
128     const std::vector<WebSocketExtension>& extensions = parser.extensions();
129     for (const auto& extension : extensions) {
130       if (extension.name() == "permessage-deflate") {
131         if (seen_permessage_deflate) {
132           *failure_message = "Received duplicate permessage-deflate response";
133           return false;
134         }
135         seen_permessage_deflate = true;
136         auto& deflate_parameters = params->deflate_parameters;
137         if (!deflate_parameters.Initialize(extension, failure_message) ||
138             !deflate_parameters.IsValidAsResponse(failure_message)) {
139           *failure_message = "Error in permessage-deflate: " + *failure_message;
140           return false;
141         }
142         // Note that we don't have to check the request-response compatibility
143         // here because we send a request compatible with any valid responses.
144         // TODO(yhirano): Place a DCHECK here.
145 
146         header_values.emplace_back(*header_value);
147       } else {
148         *failure_message = "Found an unsupported extension '" +
149                            extension.name() +
150                            "' in 'Sec-WebSocket-Extensions' header";
151         return false;
152       }
153     }
154   }
155   *accepted_extensions_descriptor = base::JoinString(header_values, ", ");
156   params->deflate_enabled = seen_permessage_deflate;
157   return true;
158 }
159 
RecordHandshakeResult(HandshakeResult result)160 void WebSocketHandshakeStreamBase::RecordHandshakeResult(
161     HandshakeResult result) {
162   UMA_HISTOGRAM_ENUMERATION("Net.WebSocket.HandshakeResult2", result,
163                             HandshakeResult::NUM_HANDSHAKE_RESULT_TYPES);
164 }
165 
166 }  // namespace net
167