• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Tests for sendfile functionality."""
2
3import asyncio
4import os
5import socket
6import sys
7import tempfile
8import unittest
9from asyncio import base_events
10from asyncio import constants
11from unittest import mock
12from test import support
13from test.test_asyncio import utils as test_utils
14
15try:
16    import ssl
17except ImportError:
18    ssl = None
19
20
21def tearDownModule():
22    asyncio.set_event_loop_policy(None)
23
24
25class MySendfileProto(asyncio.Protocol):
26
27    def __init__(self, loop=None, close_after=0):
28        self.transport = None
29        self.state = 'INITIAL'
30        self.nbytes = 0
31        if loop is not None:
32            self.connected = loop.create_future()
33            self.done = loop.create_future()
34        self.data = bytearray()
35        self.close_after = close_after
36
37    def connection_made(self, transport):
38        self.transport = transport
39        assert self.state == 'INITIAL', self.state
40        self.state = 'CONNECTED'
41        if self.connected:
42            self.connected.set_result(None)
43
44    def eof_received(self):
45        assert self.state == 'CONNECTED', self.state
46        self.state = 'EOF'
47
48    def connection_lost(self, exc):
49        assert self.state in ('CONNECTED', 'EOF'), self.state
50        self.state = 'CLOSED'
51        if self.done:
52            self.done.set_result(None)
53
54    def data_received(self, data):
55        assert self.state == 'CONNECTED', self.state
56        self.nbytes += len(data)
57        self.data.extend(data)
58        super().data_received(data)
59        if self.close_after and self.nbytes >= self.close_after:
60            self.transport.close()
61
62
63class MyProto(asyncio.Protocol):
64
65    def __init__(self, loop):
66        self.started = False
67        self.closed = False
68        self.data = bytearray()
69        self.fut = loop.create_future()
70        self.transport = None
71
72    def connection_made(self, transport):
73        self.started = True
74        self.transport = transport
75
76    def data_received(self, data):
77        self.data.extend(data)
78
79    def connection_lost(self, exc):
80        self.closed = True
81        self.fut.set_result(None)
82
83    async def wait_closed(self):
84        await self.fut
85
86
87class SendfileBase:
88
89      # 128 KiB plus small unaligned to buffer chunk
90    DATA = b"SendfileBaseData" * (1024 * 8 + 1)
91
92    # Reduce socket buffer size to test on relative small data sets.
93    BUF_SIZE = 4 * 1024   # 4 KiB
94
95    def create_event_loop(self):
96        raise NotImplementedError
97
98    @classmethod
99    def setUpClass(cls):
100        with open(support.TESTFN, 'wb') as fp:
101            fp.write(cls.DATA)
102        super().setUpClass()
103
104    @classmethod
105    def tearDownClass(cls):
106        support.unlink(support.TESTFN)
107        super().tearDownClass()
108
109    def setUp(self):
110        self.file = open(support.TESTFN, 'rb')
111        self.addCleanup(self.file.close)
112        self.loop = self.create_event_loop()
113        self.set_event_loop(self.loop)
114        super().setUp()
115
116    def tearDown(self):
117        # just in case if we have transport close callbacks
118        if not self.loop.is_closed():
119            test_utils.run_briefly(self.loop)
120
121        self.doCleanups()
122        support.gc_collect()
123        super().tearDown()
124
125    def run_loop(self, coro):
126        return self.loop.run_until_complete(coro)
127
128
129class SockSendfileMixin(SendfileBase):
130
131    @classmethod
132    def setUpClass(cls):
133        cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE
134        constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16
135        super().setUpClass()
136
137    @classmethod
138    def tearDownClass(cls):
139        constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize
140        super().tearDownClass()
141
142    def make_socket(self, cleanup=True):
143        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
144        sock.setblocking(False)
145        if cleanup:
146            self.addCleanup(sock.close)
147        return sock
148
149    def reduce_receive_buffer_size(self, sock):
150        # Reduce receive socket buffer size to test on relative
151        # small data sets.
152        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE)
153
154    def reduce_send_buffer_size(self, sock, transport=None):
155        # Reduce send socket buffer size to test on relative small data sets.
156
157        # On macOS, SO_SNDBUF is reset by connect(). So this method
158        # should be called after the socket is connected.
159        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE)
160
161        if transport is not None:
162            transport.set_write_buffer_limits(high=self.BUF_SIZE)
163
164    def prepare_socksendfile(self):
165        proto = MyProto(self.loop)
166        port = support.find_unused_port()
167        srv_sock = self.make_socket(cleanup=False)
168        srv_sock.bind((support.HOST, port))
169        server = self.run_loop(self.loop.create_server(
170            lambda: proto, sock=srv_sock))
171        self.reduce_receive_buffer_size(srv_sock)
172
173        sock = self.make_socket()
174        self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
175        self.reduce_send_buffer_size(sock)
176
177        def cleanup():
178            if proto.transport is not None:
179                # can be None if the task was cancelled before
180                # connection_made callback
181                proto.transport.close()
182                self.run_loop(proto.wait_closed())
183
184            server.close()
185            self.run_loop(server.wait_closed())
186
187        self.addCleanup(cleanup)
188
189        return sock, proto
190
191    def test_sock_sendfile_success(self):
192        sock, proto = self.prepare_socksendfile()
193        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
194        sock.close()
195        self.run_loop(proto.wait_closed())
196
197        self.assertEqual(ret, len(self.DATA))
198        self.assertEqual(proto.data, self.DATA)
199        self.assertEqual(self.file.tell(), len(self.DATA))
200
201    def test_sock_sendfile_with_offset_and_count(self):
202        sock, proto = self.prepare_socksendfile()
203        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
204                                                    1000, 2000))
205        sock.close()
206        self.run_loop(proto.wait_closed())
207
208        self.assertEqual(proto.data, self.DATA[1000:3000])
209        self.assertEqual(self.file.tell(), 3000)
210        self.assertEqual(ret, 2000)
211
212    def test_sock_sendfile_zero_size(self):
213        sock, proto = self.prepare_socksendfile()
214        with tempfile.TemporaryFile() as f:
215            ret = self.run_loop(self.loop.sock_sendfile(sock, f,
216                                                        0, None))
217        sock.close()
218        self.run_loop(proto.wait_closed())
219
220        self.assertEqual(ret, 0)
221        self.assertEqual(self.file.tell(), 0)
222
223    def test_sock_sendfile_mix_with_regular_send(self):
224        buf = b"mix_regular_send" * (4 * 1024)  # 64 KiB
225        sock, proto = self.prepare_socksendfile()
226        self.run_loop(self.loop.sock_sendall(sock, buf))
227        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
228        self.run_loop(self.loop.sock_sendall(sock, buf))
229        sock.close()
230        self.run_loop(proto.wait_closed())
231
232        self.assertEqual(ret, len(self.DATA))
233        expected = buf + self.DATA + buf
234        self.assertEqual(proto.data, expected)
235        self.assertEqual(self.file.tell(), len(self.DATA))
236
237
238class SendfileMixin(SendfileBase):
239
240    # Note: sendfile via SSL transport is equal to sendfile fallback
241
242    def prepare_sendfile(self, *, is_ssl=False, close_after=0):
243        port = support.find_unused_port()
244        srv_proto = MySendfileProto(loop=self.loop,
245                                    close_after=close_after)
246        if is_ssl:
247            if not ssl:
248                self.skipTest("No ssl module")
249            srv_ctx = test_utils.simple_server_sslcontext()
250            cli_ctx = test_utils.simple_client_sslcontext()
251        else:
252            srv_ctx = None
253            cli_ctx = None
254        srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
255        srv_sock.bind((support.HOST, port))
256        server = self.run_loop(self.loop.create_server(
257            lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
258        self.reduce_receive_buffer_size(srv_sock)
259
260        if is_ssl:
261            server_hostname = support.HOST
262        else:
263            server_hostname = None
264        cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
265        cli_sock.connect((support.HOST, port))
266
267        cli_proto = MySendfileProto(loop=self.loop)
268        tr, pr = self.run_loop(self.loop.create_connection(
269            lambda: cli_proto, sock=cli_sock,
270            ssl=cli_ctx, server_hostname=server_hostname))
271        self.reduce_send_buffer_size(cli_sock, transport=tr)
272
273        def cleanup():
274            srv_proto.transport.close()
275            cli_proto.transport.close()
276            self.run_loop(srv_proto.done)
277            self.run_loop(cli_proto.done)
278
279            server.close()
280            self.run_loop(server.wait_closed())
281
282        self.addCleanup(cleanup)
283        return srv_proto, cli_proto
284
285    @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
286    def test_sendfile_not_supported(self):
287        tr, pr = self.run_loop(
288            self.loop.create_datagram_endpoint(
289                asyncio.DatagramProtocol,
290                family=socket.AF_INET))
291        try:
292            with self.assertRaisesRegex(RuntimeError, "not supported"):
293                self.run_loop(
294                    self.loop.sendfile(tr, self.file))
295            self.assertEqual(0, self.file.tell())
296        finally:
297            # don't use self.addCleanup because it produces resource warning
298            tr.close()
299
300    def test_sendfile(self):
301        srv_proto, cli_proto = self.prepare_sendfile()
302        ret = self.run_loop(
303            self.loop.sendfile(cli_proto.transport, self.file))
304        cli_proto.transport.close()
305        self.run_loop(srv_proto.done)
306        self.assertEqual(ret, len(self.DATA))
307        self.assertEqual(srv_proto.nbytes, len(self.DATA))
308        self.assertEqual(srv_proto.data, self.DATA)
309        self.assertEqual(self.file.tell(), len(self.DATA))
310
311    def test_sendfile_force_fallback(self):
312        srv_proto, cli_proto = self.prepare_sendfile()
313
314        def sendfile_native(transp, file, offset, count):
315            # to raise SendfileNotAvailableError
316            return base_events.BaseEventLoop._sendfile_native(
317                self.loop, transp, file, offset, count)
318
319        self.loop._sendfile_native = sendfile_native
320
321        ret = self.run_loop(
322            self.loop.sendfile(cli_proto.transport, self.file))
323        cli_proto.transport.close()
324        self.run_loop(srv_proto.done)
325        self.assertEqual(ret, len(self.DATA))
326        self.assertEqual(srv_proto.nbytes, len(self.DATA))
327        self.assertEqual(srv_proto.data, self.DATA)
328        self.assertEqual(self.file.tell(), len(self.DATA))
329
330    def test_sendfile_force_unsupported_native(self):
331        if sys.platform == 'win32':
332            if isinstance(self.loop, asyncio.ProactorEventLoop):
333                self.skipTest("Fails on proactor event loop")
334        srv_proto, cli_proto = self.prepare_sendfile()
335
336        def sendfile_native(transp, file, offset, count):
337            # to raise SendfileNotAvailableError
338            return base_events.BaseEventLoop._sendfile_native(
339                self.loop, transp, file, offset, count)
340
341        self.loop._sendfile_native = sendfile_native
342
343        with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
344                                    "not supported"):
345            self.run_loop(
346                self.loop.sendfile(cli_proto.transport, self.file,
347                                   fallback=False))
348
349        cli_proto.transport.close()
350        self.run_loop(srv_proto.done)
351        self.assertEqual(srv_proto.nbytes, 0)
352        self.assertEqual(self.file.tell(), 0)
353
354    def test_sendfile_ssl(self):
355        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
356        ret = self.run_loop(
357            self.loop.sendfile(cli_proto.transport, self.file))
358        cli_proto.transport.close()
359        self.run_loop(srv_proto.done)
360        self.assertEqual(ret, len(self.DATA))
361        self.assertEqual(srv_proto.nbytes, len(self.DATA))
362        self.assertEqual(srv_proto.data, self.DATA)
363        self.assertEqual(self.file.tell(), len(self.DATA))
364
365    def test_sendfile_for_closing_transp(self):
366        srv_proto, cli_proto = self.prepare_sendfile()
367        cli_proto.transport.close()
368        with self.assertRaisesRegex(RuntimeError, "is closing"):
369            self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
370        self.run_loop(srv_proto.done)
371        self.assertEqual(srv_proto.nbytes, 0)
372        self.assertEqual(self.file.tell(), 0)
373
374    def test_sendfile_pre_and_post_data(self):
375        srv_proto, cli_proto = self.prepare_sendfile()
376        PREFIX = b'PREFIX__' * 1024  # 8 KiB
377        SUFFIX = b'--SUFFIX' * 1024  # 8 KiB
378        cli_proto.transport.write(PREFIX)
379        ret = self.run_loop(
380            self.loop.sendfile(cli_proto.transport, self.file))
381        cli_proto.transport.write(SUFFIX)
382        cli_proto.transport.close()
383        self.run_loop(srv_proto.done)
384        self.assertEqual(ret, len(self.DATA))
385        self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
386        self.assertEqual(self.file.tell(), len(self.DATA))
387
388    def test_sendfile_ssl_pre_and_post_data(self):
389        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
390        PREFIX = b'zxcvbnm' * 1024
391        SUFFIX = b'0987654321' * 1024
392        cli_proto.transport.write(PREFIX)
393        ret = self.run_loop(
394            self.loop.sendfile(cli_proto.transport, self.file))
395        cli_proto.transport.write(SUFFIX)
396        cli_proto.transport.close()
397        self.run_loop(srv_proto.done)
398        self.assertEqual(ret, len(self.DATA))
399        self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
400        self.assertEqual(self.file.tell(), len(self.DATA))
401
402    def test_sendfile_partial(self):
403        srv_proto, cli_proto = self.prepare_sendfile()
404        ret = self.run_loop(
405            self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
406        cli_proto.transport.close()
407        self.run_loop(srv_proto.done)
408        self.assertEqual(ret, 100)
409        self.assertEqual(srv_proto.nbytes, 100)
410        self.assertEqual(srv_proto.data, self.DATA[1000:1100])
411        self.assertEqual(self.file.tell(), 1100)
412
413    def test_sendfile_ssl_partial(self):
414        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
415        ret = self.run_loop(
416            self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
417        cli_proto.transport.close()
418        self.run_loop(srv_proto.done)
419        self.assertEqual(ret, 100)
420        self.assertEqual(srv_proto.nbytes, 100)
421        self.assertEqual(srv_proto.data, self.DATA[1000:1100])
422        self.assertEqual(self.file.tell(), 1100)
423
424    def test_sendfile_close_peer_after_receiving(self):
425        srv_proto, cli_proto = self.prepare_sendfile(
426            close_after=len(self.DATA))
427        ret = self.run_loop(
428            self.loop.sendfile(cli_proto.transport, self.file))
429        cli_proto.transport.close()
430        self.run_loop(srv_proto.done)
431        self.assertEqual(ret, len(self.DATA))
432        self.assertEqual(srv_proto.nbytes, len(self.DATA))
433        self.assertEqual(srv_proto.data, self.DATA)
434        self.assertEqual(self.file.tell(), len(self.DATA))
435
436    def test_sendfile_ssl_close_peer_after_receiving(self):
437        srv_proto, cli_proto = self.prepare_sendfile(
438            is_ssl=True, close_after=len(self.DATA))
439        ret = self.run_loop(
440            self.loop.sendfile(cli_proto.transport, self.file))
441        self.run_loop(srv_proto.done)
442        self.assertEqual(ret, len(self.DATA))
443        self.assertEqual(srv_proto.nbytes, len(self.DATA))
444        self.assertEqual(srv_proto.data, self.DATA)
445        self.assertEqual(self.file.tell(), len(self.DATA))
446
447    def test_sendfile_close_peer_in_the_middle_of_receiving(self):
448        srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
449        with self.assertRaises(ConnectionError):
450            self.run_loop(
451                self.loop.sendfile(cli_proto.transport, self.file))
452        self.run_loop(srv_proto.done)
453
454        self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
455                        srv_proto.nbytes)
456        self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
457                        self.file.tell())
458        self.assertTrue(cli_proto.transport.is_closing())
459
460    def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):
461
462        def sendfile_native(transp, file, offset, count):
463            # to raise SendfileNotAvailableError
464            return base_events.BaseEventLoop._sendfile_native(
465                self.loop, transp, file, offset, count)
466
467        self.loop._sendfile_native = sendfile_native
468
469        srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
470        with self.assertRaises(ConnectionError):
471            self.run_loop(
472                self.loop.sendfile(cli_proto.transport, self.file))
473        self.run_loop(srv_proto.done)
474
475        self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
476                        srv_proto.nbytes)
477        self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
478                        self.file.tell())
479
480    @unittest.skipIf(not hasattr(os, 'sendfile'),
481                     "Don't have native sendfile support")
482    def test_sendfile_prevents_bare_write(self):
483        srv_proto, cli_proto = self.prepare_sendfile()
484        fut = self.loop.create_future()
485
486        async def coro():
487            fut.set_result(None)
488            return await self.loop.sendfile(cli_proto.transport, self.file)
489
490        t = self.loop.create_task(coro())
491        self.run_loop(fut)
492        with self.assertRaisesRegex(RuntimeError,
493                                    "sendfile is in progress"):
494            cli_proto.transport.write(b'data')
495        ret = self.run_loop(t)
496        self.assertEqual(ret, len(self.DATA))
497
498    def test_sendfile_no_fallback_for_fallback_transport(self):
499        transport = mock.Mock()
500        transport.is_closing.side_effect = lambda: False
501        transport._sendfile_compatible = constants._SendfileMode.FALLBACK
502        with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'):
503            self.loop.run_until_complete(
504                self.loop.sendfile(transport, None, fallback=False))
505
506
507class SendfileTestsBase(SendfileMixin, SockSendfileMixin):
508    pass
509
510
511if sys.platform == 'win32':
512
513    class SelectEventLoopTests(SendfileTestsBase,
514                               test_utils.TestCase):
515
516        def create_event_loop(self):
517            return asyncio.SelectorEventLoop()
518
519    class ProactorEventLoopTests(SendfileTestsBase,
520                                 test_utils.TestCase):
521
522        def create_event_loop(self):
523            return asyncio.ProactorEventLoop()
524
525else:
526    import selectors
527
528    if hasattr(selectors, 'KqueueSelector'):
529        class KqueueEventLoopTests(SendfileTestsBase,
530                                   test_utils.TestCase):
531
532            def create_event_loop(self):
533                return asyncio.SelectorEventLoop(
534                    selectors.KqueueSelector())
535
536    if hasattr(selectors, 'EpollSelector'):
537        class EPollEventLoopTests(SendfileTestsBase,
538                                  test_utils.TestCase):
539
540            def create_event_loop(self):
541                return asyncio.SelectorEventLoop(selectors.EpollSelector())
542
543    if hasattr(selectors, 'PollSelector'):
544        class PollEventLoopTests(SendfileTestsBase,
545                                 test_utils.TestCase):
546
547            def create_event_loop(self):
548                return asyncio.SelectorEventLoop(selectors.PollSelector())
549
550    # Should always exist.
551    class SelectEventLoopTests(SendfileTestsBase,
552                               test_utils.TestCase):
553
554        def create_event_loop(self):
555            return asyncio.SelectorEventLoop(selectors.SelectSelector())
556