• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2Test suite for socketserver.
3"""
4
5import contextlib
6import io
7import os
8import select
9import signal
10import socket
11import tempfile
12import threading
13import unittest
14import socketserver
15
16import test.support
17from test.support import reap_children, reap_threads, verbose
18
19
20test.support.requires("network")
21
22TEST_STR = b"hello world\n"
23HOST = test.support.HOST
24
25HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
26requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
27                                            'requires Unix sockets')
28HAVE_FORKING = hasattr(os, "fork")
29requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
30
31def signal_alarm(n):
32    """Call signal.alarm when it exists (i.e. not on Windows)."""
33    if hasattr(signal, 'alarm'):
34        signal.alarm(n)
35
36# Remember real select() to avoid interferences with mocking
37_real_select = select.select
38
39def receive(sock, n, timeout=20):
40    r, w, x = _real_select([sock], [], [], timeout)
41    if sock in r:
42        return sock.recv(n)
43    else:
44        raise RuntimeError("timed out on %r" % (sock,))
45
46if HAVE_UNIX_SOCKETS and HAVE_FORKING:
47    class ForkingUnixStreamServer(socketserver.ForkingMixIn,
48                                  socketserver.UnixStreamServer):
49        pass
50
51    class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
52                                    socketserver.UnixDatagramServer):
53        pass
54
55
56@contextlib.contextmanager
57def simple_subprocess(testcase):
58    """Tests that a custom child process is not waited on (Issue 1540386)"""
59    pid = os.fork()
60    if pid == 0:
61        # Don't raise an exception; it would be caught by the test harness.
62        os._exit(72)
63    try:
64        yield None
65    except:
66        raise
67    finally:
68        pid2, status = os.waitpid(pid, 0)
69        testcase.assertEqual(pid2, pid)
70        testcase.assertEqual(72 << 8, status)
71
72
73class SocketServerTest(unittest.TestCase):
74    """Test all socket servers."""
75
76    def setUp(self):
77        signal_alarm(60)  # Kill deadlocks after 60 seconds.
78        self.port_seed = 0
79        self.test_files = []
80
81    def tearDown(self):
82        signal_alarm(0)  # Didn't deadlock.
83        reap_children()
84
85        for fn in self.test_files:
86            try:
87                os.remove(fn)
88            except OSError:
89                pass
90        self.test_files[:] = []
91
92    def pickaddr(self, proto):
93        if proto == socket.AF_INET:
94            return (HOST, 0)
95        else:
96            # XXX: We need a way to tell AF_UNIX to pick its own name
97            # like AF_INET provides port==0.
98            dir = None
99            fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
100            self.test_files.append(fn)
101            return fn
102
103    def make_server(self, addr, svrcls, hdlrbase):
104        class MyServer(svrcls):
105            def handle_error(self, request, client_address):
106                self.close_request(request)
107                raise
108
109        class MyHandler(hdlrbase):
110            def handle(self):
111                line = self.rfile.readline()
112                self.wfile.write(line)
113
114        if verbose: print("creating server")
115        try:
116            server = MyServer(addr, MyHandler)
117        except PermissionError as e:
118            # Issue 29184: cannot bind() a Unix socket on Android.
119            self.skipTest('Cannot create server (%s, %s): %s' %
120                          (svrcls, addr, e))
121        self.assertEqual(server.server_address, server.socket.getsockname())
122        return server
123
124    @reap_threads
125    def run_server(self, svrcls, hdlrbase, testfunc):
126        server = self.make_server(self.pickaddr(svrcls.address_family),
127                                  svrcls, hdlrbase)
128        # We had the OS pick a port, so pull the real address out of
129        # the server.
130        addr = server.server_address
131        if verbose:
132            print("ADDR =", addr)
133            print("CLASS =", svrcls)
134
135        t = threading.Thread(
136            name='%s serving' % svrcls,
137            target=server.serve_forever,
138            # Short poll interval to make the test finish quickly.
139            # Time between requests is short enough that we won't wake
140            # up spuriously too many times.
141            kwargs={'poll_interval':0.01})
142        t.daemon = True  # In case this function raises.
143        t.start()
144        if verbose: print("server running")
145        for i in range(3):
146            if verbose: print("test client", i)
147            testfunc(svrcls.address_family, addr)
148        if verbose: print("waiting for server")
149        server.shutdown()
150        t.join()
151        server.server_close()
152        self.assertEqual(-1, server.socket.fileno())
153        if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
154            # bpo-31151: Check that ForkingMixIn.server_close() waits until
155            # all children completed
156            self.assertFalse(server.active_children)
157        if verbose: print("done")
158
159    def stream_examine(self, proto, addr):
160        with socket.socket(proto, socket.SOCK_STREAM) as s:
161            s.connect(addr)
162            s.sendall(TEST_STR)
163            buf = data = receive(s, 100)
164            while data and b'\n' not in buf:
165                data = receive(s, 100)
166                buf += data
167            self.assertEqual(buf, TEST_STR)
168
169    def dgram_examine(self, proto, addr):
170        with socket.socket(proto, socket.SOCK_DGRAM) as s:
171            if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
172                s.bind(self.pickaddr(proto))
173            s.sendto(TEST_STR, addr)
174            buf = data = receive(s, 100)
175            while data and b'\n' not in buf:
176                data = receive(s, 100)
177                buf += data
178            self.assertEqual(buf, TEST_STR)
179
180    def test_TCPServer(self):
181        self.run_server(socketserver.TCPServer,
182                        socketserver.StreamRequestHandler,
183                        self.stream_examine)
184
185    def test_ThreadingTCPServer(self):
186        self.run_server(socketserver.ThreadingTCPServer,
187                        socketserver.StreamRequestHandler,
188                        self.stream_examine)
189
190    @requires_forking
191    def test_ForkingTCPServer(self):
192        with simple_subprocess(self):
193            self.run_server(socketserver.ForkingTCPServer,
194                            socketserver.StreamRequestHandler,
195                            self.stream_examine)
196
197    @requires_unix_sockets
198    def test_UnixStreamServer(self):
199        self.run_server(socketserver.UnixStreamServer,
200                        socketserver.StreamRequestHandler,
201                        self.stream_examine)
202
203    @requires_unix_sockets
204    def test_ThreadingUnixStreamServer(self):
205        self.run_server(socketserver.ThreadingUnixStreamServer,
206                        socketserver.StreamRequestHandler,
207                        self.stream_examine)
208
209    @requires_unix_sockets
210    @requires_forking
211    def test_ForkingUnixStreamServer(self):
212        with simple_subprocess(self):
213            self.run_server(ForkingUnixStreamServer,
214                            socketserver.StreamRequestHandler,
215                            self.stream_examine)
216
217    def test_UDPServer(self):
218        self.run_server(socketserver.UDPServer,
219                        socketserver.DatagramRequestHandler,
220                        self.dgram_examine)
221
222    def test_ThreadingUDPServer(self):
223        self.run_server(socketserver.ThreadingUDPServer,
224                        socketserver.DatagramRequestHandler,
225                        self.dgram_examine)
226
227    @requires_forking
228    def test_ForkingUDPServer(self):
229        with simple_subprocess(self):
230            self.run_server(socketserver.ForkingUDPServer,
231                            socketserver.DatagramRequestHandler,
232                            self.dgram_examine)
233
234    @requires_unix_sockets
235    def test_UnixDatagramServer(self):
236        self.run_server(socketserver.UnixDatagramServer,
237                        socketserver.DatagramRequestHandler,
238                        self.dgram_examine)
239
240    @requires_unix_sockets
241    def test_ThreadingUnixDatagramServer(self):
242        self.run_server(socketserver.ThreadingUnixDatagramServer,
243                        socketserver.DatagramRequestHandler,
244                        self.dgram_examine)
245
246    @requires_unix_sockets
247    @requires_forking
248    def test_ForkingUnixDatagramServer(self):
249        self.run_server(ForkingUnixDatagramServer,
250                        socketserver.DatagramRequestHandler,
251                        self.dgram_examine)
252
253    @reap_threads
254    def test_shutdown(self):
255        # Issue #2302: shutdown() should always succeed in making an
256        # other thread leave serve_forever().
257        class MyServer(socketserver.TCPServer):
258            pass
259
260        class MyHandler(socketserver.StreamRequestHandler):
261            pass
262
263        threads = []
264        for i in range(20):
265            s = MyServer((HOST, 0), MyHandler)
266            t = threading.Thread(
267                name='MyServer serving',
268                target=s.serve_forever,
269                kwargs={'poll_interval':0.01})
270            t.daemon = True  # In case this function raises.
271            threads.append((t, s))
272        for t, s in threads:
273            t.start()
274            s.shutdown()
275        for t, s in threads:
276            t.join()
277            s.server_close()
278
279    def test_tcpserver_bind_leak(self):
280        # Issue #22435: the server socket wouldn't be closed if bind()/listen()
281        # failed.
282        # Create many servers for which bind() will fail, to see if this result
283        # in FD exhaustion.
284        for i in range(1024):
285            with self.assertRaises(OverflowError):
286                socketserver.TCPServer((HOST, -1),
287                                       socketserver.StreamRequestHandler)
288
289    def test_context_manager(self):
290        with socketserver.TCPServer((HOST, 0),
291                                    socketserver.StreamRequestHandler) as server:
292            pass
293        self.assertEqual(-1, server.socket.fileno())
294
295
296class ErrorHandlerTest(unittest.TestCase):
297    """Test that the servers pass normal exceptions from the handler to
298    handle_error(), and that exiting exceptions like SystemExit and
299    KeyboardInterrupt are not passed."""
300
301    def tearDown(self):
302        test.support.unlink(test.support.TESTFN)
303
304    def test_sync_handled(self):
305        BaseErrorTestServer(ValueError)
306        self.check_result(handled=True)
307
308    def test_sync_not_handled(self):
309        with self.assertRaises(SystemExit):
310            BaseErrorTestServer(SystemExit)
311        self.check_result(handled=False)
312
313    def test_threading_handled(self):
314        ThreadingErrorTestServer(ValueError)
315        self.check_result(handled=True)
316
317    def test_threading_not_handled(self):
318        ThreadingErrorTestServer(SystemExit)
319        self.check_result(handled=False)
320
321    @requires_forking
322    def test_forking_handled(self):
323        ForkingErrorTestServer(ValueError)
324        self.check_result(handled=True)
325
326    @requires_forking
327    def test_forking_not_handled(self):
328        ForkingErrorTestServer(SystemExit)
329        self.check_result(handled=False)
330
331    def check_result(self, handled):
332        with open(test.support.TESTFN) as log:
333            expected = 'Handler called\n' + 'Error handled\n' * handled
334            self.assertEqual(log.read(), expected)
335
336
337class BaseErrorTestServer(socketserver.TCPServer):
338    def __init__(self, exception):
339        self.exception = exception
340        super().__init__((HOST, 0), BadHandler)
341        with socket.create_connection(self.server_address):
342            pass
343        try:
344            self.handle_request()
345        finally:
346            self.server_close()
347        self.wait_done()
348
349    def handle_error(self, request, client_address):
350        with open(test.support.TESTFN, 'a') as log:
351            log.write('Error handled\n')
352
353    def wait_done(self):
354        pass
355
356
357class BadHandler(socketserver.BaseRequestHandler):
358    def handle(self):
359        with open(test.support.TESTFN, 'a') as log:
360            log.write('Handler called\n')
361        raise self.server.exception('Test error')
362
363
364class ThreadingErrorTestServer(socketserver.ThreadingMixIn,
365        BaseErrorTestServer):
366    def __init__(self, *pos, **kw):
367        self.done = threading.Event()
368        super().__init__(*pos, **kw)
369
370    def shutdown_request(self, *pos, **kw):
371        super().shutdown_request(*pos, **kw)
372        self.done.set()
373
374    def wait_done(self):
375        self.done.wait()
376
377
378if HAVE_FORKING:
379    class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
380        pass
381
382
383class SocketWriterTest(unittest.TestCase):
384    def test_basics(self):
385        class Handler(socketserver.StreamRequestHandler):
386            def handle(self):
387                self.server.wfile = self.wfile
388                self.server.wfile_fileno = self.wfile.fileno()
389                self.server.request_fileno = self.request.fileno()
390
391        server = socketserver.TCPServer((HOST, 0), Handler)
392        self.addCleanup(server.server_close)
393        s = socket.socket(
394            server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
395        with s:
396            s.connect(server.server_address)
397        server.handle_request()
398        self.assertIsInstance(server.wfile, io.BufferedIOBase)
399        self.assertEqual(server.wfile_fileno, server.request_fileno)
400
401    def test_write(self):
402        # Test that wfile.write() sends data immediately, and that it does
403        # not truncate sends when interrupted by a Unix signal
404        pthread_kill = test.support.get_attribute(signal, 'pthread_kill')
405
406        class Handler(socketserver.StreamRequestHandler):
407            def handle(self):
408                self.server.sent1 = self.wfile.write(b'write data\n')
409                # Should be sent immediately, without requiring flush()
410                self.server.received = self.rfile.readline()
411                big_chunk = b'\0' * test.support.SOCK_MAX_SIZE
412                self.server.sent2 = self.wfile.write(big_chunk)
413
414        server = socketserver.TCPServer((HOST, 0), Handler)
415        self.addCleanup(server.server_close)
416        interrupted = threading.Event()
417
418        def signal_handler(signum, frame):
419            interrupted.set()
420
421        original = signal.signal(signal.SIGUSR1, signal_handler)
422        self.addCleanup(signal.signal, signal.SIGUSR1, original)
423        response1 = None
424        received2 = None
425        main_thread = threading.get_ident()
426
427        def run_client():
428            s = socket.socket(server.address_family, socket.SOCK_STREAM,
429                socket.IPPROTO_TCP)
430            with s, s.makefile('rb') as reader:
431                s.connect(server.server_address)
432                nonlocal response1
433                response1 = reader.readline()
434                s.sendall(b'client response\n')
435
436                reader.read(100)
437                # The main thread should now be blocking in a send() syscall.
438                # But in theory, it could get interrupted by other signals,
439                # and then retried. So keep sending the signal in a loop, in
440                # case an earlier signal happens to be delivered at an
441                # inconvenient moment.
442                while True:
443                    pthread_kill(main_thread, signal.SIGUSR1)
444                    if interrupted.wait(timeout=float(1)):
445                        break
446                nonlocal received2
447                received2 = len(reader.read())
448
449        background = threading.Thread(target=run_client)
450        background.start()
451        server.handle_request()
452        background.join()
453        self.assertEqual(server.sent1, len(response1))
454        self.assertEqual(response1, b'write data\n')
455        self.assertEqual(server.received, b'client response\n')
456        self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE)
457        self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100)
458
459
460class MiscTestCase(unittest.TestCase):
461
462    def test_all(self):
463        # objects defined in the module should be in __all__
464        expected = []
465        for name in dir(socketserver):
466            if not name.startswith('_'):
467                mod_object = getattr(socketserver, name)
468                if getattr(mod_object, '__module__', None) == 'socketserver':
469                    expected.append(name)
470        self.assertCountEqual(socketserver.__all__, expected)
471
472    def test_shutdown_request_called_if_verify_request_false(self):
473        # Issue #26309: BaseServer should call shutdown_request even if
474        # verify_request is False
475
476        class MyServer(socketserver.TCPServer):
477            def verify_request(self, request, client_address):
478                return False
479
480            shutdown_called = 0
481            def shutdown_request(self, request):
482                self.shutdown_called += 1
483                socketserver.TCPServer.shutdown_request(self, request)
484
485        server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
486        s = socket.socket(server.address_family, socket.SOCK_STREAM)
487        s.connect(server.server_address)
488        s.close()
489        server.handle_request()
490        self.assertEqual(server.shutdown_called, 1)
491        server.server_close()
492
493
494if __name__ == "__main__":
495    unittest.main()
496