1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <stdio.h>
6 #include <locale>
7 #include <string>
8 #include <vector>
9
10 #include "base/at_exit.h"
11 #include "base/bind.h"
12 #include "base/callback.h"
13 #include "base/command_line.h"
14 #include "base/files/file_path.h"
15 #include "base/lazy_instance.h"
16 #include "base/logging.h"
17 #include "base/memory/scoped_ptr.h"
18 #include "base/message_loop/message_loop.h"
19 #include "base/run_loop.h"
20 #include "base/strings/string_number_conversions.h"
21 #include "base/strings/string_split.h"
22 #include "base/strings/string_util.h"
23 #include "base/strings/stringprintf.h"
24 #include "base/synchronization/waitable_event.h"
25 #include "base/threading/thread.h"
26 #include "base/threading/thread_local.h"
27 #include "chrome/test/chromedriver/logging.h"
28 #include "chrome/test/chromedriver/net/port_server.h"
29 #include "chrome/test/chromedriver/server/http_handler.h"
30 #include "chrome/test/chromedriver/version.h"
31 #include "net/base/ip_endpoint.h"
32 #include "net/base/net_errors.h"
33 #include "net/server/http_server.h"
34 #include "net/server/http_server_request_info.h"
35 #include "net/server/http_server_response_info.h"
36 #include "net/socket/tcp_listen_socket.h"
37
38 namespace {
39
40 const char* kLocalHostAddress = "127.0.0.1";
41
42 typedef base::Callback<
43 void(const net::HttpServerRequestInfo&, const HttpResponseSenderFunc&)>
44 HttpRequestHandlerFunc;
45
46 class HttpServer : public net::HttpServer::Delegate {
47 public:
HttpServer(const HttpRequestHandlerFunc & handle_request_func)48 explicit HttpServer(const HttpRequestHandlerFunc& handle_request_func)
49 : handle_request_func_(handle_request_func),
50 weak_factory_(this) {}
51
~HttpServer()52 virtual ~HttpServer() {}
53
Start(int port,bool allow_remote)54 bool Start(int port, bool allow_remote) {
55 std::string binding_ip = kLocalHostAddress;
56 if (allow_remote)
57 binding_ip = "0.0.0.0";
58 server_ = new net::HttpServer(
59 net::TCPListenSocketFactory(binding_ip, port), this);
60 net::IPEndPoint address;
61 return server_->GetLocalAddress(&address) == net::OK;
62 }
63
64 // Overridden from net::HttpServer::Delegate:
OnHttpRequest(int connection_id,const net::HttpServerRequestInfo & info)65 virtual void OnHttpRequest(int connection_id,
66 const net::HttpServerRequestInfo& info) OVERRIDE {
67 handle_request_func_.Run(
68 info,
69 base::Bind(&HttpServer::OnResponse,
70 weak_factory_.GetWeakPtr(),
71 connection_id));
72 }
OnWebSocketRequest(int connection_id,const net::HttpServerRequestInfo & info)73 virtual void OnWebSocketRequest(
74 int connection_id,
75 const net::HttpServerRequestInfo& info) OVERRIDE {}
OnWebSocketMessage(int connection_id,const std::string & data)76 virtual void OnWebSocketMessage(int connection_id,
77 const std::string& data) OVERRIDE {}
OnClose(int connection_id)78 virtual void OnClose(int connection_id) OVERRIDE {}
79
80 private:
OnResponse(int connection_id,scoped_ptr<net::HttpServerResponseInfo> response)81 void OnResponse(int connection_id,
82 scoped_ptr<net::HttpServerResponseInfo> response) {
83 // Don't support keep-alive, since there's no way to detect if the
84 // client is HTTP/1.0. In such cases, the client may hang waiting for
85 // the connection to close (e.g., python 2.7 urllib).
86 response->AddHeader("Connection", "close");
87 server_->SendResponse(connection_id, *response);
88 server_->Close(connection_id);
89 }
90
91 HttpRequestHandlerFunc handle_request_func_;
92 scoped_refptr<net::HttpServer> server_;
93 base::WeakPtrFactory<HttpServer> weak_factory_; // Should be last.
94 };
95
SendResponseOnCmdThread(const scoped_refptr<base::SingleThreadTaskRunner> & io_task_runner,const HttpResponseSenderFunc & send_response_on_io_func,scoped_ptr<net::HttpServerResponseInfo> response)96 void SendResponseOnCmdThread(
97 const scoped_refptr<base::SingleThreadTaskRunner>& io_task_runner,
98 const HttpResponseSenderFunc& send_response_on_io_func,
99 scoped_ptr<net::HttpServerResponseInfo> response) {
100 io_task_runner->PostTask(
101 FROM_HERE, base::Bind(send_response_on_io_func, base::Passed(&response)));
102 }
103
HandleRequestOnCmdThread(HttpHandler * handler,const std::vector<std::string> & whitelisted_ips,const net::HttpServerRequestInfo & request,const HttpResponseSenderFunc & send_response_func)104 void HandleRequestOnCmdThread(
105 HttpHandler* handler,
106 const std::vector<std::string>& whitelisted_ips,
107 const net::HttpServerRequestInfo& request,
108 const HttpResponseSenderFunc& send_response_func) {
109 if (!whitelisted_ips.empty()) {
110 std::string peer_address = request.peer.ToStringWithoutPort();
111 if (peer_address != kLocalHostAddress &&
112 std::find(whitelisted_ips.begin(), whitelisted_ips.end(),
113 peer_address) == whitelisted_ips.end()) {
114 LOG(WARNING) << "unauthorized access from " << request.peer.ToString();
115 scoped_ptr<net::HttpServerResponseInfo> response(
116 new net::HttpServerResponseInfo(net::HTTP_UNAUTHORIZED));
117 response->SetBody("Unauthorized access", "text/plain");
118 send_response_func.Run(response.Pass());
119 return;
120 }
121 }
122
123 handler->Handle(request, send_response_func);
124 }
125
HandleRequestOnIOThread(const scoped_refptr<base::SingleThreadTaskRunner> & cmd_task_runner,const HttpRequestHandlerFunc & handle_request_on_cmd_func,const net::HttpServerRequestInfo & request,const HttpResponseSenderFunc & send_response_func)126 void HandleRequestOnIOThread(
127 const scoped_refptr<base::SingleThreadTaskRunner>& cmd_task_runner,
128 const HttpRequestHandlerFunc& handle_request_on_cmd_func,
129 const net::HttpServerRequestInfo& request,
130 const HttpResponseSenderFunc& send_response_func) {
131 cmd_task_runner->PostTask(
132 FROM_HERE,
133 base::Bind(handle_request_on_cmd_func,
134 request,
135 base::Bind(&SendResponseOnCmdThread,
136 base::MessageLoopProxy::current(),
137 send_response_func)));
138 }
139
140 base::LazyInstance<base::ThreadLocalPointer<HttpServer> >
141 lazy_tls_server = LAZY_INSTANCE_INITIALIZER;
142
StopServerOnIOThread()143 void StopServerOnIOThread() {
144 // Note, |server| may be NULL.
145 HttpServer* server = lazy_tls_server.Pointer()->Get();
146 lazy_tls_server.Pointer()->Set(NULL);
147 delete server;
148 }
149
StartServerOnIOThread(int port,bool allow_remote,const HttpRequestHandlerFunc & handle_request_func)150 void StartServerOnIOThread(int port,
151 bool allow_remote,
152 const HttpRequestHandlerFunc& handle_request_func) {
153 scoped_ptr<HttpServer> temp_server(new HttpServer(handle_request_func));
154 if (!temp_server->Start(port, allow_remote)) {
155 printf("Port not available. Exiting...\n");
156 exit(1);
157 }
158 lazy_tls_server.Pointer()->Set(temp_server.release());
159 }
160
RunServer(int port,bool allow_remote,const std::vector<std::string> & whitelisted_ips,const std::string & url_base,int adb_port,scoped_ptr<PortServer> port_server)161 void RunServer(int port,
162 bool allow_remote,
163 const std::vector<std::string>& whitelisted_ips,
164 const std::string& url_base,
165 int adb_port,
166 scoped_ptr<PortServer> port_server) {
167 base::Thread io_thread("ChromeDriver IO");
168 CHECK(io_thread.StartWithOptions(
169 base::Thread::Options(base::MessageLoop::TYPE_IO, 0)));
170
171 base::MessageLoop cmd_loop;
172 base::RunLoop cmd_run_loop;
173 HttpHandler handler(cmd_run_loop.QuitClosure(),
174 io_thread.message_loop_proxy(),
175 url_base,
176 adb_port,
177 port_server.Pass());
178 HttpRequestHandlerFunc handle_request_func =
179 base::Bind(&HandleRequestOnCmdThread, &handler, whitelisted_ips);
180
181 io_thread.message_loop()
182 ->PostTask(FROM_HERE,
183 base::Bind(&StartServerOnIOThread,
184 port,
185 allow_remote,
186 base::Bind(&HandleRequestOnIOThread,
187 cmd_loop.message_loop_proxy(),
188 handle_request_func)));
189 // Run the command loop. This loop is quit after the response for a shutdown
190 // request is posted to the IO loop. After the command loop quits, a task
191 // is posted to the IO loop to stop the server. Lastly, the IO thread is
192 // destroyed, which waits until all pending tasks have been completed.
193 // This assumes the response is sent synchronously as part of the IO task.
194 cmd_run_loop.Run();
195 io_thread.message_loop()
196 ->PostTask(FROM_HERE, base::Bind(&StopServerOnIOThread));
197 }
198
199 } // namespace
200
main(int argc,char * argv[])201 int main(int argc, char *argv[]) {
202 CommandLine::Init(argc, argv);
203
204 base::AtExitManager at_exit;
205 CommandLine* cmd_line = CommandLine::ForCurrentProcess();
206
207 #if defined(OS_LINUX)
208 // Select the locale from the environment by passing an empty string instead
209 // of the default "C" locale. This is particularly needed for the keycode
210 // conversion code to work.
211 setlocale(LC_ALL, "");
212 #endif
213
214 // Parse command line flags.
215 int port = 9515;
216 int adb_port = 5037;
217 bool allow_remote = false;
218 std::vector<std::string> whitelisted_ips;
219 std::string url_base;
220 scoped_ptr<PortServer> port_server;
221 if (cmd_line->HasSwitch("h") || cmd_line->HasSwitch("help")) {
222 std::string options;
223 const char* kOptionAndDescriptions[] = {
224 "port=PORT", "port to listen on",
225 "adb-port=PORT", "adb server port",
226 "log-path=FILE", "write server log to file instead of stderr, "
227 "increases log level to INFO",
228 "verbose", "log verbosely",
229 "version", "print the version number and exit",
230 "silent", "log nothing",
231 "url-base", "base URL path prefix for commands, e.g. wd/url",
232 "port-server", "address of server to contact for reserving a port",
233 "whitelisted-ips", "comma-separated whitelist of remote IPv4 addresses "
234 "which are allowed to connect to ChromeDriver",
235 };
236 for (size_t i = 0; i < arraysize(kOptionAndDescriptions) - 1; i += 2) {
237 options += base::StringPrintf(
238 " --%-30s%s\n",
239 kOptionAndDescriptions[i], kOptionAndDescriptions[i + 1]);
240 }
241 printf("Usage: %s [OPTIONS]\n\nOptions\n%s", argv[0], options.c_str());
242 return 0;
243 }
244 if (cmd_line->HasSwitch("v") || cmd_line->HasSwitch("version")) {
245 printf("ChromeDriver %s\n", kChromeDriverVersion);
246 return 0;
247 }
248 if (cmd_line->HasSwitch("port")) {
249 if (!base::StringToInt(cmd_line->GetSwitchValueASCII("port"), &port)) {
250 printf("Invalid port. Exiting...\n");
251 return 1;
252 }
253 }
254 if (cmd_line->HasSwitch("adb-port")) {
255 if (!base::StringToInt(cmd_line->GetSwitchValueASCII("adb-port"),
256 &adb_port)) {
257 printf("Invalid adb-port. Exiting...\n");
258 return 1;
259 }
260 }
261 if (cmd_line->HasSwitch("port-server")) {
262 #if defined(OS_LINUX)
263 std::string address = cmd_line->GetSwitchValueASCII("port-server");
264 if (address.empty() || address[0] != '@') {
265 printf("Invalid port-server. Exiting...\n");
266 return 1;
267 }
268 std::string path;
269 // First character of path is \0 to use Linux's abstract namespace.
270 path.push_back(0);
271 path += address.substr(1);
272 port_server.reset(new PortServer(path));
273 #else
274 printf("Warning: port-server not implemented for this platform.\n");
275 #endif
276 }
277 if (cmd_line->HasSwitch("url-base"))
278 url_base = cmd_line->GetSwitchValueASCII("url-base");
279 if (url_base.empty() || url_base[0] != '/')
280 url_base = "/" + url_base;
281 if (url_base[url_base.length() - 1] != '/')
282 url_base = url_base + "/";
283 if (cmd_line->HasSwitch("whitelisted-ips")) {
284 allow_remote = true;
285 std::string whitelist = cmd_line->GetSwitchValueASCII("whitelisted-ips");
286 base::SplitString(whitelist, ',', &whitelisted_ips);
287 }
288 if (!cmd_line->HasSwitch("silent")) {
289 printf(
290 "Starting ChromeDriver (v%s) on port %d\n", kChromeDriverVersion, port);
291 if (!allow_remote) {
292 printf("Only local connections are allowed.\n");
293 } else if (!whitelisted_ips.empty()) {
294 printf("Remote connections are allowed by a whitelist (%s).\n",
295 cmd_line->GetSwitchValueASCII("whitelisted-ips").c_str());
296 } else {
297 printf("All remote connections are allowed. Use a whitelist instead!\n");
298 }
299 fflush(stdout);
300 }
301
302 if (!InitLogging()) {
303 printf("Unable to initialize logging. Exiting...\n");
304 return 1;
305 }
306 RunServer(port, allow_remote, whitelisted_ips,
307 url_base, adb_port, port_server.Pass());
308 return 0;
309 }
310