• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2Test suite for SocketServer.py.
3"""
4
5import contextlib
6import imp
7import os
8import select
9import signal
10import socket
11import tempfile
12import unittest
13import SocketServer
14
15import test.test_support
16from test.test_support import reap_children, reap_threads, verbose
17try:
18    import threading
19except ImportError:
20    threading = None
21
22test.test_support.requires("network")
23
24TEST_STR = "hello world\n"
25HOST = test.test_support.HOST
26
27HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
28HAVE_FORKING = hasattr(os, "fork") and os.name != "os2"
29
30def signal_alarm(n):
31    """Call signal.alarm when it exists (i.e. not on Windows)."""
32    if hasattr(signal, 'alarm'):
33        signal.alarm(n)
34
35def receive(sock, n, timeout=20):
36    r, w, x = select.select([sock], [], [], timeout)
37    if sock in r:
38        return sock.recv(n)
39    else:
40        raise RuntimeError, "timed out on %r" % (sock,)
41
42if HAVE_UNIX_SOCKETS:
43    class ForkingUnixStreamServer(SocketServer.ForkingMixIn,
44                                  SocketServer.UnixStreamServer):
45        pass
46
47    class ForkingUnixDatagramServer(SocketServer.ForkingMixIn,
48                                    SocketServer.UnixDatagramServer):
49        pass
50
51
52@contextlib.contextmanager
53def simple_subprocess(testcase):
54    pid = os.fork()
55    if pid == 0:
56        # Don't throw an exception; it would be caught by the test harness.
57        os._exit(72)
58    yield None
59    pid2, status = os.waitpid(pid, 0)
60    testcase.assertEqual(pid2, pid)
61    testcase.assertEqual(72 << 8, status)
62
63
64@unittest.skipUnless(threading, 'Threading required for this test.')
65class SocketServerTest(unittest.TestCase):
66    """Test all socket servers."""
67
68    def setUp(self):
69        signal_alarm(20)  # Kill deadlocks after 20 seconds.
70        self.port_seed = 0
71        self.test_files = []
72
73    def tearDown(self):
74        signal_alarm(0)  # Didn't deadlock.
75        reap_children()
76
77        for fn in self.test_files:
78            try:
79                os.remove(fn)
80            except os.error:
81                pass
82        self.test_files[:] = []
83
84    def pickaddr(self, proto):
85        if proto == socket.AF_INET:
86            return (HOST, 0)
87        else:
88            # XXX: We need a way to tell AF_UNIX to pick its own name
89            # like AF_INET provides port==0.
90            dir = None
91            if os.name == 'os2':
92                dir = '\socket'
93            fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
94            if os.name == 'os2':
95                # AF_UNIX socket names on OS/2 require a specific prefix
96                # which can't include a drive letter and must also use
97                # backslashes as directory separators
98                if fn[1] == ':':
99                    fn = fn[2:]
100                if fn[0] in (os.sep, os.altsep):
101                    fn = fn[1:]
102                if os.sep == '/':
103                    fn = fn.replace(os.sep, os.altsep)
104                else:
105                    fn = fn.replace(os.altsep, os.sep)
106            self.test_files.append(fn)
107            return fn
108
109    def make_server(self, addr, svrcls, hdlrbase):
110        class MyServer(svrcls):
111            def handle_error(self, request, client_address):
112                self.close_request(request)
113                self.server_close()
114                raise
115
116        class MyHandler(hdlrbase):
117            def handle(self):
118                line = self.rfile.readline()
119                self.wfile.write(line)
120
121        if verbose: print "creating server"
122        server = MyServer(addr, MyHandler)
123        self.assertEqual(server.server_address, server.socket.getsockname())
124        return server
125
126    @unittest.skipUnless(threading, 'Threading required for this test.')
127    @reap_threads
128    def run_server(self, svrcls, hdlrbase, testfunc):
129        server = self.make_server(self.pickaddr(svrcls.address_family),
130                                  svrcls, hdlrbase)
131        # We had the OS pick a port, so pull the real address out of
132        # the server.
133        addr = server.server_address
134        if verbose:
135            print "server created"
136            print "ADDR =", addr
137            print "CLASS =", svrcls
138        t = threading.Thread(
139            name='%s serving' % svrcls,
140            target=server.serve_forever,
141            # Short poll interval to make the test finish quickly.
142            # Time between requests is short enough that we won't wake
143            # up spuriously too many times.
144            kwargs={'poll_interval':0.01})
145        t.daemon = True  # In case this function raises.
146        t.start()
147        if verbose: print "server running"
148        for i in range(3):
149            if verbose: print "test client", i
150            testfunc(svrcls.address_family, addr)
151        if verbose: print "waiting for server"
152        server.shutdown()
153        t.join()
154        if verbose: print "done"
155
156    def stream_examine(self, proto, addr):
157        s = socket.socket(proto, socket.SOCK_STREAM)
158        s.connect(addr)
159        s.sendall(TEST_STR)
160        buf = data = receive(s, 100)
161        while data and '\n' not in buf:
162            data = receive(s, 100)
163            buf += data
164        self.assertEqual(buf, TEST_STR)
165        s.close()
166
167    def dgram_examine(self, proto, addr):
168        s = socket.socket(proto, socket.SOCK_DGRAM)
169        s.sendto(TEST_STR, addr)
170        buf = data = receive(s, 100)
171        while data and '\n' not in buf:
172            data = receive(s, 100)
173            buf += data
174        self.assertEqual(buf, TEST_STR)
175        s.close()
176
177    def test_TCPServer(self):
178        self.run_server(SocketServer.TCPServer,
179                        SocketServer.StreamRequestHandler,
180                        self.stream_examine)
181
182    def test_ThreadingTCPServer(self):
183        self.run_server(SocketServer.ThreadingTCPServer,
184                        SocketServer.StreamRequestHandler,
185                        self.stream_examine)
186
187    if HAVE_FORKING:
188        def test_ForkingTCPServer(self):
189            with simple_subprocess(self):
190                self.run_server(SocketServer.ForkingTCPServer,
191                                SocketServer.StreamRequestHandler,
192                                self.stream_examine)
193
194    if HAVE_UNIX_SOCKETS:
195        def test_UnixStreamServer(self):
196            self.run_server(SocketServer.UnixStreamServer,
197                            SocketServer.StreamRequestHandler,
198                            self.stream_examine)
199
200        def test_ThreadingUnixStreamServer(self):
201            self.run_server(SocketServer.ThreadingUnixStreamServer,
202                            SocketServer.StreamRequestHandler,
203                            self.stream_examine)
204
205        if HAVE_FORKING:
206            def test_ForkingUnixStreamServer(self):
207                with simple_subprocess(self):
208                    self.run_server(ForkingUnixStreamServer,
209                                    SocketServer.StreamRequestHandler,
210                                    self.stream_examine)
211
212    def test_UDPServer(self):
213        self.run_server(SocketServer.UDPServer,
214                        SocketServer.DatagramRequestHandler,
215                        self.dgram_examine)
216
217    def test_ThreadingUDPServer(self):
218        self.run_server(SocketServer.ThreadingUDPServer,
219                        SocketServer.DatagramRequestHandler,
220                        self.dgram_examine)
221
222    if HAVE_FORKING:
223        def test_ForkingUDPServer(self):
224            with simple_subprocess(self):
225                self.run_server(SocketServer.ForkingUDPServer,
226                                SocketServer.DatagramRequestHandler,
227                                self.dgram_examine)
228
229    # Alas, on Linux (at least) recvfrom() doesn't return a meaningful
230    # client address so this cannot work:
231
232    # if HAVE_UNIX_SOCKETS:
233    #     def test_UnixDatagramServer(self):
234    #         self.run_server(SocketServer.UnixDatagramServer,
235    #                         SocketServer.DatagramRequestHandler,
236    #                         self.dgram_examine)
237    #
238    #     def test_ThreadingUnixDatagramServer(self):
239    #         self.run_server(SocketServer.ThreadingUnixDatagramServer,
240    #                         SocketServer.DatagramRequestHandler,
241    #                         self.dgram_examine)
242    #
243    #     if HAVE_FORKING:
244    #         def test_ForkingUnixDatagramServer(self):
245    #             self.run_server(SocketServer.ForkingUnixDatagramServer,
246    #                             SocketServer.DatagramRequestHandler,
247    #                             self.dgram_examine)
248
249    @reap_threads
250    def test_shutdown(self):
251        # Issue #2302: shutdown() should always succeed in making an
252        # other thread leave serve_forever().
253        class MyServer(SocketServer.TCPServer):
254            pass
255
256        class MyHandler(SocketServer.StreamRequestHandler):
257            pass
258
259        threads = []
260        for i in range(20):
261            s = MyServer((HOST, 0), MyHandler)
262            t = threading.Thread(
263                name='MyServer serving',
264                target=s.serve_forever,
265                kwargs={'poll_interval':0.01})
266            t.daemon = True  # In case this function raises.
267            threads.append((t, s))
268        for t, s in threads:
269            t.start()
270            s.shutdown()
271        for t, s in threads:
272            t.join()
273
274
275def test_main():
276    if imp.lock_held():
277        # If the import lock is held, the threads will hang
278        raise unittest.SkipTest("can't run when import lock is held")
279
280    test.test_support.run_unittest(SocketServerTest)
281
282if __name__ == "__main__":
283    test_main()
284    signal_alarm(3)  # Shutdown shouldn't take more than 3 seconds.
285