• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2Test suite for socketserver.
3"""
4
5import contextlib
6import io
7import os
8import select
9import signal
10import socket
11import tempfile
12import threading
13import unittest
14import socketserver
15
16import test.support
17from test.support import reap_children, reap_threads, verbose
18from test.support import socket_helper
19
20
21test.support.requires("network")
22
23TEST_STR = b"hello world\n"
24HOST = socket_helper.HOST
25
26HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
27requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
28                                            'requires Unix sockets')
29HAVE_FORKING = hasattr(os, "fork")
30requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
31
32def signal_alarm(n):
33    """Call signal.alarm when it exists (i.e. not on Windows)."""
34    if hasattr(signal, 'alarm'):
35        signal.alarm(n)
36
37# Remember real select() to avoid interferences with mocking
38_real_select = select.select
39
40def receive(sock, n, timeout=test.support.SHORT_TIMEOUT):
41    r, w, x = _real_select([sock], [], [], timeout)
42    if sock in r:
43        return sock.recv(n)
44    else:
45        raise RuntimeError("timed out on %r" % (sock,))
46
47if HAVE_UNIX_SOCKETS and HAVE_FORKING:
48    class ForkingUnixStreamServer(socketserver.ForkingMixIn,
49                                  socketserver.UnixStreamServer):
50        pass
51
52    class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
53                                    socketserver.UnixDatagramServer):
54        pass
55
56
57@contextlib.contextmanager
58def simple_subprocess(testcase):
59    """Tests that a custom child process is not waited on (Issue 1540386)"""
60    pid = os.fork()
61    if pid == 0:
62        # Don't raise an exception; it would be caught by the test harness.
63        os._exit(72)
64    try:
65        yield None
66    except:
67        raise
68    finally:
69        test.support.wait_process(pid, exitcode=72)
70
71
72class SocketServerTest(unittest.TestCase):
73    """Test all socket servers."""
74
75    def setUp(self):
76        signal_alarm(60)  # Kill deadlocks after 60 seconds.
77        self.port_seed = 0
78        self.test_files = []
79
80    def tearDown(self):
81        signal_alarm(0)  # Didn't deadlock.
82        reap_children()
83
84        for fn in self.test_files:
85            try:
86                os.remove(fn)
87            except OSError:
88                pass
89        self.test_files[:] = []
90
91    def pickaddr(self, proto):
92        if proto == socket.AF_INET:
93            return (HOST, 0)
94        else:
95            # XXX: We need a way to tell AF_UNIX to pick its own name
96            # like AF_INET provides port==0.
97            dir = None
98            fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
99            self.test_files.append(fn)
100            return fn
101
102    def make_server(self, addr, svrcls, hdlrbase):
103        class MyServer(svrcls):
104            def handle_error(self, request, client_address):
105                self.close_request(request)
106                raise
107
108        class MyHandler(hdlrbase):
109            def handle(self):
110                line = self.rfile.readline()
111                self.wfile.write(line)
112
113        if verbose: print("creating server")
114        try:
115            server = MyServer(addr, MyHandler)
116        except PermissionError as e:
117            # Issue 29184: cannot bind() a Unix socket on Android.
118            self.skipTest('Cannot create server (%s, %s): %s' %
119                          (svrcls, addr, e))
120        self.assertEqual(server.server_address, server.socket.getsockname())
121        return server
122
123    @reap_threads
124    def run_server(self, svrcls, hdlrbase, testfunc):
125        server = self.make_server(self.pickaddr(svrcls.address_family),
126                                  svrcls, hdlrbase)
127        # We had the OS pick a port, so pull the real address out of
128        # the server.
129        addr = server.server_address
130        if verbose:
131            print("ADDR =", addr)
132            print("CLASS =", svrcls)
133
134        t = threading.Thread(
135            name='%s serving' % svrcls,
136            target=server.serve_forever,
137            # Short poll interval to make the test finish quickly.
138            # Time between requests is short enough that we won't wake
139            # up spuriously too many times.
140            kwargs={'poll_interval':0.01})
141        t.daemon = True  # In case this function raises.
142        t.start()
143        if verbose: print("server running")
144        for i in range(3):
145            if verbose: print("test client", i)
146            testfunc(svrcls.address_family, addr)
147        if verbose: print("waiting for server")
148        server.shutdown()
149        t.join()
150        server.server_close()
151        self.assertEqual(-1, server.socket.fileno())
152        if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
153            # bpo-31151: Check that ForkingMixIn.server_close() waits until
154            # all children completed
155            self.assertFalse(server.active_children)
156        if verbose: print("done")
157
158    def stream_examine(self, proto, addr):
159        with socket.socket(proto, socket.SOCK_STREAM) as s:
160            s.connect(addr)
161            s.sendall(TEST_STR)
162            buf = data = receive(s, 100)
163            while data and b'\n' not in buf:
164                data = receive(s, 100)
165                buf += data
166            self.assertEqual(buf, TEST_STR)
167
168    def dgram_examine(self, proto, addr):
169        with socket.socket(proto, socket.SOCK_DGRAM) as s:
170            if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
171                s.bind(self.pickaddr(proto))
172            s.sendto(TEST_STR, addr)
173            buf = data = receive(s, 100)
174            while data and b'\n' not in buf:
175                data = receive(s, 100)
176                buf += data
177            self.assertEqual(buf, TEST_STR)
178
179    def test_TCPServer(self):
180        self.run_server(socketserver.TCPServer,
181                        socketserver.StreamRequestHandler,
182                        self.stream_examine)
183
184    def test_ThreadingTCPServer(self):
185        self.run_server(socketserver.ThreadingTCPServer,
186                        socketserver.StreamRequestHandler,
187                        self.stream_examine)
188
189    @requires_forking
190    def test_ForkingTCPServer(self):
191        with simple_subprocess(self):
192            self.run_server(socketserver.ForkingTCPServer,
193                            socketserver.StreamRequestHandler,
194                            self.stream_examine)
195
196    @requires_unix_sockets
197    def test_UnixStreamServer(self):
198        self.run_server(socketserver.UnixStreamServer,
199                        socketserver.StreamRequestHandler,
200                        self.stream_examine)
201
202    @requires_unix_sockets
203    def test_ThreadingUnixStreamServer(self):
204        self.run_server(socketserver.ThreadingUnixStreamServer,
205                        socketserver.StreamRequestHandler,
206                        self.stream_examine)
207
208    @requires_unix_sockets
209    @requires_forking
210    def test_ForkingUnixStreamServer(self):
211        with simple_subprocess(self):
212            self.run_server(ForkingUnixStreamServer,
213                            socketserver.StreamRequestHandler,
214                            self.stream_examine)
215
216    def test_UDPServer(self):
217        self.run_server(socketserver.UDPServer,
218                        socketserver.DatagramRequestHandler,
219                        self.dgram_examine)
220
221    def test_ThreadingUDPServer(self):
222        self.run_server(socketserver.ThreadingUDPServer,
223                        socketserver.DatagramRequestHandler,
224                        self.dgram_examine)
225
226    @requires_forking
227    def test_ForkingUDPServer(self):
228        with simple_subprocess(self):
229            self.run_server(socketserver.ForkingUDPServer,
230                            socketserver.DatagramRequestHandler,
231                            self.dgram_examine)
232
233    @requires_unix_sockets
234    def test_UnixDatagramServer(self):
235        self.run_server(socketserver.UnixDatagramServer,
236                        socketserver.DatagramRequestHandler,
237                        self.dgram_examine)
238
239    @requires_unix_sockets
240    def test_ThreadingUnixDatagramServer(self):
241        self.run_server(socketserver.ThreadingUnixDatagramServer,
242                        socketserver.DatagramRequestHandler,
243                        self.dgram_examine)
244
245    @requires_unix_sockets
246    @requires_forking
247    def test_ForkingUnixDatagramServer(self):
248        self.run_server(ForkingUnixDatagramServer,
249                        socketserver.DatagramRequestHandler,
250                        self.dgram_examine)
251
252    @reap_threads
253    def test_shutdown(self):
254        # Issue #2302: shutdown() should always succeed in making an
255        # other thread leave serve_forever().
256        class MyServer(socketserver.TCPServer):
257            pass
258
259        class MyHandler(socketserver.StreamRequestHandler):
260            pass
261
262        threads = []
263        for i in range(20):
264            s = MyServer((HOST, 0), MyHandler)
265            t = threading.Thread(
266                name='MyServer serving',
267                target=s.serve_forever,
268                kwargs={'poll_interval':0.01})
269            t.daemon = True  # In case this function raises.
270            threads.append((t, s))
271        for t, s in threads:
272            t.start()
273            s.shutdown()
274        for t, s in threads:
275            t.join()
276            s.server_close()
277
278    def test_tcpserver_bind_leak(self):
279        # Issue #22435: the server socket wouldn't be closed if bind()/listen()
280        # failed.
281        # Create many servers for which bind() will fail, to see if this result
282        # in FD exhaustion.
283        for i in range(1024):
284            with self.assertRaises(OverflowError):
285                socketserver.TCPServer((HOST, -1),
286                                       socketserver.StreamRequestHandler)
287
288    def test_context_manager(self):
289        with socketserver.TCPServer((HOST, 0),
290                                    socketserver.StreamRequestHandler) as server:
291            pass
292        self.assertEqual(-1, server.socket.fileno())
293
294
295class ErrorHandlerTest(unittest.TestCase):
296    """Test that the servers pass normal exceptions from the handler to
297    handle_error(), and that exiting exceptions like SystemExit and
298    KeyboardInterrupt are not passed."""
299
300    def tearDown(self):
301        test.support.unlink(test.support.TESTFN)
302
303    def test_sync_handled(self):
304        BaseErrorTestServer(ValueError)
305        self.check_result(handled=True)
306
307    def test_sync_not_handled(self):
308        with self.assertRaises(SystemExit):
309            BaseErrorTestServer(SystemExit)
310        self.check_result(handled=False)
311
312    def test_threading_handled(self):
313        ThreadingErrorTestServer(ValueError)
314        self.check_result(handled=True)
315
316    def test_threading_not_handled(self):
317        ThreadingErrorTestServer(SystemExit)
318        self.check_result(handled=False)
319
320    @requires_forking
321    def test_forking_handled(self):
322        ForkingErrorTestServer(ValueError)
323        self.check_result(handled=True)
324
325    @requires_forking
326    def test_forking_not_handled(self):
327        ForkingErrorTestServer(SystemExit)
328        self.check_result(handled=False)
329
330    def check_result(self, handled):
331        with open(test.support.TESTFN) as log:
332            expected = 'Handler called\n' + 'Error handled\n' * handled
333            self.assertEqual(log.read(), expected)
334
335
336class BaseErrorTestServer(socketserver.TCPServer):
337    def __init__(self, exception):
338        self.exception = exception
339        super().__init__((HOST, 0), BadHandler)
340        with socket.create_connection(self.server_address):
341            pass
342        try:
343            self.handle_request()
344        finally:
345            self.server_close()
346        self.wait_done()
347
348    def handle_error(self, request, client_address):
349        with open(test.support.TESTFN, 'a') as log:
350            log.write('Error handled\n')
351
352    def wait_done(self):
353        pass
354
355
356class BadHandler(socketserver.BaseRequestHandler):
357    def handle(self):
358        with open(test.support.TESTFN, 'a') as log:
359            log.write('Handler called\n')
360        raise self.server.exception('Test error')
361
362
363class ThreadingErrorTestServer(socketserver.ThreadingMixIn,
364        BaseErrorTestServer):
365    def __init__(self, *pos, **kw):
366        self.done = threading.Event()
367        super().__init__(*pos, **kw)
368
369    def shutdown_request(self, *pos, **kw):
370        super().shutdown_request(*pos, **kw)
371        self.done.set()
372
373    def wait_done(self):
374        self.done.wait()
375
376
377if HAVE_FORKING:
378    class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
379        pass
380
381
382class SocketWriterTest(unittest.TestCase):
383    def test_basics(self):
384        class Handler(socketserver.StreamRequestHandler):
385            def handle(self):
386                self.server.wfile = self.wfile
387                self.server.wfile_fileno = self.wfile.fileno()
388                self.server.request_fileno = self.request.fileno()
389
390        server = socketserver.TCPServer((HOST, 0), Handler)
391        self.addCleanup(server.server_close)
392        s = socket.socket(
393            server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
394        with s:
395            s.connect(server.server_address)
396        server.handle_request()
397        self.assertIsInstance(server.wfile, io.BufferedIOBase)
398        self.assertEqual(server.wfile_fileno, server.request_fileno)
399
400    def test_write(self):
401        # Test that wfile.write() sends data immediately, and that it does
402        # not truncate sends when interrupted by a Unix signal
403        pthread_kill = test.support.get_attribute(signal, 'pthread_kill')
404
405        class Handler(socketserver.StreamRequestHandler):
406            def handle(self):
407                self.server.sent1 = self.wfile.write(b'write data\n')
408                # Should be sent immediately, without requiring flush()
409                self.server.received = self.rfile.readline()
410                big_chunk = b'\0' * test.support.SOCK_MAX_SIZE
411                self.server.sent2 = self.wfile.write(big_chunk)
412
413        server = socketserver.TCPServer((HOST, 0), Handler)
414        self.addCleanup(server.server_close)
415        interrupted = threading.Event()
416
417        def signal_handler(signum, frame):
418            interrupted.set()
419
420        original = signal.signal(signal.SIGUSR1, signal_handler)
421        self.addCleanup(signal.signal, signal.SIGUSR1, original)
422        response1 = None
423        received2 = None
424        main_thread = threading.get_ident()
425
426        def run_client():
427            s = socket.socket(server.address_family, socket.SOCK_STREAM,
428                socket.IPPROTO_TCP)
429            with s, s.makefile('rb') as reader:
430                s.connect(server.server_address)
431                nonlocal response1
432                response1 = reader.readline()
433                s.sendall(b'client response\n')
434
435                reader.read(100)
436                # The main thread should now be blocking in a send() syscall.
437                # But in theory, it could get interrupted by other signals,
438                # and then retried. So keep sending the signal in a loop, in
439                # case an earlier signal happens to be delivered at an
440                # inconvenient moment.
441                while True:
442                    pthread_kill(main_thread, signal.SIGUSR1)
443                    if interrupted.wait(timeout=float(1)):
444                        break
445                nonlocal received2
446                received2 = len(reader.read())
447
448        background = threading.Thread(target=run_client)
449        background.start()
450        server.handle_request()
451        background.join()
452        self.assertEqual(server.sent1, len(response1))
453        self.assertEqual(response1, b'write data\n')
454        self.assertEqual(server.received, b'client response\n')
455        self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE)
456        self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100)
457
458
459class MiscTestCase(unittest.TestCase):
460
461    def test_all(self):
462        # objects defined in the module should be in __all__
463        expected = []
464        for name in dir(socketserver):
465            if not name.startswith('_'):
466                mod_object = getattr(socketserver, name)
467                if getattr(mod_object, '__module__', None) == 'socketserver':
468                    expected.append(name)
469        self.assertCountEqual(socketserver.__all__, expected)
470
471    def test_shutdown_request_called_if_verify_request_false(self):
472        # Issue #26309: BaseServer should call shutdown_request even if
473        # verify_request is False
474
475        class MyServer(socketserver.TCPServer):
476            def verify_request(self, request, client_address):
477                return False
478
479            shutdown_called = 0
480            def shutdown_request(self, request):
481                self.shutdown_called += 1
482                socketserver.TCPServer.shutdown_request(self, request)
483
484        server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
485        s = socket.socket(server.address_family, socket.SOCK_STREAM)
486        s.connect(server.server_address)
487        s.close()
488        server.handle_request()
489        self.assertEqual(server.shutdown_called, 1)
490        server.server_close()
491
492
493if __name__ == "__main__":
494    unittest.main()
495