1"""Stream-related things.""" 2 3__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', 4 'open_connection', 'start_server', 5 'IncompleteReadError', 6 'LimitOverrunError', 7 ] 8 9import socket 10 11if hasattr(socket, 'AF_UNIX'): 12 __all__.extend(['open_unix_connection', 'start_unix_server']) 13 14from . import coroutines 15from . import compat 16from . import events 17from . import protocols 18from .coroutines import coroutine 19from .log import logger 20 21 22_DEFAULT_LIMIT = 2 ** 16 23 24 25class IncompleteReadError(EOFError): 26 """ 27 Incomplete read error. Attributes: 28 29 - partial: read bytes string before the end of stream was reached 30 - expected: total number of expected bytes (or None if unknown) 31 """ 32 def __init__(self, partial, expected): 33 super().__init__("%d bytes read on a total of %r expected bytes" 34 % (len(partial), expected)) 35 self.partial = partial 36 self.expected = expected 37 38 39class LimitOverrunError(Exception): 40 """Reached the buffer limit while looking for a separator. 41 42 Attributes: 43 - consumed: total number of to be consumed bytes. 44 """ 45 def __init__(self, message, consumed): 46 super().__init__(message) 47 self.consumed = consumed 48 49 50@coroutine 51def open_connection(host=None, port=None, *, 52 loop=None, limit=_DEFAULT_LIMIT, **kwds): 53 """A wrapper for create_connection() returning a (reader, writer) pair. 54 55 The reader returned is a StreamReader instance; the writer is a 56 StreamWriter instance. 57 58 The arguments are all the usual arguments to create_connection() 59 except protocol_factory; most common are positional host and port, 60 with various optional keyword arguments following. 61 62 Additional optional keyword arguments are loop (to set the event loop 63 instance to use) and limit (to set the buffer limit passed to the 64 StreamReader). 65 66 (If you want to customize the StreamReader and/or 67 StreamReaderProtocol classes, just copy the code -- there's 68 really nothing special here except some convenience.) 69 """ 70 if loop is None: 71 loop = events.get_event_loop() 72 reader = StreamReader(limit=limit, loop=loop) 73 protocol = StreamReaderProtocol(reader, loop=loop) 74 transport, _ = yield from loop.create_connection( 75 lambda: protocol, host, port, **kwds) 76 writer = StreamWriter(transport, protocol, reader, loop) 77 return reader, writer 78 79 80@coroutine 81def start_server(client_connected_cb, host=None, port=None, *, 82 loop=None, limit=_DEFAULT_LIMIT, **kwds): 83 """Start a socket server, call back for each client connected. 84 85 The first parameter, `client_connected_cb`, takes two parameters: 86 client_reader, client_writer. client_reader is a StreamReader 87 object, while client_writer is a StreamWriter object. This 88 parameter can either be a plain callback function or a coroutine; 89 if it is a coroutine, it will be automatically converted into a 90 Task. 91 92 The rest of the arguments are all the usual arguments to 93 loop.create_server() except protocol_factory; most common are 94 positional host and port, with various optional keyword arguments 95 following. The return value is the same as loop.create_server(). 96 97 Additional optional keyword arguments are loop (to set the event loop 98 instance to use) and limit (to set the buffer limit passed to the 99 StreamReader). 100 101 The return value is the same as loop.create_server(), i.e. a 102 Server object which can be used to stop the service. 103 """ 104 if loop is None: 105 loop = events.get_event_loop() 106 107 def factory(): 108 reader = StreamReader(limit=limit, loop=loop) 109 protocol = StreamReaderProtocol(reader, client_connected_cb, 110 loop=loop) 111 return protocol 112 113 return (yield from loop.create_server(factory, host, port, **kwds)) 114 115 116if hasattr(socket, 'AF_UNIX'): 117 # UNIX Domain Sockets are supported on this platform 118 119 @coroutine 120 def open_unix_connection(path=None, *, 121 loop=None, limit=_DEFAULT_LIMIT, **kwds): 122 """Similar to `open_connection` but works with UNIX Domain Sockets.""" 123 if loop is None: 124 loop = events.get_event_loop() 125 reader = StreamReader(limit=limit, loop=loop) 126 protocol = StreamReaderProtocol(reader, loop=loop) 127 transport, _ = yield from loop.create_unix_connection( 128 lambda: protocol, path, **kwds) 129 writer = StreamWriter(transport, protocol, reader, loop) 130 return reader, writer 131 132 @coroutine 133 def start_unix_server(client_connected_cb, path=None, *, 134 loop=None, limit=_DEFAULT_LIMIT, **kwds): 135 """Similar to `start_server` but works with UNIX Domain Sockets.""" 136 if loop is None: 137 loop = events.get_event_loop() 138 139 def factory(): 140 reader = StreamReader(limit=limit, loop=loop) 141 protocol = StreamReaderProtocol(reader, client_connected_cb, 142 loop=loop) 143 return protocol 144 145 return (yield from loop.create_unix_server(factory, path, **kwds)) 146 147 148class FlowControlMixin(protocols.Protocol): 149 """Reusable flow control logic for StreamWriter.drain(). 150 151 This implements the protocol methods pause_writing(), 152 resume_reading() and connection_lost(). If the subclass overrides 153 these it must call the super methods. 154 155 StreamWriter.drain() must wait for _drain_helper() coroutine. 156 """ 157 158 def __init__(self, loop=None): 159 if loop is None: 160 self._loop = events.get_event_loop() 161 else: 162 self._loop = loop 163 self._paused = False 164 self._drain_waiter = None 165 self._connection_lost = False 166 167 def pause_writing(self): 168 assert not self._paused 169 self._paused = True 170 if self._loop.get_debug(): 171 logger.debug("%r pauses writing", self) 172 173 def resume_writing(self): 174 assert self._paused 175 self._paused = False 176 if self._loop.get_debug(): 177 logger.debug("%r resumes writing", self) 178 179 waiter = self._drain_waiter 180 if waiter is not None: 181 self._drain_waiter = None 182 if not waiter.done(): 183 waiter.set_result(None) 184 185 def connection_lost(self, exc): 186 self._connection_lost = True 187 # Wake up the writer if currently paused. 188 if not self._paused: 189 return 190 waiter = self._drain_waiter 191 if waiter is None: 192 return 193 self._drain_waiter = None 194 if waiter.done(): 195 return 196 if exc is None: 197 waiter.set_result(None) 198 else: 199 waiter.set_exception(exc) 200 201 @coroutine 202 def _drain_helper(self): 203 if self._connection_lost: 204 raise ConnectionResetError('Connection lost') 205 if not self._paused: 206 return 207 waiter = self._drain_waiter 208 assert waiter is None or waiter.cancelled() 209 waiter = self._loop.create_future() 210 self._drain_waiter = waiter 211 yield from waiter 212 213 214class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): 215 """Helper class to adapt between Protocol and StreamReader. 216 217 (This is a helper class instead of making StreamReader itself a 218 Protocol subclass, because the StreamReader has other potential 219 uses, and to prevent the user of the StreamReader to accidentally 220 call inappropriate methods of the protocol.) 221 """ 222 223 def __init__(self, stream_reader, client_connected_cb=None, loop=None): 224 super().__init__(loop=loop) 225 self._stream_reader = stream_reader 226 self._stream_writer = None 227 self._client_connected_cb = client_connected_cb 228 self._over_ssl = False 229 230 def connection_made(self, transport): 231 self._stream_reader.set_transport(transport) 232 self._over_ssl = transport.get_extra_info('sslcontext') is not None 233 if self._client_connected_cb is not None: 234 self._stream_writer = StreamWriter(transport, self, 235 self._stream_reader, 236 self._loop) 237 res = self._client_connected_cb(self._stream_reader, 238 self._stream_writer) 239 if coroutines.iscoroutine(res): 240 self._loop.create_task(res) 241 242 def connection_lost(self, exc): 243 if self._stream_reader is not None: 244 if exc is None: 245 self._stream_reader.feed_eof() 246 else: 247 self._stream_reader.set_exception(exc) 248 super().connection_lost(exc) 249 self._stream_reader = None 250 self._stream_writer = None 251 252 def data_received(self, data): 253 self._stream_reader.feed_data(data) 254 255 def eof_received(self): 256 self._stream_reader.feed_eof() 257 if self._over_ssl: 258 # Prevent a warning in SSLProtocol.eof_received: 259 # "returning true from eof_received() 260 # has no effect when using ssl" 261 return False 262 return True 263 264 265class StreamWriter: 266 """Wraps a Transport. 267 268 This exposes write(), writelines(), [can_]write_eof(), 269 get_extra_info() and close(). It adds drain() which returns an 270 optional Future on which you can wait for flow control. It also 271 adds a transport property which references the Transport 272 directly. 273 """ 274 275 def __init__(self, transport, protocol, reader, loop): 276 self._transport = transport 277 self._protocol = protocol 278 # drain() expects that the reader has an exception() method 279 assert reader is None or isinstance(reader, StreamReader) 280 self._reader = reader 281 self._loop = loop 282 283 def __repr__(self): 284 info = [self.__class__.__name__, 'transport=%r' % self._transport] 285 if self._reader is not None: 286 info.append('reader=%r' % self._reader) 287 return '<%s>' % ' '.join(info) 288 289 @property 290 def transport(self): 291 return self._transport 292 293 def write(self, data): 294 self._transport.write(data) 295 296 def writelines(self, data): 297 self._transport.writelines(data) 298 299 def write_eof(self): 300 return self._transport.write_eof() 301 302 def can_write_eof(self): 303 return self._transport.can_write_eof() 304 305 def close(self): 306 return self._transport.close() 307 308 def get_extra_info(self, name, default=None): 309 return self._transport.get_extra_info(name, default) 310 311 @coroutine 312 def drain(self): 313 """Flush the write buffer. 314 315 The intended use is to write 316 317 w.write(data) 318 yield from w.drain() 319 """ 320 if self._reader is not None: 321 exc = self._reader.exception() 322 if exc is not None: 323 raise exc 324 if self._transport is not None: 325 if self._transport.is_closing(): 326 # Yield to the event loop so connection_lost() may be 327 # called. Without this, _drain_helper() would return 328 # immediately, and code that calls 329 # write(...); yield from drain() 330 # in a loop would never call connection_lost(), so it 331 # would not see an error when the socket is closed. 332 yield 333 yield from self._protocol._drain_helper() 334 335 336class StreamReader: 337 338 def __init__(self, limit=_DEFAULT_LIMIT, loop=None): 339 # The line length limit is a security feature; 340 # it also doubles as half the buffer limit. 341 342 if limit <= 0: 343 raise ValueError('Limit cannot be <= 0') 344 345 self._limit = limit 346 if loop is None: 347 self._loop = events.get_event_loop() 348 else: 349 self._loop = loop 350 self._buffer = bytearray() 351 self._eof = False # Whether we're done. 352 self._waiter = None # A future used by _wait_for_data() 353 self._exception = None 354 self._transport = None 355 self._paused = False 356 357 def __repr__(self): 358 info = ['StreamReader'] 359 if self._buffer: 360 info.append('%d bytes' % len(self._buffer)) 361 if self._eof: 362 info.append('eof') 363 if self._limit != _DEFAULT_LIMIT: 364 info.append('l=%d' % self._limit) 365 if self._waiter: 366 info.append('w=%r' % self._waiter) 367 if self._exception: 368 info.append('e=%r' % self._exception) 369 if self._transport: 370 info.append('t=%r' % self._transport) 371 if self._paused: 372 info.append('paused') 373 return '<%s>' % ' '.join(info) 374 375 def exception(self): 376 return self._exception 377 378 def set_exception(self, exc): 379 self._exception = exc 380 381 waiter = self._waiter 382 if waiter is not None: 383 self._waiter = None 384 if not waiter.cancelled(): 385 waiter.set_exception(exc) 386 387 def _wakeup_waiter(self): 388 """Wakeup read*() functions waiting for data or EOF.""" 389 waiter = self._waiter 390 if waiter is not None: 391 self._waiter = None 392 if not waiter.cancelled(): 393 waiter.set_result(None) 394 395 def set_transport(self, transport): 396 assert self._transport is None, 'Transport already set' 397 self._transport = transport 398 399 def _maybe_resume_transport(self): 400 if self._paused and len(self._buffer) <= self._limit: 401 self._paused = False 402 self._transport.resume_reading() 403 404 def feed_eof(self): 405 self._eof = True 406 self._wakeup_waiter() 407 408 def at_eof(self): 409 """Return True if the buffer is empty and 'feed_eof' was called.""" 410 return self._eof and not self._buffer 411 412 def feed_data(self, data): 413 assert not self._eof, 'feed_data after feed_eof' 414 415 if not data: 416 return 417 418 self._buffer.extend(data) 419 self._wakeup_waiter() 420 421 if (self._transport is not None and 422 not self._paused and 423 len(self._buffer) > 2 * self._limit): 424 try: 425 self._transport.pause_reading() 426 except NotImplementedError: 427 # The transport can't be paused. 428 # We'll just have to buffer all data. 429 # Forget the transport so we don't keep trying. 430 self._transport = None 431 else: 432 self._paused = True 433 434 @coroutine 435 def _wait_for_data(self, func_name): 436 """Wait until feed_data() or feed_eof() is called. 437 438 If stream was paused, automatically resume it. 439 """ 440 # StreamReader uses a future to link the protocol feed_data() method 441 # to a read coroutine. Running two read coroutines at the same time 442 # would have an unexpected behaviour. It would not possible to know 443 # which coroutine would get the next data. 444 if self._waiter is not None: 445 raise RuntimeError('%s() called while another coroutine is ' 446 'already waiting for incoming data' % func_name) 447 448 assert not self._eof, '_wait_for_data after EOF' 449 450 # Waiting for data while paused will make deadlock, so prevent it. 451 # This is essential for readexactly(n) for case when n > self._limit. 452 if self._paused: 453 self._paused = False 454 self._transport.resume_reading() 455 456 self._waiter = self._loop.create_future() 457 try: 458 yield from self._waiter 459 finally: 460 self._waiter = None 461 462 @coroutine 463 def readline(self): 464 """Read chunk of data from the stream until newline (b'\n') is found. 465 466 On success, return chunk that ends with newline. If only partial 467 line can be read due to EOF, return incomplete line without 468 terminating newline. When EOF was reached while no bytes read, empty 469 bytes object is returned. 470 471 If limit is reached, ValueError will be raised. In that case, if 472 newline was found, complete line including newline will be removed 473 from internal buffer. Else, internal buffer will be cleared. Limit is 474 compared against part of the line without newline. 475 476 If stream was paused, this function will automatically resume it if 477 needed. 478 """ 479 sep = b'\n' 480 seplen = len(sep) 481 try: 482 line = yield from self.readuntil(sep) 483 except IncompleteReadError as e: 484 return e.partial 485 except LimitOverrunError as e: 486 if self._buffer.startswith(sep, e.consumed): 487 del self._buffer[:e.consumed + seplen] 488 else: 489 self._buffer.clear() 490 self._maybe_resume_transport() 491 raise ValueError(e.args[0]) 492 return line 493 494 @coroutine 495 def readuntil(self, separator=b'\n'): 496 """Read data from the stream until ``separator`` is found. 497 498 On success, the data and separator will be removed from the 499 internal buffer (consumed). Returned data will include the 500 separator at the end. 501 502 Configured stream limit is used to check result. Limit sets the 503 maximal length of data that can be returned, not counting the 504 separator. 505 506 If an EOF occurs and the complete separator is still not found, 507 an IncompleteReadError exception will be raised, and the internal 508 buffer will be reset. The IncompleteReadError.partial attribute 509 may contain the separator partially. 510 511 If the data cannot be read because of over limit, a 512 LimitOverrunError exception will be raised, and the data 513 will be left in the internal buffer, so it can be read again. 514 """ 515 seplen = len(separator) 516 if seplen == 0: 517 raise ValueError('Separator should be at least one-byte string') 518 519 if self._exception is not None: 520 raise self._exception 521 522 # Consume whole buffer except last bytes, which length is 523 # one less than seplen. Let's check corner cases with 524 # separator='SEPARATOR': 525 # * we have received almost complete separator (without last 526 # byte). i.e buffer='some textSEPARATO'. In this case we 527 # can safely consume len(separator) - 1 bytes. 528 # * last byte of buffer is first byte of separator, i.e. 529 # buffer='abcdefghijklmnopqrS'. We may safely consume 530 # everything except that last byte, but this require to 531 # analyze bytes of buffer that match partial separator. 532 # This is slow and/or require FSM. For this case our 533 # implementation is not optimal, since require rescanning 534 # of data that is known to not belong to separator. In 535 # real world, separator will not be so long to notice 536 # performance problems. Even when reading MIME-encoded 537 # messages :) 538 539 # `offset` is the number of bytes from the beginning of the buffer 540 # where there is no occurrence of `separator`. 541 offset = 0 542 543 # Loop until we find `separator` in the buffer, exceed the buffer size, 544 # or an EOF has happened. 545 while True: 546 buflen = len(self._buffer) 547 548 # Check if we now have enough data in the buffer for `separator` to 549 # fit. 550 if buflen - offset >= seplen: 551 isep = self._buffer.find(separator, offset) 552 553 if isep != -1: 554 # `separator` is in the buffer. `isep` will be used later 555 # to retrieve the data. 556 break 557 558 # see upper comment for explanation. 559 offset = buflen + 1 - seplen 560 if offset > self._limit: 561 raise LimitOverrunError( 562 'Separator is not found, and chunk exceed the limit', 563 offset) 564 565 # Complete message (with full separator) may be present in buffer 566 # even when EOF flag is set. This may happen when the last chunk 567 # adds data which makes separator be found. That's why we check for 568 # EOF *ater* inspecting the buffer. 569 if self._eof: 570 chunk = bytes(self._buffer) 571 self._buffer.clear() 572 raise IncompleteReadError(chunk, None) 573 574 # _wait_for_data() will resume reading if stream was paused. 575 yield from self._wait_for_data('readuntil') 576 577 if isep > self._limit: 578 raise LimitOverrunError( 579 'Separator is found, but chunk is longer than limit', isep) 580 581 chunk = self._buffer[:isep + seplen] 582 del self._buffer[:isep + seplen] 583 self._maybe_resume_transport() 584 return bytes(chunk) 585 586 @coroutine 587 def read(self, n=-1): 588 """Read up to `n` bytes from the stream. 589 590 If n is not provided, or set to -1, read until EOF and return all read 591 bytes. If the EOF was received and the internal buffer is empty, return 592 an empty bytes object. 593 594 If n is zero, return empty bytes object immediately. 595 596 If n is positive, this function try to read `n` bytes, and may return 597 less or equal bytes than requested, but at least one byte. If EOF was 598 received before any byte is read, this function returns empty byte 599 object. 600 601 Returned value is not limited with limit, configured at stream 602 creation. 603 604 If stream was paused, this function will automatically resume it if 605 needed. 606 """ 607 608 if self._exception is not None: 609 raise self._exception 610 611 if n == 0: 612 return b'' 613 614 if n < 0: 615 # This used to just loop creating a new waiter hoping to 616 # collect everything in self._buffer, but that would 617 # deadlock if the subprocess sends more than self.limit 618 # bytes. So just call self.read(self._limit) until EOF. 619 blocks = [] 620 while True: 621 block = yield from self.read(self._limit) 622 if not block: 623 break 624 blocks.append(block) 625 return b''.join(blocks) 626 627 if not self._buffer and not self._eof: 628 yield from self._wait_for_data('read') 629 630 # This will work right even if buffer is less than n bytes 631 data = bytes(self._buffer[:n]) 632 del self._buffer[:n] 633 634 self._maybe_resume_transport() 635 return data 636 637 @coroutine 638 def readexactly(self, n): 639 """Read exactly `n` bytes. 640 641 Raise an IncompleteReadError if EOF is reached before `n` bytes can be 642 read. The IncompleteReadError.partial attribute of the exception will 643 contain the partial read bytes. 644 645 if n is zero, return empty bytes object. 646 647 Returned value is not limited with limit, configured at stream 648 creation. 649 650 If stream was paused, this function will automatically resume it if 651 needed. 652 """ 653 if n < 0: 654 raise ValueError('readexactly size can not be less than zero') 655 656 if self._exception is not None: 657 raise self._exception 658 659 if n == 0: 660 return b'' 661 662 while len(self._buffer) < n: 663 if self._eof: 664 incomplete = bytes(self._buffer) 665 self._buffer.clear() 666 raise IncompleteReadError(incomplete, n) 667 668 yield from self._wait_for_data('readexactly') 669 670 if len(self._buffer) == n: 671 data = bytes(self._buffer) 672 self._buffer.clear() 673 else: 674 data = bytes(self._buffer[:n]) 675 del self._buffer[:n] 676 self._maybe_resume_transport() 677 return data 678 679 if compat.PY35: 680 @coroutine 681 def __aiter__(self): 682 return self 683 684 @coroutine 685 def __anext__(self): 686 val = yield from self.readline() 687 if val == b'': 688 raise StopAsyncIteration 689 return val 690 691 if compat.PY352: 692 # In Python 3.5.2 and greater, __aiter__ should return 693 # the asynchronous iterator directly. 694 def __aiter__(self): 695 return self 696