• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import socket
2import asyncio
3import sys
4from asyncio import proactor_events
5from itertools import cycle, islice
6from test.test_asyncio import utils as test_utils
7from test import support
8
9
10class MyProto(asyncio.Protocol):
11    connected = None
12    done = None
13
14    def __init__(self, loop=None):
15        self.transport = None
16        self.state = 'INITIAL'
17        self.nbytes = 0
18        if loop is not None:
19            self.connected = loop.create_future()
20            self.done = loop.create_future()
21
22    def connection_made(self, transport):
23        self.transport = transport
24        assert self.state == 'INITIAL', self.state
25        self.state = 'CONNECTED'
26        if self.connected:
27            self.connected.set_result(None)
28        transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
29
30    def data_received(self, data):
31        assert self.state == 'CONNECTED', self.state
32        self.nbytes += len(data)
33
34    def eof_received(self):
35        assert self.state == 'CONNECTED', self.state
36        self.state = 'EOF'
37
38    def connection_lost(self, exc):
39        assert self.state in ('CONNECTED', 'EOF'), self.state
40        self.state = 'CLOSED'
41        if self.done:
42            self.done.set_result(None)
43
44
45class BaseSockTestsMixin:
46
47    def create_event_loop(self):
48        raise NotImplementedError
49
50    def setUp(self):
51        self.loop = self.create_event_loop()
52        self.set_event_loop(self.loop)
53        super().setUp()
54
55    def tearDown(self):
56        # just in case if we have transport close callbacks
57        if not self.loop.is_closed():
58            test_utils.run_briefly(self.loop)
59
60        self.doCleanups()
61        support.gc_collect()
62        super().tearDown()
63
64    def _basetest_sock_client_ops(self, httpd, sock):
65        if not isinstance(self.loop, proactor_events.BaseProactorEventLoop):
66            # in debug mode, socket operations must fail
67            # if the socket is not in blocking mode
68            self.loop.set_debug(True)
69            sock.setblocking(True)
70            with self.assertRaises(ValueError):
71                self.loop.run_until_complete(
72                    self.loop.sock_connect(sock, httpd.address))
73            with self.assertRaises(ValueError):
74                self.loop.run_until_complete(
75                    self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
76            with self.assertRaises(ValueError):
77                self.loop.run_until_complete(
78                    self.loop.sock_recv(sock, 1024))
79            with self.assertRaises(ValueError):
80                self.loop.run_until_complete(
81                    self.loop.sock_recv_into(sock, bytearray()))
82            with self.assertRaises(ValueError):
83                self.loop.run_until_complete(
84                    self.loop.sock_accept(sock))
85
86        # test in non-blocking mode
87        sock.setblocking(False)
88        self.loop.run_until_complete(
89            self.loop.sock_connect(sock, httpd.address))
90        self.loop.run_until_complete(
91            self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
92        data = self.loop.run_until_complete(
93            self.loop.sock_recv(sock, 1024))
94        # consume data
95        self.loop.run_until_complete(
96            self.loop.sock_recv(sock, 1024))
97        sock.close()
98        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
99
100    def _basetest_sock_recv_into(self, httpd, sock):
101        # same as _basetest_sock_client_ops, but using sock_recv_into
102        sock.setblocking(False)
103        self.loop.run_until_complete(
104            self.loop.sock_connect(sock, httpd.address))
105        self.loop.run_until_complete(
106            self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
107        data = bytearray(1024)
108        with memoryview(data) as buf:
109            nbytes = self.loop.run_until_complete(
110                self.loop.sock_recv_into(sock, buf[:1024]))
111            # consume data
112            self.loop.run_until_complete(
113                self.loop.sock_recv_into(sock, buf[nbytes:]))
114        sock.close()
115        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
116
117    def test_sock_client_ops(self):
118        with test_utils.run_test_server() as httpd:
119            sock = socket.socket()
120            self._basetest_sock_client_ops(httpd, sock)
121            sock = socket.socket()
122            self._basetest_sock_recv_into(httpd, sock)
123
124    async def _basetest_huge_content(self, address):
125        sock = socket.socket()
126        sock.setblocking(False)
127        DATA_SIZE = 10_000_00
128
129        chunk = b'0123456789' * (DATA_SIZE // 10)
130
131        await self.loop.sock_connect(sock, address)
132        await self.loop.sock_sendall(sock,
133                                     (b'POST /loop HTTP/1.0\r\n' +
134                                      b'Content-Length: %d\r\n' % DATA_SIZE +
135                                      b'\r\n'))
136
137        task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
138
139        data = await self.loop.sock_recv(sock, DATA_SIZE)
140        # HTTP headers size is less than MTU,
141        # they are sent by the first packet always
142        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
143        while data.find(b'\r\n\r\n') == -1:
144            data += await self.loop.sock_recv(sock, DATA_SIZE)
145        # Strip headers
146        headers = data[:data.index(b'\r\n\r\n') + 4]
147        data = data[len(headers):]
148
149        size = DATA_SIZE
150        checker = cycle(b'0123456789')
151
152        expected = bytes(islice(checker, len(data)))
153        self.assertEqual(data, expected)
154        size -= len(data)
155
156        while True:
157            data = await self.loop.sock_recv(sock, DATA_SIZE)
158            if not data:
159                break
160            expected = bytes(islice(checker, len(data)))
161            self.assertEqual(data, expected)
162            size -= len(data)
163        self.assertEqual(size, 0)
164
165        await task
166        sock.close()
167
168    def test_huge_content(self):
169        with test_utils.run_test_server() as httpd:
170            self.loop.run_until_complete(
171                self._basetest_huge_content(httpd.address))
172
173    async def _basetest_huge_content_recvinto(self, address):
174        sock = socket.socket()
175        sock.setblocking(False)
176        DATA_SIZE = 10_000_00
177
178        chunk = b'0123456789' * (DATA_SIZE // 10)
179
180        await self.loop.sock_connect(sock, address)
181        await self.loop.sock_sendall(sock,
182                                     (b'POST /loop HTTP/1.0\r\n' +
183                                      b'Content-Length: %d\r\n' % DATA_SIZE +
184                                      b'\r\n'))
185
186        task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
187
188        array = bytearray(DATA_SIZE)
189        buf = memoryview(array)
190
191        nbytes = await self.loop.sock_recv_into(sock, buf)
192        data = bytes(buf[:nbytes])
193        # HTTP headers size is less than MTU,
194        # they are sent by the first packet always
195        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
196        while data.find(b'\r\n\r\n') == -1:
197            nbytes = await self.loop.sock_recv_into(sock, buf)
198            data = bytes(buf[:nbytes])
199        # Strip headers
200        headers = data[:data.index(b'\r\n\r\n') + 4]
201        data = data[len(headers):]
202
203        size = DATA_SIZE
204        checker = cycle(b'0123456789')
205
206        expected = bytes(islice(checker, len(data)))
207        self.assertEqual(data, expected)
208        size -= len(data)
209
210        while True:
211            nbytes = await self.loop.sock_recv_into(sock, buf)
212            data = buf[:nbytes]
213            if not data:
214                break
215            expected = bytes(islice(checker, len(data)))
216            self.assertEqual(data, expected)
217            size -= len(data)
218        self.assertEqual(size, 0)
219
220        await task
221        sock.close()
222
223    def test_huge_content_recvinto(self):
224        with test_utils.run_test_server() as httpd:
225            self.loop.run_until_complete(
226                self._basetest_huge_content_recvinto(httpd.address))
227
228    @support.skip_unless_bind_unix_socket
229    def test_unix_sock_client_ops(self):
230        with test_utils.run_test_unix_server() as httpd:
231            sock = socket.socket(socket.AF_UNIX)
232            self._basetest_sock_client_ops(httpd, sock)
233            sock = socket.socket(socket.AF_UNIX)
234            self._basetest_sock_recv_into(httpd, sock)
235
236    def test_sock_client_fail(self):
237        # Make sure that we will get an unused port
238        address = None
239        try:
240            s = socket.socket()
241            s.bind(('127.0.0.1', 0))
242            address = s.getsockname()
243        finally:
244            s.close()
245
246        sock = socket.socket()
247        sock.setblocking(False)
248        with self.assertRaises(ConnectionRefusedError):
249            self.loop.run_until_complete(
250                self.loop.sock_connect(sock, address))
251        sock.close()
252
253    def test_sock_accept(self):
254        listener = socket.socket()
255        listener.setblocking(False)
256        listener.bind(('127.0.0.1', 0))
257        listener.listen(1)
258        client = socket.socket()
259        client.connect(listener.getsockname())
260
261        f = self.loop.sock_accept(listener)
262        conn, addr = self.loop.run_until_complete(f)
263        self.assertEqual(conn.gettimeout(), 0)
264        self.assertEqual(addr, client.getsockname())
265        self.assertEqual(client.getpeername(), listener.getsockname())
266        client.close()
267        conn.close()
268        listener.close()
269
270    def test_create_connection_sock(self):
271        with test_utils.run_test_server() as httpd:
272            sock = None
273            infos = self.loop.run_until_complete(
274                self.loop.getaddrinfo(
275                    *httpd.address, type=socket.SOCK_STREAM))
276            for family, type, proto, cname, address in infos:
277                try:
278                    sock = socket.socket(family=family, type=type, proto=proto)
279                    sock.setblocking(False)
280                    self.loop.run_until_complete(
281                        self.loop.sock_connect(sock, address))
282                except BaseException:
283                    pass
284                else:
285                    break
286            else:
287                assert False, 'Can not create socket.'
288
289            f = self.loop.create_connection(
290                lambda: MyProto(loop=self.loop), sock=sock)
291            tr, pr = self.loop.run_until_complete(f)
292            self.assertIsInstance(tr, asyncio.Transport)
293            self.assertIsInstance(pr, asyncio.Protocol)
294            self.loop.run_until_complete(pr.done)
295            self.assertGreater(pr.nbytes, 0)
296            tr.close()
297
298
299if sys.platform == 'win32':
300
301    class SelectEventLoopTests(BaseSockTestsMixin,
302                               test_utils.TestCase):
303
304        def create_event_loop(self):
305            return asyncio.SelectorEventLoop()
306
307    class ProactorEventLoopTests(BaseSockTestsMixin,
308                                 test_utils.TestCase):
309
310        def create_event_loop(self):
311            return asyncio.ProactorEventLoop()
312
313else:
314    import selectors
315
316    if hasattr(selectors, 'KqueueSelector'):
317        class KqueueEventLoopTests(BaseSockTestsMixin,
318                                   test_utils.TestCase):
319
320            def create_event_loop(self):
321                return asyncio.SelectorEventLoop(
322                    selectors.KqueueSelector())
323
324    if hasattr(selectors, 'EpollSelector'):
325        class EPollEventLoopTests(BaseSockTestsMixin,
326                                  test_utils.TestCase):
327
328            def create_event_loop(self):
329                return asyncio.SelectorEventLoop(selectors.EpollSelector())
330
331    if hasattr(selectors, 'PollSelector'):
332        class PollEventLoopTests(BaseSockTestsMixin,
333                                 test_utils.TestCase):
334
335            def create_event_loop(self):
336                return asyncio.SelectorEventLoop(selectors.PollSelector())
337
338    # Should always exist.
339    class SelectEventLoopTests(BaseSockTestsMixin,
340                               test_utils.TestCase):
341
342        def create_event_loop(self):
343            return asyncio.SelectorEventLoop(selectors.SelectSelector())
344