1import os 2import sys 3import ssl 4import pprint 5import urllib 6import urlparse 7# Rename HTTPServer to _HTTPServer so as to avoid confusion with HTTPSServer. 8from BaseHTTPServer import HTTPServer as _HTTPServer, BaseHTTPRequestHandler 9from SimpleHTTPServer import SimpleHTTPRequestHandler 10 11from test import test_support as support 12threading = support.import_module("threading") 13 14here = os.path.dirname(__file__) 15 16HOST = support.HOST 17CERTFILE = os.path.join(here, 'keycert.pem') 18 19# This one's based on HTTPServer, which is based on SocketServer 20 21class HTTPSServer(_HTTPServer): 22 23 def __init__(self, server_address, handler_class, context): 24 _HTTPServer.__init__(self, server_address, handler_class) 25 self.context = context 26 27 def __str__(self): 28 return ('<%s %s:%s>' % 29 (self.__class__.__name__, 30 self.server_name, 31 self.server_port)) 32 33 def get_request(self): 34 # override this to wrap socket with SSL 35 try: 36 sock, addr = self.socket.accept() 37 sslconn = self.context.wrap_socket(sock, server_side=True) 38 except OSError as e: 39 # socket errors are silenced by the caller, print them here 40 if support.verbose: 41 sys.stderr.write("Got an error:\n%s\n" % e) 42 raise 43 return sslconn, addr 44 45class RootedHTTPRequestHandler(SimpleHTTPRequestHandler): 46 # need to override translate_path to get a known root, 47 # instead of using os.curdir, since the test could be 48 # run from anywhere 49 50 server_version = "TestHTTPS/1.0" 51 root = here 52 # Avoid hanging when a request gets interrupted by the client 53 timeout = 5 54 55 def translate_path(self, path): 56 """Translate a /-separated PATH to the local filename syntax. 57 58 Components that mean special things to the local file system 59 (e.g. drive or directory names) are ignored. (XXX They should 60 probably be diagnosed.) 61 62 """ 63 # abandon query parameters 64 path = urlparse.urlparse(path)[2] 65 path = os.path.normpath(urllib.unquote(path)) 66 words = path.split('/') 67 words = filter(None, words) 68 path = self.root 69 for word in words: 70 drive, word = os.path.splitdrive(word) 71 head, word = os.path.split(word) 72 path = os.path.join(path, word) 73 return path 74 75 def log_message(self, format, *args): 76 # we override this to suppress logging unless "verbose" 77 if support.verbose: 78 sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" % 79 (self.server.server_address, 80 self.server.server_port, 81 self.request.cipher(), 82 self.log_date_time_string(), 83 format%args)) 84 85 86class StatsRequestHandler(BaseHTTPRequestHandler): 87 """Example HTTP request handler which returns SSL statistics on GET 88 requests. 89 """ 90 91 server_version = "StatsHTTPS/1.0" 92 93 def do_GET(self, send_body=True): 94 """Serve a GET request.""" 95 sock = self.rfile.raw._sock 96 context = sock.context 97 stats = { 98 'session_cache': context.session_stats(), 99 'cipher': sock.cipher(), 100 'compression': sock.compression(), 101 } 102 body = pprint.pformat(stats) 103 body = body.encode('utf-8') 104 self.send_response(200) 105 self.send_header("Content-type", "text/plain; charset=utf-8") 106 self.send_header("Content-Length", str(len(body))) 107 self.end_headers() 108 if send_body: 109 self.wfile.write(body) 110 111 def do_HEAD(self): 112 """Serve a HEAD request.""" 113 self.do_GET(send_body=False) 114 115 def log_request(self, format, *args): 116 if support.verbose: 117 BaseHTTPRequestHandler.log_request(self, format, *args) 118 119 120class HTTPSServerThread(threading.Thread): 121 122 def __init__(self, context, host=HOST, handler_class=None): 123 self.flag = None 124 self.server = HTTPSServer((host, 0), 125 handler_class or RootedHTTPRequestHandler, 126 context) 127 self.port = self.server.server_port 128 threading.Thread.__init__(self) 129 self.daemon = True 130 131 def __str__(self): 132 return "<%s %s>" % (self.__class__.__name__, self.server) 133 134 def start(self, flag=None): 135 self.flag = flag 136 threading.Thread.start(self) 137 138 def run(self): 139 if self.flag: 140 self.flag.set() 141 try: 142 self.server.serve_forever(0.05) 143 finally: 144 self.server.server_close() 145 146 def stop(self): 147 self.server.shutdown() 148 149 150def make_https_server(case, context=None, certfile=CERTFILE, 151 host=HOST, handler_class=None): 152 if context is None: 153 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 154 # We assume the certfile contains both private key and certificate 155 context.load_cert_chain(certfile) 156 server = HTTPSServerThread(context, host, handler_class) 157 flag = threading.Event() 158 server.start(flag) 159 flag.wait() 160 def cleanup(): 161 if support.verbose: 162 sys.stdout.write('stopping HTTPS server\n') 163 server.stop() 164 if support.verbose: 165 sys.stdout.write('joining HTTPS thread\n') 166 server.join() 167 case.addCleanup(cleanup) 168 return server 169 170 171if __name__ == "__main__": 172 import argparse 173 parser = argparse.ArgumentParser( 174 description='Run a test HTTPS server. ' 175 'By default, the current directory is served.') 176 parser.add_argument('-p', '--port', type=int, default=4433, 177 help='port to listen on (default: %(default)s)') 178 parser.add_argument('-q', '--quiet', dest='verbose', default=True, 179 action='store_false', help='be less verbose') 180 parser.add_argument('-s', '--stats', dest='use_stats_handler', default=False, 181 action='store_true', help='always return stats page') 182 parser.add_argument('--curve-name', dest='curve_name', type=str, 183 action='store', 184 help='curve name for EC-based Diffie-Hellman') 185 parser.add_argument('--ciphers', dest='ciphers', type=str, 186 help='allowed cipher list') 187 parser.add_argument('--dh', dest='dh_file', type=str, action='store', 188 help='PEM file containing DH parameters') 189 args = parser.parse_args() 190 191 support.verbose = args.verbose 192 if args.use_stats_handler: 193 handler_class = StatsRequestHandler 194 else: 195 handler_class = RootedHTTPRequestHandler 196 handler_class.root = os.getcwd() 197 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 198 context.load_cert_chain(CERTFILE) 199 if args.curve_name: 200 context.set_ecdh_curve(args.curve_name) 201 if args.dh_file: 202 context.load_dh_params(args.dh_file) 203 if args.ciphers: 204 context.set_ciphers(args.ciphers) 205 206 server = HTTPSServer(("", args.port), handler_class, context) 207 if args.verbose: 208 print("Listening on https://localhost:{0.port}".format(args)) 209 server.serve_forever(0.1) 210