1import socket 2import time 3import asyncio 4import sys 5import unittest 6 7from asyncio import proactor_events 8from itertools import cycle, islice 9from test.test_asyncio import utils as test_utils 10from test import support 11from test.support import socket_helper 12 13 14class MyProto(asyncio.Protocol): 15 connected = None 16 done = None 17 18 def __init__(self, loop=None): 19 self.transport = None 20 self.state = 'INITIAL' 21 self.nbytes = 0 22 if loop is not None: 23 self.connected = loop.create_future() 24 self.done = loop.create_future() 25 26 def _assert_state(self, *expected): 27 if self.state not in expected: 28 raise AssertionError(f'state: {self.state!r}, expected: {expected!r}') 29 30 def connection_made(self, transport): 31 self.transport = transport 32 self._assert_state('INITIAL') 33 self.state = 'CONNECTED' 34 if self.connected: 35 self.connected.set_result(None) 36 transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') 37 38 def data_received(self, data): 39 self._assert_state('CONNECTED') 40 self.nbytes += len(data) 41 42 def eof_received(self): 43 self._assert_state('CONNECTED') 44 self.state = 'EOF' 45 46 def connection_lost(self, exc): 47 self._assert_state('CONNECTED', 'EOF') 48 self.state = 'CLOSED' 49 if self.done: 50 self.done.set_result(None) 51 52 53class BaseSockTestsMixin: 54 55 def create_event_loop(self): 56 raise NotImplementedError 57 58 def setUp(self): 59 self.loop = self.create_event_loop() 60 self.set_event_loop(self.loop) 61 super().setUp() 62 63 def tearDown(self): 64 # just in case if we have transport close callbacks 65 if not self.loop.is_closed(): 66 test_utils.run_briefly(self.loop) 67 68 self.doCleanups() 69 support.gc_collect() 70 super().tearDown() 71 72 def _basetest_sock_client_ops(self, httpd, sock): 73 if not isinstance(self.loop, proactor_events.BaseProactorEventLoop): 74 # in debug mode, socket operations must fail 75 # if the socket is not in blocking mode 76 self.loop.set_debug(True) 77 sock.setblocking(True) 78 with self.assertRaises(ValueError): 79 self.loop.run_until_complete( 80 self.loop.sock_connect(sock, httpd.address)) 81 with self.assertRaises(ValueError): 82 self.loop.run_until_complete( 83 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 84 with self.assertRaises(ValueError): 85 self.loop.run_until_complete( 86 self.loop.sock_recv(sock, 1024)) 87 with self.assertRaises(ValueError): 88 self.loop.run_until_complete( 89 self.loop.sock_recv_into(sock, bytearray())) 90 with self.assertRaises(ValueError): 91 self.loop.run_until_complete( 92 self.loop.sock_accept(sock)) 93 94 # test in non-blocking mode 95 sock.setblocking(False) 96 self.loop.run_until_complete( 97 self.loop.sock_connect(sock, httpd.address)) 98 self.loop.run_until_complete( 99 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 100 data = self.loop.run_until_complete( 101 self.loop.sock_recv(sock, 1024)) 102 # consume data 103 self.loop.run_until_complete( 104 self.loop.sock_recv(sock, 1024)) 105 sock.close() 106 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 107 108 def _basetest_sock_recv_into(self, httpd, sock): 109 # same as _basetest_sock_client_ops, but using sock_recv_into 110 sock.setblocking(False) 111 self.loop.run_until_complete( 112 self.loop.sock_connect(sock, httpd.address)) 113 self.loop.run_until_complete( 114 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 115 data = bytearray(1024) 116 with memoryview(data) as buf: 117 nbytes = self.loop.run_until_complete( 118 self.loop.sock_recv_into(sock, buf[:1024])) 119 # consume data 120 self.loop.run_until_complete( 121 self.loop.sock_recv_into(sock, buf[nbytes:])) 122 sock.close() 123 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 124 125 def test_sock_client_ops(self): 126 with test_utils.run_test_server() as httpd: 127 sock = socket.socket() 128 self._basetest_sock_client_ops(httpd, sock) 129 sock = socket.socket() 130 self._basetest_sock_recv_into(httpd, sock) 131 132 async def _basetest_sock_recv_racing(self, httpd, sock): 133 sock.setblocking(False) 134 await self.loop.sock_connect(sock, httpd.address) 135 136 task = asyncio.create_task(self.loop.sock_recv(sock, 1024)) 137 await asyncio.sleep(0) 138 task.cancel() 139 140 asyncio.create_task( 141 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 142 data = await self.loop.sock_recv(sock, 1024) 143 # consume data 144 await self.loop.sock_recv(sock, 1024) 145 146 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 147 148 async def _basetest_sock_recv_into_racing(self, httpd, sock): 149 sock.setblocking(False) 150 await self.loop.sock_connect(sock, httpd.address) 151 152 data = bytearray(1024) 153 with memoryview(data) as buf: 154 task = asyncio.create_task( 155 self.loop.sock_recv_into(sock, buf[:1024])) 156 await asyncio.sleep(0) 157 task.cancel() 158 159 task = asyncio.create_task( 160 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 161 nbytes = await self.loop.sock_recv_into(sock, buf[:1024]) 162 # consume data 163 await self.loop.sock_recv_into(sock, buf[nbytes:]) 164 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 165 166 await task 167 168 async def _basetest_sock_send_racing(self, listener, sock): 169 listener.bind(('127.0.0.1', 0)) 170 listener.listen(1) 171 172 # make connection 173 sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) 174 sock.setblocking(False) 175 task = asyncio.create_task( 176 self.loop.sock_connect(sock, listener.getsockname())) 177 await asyncio.sleep(0) 178 server = listener.accept()[0] 179 server.setblocking(False) 180 181 with server: 182 await task 183 184 # fill the buffer until sending 5 chars would block 185 size = 8192 186 while size >= 4: 187 with self.assertRaises(BlockingIOError): 188 while True: 189 sock.send(b' ' * size) 190 size = int(size / 2) 191 192 # cancel a blocked sock_sendall 193 task = asyncio.create_task( 194 self.loop.sock_sendall(sock, b'hello')) 195 await asyncio.sleep(0) 196 task.cancel() 197 198 # receive everything that is not a space 199 async def recv_all(): 200 rv = b'' 201 while True: 202 buf = await self.loop.sock_recv(server, 8192) 203 if not buf: 204 return rv 205 rv += buf.strip() 206 task = asyncio.create_task(recv_all()) 207 208 # immediately make another sock_sendall call 209 await self.loop.sock_sendall(sock, b'world') 210 sock.shutdown(socket.SHUT_WR) 211 data = await task 212 # ProactorEventLoop could deliver hello, so endswith is necessary 213 self.assertTrue(data.endswith(b'world')) 214 215 # After the first connect attempt before the listener is ready, 216 # the socket needs time to "recover" to make the next connect call. 217 # On Linux, a second retry will do. On Windows, the waiting time is 218 # unpredictable; and on FreeBSD the socket may never come back 219 # because it's a loopback address. Here we'll just retry for a few 220 # times, and have to skip the test if it's not working. See also: 221 # https://stackoverflow.com/a/54437602/3316267 222 # https://lists.freebsd.org/pipermail/freebsd-current/2005-May/049876.html 223 async def _basetest_sock_connect_racing(self, listener, sock): 224 listener.bind(('127.0.0.1', 0)) 225 addr = listener.getsockname() 226 sock.setblocking(False) 227 228 task = asyncio.create_task(self.loop.sock_connect(sock, addr)) 229 await asyncio.sleep(0) 230 task.cancel() 231 232 listener.listen(1) 233 234 skip_reason = "Max retries reached" 235 for i in range(128): 236 try: 237 await self.loop.sock_connect(sock, addr) 238 except ConnectionRefusedError as e: 239 skip_reason = e 240 except OSError as e: 241 skip_reason = e 242 243 # Retry only for this error: 244 # [WinError 10022] An invalid argument was supplied 245 if getattr(e, 'winerror', 0) != 10022: 246 break 247 else: 248 # success 249 return 250 251 self.skipTest(skip_reason) 252 253 def test_sock_client_racing(self): 254 with test_utils.run_test_server() as httpd: 255 sock = socket.socket() 256 with sock: 257 self.loop.run_until_complete(asyncio.wait_for( 258 self._basetest_sock_recv_racing(httpd, sock), 10)) 259 sock = socket.socket() 260 with sock: 261 self.loop.run_until_complete(asyncio.wait_for( 262 self._basetest_sock_recv_into_racing(httpd, sock), 10)) 263 listener = socket.socket() 264 sock = socket.socket() 265 with listener, sock: 266 self.loop.run_until_complete(asyncio.wait_for( 267 self._basetest_sock_send_racing(listener, sock), 10)) 268 269 def test_sock_client_connect_racing(self): 270 listener = socket.socket() 271 sock = socket.socket() 272 with listener, sock: 273 self.loop.run_until_complete(asyncio.wait_for( 274 self._basetest_sock_connect_racing(listener, sock), 10)) 275 276 async def _basetest_huge_content(self, address): 277 sock = socket.socket() 278 sock.setblocking(False) 279 DATA_SIZE = 10_000_00 280 281 chunk = b'0123456789' * (DATA_SIZE // 10) 282 283 await self.loop.sock_connect(sock, address) 284 await self.loop.sock_sendall(sock, 285 (b'POST /loop HTTP/1.0\r\n' + 286 b'Content-Length: %d\r\n' % DATA_SIZE + 287 b'\r\n')) 288 289 task = asyncio.create_task(self.loop.sock_sendall(sock, chunk)) 290 291 data = await self.loop.sock_recv(sock, DATA_SIZE) 292 # HTTP headers size is less than MTU, 293 # they are sent by the first packet always 294 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 295 while data.find(b'\r\n\r\n') == -1: 296 data += await self.loop.sock_recv(sock, DATA_SIZE) 297 # Strip headers 298 headers = data[:data.index(b'\r\n\r\n') + 4] 299 data = data[len(headers):] 300 301 size = DATA_SIZE 302 checker = cycle(b'0123456789') 303 304 expected = bytes(islice(checker, len(data))) 305 self.assertEqual(data, expected) 306 size -= len(data) 307 308 while True: 309 data = await self.loop.sock_recv(sock, DATA_SIZE) 310 if not data: 311 break 312 expected = bytes(islice(checker, len(data))) 313 self.assertEqual(data, expected) 314 size -= len(data) 315 self.assertEqual(size, 0) 316 317 await task 318 sock.close() 319 320 def test_huge_content(self): 321 with test_utils.run_test_server() as httpd: 322 self.loop.run_until_complete( 323 self._basetest_huge_content(httpd.address)) 324 325 async def _basetest_huge_content_recvinto(self, address): 326 sock = socket.socket() 327 sock.setblocking(False) 328 DATA_SIZE = 10_000_00 329 330 chunk = b'0123456789' * (DATA_SIZE // 10) 331 332 await self.loop.sock_connect(sock, address) 333 await self.loop.sock_sendall(sock, 334 (b'POST /loop HTTP/1.0\r\n' + 335 b'Content-Length: %d\r\n' % DATA_SIZE + 336 b'\r\n')) 337 338 task = asyncio.create_task(self.loop.sock_sendall(sock, chunk)) 339 340 array = bytearray(DATA_SIZE) 341 buf = memoryview(array) 342 343 nbytes = await self.loop.sock_recv_into(sock, buf) 344 data = bytes(buf[:nbytes]) 345 # HTTP headers size is less than MTU, 346 # they are sent by the first packet always 347 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 348 while data.find(b'\r\n\r\n') == -1: 349 nbytes = await self.loop.sock_recv_into(sock, buf) 350 data = bytes(buf[:nbytes]) 351 # Strip headers 352 headers = data[:data.index(b'\r\n\r\n') + 4] 353 data = data[len(headers):] 354 355 size = DATA_SIZE 356 checker = cycle(b'0123456789') 357 358 expected = bytes(islice(checker, len(data))) 359 self.assertEqual(data, expected) 360 size -= len(data) 361 362 while True: 363 nbytes = await self.loop.sock_recv_into(sock, buf) 364 data = buf[:nbytes] 365 if not data: 366 break 367 expected = bytes(islice(checker, len(data))) 368 self.assertEqual(data, expected) 369 size -= len(data) 370 self.assertEqual(size, 0) 371 372 await task 373 sock.close() 374 375 def test_huge_content_recvinto(self): 376 with test_utils.run_test_server() as httpd: 377 self.loop.run_until_complete( 378 self._basetest_huge_content_recvinto(httpd.address)) 379 380 @socket_helper.skip_unless_bind_unix_socket 381 def test_unix_sock_client_ops(self): 382 with test_utils.run_test_unix_server() as httpd: 383 sock = socket.socket(socket.AF_UNIX) 384 self._basetest_sock_client_ops(httpd, sock) 385 sock = socket.socket(socket.AF_UNIX) 386 self._basetest_sock_recv_into(httpd, sock) 387 388 def test_sock_client_fail(self): 389 # Make sure that we will get an unused port 390 address = None 391 try: 392 s = socket.socket() 393 s.bind(('127.0.0.1', 0)) 394 address = s.getsockname() 395 finally: 396 s.close() 397 398 sock = socket.socket() 399 sock.setblocking(False) 400 with self.assertRaises(ConnectionRefusedError): 401 self.loop.run_until_complete( 402 self.loop.sock_connect(sock, address)) 403 sock.close() 404 405 def test_sock_accept(self): 406 listener = socket.socket() 407 listener.setblocking(False) 408 listener.bind(('127.0.0.1', 0)) 409 listener.listen(1) 410 client = socket.socket() 411 client.connect(listener.getsockname()) 412 413 f = self.loop.sock_accept(listener) 414 conn, addr = self.loop.run_until_complete(f) 415 self.assertEqual(conn.gettimeout(), 0) 416 self.assertEqual(addr, client.getsockname()) 417 self.assertEqual(client.getpeername(), listener.getsockname()) 418 client.close() 419 conn.close() 420 listener.close() 421 422 def test_cancel_sock_accept(self): 423 listener = socket.socket() 424 listener.setblocking(False) 425 listener.bind(('127.0.0.1', 0)) 426 listener.listen(1) 427 sockaddr = listener.getsockname() 428 f = asyncio.wait_for(self.loop.sock_accept(listener), 0.1) 429 with self.assertRaises(asyncio.TimeoutError): 430 self.loop.run_until_complete(f) 431 432 listener.close() 433 client = socket.socket() 434 client.setblocking(False) 435 f = self.loop.sock_connect(client, sockaddr) 436 with self.assertRaises(ConnectionRefusedError): 437 self.loop.run_until_complete(f) 438 439 client.close() 440 441 def test_create_connection_sock(self): 442 with test_utils.run_test_server() as httpd: 443 sock = None 444 infos = self.loop.run_until_complete( 445 self.loop.getaddrinfo( 446 *httpd.address, type=socket.SOCK_STREAM)) 447 for family, type, proto, cname, address in infos: 448 try: 449 sock = socket.socket(family=family, type=type, proto=proto) 450 sock.setblocking(False) 451 self.loop.run_until_complete( 452 self.loop.sock_connect(sock, address)) 453 except BaseException: 454 pass 455 else: 456 break 457 else: 458 self.fail('Can not create socket.') 459 460 f = self.loop.create_connection( 461 lambda: MyProto(loop=self.loop), sock=sock) 462 tr, pr = self.loop.run_until_complete(f) 463 self.assertIsInstance(tr, asyncio.Transport) 464 self.assertIsInstance(pr, asyncio.Protocol) 465 self.loop.run_until_complete(pr.done) 466 self.assertGreater(pr.nbytes, 0) 467 tr.close() 468 469 470if sys.platform == 'win32': 471 472 class SelectEventLoopTests(BaseSockTestsMixin, 473 test_utils.TestCase): 474 475 def create_event_loop(self): 476 return asyncio.SelectorEventLoop() 477 478 class ProactorEventLoopTests(BaseSockTestsMixin, 479 test_utils.TestCase): 480 481 def create_event_loop(self): 482 return asyncio.ProactorEventLoop() 483 484else: 485 import selectors 486 487 if hasattr(selectors, 'KqueueSelector'): 488 class KqueueEventLoopTests(BaseSockTestsMixin, 489 test_utils.TestCase): 490 491 def create_event_loop(self): 492 return asyncio.SelectorEventLoop( 493 selectors.KqueueSelector()) 494 495 if hasattr(selectors, 'EpollSelector'): 496 class EPollEventLoopTests(BaseSockTestsMixin, 497 test_utils.TestCase): 498 499 def create_event_loop(self): 500 return asyncio.SelectorEventLoop(selectors.EpollSelector()) 501 502 if hasattr(selectors, 'PollSelector'): 503 class PollEventLoopTests(BaseSockTestsMixin, 504 test_utils.TestCase): 505 506 def create_event_loop(self): 507 return asyncio.SelectorEventLoop(selectors.PollSelector()) 508 509 # Should always exist. 510 class SelectEventLoopTests(BaseSockTestsMixin, 511 test_utils.TestCase): 512 513 def create_event_loop(self): 514 return asyncio.SelectorEventLoop(selectors.SelectSelector()) 515