1"""Tests for sendfile functionality.""" 2 3import asyncio 4import os 5import socket 6import sys 7import tempfile 8import unittest 9from asyncio import base_events 10from asyncio import constants 11from unittest import mock 12from test import support 13from test.support import socket_helper 14from test.test_asyncio import utils as test_utils 15 16try: 17 import ssl 18except ImportError: 19 ssl = None 20 21 22def tearDownModule(): 23 asyncio.set_event_loop_policy(None) 24 25 26class MySendfileProto(asyncio.Protocol): 27 28 def __init__(self, loop=None, close_after=0): 29 self.transport = None 30 self.state = 'INITIAL' 31 self.nbytes = 0 32 if loop is not None: 33 self.connected = loop.create_future() 34 self.done = loop.create_future() 35 self.data = bytearray() 36 self.close_after = close_after 37 38 def connection_made(self, transport): 39 self.transport = transport 40 assert self.state == 'INITIAL', self.state 41 self.state = 'CONNECTED' 42 if self.connected: 43 self.connected.set_result(None) 44 45 def eof_received(self): 46 assert self.state == 'CONNECTED', self.state 47 self.state = 'EOF' 48 49 def connection_lost(self, exc): 50 assert self.state in ('CONNECTED', 'EOF'), self.state 51 self.state = 'CLOSED' 52 if self.done: 53 self.done.set_result(None) 54 55 def data_received(self, data): 56 assert self.state == 'CONNECTED', self.state 57 self.nbytes += len(data) 58 self.data.extend(data) 59 super().data_received(data) 60 if self.close_after and self.nbytes >= self.close_after: 61 self.transport.close() 62 63 64class MyProto(asyncio.Protocol): 65 66 def __init__(self, loop): 67 self.started = False 68 self.closed = False 69 self.data = bytearray() 70 self.fut = loop.create_future() 71 self.transport = None 72 73 def connection_made(self, transport): 74 self.started = True 75 self.transport = transport 76 77 def data_received(self, data): 78 self.data.extend(data) 79 80 def connection_lost(self, exc): 81 self.closed = True 82 self.fut.set_result(None) 83 84 async def wait_closed(self): 85 await self.fut 86 87 88class SendfileBase: 89 90 # 128 KiB plus small unaligned to buffer chunk 91 DATA = b"SendfileBaseData" * (1024 * 8 + 1) 92 93 # Reduce socket buffer size to test on relative small data sets. 94 BUF_SIZE = 4 * 1024 # 4 KiB 95 96 def create_event_loop(self): 97 raise NotImplementedError 98 99 @classmethod 100 def setUpClass(cls): 101 with open(support.TESTFN, 'wb') as fp: 102 fp.write(cls.DATA) 103 super().setUpClass() 104 105 @classmethod 106 def tearDownClass(cls): 107 support.unlink(support.TESTFN) 108 super().tearDownClass() 109 110 def setUp(self): 111 self.file = open(support.TESTFN, 'rb') 112 self.addCleanup(self.file.close) 113 self.loop = self.create_event_loop() 114 self.set_event_loop(self.loop) 115 super().setUp() 116 117 def tearDown(self): 118 # just in case if we have transport close callbacks 119 if not self.loop.is_closed(): 120 test_utils.run_briefly(self.loop) 121 122 self.doCleanups() 123 support.gc_collect() 124 super().tearDown() 125 126 def run_loop(self, coro): 127 return self.loop.run_until_complete(coro) 128 129 130class SockSendfileMixin(SendfileBase): 131 132 @classmethod 133 def setUpClass(cls): 134 cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE 135 constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16 136 super().setUpClass() 137 138 @classmethod 139 def tearDownClass(cls): 140 constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize 141 super().tearDownClass() 142 143 def make_socket(self, cleanup=True): 144 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 145 sock.setblocking(False) 146 if cleanup: 147 self.addCleanup(sock.close) 148 return sock 149 150 def reduce_receive_buffer_size(self, sock): 151 # Reduce receive socket buffer size to test on relative 152 # small data sets. 153 sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE) 154 155 def reduce_send_buffer_size(self, sock, transport=None): 156 # Reduce send socket buffer size to test on relative small data sets. 157 158 # On macOS, SO_SNDBUF is reset by connect(). So this method 159 # should be called after the socket is connected. 160 sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE) 161 162 if transport is not None: 163 transport.set_write_buffer_limits(high=self.BUF_SIZE) 164 165 def prepare_socksendfile(self): 166 proto = MyProto(self.loop) 167 port = socket_helper.find_unused_port() 168 srv_sock = self.make_socket(cleanup=False) 169 srv_sock.bind((socket_helper.HOST, port)) 170 server = self.run_loop(self.loop.create_server( 171 lambda: proto, sock=srv_sock)) 172 self.reduce_receive_buffer_size(srv_sock) 173 174 sock = self.make_socket() 175 self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port))) 176 self.reduce_send_buffer_size(sock) 177 178 def cleanup(): 179 if proto.transport is not None: 180 # can be None if the task was cancelled before 181 # connection_made callback 182 proto.transport.close() 183 self.run_loop(proto.wait_closed()) 184 185 server.close() 186 self.run_loop(server.wait_closed()) 187 188 self.addCleanup(cleanup) 189 190 return sock, proto 191 192 def test_sock_sendfile_success(self): 193 sock, proto = self.prepare_socksendfile() 194 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) 195 sock.close() 196 self.run_loop(proto.wait_closed()) 197 198 self.assertEqual(ret, len(self.DATA)) 199 self.assertEqual(proto.data, self.DATA) 200 self.assertEqual(self.file.tell(), len(self.DATA)) 201 202 def test_sock_sendfile_with_offset_and_count(self): 203 sock, proto = self.prepare_socksendfile() 204 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file, 205 1000, 2000)) 206 sock.close() 207 self.run_loop(proto.wait_closed()) 208 209 self.assertEqual(proto.data, self.DATA[1000:3000]) 210 self.assertEqual(self.file.tell(), 3000) 211 self.assertEqual(ret, 2000) 212 213 def test_sock_sendfile_zero_size(self): 214 sock, proto = self.prepare_socksendfile() 215 with tempfile.TemporaryFile() as f: 216 ret = self.run_loop(self.loop.sock_sendfile(sock, f, 217 0, None)) 218 sock.close() 219 self.run_loop(proto.wait_closed()) 220 221 self.assertEqual(ret, 0) 222 self.assertEqual(self.file.tell(), 0) 223 224 def test_sock_sendfile_mix_with_regular_send(self): 225 buf = b"mix_regular_send" * (4 * 1024) # 64 KiB 226 sock, proto = self.prepare_socksendfile() 227 self.run_loop(self.loop.sock_sendall(sock, buf)) 228 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) 229 self.run_loop(self.loop.sock_sendall(sock, buf)) 230 sock.close() 231 self.run_loop(proto.wait_closed()) 232 233 self.assertEqual(ret, len(self.DATA)) 234 expected = buf + self.DATA + buf 235 self.assertEqual(proto.data, expected) 236 self.assertEqual(self.file.tell(), len(self.DATA)) 237 238 239class SendfileMixin(SendfileBase): 240 241 # Note: sendfile via SSL transport is equal to sendfile fallback 242 243 def prepare_sendfile(self, *, is_ssl=False, close_after=0): 244 port = socket_helper.find_unused_port() 245 srv_proto = MySendfileProto(loop=self.loop, 246 close_after=close_after) 247 if is_ssl: 248 if not ssl: 249 self.skipTest("No ssl module") 250 srv_ctx = test_utils.simple_server_sslcontext() 251 cli_ctx = test_utils.simple_client_sslcontext() 252 else: 253 srv_ctx = None 254 cli_ctx = None 255 srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 256 srv_sock.bind((socket_helper.HOST, port)) 257 server = self.run_loop(self.loop.create_server( 258 lambda: srv_proto, sock=srv_sock, ssl=srv_ctx)) 259 self.reduce_receive_buffer_size(srv_sock) 260 261 if is_ssl: 262 server_hostname = socket_helper.HOST 263 else: 264 server_hostname = None 265 cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 266 cli_sock.connect((socket_helper.HOST, port)) 267 268 cli_proto = MySendfileProto(loop=self.loop) 269 tr, pr = self.run_loop(self.loop.create_connection( 270 lambda: cli_proto, sock=cli_sock, 271 ssl=cli_ctx, server_hostname=server_hostname)) 272 self.reduce_send_buffer_size(cli_sock, transport=tr) 273 274 def cleanup(): 275 srv_proto.transport.close() 276 cli_proto.transport.close() 277 self.run_loop(srv_proto.done) 278 self.run_loop(cli_proto.done) 279 280 server.close() 281 self.run_loop(server.wait_closed()) 282 283 self.addCleanup(cleanup) 284 return srv_proto, cli_proto 285 286 @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported") 287 def test_sendfile_not_supported(self): 288 tr, pr = self.run_loop( 289 self.loop.create_datagram_endpoint( 290 asyncio.DatagramProtocol, 291 family=socket.AF_INET)) 292 try: 293 with self.assertRaisesRegex(RuntimeError, "not supported"): 294 self.run_loop( 295 self.loop.sendfile(tr, self.file)) 296 self.assertEqual(0, self.file.tell()) 297 finally: 298 # don't use self.addCleanup because it produces resource warning 299 tr.close() 300 301 def test_sendfile(self): 302 srv_proto, cli_proto = self.prepare_sendfile() 303 ret = self.run_loop( 304 self.loop.sendfile(cli_proto.transport, self.file)) 305 cli_proto.transport.close() 306 self.run_loop(srv_proto.done) 307 self.assertEqual(ret, len(self.DATA)) 308 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 309 self.assertEqual(srv_proto.data, self.DATA) 310 self.assertEqual(self.file.tell(), len(self.DATA)) 311 312 def test_sendfile_force_fallback(self): 313 srv_proto, cli_proto = self.prepare_sendfile() 314 315 def sendfile_native(transp, file, offset, count): 316 # to raise SendfileNotAvailableError 317 return base_events.BaseEventLoop._sendfile_native( 318 self.loop, transp, file, offset, count) 319 320 self.loop._sendfile_native = sendfile_native 321 322 ret = self.run_loop( 323 self.loop.sendfile(cli_proto.transport, self.file)) 324 cli_proto.transport.close() 325 self.run_loop(srv_proto.done) 326 self.assertEqual(ret, len(self.DATA)) 327 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 328 self.assertEqual(srv_proto.data, self.DATA) 329 self.assertEqual(self.file.tell(), len(self.DATA)) 330 331 def test_sendfile_force_unsupported_native(self): 332 if sys.platform == 'win32': 333 if isinstance(self.loop, asyncio.ProactorEventLoop): 334 self.skipTest("Fails on proactor event loop") 335 srv_proto, cli_proto = self.prepare_sendfile() 336 337 def sendfile_native(transp, file, offset, count): 338 # to raise SendfileNotAvailableError 339 return base_events.BaseEventLoop._sendfile_native( 340 self.loop, transp, file, offset, count) 341 342 self.loop._sendfile_native = sendfile_native 343 344 with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, 345 "not supported"): 346 self.run_loop( 347 self.loop.sendfile(cli_proto.transport, self.file, 348 fallback=False)) 349 350 cli_proto.transport.close() 351 self.run_loop(srv_proto.done) 352 self.assertEqual(srv_proto.nbytes, 0) 353 self.assertEqual(self.file.tell(), 0) 354 355 def test_sendfile_ssl(self): 356 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) 357 ret = self.run_loop( 358 self.loop.sendfile(cli_proto.transport, self.file)) 359 cli_proto.transport.close() 360 self.run_loop(srv_proto.done) 361 self.assertEqual(ret, len(self.DATA)) 362 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 363 self.assertEqual(srv_proto.data, self.DATA) 364 self.assertEqual(self.file.tell(), len(self.DATA)) 365 366 def test_sendfile_for_closing_transp(self): 367 srv_proto, cli_proto = self.prepare_sendfile() 368 cli_proto.transport.close() 369 with self.assertRaisesRegex(RuntimeError, "is closing"): 370 self.run_loop(self.loop.sendfile(cli_proto.transport, self.file)) 371 self.run_loop(srv_proto.done) 372 self.assertEqual(srv_proto.nbytes, 0) 373 self.assertEqual(self.file.tell(), 0) 374 375 def test_sendfile_pre_and_post_data(self): 376 srv_proto, cli_proto = self.prepare_sendfile() 377 PREFIX = b'PREFIX__' * 1024 # 8 KiB 378 SUFFIX = b'--SUFFIX' * 1024 # 8 KiB 379 cli_proto.transport.write(PREFIX) 380 ret = self.run_loop( 381 self.loop.sendfile(cli_proto.transport, self.file)) 382 cli_proto.transport.write(SUFFIX) 383 cli_proto.transport.close() 384 self.run_loop(srv_proto.done) 385 self.assertEqual(ret, len(self.DATA)) 386 self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) 387 self.assertEqual(self.file.tell(), len(self.DATA)) 388 389 def test_sendfile_ssl_pre_and_post_data(self): 390 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) 391 PREFIX = b'zxcvbnm' * 1024 392 SUFFIX = b'0987654321' * 1024 393 cli_proto.transport.write(PREFIX) 394 ret = self.run_loop( 395 self.loop.sendfile(cli_proto.transport, self.file)) 396 cli_proto.transport.write(SUFFIX) 397 cli_proto.transport.close() 398 self.run_loop(srv_proto.done) 399 self.assertEqual(ret, len(self.DATA)) 400 self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) 401 self.assertEqual(self.file.tell(), len(self.DATA)) 402 403 def test_sendfile_partial(self): 404 srv_proto, cli_proto = self.prepare_sendfile() 405 ret = self.run_loop( 406 self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) 407 cli_proto.transport.close() 408 self.run_loop(srv_proto.done) 409 self.assertEqual(ret, 100) 410 self.assertEqual(srv_proto.nbytes, 100) 411 self.assertEqual(srv_proto.data, self.DATA[1000:1100]) 412 self.assertEqual(self.file.tell(), 1100) 413 414 def test_sendfile_ssl_partial(self): 415 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) 416 ret = self.run_loop( 417 self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) 418 cli_proto.transport.close() 419 self.run_loop(srv_proto.done) 420 self.assertEqual(ret, 100) 421 self.assertEqual(srv_proto.nbytes, 100) 422 self.assertEqual(srv_proto.data, self.DATA[1000:1100]) 423 self.assertEqual(self.file.tell(), 1100) 424 425 def test_sendfile_close_peer_after_receiving(self): 426 srv_proto, cli_proto = self.prepare_sendfile( 427 close_after=len(self.DATA)) 428 ret = self.run_loop( 429 self.loop.sendfile(cli_proto.transport, self.file)) 430 cli_proto.transport.close() 431 self.run_loop(srv_proto.done) 432 self.assertEqual(ret, len(self.DATA)) 433 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 434 self.assertEqual(srv_proto.data, self.DATA) 435 self.assertEqual(self.file.tell(), len(self.DATA)) 436 437 def test_sendfile_ssl_close_peer_after_receiving(self): 438 srv_proto, cli_proto = self.prepare_sendfile( 439 is_ssl=True, close_after=len(self.DATA)) 440 ret = self.run_loop( 441 self.loop.sendfile(cli_proto.transport, self.file)) 442 self.run_loop(srv_proto.done) 443 self.assertEqual(ret, len(self.DATA)) 444 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 445 self.assertEqual(srv_proto.data, self.DATA) 446 self.assertEqual(self.file.tell(), len(self.DATA)) 447 448 # On Solaris, lowering SO_RCVBUF on a TCP connection after it has been 449 # established has no effect. Due to its age, this bug affects both Oracle 450 # Solaris as well as all other OpenSolaris forks (unless they fixed it 451 # themselves). 452 @unittest.skipIf(sys.platform.startswith('sunos'), 453 "Doesn't work on Solaris") 454 def test_sendfile_close_peer_in_the_middle_of_receiving(self): 455 srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) 456 with self.assertRaises(ConnectionError): 457 self.run_loop( 458 self.loop.sendfile(cli_proto.transport, self.file)) 459 self.run_loop(srv_proto.done) 460 461 self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), 462 srv_proto.nbytes) 463 self.assertTrue(1024 <= self.file.tell() < len(self.DATA), 464 self.file.tell()) 465 self.assertTrue(cli_proto.transport.is_closing()) 466 467 def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self): 468 469 def sendfile_native(transp, file, offset, count): 470 # to raise SendfileNotAvailableError 471 return base_events.BaseEventLoop._sendfile_native( 472 self.loop, transp, file, offset, count) 473 474 self.loop._sendfile_native = sendfile_native 475 476 srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) 477 with self.assertRaises(ConnectionError): 478 self.run_loop( 479 self.loop.sendfile(cli_proto.transport, self.file)) 480 self.run_loop(srv_proto.done) 481 482 self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), 483 srv_proto.nbytes) 484 self.assertTrue(1024 <= self.file.tell() < len(self.DATA), 485 self.file.tell()) 486 487 @unittest.skipIf(not hasattr(os, 'sendfile'), 488 "Don't have native sendfile support") 489 def test_sendfile_prevents_bare_write(self): 490 srv_proto, cli_proto = self.prepare_sendfile() 491 fut = self.loop.create_future() 492 493 async def coro(): 494 fut.set_result(None) 495 return await self.loop.sendfile(cli_proto.transport, self.file) 496 497 t = self.loop.create_task(coro()) 498 self.run_loop(fut) 499 with self.assertRaisesRegex(RuntimeError, 500 "sendfile is in progress"): 501 cli_proto.transport.write(b'data') 502 ret = self.run_loop(t) 503 self.assertEqual(ret, len(self.DATA)) 504 505 def test_sendfile_no_fallback_for_fallback_transport(self): 506 transport = mock.Mock() 507 transport.is_closing.side_effect = lambda: False 508 transport._sendfile_compatible = constants._SendfileMode.FALLBACK 509 with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'): 510 self.loop.run_until_complete( 511 self.loop.sendfile(transport, None, fallback=False)) 512 513 514class SendfileTestsBase(SendfileMixin, SockSendfileMixin): 515 pass 516 517 518if sys.platform == 'win32': 519 520 class SelectEventLoopTests(SendfileTestsBase, 521 test_utils.TestCase): 522 523 def create_event_loop(self): 524 return asyncio.SelectorEventLoop() 525 526 class ProactorEventLoopTests(SendfileTestsBase, 527 test_utils.TestCase): 528 529 def create_event_loop(self): 530 return asyncio.ProactorEventLoop() 531 532else: 533 import selectors 534 535 if hasattr(selectors, 'KqueueSelector'): 536 class KqueueEventLoopTests(SendfileTestsBase, 537 test_utils.TestCase): 538 539 def create_event_loop(self): 540 return asyncio.SelectorEventLoop( 541 selectors.KqueueSelector()) 542 543 if hasattr(selectors, 'EpollSelector'): 544 class EPollEventLoopTests(SendfileTestsBase, 545 test_utils.TestCase): 546 547 def create_event_loop(self): 548 return asyncio.SelectorEventLoop(selectors.EpollSelector()) 549 550 if hasattr(selectors, 'PollSelector'): 551 class PollEventLoopTests(SendfileTestsBase, 552 test_utils.TestCase): 553 554 def create_event_loop(self): 555 return asyncio.SelectorEventLoop(selectors.PollSelector()) 556 557 # Should always exist. 558 class SelectEventLoopTests(SendfileTestsBase, 559 test_utils.TestCase): 560 561 def create_event_loop(self): 562 return asyncio.SelectorEventLoop(selectors.SelectSelector()) 563