1import socket 2import time 3import asyncio 4import sys 5import unittest 6 7from asyncio import proactor_events 8from itertools import cycle, islice 9from test.test_asyncio import utils as test_utils 10from test import support 11from test.support import socket_helper 12 13 14class MyProto(asyncio.Protocol): 15 connected = None 16 done = None 17 18 def __init__(self, loop=None): 19 self.transport = None 20 self.state = 'INITIAL' 21 self.nbytes = 0 22 if loop is not None: 23 self.connected = loop.create_future() 24 self.done = loop.create_future() 25 26 def connection_made(self, transport): 27 self.transport = transport 28 assert self.state == 'INITIAL', self.state 29 self.state = 'CONNECTED' 30 if self.connected: 31 self.connected.set_result(None) 32 transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') 33 34 def data_received(self, data): 35 assert self.state == 'CONNECTED', self.state 36 self.nbytes += len(data) 37 38 def eof_received(self): 39 assert self.state == 'CONNECTED', self.state 40 self.state = 'EOF' 41 42 def connection_lost(self, exc): 43 assert self.state in ('CONNECTED', 'EOF'), self.state 44 self.state = 'CLOSED' 45 if self.done: 46 self.done.set_result(None) 47 48 49class BaseSockTestsMixin: 50 51 def create_event_loop(self): 52 raise NotImplementedError 53 54 def setUp(self): 55 self.loop = self.create_event_loop() 56 self.set_event_loop(self.loop) 57 super().setUp() 58 59 def tearDown(self): 60 # just in case if we have transport close callbacks 61 if not self.loop.is_closed(): 62 test_utils.run_briefly(self.loop) 63 64 self.doCleanups() 65 support.gc_collect() 66 super().tearDown() 67 68 def _basetest_sock_client_ops(self, httpd, sock): 69 if not isinstance(self.loop, proactor_events.BaseProactorEventLoop): 70 # in debug mode, socket operations must fail 71 # if the socket is not in blocking mode 72 self.loop.set_debug(True) 73 sock.setblocking(True) 74 with self.assertRaises(ValueError): 75 self.loop.run_until_complete( 76 self.loop.sock_connect(sock, httpd.address)) 77 with self.assertRaises(ValueError): 78 self.loop.run_until_complete( 79 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 80 with self.assertRaises(ValueError): 81 self.loop.run_until_complete( 82 self.loop.sock_recv(sock, 1024)) 83 with self.assertRaises(ValueError): 84 self.loop.run_until_complete( 85 self.loop.sock_recv_into(sock, bytearray())) 86 with self.assertRaises(ValueError): 87 self.loop.run_until_complete( 88 self.loop.sock_accept(sock)) 89 90 # test in non-blocking mode 91 sock.setblocking(False) 92 self.loop.run_until_complete( 93 self.loop.sock_connect(sock, httpd.address)) 94 self.loop.run_until_complete( 95 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 96 data = self.loop.run_until_complete( 97 self.loop.sock_recv(sock, 1024)) 98 # consume data 99 self.loop.run_until_complete( 100 self.loop.sock_recv(sock, 1024)) 101 sock.close() 102 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 103 104 def _basetest_sock_recv_into(self, httpd, sock): 105 # same as _basetest_sock_client_ops, but using sock_recv_into 106 sock.setblocking(False) 107 self.loop.run_until_complete( 108 self.loop.sock_connect(sock, httpd.address)) 109 self.loop.run_until_complete( 110 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 111 data = bytearray(1024) 112 with memoryview(data) as buf: 113 nbytes = self.loop.run_until_complete( 114 self.loop.sock_recv_into(sock, buf[:1024])) 115 # consume data 116 self.loop.run_until_complete( 117 self.loop.sock_recv_into(sock, buf[nbytes:])) 118 sock.close() 119 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 120 121 def test_sock_client_ops(self): 122 with test_utils.run_test_server() as httpd: 123 sock = socket.socket() 124 self._basetest_sock_client_ops(httpd, sock) 125 sock = socket.socket() 126 self._basetest_sock_recv_into(httpd, sock) 127 128 async def _basetest_sock_recv_racing(self, httpd, sock): 129 sock.setblocking(False) 130 await self.loop.sock_connect(sock, httpd.address) 131 132 task = asyncio.create_task(self.loop.sock_recv(sock, 1024)) 133 await asyncio.sleep(0) 134 task.cancel() 135 136 asyncio.create_task( 137 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 138 data = await self.loop.sock_recv(sock, 1024) 139 # consume data 140 await self.loop.sock_recv(sock, 1024) 141 142 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 143 144 async def _basetest_sock_recv_into_racing(self, httpd, sock): 145 sock.setblocking(False) 146 await self.loop.sock_connect(sock, httpd.address) 147 148 data = bytearray(1024) 149 with memoryview(data) as buf: 150 task = asyncio.create_task( 151 self.loop.sock_recv_into(sock, buf[:1024])) 152 await asyncio.sleep(0) 153 task.cancel() 154 155 task = asyncio.create_task( 156 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 157 nbytes = await self.loop.sock_recv_into(sock, buf[:1024]) 158 # consume data 159 await self.loop.sock_recv_into(sock, buf[nbytes:]) 160 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 161 162 await task 163 164 async def _basetest_sock_send_racing(self, listener, sock): 165 listener.bind(('127.0.0.1', 0)) 166 listener.listen(1) 167 168 # make connection 169 sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) 170 sock.setblocking(False) 171 task = asyncio.create_task( 172 self.loop.sock_connect(sock, listener.getsockname())) 173 await asyncio.sleep(0) 174 server = listener.accept()[0] 175 server.setblocking(False) 176 177 with server: 178 await task 179 180 # fill the buffer until sending 5 chars would block 181 size = 8192 182 while size >= 4: 183 with self.assertRaises(BlockingIOError): 184 while True: 185 sock.send(b' ' * size) 186 size = int(size / 2) 187 188 # cancel a blocked sock_sendall 189 task = asyncio.create_task( 190 self.loop.sock_sendall(sock, b'hello')) 191 await asyncio.sleep(0) 192 task.cancel() 193 194 # receive everything that is not a space 195 async def recv_all(): 196 rv = b'' 197 while True: 198 buf = await self.loop.sock_recv(server, 8192) 199 if not buf: 200 return rv 201 rv += buf.strip() 202 task = asyncio.create_task(recv_all()) 203 204 # immediately make another sock_sendall call 205 await self.loop.sock_sendall(sock, b'world') 206 sock.shutdown(socket.SHUT_WR) 207 data = await task 208 # ProactorEventLoop could deliver hello, so endswith is necessary 209 self.assertTrue(data.endswith(b'world')) 210 211 # After the first connect attempt before the listener is ready, 212 # the socket needs time to "recover" to make the next connect call. 213 # On Linux, a second retry will do. On Windows, the waiting time is 214 # unpredictable; and on FreeBSD the socket may never come back 215 # because it's a loopback address. Here we'll just retry for a few 216 # times, and have to skip the test if it's not working. See also: 217 # https://stackoverflow.com/a/54437602/3316267 218 # https://lists.freebsd.org/pipermail/freebsd-current/2005-May/049876.html 219 async def _basetest_sock_connect_racing(self, listener, sock): 220 listener.bind(('127.0.0.1', 0)) 221 addr = listener.getsockname() 222 sock.setblocking(False) 223 224 task = asyncio.create_task(self.loop.sock_connect(sock, addr)) 225 await asyncio.sleep(0) 226 task.cancel() 227 228 listener.listen(1) 229 230 skip_reason = "Max retries reached" 231 for i in range(128): 232 try: 233 await self.loop.sock_connect(sock, addr) 234 except ConnectionRefusedError as e: 235 skip_reason = e 236 except OSError as e: 237 skip_reason = e 238 239 # Retry only for this error: 240 # [WinError 10022] An invalid argument was supplied 241 if getattr(e, 'winerror', 0) != 10022: 242 break 243 else: 244 # success 245 return 246 247 self.skipTest(skip_reason) 248 249 def test_sock_client_racing(self): 250 with test_utils.run_test_server() as httpd: 251 sock = socket.socket() 252 with sock: 253 self.loop.run_until_complete(asyncio.wait_for( 254 self._basetest_sock_recv_racing(httpd, sock), 10)) 255 sock = socket.socket() 256 with sock: 257 self.loop.run_until_complete(asyncio.wait_for( 258 self._basetest_sock_recv_into_racing(httpd, sock), 10)) 259 listener = socket.socket() 260 sock = socket.socket() 261 with listener, sock: 262 self.loop.run_until_complete(asyncio.wait_for( 263 self._basetest_sock_send_racing(listener, sock), 10)) 264 265 def test_sock_client_connect_racing(self): 266 listener = socket.socket() 267 sock = socket.socket() 268 with listener, sock: 269 self.loop.run_until_complete(asyncio.wait_for( 270 self._basetest_sock_connect_racing(listener, sock), 10)) 271 272 async def _basetest_huge_content(self, address): 273 sock = socket.socket() 274 sock.setblocking(False) 275 DATA_SIZE = 10_000_00 276 277 chunk = b'0123456789' * (DATA_SIZE // 10) 278 279 await self.loop.sock_connect(sock, address) 280 await self.loop.sock_sendall(sock, 281 (b'POST /loop HTTP/1.0\r\n' + 282 b'Content-Length: %d\r\n' % DATA_SIZE + 283 b'\r\n')) 284 285 task = asyncio.create_task(self.loop.sock_sendall(sock, chunk)) 286 287 data = await self.loop.sock_recv(sock, DATA_SIZE) 288 # HTTP headers size is less than MTU, 289 # they are sent by the first packet always 290 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 291 while data.find(b'\r\n\r\n') == -1: 292 data += await self.loop.sock_recv(sock, DATA_SIZE) 293 # Strip headers 294 headers = data[:data.index(b'\r\n\r\n') + 4] 295 data = data[len(headers):] 296 297 size = DATA_SIZE 298 checker = cycle(b'0123456789') 299 300 expected = bytes(islice(checker, len(data))) 301 self.assertEqual(data, expected) 302 size -= len(data) 303 304 while True: 305 data = await self.loop.sock_recv(sock, DATA_SIZE) 306 if not data: 307 break 308 expected = bytes(islice(checker, len(data))) 309 self.assertEqual(data, expected) 310 size -= len(data) 311 self.assertEqual(size, 0) 312 313 await task 314 sock.close() 315 316 def test_huge_content(self): 317 with test_utils.run_test_server() as httpd: 318 self.loop.run_until_complete( 319 self._basetest_huge_content(httpd.address)) 320 321 async def _basetest_huge_content_recvinto(self, address): 322 sock = socket.socket() 323 sock.setblocking(False) 324 DATA_SIZE = 10_000_00 325 326 chunk = b'0123456789' * (DATA_SIZE // 10) 327 328 await self.loop.sock_connect(sock, address) 329 await self.loop.sock_sendall(sock, 330 (b'POST /loop HTTP/1.0\r\n' + 331 b'Content-Length: %d\r\n' % DATA_SIZE + 332 b'\r\n')) 333 334 task = asyncio.create_task(self.loop.sock_sendall(sock, chunk)) 335 336 array = bytearray(DATA_SIZE) 337 buf = memoryview(array) 338 339 nbytes = await self.loop.sock_recv_into(sock, buf) 340 data = bytes(buf[:nbytes]) 341 # HTTP headers size is less than MTU, 342 # they are sent by the first packet always 343 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 344 while data.find(b'\r\n\r\n') == -1: 345 nbytes = await self.loop.sock_recv_into(sock, buf) 346 data = bytes(buf[:nbytes]) 347 # Strip headers 348 headers = data[:data.index(b'\r\n\r\n') + 4] 349 data = data[len(headers):] 350 351 size = DATA_SIZE 352 checker = cycle(b'0123456789') 353 354 expected = bytes(islice(checker, len(data))) 355 self.assertEqual(data, expected) 356 size -= len(data) 357 358 while True: 359 nbytes = await self.loop.sock_recv_into(sock, buf) 360 data = buf[:nbytes] 361 if not data: 362 break 363 expected = bytes(islice(checker, len(data))) 364 self.assertEqual(data, expected) 365 size -= len(data) 366 self.assertEqual(size, 0) 367 368 await task 369 sock.close() 370 371 def test_huge_content_recvinto(self): 372 with test_utils.run_test_server() as httpd: 373 self.loop.run_until_complete( 374 self._basetest_huge_content_recvinto(httpd.address)) 375 376 @socket_helper.skip_unless_bind_unix_socket 377 def test_unix_sock_client_ops(self): 378 with test_utils.run_test_unix_server() as httpd: 379 sock = socket.socket(socket.AF_UNIX) 380 self._basetest_sock_client_ops(httpd, sock) 381 sock = socket.socket(socket.AF_UNIX) 382 self._basetest_sock_recv_into(httpd, sock) 383 384 def test_sock_client_fail(self): 385 # Make sure that we will get an unused port 386 address = None 387 try: 388 s = socket.socket() 389 s.bind(('127.0.0.1', 0)) 390 address = s.getsockname() 391 finally: 392 s.close() 393 394 sock = socket.socket() 395 sock.setblocking(False) 396 with self.assertRaises(ConnectionRefusedError): 397 self.loop.run_until_complete( 398 self.loop.sock_connect(sock, address)) 399 sock.close() 400 401 def test_sock_accept(self): 402 listener = socket.socket() 403 listener.setblocking(False) 404 listener.bind(('127.0.0.1', 0)) 405 listener.listen(1) 406 client = socket.socket() 407 client.connect(listener.getsockname()) 408 409 f = self.loop.sock_accept(listener) 410 conn, addr = self.loop.run_until_complete(f) 411 self.assertEqual(conn.gettimeout(), 0) 412 self.assertEqual(addr, client.getsockname()) 413 self.assertEqual(client.getpeername(), listener.getsockname()) 414 client.close() 415 conn.close() 416 listener.close() 417 418 def test_cancel_sock_accept(self): 419 listener = socket.socket() 420 listener.setblocking(False) 421 listener.bind(('127.0.0.1', 0)) 422 listener.listen(1) 423 sockaddr = listener.getsockname() 424 f = asyncio.wait_for(self.loop.sock_accept(listener), 0.1) 425 with self.assertRaises(asyncio.TimeoutError): 426 self.loop.run_until_complete(f) 427 428 listener.close() 429 client = socket.socket() 430 client.setblocking(False) 431 f = self.loop.sock_connect(client, sockaddr) 432 with self.assertRaises(ConnectionRefusedError): 433 self.loop.run_until_complete(f) 434 435 client.close() 436 437 def test_create_connection_sock(self): 438 with test_utils.run_test_server() as httpd: 439 sock = None 440 infos = self.loop.run_until_complete( 441 self.loop.getaddrinfo( 442 *httpd.address, type=socket.SOCK_STREAM)) 443 for family, type, proto, cname, address in infos: 444 try: 445 sock = socket.socket(family=family, type=type, proto=proto) 446 sock.setblocking(False) 447 self.loop.run_until_complete( 448 self.loop.sock_connect(sock, address)) 449 except BaseException: 450 pass 451 else: 452 break 453 else: 454 assert False, 'Can not create socket.' 455 456 f = self.loop.create_connection( 457 lambda: MyProto(loop=self.loop), sock=sock) 458 tr, pr = self.loop.run_until_complete(f) 459 self.assertIsInstance(tr, asyncio.Transport) 460 self.assertIsInstance(pr, asyncio.Protocol) 461 self.loop.run_until_complete(pr.done) 462 self.assertGreater(pr.nbytes, 0) 463 tr.close() 464 465 466if sys.platform == 'win32': 467 468 class SelectEventLoopTests(BaseSockTestsMixin, 469 test_utils.TestCase): 470 471 def create_event_loop(self): 472 return asyncio.SelectorEventLoop() 473 474 class ProactorEventLoopTests(BaseSockTestsMixin, 475 test_utils.TestCase): 476 477 def create_event_loop(self): 478 return asyncio.ProactorEventLoop() 479 480else: 481 import selectors 482 483 if hasattr(selectors, 'KqueueSelector'): 484 class KqueueEventLoopTests(BaseSockTestsMixin, 485 test_utils.TestCase): 486 487 def create_event_loop(self): 488 return asyncio.SelectorEventLoop( 489 selectors.KqueueSelector()) 490 491 if hasattr(selectors, 'EpollSelector'): 492 class EPollEventLoopTests(BaseSockTestsMixin, 493 test_utils.TestCase): 494 495 def create_event_loop(self): 496 return asyncio.SelectorEventLoop(selectors.EpollSelector()) 497 498 if hasattr(selectors, 'PollSelector'): 499 class PollEventLoopTests(BaseSockTestsMixin, 500 test_utils.TestCase): 501 502 def create_event_loop(self): 503 return asyncio.SelectorEventLoop(selectors.PollSelector()) 504 505 # Should always exist. 506 class SelectEventLoopTests(BaseSockTestsMixin, 507 test_utils.TestCase): 508 509 def create_event_loop(self): 510 return asyncio.SelectorEventLoop(selectors.SelectSelector()) 511