1# Copyright (c) 2012 The Chromium OS Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5"""Spins up a trivial HTTP cgi form listener in a thread. 6 7 This HTTPThread class is a utility for use with test cases that 8 need to call back to the Autotest test case with some form value, e.g. 9 http://localhost:nnnn/?status="Browser started!" 10""" 11 12import cgi, errno, logging, os, posixpath, SimpleHTTPServer, socket, ssl, sys 13import threading, urllib, urlparse 14from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer 15from SocketServer import BaseServer, ThreadingMixIn 16 17 18def _handle_http_errors(func): 19 """Decorator function for cleaner presentation of certain exceptions.""" 20 def wrapper(self): 21 try: 22 func(self) 23 except IOError, e: 24 if e.errno == errno.EPIPE or e.errno == errno.ECONNRESET: 25 # Instead of dumping a stack trace, a single line is sufficient. 26 self.log_error(str(e)) 27 else: 28 raise 29 30 return wrapper 31 32 33class FormHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): 34 """Implements a form handler (for POST requests only) which simply 35 echoes the key=value parameters back in the response. 36 37 If the form submission is a file upload, the file will be written 38 to disk with the name contained in the 'filename' field. 39 """ 40 41 SimpleHTTPServer.SimpleHTTPRequestHandler.extensions_map.update({ 42 '.webm': 'video/webm', 43 }) 44 45 # Override the default logging methods to use the logging module directly. 46 def log_error(self, format, *args): 47 logging.warning("(httpd error) %s - - [%s] %s\n" % 48 (self.address_string(), self.log_date_time_string(), 49 format%args)) 50 51 def log_message(self, format, *args): 52 logging.debug("%s - - [%s] %s\n" % 53 (self.address_string(), self.log_date_time_string(), 54 format%args)) 55 56 @_handle_http_errors 57 def do_POST(self): 58 form = cgi.FieldStorage( 59 fp=self.rfile, 60 headers=self.headers, 61 environ={'REQUEST_METHOD': 'POST', 62 'CONTENT_TYPE': self.headers['Content-Type']}) 63 # You'd think form.keys() would just return [], like it does for empty 64 # python dicts; you'd be wrong. It raises TypeError if called when it 65 # has no keys. 66 if form: 67 for field in form.keys(): 68 field_item = form[field] 69 self.server._form_entries[field] = field_item.value 70 path = urlparse.urlparse(self.path)[2] 71 if path in self.server._url_handlers: 72 self.server._url_handlers[path](self, form) 73 else: 74 # Echo back information about what was posted in the form. 75 self.write_post_response(form) 76 self._fire_event() 77 78 79 def write_post_response(self, form): 80 """Called to fill out the response to an HTTP POST. 81 82 Override this class to give custom responses. 83 """ 84 # Send response boilerplate 85 self.send_response(200) 86 self.end_headers() 87 self.wfile.write('Hello from Autotest!\nClient: %s\n' % 88 str(self.client_address)) 89 self.wfile.write('Request for path: %s\n' % self.path) 90 self.wfile.write('Got form data:\n') 91 92 # See the note in do_POST about form.keys(). 93 if form: 94 for field in form.keys(): 95 field_item = form[field] 96 if field_item.filename: 97 # The field contains an uploaded file 98 upload = field_item.file.read() 99 self.wfile.write('\tUploaded %s (%d bytes)<br>' % 100 (field, len(upload))) 101 # Write submitted file to specified filename. 102 file(field_item.filename, 'w').write(upload) 103 del upload 104 else: 105 self.wfile.write('\t%s=%s<br>' % (field, form[field].value)) 106 107 108 def translate_path(self, path): 109 """Override SimpleHTTPRequestHandler's translate_path to serve 110 from arbitrary docroot 111 """ 112 # abandon query parameters 113 path = urlparse.urlparse(path)[2] 114 path = posixpath.normpath(urllib.unquote(path)) 115 words = path.split('/') 116 words = filter(None, words) 117 path = self.server.docroot 118 for word in words: 119 drive, word = os.path.splitdrive(word) 120 head, word = os.path.split(word) 121 if word in (os.curdir, os.pardir): continue 122 path = os.path.join(path, word) 123 logging.debug('Translated path: %s', path) 124 return path 125 126 127 def _fire_event(self): 128 wait_urls = self.server._wait_urls 129 if self.path in wait_urls: 130 _, e = wait_urls[self.path] 131 e.set() 132 del wait_urls[self.path] 133 else: 134 if self.path not in self.server._urls: 135 # if the url is not in _urls, this means it was neither setup 136 # as a permanent, or event url. 137 logging.debug('URL %s not in watch list' % self.path) 138 139 140 @_handle_http_errors 141 def do_GET(self): 142 form = cgi.FieldStorage( 143 fp=self.rfile, 144 headers=self.headers, 145 environ={'REQUEST_METHOD': 'GET'}) 146 split_url = urlparse.urlsplit(self.path) 147 path = split_url[2] 148 # Strip off query parameters to ensure that the url path 149 # matches any registered events. 150 self.path = path 151 args = urlparse.parse_qs(split_url[3]) 152 if path in self.server._url_handlers: 153 self.server._url_handlers[path](self, args) 154 else: 155 SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self) 156 self._fire_event() 157 158 159 @_handle_http_errors 160 def do_HEAD(self): 161 SimpleHTTPServer.SimpleHTTPRequestHandler.do_HEAD(self) 162 163 164class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): 165 def __init__(self, server_address, HandlerClass): 166 HTTPServer.__init__(self, server_address, HandlerClass) 167 168 169class HTTPListener(object): 170 # Point default docroot to a non-existent directory (instead of None) to 171 # avoid exceptions when page content is served through handlers only. 172 def __init__(self, port=0, docroot='/_', wait_urls={}, url_handlers={}): 173 self._server = ThreadedHTTPServer(('', port), FormHandler) 174 self.config_server(self._server, docroot, wait_urls, url_handlers) 175 176 def config_server(self, server, docroot, wait_urls, url_handlers): 177 # Stuff some convenient data fields into the server object. 178 self._server.docroot = docroot 179 self._server._urls = set() 180 self._server._wait_urls = wait_urls 181 self._server._url_handlers = url_handlers 182 self._server._form_entries = {} 183 self._server_thread = threading.Thread( 184 target=self._server.serve_forever) 185 186 def add_url(self, url): 187 """ 188 Add a url to the urls that the http server is actively watching for. 189 190 Not adding a url via add_url or add_wait_url, and only installing a 191 handler will still result in that handler being executed, but this 192 server will warn in the debug logs that it does not expect that url. 193 194 Args: 195 url (string): url suffix to listen to 196 """ 197 self._server._urls.add(url) 198 199 def add_wait_url(self, url='/', matchParams={}): 200 """ 201 Add a wait url to the urls that the http server is aware of. 202 203 Not adding a url via add_url or add_wait_url, and only installing a 204 handler will still result in that handler being executed, but this 205 server will warn in the debug logs that it does not expect that url. 206 207 Args: 208 url (string): url suffix to listen to 209 matchParams (dictionary): an unused dictionary 210 211 Returns: 212 e, and event object. Call e.wait() on the object to wait (block) 213 until the server receives the first request for the wait url. 214 215 """ 216 e = threading.Event() 217 self._server._wait_urls[url] = (matchParams, e) 218 self._server._urls.add(url) 219 return e 220 221 def add_url_handler(self, url, handler_func): 222 self._server._url_handlers[url] = handler_func 223 224 def clear_form_entries(self): 225 self._server._form_entries = {} 226 227 228 def get_form_entries(self): 229 """Returns a dictionary of all field=values recieved by the server. 230 """ 231 return self._server._form_entries 232 233 234 def run(self): 235 logging.debug('http server on %s:%d' % 236 (self._server.server_name, self._server.server_port)) 237 self._server_thread.start() 238 239 240 def stop(self): 241 self._server.shutdown() 242 self._server.socket.close() 243 self._server_thread.join() 244 245 246class SecureHTTPServer(ThreadingMixIn, HTTPServer): 247 def __init__(self, server_address, HandlerClass, cert_path, key_path): 248 _socket = socket.socket(self.address_family, self.socket_type) 249 self.socket = ssl.wrap_socket(_socket, 250 server_side=True, 251 ssl_version=ssl.PROTOCOL_TLSv1, 252 certfile=cert_path, 253 keyfile=key_path) 254 BaseServer.__init__(self, server_address, HandlerClass) 255 self.server_bind() 256 self.server_activate() 257 258 259class SecureHTTPRequestHandler(FormHandler): 260 def setup(self): 261 self.connection = self.request 262 self.rfile = socket._fileobject(self.request, 'rb', self.rbufsize) 263 self.wfile = socket._fileobject(self.request, 'wb', self.wbufsize) 264 265 # Override the default logging methods to use the logging module directly. 266 def log_error(self, format, *args): 267 logging.warning("(httpd error) %s - - [%s] %s\n" % 268 (self.address_string(), self.log_date_time_string(), 269 format%args)) 270 271 def log_message(self, format, *args): 272 logging.debug("%s - - [%s] %s\n" % 273 (self.address_string(), self.log_date_time_string(), 274 format%args)) 275 276 277class SecureHTTPListener(HTTPListener): 278 def __init__(self, 279 cert_path='/etc/login_trust_root.pem', 280 key_path='/etc/mock_server.key', 281 port=0, 282 docroot='/_', 283 wait_urls={}, 284 url_handlers={}): 285 self._server = SecureHTTPServer(('', port), 286 SecureHTTPRequestHandler, 287 cert_path, 288 key_path) 289 self.config_server(self._server, docroot, wait_urls, url_handlers) 290 291 292 def getsockname(self): 293 return self._server.socket.getsockname() 294 295