1"""Tests for sendfile functionality.""" 2 3import asyncio 4import os 5import socket 6import sys 7import tempfile 8import unittest 9from asyncio import base_events 10from asyncio import constants 11from unittest import mock 12from test import support 13from test.support import os_helper 14from test.support import socket_helper 15from test.test_asyncio import utils as test_utils 16 17try: 18 import ssl 19except ImportError: 20 ssl = None 21 22 23def tearDownModule(): 24 asyncio.set_event_loop_policy(None) 25 26 27class MySendfileProto(asyncio.Protocol): 28 29 def __init__(self, loop=None, close_after=0): 30 self.transport = None 31 self.state = 'INITIAL' 32 self.nbytes = 0 33 if loop is not None: 34 self.connected = loop.create_future() 35 self.done = loop.create_future() 36 self.data = bytearray() 37 self.close_after = close_after 38 39 def _assert_state(self, *expected): 40 if self.state not in expected: 41 raise AssertionError(f'state: {self.state!r}, expected: {expected!r}') 42 43 def connection_made(self, transport): 44 self.transport = transport 45 self._assert_state('INITIAL') 46 self.state = 'CONNECTED' 47 if self.connected: 48 self.connected.set_result(None) 49 50 def eof_received(self): 51 self._assert_state('CONNECTED') 52 self.state = 'EOF' 53 54 def connection_lost(self, exc): 55 self._assert_state('CONNECTED', 'EOF') 56 self.state = 'CLOSED' 57 if self.done: 58 self.done.set_result(None) 59 60 def data_received(self, data): 61 self._assert_state('CONNECTED') 62 self.nbytes += len(data) 63 self.data.extend(data) 64 super().data_received(data) 65 if self.close_after and self.nbytes >= self.close_after: 66 self.transport.close() 67 68 69class MyProto(asyncio.Protocol): 70 71 def __init__(self, loop): 72 self.started = False 73 self.closed = False 74 self.data = bytearray() 75 self.fut = loop.create_future() 76 self.transport = None 77 78 def connection_made(self, transport): 79 self.started = True 80 self.transport = transport 81 82 def data_received(self, data): 83 self.data.extend(data) 84 85 def connection_lost(self, exc): 86 self.closed = True 87 self.fut.set_result(None) 88 89 async def wait_closed(self): 90 await self.fut 91 92 93class SendfileBase: 94 95 # 128 KiB plus small unaligned to buffer chunk 96 DATA = b"SendfileBaseData" * (1024 * 8 + 1) 97 98 # Reduce socket buffer size to test on relative small data sets. 99 BUF_SIZE = 4 * 1024 # 4 KiB 100 101 def create_event_loop(self): 102 raise NotImplementedError 103 104 @classmethod 105 def setUpClass(cls): 106 with open(os_helper.TESTFN, 'wb') as fp: 107 fp.write(cls.DATA) 108 super().setUpClass() 109 110 @classmethod 111 def tearDownClass(cls): 112 os_helper.unlink(os_helper.TESTFN) 113 super().tearDownClass() 114 115 def setUp(self): 116 self.file = open(os_helper.TESTFN, 'rb') 117 self.addCleanup(self.file.close) 118 self.loop = self.create_event_loop() 119 self.set_event_loop(self.loop) 120 super().setUp() 121 122 def tearDown(self): 123 # just in case if we have transport close callbacks 124 if not self.loop.is_closed(): 125 test_utils.run_briefly(self.loop) 126 127 self.doCleanups() 128 support.gc_collect() 129 super().tearDown() 130 131 def run_loop(self, coro): 132 return self.loop.run_until_complete(coro) 133 134 135class SockSendfileMixin(SendfileBase): 136 137 @classmethod 138 def setUpClass(cls): 139 cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE 140 constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16 141 super().setUpClass() 142 143 @classmethod 144 def tearDownClass(cls): 145 constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize 146 super().tearDownClass() 147 148 def make_socket(self, cleanup=True): 149 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 150 sock.setblocking(False) 151 if cleanup: 152 self.addCleanup(sock.close) 153 return sock 154 155 def reduce_receive_buffer_size(self, sock): 156 # Reduce receive socket buffer size to test on relative 157 # small data sets. 158 sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE) 159 160 def reduce_send_buffer_size(self, sock, transport=None): 161 # Reduce send socket buffer size to test on relative small data sets. 162 163 # On macOS, SO_SNDBUF is reset by connect(). So this method 164 # should be called after the socket is connected. 165 sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE) 166 167 if transport is not None: 168 transport.set_write_buffer_limits(high=self.BUF_SIZE) 169 170 def prepare_socksendfile(self): 171 proto = MyProto(self.loop) 172 port = socket_helper.find_unused_port() 173 srv_sock = self.make_socket(cleanup=False) 174 srv_sock.bind((socket_helper.HOST, port)) 175 server = self.run_loop(self.loop.create_server( 176 lambda: proto, sock=srv_sock)) 177 self.reduce_receive_buffer_size(srv_sock) 178 179 sock = self.make_socket() 180 self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port))) 181 self.reduce_send_buffer_size(sock) 182 183 def cleanup(): 184 if proto.transport is not None: 185 # can be None if the task was cancelled before 186 # connection_made callback 187 proto.transport.close() 188 self.run_loop(proto.wait_closed()) 189 190 server.close() 191 self.run_loop(server.wait_closed()) 192 193 self.addCleanup(cleanup) 194 195 return sock, proto 196 197 def test_sock_sendfile_success(self): 198 sock, proto = self.prepare_socksendfile() 199 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) 200 sock.close() 201 self.run_loop(proto.wait_closed()) 202 203 self.assertEqual(ret, len(self.DATA)) 204 self.assertEqual(proto.data, self.DATA) 205 self.assertEqual(self.file.tell(), len(self.DATA)) 206 207 def test_sock_sendfile_with_offset_and_count(self): 208 sock, proto = self.prepare_socksendfile() 209 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file, 210 1000, 2000)) 211 sock.close() 212 self.run_loop(proto.wait_closed()) 213 214 self.assertEqual(proto.data, self.DATA[1000:3000]) 215 self.assertEqual(self.file.tell(), 3000) 216 self.assertEqual(ret, 2000) 217 218 def test_sock_sendfile_zero_size(self): 219 sock, proto = self.prepare_socksendfile() 220 with tempfile.TemporaryFile() as f: 221 ret = self.run_loop(self.loop.sock_sendfile(sock, f, 222 0, None)) 223 sock.close() 224 self.run_loop(proto.wait_closed()) 225 226 self.assertEqual(ret, 0) 227 self.assertEqual(self.file.tell(), 0) 228 229 def test_sock_sendfile_mix_with_regular_send(self): 230 buf = b"mix_regular_send" * (4 * 1024) # 64 KiB 231 sock, proto = self.prepare_socksendfile() 232 self.run_loop(self.loop.sock_sendall(sock, buf)) 233 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) 234 self.run_loop(self.loop.sock_sendall(sock, buf)) 235 sock.close() 236 self.run_loop(proto.wait_closed()) 237 238 self.assertEqual(ret, len(self.DATA)) 239 expected = buf + self.DATA + buf 240 self.assertEqual(proto.data, expected) 241 self.assertEqual(self.file.tell(), len(self.DATA)) 242 243 244class SendfileMixin(SendfileBase): 245 246 # Note: sendfile via SSL transport is equal to sendfile fallback 247 248 def prepare_sendfile(self, *, is_ssl=False, close_after=0): 249 port = socket_helper.find_unused_port() 250 srv_proto = MySendfileProto(loop=self.loop, 251 close_after=close_after) 252 if is_ssl: 253 if not ssl: 254 self.skipTest("No ssl module") 255 srv_ctx = test_utils.simple_server_sslcontext() 256 cli_ctx = test_utils.simple_client_sslcontext() 257 else: 258 srv_ctx = None 259 cli_ctx = None 260 srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 261 srv_sock.bind((socket_helper.HOST, port)) 262 server = self.run_loop(self.loop.create_server( 263 lambda: srv_proto, sock=srv_sock, ssl=srv_ctx)) 264 self.reduce_receive_buffer_size(srv_sock) 265 266 if is_ssl: 267 server_hostname = socket_helper.HOST 268 else: 269 server_hostname = None 270 cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 271 cli_sock.connect((socket_helper.HOST, port)) 272 273 cli_proto = MySendfileProto(loop=self.loop) 274 tr, pr = self.run_loop(self.loop.create_connection( 275 lambda: cli_proto, sock=cli_sock, 276 ssl=cli_ctx, server_hostname=server_hostname)) 277 self.reduce_send_buffer_size(cli_sock, transport=tr) 278 279 def cleanup(): 280 srv_proto.transport.close() 281 cli_proto.transport.close() 282 self.run_loop(srv_proto.done) 283 self.run_loop(cli_proto.done) 284 285 server.close() 286 self.run_loop(server.wait_closed()) 287 288 self.addCleanup(cleanup) 289 return srv_proto, cli_proto 290 291 @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported") 292 def test_sendfile_not_supported(self): 293 tr, pr = self.run_loop( 294 self.loop.create_datagram_endpoint( 295 asyncio.DatagramProtocol, 296 family=socket.AF_INET)) 297 try: 298 with self.assertRaisesRegex(RuntimeError, "not supported"): 299 self.run_loop( 300 self.loop.sendfile(tr, self.file)) 301 self.assertEqual(0, self.file.tell()) 302 finally: 303 # don't use self.addCleanup because it produces resource warning 304 tr.close() 305 306 def test_sendfile(self): 307 srv_proto, cli_proto = self.prepare_sendfile() 308 ret = self.run_loop( 309 self.loop.sendfile(cli_proto.transport, self.file)) 310 cli_proto.transport.close() 311 self.run_loop(srv_proto.done) 312 self.assertEqual(ret, len(self.DATA)) 313 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 314 self.assertEqual(srv_proto.data, self.DATA) 315 self.assertEqual(self.file.tell(), len(self.DATA)) 316 317 def test_sendfile_force_fallback(self): 318 srv_proto, cli_proto = self.prepare_sendfile() 319 320 def sendfile_native(transp, file, offset, count): 321 # to raise SendfileNotAvailableError 322 return base_events.BaseEventLoop._sendfile_native( 323 self.loop, transp, file, offset, count) 324 325 self.loop._sendfile_native = sendfile_native 326 327 ret = self.run_loop( 328 self.loop.sendfile(cli_proto.transport, self.file)) 329 cli_proto.transport.close() 330 self.run_loop(srv_proto.done) 331 self.assertEqual(ret, len(self.DATA)) 332 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 333 self.assertEqual(srv_proto.data, self.DATA) 334 self.assertEqual(self.file.tell(), len(self.DATA)) 335 336 def test_sendfile_force_unsupported_native(self): 337 if sys.platform == 'win32': 338 if isinstance(self.loop, asyncio.ProactorEventLoop): 339 self.skipTest("Fails on proactor event loop") 340 srv_proto, cli_proto = self.prepare_sendfile() 341 342 def sendfile_native(transp, file, offset, count): 343 # to raise SendfileNotAvailableError 344 return base_events.BaseEventLoop._sendfile_native( 345 self.loop, transp, file, offset, count) 346 347 self.loop._sendfile_native = sendfile_native 348 349 with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, 350 "not supported"): 351 self.run_loop( 352 self.loop.sendfile(cli_proto.transport, self.file, 353 fallback=False)) 354 355 cli_proto.transport.close() 356 self.run_loop(srv_proto.done) 357 self.assertEqual(srv_proto.nbytes, 0) 358 self.assertEqual(self.file.tell(), 0) 359 360 def test_sendfile_ssl(self): 361 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) 362 ret = self.run_loop( 363 self.loop.sendfile(cli_proto.transport, self.file)) 364 cli_proto.transport.close() 365 self.run_loop(srv_proto.done) 366 self.assertEqual(ret, len(self.DATA)) 367 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 368 self.assertEqual(srv_proto.data, self.DATA) 369 self.assertEqual(self.file.tell(), len(self.DATA)) 370 371 def test_sendfile_for_closing_transp(self): 372 srv_proto, cli_proto = self.prepare_sendfile() 373 cli_proto.transport.close() 374 with self.assertRaisesRegex(RuntimeError, "is closing"): 375 self.run_loop(self.loop.sendfile(cli_proto.transport, self.file)) 376 self.run_loop(srv_proto.done) 377 self.assertEqual(srv_proto.nbytes, 0) 378 self.assertEqual(self.file.tell(), 0) 379 380 def test_sendfile_pre_and_post_data(self): 381 srv_proto, cli_proto = self.prepare_sendfile() 382 PREFIX = b'PREFIX__' * 1024 # 8 KiB 383 SUFFIX = b'--SUFFIX' * 1024 # 8 KiB 384 cli_proto.transport.write(PREFIX) 385 ret = self.run_loop( 386 self.loop.sendfile(cli_proto.transport, self.file)) 387 cli_proto.transport.write(SUFFIX) 388 cli_proto.transport.close() 389 self.run_loop(srv_proto.done) 390 self.assertEqual(ret, len(self.DATA)) 391 self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) 392 self.assertEqual(self.file.tell(), len(self.DATA)) 393 394 def test_sendfile_ssl_pre_and_post_data(self): 395 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) 396 PREFIX = b'zxcvbnm' * 1024 397 SUFFIX = b'0987654321' * 1024 398 cli_proto.transport.write(PREFIX) 399 ret = self.run_loop( 400 self.loop.sendfile(cli_proto.transport, self.file)) 401 cli_proto.transport.write(SUFFIX) 402 cli_proto.transport.close() 403 self.run_loop(srv_proto.done) 404 self.assertEqual(ret, len(self.DATA)) 405 self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) 406 self.assertEqual(self.file.tell(), len(self.DATA)) 407 408 def test_sendfile_partial(self): 409 srv_proto, cli_proto = self.prepare_sendfile() 410 ret = self.run_loop( 411 self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) 412 cli_proto.transport.close() 413 self.run_loop(srv_proto.done) 414 self.assertEqual(ret, 100) 415 self.assertEqual(srv_proto.nbytes, 100) 416 self.assertEqual(srv_proto.data, self.DATA[1000:1100]) 417 self.assertEqual(self.file.tell(), 1100) 418 419 def test_sendfile_ssl_partial(self): 420 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) 421 ret = self.run_loop( 422 self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) 423 cli_proto.transport.close() 424 self.run_loop(srv_proto.done) 425 self.assertEqual(ret, 100) 426 self.assertEqual(srv_proto.nbytes, 100) 427 self.assertEqual(srv_proto.data, self.DATA[1000:1100]) 428 self.assertEqual(self.file.tell(), 1100) 429 430 def test_sendfile_close_peer_after_receiving(self): 431 srv_proto, cli_proto = self.prepare_sendfile( 432 close_after=len(self.DATA)) 433 ret = self.run_loop( 434 self.loop.sendfile(cli_proto.transport, self.file)) 435 cli_proto.transport.close() 436 self.run_loop(srv_proto.done) 437 self.assertEqual(ret, len(self.DATA)) 438 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 439 self.assertEqual(srv_proto.data, self.DATA) 440 self.assertEqual(self.file.tell(), len(self.DATA)) 441 442 def test_sendfile_ssl_close_peer_after_receiving(self): 443 srv_proto, cli_proto = self.prepare_sendfile( 444 is_ssl=True, close_after=len(self.DATA)) 445 ret = self.run_loop( 446 self.loop.sendfile(cli_proto.transport, self.file)) 447 self.run_loop(srv_proto.done) 448 self.assertEqual(ret, len(self.DATA)) 449 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 450 self.assertEqual(srv_proto.data, self.DATA) 451 self.assertEqual(self.file.tell(), len(self.DATA)) 452 453 # On Solaris, lowering SO_RCVBUF on a TCP connection after it has been 454 # established has no effect. Due to its age, this bug affects both Oracle 455 # Solaris as well as all other OpenSolaris forks (unless they fixed it 456 # themselves). 457 @unittest.skipIf(sys.platform.startswith('sunos'), 458 "Doesn't work on Solaris") 459 def test_sendfile_close_peer_in_the_middle_of_receiving(self): 460 srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) 461 with self.assertRaises(ConnectionError): 462 self.run_loop( 463 self.loop.sendfile(cli_proto.transport, self.file)) 464 self.run_loop(srv_proto.done) 465 466 self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), 467 srv_proto.nbytes) 468 self.assertTrue(1024 <= self.file.tell() < len(self.DATA), 469 self.file.tell()) 470 self.assertTrue(cli_proto.transport.is_closing()) 471 472 def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self): 473 474 def sendfile_native(transp, file, offset, count): 475 # to raise SendfileNotAvailableError 476 return base_events.BaseEventLoop._sendfile_native( 477 self.loop, transp, file, offset, count) 478 479 self.loop._sendfile_native = sendfile_native 480 481 srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) 482 with self.assertRaises(ConnectionError): 483 self.run_loop( 484 self.loop.sendfile(cli_proto.transport, self.file)) 485 self.run_loop(srv_proto.done) 486 487 self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), 488 srv_proto.nbytes) 489 self.assertTrue(1024 <= self.file.tell() < len(self.DATA), 490 self.file.tell()) 491 492 @unittest.skipIf(not hasattr(os, 'sendfile'), 493 "Don't have native sendfile support") 494 def test_sendfile_prevents_bare_write(self): 495 srv_proto, cli_proto = self.prepare_sendfile() 496 fut = self.loop.create_future() 497 498 async def coro(): 499 fut.set_result(None) 500 return await self.loop.sendfile(cli_proto.transport, self.file) 501 502 t = self.loop.create_task(coro()) 503 self.run_loop(fut) 504 with self.assertRaisesRegex(RuntimeError, 505 "sendfile is in progress"): 506 cli_proto.transport.write(b'data') 507 ret = self.run_loop(t) 508 self.assertEqual(ret, len(self.DATA)) 509 510 def test_sendfile_no_fallback_for_fallback_transport(self): 511 transport = mock.Mock() 512 transport.is_closing.side_effect = lambda: False 513 transport._sendfile_compatible = constants._SendfileMode.FALLBACK 514 with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'): 515 self.loop.run_until_complete( 516 self.loop.sendfile(transport, None, fallback=False)) 517 518 519class SendfileTestsBase(SendfileMixin, SockSendfileMixin): 520 pass 521 522 523if sys.platform == 'win32': 524 525 class SelectEventLoopTests(SendfileTestsBase, 526 test_utils.TestCase): 527 528 def create_event_loop(self): 529 return asyncio.SelectorEventLoop() 530 531 class ProactorEventLoopTests(SendfileTestsBase, 532 test_utils.TestCase): 533 534 def create_event_loop(self): 535 return asyncio.ProactorEventLoop() 536 537else: 538 import selectors 539 540 if hasattr(selectors, 'KqueueSelector'): 541 class KqueueEventLoopTests(SendfileTestsBase, 542 test_utils.TestCase): 543 544 def create_event_loop(self): 545 return asyncio.SelectorEventLoop( 546 selectors.KqueueSelector()) 547 548 if hasattr(selectors, 'EpollSelector'): 549 class EPollEventLoopTests(SendfileTestsBase, 550 test_utils.TestCase): 551 552 def create_event_loop(self): 553 return asyncio.SelectorEventLoop(selectors.EpollSelector()) 554 555 if hasattr(selectors, 'PollSelector'): 556 class PollEventLoopTests(SendfileTestsBase, 557 test_utils.TestCase): 558 559 def create_event_loop(self): 560 return asyncio.SelectorEventLoop(selectors.PollSelector()) 561 562 # Should always exist. 563 class SelectEventLoopTests(SendfileTestsBase, 564 test_utils.TestCase): 565 566 def create_event_loop(self): 567 return asyncio.SelectorEventLoop(selectors.SelectSelector()) 568