1__all__ = ( 2 'StreamReader', 'StreamWriter', 'StreamReaderProtocol', 3 'open_connection', 'start_server') 4 5import socket 6import sys 7import warnings 8import weakref 9 10if hasattr(socket, 'AF_UNIX'): 11 __all__ += ('open_unix_connection', 'start_unix_server') 12 13from . import coroutines 14from . import events 15from . import exceptions 16from . import format_helpers 17from . import protocols 18from .log import logger 19from .tasks import sleep 20 21 22_DEFAULT_LIMIT = 2 ** 16 # 64 KiB 23 24 25async def open_connection(host=None, port=None, *, 26 loop=None, limit=_DEFAULT_LIMIT, **kwds): 27 """A wrapper for create_connection() returning a (reader, writer) pair. 28 29 The reader returned is a StreamReader instance; the writer is a 30 StreamWriter instance. 31 32 The arguments are all the usual arguments to create_connection() 33 except protocol_factory; most common are positional host and port, 34 with various optional keyword arguments following. 35 36 Additional optional keyword arguments are loop (to set the event loop 37 instance to use) and limit (to set the buffer limit passed to the 38 StreamReader). 39 40 (If you want to customize the StreamReader and/or 41 StreamReaderProtocol classes, just copy the code -- there's 42 really nothing special here except some convenience.) 43 """ 44 if loop is None: 45 loop = events.get_event_loop() 46 else: 47 warnings.warn("The loop argument is deprecated since Python 3.8, " 48 "and scheduled for removal in Python 3.10.", 49 DeprecationWarning, stacklevel=2) 50 reader = StreamReader(limit=limit, loop=loop) 51 protocol = StreamReaderProtocol(reader, loop=loop) 52 transport, _ = await loop.create_connection( 53 lambda: protocol, host, port, **kwds) 54 writer = StreamWriter(transport, protocol, reader, loop) 55 return reader, writer 56 57 58async def start_server(client_connected_cb, host=None, port=None, *, 59 loop=None, limit=_DEFAULT_LIMIT, **kwds): 60 """Start a socket server, call back for each client connected. 61 62 The first parameter, `client_connected_cb`, takes two parameters: 63 client_reader, client_writer. client_reader is a StreamReader 64 object, while client_writer is a StreamWriter object. This 65 parameter can either be a plain callback function or a coroutine; 66 if it is a coroutine, it will be automatically converted into a 67 Task. 68 69 The rest of the arguments are all the usual arguments to 70 loop.create_server() except protocol_factory; most common are 71 positional host and port, with various optional keyword arguments 72 following. The return value is the same as loop.create_server(). 73 74 Additional optional keyword arguments are loop (to set the event loop 75 instance to use) and limit (to set the buffer limit passed to the 76 StreamReader). 77 78 The return value is the same as loop.create_server(), i.e. a 79 Server object which can be used to stop the service. 80 """ 81 if loop is None: 82 loop = events.get_event_loop() 83 else: 84 warnings.warn("The loop argument is deprecated since Python 3.8, " 85 "and scheduled for removal in Python 3.10.", 86 DeprecationWarning, stacklevel=2) 87 88 def factory(): 89 reader = StreamReader(limit=limit, loop=loop) 90 protocol = StreamReaderProtocol(reader, client_connected_cb, 91 loop=loop) 92 return protocol 93 94 return await loop.create_server(factory, host, port, **kwds) 95 96 97if hasattr(socket, 'AF_UNIX'): 98 # UNIX Domain Sockets are supported on this platform 99 100 async def open_unix_connection(path=None, *, 101 loop=None, limit=_DEFAULT_LIMIT, **kwds): 102 """Similar to `open_connection` but works with UNIX Domain Sockets.""" 103 if loop is None: 104 loop = events.get_event_loop() 105 else: 106 warnings.warn("The loop argument is deprecated since Python 3.8, " 107 "and scheduled for removal in Python 3.10.", 108 DeprecationWarning, stacklevel=2) 109 reader = StreamReader(limit=limit, loop=loop) 110 protocol = StreamReaderProtocol(reader, loop=loop) 111 transport, _ = await loop.create_unix_connection( 112 lambda: protocol, path, **kwds) 113 writer = StreamWriter(transport, protocol, reader, loop) 114 return reader, writer 115 116 async def start_unix_server(client_connected_cb, path=None, *, 117 loop=None, limit=_DEFAULT_LIMIT, **kwds): 118 """Similar to `start_server` but works with UNIX Domain Sockets.""" 119 if loop is None: 120 loop = events.get_event_loop() 121 else: 122 warnings.warn("The loop argument is deprecated since Python 3.8, " 123 "and scheduled for removal in Python 3.10.", 124 DeprecationWarning, stacklevel=2) 125 126 def factory(): 127 reader = StreamReader(limit=limit, loop=loop) 128 protocol = StreamReaderProtocol(reader, client_connected_cb, 129 loop=loop) 130 return protocol 131 132 return await loop.create_unix_server(factory, path, **kwds) 133 134 135class FlowControlMixin(protocols.Protocol): 136 """Reusable flow control logic for StreamWriter.drain(). 137 138 This implements the protocol methods pause_writing(), 139 resume_writing() and connection_lost(). If the subclass overrides 140 these it must call the super methods. 141 142 StreamWriter.drain() must wait for _drain_helper() coroutine. 143 """ 144 145 def __init__(self, loop=None): 146 if loop is None: 147 self._loop = events.get_event_loop() 148 else: 149 self._loop = loop 150 self._paused = False 151 self._drain_waiter = None 152 self._connection_lost = False 153 154 def pause_writing(self): 155 assert not self._paused 156 self._paused = True 157 if self._loop.get_debug(): 158 logger.debug("%r pauses writing", self) 159 160 def resume_writing(self): 161 assert self._paused 162 self._paused = False 163 if self._loop.get_debug(): 164 logger.debug("%r resumes writing", self) 165 166 waiter = self._drain_waiter 167 if waiter is not None: 168 self._drain_waiter = None 169 if not waiter.done(): 170 waiter.set_result(None) 171 172 def connection_lost(self, exc): 173 self._connection_lost = True 174 # Wake up the writer if currently paused. 175 if not self._paused: 176 return 177 waiter = self._drain_waiter 178 if waiter is None: 179 return 180 self._drain_waiter = None 181 if waiter.done(): 182 return 183 if exc is None: 184 waiter.set_result(None) 185 else: 186 waiter.set_exception(exc) 187 188 async def _drain_helper(self): 189 if self._connection_lost: 190 raise ConnectionResetError('Connection lost') 191 if not self._paused: 192 return 193 waiter = self._drain_waiter 194 assert waiter is None or waiter.cancelled() 195 waiter = self._loop.create_future() 196 self._drain_waiter = waiter 197 await waiter 198 199 def _get_close_waiter(self, stream): 200 raise NotImplementedError 201 202 203class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): 204 """Helper class to adapt between Protocol and StreamReader. 205 206 (This is a helper class instead of making StreamReader itself a 207 Protocol subclass, because the StreamReader has other potential 208 uses, and to prevent the user of the StreamReader to accidentally 209 call inappropriate methods of the protocol.) 210 """ 211 212 _source_traceback = None 213 214 def __init__(self, stream_reader, client_connected_cb=None, loop=None): 215 super().__init__(loop=loop) 216 if stream_reader is not None: 217 self._stream_reader_wr = weakref.ref(stream_reader) 218 self._source_traceback = stream_reader._source_traceback 219 else: 220 self._stream_reader_wr = None 221 if client_connected_cb is not None: 222 # This is a stream created by the `create_server()` function. 223 # Keep a strong reference to the reader until a connection 224 # is established. 225 self._strong_reader = stream_reader 226 self._reject_connection = False 227 self._stream_writer = None 228 self._transport = None 229 self._client_connected_cb = client_connected_cb 230 self._over_ssl = False 231 self._closed = self._loop.create_future() 232 233 @property 234 def _stream_reader(self): 235 if self._stream_reader_wr is None: 236 return None 237 return self._stream_reader_wr() 238 239 def connection_made(self, transport): 240 if self._reject_connection: 241 context = { 242 'message': ('An open stream was garbage collected prior to ' 243 'establishing network connection; ' 244 'call "stream.close()" explicitly.') 245 } 246 if self._source_traceback: 247 context['source_traceback'] = self._source_traceback 248 self._loop.call_exception_handler(context) 249 transport.abort() 250 return 251 self._transport = transport 252 reader = self._stream_reader 253 if reader is not None: 254 reader.set_transport(transport) 255 self._over_ssl = transport.get_extra_info('sslcontext') is not None 256 if self._client_connected_cb is not None: 257 self._stream_writer = StreamWriter(transport, self, 258 reader, 259 self._loop) 260 res = self._client_connected_cb(reader, 261 self._stream_writer) 262 if coroutines.iscoroutine(res): 263 self._loop.create_task(res) 264 self._strong_reader = None 265 266 def connection_lost(self, exc): 267 reader = self._stream_reader 268 if reader is not None: 269 if exc is None: 270 reader.feed_eof() 271 else: 272 reader.set_exception(exc) 273 if not self._closed.done(): 274 if exc is None: 275 self._closed.set_result(None) 276 else: 277 self._closed.set_exception(exc) 278 super().connection_lost(exc) 279 self._stream_reader_wr = None 280 self._stream_writer = None 281 self._transport = None 282 283 def data_received(self, data): 284 reader = self._stream_reader 285 if reader is not None: 286 reader.feed_data(data) 287 288 def eof_received(self): 289 reader = self._stream_reader 290 if reader is not None: 291 reader.feed_eof() 292 if self._over_ssl: 293 # Prevent a warning in SSLProtocol.eof_received: 294 # "returning true from eof_received() 295 # has no effect when using ssl" 296 return False 297 return True 298 299 def _get_close_waiter(self, stream): 300 return self._closed 301 302 def __del__(self): 303 # Prevent reports about unhandled exceptions. 304 # Better than self._closed._log_traceback = False hack 305 closed = self._closed 306 if closed.done() and not closed.cancelled(): 307 closed.exception() 308 309 310class StreamWriter: 311 """Wraps a Transport. 312 313 This exposes write(), writelines(), [can_]write_eof(), 314 get_extra_info() and close(). It adds drain() which returns an 315 optional Future on which you can wait for flow control. It also 316 adds a transport property which references the Transport 317 directly. 318 """ 319 320 def __init__(self, transport, protocol, reader, loop): 321 self._transport = transport 322 self._protocol = protocol 323 # drain() expects that the reader has an exception() method 324 assert reader is None or isinstance(reader, StreamReader) 325 self._reader = reader 326 self._loop = loop 327 self._complete_fut = self._loop.create_future() 328 self._complete_fut.set_result(None) 329 330 def __repr__(self): 331 info = [self.__class__.__name__, f'transport={self._transport!r}'] 332 if self._reader is not None: 333 info.append(f'reader={self._reader!r}') 334 return '<{}>'.format(' '.join(info)) 335 336 @property 337 def transport(self): 338 return self._transport 339 340 def write(self, data): 341 self._transport.write(data) 342 343 def writelines(self, data): 344 self._transport.writelines(data) 345 346 def write_eof(self): 347 return self._transport.write_eof() 348 349 def can_write_eof(self): 350 return self._transport.can_write_eof() 351 352 def close(self): 353 return self._transport.close() 354 355 def is_closing(self): 356 return self._transport.is_closing() 357 358 async def wait_closed(self): 359 await self._protocol._get_close_waiter(self) 360 361 def get_extra_info(self, name, default=None): 362 return self._transport.get_extra_info(name, default) 363 364 async def drain(self): 365 """Flush the write buffer. 366 367 The intended use is to write 368 369 w.write(data) 370 await w.drain() 371 """ 372 if self._reader is not None: 373 exc = self._reader.exception() 374 if exc is not None: 375 raise exc 376 if self._transport.is_closing(): 377 # Wait for protocol.connection_lost() call 378 # Raise connection closing error if any, 379 # ConnectionResetError otherwise 380 # Yield to the event loop so connection_lost() may be 381 # called. Without this, _drain_helper() would return 382 # immediately, and code that calls 383 # write(...); await drain() 384 # in a loop would never call connection_lost(), so it 385 # would not see an error when the socket is closed. 386 await sleep(0) 387 await self._protocol._drain_helper() 388 389 390class StreamReader: 391 392 _source_traceback = None 393 394 def __init__(self, limit=_DEFAULT_LIMIT, loop=None): 395 # The line length limit is a security feature; 396 # it also doubles as half the buffer limit. 397 398 if limit <= 0: 399 raise ValueError('Limit cannot be <= 0') 400 401 self._limit = limit 402 if loop is None: 403 self._loop = events.get_event_loop() 404 else: 405 self._loop = loop 406 self._buffer = bytearray() 407 self._eof = False # Whether we're done. 408 self._waiter = None # A future used by _wait_for_data() 409 self._exception = None 410 self._transport = None 411 self._paused = False 412 if self._loop.get_debug(): 413 self._source_traceback = format_helpers.extract_stack( 414 sys._getframe(1)) 415 416 def __repr__(self): 417 info = ['StreamReader'] 418 if self._buffer: 419 info.append(f'{len(self._buffer)} bytes') 420 if self._eof: 421 info.append('eof') 422 if self._limit != _DEFAULT_LIMIT: 423 info.append(f'limit={self._limit}') 424 if self._waiter: 425 info.append(f'waiter={self._waiter!r}') 426 if self._exception: 427 info.append(f'exception={self._exception!r}') 428 if self._transport: 429 info.append(f'transport={self._transport!r}') 430 if self._paused: 431 info.append('paused') 432 return '<{}>'.format(' '.join(info)) 433 434 def exception(self): 435 return self._exception 436 437 def set_exception(self, exc): 438 self._exception = exc 439 440 waiter = self._waiter 441 if waiter is not None: 442 self._waiter = None 443 if not waiter.cancelled(): 444 waiter.set_exception(exc) 445 446 def _wakeup_waiter(self): 447 """Wakeup read*() functions waiting for data or EOF.""" 448 waiter = self._waiter 449 if waiter is not None: 450 self._waiter = None 451 if not waiter.cancelled(): 452 waiter.set_result(None) 453 454 def set_transport(self, transport): 455 assert self._transport is None, 'Transport already set' 456 self._transport = transport 457 458 def _maybe_resume_transport(self): 459 if self._paused and len(self._buffer) <= self._limit: 460 self._paused = False 461 self._transport.resume_reading() 462 463 def feed_eof(self): 464 self._eof = True 465 self._wakeup_waiter() 466 467 def at_eof(self): 468 """Return True if the buffer is empty and 'feed_eof' was called.""" 469 return self._eof and not self._buffer 470 471 def feed_data(self, data): 472 assert not self._eof, 'feed_data after feed_eof' 473 474 if not data: 475 return 476 477 self._buffer.extend(data) 478 self._wakeup_waiter() 479 480 if (self._transport is not None and 481 not self._paused and 482 len(self._buffer) > 2 * self._limit): 483 try: 484 self._transport.pause_reading() 485 except NotImplementedError: 486 # The transport can't be paused. 487 # We'll just have to buffer all data. 488 # Forget the transport so we don't keep trying. 489 self._transport = None 490 else: 491 self._paused = True 492 493 async def _wait_for_data(self, func_name): 494 """Wait until feed_data() or feed_eof() is called. 495 496 If stream was paused, automatically resume it. 497 """ 498 # StreamReader uses a future to link the protocol feed_data() method 499 # to a read coroutine. Running two read coroutines at the same time 500 # would have an unexpected behaviour. It would not possible to know 501 # which coroutine would get the next data. 502 if self._waiter is not None: 503 raise RuntimeError( 504 f'{func_name}() called while another coroutine is ' 505 f'already waiting for incoming data') 506 507 assert not self._eof, '_wait_for_data after EOF' 508 509 # Waiting for data while paused will make deadlock, so prevent it. 510 # This is essential for readexactly(n) for case when n > self._limit. 511 if self._paused: 512 self._paused = False 513 self._transport.resume_reading() 514 515 self._waiter = self._loop.create_future() 516 try: 517 await self._waiter 518 finally: 519 self._waiter = None 520 521 async def readline(self): 522 """Read chunk of data from the stream until newline (b'\n') is found. 523 524 On success, return chunk that ends with newline. If only partial 525 line can be read due to EOF, return incomplete line without 526 terminating newline. When EOF was reached while no bytes read, empty 527 bytes object is returned. 528 529 If limit is reached, ValueError will be raised. In that case, if 530 newline was found, complete line including newline will be removed 531 from internal buffer. Else, internal buffer will be cleared. Limit is 532 compared against part of the line without newline. 533 534 If stream was paused, this function will automatically resume it if 535 needed. 536 """ 537 sep = b'\n' 538 seplen = len(sep) 539 try: 540 line = await self.readuntil(sep) 541 except exceptions.IncompleteReadError as e: 542 return e.partial 543 except exceptions.LimitOverrunError as e: 544 if self._buffer.startswith(sep, e.consumed): 545 del self._buffer[:e.consumed + seplen] 546 else: 547 self._buffer.clear() 548 self._maybe_resume_transport() 549 raise ValueError(e.args[0]) 550 return line 551 552 async def readuntil(self, separator=b'\n'): 553 """Read data from the stream until ``separator`` is found. 554 555 On success, the data and separator will be removed from the 556 internal buffer (consumed). Returned data will include the 557 separator at the end. 558 559 Configured stream limit is used to check result. Limit sets the 560 maximal length of data that can be returned, not counting the 561 separator. 562 563 If an EOF occurs and the complete separator is still not found, 564 an IncompleteReadError exception will be raised, and the internal 565 buffer will be reset. The IncompleteReadError.partial attribute 566 may contain the separator partially. 567 568 If the data cannot be read because of over limit, a 569 LimitOverrunError exception will be raised, and the data 570 will be left in the internal buffer, so it can be read again. 571 """ 572 seplen = len(separator) 573 if seplen == 0: 574 raise ValueError('Separator should be at least one-byte string') 575 576 if self._exception is not None: 577 raise self._exception 578 579 # Consume whole buffer except last bytes, which length is 580 # one less than seplen. Let's check corner cases with 581 # separator='SEPARATOR': 582 # * we have received almost complete separator (without last 583 # byte). i.e buffer='some textSEPARATO'. In this case we 584 # can safely consume len(separator) - 1 bytes. 585 # * last byte of buffer is first byte of separator, i.e. 586 # buffer='abcdefghijklmnopqrS'. We may safely consume 587 # everything except that last byte, but this require to 588 # analyze bytes of buffer that match partial separator. 589 # This is slow and/or require FSM. For this case our 590 # implementation is not optimal, since require rescanning 591 # of data that is known to not belong to separator. In 592 # real world, separator will not be so long to notice 593 # performance problems. Even when reading MIME-encoded 594 # messages :) 595 596 # `offset` is the number of bytes from the beginning of the buffer 597 # where there is no occurrence of `separator`. 598 offset = 0 599 600 # Loop until we find `separator` in the buffer, exceed the buffer size, 601 # or an EOF has happened. 602 while True: 603 buflen = len(self._buffer) 604 605 # Check if we now have enough data in the buffer for `separator` to 606 # fit. 607 if buflen - offset >= seplen: 608 isep = self._buffer.find(separator, offset) 609 610 if isep != -1: 611 # `separator` is in the buffer. `isep` will be used later 612 # to retrieve the data. 613 break 614 615 # see upper comment for explanation. 616 offset = buflen + 1 - seplen 617 if offset > self._limit: 618 raise exceptions.LimitOverrunError( 619 'Separator is not found, and chunk exceed the limit', 620 offset) 621 622 # Complete message (with full separator) may be present in buffer 623 # even when EOF flag is set. This may happen when the last chunk 624 # adds data which makes separator be found. That's why we check for 625 # EOF *ater* inspecting the buffer. 626 if self._eof: 627 chunk = bytes(self._buffer) 628 self._buffer.clear() 629 raise exceptions.IncompleteReadError(chunk, None) 630 631 # _wait_for_data() will resume reading if stream was paused. 632 await self._wait_for_data('readuntil') 633 634 if isep > self._limit: 635 raise exceptions.LimitOverrunError( 636 'Separator is found, but chunk is longer than limit', isep) 637 638 chunk = self._buffer[:isep + seplen] 639 del self._buffer[:isep + seplen] 640 self._maybe_resume_transport() 641 return bytes(chunk) 642 643 async def read(self, n=-1): 644 """Read up to `n` bytes from the stream. 645 646 If n is not provided, or set to -1, read until EOF and return all read 647 bytes. If the EOF was received and the internal buffer is empty, return 648 an empty bytes object. 649 650 If n is zero, return empty bytes object immediately. 651 652 If n is positive, this function try to read `n` bytes, and may return 653 less or equal bytes than requested, but at least one byte. If EOF was 654 received before any byte is read, this function returns empty byte 655 object. 656 657 Returned value is not limited with limit, configured at stream 658 creation. 659 660 If stream was paused, this function will automatically resume it if 661 needed. 662 """ 663 664 if self._exception is not None: 665 raise self._exception 666 667 if n == 0: 668 return b'' 669 670 if n < 0: 671 # This used to just loop creating a new waiter hoping to 672 # collect everything in self._buffer, but that would 673 # deadlock if the subprocess sends more than self.limit 674 # bytes. So just call self.read(self._limit) until EOF. 675 blocks = [] 676 while True: 677 block = await self.read(self._limit) 678 if not block: 679 break 680 blocks.append(block) 681 return b''.join(blocks) 682 683 if not self._buffer and not self._eof: 684 await self._wait_for_data('read') 685 686 # This will work right even if buffer is less than n bytes 687 data = bytes(self._buffer[:n]) 688 del self._buffer[:n] 689 690 self._maybe_resume_transport() 691 return data 692 693 async def readexactly(self, n): 694 """Read exactly `n` bytes. 695 696 Raise an IncompleteReadError if EOF is reached before `n` bytes can be 697 read. The IncompleteReadError.partial attribute of the exception will 698 contain the partial read bytes. 699 700 if n is zero, return empty bytes object. 701 702 Returned value is not limited with limit, configured at stream 703 creation. 704 705 If stream was paused, this function will automatically resume it if 706 needed. 707 """ 708 if n < 0: 709 raise ValueError('readexactly size can not be less than zero') 710 711 if self._exception is not None: 712 raise self._exception 713 714 if n == 0: 715 return b'' 716 717 while len(self._buffer) < n: 718 if self._eof: 719 incomplete = bytes(self._buffer) 720 self._buffer.clear() 721 raise exceptions.IncompleteReadError(incomplete, n) 722 723 await self._wait_for_data('readexactly') 724 725 if len(self._buffer) == n: 726 data = bytes(self._buffer) 727 self._buffer.clear() 728 else: 729 data = bytes(self._buffer[:n]) 730 del self._buffer[:n] 731 self._maybe_resume_transport() 732 return data 733 734 def __aiter__(self): 735 return self 736 737 async def __anext__(self): 738 val = await self.readline() 739 if val == b'': 740 raise StopAsyncIteration 741 return val 742