1""" 2Test suite for socketserver. 3""" 4 5import contextlib 6import io 7import os 8import select 9import signal 10import socket 11import tempfile 12import threading 13import unittest 14import socketserver 15 16import test.support 17from test.support import reap_children, reap_threads, verbose 18 19 20test.support.requires("network") 21 22TEST_STR = b"hello world\n" 23HOST = test.support.HOST 24 25HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") 26requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS, 27 'requires Unix sockets') 28HAVE_FORKING = hasattr(os, "fork") 29requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking') 30 31def signal_alarm(n): 32 """Call signal.alarm when it exists (i.e. not on Windows).""" 33 if hasattr(signal, 'alarm'): 34 signal.alarm(n) 35 36# Remember real select() to avoid interferences with mocking 37_real_select = select.select 38 39def receive(sock, n, timeout=20): 40 r, w, x = _real_select([sock], [], [], timeout) 41 if sock in r: 42 return sock.recv(n) 43 else: 44 raise RuntimeError("timed out on %r" % (sock,)) 45 46if HAVE_UNIX_SOCKETS and HAVE_FORKING: 47 class ForkingUnixStreamServer(socketserver.ForkingMixIn, 48 socketserver.UnixStreamServer): 49 pass 50 51 class ForkingUnixDatagramServer(socketserver.ForkingMixIn, 52 socketserver.UnixDatagramServer): 53 pass 54 55 56@contextlib.contextmanager 57def simple_subprocess(testcase): 58 """Tests that a custom child process is not waited on (Issue 1540386)""" 59 pid = os.fork() 60 if pid == 0: 61 # Don't raise an exception; it would be caught by the test harness. 62 os._exit(72) 63 try: 64 yield None 65 except: 66 raise 67 finally: 68 pid2, status = os.waitpid(pid, 0) 69 testcase.assertEqual(pid2, pid) 70 testcase.assertEqual(72 << 8, status) 71 72 73class SocketServerTest(unittest.TestCase): 74 """Test all socket servers.""" 75 76 def setUp(self): 77 signal_alarm(60) # Kill deadlocks after 60 seconds. 78 self.port_seed = 0 79 self.test_files = [] 80 81 def tearDown(self): 82 signal_alarm(0) # Didn't deadlock. 83 reap_children() 84 85 for fn in self.test_files: 86 try: 87 os.remove(fn) 88 except OSError: 89 pass 90 self.test_files[:] = [] 91 92 def pickaddr(self, proto): 93 if proto == socket.AF_INET: 94 return (HOST, 0) 95 else: 96 # XXX: We need a way to tell AF_UNIX to pick its own name 97 # like AF_INET provides port==0. 98 dir = None 99 fn = tempfile.mktemp(prefix='unix_socket.', dir=dir) 100 self.test_files.append(fn) 101 return fn 102 103 def make_server(self, addr, svrcls, hdlrbase): 104 class MyServer(svrcls): 105 def handle_error(self, request, client_address): 106 self.close_request(request) 107 raise 108 109 class MyHandler(hdlrbase): 110 def handle(self): 111 line = self.rfile.readline() 112 self.wfile.write(line) 113 114 if verbose: print("creating server") 115 try: 116 server = MyServer(addr, MyHandler) 117 except PermissionError as e: 118 # Issue 29184: cannot bind() a Unix socket on Android. 119 self.skipTest('Cannot create server (%s, %s): %s' % 120 (svrcls, addr, e)) 121 self.assertEqual(server.server_address, server.socket.getsockname()) 122 return server 123 124 @reap_threads 125 def run_server(self, svrcls, hdlrbase, testfunc): 126 server = self.make_server(self.pickaddr(svrcls.address_family), 127 svrcls, hdlrbase) 128 # We had the OS pick a port, so pull the real address out of 129 # the server. 130 addr = server.server_address 131 if verbose: 132 print("ADDR =", addr) 133 print("CLASS =", svrcls) 134 135 t = threading.Thread( 136 name='%s serving' % svrcls, 137 target=server.serve_forever, 138 # Short poll interval to make the test finish quickly. 139 # Time between requests is short enough that we won't wake 140 # up spuriously too many times. 141 kwargs={'poll_interval':0.01}) 142 t.daemon = True # In case this function raises. 143 t.start() 144 if verbose: print("server running") 145 for i in range(3): 146 if verbose: print("test client", i) 147 testfunc(svrcls.address_family, addr) 148 if verbose: print("waiting for server") 149 server.shutdown() 150 t.join() 151 server.server_close() 152 self.assertEqual(-1, server.socket.fileno()) 153 if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn): 154 # bpo-31151: Check that ForkingMixIn.server_close() waits until 155 # all children completed 156 self.assertFalse(server.active_children) 157 if verbose: print("done") 158 159 def stream_examine(self, proto, addr): 160 with socket.socket(proto, socket.SOCK_STREAM) as s: 161 s.connect(addr) 162 s.sendall(TEST_STR) 163 buf = data = receive(s, 100) 164 while data and b'\n' not in buf: 165 data = receive(s, 100) 166 buf += data 167 self.assertEqual(buf, TEST_STR) 168 169 def dgram_examine(self, proto, addr): 170 with socket.socket(proto, socket.SOCK_DGRAM) as s: 171 if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX: 172 s.bind(self.pickaddr(proto)) 173 s.sendto(TEST_STR, addr) 174 buf = data = receive(s, 100) 175 while data and b'\n' not in buf: 176 data = receive(s, 100) 177 buf += data 178 self.assertEqual(buf, TEST_STR) 179 180 def test_TCPServer(self): 181 self.run_server(socketserver.TCPServer, 182 socketserver.StreamRequestHandler, 183 self.stream_examine) 184 185 def test_ThreadingTCPServer(self): 186 self.run_server(socketserver.ThreadingTCPServer, 187 socketserver.StreamRequestHandler, 188 self.stream_examine) 189 190 @requires_forking 191 def test_ForkingTCPServer(self): 192 with simple_subprocess(self): 193 self.run_server(socketserver.ForkingTCPServer, 194 socketserver.StreamRequestHandler, 195 self.stream_examine) 196 197 @requires_unix_sockets 198 def test_UnixStreamServer(self): 199 self.run_server(socketserver.UnixStreamServer, 200 socketserver.StreamRequestHandler, 201 self.stream_examine) 202 203 @requires_unix_sockets 204 def test_ThreadingUnixStreamServer(self): 205 self.run_server(socketserver.ThreadingUnixStreamServer, 206 socketserver.StreamRequestHandler, 207 self.stream_examine) 208 209 @requires_unix_sockets 210 @requires_forking 211 def test_ForkingUnixStreamServer(self): 212 with simple_subprocess(self): 213 self.run_server(ForkingUnixStreamServer, 214 socketserver.StreamRequestHandler, 215 self.stream_examine) 216 217 def test_UDPServer(self): 218 self.run_server(socketserver.UDPServer, 219 socketserver.DatagramRequestHandler, 220 self.dgram_examine) 221 222 def test_ThreadingUDPServer(self): 223 self.run_server(socketserver.ThreadingUDPServer, 224 socketserver.DatagramRequestHandler, 225 self.dgram_examine) 226 227 @requires_forking 228 def test_ForkingUDPServer(self): 229 with simple_subprocess(self): 230 self.run_server(socketserver.ForkingUDPServer, 231 socketserver.DatagramRequestHandler, 232 self.dgram_examine) 233 234 @requires_unix_sockets 235 def test_UnixDatagramServer(self): 236 self.run_server(socketserver.UnixDatagramServer, 237 socketserver.DatagramRequestHandler, 238 self.dgram_examine) 239 240 @requires_unix_sockets 241 def test_ThreadingUnixDatagramServer(self): 242 self.run_server(socketserver.ThreadingUnixDatagramServer, 243 socketserver.DatagramRequestHandler, 244 self.dgram_examine) 245 246 @requires_unix_sockets 247 @requires_forking 248 def test_ForkingUnixDatagramServer(self): 249 self.run_server(ForkingUnixDatagramServer, 250 socketserver.DatagramRequestHandler, 251 self.dgram_examine) 252 253 @reap_threads 254 def test_shutdown(self): 255 # Issue #2302: shutdown() should always succeed in making an 256 # other thread leave serve_forever(). 257 class MyServer(socketserver.TCPServer): 258 pass 259 260 class MyHandler(socketserver.StreamRequestHandler): 261 pass 262 263 threads = [] 264 for i in range(20): 265 s = MyServer((HOST, 0), MyHandler) 266 t = threading.Thread( 267 name='MyServer serving', 268 target=s.serve_forever, 269 kwargs={'poll_interval':0.01}) 270 t.daemon = True # In case this function raises. 271 threads.append((t, s)) 272 for t, s in threads: 273 t.start() 274 s.shutdown() 275 for t, s in threads: 276 t.join() 277 s.server_close() 278 279 def test_tcpserver_bind_leak(self): 280 # Issue #22435: the server socket wouldn't be closed if bind()/listen() 281 # failed. 282 # Create many servers for which bind() will fail, to see if this result 283 # in FD exhaustion. 284 for i in range(1024): 285 with self.assertRaises(OverflowError): 286 socketserver.TCPServer((HOST, -1), 287 socketserver.StreamRequestHandler) 288 289 def test_context_manager(self): 290 with socketserver.TCPServer((HOST, 0), 291 socketserver.StreamRequestHandler) as server: 292 pass 293 self.assertEqual(-1, server.socket.fileno()) 294 295 296class ErrorHandlerTest(unittest.TestCase): 297 """Test that the servers pass normal exceptions from the handler to 298 handle_error(), and that exiting exceptions like SystemExit and 299 KeyboardInterrupt are not passed.""" 300 301 def tearDown(self): 302 test.support.unlink(test.support.TESTFN) 303 304 def test_sync_handled(self): 305 BaseErrorTestServer(ValueError) 306 self.check_result(handled=True) 307 308 def test_sync_not_handled(self): 309 with self.assertRaises(SystemExit): 310 BaseErrorTestServer(SystemExit) 311 self.check_result(handled=False) 312 313 def test_threading_handled(self): 314 ThreadingErrorTestServer(ValueError) 315 self.check_result(handled=True) 316 317 def test_threading_not_handled(self): 318 ThreadingErrorTestServer(SystemExit) 319 self.check_result(handled=False) 320 321 @requires_forking 322 def test_forking_handled(self): 323 ForkingErrorTestServer(ValueError) 324 self.check_result(handled=True) 325 326 @requires_forking 327 def test_forking_not_handled(self): 328 ForkingErrorTestServer(SystemExit) 329 self.check_result(handled=False) 330 331 def check_result(self, handled): 332 with open(test.support.TESTFN) as log: 333 expected = 'Handler called\n' + 'Error handled\n' * handled 334 self.assertEqual(log.read(), expected) 335 336 337class BaseErrorTestServer(socketserver.TCPServer): 338 def __init__(self, exception): 339 self.exception = exception 340 super().__init__((HOST, 0), BadHandler) 341 with socket.create_connection(self.server_address): 342 pass 343 try: 344 self.handle_request() 345 finally: 346 self.server_close() 347 self.wait_done() 348 349 def handle_error(self, request, client_address): 350 with open(test.support.TESTFN, 'a') as log: 351 log.write('Error handled\n') 352 353 def wait_done(self): 354 pass 355 356 357class BadHandler(socketserver.BaseRequestHandler): 358 def handle(self): 359 with open(test.support.TESTFN, 'a') as log: 360 log.write('Handler called\n') 361 raise self.server.exception('Test error') 362 363 364class ThreadingErrorTestServer(socketserver.ThreadingMixIn, 365 BaseErrorTestServer): 366 def __init__(self, *pos, **kw): 367 self.done = threading.Event() 368 super().__init__(*pos, **kw) 369 370 def shutdown_request(self, *pos, **kw): 371 super().shutdown_request(*pos, **kw) 372 self.done.set() 373 374 def wait_done(self): 375 self.done.wait() 376 377 378if HAVE_FORKING: 379 class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer): 380 pass 381 382 383class SocketWriterTest(unittest.TestCase): 384 def test_basics(self): 385 class Handler(socketserver.StreamRequestHandler): 386 def handle(self): 387 self.server.wfile = self.wfile 388 self.server.wfile_fileno = self.wfile.fileno() 389 self.server.request_fileno = self.request.fileno() 390 391 server = socketserver.TCPServer((HOST, 0), Handler) 392 self.addCleanup(server.server_close) 393 s = socket.socket( 394 server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) 395 with s: 396 s.connect(server.server_address) 397 server.handle_request() 398 self.assertIsInstance(server.wfile, io.BufferedIOBase) 399 self.assertEqual(server.wfile_fileno, server.request_fileno) 400 401 def test_write(self): 402 # Test that wfile.write() sends data immediately, and that it does 403 # not truncate sends when interrupted by a Unix signal 404 pthread_kill = test.support.get_attribute(signal, 'pthread_kill') 405 406 class Handler(socketserver.StreamRequestHandler): 407 def handle(self): 408 self.server.sent1 = self.wfile.write(b'write data\n') 409 # Should be sent immediately, without requiring flush() 410 self.server.received = self.rfile.readline() 411 big_chunk = b'\0' * test.support.SOCK_MAX_SIZE 412 self.server.sent2 = self.wfile.write(big_chunk) 413 414 server = socketserver.TCPServer((HOST, 0), Handler) 415 self.addCleanup(server.server_close) 416 interrupted = threading.Event() 417 418 def signal_handler(signum, frame): 419 interrupted.set() 420 421 original = signal.signal(signal.SIGUSR1, signal_handler) 422 self.addCleanup(signal.signal, signal.SIGUSR1, original) 423 response1 = None 424 received2 = None 425 main_thread = threading.get_ident() 426 427 def run_client(): 428 s = socket.socket(server.address_family, socket.SOCK_STREAM, 429 socket.IPPROTO_TCP) 430 with s, s.makefile('rb') as reader: 431 s.connect(server.server_address) 432 nonlocal response1 433 response1 = reader.readline() 434 s.sendall(b'client response\n') 435 436 reader.read(100) 437 # The main thread should now be blocking in a send() syscall. 438 # But in theory, it could get interrupted by other signals, 439 # and then retried. So keep sending the signal in a loop, in 440 # case an earlier signal happens to be delivered at an 441 # inconvenient moment. 442 while True: 443 pthread_kill(main_thread, signal.SIGUSR1) 444 if interrupted.wait(timeout=float(1)): 445 break 446 nonlocal received2 447 received2 = len(reader.read()) 448 449 background = threading.Thread(target=run_client) 450 background.start() 451 server.handle_request() 452 background.join() 453 self.assertEqual(server.sent1, len(response1)) 454 self.assertEqual(response1, b'write data\n') 455 self.assertEqual(server.received, b'client response\n') 456 self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE) 457 self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100) 458 459 460class MiscTestCase(unittest.TestCase): 461 462 def test_all(self): 463 # objects defined in the module should be in __all__ 464 expected = [] 465 for name in dir(socketserver): 466 if not name.startswith('_'): 467 mod_object = getattr(socketserver, name) 468 if getattr(mod_object, '__module__', None) == 'socketserver': 469 expected.append(name) 470 self.assertCountEqual(socketserver.__all__, expected) 471 472 def test_shutdown_request_called_if_verify_request_false(self): 473 # Issue #26309: BaseServer should call shutdown_request even if 474 # verify_request is False 475 476 class MyServer(socketserver.TCPServer): 477 def verify_request(self, request, client_address): 478 return False 479 480 shutdown_called = 0 481 def shutdown_request(self, request): 482 self.shutdown_called += 1 483 socketserver.TCPServer.shutdown_request(self, request) 484 485 server = MyServer((HOST, 0), socketserver.StreamRequestHandler) 486 s = socket.socket(server.address_family, socket.SOCK_STREAM) 487 s.connect(server.server_address) 488 s.close() 489 server.handle_request() 490 self.assertEqual(server.shutdown_called, 1) 491 server.server_close() 492 493 494if __name__ == "__main__": 495 unittest.main() 496