• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Tests for sendfile functionality."""
2
3import asyncio
4import os
5import socket
6import sys
7import tempfile
8import unittest
9from asyncio import base_events
10from asyncio import constants
11from unittest import mock
12from test import support
13from test.support import os_helper
14from test.support import socket_helper
15from test.test_asyncio import utils as test_utils
16
17try:
18    import ssl
19except ImportError:
20    ssl = None
21
22
23def tearDownModule():
24    asyncio.set_event_loop_policy(None)
25
26
27class MySendfileProto(asyncio.Protocol):
28
29    def __init__(self, loop=None, close_after=0):
30        self.transport = None
31        self.state = 'INITIAL'
32        self.nbytes = 0
33        if loop is not None:
34            self.connected = loop.create_future()
35            self.done = loop.create_future()
36        self.data = bytearray()
37        self.close_after = close_after
38
39    def _assert_state(self, *expected):
40        if self.state not in expected:
41            raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
42
43    def connection_made(self, transport):
44        self.transport = transport
45        self._assert_state('INITIAL')
46        self.state = 'CONNECTED'
47        if self.connected:
48            self.connected.set_result(None)
49
50    def eof_received(self):
51        self._assert_state('CONNECTED')
52        self.state = 'EOF'
53
54    def connection_lost(self, exc):
55        self._assert_state('CONNECTED', 'EOF')
56        self.state = 'CLOSED'
57        if self.done:
58            self.done.set_result(None)
59
60    def data_received(self, data):
61        self._assert_state('CONNECTED')
62        self.nbytes += len(data)
63        self.data.extend(data)
64        super().data_received(data)
65        if self.close_after and self.nbytes >= self.close_after:
66            self.transport.close()
67
68
69class MyProto(asyncio.Protocol):
70
71    def __init__(self, loop):
72        self.started = False
73        self.closed = False
74        self.data = bytearray()
75        self.fut = loop.create_future()
76        self.transport = None
77
78    def connection_made(self, transport):
79        self.started = True
80        self.transport = transport
81
82    def data_received(self, data):
83        self.data.extend(data)
84
85    def connection_lost(self, exc):
86        self.closed = True
87        self.fut.set_result(None)
88
89    async def wait_closed(self):
90        await self.fut
91
92
93class SendfileBase:
94
95      # 128 KiB plus small unaligned to buffer chunk
96    DATA = b"SendfileBaseData" * (1024 * 8 + 1)
97
98    # Reduce socket buffer size to test on relative small data sets.
99    BUF_SIZE = 4 * 1024   # 4 KiB
100
101    def create_event_loop(self):
102        raise NotImplementedError
103
104    @classmethod
105    def setUpClass(cls):
106        with open(os_helper.TESTFN, 'wb') as fp:
107            fp.write(cls.DATA)
108        super().setUpClass()
109
110    @classmethod
111    def tearDownClass(cls):
112        os_helper.unlink(os_helper.TESTFN)
113        super().tearDownClass()
114
115    def setUp(self):
116        self.file = open(os_helper.TESTFN, 'rb')
117        self.addCleanup(self.file.close)
118        self.loop = self.create_event_loop()
119        self.set_event_loop(self.loop)
120        super().setUp()
121
122    def tearDown(self):
123        # just in case if we have transport close callbacks
124        if not self.loop.is_closed():
125            test_utils.run_briefly(self.loop)
126
127        self.doCleanups()
128        support.gc_collect()
129        super().tearDown()
130
131    def run_loop(self, coro):
132        return self.loop.run_until_complete(coro)
133
134
135class SockSendfileMixin(SendfileBase):
136
137    @classmethod
138    def setUpClass(cls):
139        cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE
140        constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16
141        super().setUpClass()
142
143    @classmethod
144    def tearDownClass(cls):
145        constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize
146        super().tearDownClass()
147
148    def make_socket(self, cleanup=True):
149        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
150        sock.setblocking(False)
151        if cleanup:
152            self.addCleanup(sock.close)
153        return sock
154
155    def reduce_receive_buffer_size(self, sock):
156        # Reduce receive socket buffer size to test on relative
157        # small data sets.
158        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE)
159
160    def reduce_send_buffer_size(self, sock, transport=None):
161        # Reduce send socket buffer size to test on relative small data sets.
162
163        # On macOS, SO_SNDBUF is reset by connect(). So this method
164        # should be called after the socket is connected.
165        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE)
166
167        if transport is not None:
168            transport.set_write_buffer_limits(high=self.BUF_SIZE)
169
170    def prepare_socksendfile(self):
171        proto = MyProto(self.loop)
172        port = socket_helper.find_unused_port()
173        srv_sock = self.make_socket(cleanup=False)
174        srv_sock.bind((socket_helper.HOST, port))
175        server = self.run_loop(self.loop.create_server(
176            lambda: proto, sock=srv_sock))
177        self.reduce_receive_buffer_size(srv_sock)
178
179        sock = self.make_socket()
180        self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
181        self.reduce_send_buffer_size(sock)
182
183        def cleanup():
184            if proto.transport is not None:
185                # can be None if the task was cancelled before
186                # connection_made callback
187                proto.transport.close()
188                self.run_loop(proto.wait_closed())
189
190            server.close()
191            self.run_loop(server.wait_closed())
192
193        self.addCleanup(cleanup)
194
195        return sock, proto
196
197    def test_sock_sendfile_success(self):
198        sock, proto = self.prepare_socksendfile()
199        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
200        sock.close()
201        self.run_loop(proto.wait_closed())
202
203        self.assertEqual(ret, len(self.DATA))
204        self.assertEqual(proto.data, self.DATA)
205        self.assertEqual(self.file.tell(), len(self.DATA))
206
207    def test_sock_sendfile_with_offset_and_count(self):
208        sock, proto = self.prepare_socksendfile()
209        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
210                                                    1000, 2000))
211        sock.close()
212        self.run_loop(proto.wait_closed())
213
214        self.assertEqual(proto.data, self.DATA[1000:3000])
215        self.assertEqual(self.file.tell(), 3000)
216        self.assertEqual(ret, 2000)
217
218    def test_sock_sendfile_zero_size(self):
219        sock, proto = self.prepare_socksendfile()
220        with tempfile.TemporaryFile() as f:
221            ret = self.run_loop(self.loop.sock_sendfile(sock, f,
222                                                        0, None))
223        sock.close()
224        self.run_loop(proto.wait_closed())
225
226        self.assertEqual(ret, 0)
227        self.assertEqual(self.file.tell(), 0)
228
229    def test_sock_sendfile_mix_with_regular_send(self):
230        buf = b"mix_regular_send" * (4 * 1024)  # 64 KiB
231        sock, proto = self.prepare_socksendfile()
232        self.run_loop(self.loop.sock_sendall(sock, buf))
233        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
234        self.run_loop(self.loop.sock_sendall(sock, buf))
235        sock.close()
236        self.run_loop(proto.wait_closed())
237
238        self.assertEqual(ret, len(self.DATA))
239        expected = buf + self.DATA + buf
240        self.assertEqual(proto.data, expected)
241        self.assertEqual(self.file.tell(), len(self.DATA))
242
243
244class SendfileMixin(SendfileBase):
245
246    # Note: sendfile via SSL transport is equal to sendfile fallback
247
248    def prepare_sendfile(self, *, is_ssl=False, close_after=0):
249        port = socket_helper.find_unused_port()
250        srv_proto = MySendfileProto(loop=self.loop,
251                                    close_after=close_after)
252        if is_ssl:
253            if not ssl:
254                self.skipTest("No ssl module")
255            srv_ctx = test_utils.simple_server_sslcontext()
256            cli_ctx = test_utils.simple_client_sslcontext()
257        else:
258            srv_ctx = None
259            cli_ctx = None
260        srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
261        srv_sock.bind((socket_helper.HOST, port))
262        server = self.run_loop(self.loop.create_server(
263            lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
264        self.reduce_receive_buffer_size(srv_sock)
265
266        if is_ssl:
267            server_hostname = socket_helper.HOST
268        else:
269            server_hostname = None
270        cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
271        cli_sock.connect((socket_helper.HOST, port))
272
273        cli_proto = MySendfileProto(loop=self.loop)
274        tr, pr = self.run_loop(self.loop.create_connection(
275            lambda: cli_proto, sock=cli_sock,
276            ssl=cli_ctx, server_hostname=server_hostname))
277        self.reduce_send_buffer_size(cli_sock, transport=tr)
278
279        def cleanup():
280            srv_proto.transport.close()
281            cli_proto.transport.close()
282            self.run_loop(srv_proto.done)
283            self.run_loop(cli_proto.done)
284
285            server.close()
286            self.run_loop(server.wait_closed())
287
288        self.addCleanup(cleanup)
289        return srv_proto, cli_proto
290
291    @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
292    def test_sendfile_not_supported(self):
293        tr, pr = self.run_loop(
294            self.loop.create_datagram_endpoint(
295                asyncio.DatagramProtocol,
296                family=socket.AF_INET))
297        try:
298            with self.assertRaisesRegex(RuntimeError, "not supported"):
299                self.run_loop(
300                    self.loop.sendfile(tr, self.file))
301            self.assertEqual(0, self.file.tell())
302        finally:
303            # don't use self.addCleanup because it produces resource warning
304            tr.close()
305
306    def test_sendfile(self):
307        srv_proto, cli_proto = self.prepare_sendfile()
308        ret = self.run_loop(
309            self.loop.sendfile(cli_proto.transport, self.file))
310        cli_proto.transport.close()
311        self.run_loop(srv_proto.done)
312        self.assertEqual(ret, len(self.DATA))
313        self.assertEqual(srv_proto.nbytes, len(self.DATA))
314        self.assertEqual(srv_proto.data, self.DATA)
315        self.assertEqual(self.file.tell(), len(self.DATA))
316
317    def test_sendfile_force_fallback(self):
318        srv_proto, cli_proto = self.prepare_sendfile()
319
320        def sendfile_native(transp, file, offset, count):
321            # to raise SendfileNotAvailableError
322            return base_events.BaseEventLoop._sendfile_native(
323                self.loop, transp, file, offset, count)
324
325        self.loop._sendfile_native = sendfile_native
326
327        ret = self.run_loop(
328            self.loop.sendfile(cli_proto.transport, self.file))
329        cli_proto.transport.close()
330        self.run_loop(srv_proto.done)
331        self.assertEqual(ret, len(self.DATA))
332        self.assertEqual(srv_proto.nbytes, len(self.DATA))
333        self.assertEqual(srv_proto.data, self.DATA)
334        self.assertEqual(self.file.tell(), len(self.DATA))
335
336    def test_sendfile_force_unsupported_native(self):
337        if sys.platform == 'win32':
338            if isinstance(self.loop, asyncio.ProactorEventLoop):
339                self.skipTest("Fails on proactor event loop")
340        srv_proto, cli_proto = self.prepare_sendfile()
341
342        def sendfile_native(transp, file, offset, count):
343            # to raise SendfileNotAvailableError
344            return base_events.BaseEventLoop._sendfile_native(
345                self.loop, transp, file, offset, count)
346
347        self.loop._sendfile_native = sendfile_native
348
349        with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
350                                    "not supported"):
351            self.run_loop(
352                self.loop.sendfile(cli_proto.transport, self.file,
353                                   fallback=False))
354
355        cli_proto.transport.close()
356        self.run_loop(srv_proto.done)
357        self.assertEqual(srv_proto.nbytes, 0)
358        self.assertEqual(self.file.tell(), 0)
359
360    def test_sendfile_ssl(self):
361        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
362        ret = self.run_loop(
363            self.loop.sendfile(cli_proto.transport, self.file))
364        cli_proto.transport.close()
365        self.run_loop(srv_proto.done)
366        self.assertEqual(ret, len(self.DATA))
367        self.assertEqual(srv_proto.nbytes, len(self.DATA))
368        self.assertEqual(srv_proto.data, self.DATA)
369        self.assertEqual(self.file.tell(), len(self.DATA))
370
371    def test_sendfile_for_closing_transp(self):
372        srv_proto, cli_proto = self.prepare_sendfile()
373        cli_proto.transport.close()
374        with self.assertRaisesRegex(RuntimeError, "is closing"):
375            self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
376        self.run_loop(srv_proto.done)
377        self.assertEqual(srv_proto.nbytes, 0)
378        self.assertEqual(self.file.tell(), 0)
379
380    def test_sendfile_pre_and_post_data(self):
381        srv_proto, cli_proto = self.prepare_sendfile()
382        PREFIX = b'PREFIX__' * 1024  # 8 KiB
383        SUFFIX = b'--SUFFIX' * 1024  # 8 KiB
384        cli_proto.transport.write(PREFIX)
385        ret = self.run_loop(
386            self.loop.sendfile(cli_proto.transport, self.file))
387        cli_proto.transport.write(SUFFIX)
388        cli_proto.transport.close()
389        self.run_loop(srv_proto.done)
390        self.assertEqual(ret, len(self.DATA))
391        self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
392        self.assertEqual(self.file.tell(), len(self.DATA))
393
394    def test_sendfile_ssl_pre_and_post_data(self):
395        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
396        PREFIX = b'zxcvbnm' * 1024
397        SUFFIX = b'0987654321' * 1024
398        cli_proto.transport.write(PREFIX)
399        ret = self.run_loop(
400            self.loop.sendfile(cli_proto.transport, self.file))
401        cli_proto.transport.write(SUFFIX)
402        cli_proto.transport.close()
403        self.run_loop(srv_proto.done)
404        self.assertEqual(ret, len(self.DATA))
405        self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
406        self.assertEqual(self.file.tell(), len(self.DATA))
407
408    def test_sendfile_partial(self):
409        srv_proto, cli_proto = self.prepare_sendfile()
410        ret = self.run_loop(
411            self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
412        cli_proto.transport.close()
413        self.run_loop(srv_proto.done)
414        self.assertEqual(ret, 100)
415        self.assertEqual(srv_proto.nbytes, 100)
416        self.assertEqual(srv_proto.data, self.DATA[1000:1100])
417        self.assertEqual(self.file.tell(), 1100)
418
419    def test_sendfile_ssl_partial(self):
420        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
421        ret = self.run_loop(
422            self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
423        cli_proto.transport.close()
424        self.run_loop(srv_proto.done)
425        self.assertEqual(ret, 100)
426        self.assertEqual(srv_proto.nbytes, 100)
427        self.assertEqual(srv_proto.data, self.DATA[1000:1100])
428        self.assertEqual(self.file.tell(), 1100)
429
430    def test_sendfile_close_peer_after_receiving(self):
431        srv_proto, cli_proto = self.prepare_sendfile(
432            close_after=len(self.DATA))
433        ret = self.run_loop(
434            self.loop.sendfile(cli_proto.transport, self.file))
435        cli_proto.transport.close()
436        self.run_loop(srv_proto.done)
437        self.assertEqual(ret, len(self.DATA))
438        self.assertEqual(srv_proto.nbytes, len(self.DATA))
439        self.assertEqual(srv_proto.data, self.DATA)
440        self.assertEqual(self.file.tell(), len(self.DATA))
441
442    def test_sendfile_ssl_close_peer_after_receiving(self):
443        srv_proto, cli_proto = self.prepare_sendfile(
444            is_ssl=True, close_after=len(self.DATA))
445        ret = self.run_loop(
446            self.loop.sendfile(cli_proto.transport, self.file))
447        self.run_loop(srv_proto.done)
448        self.assertEqual(ret, len(self.DATA))
449        self.assertEqual(srv_proto.nbytes, len(self.DATA))
450        self.assertEqual(srv_proto.data, self.DATA)
451        self.assertEqual(self.file.tell(), len(self.DATA))
452
453    # On Solaris, lowering SO_RCVBUF on a TCP connection after it has been
454    # established has no effect. Due to its age, this bug affects both Oracle
455    # Solaris as well as all other OpenSolaris forks (unless they fixed it
456    # themselves).
457    @unittest.skipIf(sys.platform.startswith('sunos'),
458                     "Doesn't work on Solaris")
459    def test_sendfile_close_peer_in_the_middle_of_receiving(self):
460        srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
461        with self.assertRaises(ConnectionError):
462            self.run_loop(
463                self.loop.sendfile(cli_proto.transport, self.file))
464        self.run_loop(srv_proto.done)
465
466        self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
467                        srv_proto.nbytes)
468        self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
469                        self.file.tell())
470        self.assertTrue(cli_proto.transport.is_closing())
471
472    def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):
473
474        def sendfile_native(transp, file, offset, count):
475            # to raise SendfileNotAvailableError
476            return base_events.BaseEventLoop._sendfile_native(
477                self.loop, transp, file, offset, count)
478
479        self.loop._sendfile_native = sendfile_native
480
481        srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
482        with self.assertRaises(ConnectionError):
483            self.run_loop(
484                self.loop.sendfile(cli_proto.transport, self.file))
485        self.run_loop(srv_proto.done)
486
487        self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
488                        srv_proto.nbytes)
489        self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
490                        self.file.tell())
491
492    @unittest.skipIf(not hasattr(os, 'sendfile'),
493                     "Don't have native sendfile support")
494    def test_sendfile_prevents_bare_write(self):
495        srv_proto, cli_proto = self.prepare_sendfile()
496        fut = self.loop.create_future()
497
498        async def coro():
499            fut.set_result(None)
500            return await self.loop.sendfile(cli_proto.transport, self.file)
501
502        t = self.loop.create_task(coro())
503        self.run_loop(fut)
504        with self.assertRaisesRegex(RuntimeError,
505                                    "sendfile is in progress"):
506            cli_proto.transport.write(b'data')
507        ret = self.run_loop(t)
508        self.assertEqual(ret, len(self.DATA))
509
510    def test_sendfile_no_fallback_for_fallback_transport(self):
511        transport = mock.Mock()
512        transport.is_closing.side_effect = lambda: False
513        transport._sendfile_compatible = constants._SendfileMode.FALLBACK
514        with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'):
515            self.loop.run_until_complete(
516                self.loop.sendfile(transport, None, fallback=False))
517
518
519class SendfileTestsBase(SendfileMixin, SockSendfileMixin):
520    pass
521
522
523if sys.platform == 'win32':
524
525    class SelectEventLoopTests(SendfileTestsBase,
526                               test_utils.TestCase):
527
528        def create_event_loop(self):
529            return asyncio.SelectorEventLoop()
530
531    class ProactorEventLoopTests(SendfileTestsBase,
532                                 test_utils.TestCase):
533
534        def create_event_loop(self):
535            return asyncio.ProactorEventLoop()
536
537else:
538    import selectors
539
540    if hasattr(selectors, 'KqueueSelector'):
541        class KqueueEventLoopTests(SendfileTestsBase,
542                                   test_utils.TestCase):
543
544            def create_event_loop(self):
545                return asyncio.SelectorEventLoop(
546                    selectors.KqueueSelector())
547
548    if hasattr(selectors, 'EpollSelector'):
549        class EPollEventLoopTests(SendfileTestsBase,
550                                  test_utils.TestCase):
551
552            def create_event_loop(self):
553                return asyncio.SelectorEventLoop(selectors.EpollSelector())
554
555    if hasattr(selectors, 'PollSelector'):
556        class PollEventLoopTests(SendfileTestsBase,
557                                 test_utils.TestCase):
558
559            def create_event_loop(self):
560                return asyncio.SelectorEventLoop(selectors.PollSelector())
561
562    # Should always exist.
563    class SelectEventLoopTests(SendfileTestsBase,
564                               test_utils.TestCase):
565
566        def create_event_loop(self):
567            return asyncio.SelectorEventLoop(selectors.SelectSelector())
568