1__all__ = ( 2 'StreamReader', 'StreamWriter', 'StreamReaderProtocol', 3 'open_connection', 'start_server') 4 5import socket 6import sys 7import warnings 8import weakref 9 10if hasattr(socket, 'AF_UNIX'): 11 __all__ += ('open_unix_connection', 'start_unix_server') 12 13from . import coroutines 14from . import events 15from . import exceptions 16from . import format_helpers 17from . import protocols 18from .log import logger 19from .tasks import sleep 20 21 22_DEFAULT_LIMIT = 2 ** 16 # 64 KiB 23 24 25async def open_connection(host=None, port=None, *, 26 limit=_DEFAULT_LIMIT, **kwds): 27 """A wrapper for create_connection() returning a (reader, writer) pair. 28 29 The reader returned is a StreamReader instance; the writer is a 30 StreamWriter instance. 31 32 The arguments are all the usual arguments to create_connection() 33 except protocol_factory; most common are positional host and port, 34 with various optional keyword arguments following. 35 36 Additional optional keyword arguments are loop (to set the event loop 37 instance to use) and limit (to set the buffer limit passed to the 38 StreamReader). 39 40 (If you want to customize the StreamReader and/or 41 StreamReaderProtocol classes, just copy the code -- there's 42 really nothing special here except some convenience.) 43 """ 44 loop = events.get_running_loop() 45 reader = StreamReader(limit=limit, loop=loop) 46 protocol = StreamReaderProtocol(reader, loop=loop) 47 transport, _ = await loop.create_connection( 48 lambda: protocol, host, port, **kwds) 49 writer = StreamWriter(transport, protocol, reader, loop) 50 return reader, writer 51 52 53async def start_server(client_connected_cb, host=None, port=None, *, 54 limit=_DEFAULT_LIMIT, **kwds): 55 """Start a socket server, call back for each client connected. 56 57 The first parameter, `client_connected_cb`, takes two parameters: 58 client_reader, client_writer. client_reader is a StreamReader 59 object, while client_writer is a StreamWriter object. This 60 parameter can either be a plain callback function or a coroutine; 61 if it is a coroutine, it will be automatically converted into a 62 Task. 63 64 The rest of the arguments are all the usual arguments to 65 loop.create_server() except protocol_factory; most common are 66 positional host and port, with various optional keyword arguments 67 following. The return value is the same as loop.create_server(). 68 69 Additional optional keyword arguments are loop (to set the event loop 70 instance to use) and limit (to set the buffer limit passed to the 71 StreamReader). 72 73 The return value is the same as loop.create_server(), i.e. a 74 Server object which can be used to stop the service. 75 """ 76 loop = events.get_running_loop() 77 78 def factory(): 79 reader = StreamReader(limit=limit, loop=loop) 80 protocol = StreamReaderProtocol(reader, client_connected_cb, 81 loop=loop) 82 return protocol 83 84 return await loop.create_server(factory, host, port, **kwds) 85 86 87if hasattr(socket, 'AF_UNIX'): 88 # UNIX Domain Sockets are supported on this platform 89 90 async def open_unix_connection(path=None, *, 91 limit=_DEFAULT_LIMIT, **kwds): 92 """Similar to `open_connection` but works with UNIX Domain Sockets.""" 93 loop = events.get_running_loop() 94 95 reader = StreamReader(limit=limit, loop=loop) 96 protocol = StreamReaderProtocol(reader, loop=loop) 97 transport, _ = await loop.create_unix_connection( 98 lambda: protocol, path, **kwds) 99 writer = StreamWriter(transport, protocol, reader, loop) 100 return reader, writer 101 102 async def start_unix_server(client_connected_cb, path=None, *, 103 limit=_DEFAULT_LIMIT, **kwds): 104 """Similar to `start_server` but works with UNIX Domain Sockets.""" 105 loop = events.get_running_loop() 106 107 def factory(): 108 reader = StreamReader(limit=limit, loop=loop) 109 protocol = StreamReaderProtocol(reader, client_connected_cb, 110 loop=loop) 111 return protocol 112 113 return await loop.create_unix_server(factory, path, **kwds) 114 115 116class FlowControlMixin(protocols.Protocol): 117 """Reusable flow control logic for StreamWriter.drain(). 118 119 This implements the protocol methods pause_writing(), 120 resume_writing() and connection_lost(). If the subclass overrides 121 these it must call the super methods. 122 123 StreamWriter.drain() must wait for _drain_helper() coroutine. 124 """ 125 126 def __init__(self, loop=None): 127 if loop is None: 128 self._loop = events._get_event_loop(stacklevel=4) 129 else: 130 self._loop = loop 131 self._paused = False 132 self._drain_waiter = None 133 self._connection_lost = False 134 135 def pause_writing(self): 136 assert not self._paused 137 self._paused = True 138 if self._loop.get_debug(): 139 logger.debug("%r pauses writing", self) 140 141 def resume_writing(self): 142 assert self._paused 143 self._paused = False 144 if self._loop.get_debug(): 145 logger.debug("%r resumes writing", self) 146 147 waiter = self._drain_waiter 148 if waiter is not None: 149 self._drain_waiter = None 150 if not waiter.done(): 151 waiter.set_result(None) 152 153 def connection_lost(self, exc): 154 self._connection_lost = True 155 # Wake up the writer if currently paused. 156 if not self._paused: 157 return 158 waiter = self._drain_waiter 159 if waiter is None: 160 return 161 self._drain_waiter = None 162 if waiter.done(): 163 return 164 if exc is None: 165 waiter.set_result(None) 166 else: 167 waiter.set_exception(exc) 168 169 async def _drain_helper(self): 170 if self._connection_lost: 171 raise ConnectionResetError('Connection lost') 172 if not self._paused: 173 return 174 waiter = self._drain_waiter 175 assert waiter is None or waiter.cancelled() 176 waiter = self._loop.create_future() 177 self._drain_waiter = waiter 178 await waiter 179 180 def _get_close_waiter(self, stream): 181 raise NotImplementedError 182 183 184class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): 185 """Helper class to adapt between Protocol and StreamReader. 186 187 (This is a helper class instead of making StreamReader itself a 188 Protocol subclass, because the StreamReader has other potential 189 uses, and to prevent the user of the StreamReader to accidentally 190 call inappropriate methods of the protocol.) 191 """ 192 193 _source_traceback = None 194 195 def __init__(self, stream_reader, client_connected_cb=None, loop=None): 196 super().__init__(loop=loop) 197 if stream_reader is not None: 198 self._stream_reader_wr = weakref.ref(stream_reader) 199 self._source_traceback = stream_reader._source_traceback 200 else: 201 self._stream_reader_wr = None 202 if client_connected_cb is not None: 203 # This is a stream created by the `create_server()` function. 204 # Keep a strong reference to the reader until a connection 205 # is established. 206 self._strong_reader = stream_reader 207 self._reject_connection = False 208 self._stream_writer = None 209 self._transport = None 210 self._client_connected_cb = client_connected_cb 211 self._over_ssl = False 212 self._closed = self._loop.create_future() 213 214 @property 215 def _stream_reader(self): 216 if self._stream_reader_wr is None: 217 return None 218 return self._stream_reader_wr() 219 220 def connection_made(self, transport): 221 if self._reject_connection: 222 context = { 223 'message': ('An open stream was garbage collected prior to ' 224 'establishing network connection; ' 225 'call "stream.close()" explicitly.') 226 } 227 if self._source_traceback: 228 context['source_traceback'] = self._source_traceback 229 self._loop.call_exception_handler(context) 230 transport.abort() 231 return 232 self._transport = transport 233 reader = self._stream_reader 234 if reader is not None: 235 reader.set_transport(transport) 236 self._over_ssl = transport.get_extra_info('sslcontext') is not None 237 if self._client_connected_cb is not None: 238 self._stream_writer = StreamWriter(transport, self, 239 reader, 240 self._loop) 241 res = self._client_connected_cb(reader, 242 self._stream_writer) 243 if coroutines.iscoroutine(res): 244 self._loop.create_task(res) 245 self._strong_reader = None 246 247 def connection_lost(self, exc): 248 reader = self._stream_reader 249 if reader is not None: 250 if exc is None: 251 reader.feed_eof() 252 else: 253 reader.set_exception(exc) 254 if not self._closed.done(): 255 if exc is None: 256 self._closed.set_result(None) 257 else: 258 self._closed.set_exception(exc) 259 super().connection_lost(exc) 260 self._stream_reader_wr = None 261 self._stream_writer = None 262 self._transport = None 263 264 def data_received(self, data): 265 reader = self._stream_reader 266 if reader is not None: 267 reader.feed_data(data) 268 269 def eof_received(self): 270 reader = self._stream_reader 271 if reader is not None: 272 reader.feed_eof() 273 if self._over_ssl: 274 # Prevent a warning in SSLProtocol.eof_received: 275 # "returning true from eof_received() 276 # has no effect when using ssl" 277 return False 278 return True 279 280 def _get_close_waiter(self, stream): 281 return self._closed 282 283 def __del__(self): 284 # Prevent reports about unhandled exceptions. 285 # Better than self._closed._log_traceback = False hack 286 try: 287 closed = self._closed 288 except AttributeError: 289 pass # failed constructor 290 else: 291 if closed.done() and not closed.cancelled(): 292 closed.exception() 293 294 295class StreamWriter: 296 """Wraps a Transport. 297 298 This exposes write(), writelines(), [can_]write_eof(), 299 get_extra_info() and close(). It adds drain() which returns an 300 optional Future on which you can wait for flow control. It also 301 adds a transport property which references the Transport 302 directly. 303 """ 304 305 def __init__(self, transport, protocol, reader, loop): 306 self._transport = transport 307 self._protocol = protocol 308 # drain() expects that the reader has an exception() method 309 assert reader is None or isinstance(reader, StreamReader) 310 self._reader = reader 311 self._loop = loop 312 self._complete_fut = self._loop.create_future() 313 self._complete_fut.set_result(None) 314 315 def __repr__(self): 316 info = [self.__class__.__name__, f'transport={self._transport!r}'] 317 if self._reader is not None: 318 info.append(f'reader={self._reader!r}') 319 return '<{}>'.format(' '.join(info)) 320 321 @property 322 def transport(self): 323 return self._transport 324 325 def write(self, data): 326 self._transport.write(data) 327 328 def writelines(self, data): 329 self._transport.writelines(data) 330 331 def write_eof(self): 332 return self._transport.write_eof() 333 334 def can_write_eof(self): 335 return self._transport.can_write_eof() 336 337 def close(self): 338 return self._transport.close() 339 340 def is_closing(self): 341 return self._transport.is_closing() 342 343 async def wait_closed(self): 344 await self._protocol._get_close_waiter(self) 345 346 def get_extra_info(self, name, default=None): 347 return self._transport.get_extra_info(name, default) 348 349 async def drain(self): 350 """Flush the write buffer. 351 352 The intended use is to write 353 354 w.write(data) 355 await w.drain() 356 """ 357 if self._reader is not None: 358 exc = self._reader.exception() 359 if exc is not None: 360 raise exc 361 if self._transport.is_closing(): 362 # Wait for protocol.connection_lost() call 363 # Raise connection closing error if any, 364 # ConnectionResetError otherwise 365 # Yield to the event loop so connection_lost() may be 366 # called. Without this, _drain_helper() would return 367 # immediately, and code that calls 368 # write(...); await drain() 369 # in a loop would never call connection_lost(), so it 370 # would not see an error when the socket is closed. 371 await sleep(0) 372 await self._protocol._drain_helper() 373 374 375class StreamReader: 376 377 _source_traceback = None 378 379 def __init__(self, limit=_DEFAULT_LIMIT, loop=None): 380 # The line length limit is a security feature; 381 # it also doubles as half the buffer limit. 382 383 if limit <= 0: 384 raise ValueError('Limit cannot be <= 0') 385 386 self._limit = limit 387 if loop is None: 388 self._loop = events._get_event_loop() 389 else: 390 self._loop = loop 391 self._buffer = bytearray() 392 self._eof = False # Whether we're done. 393 self._waiter = None # A future used by _wait_for_data() 394 self._exception = None 395 self._transport = None 396 self._paused = False 397 if self._loop.get_debug(): 398 self._source_traceback = format_helpers.extract_stack( 399 sys._getframe(1)) 400 401 def __repr__(self): 402 info = ['StreamReader'] 403 if self._buffer: 404 info.append(f'{len(self._buffer)} bytes') 405 if self._eof: 406 info.append('eof') 407 if self._limit != _DEFAULT_LIMIT: 408 info.append(f'limit={self._limit}') 409 if self._waiter: 410 info.append(f'waiter={self._waiter!r}') 411 if self._exception: 412 info.append(f'exception={self._exception!r}') 413 if self._transport: 414 info.append(f'transport={self._transport!r}') 415 if self._paused: 416 info.append('paused') 417 return '<{}>'.format(' '.join(info)) 418 419 def exception(self): 420 return self._exception 421 422 def set_exception(self, exc): 423 self._exception = exc 424 425 waiter = self._waiter 426 if waiter is not None: 427 self._waiter = None 428 if not waiter.cancelled(): 429 waiter.set_exception(exc) 430 431 def _wakeup_waiter(self): 432 """Wakeup read*() functions waiting for data or EOF.""" 433 waiter = self._waiter 434 if waiter is not None: 435 self._waiter = None 436 if not waiter.cancelled(): 437 waiter.set_result(None) 438 439 def set_transport(self, transport): 440 assert self._transport is None, 'Transport already set' 441 self._transport = transport 442 443 def _maybe_resume_transport(self): 444 if self._paused and len(self._buffer) <= self._limit: 445 self._paused = False 446 self._transport.resume_reading() 447 448 def feed_eof(self): 449 self._eof = True 450 self._wakeup_waiter() 451 452 def at_eof(self): 453 """Return True if the buffer is empty and 'feed_eof' was called.""" 454 return self._eof and not self._buffer 455 456 def feed_data(self, data): 457 assert not self._eof, 'feed_data after feed_eof' 458 459 if not data: 460 return 461 462 self._buffer.extend(data) 463 self._wakeup_waiter() 464 465 if (self._transport is not None and 466 not self._paused and 467 len(self._buffer) > 2 * self._limit): 468 try: 469 self._transport.pause_reading() 470 except NotImplementedError: 471 # The transport can't be paused. 472 # We'll just have to buffer all data. 473 # Forget the transport so we don't keep trying. 474 self._transport = None 475 else: 476 self._paused = True 477 478 async def _wait_for_data(self, func_name): 479 """Wait until feed_data() or feed_eof() is called. 480 481 If stream was paused, automatically resume it. 482 """ 483 # StreamReader uses a future to link the protocol feed_data() method 484 # to a read coroutine. Running two read coroutines at the same time 485 # would have an unexpected behaviour. It would not possible to know 486 # which coroutine would get the next data. 487 if self._waiter is not None: 488 raise RuntimeError( 489 f'{func_name}() called while another coroutine is ' 490 f'already waiting for incoming data') 491 492 assert not self._eof, '_wait_for_data after EOF' 493 494 # Waiting for data while paused will make deadlock, so prevent it. 495 # This is essential for readexactly(n) for case when n > self._limit. 496 if self._paused: 497 self._paused = False 498 self._transport.resume_reading() 499 500 self._waiter = self._loop.create_future() 501 try: 502 await self._waiter 503 finally: 504 self._waiter = None 505 506 async def readline(self): 507 """Read chunk of data from the stream until newline (b'\n') is found. 508 509 On success, return chunk that ends with newline. If only partial 510 line can be read due to EOF, return incomplete line without 511 terminating newline. When EOF was reached while no bytes read, empty 512 bytes object is returned. 513 514 If limit is reached, ValueError will be raised. In that case, if 515 newline was found, complete line including newline will be removed 516 from internal buffer. Else, internal buffer will be cleared. Limit is 517 compared against part of the line without newline. 518 519 If stream was paused, this function will automatically resume it if 520 needed. 521 """ 522 sep = b'\n' 523 seplen = len(sep) 524 try: 525 line = await self.readuntil(sep) 526 except exceptions.IncompleteReadError as e: 527 return e.partial 528 except exceptions.LimitOverrunError as e: 529 if self._buffer.startswith(sep, e.consumed): 530 del self._buffer[:e.consumed + seplen] 531 else: 532 self._buffer.clear() 533 self._maybe_resume_transport() 534 raise ValueError(e.args[0]) 535 return line 536 537 async def readuntil(self, separator=b'\n'): 538 """Read data from the stream until ``separator`` is found. 539 540 On success, the data and separator will be removed from the 541 internal buffer (consumed). Returned data will include the 542 separator at the end. 543 544 Configured stream limit is used to check result. Limit sets the 545 maximal length of data that can be returned, not counting the 546 separator. 547 548 If an EOF occurs and the complete separator is still not found, 549 an IncompleteReadError exception will be raised, and the internal 550 buffer will be reset. The IncompleteReadError.partial attribute 551 may contain the separator partially. 552 553 If the data cannot be read because of over limit, a 554 LimitOverrunError exception will be raised, and the data 555 will be left in the internal buffer, so it can be read again. 556 """ 557 seplen = len(separator) 558 if seplen == 0: 559 raise ValueError('Separator should be at least one-byte string') 560 561 if self._exception is not None: 562 raise self._exception 563 564 # Consume whole buffer except last bytes, which length is 565 # one less than seplen. Let's check corner cases with 566 # separator='SEPARATOR': 567 # * we have received almost complete separator (without last 568 # byte). i.e buffer='some textSEPARATO'. In this case we 569 # can safely consume len(separator) - 1 bytes. 570 # * last byte of buffer is first byte of separator, i.e. 571 # buffer='abcdefghijklmnopqrS'. We may safely consume 572 # everything except that last byte, but this require to 573 # analyze bytes of buffer that match partial separator. 574 # This is slow and/or require FSM. For this case our 575 # implementation is not optimal, since require rescanning 576 # of data that is known to not belong to separator. In 577 # real world, separator will not be so long to notice 578 # performance problems. Even when reading MIME-encoded 579 # messages :) 580 581 # `offset` is the number of bytes from the beginning of the buffer 582 # where there is no occurrence of `separator`. 583 offset = 0 584 585 # Loop until we find `separator` in the buffer, exceed the buffer size, 586 # or an EOF has happened. 587 while True: 588 buflen = len(self._buffer) 589 590 # Check if we now have enough data in the buffer for `separator` to 591 # fit. 592 if buflen - offset >= seplen: 593 isep = self._buffer.find(separator, offset) 594 595 if isep != -1: 596 # `separator` is in the buffer. `isep` will be used later 597 # to retrieve the data. 598 break 599 600 # see upper comment for explanation. 601 offset = buflen + 1 - seplen 602 if offset > self._limit: 603 raise exceptions.LimitOverrunError( 604 'Separator is not found, and chunk exceed the limit', 605 offset) 606 607 # Complete message (with full separator) may be present in buffer 608 # even when EOF flag is set. This may happen when the last chunk 609 # adds data which makes separator be found. That's why we check for 610 # EOF *ater* inspecting the buffer. 611 if self._eof: 612 chunk = bytes(self._buffer) 613 self._buffer.clear() 614 raise exceptions.IncompleteReadError(chunk, None) 615 616 # _wait_for_data() will resume reading if stream was paused. 617 await self._wait_for_data('readuntil') 618 619 if isep > self._limit: 620 raise exceptions.LimitOverrunError( 621 'Separator is found, but chunk is longer than limit', isep) 622 623 chunk = self._buffer[:isep + seplen] 624 del self._buffer[:isep + seplen] 625 self._maybe_resume_transport() 626 return bytes(chunk) 627 628 async def read(self, n=-1): 629 """Read up to `n` bytes from the stream. 630 631 If n is not provided, or set to -1, read until EOF and return all read 632 bytes. If the EOF was received and the internal buffer is empty, return 633 an empty bytes object. 634 635 If n is zero, return empty bytes object immediately. 636 637 If n is positive, this function try to read `n` bytes, and may return 638 less or equal bytes than requested, but at least one byte. If EOF was 639 received before any byte is read, this function returns empty byte 640 object. 641 642 Returned value is not limited with limit, configured at stream 643 creation. 644 645 If stream was paused, this function will automatically resume it if 646 needed. 647 """ 648 649 if self._exception is not None: 650 raise self._exception 651 652 if n == 0: 653 return b'' 654 655 if n < 0: 656 # This used to just loop creating a new waiter hoping to 657 # collect everything in self._buffer, but that would 658 # deadlock if the subprocess sends more than self.limit 659 # bytes. So just call self.read(self._limit) until EOF. 660 blocks = [] 661 while True: 662 block = await self.read(self._limit) 663 if not block: 664 break 665 blocks.append(block) 666 return b''.join(blocks) 667 668 if not self._buffer and not self._eof: 669 await self._wait_for_data('read') 670 671 # This will work right even if buffer is less than n bytes 672 data = bytes(self._buffer[:n]) 673 del self._buffer[:n] 674 675 self._maybe_resume_transport() 676 return data 677 678 async def readexactly(self, n): 679 """Read exactly `n` bytes. 680 681 Raise an IncompleteReadError if EOF is reached before `n` bytes can be 682 read. The IncompleteReadError.partial attribute of the exception will 683 contain the partial read bytes. 684 685 if n is zero, return empty bytes object. 686 687 Returned value is not limited with limit, configured at stream 688 creation. 689 690 If stream was paused, this function will automatically resume it if 691 needed. 692 """ 693 if n < 0: 694 raise ValueError('readexactly size can not be less than zero') 695 696 if self._exception is not None: 697 raise self._exception 698 699 if n == 0: 700 return b'' 701 702 while len(self._buffer) < n: 703 if self._eof: 704 incomplete = bytes(self._buffer) 705 self._buffer.clear() 706 raise exceptions.IncompleteReadError(incomplete, n) 707 708 await self._wait_for_data('readexactly') 709 710 if len(self._buffer) == n: 711 data = bytes(self._buffer) 712 self._buffer.clear() 713 else: 714 data = bytes(self._buffer[:n]) 715 del self._buffer[:n] 716 self._maybe_resume_transport() 717 return data 718 719 def __aiter__(self): 720 return self 721 722 async def __anext__(self): 723 val = await self.readline() 724 if val == b'': 725 raise StopAsyncIteration 726 return val 727