1__all__ = ( 2 'StreamReader', 'StreamWriter', 'StreamReaderProtocol', 3 'open_connection', 'start_server') 4 5import collections 6import socket 7import sys 8import warnings 9import weakref 10 11if hasattr(socket, 'AF_UNIX'): 12 __all__ += ('open_unix_connection', 'start_unix_server') 13 14from . import coroutines 15from . import events 16from . import exceptions 17from . import format_helpers 18from . import protocols 19from .log import logger 20from .tasks import sleep 21 22 23_DEFAULT_LIMIT = 2 ** 16 # 64 KiB 24 25 26async def open_connection(host=None, port=None, *, 27 limit=_DEFAULT_LIMIT, **kwds): 28 """A wrapper for create_connection() returning a (reader, writer) pair. 29 30 The reader returned is a StreamReader instance; the writer is a 31 StreamWriter instance. 32 33 The arguments are all the usual arguments to create_connection() 34 except protocol_factory; most common are positional host and port, 35 with various optional keyword arguments following. 36 37 Additional optional keyword arguments are loop (to set the event loop 38 instance to use) and limit (to set the buffer limit passed to the 39 StreamReader). 40 41 (If you want to customize the StreamReader and/or 42 StreamReaderProtocol classes, just copy the code -- there's 43 really nothing special here except some convenience.) 44 """ 45 loop = events.get_running_loop() 46 reader = StreamReader(limit=limit, loop=loop) 47 protocol = StreamReaderProtocol(reader, loop=loop) 48 transport, _ = await loop.create_connection( 49 lambda: protocol, host, port, **kwds) 50 writer = StreamWriter(transport, protocol, reader, loop) 51 return reader, writer 52 53 54async def start_server(client_connected_cb, host=None, port=None, *, 55 limit=_DEFAULT_LIMIT, **kwds): 56 """Start a socket server, call back for each client connected. 57 58 The first parameter, `client_connected_cb`, takes two parameters: 59 client_reader, client_writer. client_reader is a StreamReader 60 object, while client_writer is a StreamWriter object. This 61 parameter can either be a plain callback function or a coroutine; 62 if it is a coroutine, it will be automatically converted into a 63 Task. 64 65 The rest of the arguments are all the usual arguments to 66 loop.create_server() except protocol_factory; most common are 67 positional host and port, with various optional keyword arguments 68 following. The return value is the same as loop.create_server(). 69 70 Additional optional keyword argument is limit (to set the buffer 71 limit passed to the StreamReader). 72 73 The return value is the same as loop.create_server(), i.e. a 74 Server object which can be used to stop the service. 75 """ 76 loop = events.get_running_loop() 77 78 def factory(): 79 reader = StreamReader(limit=limit, loop=loop) 80 protocol = StreamReaderProtocol(reader, client_connected_cb, 81 loop=loop) 82 return protocol 83 84 return await loop.create_server(factory, host, port, **kwds) 85 86 87if hasattr(socket, 'AF_UNIX'): 88 # UNIX Domain Sockets are supported on this platform 89 90 async def open_unix_connection(path=None, *, 91 limit=_DEFAULT_LIMIT, **kwds): 92 """Similar to `open_connection` but works with UNIX Domain Sockets.""" 93 loop = events.get_running_loop() 94 95 reader = StreamReader(limit=limit, loop=loop) 96 protocol = StreamReaderProtocol(reader, loop=loop) 97 transport, _ = await loop.create_unix_connection( 98 lambda: protocol, path, **kwds) 99 writer = StreamWriter(transport, protocol, reader, loop) 100 return reader, writer 101 102 async def start_unix_server(client_connected_cb, path=None, *, 103 limit=_DEFAULT_LIMIT, **kwds): 104 """Similar to `start_server` but works with UNIX Domain Sockets.""" 105 loop = events.get_running_loop() 106 107 def factory(): 108 reader = StreamReader(limit=limit, loop=loop) 109 protocol = StreamReaderProtocol(reader, client_connected_cb, 110 loop=loop) 111 return protocol 112 113 return await loop.create_unix_server(factory, path, **kwds) 114 115 116class FlowControlMixin(protocols.Protocol): 117 """Reusable flow control logic for StreamWriter.drain(). 118 119 This implements the protocol methods pause_writing(), 120 resume_writing() and connection_lost(). If the subclass overrides 121 these it must call the super methods. 122 123 StreamWriter.drain() must wait for _drain_helper() coroutine. 124 """ 125 126 def __init__(self, loop=None): 127 if loop is None: 128 self._loop = events.get_event_loop() 129 else: 130 self._loop = loop 131 self._paused = False 132 self._drain_waiters = collections.deque() 133 self._connection_lost = False 134 135 def pause_writing(self): 136 assert not self._paused 137 self._paused = True 138 if self._loop.get_debug(): 139 logger.debug("%r pauses writing", self) 140 141 def resume_writing(self): 142 assert self._paused 143 self._paused = False 144 if self._loop.get_debug(): 145 logger.debug("%r resumes writing", self) 146 147 for waiter in self._drain_waiters: 148 if not waiter.done(): 149 waiter.set_result(None) 150 151 def connection_lost(self, exc): 152 self._connection_lost = True 153 # Wake up the writer(s) if currently paused. 154 if not self._paused: 155 return 156 157 for waiter in self._drain_waiters: 158 if not waiter.done(): 159 if exc is None: 160 waiter.set_result(None) 161 else: 162 waiter.set_exception(exc) 163 164 async def _drain_helper(self): 165 if self._connection_lost: 166 raise ConnectionResetError('Connection lost') 167 if not self._paused: 168 return 169 waiter = self._loop.create_future() 170 self._drain_waiters.append(waiter) 171 try: 172 await waiter 173 finally: 174 self._drain_waiters.remove(waiter) 175 176 def _get_close_waiter(self, stream): 177 raise NotImplementedError 178 179 180class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): 181 """Helper class to adapt between Protocol and StreamReader. 182 183 (This is a helper class instead of making StreamReader itself a 184 Protocol subclass, because the StreamReader has other potential 185 uses, and to prevent the user of the StreamReader to accidentally 186 call inappropriate methods of the protocol.) 187 """ 188 189 _source_traceback = None 190 191 def __init__(self, stream_reader, client_connected_cb=None, loop=None): 192 super().__init__(loop=loop) 193 if stream_reader is not None: 194 self._stream_reader_wr = weakref.ref(stream_reader) 195 self._source_traceback = stream_reader._source_traceback 196 else: 197 self._stream_reader_wr = None 198 if client_connected_cb is not None: 199 # This is a stream created by the `create_server()` function. 200 # Keep a strong reference to the reader until a connection 201 # is established. 202 self._strong_reader = stream_reader 203 self._reject_connection = False 204 self._task = None 205 self._transport = None 206 self._client_connected_cb = client_connected_cb 207 self._over_ssl = False 208 self._closed = self._loop.create_future() 209 210 @property 211 def _stream_reader(self): 212 if self._stream_reader_wr is None: 213 return None 214 return self._stream_reader_wr() 215 216 def _replace_transport(self, transport): 217 loop = self._loop 218 self._transport = transport 219 self._over_ssl = transport.get_extra_info('sslcontext') is not None 220 221 def connection_made(self, transport): 222 if self._reject_connection: 223 context = { 224 'message': ('An open stream was garbage collected prior to ' 225 'establishing network connection; ' 226 'call "stream.close()" explicitly.') 227 } 228 if self._source_traceback: 229 context['source_traceback'] = self._source_traceback 230 self._loop.call_exception_handler(context) 231 transport.abort() 232 return 233 self._transport = transport 234 reader = self._stream_reader 235 if reader is not None: 236 reader.set_transport(transport) 237 self._over_ssl = transport.get_extra_info('sslcontext') is not None 238 if self._client_connected_cb is not None: 239 writer = StreamWriter(transport, self, reader, self._loop) 240 res = self._client_connected_cb(reader, writer) 241 if coroutines.iscoroutine(res): 242 def callback(task): 243 if task.cancelled(): 244 transport.close() 245 return 246 exc = task.exception() 247 if exc is not None: 248 self._loop.call_exception_handler({ 249 'message': 'Unhandled exception in client_connected_cb', 250 'exception': exc, 251 'transport': transport, 252 }) 253 transport.close() 254 255 self._task = self._loop.create_task(res) 256 self._task.add_done_callback(callback) 257 258 self._strong_reader = None 259 260 def connection_lost(self, exc): 261 reader = self._stream_reader 262 if reader is not None: 263 if exc is None: 264 reader.feed_eof() 265 else: 266 reader.set_exception(exc) 267 if not self._closed.done(): 268 if exc is None: 269 self._closed.set_result(None) 270 else: 271 self._closed.set_exception(exc) 272 super().connection_lost(exc) 273 self._stream_reader_wr = None 274 self._stream_writer = None 275 self._task = None 276 self._transport = None 277 278 def data_received(self, data): 279 reader = self._stream_reader 280 if reader is not None: 281 reader.feed_data(data) 282 283 def eof_received(self): 284 reader = self._stream_reader 285 if reader is not None: 286 reader.feed_eof() 287 if self._over_ssl: 288 # Prevent a warning in SSLProtocol.eof_received: 289 # "returning true from eof_received() 290 # has no effect when using ssl" 291 return False 292 return True 293 294 def _get_close_waiter(self, stream): 295 return self._closed 296 297 def __del__(self): 298 # Prevent reports about unhandled exceptions. 299 # Better than self._closed._log_traceback = False hack 300 try: 301 closed = self._closed 302 except AttributeError: 303 pass # failed constructor 304 else: 305 if closed.done() and not closed.cancelled(): 306 closed.exception() 307 308 309class StreamWriter: 310 """Wraps a Transport. 311 312 This exposes write(), writelines(), [can_]write_eof(), 313 get_extra_info() and close(). It adds drain() which returns an 314 optional Future on which you can wait for flow control. It also 315 adds a transport property which references the Transport 316 directly. 317 """ 318 319 def __init__(self, transport, protocol, reader, loop): 320 self._transport = transport 321 self._protocol = protocol 322 # drain() expects that the reader has an exception() method 323 assert reader is None or isinstance(reader, StreamReader) 324 self._reader = reader 325 self._loop = loop 326 self._complete_fut = self._loop.create_future() 327 self._complete_fut.set_result(None) 328 329 def __repr__(self): 330 info = [self.__class__.__name__, f'transport={self._transport!r}'] 331 if self._reader is not None: 332 info.append(f'reader={self._reader!r}') 333 return '<{}>'.format(' '.join(info)) 334 335 @property 336 def transport(self): 337 return self._transport 338 339 def write(self, data): 340 self._transport.write(data) 341 342 def writelines(self, data): 343 self._transport.writelines(data) 344 345 def write_eof(self): 346 return self._transport.write_eof() 347 348 def can_write_eof(self): 349 return self._transport.can_write_eof() 350 351 def close(self): 352 return self._transport.close() 353 354 def is_closing(self): 355 return self._transport.is_closing() 356 357 async def wait_closed(self): 358 await self._protocol._get_close_waiter(self) 359 360 def get_extra_info(self, name, default=None): 361 return self._transport.get_extra_info(name, default) 362 363 async def drain(self): 364 """Flush the write buffer. 365 366 The intended use is to write 367 368 w.write(data) 369 await w.drain() 370 """ 371 if self._reader is not None: 372 exc = self._reader.exception() 373 if exc is not None: 374 raise exc 375 if self._transport.is_closing(): 376 # Wait for protocol.connection_lost() call 377 # Raise connection closing error if any, 378 # ConnectionResetError otherwise 379 # Yield to the event loop so connection_lost() may be 380 # called. Without this, _drain_helper() would return 381 # immediately, and code that calls 382 # write(...); await drain() 383 # in a loop would never call connection_lost(), so it 384 # would not see an error when the socket is closed. 385 await sleep(0) 386 await self._protocol._drain_helper() 387 388 async def start_tls(self, sslcontext, *, 389 server_hostname=None, 390 ssl_handshake_timeout=None, 391 ssl_shutdown_timeout=None): 392 """Upgrade an existing stream-based connection to TLS.""" 393 server_side = self._protocol._client_connected_cb is not None 394 protocol = self._protocol 395 await self.drain() 396 new_transport = await self._loop.start_tls( # type: ignore 397 self._transport, protocol, sslcontext, 398 server_side=server_side, server_hostname=server_hostname, 399 ssl_handshake_timeout=ssl_handshake_timeout, 400 ssl_shutdown_timeout=ssl_shutdown_timeout) 401 self._transport = new_transport 402 protocol._replace_transport(new_transport) 403 404 def __del__(self, warnings=warnings): 405 if not self._transport.is_closing(): 406 if self._loop.is_closed(): 407 warnings.warn("loop is closed", ResourceWarning) 408 else: 409 self.close() 410 warnings.warn(f"unclosed {self!r}", ResourceWarning) 411 412class StreamReader: 413 414 _source_traceback = None 415 416 def __init__(self, limit=_DEFAULT_LIMIT, loop=None): 417 # The line length limit is a security feature; 418 # it also doubles as half the buffer limit. 419 420 if limit <= 0: 421 raise ValueError('Limit cannot be <= 0') 422 423 self._limit = limit 424 if loop is None: 425 self._loop = events.get_event_loop() 426 else: 427 self._loop = loop 428 self._buffer = bytearray() 429 self._eof = False # Whether we're done. 430 self._waiter = None # A future used by _wait_for_data() 431 self._exception = None 432 self._transport = None 433 self._paused = False 434 if self._loop.get_debug(): 435 self._source_traceback = format_helpers.extract_stack( 436 sys._getframe(1)) 437 438 def __repr__(self): 439 info = ['StreamReader'] 440 if self._buffer: 441 info.append(f'{len(self._buffer)} bytes') 442 if self._eof: 443 info.append('eof') 444 if self._limit != _DEFAULT_LIMIT: 445 info.append(f'limit={self._limit}') 446 if self._waiter: 447 info.append(f'waiter={self._waiter!r}') 448 if self._exception: 449 info.append(f'exception={self._exception!r}') 450 if self._transport: 451 info.append(f'transport={self._transport!r}') 452 if self._paused: 453 info.append('paused') 454 return '<{}>'.format(' '.join(info)) 455 456 def exception(self): 457 return self._exception 458 459 def set_exception(self, exc): 460 self._exception = exc 461 462 waiter = self._waiter 463 if waiter is not None: 464 self._waiter = None 465 if not waiter.cancelled(): 466 waiter.set_exception(exc) 467 468 def _wakeup_waiter(self): 469 """Wakeup read*() functions waiting for data or EOF.""" 470 waiter = self._waiter 471 if waiter is not None: 472 self._waiter = None 473 if not waiter.cancelled(): 474 waiter.set_result(None) 475 476 def set_transport(self, transport): 477 assert self._transport is None, 'Transport already set' 478 self._transport = transport 479 480 def _maybe_resume_transport(self): 481 if self._paused and len(self._buffer) <= self._limit: 482 self._paused = False 483 self._transport.resume_reading() 484 485 def feed_eof(self): 486 self._eof = True 487 self._wakeup_waiter() 488 489 def at_eof(self): 490 """Return True if the buffer is empty and 'feed_eof' was called.""" 491 return self._eof and not self._buffer 492 493 def feed_data(self, data): 494 assert not self._eof, 'feed_data after feed_eof' 495 496 if not data: 497 return 498 499 self._buffer.extend(data) 500 self._wakeup_waiter() 501 502 if (self._transport is not None and 503 not self._paused and 504 len(self._buffer) > 2 * self._limit): 505 try: 506 self._transport.pause_reading() 507 except NotImplementedError: 508 # The transport can't be paused. 509 # We'll just have to buffer all data. 510 # Forget the transport so we don't keep trying. 511 self._transport = None 512 else: 513 self._paused = True 514 515 async def _wait_for_data(self, func_name): 516 """Wait until feed_data() or feed_eof() is called. 517 518 If stream was paused, automatically resume it. 519 """ 520 # StreamReader uses a future to link the protocol feed_data() method 521 # to a read coroutine. Running two read coroutines at the same time 522 # would have an unexpected behaviour. It would not possible to know 523 # which coroutine would get the next data. 524 if self._waiter is not None: 525 raise RuntimeError( 526 f'{func_name}() called while another coroutine is ' 527 f'already waiting for incoming data') 528 529 assert not self._eof, '_wait_for_data after EOF' 530 531 # Waiting for data while paused will make deadlock, so prevent it. 532 # This is essential for readexactly(n) for case when n > self._limit. 533 if self._paused: 534 self._paused = False 535 self._transport.resume_reading() 536 537 self._waiter = self._loop.create_future() 538 try: 539 await self._waiter 540 finally: 541 self._waiter = None 542 543 async def readline(self): 544 """Read chunk of data from the stream until newline (b'\n') is found. 545 546 On success, return chunk that ends with newline. If only partial 547 line can be read due to EOF, return incomplete line without 548 terminating newline. When EOF was reached while no bytes read, empty 549 bytes object is returned. 550 551 If limit is reached, ValueError will be raised. In that case, if 552 newline was found, complete line including newline will be removed 553 from internal buffer. Else, internal buffer will be cleared. Limit is 554 compared against part of the line without newline. 555 556 If stream was paused, this function will automatically resume it if 557 needed. 558 """ 559 sep = b'\n' 560 seplen = len(sep) 561 try: 562 line = await self.readuntil(sep) 563 except exceptions.IncompleteReadError as e: 564 return e.partial 565 except exceptions.LimitOverrunError as e: 566 if self._buffer.startswith(sep, e.consumed): 567 del self._buffer[:e.consumed + seplen] 568 else: 569 self._buffer.clear() 570 self._maybe_resume_transport() 571 raise ValueError(e.args[0]) 572 return line 573 574 async def readuntil(self, separator=b'\n'): 575 """Read data from the stream until ``separator`` is found. 576 577 On success, the data and separator will be removed from the 578 internal buffer (consumed). Returned data will include the 579 separator at the end. 580 581 Configured stream limit is used to check result. Limit sets the 582 maximal length of data that can be returned, not counting the 583 separator. 584 585 If an EOF occurs and the complete separator is still not found, 586 an IncompleteReadError exception will be raised, and the internal 587 buffer will be reset. The IncompleteReadError.partial attribute 588 may contain the separator partially. 589 590 If the data cannot be read because of over limit, a 591 LimitOverrunError exception will be raised, and the data 592 will be left in the internal buffer, so it can be read again. 593 594 The ``separator`` may also be a tuple of separators. In this 595 case the return value will be the shortest possible that has any 596 separator as the suffix. For the purposes of LimitOverrunError, 597 the shortest possible separator is considered to be the one that 598 matched. 599 """ 600 if isinstance(separator, tuple): 601 # Makes sure shortest matches wins 602 separator = sorted(separator, key=len) 603 else: 604 separator = [separator] 605 if not separator: 606 raise ValueError('Separator should contain at least one element') 607 min_seplen = len(separator[0]) 608 max_seplen = len(separator[-1]) 609 if min_seplen == 0: 610 raise ValueError('Separator should be at least one-byte string') 611 612 if self._exception is not None: 613 raise self._exception 614 615 # Consume whole buffer except last bytes, which length is 616 # one less than max_seplen. Let's check corner cases with 617 # separator[-1]='SEPARATOR': 618 # * we have received almost complete separator (without last 619 # byte). i.e buffer='some textSEPARATO'. In this case we 620 # can safely consume max_seplen - 1 bytes. 621 # * last byte of buffer is first byte of separator, i.e. 622 # buffer='abcdefghijklmnopqrS'. We may safely consume 623 # everything except that last byte, but this require to 624 # analyze bytes of buffer that match partial separator. 625 # This is slow and/or require FSM. For this case our 626 # implementation is not optimal, since require rescanning 627 # of data that is known to not belong to separator. In 628 # real world, separator will not be so long to notice 629 # performance problems. Even when reading MIME-encoded 630 # messages :) 631 632 # `offset` is the number of bytes from the beginning of the buffer 633 # where there is no occurrence of any `separator`. 634 offset = 0 635 636 # Loop until we find a `separator` in the buffer, exceed the buffer size, 637 # or an EOF has happened. 638 while True: 639 buflen = len(self._buffer) 640 641 # Check if we now have enough data in the buffer for shortest 642 # separator to fit. 643 if buflen - offset >= min_seplen: 644 match_start = None 645 match_end = None 646 for sep in separator: 647 isep = self._buffer.find(sep, offset) 648 649 if isep != -1: 650 # `separator` is in the buffer. `match_start` and 651 # `match_end` will be used later to retrieve the 652 # data. 653 end = isep + len(sep) 654 if match_end is None or end < match_end: 655 match_end = end 656 match_start = isep 657 if match_end is not None: 658 break 659 660 # see upper comment for explanation. 661 offset = max(0, buflen + 1 - max_seplen) 662 if offset > self._limit: 663 raise exceptions.LimitOverrunError( 664 'Separator is not found, and chunk exceed the limit', 665 offset) 666 667 # Complete message (with full separator) may be present in buffer 668 # even when EOF flag is set. This may happen when the last chunk 669 # adds data which makes separator be found. That's why we check for 670 # EOF *after* inspecting the buffer. 671 if self._eof: 672 chunk = bytes(self._buffer) 673 self._buffer.clear() 674 raise exceptions.IncompleteReadError(chunk, None) 675 676 # _wait_for_data() will resume reading if stream was paused. 677 await self._wait_for_data('readuntil') 678 679 if match_start > self._limit: 680 raise exceptions.LimitOverrunError( 681 'Separator is found, but chunk is longer than limit', match_start) 682 683 chunk = self._buffer[:match_end] 684 del self._buffer[:match_end] 685 self._maybe_resume_transport() 686 return bytes(chunk) 687 688 async def read(self, n=-1): 689 """Read up to `n` bytes from the stream. 690 691 If `n` is not provided or set to -1, 692 read until EOF, then return all read bytes. 693 If EOF was received and the internal buffer is empty, 694 return an empty bytes object. 695 696 If `n` is 0, return an empty bytes object immediately. 697 698 If `n` is positive, return at most `n` available bytes 699 as soon as at least 1 byte is available in the internal buffer. 700 If EOF is received before any byte is read, return an empty 701 bytes object. 702 703 Returned value is not limited with limit, configured at stream 704 creation. 705 706 If stream was paused, this function will automatically resume it if 707 needed. 708 """ 709 710 if self._exception is not None: 711 raise self._exception 712 713 if n == 0: 714 return b'' 715 716 if n < 0: 717 # This used to just loop creating a new waiter hoping to 718 # collect everything in self._buffer, but that would 719 # deadlock if the subprocess sends more than self.limit 720 # bytes. So just call self.read(self._limit) until EOF. 721 blocks = [] 722 while True: 723 block = await self.read(self._limit) 724 if not block: 725 break 726 blocks.append(block) 727 return b''.join(blocks) 728 729 if not self._buffer and not self._eof: 730 await self._wait_for_data('read') 731 732 # This will work right even if buffer is less than n bytes 733 data = bytes(memoryview(self._buffer)[:n]) 734 del self._buffer[:n] 735 736 self._maybe_resume_transport() 737 return data 738 739 async def readexactly(self, n): 740 """Read exactly `n` bytes. 741 742 Raise an IncompleteReadError if EOF is reached before `n` bytes can be 743 read. The IncompleteReadError.partial attribute of the exception will 744 contain the partial read bytes. 745 746 if n is zero, return empty bytes object. 747 748 Returned value is not limited with limit, configured at stream 749 creation. 750 751 If stream was paused, this function will automatically resume it if 752 needed. 753 """ 754 if n < 0: 755 raise ValueError('readexactly size can not be less than zero') 756 757 if self._exception is not None: 758 raise self._exception 759 760 if n == 0: 761 return b'' 762 763 while len(self._buffer) < n: 764 if self._eof: 765 incomplete = bytes(self._buffer) 766 self._buffer.clear() 767 raise exceptions.IncompleteReadError(incomplete, n) 768 769 await self._wait_for_data('readexactly') 770 771 if len(self._buffer) == n: 772 data = bytes(self._buffer) 773 self._buffer.clear() 774 else: 775 data = bytes(memoryview(self._buffer)[:n]) 776 del self._buffer[:n] 777 self._maybe_resume_transport() 778 return data 779 780 def __aiter__(self): 781 return self 782 783 async def __anext__(self): 784 val = await self.readline() 785 if val == b'': 786 raise StopAsyncIteration 787 return val 788