• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import os
2import sys
3import ssl
4import pprint
5import threading
6import urllib.parse
7# Rename HTTPServer to _HTTPServer so as to avoid confusion with HTTPSServer.
8from http.server import (HTTPServer as _HTTPServer,
9    SimpleHTTPRequestHandler, BaseHTTPRequestHandler)
10
11from test import support
12
13here = os.path.dirname(__file__)
14
15HOST = support.HOST
16CERTFILE = os.path.join(here, 'keycert.pem')
17
18# This one's based on HTTPServer, which is based on socketserver
19
20class HTTPSServer(_HTTPServer):
21
22    def __init__(self, server_address, handler_class, context):
23        _HTTPServer.__init__(self, server_address, handler_class)
24        self.context = context
25
26    def __str__(self):
27        return ('<%s %s:%s>' %
28                (self.__class__.__name__,
29                 self.server_name,
30                 self.server_port))
31
32    def get_request(self):
33        # override this to wrap socket with SSL
34        try:
35            sock, addr = self.socket.accept()
36            sslconn = self.context.wrap_socket(sock, server_side=True)
37        except OSError as e:
38            # socket errors are silenced by the caller, print them here
39            if support.verbose:
40                sys.stderr.write("Got an error:\n%s\n" % e)
41            raise
42        return sslconn, addr
43
44class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
45    # need to override translate_path to get a known root,
46    # instead of using os.curdir, since the test could be
47    # run from anywhere
48
49    server_version = "TestHTTPS/1.0"
50    root = here
51    # Avoid hanging when a request gets interrupted by the client
52    timeout = 5
53
54    def translate_path(self, path):
55        """Translate a /-separated PATH to the local filename syntax.
56
57        Components that mean special things to the local file system
58        (e.g. drive or directory names) are ignored.  (XXX They should
59        probably be diagnosed.)
60
61        """
62        # abandon query parameters
63        path = urllib.parse.urlparse(path)[2]
64        path = os.path.normpath(urllib.parse.unquote(path))
65        words = path.split('/')
66        words = filter(None, words)
67        path = self.root
68        for word in words:
69            drive, word = os.path.splitdrive(word)
70            head, word = os.path.split(word)
71            path = os.path.join(path, word)
72        return path
73
74    def log_message(self, format, *args):
75        # we override this to suppress logging unless "verbose"
76        if support.verbose:
77            sys.stdout.write(" server (%s:%d %s):\n   [%s] %s\n" %
78                             (self.server.server_address,
79                              self.server.server_port,
80                              self.request.cipher(),
81                              self.log_date_time_string(),
82                              format%args))
83
84
85class StatsRequestHandler(BaseHTTPRequestHandler):
86    """Example HTTP request handler which returns SSL statistics on GET
87    requests.
88    """
89
90    server_version = "StatsHTTPS/1.0"
91
92    def do_GET(self, send_body=True):
93        """Serve a GET request."""
94        sock = self.rfile.raw._sock
95        context = sock.context
96        stats = {
97            'session_cache': context.session_stats(),
98            'cipher': sock.cipher(),
99            'compression': sock.compression(),
100            }
101        body = pprint.pformat(stats)
102        body = body.encode('utf-8')
103        self.send_response(200)
104        self.send_header("Content-type", "text/plain; charset=utf-8")
105        self.send_header("Content-Length", str(len(body)))
106        self.end_headers()
107        if send_body:
108            self.wfile.write(body)
109
110    def do_HEAD(self):
111        """Serve a HEAD request."""
112        self.do_GET(send_body=False)
113
114    def log_request(self, format, *args):
115        if support.verbose:
116            BaseHTTPRequestHandler.log_request(self, format, *args)
117
118
119class HTTPSServerThread(threading.Thread):
120
121    def __init__(self, context, host=HOST, handler_class=None):
122        self.flag = None
123        self.server = HTTPSServer((host, 0),
124                                  handler_class or RootedHTTPRequestHandler,
125                                  context)
126        self.port = self.server.server_port
127        threading.Thread.__init__(self)
128        self.daemon = True
129
130    def __str__(self):
131        return "<%s %s>" % (self.__class__.__name__, self.server)
132
133    def start(self, flag=None):
134        self.flag = flag
135        threading.Thread.start(self)
136
137    def run(self):
138        if self.flag:
139            self.flag.set()
140        try:
141            self.server.serve_forever(0.05)
142        finally:
143            self.server.server_close()
144
145    def stop(self):
146        self.server.shutdown()
147
148
149def make_https_server(case, *, context=None, certfile=CERTFILE,
150                      host=HOST, handler_class=None):
151    if context is None:
152        context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
153    # We assume the certfile contains both private key and certificate
154    context.load_cert_chain(certfile)
155    server = HTTPSServerThread(context, host, handler_class)
156    flag = threading.Event()
157    server.start(flag)
158    flag.wait()
159    def cleanup():
160        if support.verbose:
161            sys.stdout.write('stopping HTTPS server\n')
162        server.stop()
163        if support.verbose:
164            sys.stdout.write('joining HTTPS thread\n')
165        server.join()
166    case.addCleanup(cleanup)
167    return server
168
169
170if __name__ == "__main__":
171    import argparse
172    parser = argparse.ArgumentParser(
173        description='Run a test HTTPS server. '
174                    'By default, the current directory is served.')
175    parser.add_argument('-p', '--port', type=int, default=4433,
176                        help='port to listen on (default: %(default)s)')
177    parser.add_argument('-q', '--quiet', dest='verbose', default=True,
178                        action='store_false', help='be less verbose')
179    parser.add_argument('-s', '--stats', dest='use_stats_handler', default=False,
180                        action='store_true', help='always return stats page')
181    parser.add_argument('--curve-name', dest='curve_name', type=str,
182                        action='store',
183                        help='curve name for EC-based Diffie-Hellman')
184    parser.add_argument('--ciphers', dest='ciphers', type=str,
185                        help='allowed cipher list')
186    parser.add_argument('--dh', dest='dh_file', type=str, action='store',
187                        help='PEM file containing DH parameters')
188    args = parser.parse_args()
189
190    support.verbose = args.verbose
191    if args.use_stats_handler:
192        handler_class = StatsRequestHandler
193    else:
194        handler_class = RootedHTTPRequestHandler
195        handler_class.root = os.getcwd()
196    context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
197    context.load_cert_chain(CERTFILE)
198    if args.curve_name:
199        context.set_ecdh_curve(args.curve_name)
200    if args.dh_file:
201        context.load_dh_params(args.dh_file)
202    if args.ciphers:
203        context.set_ciphers(args.ciphers)
204
205    server = HTTPSServer(("", args.port), handler_class, context)
206    if args.verbose:
207        print("Listening on https://localhost:{0.port}".format(args))
208    server.serve_forever(0.1)
209