• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Tests for sendfile functionality."""
2
3import asyncio
4import os
5import socket
6import sys
7import tempfile
8import unittest
9from asyncio import base_events
10from asyncio import constants
11from unittest import mock
12from test import support
13from test.support import socket_helper
14from test.test_asyncio import utils as test_utils
15
16try:
17    import ssl
18except ImportError:
19    ssl = None
20
21
22def tearDownModule():
23    asyncio.set_event_loop_policy(None)
24
25
26class MySendfileProto(asyncio.Protocol):
27
28    def __init__(self, loop=None, close_after=0):
29        self.transport = None
30        self.state = 'INITIAL'
31        self.nbytes = 0
32        if loop is not None:
33            self.connected = loop.create_future()
34            self.done = loop.create_future()
35        self.data = bytearray()
36        self.close_after = close_after
37
38    def connection_made(self, transport):
39        self.transport = transport
40        assert self.state == 'INITIAL', self.state
41        self.state = 'CONNECTED'
42        if self.connected:
43            self.connected.set_result(None)
44
45    def eof_received(self):
46        assert self.state == 'CONNECTED', self.state
47        self.state = 'EOF'
48
49    def connection_lost(self, exc):
50        assert self.state in ('CONNECTED', 'EOF'), self.state
51        self.state = 'CLOSED'
52        if self.done:
53            self.done.set_result(None)
54
55    def data_received(self, data):
56        assert self.state == 'CONNECTED', self.state
57        self.nbytes += len(data)
58        self.data.extend(data)
59        super().data_received(data)
60        if self.close_after and self.nbytes >= self.close_after:
61            self.transport.close()
62
63
64class MyProto(asyncio.Protocol):
65
66    def __init__(self, loop):
67        self.started = False
68        self.closed = False
69        self.data = bytearray()
70        self.fut = loop.create_future()
71        self.transport = None
72
73    def connection_made(self, transport):
74        self.started = True
75        self.transport = transport
76
77    def data_received(self, data):
78        self.data.extend(data)
79
80    def connection_lost(self, exc):
81        self.closed = True
82        self.fut.set_result(None)
83
84    async def wait_closed(self):
85        await self.fut
86
87
88class SendfileBase:
89
90      # 128 KiB plus small unaligned to buffer chunk
91    DATA = b"SendfileBaseData" * (1024 * 8 + 1)
92
93    # Reduce socket buffer size to test on relative small data sets.
94    BUF_SIZE = 4 * 1024   # 4 KiB
95
96    def create_event_loop(self):
97        raise NotImplementedError
98
99    @classmethod
100    def setUpClass(cls):
101        with open(support.TESTFN, 'wb') as fp:
102            fp.write(cls.DATA)
103        super().setUpClass()
104
105    @classmethod
106    def tearDownClass(cls):
107        support.unlink(support.TESTFN)
108        super().tearDownClass()
109
110    def setUp(self):
111        self.file = open(support.TESTFN, 'rb')
112        self.addCleanup(self.file.close)
113        self.loop = self.create_event_loop()
114        self.set_event_loop(self.loop)
115        super().setUp()
116
117    def tearDown(self):
118        # just in case if we have transport close callbacks
119        if not self.loop.is_closed():
120            test_utils.run_briefly(self.loop)
121
122        self.doCleanups()
123        support.gc_collect()
124        super().tearDown()
125
126    def run_loop(self, coro):
127        return self.loop.run_until_complete(coro)
128
129
130class SockSendfileMixin(SendfileBase):
131
132    @classmethod
133    def setUpClass(cls):
134        cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE
135        constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16
136        super().setUpClass()
137
138    @classmethod
139    def tearDownClass(cls):
140        constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize
141        super().tearDownClass()
142
143    def make_socket(self, cleanup=True):
144        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
145        sock.setblocking(False)
146        if cleanup:
147            self.addCleanup(sock.close)
148        return sock
149
150    def reduce_receive_buffer_size(self, sock):
151        # Reduce receive socket buffer size to test on relative
152        # small data sets.
153        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE)
154
155    def reduce_send_buffer_size(self, sock, transport=None):
156        # Reduce send socket buffer size to test on relative small data sets.
157
158        # On macOS, SO_SNDBUF is reset by connect(). So this method
159        # should be called after the socket is connected.
160        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE)
161
162        if transport is not None:
163            transport.set_write_buffer_limits(high=self.BUF_SIZE)
164
165    def prepare_socksendfile(self):
166        proto = MyProto(self.loop)
167        port = socket_helper.find_unused_port()
168        srv_sock = self.make_socket(cleanup=False)
169        srv_sock.bind((socket_helper.HOST, port))
170        server = self.run_loop(self.loop.create_server(
171            lambda: proto, sock=srv_sock))
172        self.reduce_receive_buffer_size(srv_sock)
173
174        sock = self.make_socket()
175        self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
176        self.reduce_send_buffer_size(sock)
177
178        def cleanup():
179            if proto.transport is not None:
180                # can be None if the task was cancelled before
181                # connection_made callback
182                proto.transport.close()
183                self.run_loop(proto.wait_closed())
184
185            server.close()
186            self.run_loop(server.wait_closed())
187
188        self.addCleanup(cleanup)
189
190        return sock, proto
191
192    def test_sock_sendfile_success(self):
193        sock, proto = self.prepare_socksendfile()
194        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
195        sock.close()
196        self.run_loop(proto.wait_closed())
197
198        self.assertEqual(ret, len(self.DATA))
199        self.assertEqual(proto.data, self.DATA)
200        self.assertEqual(self.file.tell(), len(self.DATA))
201
202    def test_sock_sendfile_with_offset_and_count(self):
203        sock, proto = self.prepare_socksendfile()
204        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
205                                                    1000, 2000))
206        sock.close()
207        self.run_loop(proto.wait_closed())
208
209        self.assertEqual(proto.data, self.DATA[1000:3000])
210        self.assertEqual(self.file.tell(), 3000)
211        self.assertEqual(ret, 2000)
212
213    def test_sock_sendfile_zero_size(self):
214        sock, proto = self.prepare_socksendfile()
215        with tempfile.TemporaryFile() as f:
216            ret = self.run_loop(self.loop.sock_sendfile(sock, f,
217                                                        0, None))
218        sock.close()
219        self.run_loop(proto.wait_closed())
220
221        self.assertEqual(ret, 0)
222        self.assertEqual(self.file.tell(), 0)
223
224    def test_sock_sendfile_mix_with_regular_send(self):
225        buf = b"mix_regular_send" * (4 * 1024)  # 64 KiB
226        sock, proto = self.prepare_socksendfile()
227        self.run_loop(self.loop.sock_sendall(sock, buf))
228        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
229        self.run_loop(self.loop.sock_sendall(sock, buf))
230        sock.close()
231        self.run_loop(proto.wait_closed())
232
233        self.assertEqual(ret, len(self.DATA))
234        expected = buf + self.DATA + buf
235        self.assertEqual(proto.data, expected)
236        self.assertEqual(self.file.tell(), len(self.DATA))
237
238
239class SendfileMixin(SendfileBase):
240
241    # Note: sendfile via SSL transport is equal to sendfile fallback
242
243    def prepare_sendfile(self, *, is_ssl=False, close_after=0):
244        port = socket_helper.find_unused_port()
245        srv_proto = MySendfileProto(loop=self.loop,
246                                    close_after=close_after)
247        if is_ssl:
248            if not ssl:
249                self.skipTest("No ssl module")
250            srv_ctx = test_utils.simple_server_sslcontext()
251            cli_ctx = test_utils.simple_client_sslcontext()
252        else:
253            srv_ctx = None
254            cli_ctx = None
255        srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
256        srv_sock.bind((socket_helper.HOST, port))
257        server = self.run_loop(self.loop.create_server(
258            lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
259        self.reduce_receive_buffer_size(srv_sock)
260
261        if is_ssl:
262            server_hostname = socket_helper.HOST
263        else:
264            server_hostname = None
265        cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
266        cli_sock.connect((socket_helper.HOST, port))
267
268        cli_proto = MySendfileProto(loop=self.loop)
269        tr, pr = self.run_loop(self.loop.create_connection(
270            lambda: cli_proto, sock=cli_sock,
271            ssl=cli_ctx, server_hostname=server_hostname))
272        self.reduce_send_buffer_size(cli_sock, transport=tr)
273
274        def cleanup():
275            srv_proto.transport.close()
276            cli_proto.transport.close()
277            self.run_loop(srv_proto.done)
278            self.run_loop(cli_proto.done)
279
280            server.close()
281            self.run_loop(server.wait_closed())
282
283        self.addCleanup(cleanup)
284        return srv_proto, cli_proto
285
286    @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
287    def test_sendfile_not_supported(self):
288        tr, pr = self.run_loop(
289            self.loop.create_datagram_endpoint(
290                asyncio.DatagramProtocol,
291                family=socket.AF_INET))
292        try:
293            with self.assertRaisesRegex(RuntimeError, "not supported"):
294                self.run_loop(
295                    self.loop.sendfile(tr, self.file))
296            self.assertEqual(0, self.file.tell())
297        finally:
298            # don't use self.addCleanup because it produces resource warning
299            tr.close()
300
301    def test_sendfile(self):
302        srv_proto, cli_proto = self.prepare_sendfile()
303        ret = self.run_loop(
304            self.loop.sendfile(cli_proto.transport, self.file))
305        cli_proto.transport.close()
306        self.run_loop(srv_proto.done)
307        self.assertEqual(ret, len(self.DATA))
308        self.assertEqual(srv_proto.nbytes, len(self.DATA))
309        self.assertEqual(srv_proto.data, self.DATA)
310        self.assertEqual(self.file.tell(), len(self.DATA))
311
312    def test_sendfile_force_fallback(self):
313        srv_proto, cli_proto = self.prepare_sendfile()
314
315        def sendfile_native(transp, file, offset, count):
316            # to raise SendfileNotAvailableError
317            return base_events.BaseEventLoop._sendfile_native(
318                self.loop, transp, file, offset, count)
319
320        self.loop._sendfile_native = sendfile_native
321
322        ret = self.run_loop(
323            self.loop.sendfile(cli_proto.transport, self.file))
324        cli_proto.transport.close()
325        self.run_loop(srv_proto.done)
326        self.assertEqual(ret, len(self.DATA))
327        self.assertEqual(srv_proto.nbytes, len(self.DATA))
328        self.assertEqual(srv_proto.data, self.DATA)
329        self.assertEqual(self.file.tell(), len(self.DATA))
330
331    def test_sendfile_force_unsupported_native(self):
332        if sys.platform == 'win32':
333            if isinstance(self.loop, asyncio.ProactorEventLoop):
334                self.skipTest("Fails on proactor event loop")
335        srv_proto, cli_proto = self.prepare_sendfile()
336
337        def sendfile_native(transp, file, offset, count):
338            # to raise SendfileNotAvailableError
339            return base_events.BaseEventLoop._sendfile_native(
340                self.loop, transp, file, offset, count)
341
342        self.loop._sendfile_native = sendfile_native
343
344        with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
345                                    "not supported"):
346            self.run_loop(
347                self.loop.sendfile(cli_proto.transport, self.file,
348                                   fallback=False))
349
350        cli_proto.transport.close()
351        self.run_loop(srv_proto.done)
352        self.assertEqual(srv_proto.nbytes, 0)
353        self.assertEqual(self.file.tell(), 0)
354
355    def test_sendfile_ssl(self):
356        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
357        ret = self.run_loop(
358            self.loop.sendfile(cli_proto.transport, self.file))
359        cli_proto.transport.close()
360        self.run_loop(srv_proto.done)
361        self.assertEqual(ret, len(self.DATA))
362        self.assertEqual(srv_proto.nbytes, len(self.DATA))
363        self.assertEqual(srv_proto.data, self.DATA)
364        self.assertEqual(self.file.tell(), len(self.DATA))
365
366    def test_sendfile_for_closing_transp(self):
367        srv_proto, cli_proto = self.prepare_sendfile()
368        cli_proto.transport.close()
369        with self.assertRaisesRegex(RuntimeError, "is closing"):
370            self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
371        self.run_loop(srv_proto.done)
372        self.assertEqual(srv_proto.nbytes, 0)
373        self.assertEqual(self.file.tell(), 0)
374
375    def test_sendfile_pre_and_post_data(self):
376        srv_proto, cli_proto = self.prepare_sendfile()
377        PREFIX = b'PREFIX__' * 1024  # 8 KiB
378        SUFFIX = b'--SUFFIX' * 1024  # 8 KiB
379        cli_proto.transport.write(PREFIX)
380        ret = self.run_loop(
381            self.loop.sendfile(cli_proto.transport, self.file))
382        cli_proto.transport.write(SUFFIX)
383        cli_proto.transport.close()
384        self.run_loop(srv_proto.done)
385        self.assertEqual(ret, len(self.DATA))
386        self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
387        self.assertEqual(self.file.tell(), len(self.DATA))
388
389    def test_sendfile_ssl_pre_and_post_data(self):
390        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
391        PREFIX = b'zxcvbnm' * 1024
392        SUFFIX = b'0987654321' * 1024
393        cli_proto.transport.write(PREFIX)
394        ret = self.run_loop(
395            self.loop.sendfile(cli_proto.transport, self.file))
396        cli_proto.transport.write(SUFFIX)
397        cli_proto.transport.close()
398        self.run_loop(srv_proto.done)
399        self.assertEqual(ret, len(self.DATA))
400        self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
401        self.assertEqual(self.file.tell(), len(self.DATA))
402
403    def test_sendfile_partial(self):
404        srv_proto, cli_proto = self.prepare_sendfile()
405        ret = self.run_loop(
406            self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
407        cli_proto.transport.close()
408        self.run_loop(srv_proto.done)
409        self.assertEqual(ret, 100)
410        self.assertEqual(srv_proto.nbytes, 100)
411        self.assertEqual(srv_proto.data, self.DATA[1000:1100])
412        self.assertEqual(self.file.tell(), 1100)
413
414    def test_sendfile_ssl_partial(self):
415        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
416        ret = self.run_loop(
417            self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
418        cli_proto.transport.close()
419        self.run_loop(srv_proto.done)
420        self.assertEqual(ret, 100)
421        self.assertEqual(srv_proto.nbytes, 100)
422        self.assertEqual(srv_proto.data, self.DATA[1000:1100])
423        self.assertEqual(self.file.tell(), 1100)
424
425    def test_sendfile_close_peer_after_receiving(self):
426        srv_proto, cli_proto = self.prepare_sendfile(
427            close_after=len(self.DATA))
428        ret = self.run_loop(
429            self.loop.sendfile(cli_proto.transport, self.file))
430        cli_proto.transport.close()
431        self.run_loop(srv_proto.done)
432        self.assertEqual(ret, len(self.DATA))
433        self.assertEqual(srv_proto.nbytes, len(self.DATA))
434        self.assertEqual(srv_proto.data, self.DATA)
435        self.assertEqual(self.file.tell(), len(self.DATA))
436
437    def test_sendfile_ssl_close_peer_after_receiving(self):
438        srv_proto, cli_proto = self.prepare_sendfile(
439            is_ssl=True, close_after=len(self.DATA))
440        ret = self.run_loop(
441            self.loop.sendfile(cli_proto.transport, self.file))
442        self.run_loop(srv_proto.done)
443        self.assertEqual(ret, len(self.DATA))
444        self.assertEqual(srv_proto.nbytes, len(self.DATA))
445        self.assertEqual(srv_proto.data, self.DATA)
446        self.assertEqual(self.file.tell(), len(self.DATA))
447
448    # On Solaris, lowering SO_RCVBUF on a TCP connection after it has been
449    # established has no effect. Due to its age, this bug affects both Oracle
450    # Solaris as well as all other OpenSolaris forks (unless they fixed it
451    # themselves).
452    @unittest.skipIf(sys.platform.startswith('sunos'),
453                     "Doesn't work on Solaris")
454    def test_sendfile_close_peer_in_the_middle_of_receiving(self):
455        srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
456        with self.assertRaises(ConnectionError):
457            self.run_loop(
458                self.loop.sendfile(cli_proto.transport, self.file))
459        self.run_loop(srv_proto.done)
460
461        self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
462                        srv_proto.nbytes)
463        self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
464                        self.file.tell())
465        self.assertTrue(cli_proto.transport.is_closing())
466
467    def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):
468
469        def sendfile_native(transp, file, offset, count):
470            # to raise SendfileNotAvailableError
471            return base_events.BaseEventLoop._sendfile_native(
472                self.loop, transp, file, offset, count)
473
474        self.loop._sendfile_native = sendfile_native
475
476        srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
477        with self.assertRaises(ConnectionError):
478            self.run_loop(
479                self.loop.sendfile(cli_proto.transport, self.file))
480        self.run_loop(srv_proto.done)
481
482        self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
483                        srv_proto.nbytes)
484        self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
485                        self.file.tell())
486
487    @unittest.skipIf(not hasattr(os, 'sendfile'),
488                     "Don't have native sendfile support")
489    def test_sendfile_prevents_bare_write(self):
490        srv_proto, cli_proto = self.prepare_sendfile()
491        fut = self.loop.create_future()
492
493        async def coro():
494            fut.set_result(None)
495            return await self.loop.sendfile(cli_proto.transport, self.file)
496
497        t = self.loop.create_task(coro())
498        self.run_loop(fut)
499        with self.assertRaisesRegex(RuntimeError,
500                                    "sendfile is in progress"):
501            cli_proto.transport.write(b'data')
502        ret = self.run_loop(t)
503        self.assertEqual(ret, len(self.DATA))
504
505    def test_sendfile_no_fallback_for_fallback_transport(self):
506        transport = mock.Mock()
507        transport.is_closing.side_effect = lambda: False
508        transport._sendfile_compatible = constants._SendfileMode.FALLBACK
509        with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'):
510            self.loop.run_until_complete(
511                self.loop.sendfile(transport, None, fallback=False))
512
513
514class SendfileTestsBase(SendfileMixin, SockSendfileMixin):
515    pass
516
517
518if sys.platform == 'win32':
519
520    class SelectEventLoopTests(SendfileTestsBase,
521                               test_utils.TestCase):
522
523        def create_event_loop(self):
524            return asyncio.SelectorEventLoop()
525
526    class ProactorEventLoopTests(SendfileTestsBase,
527                                 test_utils.TestCase):
528
529        def create_event_loop(self):
530            return asyncio.ProactorEventLoop()
531
532else:
533    import selectors
534
535    if hasattr(selectors, 'KqueueSelector'):
536        class KqueueEventLoopTests(SendfileTestsBase,
537                                   test_utils.TestCase):
538
539            def create_event_loop(self):
540                return asyncio.SelectorEventLoop(
541                    selectors.KqueueSelector())
542
543    if hasattr(selectors, 'EpollSelector'):
544        class EPollEventLoopTests(SendfileTestsBase,
545                                  test_utils.TestCase):
546
547            def create_event_loop(self):
548                return asyncio.SelectorEventLoop(selectors.EpollSelector())
549
550    if hasattr(selectors, 'PollSelector'):
551        class PollEventLoopTests(SendfileTestsBase,
552                                 test_utils.TestCase):
553
554            def create_event_loop(self):
555                return asyncio.SelectorEventLoop(selectors.PollSelector())
556
557    # Should always exist.
558    class SelectEventLoopTests(SendfileTestsBase,
559                               test_utils.TestCase):
560
561        def create_event_loop(self):
562            return asyncio.SelectorEventLoop(selectors.SelectSelector())
563