1import socket 2import asyncio 3import sys 4import unittest 5 6from asyncio import proactor_events 7from itertools import cycle, islice 8from unittest.mock import Mock 9from test.test_asyncio import utils as test_utils 10from test import support 11from test.support import socket_helper 12 13if socket_helper.tcp_blackhole(): 14 raise unittest.SkipTest('Not relevant to ProactorEventLoop') 15 16 17def tearDownModule(): 18 asyncio.set_event_loop_policy(None) 19 20 21class MyProto(asyncio.Protocol): 22 connected = None 23 done = None 24 25 def __init__(self, loop=None): 26 self.transport = None 27 self.state = 'INITIAL' 28 self.nbytes = 0 29 if loop is not None: 30 self.connected = loop.create_future() 31 self.done = loop.create_future() 32 33 def _assert_state(self, *expected): 34 if self.state not in expected: 35 raise AssertionError(f'state: {self.state!r}, expected: {expected!r}') 36 37 def connection_made(self, transport): 38 self.transport = transport 39 self._assert_state('INITIAL') 40 self.state = 'CONNECTED' 41 if self.connected: 42 self.connected.set_result(None) 43 transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') 44 45 def data_received(self, data): 46 self._assert_state('CONNECTED') 47 self.nbytes += len(data) 48 49 def eof_received(self): 50 self._assert_state('CONNECTED') 51 self.state = 'EOF' 52 53 def connection_lost(self, exc): 54 self._assert_state('CONNECTED', 'EOF') 55 self.state = 'CLOSED' 56 if self.done: 57 self.done.set_result(None) 58 59 60class BaseSockTestsMixin: 61 62 def create_event_loop(self): 63 raise NotImplementedError 64 65 def setUp(self): 66 self.loop = self.create_event_loop() 67 self.set_event_loop(self.loop) 68 super().setUp() 69 70 def tearDown(self): 71 # just in case if we have transport close callbacks 72 if not self.loop.is_closed(): 73 test_utils.run_briefly(self.loop) 74 75 self.doCleanups() 76 support.gc_collect() 77 super().tearDown() 78 79 def _basetest_sock_client_ops(self, httpd, sock): 80 if not isinstance(self.loop, proactor_events.BaseProactorEventLoop): 81 # in debug mode, socket operations must fail 82 # if the socket is not in blocking mode 83 self.loop.set_debug(True) 84 sock.setblocking(True) 85 with self.assertRaises(ValueError): 86 self.loop.run_until_complete( 87 self.loop.sock_connect(sock, httpd.address)) 88 with self.assertRaises(ValueError): 89 self.loop.run_until_complete( 90 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 91 with self.assertRaises(ValueError): 92 self.loop.run_until_complete( 93 self.loop.sock_recv(sock, 1024)) 94 with self.assertRaises(ValueError): 95 self.loop.run_until_complete( 96 self.loop.sock_recv_into(sock, bytearray())) 97 with self.assertRaises(ValueError): 98 self.loop.run_until_complete( 99 self.loop.sock_accept(sock)) 100 101 # test in non-blocking mode 102 sock.setblocking(False) 103 self.loop.run_until_complete( 104 self.loop.sock_connect(sock, httpd.address)) 105 self.loop.run_until_complete( 106 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 107 data = self.loop.run_until_complete( 108 self.loop.sock_recv(sock, 1024)) 109 # consume data 110 self.loop.run_until_complete( 111 self.loop.sock_recv(sock, 1024)) 112 sock.close() 113 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 114 115 def _basetest_sock_recv_into(self, httpd, sock): 116 # same as _basetest_sock_client_ops, but using sock_recv_into 117 sock.setblocking(False) 118 self.loop.run_until_complete( 119 self.loop.sock_connect(sock, httpd.address)) 120 self.loop.run_until_complete( 121 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 122 data = bytearray(1024) 123 with memoryview(data) as buf: 124 nbytes = self.loop.run_until_complete( 125 self.loop.sock_recv_into(sock, buf[:1024])) 126 # consume data 127 self.loop.run_until_complete( 128 self.loop.sock_recv_into(sock, buf[nbytes:])) 129 sock.close() 130 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 131 132 def test_sock_client_ops(self): 133 with test_utils.run_test_server() as httpd: 134 sock = socket.socket() 135 self._basetest_sock_client_ops(httpd, sock) 136 sock = socket.socket() 137 self._basetest_sock_recv_into(httpd, sock) 138 139 async def _basetest_sock_recv_racing(self, httpd, sock): 140 sock.setblocking(False) 141 await self.loop.sock_connect(sock, httpd.address) 142 143 task = asyncio.create_task(self.loop.sock_recv(sock, 1024)) 144 await asyncio.sleep(0) 145 task.cancel() 146 147 asyncio.create_task( 148 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 149 data = await self.loop.sock_recv(sock, 1024) 150 # consume data 151 await self.loop.sock_recv(sock, 1024) 152 153 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 154 155 async def _basetest_sock_recv_into_racing(self, httpd, sock): 156 sock.setblocking(False) 157 await self.loop.sock_connect(sock, httpd.address) 158 159 data = bytearray(1024) 160 with memoryview(data) as buf: 161 task = asyncio.create_task( 162 self.loop.sock_recv_into(sock, buf[:1024])) 163 await asyncio.sleep(0) 164 task.cancel() 165 166 task = asyncio.create_task( 167 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 168 nbytes = await self.loop.sock_recv_into(sock, buf[:1024]) 169 # consume data 170 await self.loop.sock_recv_into(sock, buf[nbytes:]) 171 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 172 173 await task 174 175 async def _basetest_sock_send_racing(self, listener, sock): 176 listener.bind(('127.0.0.1', 0)) 177 listener.listen(1) 178 179 # make connection 180 sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) 181 sock.setblocking(False) 182 task = asyncio.create_task( 183 self.loop.sock_connect(sock, listener.getsockname())) 184 await asyncio.sleep(0) 185 server = listener.accept()[0] 186 server.setblocking(False) 187 188 with server: 189 await task 190 191 # fill the buffer until sending 5 chars would block 192 size = 8192 193 while size >= 4: 194 with self.assertRaises(BlockingIOError): 195 while True: 196 sock.send(b' ' * size) 197 size = int(size / 2) 198 199 # cancel a blocked sock_sendall 200 task = asyncio.create_task( 201 self.loop.sock_sendall(sock, b'hello')) 202 await asyncio.sleep(0) 203 task.cancel() 204 205 # receive everything that is not a space 206 async def recv_all(): 207 rv = b'' 208 while True: 209 buf = await self.loop.sock_recv(server, 8192) 210 if not buf: 211 return rv 212 rv += buf.strip() 213 task = asyncio.create_task(recv_all()) 214 215 # immediately make another sock_sendall call 216 await self.loop.sock_sendall(sock, b'world') 217 sock.shutdown(socket.SHUT_WR) 218 data = await task 219 # ProactorEventLoop could deliver hello, so endswith is necessary 220 self.assertTrue(data.endswith(b'world')) 221 222 # After the first connect attempt before the listener is ready, 223 # the socket needs time to "recover" to make the next connect call. 224 # On Linux, a second retry will do. On Windows, the waiting time is 225 # unpredictable; and on FreeBSD the socket may never come back 226 # because it's a loopback address. Here we'll just retry for a few 227 # times, and have to skip the test if it's not working. See also: 228 # https://stackoverflow.com/a/54437602/3316267 229 # https://lists.freebsd.org/pipermail/freebsd-current/2005-May/049876.html 230 async def _basetest_sock_connect_racing(self, listener, sock): 231 listener.bind(('127.0.0.1', 0)) 232 addr = listener.getsockname() 233 sock.setblocking(False) 234 235 task = asyncio.create_task(self.loop.sock_connect(sock, addr)) 236 await asyncio.sleep(0) 237 task.cancel() 238 239 listener.listen(1) 240 241 skip_reason = "Max retries reached" 242 for i in range(128): 243 try: 244 await self.loop.sock_connect(sock, addr) 245 except ConnectionRefusedError as e: 246 skip_reason = e 247 except OSError as e: 248 skip_reason = e 249 250 # Retry only for this error: 251 # [WinError 10022] An invalid argument was supplied 252 if getattr(e, 'winerror', 0) != 10022: 253 break 254 else: 255 # success 256 return 257 258 self.skipTest(skip_reason) 259 260 def test_sock_client_racing(self): 261 with test_utils.run_test_server() as httpd: 262 sock = socket.socket() 263 with sock: 264 self.loop.run_until_complete(asyncio.wait_for( 265 self._basetest_sock_recv_racing(httpd, sock), 10)) 266 sock = socket.socket() 267 with sock: 268 self.loop.run_until_complete(asyncio.wait_for( 269 self._basetest_sock_recv_into_racing(httpd, sock), 10)) 270 listener = socket.socket() 271 sock = socket.socket() 272 with listener, sock: 273 self.loop.run_until_complete(asyncio.wait_for( 274 self._basetest_sock_send_racing(listener, sock), 10)) 275 276 def test_sock_client_connect_racing(self): 277 listener = socket.socket() 278 sock = socket.socket() 279 with listener, sock: 280 self.loop.run_until_complete(asyncio.wait_for( 281 self._basetest_sock_connect_racing(listener, sock), 10)) 282 283 async def _basetest_huge_content(self, address): 284 sock = socket.socket() 285 sock.setblocking(False) 286 DATA_SIZE = 10_000_00 287 288 chunk = b'0123456789' * (DATA_SIZE // 10) 289 290 await self.loop.sock_connect(sock, address) 291 await self.loop.sock_sendall(sock, 292 (b'POST /loop HTTP/1.0\r\n' + 293 b'Content-Length: %d\r\n' % DATA_SIZE + 294 b'\r\n')) 295 296 task = asyncio.create_task(self.loop.sock_sendall(sock, chunk)) 297 298 data = await self.loop.sock_recv(sock, DATA_SIZE) 299 # HTTP headers size is less than MTU, 300 # they are sent by the first packet always 301 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 302 while data.find(b'\r\n\r\n') == -1: 303 data += await self.loop.sock_recv(sock, DATA_SIZE) 304 # Strip headers 305 headers = data[:data.index(b'\r\n\r\n') + 4] 306 data = data[len(headers):] 307 308 size = DATA_SIZE 309 checker = cycle(b'0123456789') 310 311 expected = bytes(islice(checker, len(data))) 312 self.assertEqual(data, expected) 313 size -= len(data) 314 315 while True: 316 data = await self.loop.sock_recv(sock, DATA_SIZE) 317 if not data: 318 break 319 expected = bytes(islice(checker, len(data))) 320 self.assertEqual(data, expected) 321 size -= len(data) 322 self.assertEqual(size, 0) 323 324 await task 325 sock.close() 326 327 def test_huge_content(self): 328 with test_utils.run_test_server() as httpd: 329 self.loop.run_until_complete( 330 self._basetest_huge_content(httpd.address)) 331 332 async def _basetest_huge_content_recvinto(self, address): 333 sock = socket.socket() 334 sock.setblocking(False) 335 DATA_SIZE = 10_000_00 336 337 chunk = b'0123456789' * (DATA_SIZE // 10) 338 339 await self.loop.sock_connect(sock, address) 340 await self.loop.sock_sendall(sock, 341 (b'POST /loop HTTP/1.0\r\n' + 342 b'Content-Length: %d\r\n' % DATA_SIZE + 343 b'\r\n')) 344 345 task = asyncio.create_task(self.loop.sock_sendall(sock, chunk)) 346 347 array = bytearray(DATA_SIZE) 348 buf = memoryview(array) 349 350 nbytes = await self.loop.sock_recv_into(sock, buf) 351 data = bytes(buf[:nbytes]) 352 # HTTP headers size is less than MTU, 353 # they are sent by the first packet always 354 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 355 while data.find(b'\r\n\r\n') == -1: 356 nbytes = await self.loop.sock_recv_into(sock, buf) 357 data = bytes(buf[:nbytes]) 358 # Strip headers 359 headers = data[:data.index(b'\r\n\r\n') + 4] 360 data = data[len(headers):] 361 362 size = DATA_SIZE 363 checker = cycle(b'0123456789') 364 365 expected = bytes(islice(checker, len(data))) 366 self.assertEqual(data, expected) 367 size -= len(data) 368 369 while True: 370 nbytes = await self.loop.sock_recv_into(sock, buf) 371 data = buf[:nbytes] 372 if not data: 373 break 374 expected = bytes(islice(checker, len(data))) 375 self.assertEqual(data, expected) 376 size -= len(data) 377 self.assertEqual(size, 0) 378 379 await task 380 sock.close() 381 382 def test_huge_content_recvinto(self): 383 with test_utils.run_test_server() as httpd: 384 self.loop.run_until_complete( 385 self._basetest_huge_content_recvinto(httpd.address)) 386 387 async def _basetest_datagram_recvfrom(self, server_address): 388 # Happy path, sock.sendto() returns immediately 389 data = b'\x01' * 4096 390 with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: 391 sock.setblocking(False) 392 await self.loop.sock_sendto(sock, data, server_address) 393 received_data, from_addr = await self.loop.sock_recvfrom( 394 sock, 4096) 395 self.assertEqual(received_data, data) 396 self.assertEqual(from_addr, server_address) 397 398 def test_recvfrom(self): 399 with test_utils.run_udp_echo_server() as server_address: 400 self.loop.run_until_complete( 401 self._basetest_datagram_recvfrom(server_address)) 402 403 async def _basetest_datagram_recvfrom_into(self, server_address): 404 # Happy path, sock.sendto() returns immediately 405 with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: 406 sock.setblocking(False) 407 408 buf = bytearray(4096) 409 data = b'\x01' * 4096 410 await self.loop.sock_sendto(sock, data, server_address) 411 num_bytes, from_addr = await self.loop.sock_recvfrom_into( 412 sock, buf) 413 self.assertEqual(num_bytes, 4096) 414 self.assertEqual(buf, data) 415 self.assertEqual(from_addr, server_address) 416 417 buf = bytearray(8192) 418 await self.loop.sock_sendto(sock, data, server_address) 419 num_bytes, from_addr = await self.loop.sock_recvfrom_into( 420 sock, buf, 4096) 421 self.assertEqual(num_bytes, 4096) 422 self.assertEqual(buf[:4096], data[:4096]) 423 self.assertEqual(from_addr, server_address) 424 425 def test_recvfrom_into(self): 426 with test_utils.run_udp_echo_server() as server_address: 427 self.loop.run_until_complete( 428 self._basetest_datagram_recvfrom_into(server_address)) 429 430 async def _basetest_datagram_sendto_blocking(self, server_address): 431 # Sad path, sock.sendto() raises BlockingIOError 432 # This involves patching sock.sendto() to raise BlockingIOError but 433 # sendto() is not used by the proactor event loop 434 data = b'\x01' * 4096 435 with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: 436 sock.setblocking(False) 437 mock_sock = Mock(sock) 438 mock_sock.gettimeout = sock.gettimeout 439 mock_sock.sendto.configure_mock(side_effect=BlockingIOError) 440 mock_sock.fileno = sock.fileno 441 self.loop.call_soon( 442 lambda: setattr(mock_sock, 'sendto', sock.sendto) 443 ) 444 await self.loop.sock_sendto(mock_sock, data, server_address) 445 446 received_data, from_addr = await self.loop.sock_recvfrom( 447 sock, 4096) 448 self.assertEqual(received_data, data) 449 self.assertEqual(from_addr, server_address) 450 451 def test_sendto_blocking(self): 452 if sys.platform == 'win32': 453 if isinstance(self.loop, asyncio.ProactorEventLoop): 454 raise unittest.SkipTest('Not relevant to ProactorEventLoop') 455 456 with test_utils.run_udp_echo_server() as server_address: 457 self.loop.run_until_complete( 458 self._basetest_datagram_sendto_blocking(server_address)) 459 460 @socket_helper.skip_unless_bind_unix_socket 461 def test_unix_sock_client_ops(self): 462 with test_utils.run_test_unix_server() as httpd: 463 sock = socket.socket(socket.AF_UNIX) 464 self._basetest_sock_client_ops(httpd, sock) 465 sock = socket.socket(socket.AF_UNIX) 466 self._basetest_sock_recv_into(httpd, sock) 467 468 def test_sock_client_fail(self): 469 # Make sure that we will get an unused port 470 address = None 471 try: 472 s = socket.socket() 473 s.bind(('127.0.0.1', 0)) 474 address = s.getsockname() 475 finally: 476 s.close() 477 478 sock = socket.socket() 479 sock.setblocking(False) 480 with self.assertRaises(ConnectionRefusedError): 481 self.loop.run_until_complete( 482 self.loop.sock_connect(sock, address)) 483 sock.close() 484 485 def test_sock_accept(self): 486 listener = socket.socket() 487 listener.setblocking(False) 488 listener.bind(('127.0.0.1', 0)) 489 listener.listen(1) 490 client = socket.socket() 491 client.connect(listener.getsockname()) 492 493 f = self.loop.sock_accept(listener) 494 conn, addr = self.loop.run_until_complete(f) 495 self.assertEqual(conn.gettimeout(), 0) 496 self.assertEqual(addr, client.getsockname()) 497 self.assertEqual(client.getpeername(), listener.getsockname()) 498 client.close() 499 conn.close() 500 listener.close() 501 502 def test_cancel_sock_accept(self): 503 listener = socket.socket() 504 listener.setblocking(False) 505 listener.bind(('127.0.0.1', 0)) 506 listener.listen(1) 507 sockaddr = listener.getsockname() 508 f = asyncio.wait_for(self.loop.sock_accept(listener), 0.1) 509 with self.assertRaises(asyncio.TimeoutError): 510 self.loop.run_until_complete(f) 511 512 listener.close() 513 client = socket.socket() 514 client.setblocking(False) 515 f = self.loop.sock_connect(client, sockaddr) 516 with self.assertRaises(ConnectionRefusedError): 517 self.loop.run_until_complete(f) 518 519 client.close() 520 521 def test_create_connection_sock(self): 522 with test_utils.run_test_server() as httpd: 523 sock = None 524 infos = self.loop.run_until_complete( 525 self.loop.getaddrinfo( 526 *httpd.address, type=socket.SOCK_STREAM)) 527 for family, type, proto, cname, address in infos: 528 try: 529 sock = socket.socket(family=family, type=type, proto=proto) 530 sock.setblocking(False) 531 self.loop.run_until_complete( 532 self.loop.sock_connect(sock, address)) 533 except BaseException: 534 pass 535 else: 536 break 537 else: 538 self.fail('Can not create socket.') 539 540 f = self.loop.create_connection( 541 lambda: MyProto(loop=self.loop), sock=sock) 542 tr, pr = self.loop.run_until_complete(f) 543 self.assertIsInstance(tr, asyncio.Transport) 544 self.assertIsInstance(pr, asyncio.Protocol) 545 self.loop.run_until_complete(pr.done) 546 self.assertGreater(pr.nbytes, 0) 547 tr.close() 548 549 550if sys.platform == 'win32': 551 552 class SelectEventLoopTests(BaseSockTestsMixin, 553 test_utils.TestCase): 554 555 def create_event_loop(self): 556 return asyncio.SelectorEventLoop() 557 558 559 class ProactorEventLoopTests(BaseSockTestsMixin, 560 test_utils.TestCase): 561 562 def create_event_loop(self): 563 return asyncio.ProactorEventLoop() 564 565 566 async def _basetest_datagram_send_to_non_listening_address(self, 567 recvfrom): 568 # see: 569 # https://github.com/python/cpython/issues/91227 570 # https://github.com/python/cpython/issues/88906 571 # https://bugs.python.org/issue47071 572 # https://bugs.python.org/issue44743 573 # The Proactor event loop would fail to receive datagram messages 574 # after sending a message to an address that wasn't listening. 575 576 def create_socket(): 577 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 578 sock.setblocking(False) 579 sock.bind(('127.0.0.1', 0)) 580 return sock 581 582 socket_1 = create_socket() 583 addr_1 = socket_1.getsockname() 584 585 socket_2 = create_socket() 586 addr_2 = socket_2.getsockname() 587 588 # creating and immediately closing this to try to get an address 589 # that is not listening 590 socket_3 = create_socket() 591 addr_3 = socket_3.getsockname() 592 socket_3.shutdown(socket.SHUT_RDWR) 593 socket_3.close() 594 595 socket_1_recv_task = self.loop.create_task(recvfrom(socket_1)) 596 socket_2_recv_task = self.loop.create_task(recvfrom(socket_2)) 597 await asyncio.sleep(0) 598 599 await self.loop.sock_sendto(socket_1, b'a', addr_2) 600 self.assertEqual(await socket_2_recv_task, b'a') 601 602 await self.loop.sock_sendto(socket_2, b'b', addr_1) 603 self.assertEqual(await socket_1_recv_task, b'b') 604 socket_1_recv_task = self.loop.create_task(recvfrom(socket_1)) 605 await asyncio.sleep(0) 606 607 # this should send to an address that isn't listening 608 await self.loop.sock_sendto(socket_1, b'c', addr_3) 609 self.assertEqual(await socket_1_recv_task, b'') 610 socket_1_recv_task = self.loop.create_task(recvfrom(socket_1)) 611 await asyncio.sleep(0) 612 613 # socket 1 should still be able to receive messages after sending 614 # to an address that wasn't listening 615 socket_2.sendto(b'd', addr_1) 616 self.assertEqual(await socket_1_recv_task, b'd') 617 618 socket_1.shutdown(socket.SHUT_RDWR) 619 socket_1.close() 620 socket_2.shutdown(socket.SHUT_RDWR) 621 socket_2.close() 622 623 624 def test_datagram_send_to_non_listening_address_recvfrom(self): 625 async def recvfrom(socket): 626 data, _ = await self.loop.sock_recvfrom(socket, 4096) 627 return data 628 629 self.loop.run_until_complete( 630 self._basetest_datagram_send_to_non_listening_address( 631 recvfrom)) 632 633 634 def test_datagram_send_to_non_listening_address_recvfrom_into(self): 635 async def recvfrom_into(socket): 636 buf = bytearray(4096) 637 length, _ = await self.loop.sock_recvfrom_into(socket, buf, 638 4096) 639 return buf[:length] 640 641 self.loop.run_until_complete( 642 self._basetest_datagram_send_to_non_listening_address( 643 recvfrom_into)) 644 645else: 646 import selectors 647 648 if hasattr(selectors, 'KqueueSelector'): 649 class KqueueEventLoopTests(BaseSockTestsMixin, 650 test_utils.TestCase): 651 652 def create_event_loop(self): 653 return asyncio.SelectorEventLoop( 654 selectors.KqueueSelector()) 655 656 if hasattr(selectors, 'EpollSelector'): 657 class EPollEventLoopTests(BaseSockTestsMixin, 658 test_utils.TestCase): 659 660 def create_event_loop(self): 661 return asyncio.SelectorEventLoop(selectors.EpollSelector()) 662 663 if hasattr(selectors, 'PollSelector'): 664 class PollEventLoopTests(BaseSockTestsMixin, 665 test_utils.TestCase): 666 667 def create_event_loop(self): 668 return asyncio.SelectorEventLoop(selectors.PollSelector()) 669 670 # Should always exist. 671 class SelectEventLoopTests(BaseSockTestsMixin, 672 test_utils.TestCase): 673 674 def create_event_loop(self): 675 return asyncio.SelectorEventLoop(selectors.SelectSelector()) 676 677 678if __name__ == '__main__': 679 unittest.main() 680