1import socket 2import asyncio 3import sys 4from asyncio import proactor_events 5from itertools import cycle, islice 6from test.test_asyncio import utils as test_utils 7from test import support 8 9 10class MyProto(asyncio.Protocol): 11 connected = None 12 done = None 13 14 def __init__(self, loop=None): 15 self.transport = None 16 self.state = 'INITIAL' 17 self.nbytes = 0 18 if loop is not None: 19 self.connected = loop.create_future() 20 self.done = loop.create_future() 21 22 def connection_made(self, transport): 23 self.transport = transport 24 assert self.state == 'INITIAL', self.state 25 self.state = 'CONNECTED' 26 if self.connected: 27 self.connected.set_result(None) 28 transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') 29 30 def data_received(self, data): 31 assert self.state == 'CONNECTED', self.state 32 self.nbytes += len(data) 33 34 def eof_received(self): 35 assert self.state == 'CONNECTED', self.state 36 self.state = 'EOF' 37 38 def connection_lost(self, exc): 39 assert self.state in ('CONNECTED', 'EOF'), self.state 40 self.state = 'CLOSED' 41 if self.done: 42 self.done.set_result(None) 43 44 45class BaseSockTestsMixin: 46 47 def create_event_loop(self): 48 raise NotImplementedError 49 50 def setUp(self): 51 self.loop = self.create_event_loop() 52 self.set_event_loop(self.loop) 53 super().setUp() 54 55 def tearDown(self): 56 # just in case if we have transport close callbacks 57 if not self.loop.is_closed(): 58 test_utils.run_briefly(self.loop) 59 60 self.doCleanups() 61 support.gc_collect() 62 super().tearDown() 63 64 def _basetest_sock_client_ops(self, httpd, sock): 65 if not isinstance(self.loop, proactor_events.BaseProactorEventLoop): 66 # in debug mode, socket operations must fail 67 # if the socket is not in blocking mode 68 self.loop.set_debug(True) 69 sock.setblocking(True) 70 with self.assertRaises(ValueError): 71 self.loop.run_until_complete( 72 self.loop.sock_connect(sock, httpd.address)) 73 with self.assertRaises(ValueError): 74 self.loop.run_until_complete( 75 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 76 with self.assertRaises(ValueError): 77 self.loop.run_until_complete( 78 self.loop.sock_recv(sock, 1024)) 79 with self.assertRaises(ValueError): 80 self.loop.run_until_complete( 81 self.loop.sock_recv_into(sock, bytearray())) 82 with self.assertRaises(ValueError): 83 self.loop.run_until_complete( 84 self.loop.sock_accept(sock)) 85 86 # test in non-blocking mode 87 sock.setblocking(False) 88 self.loop.run_until_complete( 89 self.loop.sock_connect(sock, httpd.address)) 90 self.loop.run_until_complete( 91 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 92 data = self.loop.run_until_complete( 93 self.loop.sock_recv(sock, 1024)) 94 # consume data 95 self.loop.run_until_complete( 96 self.loop.sock_recv(sock, 1024)) 97 sock.close() 98 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 99 100 def _basetest_sock_recv_into(self, httpd, sock): 101 # same as _basetest_sock_client_ops, but using sock_recv_into 102 sock.setblocking(False) 103 self.loop.run_until_complete( 104 self.loop.sock_connect(sock, httpd.address)) 105 self.loop.run_until_complete( 106 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) 107 data = bytearray(1024) 108 with memoryview(data) as buf: 109 nbytes = self.loop.run_until_complete( 110 self.loop.sock_recv_into(sock, buf[:1024])) 111 # consume data 112 self.loop.run_until_complete( 113 self.loop.sock_recv_into(sock, buf[nbytes:])) 114 sock.close() 115 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 116 117 def test_sock_client_ops(self): 118 with test_utils.run_test_server() as httpd: 119 sock = socket.socket() 120 self._basetest_sock_client_ops(httpd, sock) 121 sock = socket.socket() 122 self._basetest_sock_recv_into(httpd, sock) 123 124 async def _basetest_huge_content(self, address): 125 sock = socket.socket() 126 sock.setblocking(False) 127 DATA_SIZE = 10_000_00 128 129 chunk = b'0123456789' * (DATA_SIZE // 10) 130 131 await self.loop.sock_connect(sock, address) 132 await self.loop.sock_sendall(sock, 133 (b'POST /loop HTTP/1.0\r\n' + 134 b'Content-Length: %d\r\n' % DATA_SIZE + 135 b'\r\n')) 136 137 task = asyncio.create_task(self.loop.sock_sendall(sock, chunk)) 138 139 data = await self.loop.sock_recv(sock, DATA_SIZE) 140 # HTTP headers size is less than MTU, 141 # they are sent by the first packet always 142 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 143 while data.find(b'\r\n\r\n') == -1: 144 data += await self.loop.sock_recv(sock, DATA_SIZE) 145 # Strip headers 146 headers = data[:data.index(b'\r\n\r\n') + 4] 147 data = data[len(headers):] 148 149 size = DATA_SIZE 150 checker = cycle(b'0123456789') 151 152 expected = bytes(islice(checker, len(data))) 153 self.assertEqual(data, expected) 154 size -= len(data) 155 156 while True: 157 data = await self.loop.sock_recv(sock, DATA_SIZE) 158 if not data: 159 break 160 expected = bytes(islice(checker, len(data))) 161 self.assertEqual(data, expected) 162 size -= len(data) 163 self.assertEqual(size, 0) 164 165 await task 166 sock.close() 167 168 def test_huge_content(self): 169 with test_utils.run_test_server() as httpd: 170 self.loop.run_until_complete( 171 self._basetest_huge_content(httpd.address)) 172 173 async def _basetest_huge_content_recvinto(self, address): 174 sock = socket.socket() 175 sock.setblocking(False) 176 DATA_SIZE = 10_000_00 177 178 chunk = b'0123456789' * (DATA_SIZE // 10) 179 180 await self.loop.sock_connect(sock, address) 181 await self.loop.sock_sendall(sock, 182 (b'POST /loop HTTP/1.0\r\n' + 183 b'Content-Length: %d\r\n' % DATA_SIZE + 184 b'\r\n')) 185 186 task = asyncio.create_task(self.loop.sock_sendall(sock, chunk)) 187 188 array = bytearray(DATA_SIZE) 189 buf = memoryview(array) 190 191 nbytes = await self.loop.sock_recv_into(sock, buf) 192 data = bytes(buf[:nbytes]) 193 # HTTP headers size is less than MTU, 194 # they are sent by the first packet always 195 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) 196 while data.find(b'\r\n\r\n') == -1: 197 nbytes = await self.loop.sock_recv_into(sock, buf) 198 data = bytes(buf[:nbytes]) 199 # Strip headers 200 headers = data[:data.index(b'\r\n\r\n') + 4] 201 data = data[len(headers):] 202 203 size = DATA_SIZE 204 checker = cycle(b'0123456789') 205 206 expected = bytes(islice(checker, len(data))) 207 self.assertEqual(data, expected) 208 size -= len(data) 209 210 while True: 211 nbytes = await self.loop.sock_recv_into(sock, buf) 212 data = buf[:nbytes] 213 if not data: 214 break 215 expected = bytes(islice(checker, len(data))) 216 self.assertEqual(data, expected) 217 size -= len(data) 218 self.assertEqual(size, 0) 219 220 await task 221 sock.close() 222 223 def test_huge_content_recvinto(self): 224 with test_utils.run_test_server() as httpd: 225 self.loop.run_until_complete( 226 self._basetest_huge_content_recvinto(httpd.address)) 227 228 @support.skip_unless_bind_unix_socket 229 def test_unix_sock_client_ops(self): 230 with test_utils.run_test_unix_server() as httpd: 231 sock = socket.socket(socket.AF_UNIX) 232 self._basetest_sock_client_ops(httpd, sock) 233 sock = socket.socket(socket.AF_UNIX) 234 self._basetest_sock_recv_into(httpd, sock) 235 236 def test_sock_client_fail(self): 237 # Make sure that we will get an unused port 238 address = None 239 try: 240 s = socket.socket() 241 s.bind(('127.0.0.1', 0)) 242 address = s.getsockname() 243 finally: 244 s.close() 245 246 sock = socket.socket() 247 sock.setblocking(False) 248 with self.assertRaises(ConnectionRefusedError): 249 self.loop.run_until_complete( 250 self.loop.sock_connect(sock, address)) 251 sock.close() 252 253 def test_sock_accept(self): 254 listener = socket.socket() 255 listener.setblocking(False) 256 listener.bind(('127.0.0.1', 0)) 257 listener.listen(1) 258 client = socket.socket() 259 client.connect(listener.getsockname()) 260 261 f = self.loop.sock_accept(listener) 262 conn, addr = self.loop.run_until_complete(f) 263 self.assertEqual(conn.gettimeout(), 0) 264 self.assertEqual(addr, client.getsockname()) 265 self.assertEqual(client.getpeername(), listener.getsockname()) 266 client.close() 267 conn.close() 268 listener.close() 269 270 def test_create_connection_sock(self): 271 with test_utils.run_test_server() as httpd: 272 sock = None 273 infos = self.loop.run_until_complete( 274 self.loop.getaddrinfo( 275 *httpd.address, type=socket.SOCK_STREAM)) 276 for family, type, proto, cname, address in infos: 277 try: 278 sock = socket.socket(family=family, type=type, proto=proto) 279 sock.setblocking(False) 280 self.loop.run_until_complete( 281 self.loop.sock_connect(sock, address)) 282 except BaseException: 283 pass 284 else: 285 break 286 else: 287 assert False, 'Can not create socket.' 288 289 f = self.loop.create_connection( 290 lambda: MyProto(loop=self.loop), sock=sock) 291 tr, pr = self.loop.run_until_complete(f) 292 self.assertIsInstance(tr, asyncio.Transport) 293 self.assertIsInstance(pr, asyncio.Protocol) 294 self.loop.run_until_complete(pr.done) 295 self.assertGreater(pr.nbytes, 0) 296 tr.close() 297 298 299if sys.platform == 'win32': 300 301 class SelectEventLoopTests(BaseSockTestsMixin, 302 test_utils.TestCase): 303 304 def create_event_loop(self): 305 return asyncio.SelectorEventLoop() 306 307 class ProactorEventLoopTests(BaseSockTestsMixin, 308 test_utils.TestCase): 309 310 def create_event_loop(self): 311 return asyncio.ProactorEventLoop() 312 313else: 314 import selectors 315 316 if hasattr(selectors, 'KqueueSelector'): 317 class KqueueEventLoopTests(BaseSockTestsMixin, 318 test_utils.TestCase): 319 320 def create_event_loop(self): 321 return asyncio.SelectorEventLoop( 322 selectors.KqueueSelector()) 323 324 if hasattr(selectors, 'EpollSelector'): 325 class EPollEventLoopTests(BaseSockTestsMixin, 326 test_utils.TestCase): 327 328 def create_event_loop(self): 329 return asyncio.SelectorEventLoop(selectors.EpollSelector()) 330 331 if hasattr(selectors, 'PollSelector'): 332 class PollEventLoopTests(BaseSockTestsMixin, 333 test_utils.TestCase): 334 335 def create_event_loop(self): 336 return asyncio.SelectorEventLoop(selectors.PollSelector()) 337 338 # Should always exist. 339 class SelectEventLoopTests(BaseSockTestsMixin, 340 test_utils.TestCase): 341 342 def create_event_loop(self): 343 return asyncio.SelectorEventLoop(selectors.SelectSelector()) 344