1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/server/http_server.h"
6
7 #include "base/compiler_specific.h"
8 #include "base/logging.h"
9 #include "base/stl_util.h"
10 #include "base/strings/string_number_conversions.h"
11 #include "base/strings/string_util.h"
12 #include "base/strings/stringprintf.h"
13 #include "base/sys_byteorder.h"
14 #include "build/build_config.h"
15 #include "net/base/net_errors.h"
16 #include "net/server/http_connection.h"
17 #include "net/server/http_server_request_info.h"
18 #include "net/server/http_server_response_info.h"
19 #include "net/server/web_socket.h"
20 #include "net/socket/tcp_listen_socket.h"
21
22 namespace net {
23
HttpServer(const StreamListenSocketFactory & factory,HttpServer::Delegate * delegate)24 HttpServer::HttpServer(const StreamListenSocketFactory& factory,
25 HttpServer::Delegate* delegate)
26 : delegate_(delegate),
27 server_(factory.CreateAndListen(this)) {
28 }
29
AcceptWebSocket(int connection_id,const HttpServerRequestInfo & request)30 void HttpServer::AcceptWebSocket(
31 int connection_id,
32 const HttpServerRequestInfo& request) {
33 HttpConnection* connection = FindConnection(connection_id);
34 if (connection == NULL)
35 return;
36
37 DCHECK(connection->web_socket_.get());
38 connection->web_socket_->Accept(request);
39 }
40
SendOverWebSocket(int connection_id,const std::string & data)41 void HttpServer::SendOverWebSocket(int connection_id,
42 const std::string& data) {
43 HttpConnection* connection = FindConnection(connection_id);
44 if (connection == NULL)
45 return;
46 DCHECK(connection->web_socket_.get());
47 connection->web_socket_->Send(data);
48 }
49
SendRaw(int connection_id,const std::string & data)50 void HttpServer::SendRaw(int connection_id, const std::string& data) {
51 HttpConnection* connection = FindConnection(connection_id);
52 if (connection == NULL)
53 return;
54 connection->Send(data);
55 }
56
SendResponse(int connection_id,const HttpServerResponseInfo & response)57 void HttpServer::SendResponse(int connection_id,
58 const HttpServerResponseInfo& response) {
59 HttpConnection* connection = FindConnection(connection_id);
60 if (connection == NULL)
61 return;
62 connection->Send(response);
63 }
64
Send(int connection_id,HttpStatusCode status_code,const std::string & data,const std::string & content_type)65 void HttpServer::Send(int connection_id,
66 HttpStatusCode status_code,
67 const std::string& data,
68 const std::string& content_type) {
69 HttpServerResponseInfo response(status_code);
70 response.SetBody(data, content_type);
71 SendResponse(connection_id, response);
72 }
73
Send200(int connection_id,const std::string & data,const std::string & content_type)74 void HttpServer::Send200(int connection_id,
75 const std::string& data,
76 const std::string& content_type) {
77 Send(connection_id, HTTP_OK, data, content_type);
78 }
79
Send404(int connection_id)80 void HttpServer::Send404(int connection_id) {
81 SendResponse(connection_id, HttpServerResponseInfo::CreateFor404());
82 }
83
Send500(int connection_id,const std::string & message)84 void HttpServer::Send500(int connection_id, const std::string& message) {
85 SendResponse(connection_id, HttpServerResponseInfo::CreateFor500(message));
86 }
87
Close(int connection_id)88 void HttpServer::Close(int connection_id) {
89 HttpConnection* connection = FindConnection(connection_id);
90 if (connection == NULL)
91 return;
92
93 // Initiating close from server-side does not lead to the DidClose call.
94 // Do it manually here.
95 DidClose(connection->socket_.get());
96 }
97
GetLocalAddress(IPEndPoint * address)98 int HttpServer::GetLocalAddress(IPEndPoint* address) {
99 if (!server_)
100 return ERR_SOCKET_NOT_CONNECTED;
101 return server_->GetLocalAddress(address);
102 }
103
DidAccept(StreamListenSocket * server,scoped_ptr<StreamListenSocket> socket)104 void HttpServer::DidAccept(StreamListenSocket* server,
105 scoped_ptr<StreamListenSocket> socket) {
106 HttpConnection* connection = new HttpConnection(this, socket.Pass());
107 id_to_connection_[connection->id()] = connection;
108 // TODO(szym): Fix socket access. Make HttpConnection the Delegate.
109 socket_to_connection_[connection->socket_.get()] = connection;
110 }
111
DidRead(StreamListenSocket * socket,const char * data,int len)112 void HttpServer::DidRead(StreamListenSocket* socket,
113 const char* data,
114 int len) {
115 HttpConnection* connection = FindConnection(socket);
116 DCHECK(connection != NULL);
117 if (connection == NULL)
118 return;
119
120 connection->recv_data_.append(data, len);
121 while (connection->recv_data_.length()) {
122 if (connection->web_socket_.get()) {
123 std::string message;
124 WebSocket::ParseResult result = connection->web_socket_->Read(&message);
125 if (result == WebSocket::FRAME_INCOMPLETE)
126 break;
127
128 if (result == WebSocket::FRAME_CLOSE ||
129 result == WebSocket::FRAME_ERROR) {
130 Close(connection->id());
131 break;
132 }
133 delegate_->OnWebSocketMessage(connection->id(), message);
134 continue;
135 }
136
137 HttpServerRequestInfo request;
138 size_t pos = 0;
139 if (!ParseHeaders(connection, &request, &pos))
140 break;
141
142 // Sets peer address if exists.
143 socket->GetPeerAddress(&request.peer);
144
145 if (request.HasHeaderValue("connection", "upgrade")) {
146 connection->web_socket_.reset(WebSocket::CreateWebSocket(connection,
147 request,
148 &pos));
149
150 if (!connection->web_socket_.get()) // Not enough data was received.
151 break;
152 delegate_->OnWebSocketRequest(connection->id(), request);
153 connection->Shift(pos);
154 continue;
155 }
156
157 const char kContentLength[] = "content-length";
158 if (request.headers.count(kContentLength)) {
159 size_t content_length = 0;
160 const size_t kMaxBodySize = 100 << 20;
161 if (!base::StringToSizeT(request.GetHeaderValue(kContentLength),
162 &content_length) ||
163 content_length > kMaxBodySize) {
164 connection->Send(HttpServerResponseInfo::CreateFor500(
165 "request content-length too big or unknown: " +
166 request.GetHeaderValue(kContentLength)));
167 DidClose(socket);
168 break;
169 }
170
171 if (connection->recv_data_.length() - pos < content_length)
172 break; // Not enough data was received yet.
173 request.data = connection->recv_data_.substr(pos, content_length);
174 pos += content_length;
175 }
176
177 delegate_->OnHttpRequest(connection->id(), request);
178 connection->Shift(pos);
179 }
180 }
181
DidClose(StreamListenSocket * socket)182 void HttpServer::DidClose(StreamListenSocket* socket) {
183 HttpConnection* connection = FindConnection(socket);
184 DCHECK(connection != NULL);
185 id_to_connection_.erase(connection->id());
186 socket_to_connection_.erase(connection->socket_.get());
187 delete connection;
188 }
189
~HttpServer()190 HttpServer::~HttpServer() {
191 STLDeleteContainerPairSecondPointers(
192 id_to_connection_.begin(), id_to_connection_.end());
193 }
194
195 //
196 // HTTP Request Parser
197 // This HTTP request parser uses a simple state machine to quickly parse
198 // through the headers. The parser is not 100% complete, as it is designed
199 // for use in this simple test driver.
200 //
201 // Known issues:
202 // - does not handle whitespace on first HTTP line correctly. Expects
203 // a single space between the method/url and url/protocol.
204
205 // Input character types.
206 enum header_parse_inputs {
207 INPUT_LWS,
208 INPUT_CR,
209 INPUT_LF,
210 INPUT_COLON,
211 INPUT_DEFAULT,
212 MAX_INPUTS,
213 };
214
215 // Parser states.
216 enum header_parse_states {
217 ST_METHOD, // Receiving the method
218 ST_URL, // Receiving the URL
219 ST_PROTO, // Receiving the protocol
220 ST_HEADER, // Starting a Request Header
221 ST_NAME, // Receiving a request header name
222 ST_SEPARATOR, // Receiving the separator between header name and value
223 ST_VALUE, // Receiving a request header value
224 ST_DONE, // Parsing is complete and successful
225 ST_ERR, // Parsing encountered invalid syntax.
226 MAX_STATES
227 };
228
229 // State transition table
230 int parser_state[MAX_STATES][MAX_INPUTS] = {
231 /* METHOD */ { ST_URL, ST_ERR, ST_ERR, ST_ERR, ST_METHOD },
232 /* URL */ { ST_PROTO, ST_ERR, ST_ERR, ST_URL, ST_URL },
233 /* PROTOCOL */ { ST_ERR, ST_HEADER, ST_NAME, ST_ERR, ST_PROTO },
234 /* HEADER */ { ST_ERR, ST_ERR, ST_NAME, ST_ERR, ST_ERR },
235 /* NAME */ { ST_SEPARATOR, ST_DONE, ST_ERR, ST_VALUE, ST_NAME },
236 /* SEPARATOR */ { ST_SEPARATOR, ST_ERR, ST_ERR, ST_VALUE, ST_ERR },
237 /* VALUE */ { ST_VALUE, ST_HEADER, ST_NAME, ST_VALUE, ST_VALUE },
238 /* DONE */ { ST_DONE, ST_DONE, ST_DONE, ST_DONE, ST_DONE },
239 /* ERR */ { ST_ERR, ST_ERR, ST_ERR, ST_ERR, ST_ERR }
240 };
241
242 // Convert an input character to the parser's input token.
charToInput(char ch)243 int charToInput(char ch) {
244 switch(ch) {
245 case ' ':
246 case '\t':
247 return INPUT_LWS;
248 case '\r':
249 return INPUT_CR;
250 case '\n':
251 return INPUT_LF;
252 case ':':
253 return INPUT_COLON;
254 }
255 return INPUT_DEFAULT;
256 }
257
ParseHeaders(HttpConnection * connection,HttpServerRequestInfo * info,size_t * ppos)258 bool HttpServer::ParseHeaders(HttpConnection* connection,
259 HttpServerRequestInfo* info,
260 size_t* ppos) {
261 size_t& pos = *ppos;
262 size_t data_len = connection->recv_data_.length();
263 int state = ST_METHOD;
264 std::string buffer;
265 std::string header_name;
266 std::string header_value;
267 while (pos < data_len) {
268 char ch = connection->recv_data_[pos++];
269 int input = charToInput(ch);
270 int next_state = parser_state[state][input];
271
272 bool transition = (next_state != state);
273 HttpServerRequestInfo::HeadersMap::iterator it;
274 if (transition) {
275 // Do any actions based on state transitions.
276 switch (state) {
277 case ST_METHOD:
278 info->method = buffer;
279 buffer.clear();
280 break;
281 case ST_URL:
282 info->path = buffer;
283 buffer.clear();
284 break;
285 case ST_PROTO:
286 // TODO(mbelshe): Deal better with parsing protocol.
287 DCHECK(buffer == "HTTP/1.1");
288 buffer.clear();
289 break;
290 case ST_NAME:
291 header_name = StringToLowerASCII(buffer);
292 buffer.clear();
293 break;
294 case ST_VALUE:
295 base::TrimWhitespaceASCII(buffer, base::TRIM_LEADING, &header_value);
296 it = info->headers.find(header_name);
297 // See last paragraph ("Multiple message-header fields...")
298 // of www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
299 if (it == info->headers.end()) {
300 info->headers[header_name] = header_value;
301 } else {
302 it->second.append(",");
303 it->second.append(header_value);
304 }
305 buffer.clear();
306 break;
307 case ST_SEPARATOR:
308 break;
309 }
310 state = next_state;
311 } else {
312 // Do any actions based on current state
313 switch (state) {
314 case ST_METHOD:
315 case ST_URL:
316 case ST_PROTO:
317 case ST_VALUE:
318 case ST_NAME:
319 buffer.append(&ch, 1);
320 break;
321 case ST_DONE:
322 DCHECK(input == INPUT_LF);
323 return true;
324 case ST_ERR:
325 return false;
326 }
327 }
328 }
329 // No more characters, but we haven't finished parsing yet.
330 return false;
331 }
332
FindConnection(int connection_id)333 HttpConnection* HttpServer::FindConnection(int connection_id) {
334 IdToConnectionMap::iterator it = id_to_connection_.find(connection_id);
335 if (it == id_to_connection_.end())
336 return NULL;
337 return it->second;
338 }
339
FindConnection(StreamListenSocket * socket)340 HttpConnection* HttpServer::FindConnection(StreamListenSocket* socket) {
341 SocketToConnectionMap::iterator it = socket_to_connection_.find(socket);
342 if (it == socket_to_connection_.end())
343 return NULL;
344 return it->second;
345 }
346
347 } // namespace net
348