1"""Tests for sendfile functionality.""" 2 3import asyncio 4import os 5import socket 6import sys 7import tempfile 8import unittest 9from asyncio import base_events 10from asyncio import constants 11from unittest import mock 12from test import support 13from test.test_asyncio import utils as test_utils 14 15try: 16 import ssl 17except ImportError: 18 ssl = None 19 20 21def tearDownModule(): 22 asyncio.set_event_loop_policy(None) 23 24 25class MySendfileProto(asyncio.Protocol): 26 27 def __init__(self, loop=None, close_after=0): 28 self.transport = None 29 self.state = 'INITIAL' 30 self.nbytes = 0 31 if loop is not None: 32 self.connected = loop.create_future() 33 self.done = loop.create_future() 34 self.data = bytearray() 35 self.close_after = close_after 36 37 def connection_made(self, transport): 38 self.transport = transport 39 assert self.state == 'INITIAL', self.state 40 self.state = 'CONNECTED' 41 if self.connected: 42 self.connected.set_result(None) 43 44 def eof_received(self): 45 assert self.state == 'CONNECTED', self.state 46 self.state = 'EOF' 47 48 def connection_lost(self, exc): 49 assert self.state in ('CONNECTED', 'EOF'), self.state 50 self.state = 'CLOSED' 51 if self.done: 52 self.done.set_result(None) 53 54 def data_received(self, data): 55 assert self.state == 'CONNECTED', self.state 56 self.nbytes += len(data) 57 self.data.extend(data) 58 super().data_received(data) 59 if self.close_after and self.nbytes >= self.close_after: 60 self.transport.close() 61 62 63class MyProto(asyncio.Protocol): 64 65 def __init__(self, loop): 66 self.started = False 67 self.closed = False 68 self.data = bytearray() 69 self.fut = loop.create_future() 70 self.transport = None 71 72 def connection_made(self, transport): 73 self.started = True 74 self.transport = transport 75 76 def data_received(self, data): 77 self.data.extend(data) 78 79 def connection_lost(self, exc): 80 self.closed = True 81 self.fut.set_result(None) 82 83 async def wait_closed(self): 84 await self.fut 85 86 87class SendfileBase: 88 89 # 128 KiB plus small unaligned to buffer chunk 90 DATA = b"SendfileBaseData" * (1024 * 8 + 1) 91 92 # Reduce socket buffer size to test on relative small data sets. 93 BUF_SIZE = 4 * 1024 # 4 KiB 94 95 def create_event_loop(self): 96 raise NotImplementedError 97 98 @classmethod 99 def setUpClass(cls): 100 with open(support.TESTFN, 'wb') as fp: 101 fp.write(cls.DATA) 102 super().setUpClass() 103 104 @classmethod 105 def tearDownClass(cls): 106 support.unlink(support.TESTFN) 107 super().tearDownClass() 108 109 def setUp(self): 110 self.file = open(support.TESTFN, 'rb') 111 self.addCleanup(self.file.close) 112 self.loop = self.create_event_loop() 113 self.set_event_loop(self.loop) 114 super().setUp() 115 116 def tearDown(self): 117 # just in case if we have transport close callbacks 118 if not self.loop.is_closed(): 119 test_utils.run_briefly(self.loop) 120 121 self.doCleanups() 122 support.gc_collect() 123 super().tearDown() 124 125 def run_loop(self, coro): 126 return self.loop.run_until_complete(coro) 127 128 129class SockSendfileMixin(SendfileBase): 130 131 @classmethod 132 def setUpClass(cls): 133 cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE 134 constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16 135 super().setUpClass() 136 137 @classmethod 138 def tearDownClass(cls): 139 constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize 140 super().tearDownClass() 141 142 def make_socket(self, cleanup=True): 143 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 144 sock.setblocking(False) 145 if cleanup: 146 self.addCleanup(sock.close) 147 return sock 148 149 def reduce_receive_buffer_size(self, sock): 150 # Reduce receive socket buffer size to test on relative 151 # small data sets. 152 sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE) 153 154 def reduce_send_buffer_size(self, sock, transport=None): 155 # Reduce send socket buffer size to test on relative small data sets. 156 157 # On macOS, SO_SNDBUF is reset by connect(). So this method 158 # should be called after the socket is connected. 159 sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE) 160 161 if transport is not None: 162 transport.set_write_buffer_limits(high=self.BUF_SIZE) 163 164 def prepare_socksendfile(self): 165 proto = MyProto(self.loop) 166 port = support.find_unused_port() 167 srv_sock = self.make_socket(cleanup=False) 168 srv_sock.bind((support.HOST, port)) 169 server = self.run_loop(self.loop.create_server( 170 lambda: proto, sock=srv_sock)) 171 self.reduce_receive_buffer_size(srv_sock) 172 173 sock = self.make_socket() 174 self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port))) 175 self.reduce_send_buffer_size(sock) 176 177 def cleanup(): 178 if proto.transport is not None: 179 # can be None if the task was cancelled before 180 # connection_made callback 181 proto.transport.close() 182 self.run_loop(proto.wait_closed()) 183 184 server.close() 185 self.run_loop(server.wait_closed()) 186 187 self.addCleanup(cleanup) 188 189 return sock, proto 190 191 def test_sock_sendfile_success(self): 192 sock, proto = self.prepare_socksendfile() 193 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) 194 sock.close() 195 self.run_loop(proto.wait_closed()) 196 197 self.assertEqual(ret, len(self.DATA)) 198 self.assertEqual(proto.data, self.DATA) 199 self.assertEqual(self.file.tell(), len(self.DATA)) 200 201 def test_sock_sendfile_with_offset_and_count(self): 202 sock, proto = self.prepare_socksendfile() 203 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file, 204 1000, 2000)) 205 sock.close() 206 self.run_loop(proto.wait_closed()) 207 208 self.assertEqual(proto.data, self.DATA[1000:3000]) 209 self.assertEqual(self.file.tell(), 3000) 210 self.assertEqual(ret, 2000) 211 212 def test_sock_sendfile_zero_size(self): 213 sock, proto = self.prepare_socksendfile() 214 with tempfile.TemporaryFile() as f: 215 ret = self.run_loop(self.loop.sock_sendfile(sock, f, 216 0, None)) 217 sock.close() 218 self.run_loop(proto.wait_closed()) 219 220 self.assertEqual(ret, 0) 221 self.assertEqual(self.file.tell(), 0) 222 223 def test_sock_sendfile_mix_with_regular_send(self): 224 buf = b"mix_regular_send" * (4 * 1024) # 64 KiB 225 sock, proto = self.prepare_socksendfile() 226 self.run_loop(self.loop.sock_sendall(sock, buf)) 227 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) 228 self.run_loop(self.loop.sock_sendall(sock, buf)) 229 sock.close() 230 self.run_loop(proto.wait_closed()) 231 232 self.assertEqual(ret, len(self.DATA)) 233 expected = buf + self.DATA + buf 234 self.assertEqual(proto.data, expected) 235 self.assertEqual(self.file.tell(), len(self.DATA)) 236 237 238class SendfileMixin(SendfileBase): 239 240 # Note: sendfile via SSL transport is equal to sendfile fallback 241 242 def prepare_sendfile(self, *, is_ssl=False, close_after=0): 243 port = support.find_unused_port() 244 srv_proto = MySendfileProto(loop=self.loop, 245 close_after=close_after) 246 if is_ssl: 247 if not ssl: 248 self.skipTest("No ssl module") 249 srv_ctx = test_utils.simple_server_sslcontext() 250 cli_ctx = test_utils.simple_client_sslcontext() 251 else: 252 srv_ctx = None 253 cli_ctx = None 254 srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 255 srv_sock.bind((support.HOST, port)) 256 server = self.run_loop(self.loop.create_server( 257 lambda: srv_proto, sock=srv_sock, ssl=srv_ctx)) 258 self.reduce_receive_buffer_size(srv_sock) 259 260 if is_ssl: 261 server_hostname = support.HOST 262 else: 263 server_hostname = None 264 cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 265 cli_sock.connect((support.HOST, port)) 266 267 cli_proto = MySendfileProto(loop=self.loop) 268 tr, pr = self.run_loop(self.loop.create_connection( 269 lambda: cli_proto, sock=cli_sock, 270 ssl=cli_ctx, server_hostname=server_hostname)) 271 self.reduce_send_buffer_size(cli_sock, transport=tr) 272 273 def cleanup(): 274 srv_proto.transport.close() 275 cli_proto.transport.close() 276 self.run_loop(srv_proto.done) 277 self.run_loop(cli_proto.done) 278 279 server.close() 280 self.run_loop(server.wait_closed()) 281 282 self.addCleanup(cleanup) 283 return srv_proto, cli_proto 284 285 @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported") 286 def test_sendfile_not_supported(self): 287 tr, pr = self.run_loop( 288 self.loop.create_datagram_endpoint( 289 asyncio.DatagramProtocol, 290 family=socket.AF_INET)) 291 try: 292 with self.assertRaisesRegex(RuntimeError, "not supported"): 293 self.run_loop( 294 self.loop.sendfile(tr, self.file)) 295 self.assertEqual(0, self.file.tell()) 296 finally: 297 # don't use self.addCleanup because it produces resource warning 298 tr.close() 299 300 def test_sendfile(self): 301 srv_proto, cli_proto = self.prepare_sendfile() 302 ret = self.run_loop( 303 self.loop.sendfile(cli_proto.transport, self.file)) 304 cli_proto.transport.close() 305 self.run_loop(srv_proto.done) 306 self.assertEqual(ret, len(self.DATA)) 307 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 308 self.assertEqual(srv_proto.data, self.DATA) 309 self.assertEqual(self.file.tell(), len(self.DATA)) 310 311 def test_sendfile_force_fallback(self): 312 srv_proto, cli_proto = self.prepare_sendfile() 313 314 def sendfile_native(transp, file, offset, count): 315 # to raise SendfileNotAvailableError 316 return base_events.BaseEventLoop._sendfile_native( 317 self.loop, transp, file, offset, count) 318 319 self.loop._sendfile_native = sendfile_native 320 321 ret = self.run_loop( 322 self.loop.sendfile(cli_proto.transport, self.file)) 323 cli_proto.transport.close() 324 self.run_loop(srv_proto.done) 325 self.assertEqual(ret, len(self.DATA)) 326 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 327 self.assertEqual(srv_proto.data, self.DATA) 328 self.assertEqual(self.file.tell(), len(self.DATA)) 329 330 def test_sendfile_force_unsupported_native(self): 331 if sys.platform == 'win32': 332 if isinstance(self.loop, asyncio.ProactorEventLoop): 333 self.skipTest("Fails on proactor event loop") 334 srv_proto, cli_proto = self.prepare_sendfile() 335 336 def sendfile_native(transp, file, offset, count): 337 # to raise SendfileNotAvailableError 338 return base_events.BaseEventLoop._sendfile_native( 339 self.loop, transp, file, offset, count) 340 341 self.loop._sendfile_native = sendfile_native 342 343 with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, 344 "not supported"): 345 self.run_loop( 346 self.loop.sendfile(cli_proto.transport, self.file, 347 fallback=False)) 348 349 cli_proto.transport.close() 350 self.run_loop(srv_proto.done) 351 self.assertEqual(srv_proto.nbytes, 0) 352 self.assertEqual(self.file.tell(), 0) 353 354 def test_sendfile_ssl(self): 355 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) 356 ret = self.run_loop( 357 self.loop.sendfile(cli_proto.transport, self.file)) 358 cli_proto.transport.close() 359 self.run_loop(srv_proto.done) 360 self.assertEqual(ret, len(self.DATA)) 361 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 362 self.assertEqual(srv_proto.data, self.DATA) 363 self.assertEqual(self.file.tell(), len(self.DATA)) 364 365 def test_sendfile_for_closing_transp(self): 366 srv_proto, cli_proto = self.prepare_sendfile() 367 cli_proto.transport.close() 368 with self.assertRaisesRegex(RuntimeError, "is closing"): 369 self.run_loop(self.loop.sendfile(cli_proto.transport, self.file)) 370 self.run_loop(srv_proto.done) 371 self.assertEqual(srv_proto.nbytes, 0) 372 self.assertEqual(self.file.tell(), 0) 373 374 def test_sendfile_pre_and_post_data(self): 375 srv_proto, cli_proto = self.prepare_sendfile() 376 PREFIX = b'PREFIX__' * 1024 # 8 KiB 377 SUFFIX = b'--SUFFIX' * 1024 # 8 KiB 378 cli_proto.transport.write(PREFIX) 379 ret = self.run_loop( 380 self.loop.sendfile(cli_proto.transport, self.file)) 381 cli_proto.transport.write(SUFFIX) 382 cli_proto.transport.close() 383 self.run_loop(srv_proto.done) 384 self.assertEqual(ret, len(self.DATA)) 385 self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) 386 self.assertEqual(self.file.tell(), len(self.DATA)) 387 388 def test_sendfile_ssl_pre_and_post_data(self): 389 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) 390 PREFIX = b'zxcvbnm' * 1024 391 SUFFIX = b'0987654321' * 1024 392 cli_proto.transport.write(PREFIX) 393 ret = self.run_loop( 394 self.loop.sendfile(cli_proto.transport, self.file)) 395 cli_proto.transport.write(SUFFIX) 396 cli_proto.transport.close() 397 self.run_loop(srv_proto.done) 398 self.assertEqual(ret, len(self.DATA)) 399 self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) 400 self.assertEqual(self.file.tell(), len(self.DATA)) 401 402 def test_sendfile_partial(self): 403 srv_proto, cli_proto = self.prepare_sendfile() 404 ret = self.run_loop( 405 self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) 406 cli_proto.transport.close() 407 self.run_loop(srv_proto.done) 408 self.assertEqual(ret, 100) 409 self.assertEqual(srv_proto.nbytes, 100) 410 self.assertEqual(srv_proto.data, self.DATA[1000:1100]) 411 self.assertEqual(self.file.tell(), 1100) 412 413 def test_sendfile_ssl_partial(self): 414 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) 415 ret = self.run_loop( 416 self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) 417 cli_proto.transport.close() 418 self.run_loop(srv_proto.done) 419 self.assertEqual(ret, 100) 420 self.assertEqual(srv_proto.nbytes, 100) 421 self.assertEqual(srv_proto.data, self.DATA[1000:1100]) 422 self.assertEqual(self.file.tell(), 1100) 423 424 def test_sendfile_close_peer_after_receiving(self): 425 srv_proto, cli_proto = self.prepare_sendfile( 426 close_after=len(self.DATA)) 427 ret = self.run_loop( 428 self.loop.sendfile(cli_proto.transport, self.file)) 429 cli_proto.transport.close() 430 self.run_loop(srv_proto.done) 431 self.assertEqual(ret, len(self.DATA)) 432 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 433 self.assertEqual(srv_proto.data, self.DATA) 434 self.assertEqual(self.file.tell(), len(self.DATA)) 435 436 def test_sendfile_ssl_close_peer_after_receiving(self): 437 srv_proto, cli_proto = self.prepare_sendfile( 438 is_ssl=True, close_after=len(self.DATA)) 439 ret = self.run_loop( 440 self.loop.sendfile(cli_proto.transport, self.file)) 441 self.run_loop(srv_proto.done) 442 self.assertEqual(ret, len(self.DATA)) 443 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 444 self.assertEqual(srv_proto.data, self.DATA) 445 self.assertEqual(self.file.tell(), len(self.DATA)) 446 447 def test_sendfile_close_peer_in_the_middle_of_receiving(self): 448 srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) 449 with self.assertRaises(ConnectionError): 450 self.run_loop( 451 self.loop.sendfile(cli_proto.transport, self.file)) 452 self.run_loop(srv_proto.done) 453 454 self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), 455 srv_proto.nbytes) 456 self.assertTrue(1024 <= self.file.tell() < len(self.DATA), 457 self.file.tell()) 458 self.assertTrue(cli_proto.transport.is_closing()) 459 460 def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self): 461 462 def sendfile_native(transp, file, offset, count): 463 # to raise SendfileNotAvailableError 464 return base_events.BaseEventLoop._sendfile_native( 465 self.loop, transp, file, offset, count) 466 467 self.loop._sendfile_native = sendfile_native 468 469 srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) 470 with self.assertRaises(ConnectionError): 471 self.run_loop( 472 self.loop.sendfile(cli_proto.transport, self.file)) 473 self.run_loop(srv_proto.done) 474 475 self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), 476 srv_proto.nbytes) 477 self.assertTrue(1024 <= self.file.tell() < len(self.DATA), 478 self.file.tell()) 479 480 @unittest.skipIf(not hasattr(os, 'sendfile'), 481 "Don't have native sendfile support") 482 def test_sendfile_prevents_bare_write(self): 483 srv_proto, cli_proto = self.prepare_sendfile() 484 fut = self.loop.create_future() 485 486 async def coro(): 487 fut.set_result(None) 488 return await self.loop.sendfile(cli_proto.transport, self.file) 489 490 t = self.loop.create_task(coro()) 491 self.run_loop(fut) 492 with self.assertRaisesRegex(RuntimeError, 493 "sendfile is in progress"): 494 cli_proto.transport.write(b'data') 495 ret = self.run_loop(t) 496 self.assertEqual(ret, len(self.DATA)) 497 498 def test_sendfile_no_fallback_for_fallback_transport(self): 499 transport = mock.Mock() 500 transport.is_closing.side_effect = lambda: False 501 transport._sendfile_compatible = constants._SendfileMode.FALLBACK 502 with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'): 503 self.loop.run_until_complete( 504 self.loop.sendfile(transport, None, fallback=False)) 505 506 507class SendfileTestsBase(SendfileMixin, SockSendfileMixin): 508 pass 509 510 511if sys.platform == 'win32': 512 513 class SelectEventLoopTests(SendfileTestsBase, 514 test_utils.TestCase): 515 516 def create_event_loop(self): 517 return asyncio.SelectorEventLoop() 518 519 class ProactorEventLoopTests(SendfileTestsBase, 520 test_utils.TestCase): 521 522 def create_event_loop(self): 523 return asyncio.ProactorEventLoop() 524 525else: 526 import selectors 527 528 if hasattr(selectors, 'KqueueSelector'): 529 class KqueueEventLoopTests(SendfileTestsBase, 530 test_utils.TestCase): 531 532 def create_event_loop(self): 533 return asyncio.SelectorEventLoop( 534 selectors.KqueueSelector()) 535 536 if hasattr(selectors, 'EpollSelector'): 537 class EPollEventLoopTests(SendfileTestsBase, 538 test_utils.TestCase): 539 540 def create_event_loop(self): 541 return asyncio.SelectorEventLoop(selectors.EpollSelector()) 542 543 if hasattr(selectors, 'PollSelector'): 544 class PollEventLoopTests(SendfileTestsBase, 545 test_utils.TestCase): 546 547 def create_event_loop(self): 548 return asyncio.SelectorEventLoop(selectors.PollSelector()) 549 550 # Should always exist. 551 class SelectEventLoopTests(SendfileTestsBase, 552 test_utils.TestCase): 553 554 def create_event_loop(self): 555 return asyncio.SelectorEventLoop(selectors.SelectSelector()) 556