• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import socket
2import time
3import asyncio
4import sys
5import unittest
6
7from asyncio import proactor_events
8from itertools import cycle, islice
9from test.test_asyncio import utils as test_utils
10from test import support
11from test.support import socket_helper
12
13
14class MyProto(asyncio.Protocol):
15    connected = None
16    done = None
17
18    def __init__(self, loop=None):
19        self.transport = None
20        self.state = 'INITIAL'
21        self.nbytes = 0
22        if loop is not None:
23            self.connected = loop.create_future()
24            self.done = loop.create_future()
25
26    def connection_made(self, transport):
27        self.transport = transport
28        assert self.state == 'INITIAL', self.state
29        self.state = 'CONNECTED'
30        if self.connected:
31            self.connected.set_result(None)
32        transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
33
34    def data_received(self, data):
35        assert self.state == 'CONNECTED', self.state
36        self.nbytes += len(data)
37
38    def eof_received(self):
39        assert self.state == 'CONNECTED', self.state
40        self.state = 'EOF'
41
42    def connection_lost(self, exc):
43        assert self.state in ('CONNECTED', 'EOF'), self.state
44        self.state = 'CLOSED'
45        if self.done:
46            self.done.set_result(None)
47
48
49class BaseSockTestsMixin:
50
51    def create_event_loop(self):
52        raise NotImplementedError
53
54    def setUp(self):
55        self.loop = self.create_event_loop()
56        self.set_event_loop(self.loop)
57        super().setUp()
58
59    def tearDown(self):
60        # just in case if we have transport close callbacks
61        if not self.loop.is_closed():
62            test_utils.run_briefly(self.loop)
63
64        self.doCleanups()
65        support.gc_collect()
66        super().tearDown()
67
68    def _basetest_sock_client_ops(self, httpd, sock):
69        if not isinstance(self.loop, proactor_events.BaseProactorEventLoop):
70            # in debug mode, socket operations must fail
71            # if the socket is not in blocking mode
72            self.loop.set_debug(True)
73            sock.setblocking(True)
74            with self.assertRaises(ValueError):
75                self.loop.run_until_complete(
76                    self.loop.sock_connect(sock, httpd.address))
77            with self.assertRaises(ValueError):
78                self.loop.run_until_complete(
79                    self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
80            with self.assertRaises(ValueError):
81                self.loop.run_until_complete(
82                    self.loop.sock_recv(sock, 1024))
83            with self.assertRaises(ValueError):
84                self.loop.run_until_complete(
85                    self.loop.sock_recv_into(sock, bytearray()))
86            with self.assertRaises(ValueError):
87                self.loop.run_until_complete(
88                    self.loop.sock_accept(sock))
89
90        # test in non-blocking mode
91        sock.setblocking(False)
92        self.loop.run_until_complete(
93            self.loop.sock_connect(sock, httpd.address))
94        self.loop.run_until_complete(
95            self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
96        data = self.loop.run_until_complete(
97            self.loop.sock_recv(sock, 1024))
98        # consume data
99        self.loop.run_until_complete(
100            self.loop.sock_recv(sock, 1024))
101        sock.close()
102        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
103
104    def _basetest_sock_recv_into(self, httpd, sock):
105        # same as _basetest_sock_client_ops, but using sock_recv_into
106        sock.setblocking(False)
107        self.loop.run_until_complete(
108            self.loop.sock_connect(sock, httpd.address))
109        self.loop.run_until_complete(
110            self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
111        data = bytearray(1024)
112        with memoryview(data) as buf:
113            nbytes = self.loop.run_until_complete(
114                self.loop.sock_recv_into(sock, buf[:1024]))
115            # consume data
116            self.loop.run_until_complete(
117                self.loop.sock_recv_into(sock, buf[nbytes:]))
118        sock.close()
119        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
120
121    def test_sock_client_ops(self):
122        with test_utils.run_test_server() as httpd:
123            sock = socket.socket()
124            self._basetest_sock_client_ops(httpd, sock)
125            sock = socket.socket()
126            self._basetest_sock_recv_into(httpd, sock)
127
128    async def _basetest_sock_recv_racing(self, httpd, sock):
129        sock.setblocking(False)
130        await self.loop.sock_connect(sock, httpd.address)
131
132        task = asyncio.create_task(self.loop.sock_recv(sock, 1024))
133        await asyncio.sleep(0)
134        task.cancel()
135
136        asyncio.create_task(
137            self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
138        data = await self.loop.sock_recv(sock, 1024)
139        # consume data
140        await self.loop.sock_recv(sock, 1024)
141
142        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
143
144    async def _basetest_sock_recv_into_racing(self, httpd, sock):
145        sock.setblocking(False)
146        await self.loop.sock_connect(sock, httpd.address)
147
148        data = bytearray(1024)
149        with memoryview(data) as buf:
150            task = asyncio.create_task(
151                self.loop.sock_recv_into(sock, buf[:1024]))
152            await asyncio.sleep(0)
153            task.cancel()
154
155            task = asyncio.create_task(
156                self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
157            nbytes = await self.loop.sock_recv_into(sock, buf[:1024])
158            # consume data
159            await self.loop.sock_recv_into(sock, buf[nbytes:])
160            self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
161
162        await task
163
164    async def _basetest_sock_send_racing(self, listener, sock):
165        listener.bind(('127.0.0.1', 0))
166        listener.listen(1)
167
168        # make connection
169        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
170        sock.setblocking(False)
171        task = asyncio.create_task(
172            self.loop.sock_connect(sock, listener.getsockname()))
173        await asyncio.sleep(0)
174        server = listener.accept()[0]
175        server.setblocking(False)
176
177        with server:
178            await task
179
180            # fill the buffer until sending 5 chars would block
181            size = 8192
182            while size >= 4:
183                with self.assertRaises(BlockingIOError):
184                    while True:
185                        sock.send(b' ' * size)
186                size = int(size / 2)
187
188            # cancel a blocked sock_sendall
189            task = asyncio.create_task(
190                self.loop.sock_sendall(sock, b'hello'))
191            await asyncio.sleep(0)
192            task.cancel()
193
194            # receive everything that is not a space
195            async def recv_all():
196                rv = b''
197                while True:
198                    buf = await self.loop.sock_recv(server, 8192)
199                    if not buf:
200                        return rv
201                    rv += buf.strip()
202            task = asyncio.create_task(recv_all())
203
204            # immediately make another sock_sendall call
205            await self.loop.sock_sendall(sock, b'world')
206            sock.shutdown(socket.SHUT_WR)
207            data = await task
208            # ProactorEventLoop could deliver hello, so endswith is necessary
209            self.assertTrue(data.endswith(b'world'))
210
211    # After the first connect attempt before the listener is ready,
212    # the socket needs time to "recover" to make the next connect call.
213    # On Linux, a second retry will do. On Windows, the waiting time is
214    # unpredictable; and on FreeBSD the socket may never come back
215    # because it's a loopback address. Here we'll just retry for a few
216    # times, and have to skip the test if it's not working. See also:
217    # https://stackoverflow.com/a/54437602/3316267
218    # https://lists.freebsd.org/pipermail/freebsd-current/2005-May/049876.html
219    async def _basetest_sock_connect_racing(self, listener, sock):
220        listener.bind(('127.0.0.1', 0))
221        addr = listener.getsockname()
222        sock.setblocking(False)
223
224        task = asyncio.create_task(self.loop.sock_connect(sock, addr))
225        await asyncio.sleep(0)
226        task.cancel()
227
228        listener.listen(1)
229
230        skip_reason = "Max retries reached"
231        for i in range(128):
232            try:
233                await self.loop.sock_connect(sock, addr)
234            except ConnectionRefusedError as e:
235                skip_reason = e
236            except OSError as e:
237                skip_reason = e
238
239                # Retry only for this error:
240                # [WinError 10022] An invalid argument was supplied
241                if getattr(e, 'winerror', 0) != 10022:
242                    break
243            else:
244                # success
245                return
246
247        self.skipTest(skip_reason)
248
249    def test_sock_client_racing(self):
250        with test_utils.run_test_server() as httpd:
251            sock = socket.socket()
252            with sock:
253                self.loop.run_until_complete(asyncio.wait_for(
254                    self._basetest_sock_recv_racing(httpd, sock), 10))
255            sock = socket.socket()
256            with sock:
257                self.loop.run_until_complete(asyncio.wait_for(
258                    self._basetest_sock_recv_into_racing(httpd, sock), 10))
259        listener = socket.socket()
260        sock = socket.socket()
261        with listener, sock:
262            self.loop.run_until_complete(asyncio.wait_for(
263                self._basetest_sock_send_racing(listener, sock), 10))
264
265    def test_sock_client_connect_racing(self):
266        listener = socket.socket()
267        sock = socket.socket()
268        with listener, sock:
269            self.loop.run_until_complete(asyncio.wait_for(
270                self._basetest_sock_connect_racing(listener, sock), 10))
271
272    async def _basetest_huge_content(self, address):
273        sock = socket.socket()
274        sock.setblocking(False)
275        DATA_SIZE = 10_000_00
276
277        chunk = b'0123456789' * (DATA_SIZE // 10)
278
279        await self.loop.sock_connect(sock, address)
280        await self.loop.sock_sendall(sock,
281                                     (b'POST /loop HTTP/1.0\r\n' +
282                                      b'Content-Length: %d\r\n' % DATA_SIZE +
283                                      b'\r\n'))
284
285        task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
286
287        data = await self.loop.sock_recv(sock, DATA_SIZE)
288        # HTTP headers size is less than MTU,
289        # they are sent by the first packet always
290        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
291        while data.find(b'\r\n\r\n') == -1:
292            data += await self.loop.sock_recv(sock, DATA_SIZE)
293        # Strip headers
294        headers = data[:data.index(b'\r\n\r\n') + 4]
295        data = data[len(headers):]
296
297        size = DATA_SIZE
298        checker = cycle(b'0123456789')
299
300        expected = bytes(islice(checker, len(data)))
301        self.assertEqual(data, expected)
302        size -= len(data)
303
304        while True:
305            data = await self.loop.sock_recv(sock, DATA_SIZE)
306            if not data:
307                break
308            expected = bytes(islice(checker, len(data)))
309            self.assertEqual(data, expected)
310            size -= len(data)
311        self.assertEqual(size, 0)
312
313        await task
314        sock.close()
315
316    def test_huge_content(self):
317        with test_utils.run_test_server() as httpd:
318            self.loop.run_until_complete(
319                self._basetest_huge_content(httpd.address))
320
321    async def _basetest_huge_content_recvinto(self, address):
322        sock = socket.socket()
323        sock.setblocking(False)
324        DATA_SIZE = 10_000_00
325
326        chunk = b'0123456789' * (DATA_SIZE // 10)
327
328        await self.loop.sock_connect(sock, address)
329        await self.loop.sock_sendall(sock,
330                                     (b'POST /loop HTTP/1.0\r\n' +
331                                      b'Content-Length: %d\r\n' % DATA_SIZE +
332                                      b'\r\n'))
333
334        task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
335
336        array = bytearray(DATA_SIZE)
337        buf = memoryview(array)
338
339        nbytes = await self.loop.sock_recv_into(sock, buf)
340        data = bytes(buf[:nbytes])
341        # HTTP headers size is less than MTU,
342        # they are sent by the first packet always
343        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
344        while data.find(b'\r\n\r\n') == -1:
345            nbytes = await self.loop.sock_recv_into(sock, buf)
346            data = bytes(buf[:nbytes])
347        # Strip headers
348        headers = data[:data.index(b'\r\n\r\n') + 4]
349        data = data[len(headers):]
350
351        size = DATA_SIZE
352        checker = cycle(b'0123456789')
353
354        expected = bytes(islice(checker, len(data)))
355        self.assertEqual(data, expected)
356        size -= len(data)
357
358        while True:
359            nbytes = await self.loop.sock_recv_into(sock, buf)
360            data = buf[:nbytes]
361            if not data:
362                break
363            expected = bytes(islice(checker, len(data)))
364            self.assertEqual(data, expected)
365            size -= len(data)
366        self.assertEqual(size, 0)
367
368        await task
369        sock.close()
370
371    def test_huge_content_recvinto(self):
372        with test_utils.run_test_server() as httpd:
373            self.loop.run_until_complete(
374                self._basetest_huge_content_recvinto(httpd.address))
375
376    @socket_helper.skip_unless_bind_unix_socket
377    def test_unix_sock_client_ops(self):
378        with test_utils.run_test_unix_server() as httpd:
379            sock = socket.socket(socket.AF_UNIX)
380            self._basetest_sock_client_ops(httpd, sock)
381            sock = socket.socket(socket.AF_UNIX)
382            self._basetest_sock_recv_into(httpd, sock)
383
384    def test_sock_client_fail(self):
385        # Make sure that we will get an unused port
386        address = None
387        try:
388            s = socket.socket()
389            s.bind(('127.0.0.1', 0))
390            address = s.getsockname()
391        finally:
392            s.close()
393
394        sock = socket.socket()
395        sock.setblocking(False)
396        with self.assertRaises(ConnectionRefusedError):
397            self.loop.run_until_complete(
398                self.loop.sock_connect(sock, address))
399        sock.close()
400
401    def test_sock_accept(self):
402        listener = socket.socket()
403        listener.setblocking(False)
404        listener.bind(('127.0.0.1', 0))
405        listener.listen(1)
406        client = socket.socket()
407        client.connect(listener.getsockname())
408
409        f = self.loop.sock_accept(listener)
410        conn, addr = self.loop.run_until_complete(f)
411        self.assertEqual(conn.gettimeout(), 0)
412        self.assertEqual(addr, client.getsockname())
413        self.assertEqual(client.getpeername(), listener.getsockname())
414        client.close()
415        conn.close()
416        listener.close()
417
418    def test_cancel_sock_accept(self):
419        listener = socket.socket()
420        listener.setblocking(False)
421        listener.bind(('127.0.0.1', 0))
422        listener.listen(1)
423        sockaddr = listener.getsockname()
424        f = asyncio.wait_for(self.loop.sock_accept(listener), 0.1)
425        with self.assertRaises(asyncio.TimeoutError):
426            self.loop.run_until_complete(f)
427
428        listener.close()
429        client = socket.socket()
430        client.setblocking(False)
431        f = self.loop.sock_connect(client, sockaddr)
432        with self.assertRaises(ConnectionRefusedError):
433            self.loop.run_until_complete(f)
434
435        client.close()
436
437    def test_create_connection_sock(self):
438        with test_utils.run_test_server() as httpd:
439            sock = None
440            infos = self.loop.run_until_complete(
441                self.loop.getaddrinfo(
442                    *httpd.address, type=socket.SOCK_STREAM))
443            for family, type, proto, cname, address in infos:
444                try:
445                    sock = socket.socket(family=family, type=type, proto=proto)
446                    sock.setblocking(False)
447                    self.loop.run_until_complete(
448                        self.loop.sock_connect(sock, address))
449                except BaseException:
450                    pass
451                else:
452                    break
453            else:
454                assert False, 'Can not create socket.'
455
456            f = self.loop.create_connection(
457                lambda: MyProto(loop=self.loop), sock=sock)
458            tr, pr = self.loop.run_until_complete(f)
459            self.assertIsInstance(tr, asyncio.Transport)
460            self.assertIsInstance(pr, asyncio.Protocol)
461            self.loop.run_until_complete(pr.done)
462            self.assertGreater(pr.nbytes, 0)
463            tr.close()
464
465
466if sys.platform == 'win32':
467
468    class SelectEventLoopTests(BaseSockTestsMixin,
469                               test_utils.TestCase):
470
471        def create_event_loop(self):
472            return asyncio.SelectorEventLoop()
473
474    class ProactorEventLoopTests(BaseSockTestsMixin,
475                                 test_utils.TestCase):
476
477        def create_event_loop(self):
478            return asyncio.ProactorEventLoop()
479
480else:
481    import selectors
482
483    if hasattr(selectors, 'KqueueSelector'):
484        class KqueueEventLoopTests(BaseSockTestsMixin,
485                                   test_utils.TestCase):
486
487            def create_event_loop(self):
488                return asyncio.SelectorEventLoop(
489                    selectors.KqueueSelector())
490
491    if hasattr(selectors, 'EpollSelector'):
492        class EPollEventLoopTests(BaseSockTestsMixin,
493                                  test_utils.TestCase):
494
495            def create_event_loop(self):
496                return asyncio.SelectorEventLoop(selectors.EpollSelector())
497
498    if hasattr(selectors, 'PollSelector'):
499        class PollEventLoopTests(BaseSockTestsMixin,
500                                 test_utils.TestCase):
501
502            def create_event_loop(self):
503                return asyncio.SelectorEventLoop(selectors.PollSelector())
504
505    # Should always exist.
506    class SelectEventLoopTests(BaseSockTestsMixin,
507                               test_utils.TestCase):
508
509        def create_event_loop(self):
510            return asyncio.SelectorEventLoop(selectors.SelectSelector())
511