1""" 2Test suite for socketserver. 3""" 4 5import contextlib 6import io 7import os 8import select 9import signal 10import socket 11import tempfile 12import threading 13import unittest 14import socketserver 15 16import test.support 17from test.support import reap_children, reap_threads, verbose 18from test.support import socket_helper 19 20 21test.support.requires("network") 22 23TEST_STR = b"hello world\n" 24HOST = socket_helper.HOST 25 26HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") 27requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS, 28 'requires Unix sockets') 29HAVE_FORKING = hasattr(os, "fork") 30requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking') 31 32def signal_alarm(n): 33 """Call signal.alarm when it exists (i.e. not on Windows).""" 34 if hasattr(signal, 'alarm'): 35 signal.alarm(n) 36 37# Remember real select() to avoid interferences with mocking 38_real_select = select.select 39 40def receive(sock, n, timeout=test.support.SHORT_TIMEOUT): 41 r, w, x = _real_select([sock], [], [], timeout) 42 if sock in r: 43 return sock.recv(n) 44 else: 45 raise RuntimeError("timed out on %r" % (sock,)) 46 47if HAVE_UNIX_SOCKETS and HAVE_FORKING: 48 class ForkingUnixStreamServer(socketserver.ForkingMixIn, 49 socketserver.UnixStreamServer): 50 pass 51 52 class ForkingUnixDatagramServer(socketserver.ForkingMixIn, 53 socketserver.UnixDatagramServer): 54 pass 55 56 57@contextlib.contextmanager 58def simple_subprocess(testcase): 59 """Tests that a custom child process is not waited on (Issue 1540386)""" 60 pid = os.fork() 61 if pid == 0: 62 # Don't raise an exception; it would be caught by the test harness. 63 os._exit(72) 64 try: 65 yield None 66 except: 67 raise 68 finally: 69 test.support.wait_process(pid, exitcode=72) 70 71 72class SocketServerTest(unittest.TestCase): 73 """Test all socket servers.""" 74 75 def setUp(self): 76 signal_alarm(60) # Kill deadlocks after 60 seconds. 77 self.port_seed = 0 78 self.test_files = [] 79 80 def tearDown(self): 81 signal_alarm(0) # Didn't deadlock. 82 reap_children() 83 84 for fn in self.test_files: 85 try: 86 os.remove(fn) 87 except OSError: 88 pass 89 self.test_files[:] = [] 90 91 def pickaddr(self, proto): 92 if proto == socket.AF_INET: 93 return (HOST, 0) 94 else: 95 # XXX: We need a way to tell AF_UNIX to pick its own name 96 # like AF_INET provides port==0. 97 dir = None 98 fn = tempfile.mktemp(prefix='unix_socket.', dir=dir) 99 self.test_files.append(fn) 100 return fn 101 102 def make_server(self, addr, svrcls, hdlrbase): 103 class MyServer(svrcls): 104 def handle_error(self, request, client_address): 105 self.close_request(request) 106 raise 107 108 class MyHandler(hdlrbase): 109 def handle(self): 110 line = self.rfile.readline() 111 self.wfile.write(line) 112 113 if verbose: print("creating server") 114 try: 115 server = MyServer(addr, MyHandler) 116 except PermissionError as e: 117 # Issue 29184: cannot bind() a Unix socket on Android. 118 self.skipTest('Cannot create server (%s, %s): %s' % 119 (svrcls, addr, e)) 120 self.assertEqual(server.server_address, server.socket.getsockname()) 121 return server 122 123 @reap_threads 124 def run_server(self, svrcls, hdlrbase, testfunc): 125 server = self.make_server(self.pickaddr(svrcls.address_family), 126 svrcls, hdlrbase) 127 # We had the OS pick a port, so pull the real address out of 128 # the server. 129 addr = server.server_address 130 if verbose: 131 print("ADDR =", addr) 132 print("CLASS =", svrcls) 133 134 t = threading.Thread( 135 name='%s serving' % svrcls, 136 target=server.serve_forever, 137 # Short poll interval to make the test finish quickly. 138 # Time between requests is short enough that we won't wake 139 # up spuriously too many times. 140 kwargs={'poll_interval':0.01}) 141 t.daemon = True # In case this function raises. 142 t.start() 143 if verbose: print("server running") 144 for i in range(3): 145 if verbose: print("test client", i) 146 testfunc(svrcls.address_family, addr) 147 if verbose: print("waiting for server") 148 server.shutdown() 149 t.join() 150 server.server_close() 151 self.assertEqual(-1, server.socket.fileno()) 152 if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn): 153 # bpo-31151: Check that ForkingMixIn.server_close() waits until 154 # all children completed 155 self.assertFalse(server.active_children) 156 if verbose: print("done") 157 158 def stream_examine(self, proto, addr): 159 with socket.socket(proto, socket.SOCK_STREAM) as s: 160 s.connect(addr) 161 s.sendall(TEST_STR) 162 buf = data = receive(s, 100) 163 while data and b'\n' not in buf: 164 data = receive(s, 100) 165 buf += data 166 self.assertEqual(buf, TEST_STR) 167 168 def dgram_examine(self, proto, addr): 169 with socket.socket(proto, socket.SOCK_DGRAM) as s: 170 if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX: 171 s.bind(self.pickaddr(proto)) 172 s.sendto(TEST_STR, addr) 173 buf = data = receive(s, 100) 174 while data and b'\n' not in buf: 175 data = receive(s, 100) 176 buf += data 177 self.assertEqual(buf, TEST_STR) 178 179 def test_TCPServer(self): 180 self.run_server(socketserver.TCPServer, 181 socketserver.StreamRequestHandler, 182 self.stream_examine) 183 184 def test_ThreadingTCPServer(self): 185 self.run_server(socketserver.ThreadingTCPServer, 186 socketserver.StreamRequestHandler, 187 self.stream_examine) 188 189 @requires_forking 190 def test_ForkingTCPServer(self): 191 with simple_subprocess(self): 192 self.run_server(socketserver.ForkingTCPServer, 193 socketserver.StreamRequestHandler, 194 self.stream_examine) 195 196 @requires_unix_sockets 197 def test_UnixStreamServer(self): 198 self.run_server(socketserver.UnixStreamServer, 199 socketserver.StreamRequestHandler, 200 self.stream_examine) 201 202 @requires_unix_sockets 203 def test_ThreadingUnixStreamServer(self): 204 self.run_server(socketserver.ThreadingUnixStreamServer, 205 socketserver.StreamRequestHandler, 206 self.stream_examine) 207 208 @requires_unix_sockets 209 @requires_forking 210 def test_ForkingUnixStreamServer(self): 211 with simple_subprocess(self): 212 self.run_server(ForkingUnixStreamServer, 213 socketserver.StreamRequestHandler, 214 self.stream_examine) 215 216 def test_UDPServer(self): 217 self.run_server(socketserver.UDPServer, 218 socketserver.DatagramRequestHandler, 219 self.dgram_examine) 220 221 def test_ThreadingUDPServer(self): 222 self.run_server(socketserver.ThreadingUDPServer, 223 socketserver.DatagramRequestHandler, 224 self.dgram_examine) 225 226 @requires_forking 227 def test_ForkingUDPServer(self): 228 with simple_subprocess(self): 229 self.run_server(socketserver.ForkingUDPServer, 230 socketserver.DatagramRequestHandler, 231 self.dgram_examine) 232 233 @requires_unix_sockets 234 def test_UnixDatagramServer(self): 235 self.run_server(socketserver.UnixDatagramServer, 236 socketserver.DatagramRequestHandler, 237 self.dgram_examine) 238 239 @requires_unix_sockets 240 def test_ThreadingUnixDatagramServer(self): 241 self.run_server(socketserver.ThreadingUnixDatagramServer, 242 socketserver.DatagramRequestHandler, 243 self.dgram_examine) 244 245 @requires_unix_sockets 246 @requires_forking 247 def test_ForkingUnixDatagramServer(self): 248 self.run_server(ForkingUnixDatagramServer, 249 socketserver.DatagramRequestHandler, 250 self.dgram_examine) 251 252 @reap_threads 253 def test_shutdown(self): 254 # Issue #2302: shutdown() should always succeed in making an 255 # other thread leave serve_forever(). 256 class MyServer(socketserver.TCPServer): 257 pass 258 259 class MyHandler(socketserver.StreamRequestHandler): 260 pass 261 262 threads = [] 263 for i in range(20): 264 s = MyServer((HOST, 0), MyHandler) 265 t = threading.Thread( 266 name='MyServer serving', 267 target=s.serve_forever, 268 kwargs={'poll_interval':0.01}) 269 t.daemon = True # In case this function raises. 270 threads.append((t, s)) 271 for t, s in threads: 272 t.start() 273 s.shutdown() 274 for t, s in threads: 275 t.join() 276 s.server_close() 277 278 def test_tcpserver_bind_leak(self): 279 # Issue #22435: the server socket wouldn't be closed if bind()/listen() 280 # failed. 281 # Create many servers for which bind() will fail, to see if this result 282 # in FD exhaustion. 283 for i in range(1024): 284 with self.assertRaises(OverflowError): 285 socketserver.TCPServer((HOST, -1), 286 socketserver.StreamRequestHandler) 287 288 def test_context_manager(self): 289 with socketserver.TCPServer((HOST, 0), 290 socketserver.StreamRequestHandler) as server: 291 pass 292 self.assertEqual(-1, server.socket.fileno()) 293 294 295class ErrorHandlerTest(unittest.TestCase): 296 """Test that the servers pass normal exceptions from the handler to 297 handle_error(), and that exiting exceptions like SystemExit and 298 KeyboardInterrupt are not passed.""" 299 300 def tearDown(self): 301 test.support.unlink(test.support.TESTFN) 302 303 def test_sync_handled(self): 304 BaseErrorTestServer(ValueError) 305 self.check_result(handled=True) 306 307 def test_sync_not_handled(self): 308 with self.assertRaises(SystemExit): 309 BaseErrorTestServer(SystemExit) 310 self.check_result(handled=False) 311 312 def test_threading_handled(self): 313 ThreadingErrorTestServer(ValueError) 314 self.check_result(handled=True) 315 316 def test_threading_not_handled(self): 317 ThreadingErrorTestServer(SystemExit) 318 self.check_result(handled=False) 319 320 @requires_forking 321 def test_forking_handled(self): 322 ForkingErrorTestServer(ValueError) 323 self.check_result(handled=True) 324 325 @requires_forking 326 def test_forking_not_handled(self): 327 ForkingErrorTestServer(SystemExit) 328 self.check_result(handled=False) 329 330 def check_result(self, handled): 331 with open(test.support.TESTFN) as log: 332 expected = 'Handler called\n' + 'Error handled\n' * handled 333 self.assertEqual(log.read(), expected) 334 335 336class BaseErrorTestServer(socketserver.TCPServer): 337 def __init__(self, exception): 338 self.exception = exception 339 super().__init__((HOST, 0), BadHandler) 340 with socket.create_connection(self.server_address): 341 pass 342 try: 343 self.handle_request() 344 finally: 345 self.server_close() 346 self.wait_done() 347 348 def handle_error(self, request, client_address): 349 with open(test.support.TESTFN, 'a') as log: 350 log.write('Error handled\n') 351 352 def wait_done(self): 353 pass 354 355 356class BadHandler(socketserver.BaseRequestHandler): 357 def handle(self): 358 with open(test.support.TESTFN, 'a') as log: 359 log.write('Handler called\n') 360 raise self.server.exception('Test error') 361 362 363class ThreadingErrorTestServer(socketserver.ThreadingMixIn, 364 BaseErrorTestServer): 365 def __init__(self, *pos, **kw): 366 self.done = threading.Event() 367 super().__init__(*pos, **kw) 368 369 def shutdown_request(self, *pos, **kw): 370 super().shutdown_request(*pos, **kw) 371 self.done.set() 372 373 def wait_done(self): 374 self.done.wait() 375 376 377if HAVE_FORKING: 378 class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer): 379 pass 380 381 382class SocketWriterTest(unittest.TestCase): 383 def test_basics(self): 384 class Handler(socketserver.StreamRequestHandler): 385 def handle(self): 386 self.server.wfile = self.wfile 387 self.server.wfile_fileno = self.wfile.fileno() 388 self.server.request_fileno = self.request.fileno() 389 390 server = socketserver.TCPServer((HOST, 0), Handler) 391 self.addCleanup(server.server_close) 392 s = socket.socket( 393 server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) 394 with s: 395 s.connect(server.server_address) 396 server.handle_request() 397 self.assertIsInstance(server.wfile, io.BufferedIOBase) 398 self.assertEqual(server.wfile_fileno, server.request_fileno) 399 400 def test_write(self): 401 # Test that wfile.write() sends data immediately, and that it does 402 # not truncate sends when interrupted by a Unix signal 403 pthread_kill = test.support.get_attribute(signal, 'pthread_kill') 404 405 class Handler(socketserver.StreamRequestHandler): 406 def handle(self): 407 self.server.sent1 = self.wfile.write(b'write data\n') 408 # Should be sent immediately, without requiring flush() 409 self.server.received = self.rfile.readline() 410 big_chunk = b'\0' * test.support.SOCK_MAX_SIZE 411 self.server.sent2 = self.wfile.write(big_chunk) 412 413 server = socketserver.TCPServer((HOST, 0), Handler) 414 self.addCleanup(server.server_close) 415 interrupted = threading.Event() 416 417 def signal_handler(signum, frame): 418 interrupted.set() 419 420 original = signal.signal(signal.SIGUSR1, signal_handler) 421 self.addCleanup(signal.signal, signal.SIGUSR1, original) 422 response1 = None 423 received2 = None 424 main_thread = threading.get_ident() 425 426 def run_client(): 427 s = socket.socket(server.address_family, socket.SOCK_STREAM, 428 socket.IPPROTO_TCP) 429 with s, s.makefile('rb') as reader: 430 s.connect(server.server_address) 431 nonlocal response1 432 response1 = reader.readline() 433 s.sendall(b'client response\n') 434 435 reader.read(100) 436 # The main thread should now be blocking in a send() syscall. 437 # But in theory, it could get interrupted by other signals, 438 # and then retried. So keep sending the signal in a loop, in 439 # case an earlier signal happens to be delivered at an 440 # inconvenient moment. 441 while True: 442 pthread_kill(main_thread, signal.SIGUSR1) 443 if interrupted.wait(timeout=float(1)): 444 break 445 nonlocal received2 446 received2 = len(reader.read()) 447 448 background = threading.Thread(target=run_client) 449 background.start() 450 server.handle_request() 451 background.join() 452 self.assertEqual(server.sent1, len(response1)) 453 self.assertEqual(response1, b'write data\n') 454 self.assertEqual(server.received, b'client response\n') 455 self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE) 456 self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100) 457 458 459class MiscTestCase(unittest.TestCase): 460 461 def test_all(self): 462 # objects defined in the module should be in __all__ 463 expected = [] 464 for name in dir(socketserver): 465 if not name.startswith('_'): 466 mod_object = getattr(socketserver, name) 467 if getattr(mod_object, '__module__', None) == 'socketserver': 468 expected.append(name) 469 self.assertCountEqual(socketserver.__all__, expected) 470 471 def test_shutdown_request_called_if_verify_request_false(self): 472 # Issue #26309: BaseServer should call shutdown_request even if 473 # verify_request is False 474 475 class MyServer(socketserver.TCPServer): 476 def verify_request(self, request, client_address): 477 return False 478 479 shutdown_called = 0 480 def shutdown_request(self, request): 481 self.shutdown_called += 1 482 socketserver.TCPServer.shutdown_request(self, request) 483 484 server = MyServer((HOST, 0), socketserver.StreamRequestHandler) 485 s = socket.socket(server.address_family, socket.SOCK_STREAM) 486 s.connect(server.server_address) 487 s.close() 488 server.handle_request() 489 self.assertEqual(server.shutdown_called, 1) 490 server.server_close() 491 492 493if __name__ == "__main__": 494 unittest.main() 495