• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <stdio.h>
6 #include <locale>
7 #include <string>
8 #include <vector>
9 
10 #include "base/at_exit.h"
11 #include "base/bind.h"
12 #include "base/callback.h"
13 #include "base/command_line.h"
14 #include "base/files/file_path.h"
15 #include "base/lazy_instance.h"
16 #include "base/logging.h"
17 #include "base/memory/scoped_ptr.h"
18 #include "base/message_loop/message_loop.h"
19 #include "base/run_loop.h"
20 #include "base/strings/string_number_conversions.h"
21 #include "base/strings/string_split.h"
22 #include "base/strings/string_util.h"
23 #include "base/strings/stringprintf.h"
24 #include "base/synchronization/waitable_event.h"
25 #include "base/threading/thread.h"
26 #include "base/threading/thread_local.h"
27 #include "chrome/test/chromedriver/logging.h"
28 #include "chrome/test/chromedriver/net/port_server.h"
29 #include "chrome/test/chromedriver/server/http_handler.h"
30 #include "chrome/test/chromedriver/version.h"
31 #include "net/base/ip_endpoint.h"
32 #include "net/base/net_errors.h"
33 #include "net/server/http_server.h"
34 #include "net/server/http_server_request_info.h"
35 #include "net/server/http_server_response_info.h"
36 #include "net/socket/tcp_listen_socket.h"
37 
38 namespace {
39 
40 const char* kLocalHostAddress = "127.0.0.1";
41 
42 typedef base::Callback<
43     void(const net::HttpServerRequestInfo&, const HttpResponseSenderFunc&)>
44     HttpRequestHandlerFunc;
45 
46 class HttpServer : public net::HttpServer::Delegate {
47  public:
HttpServer(const HttpRequestHandlerFunc & handle_request_func)48   explicit HttpServer(const HttpRequestHandlerFunc& handle_request_func)
49       : handle_request_func_(handle_request_func),
50         weak_factory_(this) {}
51 
~HttpServer()52   virtual ~HttpServer() {}
53 
Start(int port,bool allow_remote)54   bool Start(int port, bool allow_remote) {
55     std::string binding_ip = kLocalHostAddress;
56     if (allow_remote)
57       binding_ip = "0.0.0.0";
58     server_ = new net::HttpServer(
59         net::TCPListenSocketFactory(binding_ip, port), this);
60     net::IPEndPoint address;
61     return server_->GetLocalAddress(&address) == net::OK;
62   }
63 
64   // Overridden from net::HttpServer::Delegate:
OnHttpRequest(int connection_id,const net::HttpServerRequestInfo & info)65   virtual void OnHttpRequest(int connection_id,
66                              const net::HttpServerRequestInfo& info) OVERRIDE {
67     handle_request_func_.Run(
68         info,
69         base::Bind(&HttpServer::OnResponse,
70                    weak_factory_.GetWeakPtr(),
71                    connection_id));
72   }
OnWebSocketRequest(int connection_id,const net::HttpServerRequestInfo & info)73   virtual void OnWebSocketRequest(
74       int connection_id,
75       const net::HttpServerRequestInfo& info) OVERRIDE {}
OnWebSocketMessage(int connection_id,const std::string & data)76   virtual void OnWebSocketMessage(int connection_id,
77                                   const std::string& data) OVERRIDE {}
OnClose(int connection_id)78   virtual void OnClose(int connection_id) OVERRIDE {}
79 
80  private:
OnResponse(int connection_id,scoped_ptr<net::HttpServerResponseInfo> response)81   void OnResponse(int connection_id,
82                   scoped_ptr<net::HttpServerResponseInfo> response) {
83     // Don't support keep-alive, since there's no way to detect if the
84     // client is HTTP/1.0. In such cases, the client may hang waiting for
85     // the connection to close (e.g., python 2.7 urllib).
86     response->AddHeader("Connection", "close");
87     server_->SendResponse(connection_id, *response);
88     server_->Close(connection_id);
89   }
90 
91   HttpRequestHandlerFunc handle_request_func_;
92   scoped_refptr<net::HttpServer> server_;
93   base::WeakPtrFactory<HttpServer> weak_factory_;  // Should be last.
94 };
95 
SendResponseOnCmdThread(const scoped_refptr<base::SingleThreadTaskRunner> & io_task_runner,const HttpResponseSenderFunc & send_response_on_io_func,scoped_ptr<net::HttpServerResponseInfo> response)96 void SendResponseOnCmdThread(
97     const scoped_refptr<base::SingleThreadTaskRunner>& io_task_runner,
98     const HttpResponseSenderFunc& send_response_on_io_func,
99     scoped_ptr<net::HttpServerResponseInfo> response) {
100   io_task_runner->PostTask(
101       FROM_HERE, base::Bind(send_response_on_io_func, base::Passed(&response)));
102 }
103 
HandleRequestOnCmdThread(HttpHandler * handler,const std::vector<std::string> & whitelisted_ips,const net::HttpServerRequestInfo & request,const HttpResponseSenderFunc & send_response_func)104 void HandleRequestOnCmdThread(
105     HttpHandler* handler,
106     const std::vector<std::string>& whitelisted_ips,
107     const net::HttpServerRequestInfo& request,
108     const HttpResponseSenderFunc& send_response_func) {
109   if (!whitelisted_ips.empty()) {
110     std::string peer_address = request.peer.ToStringWithoutPort();
111     if (peer_address != kLocalHostAddress &&
112         std::find(whitelisted_ips.begin(), whitelisted_ips.end(),
113                   peer_address) == whitelisted_ips.end()) {
114       LOG(WARNING) << "unauthorized access from " << request.peer.ToString();
115       scoped_ptr<net::HttpServerResponseInfo> response(
116           new net::HttpServerResponseInfo(net::HTTP_UNAUTHORIZED));
117       response->SetBody("Unauthorized access", "text/plain");
118       send_response_func.Run(response.Pass());
119       return;
120     }
121   }
122 
123   handler->Handle(request, send_response_func);
124 }
125 
HandleRequestOnIOThread(const scoped_refptr<base::SingleThreadTaskRunner> & cmd_task_runner,const HttpRequestHandlerFunc & handle_request_on_cmd_func,const net::HttpServerRequestInfo & request,const HttpResponseSenderFunc & send_response_func)126 void HandleRequestOnIOThread(
127     const scoped_refptr<base::SingleThreadTaskRunner>& cmd_task_runner,
128     const HttpRequestHandlerFunc& handle_request_on_cmd_func,
129     const net::HttpServerRequestInfo& request,
130     const HttpResponseSenderFunc& send_response_func) {
131   cmd_task_runner->PostTask(
132       FROM_HERE,
133       base::Bind(handle_request_on_cmd_func,
134                  request,
135                  base::Bind(&SendResponseOnCmdThread,
136                             base::MessageLoopProxy::current(),
137                             send_response_func)));
138 }
139 
140 base::LazyInstance<base::ThreadLocalPointer<HttpServer> >
141     lazy_tls_server = LAZY_INSTANCE_INITIALIZER;
142 
StopServerOnIOThread()143 void StopServerOnIOThread() {
144   // Note, |server| may be NULL.
145   HttpServer* server = lazy_tls_server.Pointer()->Get();
146   lazy_tls_server.Pointer()->Set(NULL);
147   delete server;
148 }
149 
StartServerOnIOThread(int port,bool allow_remote,const HttpRequestHandlerFunc & handle_request_func)150 void StartServerOnIOThread(int port,
151                            bool allow_remote,
152                            const HttpRequestHandlerFunc& handle_request_func) {
153   scoped_ptr<HttpServer> temp_server(new HttpServer(handle_request_func));
154   if (!temp_server->Start(port, allow_remote)) {
155     printf("Port not available. Exiting...\n");
156     exit(1);
157   }
158   lazy_tls_server.Pointer()->Set(temp_server.release());
159 }
160 
RunServer(int port,bool allow_remote,const std::vector<std::string> & whitelisted_ips,const std::string & url_base,int adb_port,scoped_ptr<PortServer> port_server)161 void RunServer(int port,
162                bool allow_remote,
163                const std::vector<std::string>& whitelisted_ips,
164                const std::string& url_base,
165                int adb_port,
166                scoped_ptr<PortServer> port_server) {
167   base::Thread io_thread("ChromeDriver IO");
168   CHECK(io_thread.StartWithOptions(
169       base::Thread::Options(base::MessageLoop::TYPE_IO, 0)));
170 
171   base::MessageLoop cmd_loop;
172   base::RunLoop cmd_run_loop;
173   HttpHandler handler(cmd_run_loop.QuitClosure(),
174                       io_thread.message_loop_proxy(),
175                       url_base,
176                       adb_port,
177                       port_server.Pass());
178   HttpRequestHandlerFunc handle_request_func =
179       base::Bind(&HandleRequestOnCmdThread, &handler, whitelisted_ips);
180 
181   io_thread.message_loop()
182       ->PostTask(FROM_HERE,
183                  base::Bind(&StartServerOnIOThread,
184                             port,
185                             allow_remote,
186                             base::Bind(&HandleRequestOnIOThread,
187                                        cmd_loop.message_loop_proxy(),
188                                        handle_request_func)));
189   // Run the command loop. This loop is quit after the response for a shutdown
190   // request is posted to the IO loop. After the command loop quits, a task
191   // is posted to the IO loop to stop the server. Lastly, the IO thread is
192   // destroyed, which waits until all pending tasks have been completed.
193   // This assumes the response is sent synchronously as part of the IO task.
194   cmd_run_loop.Run();
195   io_thread.message_loop()
196       ->PostTask(FROM_HERE, base::Bind(&StopServerOnIOThread));
197 }
198 
199 }  // namespace
200 
main(int argc,char * argv[])201 int main(int argc, char *argv[]) {
202   CommandLine::Init(argc, argv);
203 
204   base::AtExitManager at_exit;
205   CommandLine* cmd_line = CommandLine::ForCurrentProcess();
206 
207 #if defined(OS_LINUX)
208   // Select the locale from the environment by passing an empty string instead
209   // of the default "C" locale. This is particularly needed for the keycode
210   // conversion code to work.
211   setlocale(LC_ALL, "");
212 #endif
213 
214   // Parse command line flags.
215   int port = 9515;
216   int adb_port = 5037;
217   bool allow_remote = false;
218   std::vector<std::string> whitelisted_ips;
219   std::string url_base;
220   scoped_ptr<PortServer> port_server;
221   if (cmd_line->HasSwitch("h") || cmd_line->HasSwitch("help")) {
222     std::string options;
223     const char* kOptionAndDescriptions[] = {
224         "port=PORT", "port to listen on",
225         "adb-port=PORT", "adb server port",
226         "log-path=FILE", "write server log to file instead of stderr, "
227             "increases log level to INFO",
228         "verbose", "log verbosely",
229         "version", "print the version number and exit",
230         "silent", "log nothing",
231         "url-base", "base URL path prefix for commands, e.g. wd/url",
232         "port-server", "address of server to contact for reserving a port",
233         "whitelisted-ips", "comma-separated whitelist of remote IPv4 addresses "
234             "which are allowed to connect to ChromeDriver",
235     };
236     for (size_t i = 0; i < arraysize(kOptionAndDescriptions) - 1; i += 2) {
237       options += base::StringPrintf(
238           "  --%-30s%s\n",
239           kOptionAndDescriptions[i], kOptionAndDescriptions[i + 1]);
240     }
241     printf("Usage: %s [OPTIONS]\n\nOptions\n%s", argv[0], options.c_str());
242     return 0;
243   }
244   if (cmd_line->HasSwitch("v") || cmd_line->HasSwitch("version")) {
245     printf("ChromeDriver %s\n", kChromeDriverVersion);
246     return 0;
247   }
248   if (cmd_line->HasSwitch("port")) {
249     if (!base::StringToInt(cmd_line->GetSwitchValueASCII("port"), &port)) {
250       printf("Invalid port. Exiting...\n");
251       return 1;
252     }
253   }
254   if (cmd_line->HasSwitch("adb-port")) {
255     if (!base::StringToInt(cmd_line->GetSwitchValueASCII("adb-port"),
256                            &adb_port)) {
257       printf("Invalid adb-port. Exiting...\n");
258       return 1;
259     }
260   }
261   if (cmd_line->HasSwitch("port-server")) {
262 #if defined(OS_LINUX)
263     std::string address = cmd_line->GetSwitchValueASCII("port-server");
264     if (address.empty() || address[0] != '@') {
265       printf("Invalid port-server. Exiting...\n");
266       return 1;
267     }
268     std::string path;
269     // First character of path is \0 to use Linux's abstract namespace.
270     path.push_back(0);
271     path += address.substr(1);
272     port_server.reset(new PortServer(path));
273 #else
274     printf("Warning: port-server not implemented for this platform.\n");
275 #endif
276   }
277   if (cmd_line->HasSwitch("url-base"))
278     url_base = cmd_line->GetSwitchValueASCII("url-base");
279   if (url_base.empty() || url_base[0] != '/')
280     url_base = "/" + url_base;
281   if (url_base[url_base.length() - 1] != '/')
282     url_base = url_base + "/";
283   if (cmd_line->HasSwitch("whitelisted-ips")) {
284     allow_remote = true;
285     std::string whitelist = cmd_line->GetSwitchValueASCII("whitelisted-ips");
286     base::SplitString(whitelist, ',', &whitelisted_ips);
287   }
288   if (!cmd_line->HasSwitch("silent")) {
289     printf(
290         "Starting ChromeDriver (v%s) on port %d\n", kChromeDriverVersion, port);
291     if (!allow_remote) {
292       printf("Only local connections are allowed.\n");
293     } else if (!whitelisted_ips.empty()) {
294       printf("Remote connections are allowed by a whitelist (%s).\n",
295              cmd_line->GetSwitchValueASCII("whitelisted-ips").c_str());
296     } else {
297       printf("All remote connections are allowed. Use a whitelist instead!\n");
298     }
299     fflush(stdout);
300   }
301 
302   if (!InitLogging()) {
303     printf("Unable to initialize logging. Exiting...\n");
304     return 1;
305   }
306   RunServer(port, allow_remote, whitelisted_ips,
307             url_base, adb_port, port_server.Pass());
308   return 0;
309 }
310