• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/server/web_socket.h"
6 
7 #include <vector>
8 
9 #include "base/base64.h"
10 #include "base/check.h"
11 #include "base/hash/sha1.h"
12 #include "base/strings/string_number_conversions.h"
13 #include "base/strings/stringprintf.h"
14 #include "base/sys_byteorder.h"
15 #include "net/server/http_connection.h"
16 #include "net/server/http_server.h"
17 #include "net/server/http_server_request_info.h"
18 #include "net/server/http_server_response_info.h"
19 #include "net/server/web_socket_encoder.h"
20 #include "net/websockets/websocket_deflate_parameters.h"
21 #include "net/websockets/websocket_extension.h"
22 #include "net/websockets/websocket_handshake_constants.h"
23 
24 namespace net {
25 
26 namespace {
27 
ExtensionsHeaderString(const std::vector<WebSocketExtension> & extensions)28 std::string ExtensionsHeaderString(
29     const std::vector<WebSocketExtension>& extensions) {
30   if (extensions.empty())
31     return std::string();
32 
33   std::string result = "Sec-WebSocket-Extensions: " + extensions[0].ToString();
34   for (size_t i = 1; i < extensions.size(); ++i)
35     result += ", " + extensions[i].ToString();
36   return result + "\r\n";
37 }
38 
ValidResponseString(const std::string & accept_hash,const std::vector<WebSocketExtension> extensions)39 std::string ValidResponseString(
40     const std::string& accept_hash,
41     const std::vector<WebSocketExtension> extensions) {
42   return base::StringPrintf(
43       "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
44       "Upgrade: WebSocket\r\n"
45       "Connection: Upgrade\r\n"
46       "Sec-WebSocket-Accept: %s\r\n"
47       "%s"
48       "\r\n",
49       accept_hash.c_str(), ExtensionsHeaderString(extensions).c_str());
50 }
51 
52 }  // namespace
53 
WebSocket(HttpServer * server,HttpConnection * connection)54 WebSocket::WebSocket(HttpServer* server, HttpConnection* connection)
55     : server_(server), connection_(connection) {}
56 
57 WebSocket::~WebSocket() = default;
58 
Accept(const HttpServerRequestInfo & request,const NetworkTrafficAnnotationTag traffic_annotation)59 void WebSocket::Accept(const HttpServerRequestInfo& request,
60                        const NetworkTrafficAnnotationTag traffic_annotation) {
61   std::string version = request.GetHeaderValue("sec-websocket-version");
62   if (version != "8" && version != "13") {
63     SendErrorResponse("Invalid request format. The version is not valid.",
64                       traffic_annotation);
65     return;
66   }
67 
68   std::string key = request.GetHeaderValue("sec-websocket-key");
69   if (key.empty()) {
70     SendErrorResponse(
71         "Invalid request format. Sec-WebSocket-Key is empty or isn't "
72         "specified.",
73         traffic_annotation);
74     return;
75   }
76   std::string encoded_hash;
77   base::Base64Encode(base::SHA1HashString(key + websockets::kWebSocketGuid),
78                      &encoded_hash);
79 
80   std::vector<WebSocketExtension> response_extensions;
81   auto i = request.headers.find("sec-websocket-extensions");
82   if (i == request.headers.end()) {
83     encoder_ = WebSocketEncoder::CreateServer();
84   } else {
85     WebSocketDeflateParameters params;
86     encoder_ = WebSocketEncoder::CreateServer(i->second, &params);
87     if (!encoder_) {
88       Fail();
89       return;
90     }
91     if (encoder_->deflate_enabled()) {
92       DCHECK(params.IsValidAsResponse());
93       response_extensions.push_back(params.AsExtension());
94     }
95   }
96   server_->SendRaw(connection_->id(),
97                    ValidResponseString(encoded_hash, response_extensions),
98                    traffic_annotation);
99   traffic_annotation_ = std::make_unique<NetworkTrafficAnnotationTag>(
100       NetworkTrafficAnnotationTag(traffic_annotation));
101 }
102 
Read(std::string * message)103 WebSocket::ParseResult WebSocket::Read(std::string* message) {
104   if (closed_)
105     return FRAME_CLOSE;
106 
107   if (!encoder_) {
108     // RFC6455, section 4.1 says "Once the client's opening handshake has been
109     // sent, the client MUST wait for a response from the server before sending
110     // any further data". If |encoder_| is null here, ::Accept either has not
111     // been called at all, or has rejected a request rather than producing
112     // a server handshake. Either way, the client clearly couldn't have gotten
113     // a proper server handshake, so error out, especially since this method
114     // can't proceed without an |encoder_|.
115     return FRAME_ERROR;
116   }
117 
118   ParseResult result = FRAME_OK_MIDDLE;
119   HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf();
120   base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize());
121   int bytes_consumed = 0;
122   result = encoder_->DecodeFrame(frame, &bytes_consumed, message);
123   read_buf->DidConsume(bytes_consumed);
124 
125   if (result == FRAME_CLOSE) {
126     // The current websocket implementation does not initiate the Close
127     // handshake before closing the connection.
128     // Therefore the received Close frame most likely belongs to the client that
129     // initiated the Closing handshake.
130     // According to https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1
131     // if an endpoint receives a Close frame and did not previously send a
132     // Close frame, the endpoint MUST send a Close frame in response.
133     // It also MAY provide the close reason listed in
134     // https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1.
135     // As the closure was initiated by the client the "normal closure" status
136     // code is appropriate.
137     std::string code = "\x03\xe8";  // code = 1000;
138     std::string encoded;
139     encoder_->EncodeCloseFrame(code, 0, &encoded);
140     server_->SendRaw(connection_->id(), encoded, *traffic_annotation_);
141 
142     closed_ = true;
143   }
144 
145   if (result == FRAME_PING) {
146     if (!traffic_annotation_)
147       return FRAME_ERROR;
148     Send(*message, WebSocketFrameHeader::kOpCodePong, *traffic_annotation_);
149   }
150   return result;
151 }
152 
Send(base::StringPiece message,WebSocketFrameHeader::OpCodeEnum op_code,const NetworkTrafficAnnotationTag traffic_annotation)153 void WebSocket::Send(base::StringPiece message,
154                      WebSocketFrameHeader::OpCodeEnum op_code,
155                      const NetworkTrafficAnnotationTag traffic_annotation) {
156   if (closed_)
157     return;
158   std::string encoded;
159   switch (op_code) {
160     case WebSocketFrameHeader::kOpCodeText:
161       encoder_->EncodeTextFrame(message, 0, &encoded);
162       break;
163 
164     case WebSocketFrameHeader::kOpCodePong:
165       encoder_->EncodePongFrame(message, 0, &encoded);
166       break;
167 
168     default:
169       // Only Pong and Text frame types are supported.
170       NOTREACHED();
171   }
172   server_->SendRaw(connection_->id(), encoded, traffic_annotation);
173 }
174 
Fail()175 void WebSocket::Fail() {
176   closed_ = true;
177   // TODO(yhirano): The server SHOULD log the problem.
178   server_->Close(connection_->id());
179 }
180 
SendErrorResponse(const std::string & message,const NetworkTrafficAnnotationTag traffic_annotation)181 void WebSocket::SendErrorResponse(
182     const std::string& message,
183     const NetworkTrafficAnnotationTag traffic_annotation) {
184   if (closed_)
185     return;
186   closed_ = true;
187   server_->Send500(connection_->id(), message, traffic_annotation);
188 }
189 
190 }  // namespace net
191