1""" 2Test suite for socketserver. 3""" 4 5import contextlib 6import io 7import os 8import select 9import signal 10import socket 11import tempfile 12import unittest 13import socketserver 14 15import test.support 16from test.support import reap_children, reap_threads, verbose 17try: 18 import threading 19except ImportError: 20 threading = None 21 22test.support.requires("network") 23 24TEST_STR = b"hello world\n" 25HOST = test.support.HOST 26 27HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") 28requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS, 29 'requires Unix sockets') 30HAVE_FORKING = hasattr(os, "fork") 31requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking') 32 33def signal_alarm(n): 34 """Call signal.alarm when it exists (i.e. not on Windows).""" 35 if hasattr(signal, 'alarm'): 36 signal.alarm(n) 37 38# Remember real select() to avoid interferences with mocking 39_real_select = select.select 40 41def receive(sock, n, timeout=20): 42 r, w, x = _real_select([sock], [], [], timeout) 43 if sock in r: 44 return sock.recv(n) 45 else: 46 raise RuntimeError("timed out on %r" % (sock,)) 47 48if HAVE_UNIX_SOCKETS and HAVE_FORKING: 49 class ForkingUnixStreamServer(socketserver.ForkingMixIn, 50 socketserver.UnixStreamServer): 51 pass 52 53 class ForkingUnixDatagramServer(socketserver.ForkingMixIn, 54 socketserver.UnixDatagramServer): 55 pass 56 57 58@contextlib.contextmanager 59def simple_subprocess(testcase): 60 """Tests that a custom child process is not waited on (Issue 1540386)""" 61 pid = os.fork() 62 if pid == 0: 63 # Don't raise an exception; it would be caught by the test harness. 64 os._exit(72) 65 yield None 66 pid2, status = os.waitpid(pid, 0) 67 testcase.assertEqual(pid2, pid) 68 testcase.assertEqual(72 << 8, status) 69 70 71@unittest.skipUnless(threading, 'Threading required for this test.') 72class SocketServerTest(unittest.TestCase): 73 """Test all socket servers.""" 74 75 def setUp(self): 76 signal_alarm(60) # Kill deadlocks after 60 seconds. 77 self.port_seed = 0 78 self.test_files = [] 79 80 def tearDown(self): 81 signal_alarm(0) # Didn't deadlock. 82 reap_children() 83 84 for fn in self.test_files: 85 try: 86 os.remove(fn) 87 except OSError: 88 pass 89 self.test_files[:] = [] 90 91 def pickaddr(self, proto): 92 if proto == socket.AF_INET: 93 return (HOST, 0) 94 else: 95 # XXX: We need a way to tell AF_UNIX to pick its own name 96 # like AF_INET provides port==0. 97 dir = None 98 fn = tempfile.mktemp(prefix='unix_socket.', dir=dir) 99 self.test_files.append(fn) 100 return fn 101 102 def make_server(self, addr, svrcls, hdlrbase): 103 class MyServer(svrcls): 104 def handle_error(self, request, client_address): 105 self.close_request(request) 106 raise 107 108 class MyHandler(hdlrbase): 109 def handle(self): 110 line = self.rfile.readline() 111 self.wfile.write(line) 112 113 if verbose: print("creating server") 114 server = MyServer(addr, MyHandler) 115 self.assertEqual(server.server_address, server.socket.getsockname()) 116 return server 117 118 @reap_threads 119 def run_server(self, svrcls, hdlrbase, testfunc): 120 server = self.make_server(self.pickaddr(svrcls.address_family), 121 svrcls, hdlrbase) 122 # We had the OS pick a port, so pull the real address out of 123 # the server. 124 addr = server.server_address 125 if verbose: 126 print("ADDR =", addr) 127 print("CLASS =", svrcls) 128 129 t = threading.Thread( 130 name='%s serving' % svrcls, 131 target=server.serve_forever, 132 # Short poll interval to make the test finish quickly. 133 # Time between requests is short enough that we won't wake 134 # up spuriously too many times. 135 kwargs={'poll_interval':0.01}) 136 t.daemon = True # In case this function raises. 137 t.start() 138 if verbose: print("server running") 139 for i in range(3): 140 if verbose: print("test client", i) 141 testfunc(svrcls.address_family, addr) 142 if verbose: print("waiting for server") 143 server.shutdown() 144 t.join() 145 server.server_close() 146 self.assertEqual(-1, server.socket.fileno()) 147 if verbose: print("done") 148 149 def stream_examine(self, proto, addr): 150 s = socket.socket(proto, socket.SOCK_STREAM) 151 s.connect(addr) 152 s.sendall(TEST_STR) 153 buf = data = receive(s, 100) 154 while data and b'\n' not in buf: 155 data = receive(s, 100) 156 buf += data 157 self.assertEqual(buf, TEST_STR) 158 s.close() 159 160 def dgram_examine(self, proto, addr): 161 s = socket.socket(proto, socket.SOCK_DGRAM) 162 if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX: 163 s.bind(self.pickaddr(proto)) 164 s.sendto(TEST_STR, addr) 165 buf = data = receive(s, 100) 166 while data and b'\n' not in buf: 167 data = receive(s, 100) 168 buf += data 169 self.assertEqual(buf, TEST_STR) 170 s.close() 171 172 def test_TCPServer(self): 173 self.run_server(socketserver.TCPServer, 174 socketserver.StreamRequestHandler, 175 self.stream_examine) 176 177 def test_ThreadingTCPServer(self): 178 self.run_server(socketserver.ThreadingTCPServer, 179 socketserver.StreamRequestHandler, 180 self.stream_examine) 181 182 @requires_forking 183 def test_ForkingTCPServer(self): 184 with simple_subprocess(self): 185 self.run_server(socketserver.ForkingTCPServer, 186 socketserver.StreamRequestHandler, 187 self.stream_examine) 188 189 @requires_unix_sockets 190 def test_UnixStreamServer(self): 191 self.run_server(socketserver.UnixStreamServer, 192 socketserver.StreamRequestHandler, 193 self.stream_examine) 194 195 @requires_unix_sockets 196 def test_ThreadingUnixStreamServer(self): 197 self.run_server(socketserver.ThreadingUnixStreamServer, 198 socketserver.StreamRequestHandler, 199 self.stream_examine) 200 201 @requires_unix_sockets 202 @requires_forking 203 def test_ForkingUnixStreamServer(self): 204 with simple_subprocess(self): 205 self.run_server(ForkingUnixStreamServer, 206 socketserver.StreamRequestHandler, 207 self.stream_examine) 208 209 def test_UDPServer(self): 210 self.run_server(socketserver.UDPServer, 211 socketserver.DatagramRequestHandler, 212 self.dgram_examine) 213 214 def test_ThreadingUDPServer(self): 215 self.run_server(socketserver.ThreadingUDPServer, 216 socketserver.DatagramRequestHandler, 217 self.dgram_examine) 218 219 @requires_forking 220 def test_ForkingUDPServer(self): 221 with simple_subprocess(self): 222 self.run_server(socketserver.ForkingUDPServer, 223 socketserver.DatagramRequestHandler, 224 self.dgram_examine) 225 226 @requires_unix_sockets 227 def test_UnixDatagramServer(self): 228 self.run_server(socketserver.UnixDatagramServer, 229 socketserver.DatagramRequestHandler, 230 self.dgram_examine) 231 232 @requires_unix_sockets 233 def test_ThreadingUnixDatagramServer(self): 234 self.run_server(socketserver.ThreadingUnixDatagramServer, 235 socketserver.DatagramRequestHandler, 236 self.dgram_examine) 237 238 @requires_unix_sockets 239 @requires_forking 240 def test_ForkingUnixDatagramServer(self): 241 self.run_server(ForkingUnixDatagramServer, 242 socketserver.DatagramRequestHandler, 243 self.dgram_examine) 244 245 @reap_threads 246 def test_shutdown(self): 247 # Issue #2302: shutdown() should always succeed in making an 248 # other thread leave serve_forever(). 249 class MyServer(socketserver.TCPServer): 250 pass 251 252 class MyHandler(socketserver.StreamRequestHandler): 253 pass 254 255 threads = [] 256 for i in range(20): 257 s = MyServer((HOST, 0), MyHandler) 258 t = threading.Thread( 259 name='MyServer serving', 260 target=s.serve_forever, 261 kwargs={'poll_interval':0.01}) 262 t.daemon = True # In case this function raises. 263 threads.append((t, s)) 264 for t, s in threads: 265 t.start() 266 s.shutdown() 267 for t, s in threads: 268 t.join() 269 s.server_close() 270 271 def test_tcpserver_bind_leak(self): 272 # Issue #22435: the server socket wouldn't be closed if bind()/listen() 273 # failed. 274 # Create many servers for which bind() will fail, to see if this result 275 # in FD exhaustion. 276 for i in range(1024): 277 with self.assertRaises(OverflowError): 278 socketserver.TCPServer((HOST, -1), 279 socketserver.StreamRequestHandler) 280 281 def test_context_manager(self): 282 with socketserver.TCPServer((HOST, 0), 283 socketserver.StreamRequestHandler) as server: 284 pass 285 self.assertEqual(-1, server.socket.fileno()) 286 287 288class ErrorHandlerTest(unittest.TestCase): 289 """Test that the servers pass normal exceptions from the handler to 290 handle_error(), and that exiting exceptions like SystemExit and 291 KeyboardInterrupt are not passed.""" 292 293 def tearDown(self): 294 test.support.unlink(test.support.TESTFN) 295 296 def test_sync_handled(self): 297 BaseErrorTestServer(ValueError) 298 self.check_result(handled=True) 299 300 def test_sync_not_handled(self): 301 with self.assertRaises(SystemExit): 302 BaseErrorTestServer(SystemExit) 303 self.check_result(handled=False) 304 305 @unittest.skipUnless(threading, 'Threading required for this test.') 306 def test_threading_handled(self): 307 ThreadingErrorTestServer(ValueError) 308 self.check_result(handled=True) 309 310 @unittest.skipUnless(threading, 'Threading required for this test.') 311 def test_threading_not_handled(self): 312 ThreadingErrorTestServer(SystemExit) 313 self.check_result(handled=False) 314 315 @requires_forking 316 def test_forking_handled(self): 317 ForkingErrorTestServer(ValueError) 318 self.check_result(handled=True) 319 320 @requires_forking 321 def test_forking_not_handled(self): 322 ForkingErrorTestServer(SystemExit) 323 self.check_result(handled=False) 324 325 def check_result(self, handled): 326 with open(test.support.TESTFN) as log: 327 expected = 'Handler called\n' + 'Error handled\n' * handled 328 self.assertEqual(log.read(), expected) 329 330 331class BaseErrorTestServer(socketserver.TCPServer): 332 def __init__(self, exception): 333 self.exception = exception 334 super().__init__((HOST, 0), BadHandler) 335 with socket.create_connection(self.server_address): 336 pass 337 try: 338 self.handle_request() 339 finally: 340 self.server_close() 341 self.wait_done() 342 343 def handle_error(self, request, client_address): 344 with open(test.support.TESTFN, 'a') as log: 345 log.write('Error handled\n') 346 347 def wait_done(self): 348 pass 349 350 351class BadHandler(socketserver.BaseRequestHandler): 352 def handle(self): 353 with open(test.support.TESTFN, 'a') as log: 354 log.write('Handler called\n') 355 raise self.server.exception('Test error') 356 357 358class ThreadingErrorTestServer(socketserver.ThreadingMixIn, 359 BaseErrorTestServer): 360 def __init__(self, *pos, **kw): 361 self.done = threading.Event() 362 super().__init__(*pos, **kw) 363 364 def shutdown_request(self, *pos, **kw): 365 super().shutdown_request(*pos, **kw) 366 self.done.set() 367 368 def wait_done(self): 369 self.done.wait() 370 371 372if HAVE_FORKING: 373 class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer): 374 def wait_done(self): 375 [child] = self.active_children 376 os.waitpid(child, 0) 377 self.active_children.clear() 378 379 380class SocketWriterTest(unittest.TestCase): 381 def test_basics(self): 382 class Handler(socketserver.StreamRequestHandler): 383 def handle(self): 384 self.server.wfile = self.wfile 385 self.server.wfile_fileno = self.wfile.fileno() 386 self.server.request_fileno = self.request.fileno() 387 388 server = socketserver.TCPServer((HOST, 0), Handler) 389 self.addCleanup(server.server_close) 390 s = socket.socket( 391 server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) 392 with s: 393 s.connect(server.server_address) 394 server.handle_request() 395 self.assertIsInstance(server.wfile, io.BufferedIOBase) 396 self.assertEqual(server.wfile_fileno, server.request_fileno) 397 398 @unittest.skipUnless(threading, 'Threading required for this test.') 399 def test_write(self): 400 # Test that wfile.write() sends data immediately, and that it does 401 # not truncate sends when interrupted by a Unix signal 402 pthread_kill = test.support.get_attribute(signal, 'pthread_kill') 403 404 class Handler(socketserver.StreamRequestHandler): 405 def handle(self): 406 self.server.sent1 = self.wfile.write(b'write data\n') 407 # Should be sent immediately, without requiring flush() 408 self.server.received = self.rfile.readline() 409 big_chunk = b'\0' * test.support.SOCK_MAX_SIZE 410 self.server.sent2 = self.wfile.write(big_chunk) 411 412 server = socketserver.TCPServer((HOST, 0), Handler) 413 self.addCleanup(server.server_close) 414 interrupted = threading.Event() 415 416 def signal_handler(signum, frame): 417 interrupted.set() 418 419 original = signal.signal(signal.SIGUSR1, signal_handler) 420 self.addCleanup(signal.signal, signal.SIGUSR1, original) 421 response1 = None 422 received2 = None 423 main_thread = threading.get_ident() 424 425 def run_client(): 426 s = socket.socket(server.address_family, socket.SOCK_STREAM, 427 socket.IPPROTO_TCP) 428 with s, s.makefile('rb') as reader: 429 s.connect(server.server_address) 430 nonlocal response1 431 response1 = reader.readline() 432 s.sendall(b'client response\n') 433 434 reader.read(100) 435 # The main thread should now be blocking in a send() syscall. 436 # But in theory, it could get interrupted by other signals, 437 # and then retried. So keep sending the signal in a loop, in 438 # case an earlier signal happens to be delivered at an 439 # inconvenient moment. 440 while True: 441 pthread_kill(main_thread, signal.SIGUSR1) 442 if interrupted.wait(timeout=float(1)): 443 break 444 nonlocal received2 445 received2 = len(reader.read()) 446 447 background = threading.Thread(target=run_client) 448 background.start() 449 server.handle_request() 450 background.join() 451 self.assertEqual(server.sent1, len(response1)) 452 self.assertEqual(response1, b'write data\n') 453 self.assertEqual(server.received, b'client response\n') 454 self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE) 455 self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100) 456 457 458class MiscTestCase(unittest.TestCase): 459 460 def test_all(self): 461 # objects defined in the module should be in __all__ 462 expected = [] 463 for name in dir(socketserver): 464 if not name.startswith('_'): 465 mod_object = getattr(socketserver, name) 466 if getattr(mod_object, '__module__', None) == 'socketserver': 467 expected.append(name) 468 self.assertCountEqual(socketserver.__all__, expected) 469 470 def test_shutdown_request_called_if_verify_request_false(self): 471 # Issue #26309: BaseServer should call shutdown_request even if 472 # verify_request is False 473 474 class MyServer(socketserver.TCPServer): 475 def verify_request(self, request, client_address): 476 return False 477 478 shutdown_called = 0 479 def shutdown_request(self, request): 480 self.shutdown_called += 1 481 socketserver.TCPServer.shutdown_request(self, request) 482 483 server = MyServer((HOST, 0), socketserver.StreamRequestHandler) 484 s = socket.socket(server.address_family, socket.SOCK_STREAM) 485 s.connect(server.server_address) 486 s.close() 487 server.handle_request() 488 self.assertEqual(server.shutdown_called, 1) 489 server.server_close() 490 491 492if __name__ == "__main__": 493 unittest.main() 494