1__all__ = ( 2 'StreamReader', 'StreamWriter', 'StreamReaderProtocol', 3 'open_connection', 'start_server', 4 'IncompleteReadError', 'LimitOverrunError', 5) 6 7import socket 8 9if hasattr(socket, 'AF_UNIX'): 10 __all__ += ('open_unix_connection', 'start_unix_server') 11 12from . import coroutines 13from . import events 14from . import protocols 15from .log import logger 16from .tasks import sleep 17 18 19_DEFAULT_LIMIT = 2 ** 16 # 64 KiB 20 21 22class IncompleteReadError(EOFError): 23 """ 24 Incomplete read error. Attributes: 25 26 - partial: read bytes string before the end of stream was reached 27 - expected: total number of expected bytes (or None if unknown) 28 """ 29 def __init__(self, partial, expected): 30 super().__init__(f'{len(partial)} bytes read on a total of ' 31 f'{expected!r} expected bytes') 32 self.partial = partial 33 self.expected = expected 34 35 def __reduce__(self): 36 return type(self), (self.partial, self.expected) 37 38 39class LimitOverrunError(Exception): 40 """Reached the buffer limit while looking for a separator. 41 42 Attributes: 43 - consumed: total number of to be consumed bytes. 44 """ 45 def __init__(self, message, consumed): 46 super().__init__(message) 47 self.consumed = consumed 48 49 def __reduce__(self): 50 return type(self), (self.args[0], self.consumed) 51 52 53async def open_connection(host=None, port=None, *, 54 loop=None, limit=_DEFAULT_LIMIT, **kwds): 55 """A wrapper for create_connection() returning a (reader, writer) pair. 56 57 The reader returned is a StreamReader instance; the writer is a 58 StreamWriter instance. 59 60 The arguments are all the usual arguments to create_connection() 61 except protocol_factory; most common are positional host and port, 62 with various optional keyword arguments following. 63 64 Additional optional keyword arguments are loop (to set the event loop 65 instance to use) and limit (to set the buffer limit passed to the 66 StreamReader). 67 68 (If you want to customize the StreamReader and/or 69 StreamReaderProtocol classes, just copy the code -- there's 70 really nothing special here except some convenience.) 71 """ 72 if loop is None: 73 loop = events.get_event_loop() 74 reader = StreamReader(limit=limit, loop=loop) 75 protocol = StreamReaderProtocol(reader, loop=loop) 76 transport, _ = await loop.create_connection( 77 lambda: protocol, host, port, **kwds) 78 writer = StreamWriter(transport, protocol, reader, loop) 79 return reader, writer 80 81 82async def start_server(client_connected_cb, host=None, port=None, *, 83 loop=None, limit=_DEFAULT_LIMIT, **kwds): 84 """Start a socket server, call back for each client connected. 85 86 The first parameter, `client_connected_cb`, takes two parameters: 87 client_reader, client_writer. client_reader is a StreamReader 88 object, while client_writer is a StreamWriter object. This 89 parameter can either be a plain callback function or a coroutine; 90 if it is a coroutine, it will be automatically converted into a 91 Task. 92 93 The rest of the arguments are all the usual arguments to 94 loop.create_server() except protocol_factory; most common are 95 positional host and port, with various optional keyword arguments 96 following. The return value is the same as loop.create_server(). 97 98 Additional optional keyword arguments are loop (to set the event loop 99 instance to use) and limit (to set the buffer limit passed to the 100 StreamReader). 101 102 The return value is the same as loop.create_server(), i.e. a 103 Server object which can be used to stop the service. 104 """ 105 if loop is None: 106 loop = events.get_event_loop() 107 108 def factory(): 109 reader = StreamReader(limit=limit, loop=loop) 110 protocol = StreamReaderProtocol(reader, client_connected_cb, 111 loop=loop) 112 return protocol 113 114 return await loop.create_server(factory, host, port, **kwds) 115 116 117if hasattr(socket, 'AF_UNIX'): 118 # UNIX Domain Sockets are supported on this platform 119 120 async def open_unix_connection(path=None, *, 121 loop=None, limit=_DEFAULT_LIMIT, **kwds): 122 """Similar to `open_connection` but works with UNIX Domain Sockets.""" 123 if loop is None: 124 loop = events.get_event_loop() 125 reader = StreamReader(limit=limit, loop=loop) 126 protocol = StreamReaderProtocol(reader, loop=loop) 127 transport, _ = await loop.create_unix_connection( 128 lambda: protocol, path, **kwds) 129 writer = StreamWriter(transport, protocol, reader, loop) 130 return reader, writer 131 132 async def start_unix_server(client_connected_cb, path=None, *, 133 loop=None, limit=_DEFAULT_LIMIT, **kwds): 134 """Similar to `start_server` but works with UNIX Domain Sockets.""" 135 if loop is None: 136 loop = events.get_event_loop() 137 138 def factory(): 139 reader = StreamReader(limit=limit, loop=loop) 140 protocol = StreamReaderProtocol(reader, client_connected_cb, 141 loop=loop) 142 return protocol 143 144 return await loop.create_unix_server(factory, path, **kwds) 145 146 147class FlowControlMixin(protocols.Protocol): 148 """Reusable flow control logic for StreamWriter.drain(). 149 150 This implements the protocol methods pause_writing(), 151 resume_writing() and connection_lost(). If the subclass overrides 152 these it must call the super methods. 153 154 StreamWriter.drain() must wait for _drain_helper() coroutine. 155 """ 156 157 def __init__(self, loop=None): 158 if loop is None: 159 self._loop = events.get_event_loop() 160 else: 161 self._loop = loop 162 self._paused = False 163 self._drain_waiter = None 164 self._connection_lost = False 165 166 def pause_writing(self): 167 assert not self._paused 168 self._paused = True 169 if self._loop.get_debug(): 170 logger.debug("%r pauses writing", self) 171 172 def resume_writing(self): 173 assert self._paused 174 self._paused = False 175 if self._loop.get_debug(): 176 logger.debug("%r resumes writing", self) 177 178 waiter = self._drain_waiter 179 if waiter is not None: 180 self._drain_waiter = None 181 if not waiter.done(): 182 waiter.set_result(None) 183 184 def connection_lost(self, exc): 185 self._connection_lost = True 186 # Wake up the writer if currently paused. 187 if not self._paused: 188 return 189 waiter = self._drain_waiter 190 if waiter is None: 191 return 192 self._drain_waiter = None 193 if waiter.done(): 194 return 195 if exc is None: 196 waiter.set_result(None) 197 else: 198 waiter.set_exception(exc) 199 200 async def _drain_helper(self): 201 if self._connection_lost: 202 raise ConnectionResetError('Connection lost') 203 if not self._paused: 204 return 205 waiter = self._drain_waiter 206 assert waiter is None or waiter.cancelled() 207 waiter = self._loop.create_future() 208 self._drain_waiter = waiter 209 await waiter 210 211 212class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): 213 """Helper class to adapt between Protocol and StreamReader. 214 215 (This is a helper class instead of making StreamReader itself a 216 Protocol subclass, because the StreamReader has other potential 217 uses, and to prevent the user of the StreamReader to accidentally 218 call inappropriate methods of the protocol.) 219 """ 220 221 def __init__(self, stream_reader, client_connected_cb=None, loop=None): 222 super().__init__(loop=loop) 223 self._stream_reader = stream_reader 224 self._stream_writer = None 225 self._client_connected_cb = client_connected_cb 226 self._over_ssl = False 227 self._closed = self._loop.create_future() 228 229 def connection_made(self, transport): 230 self._stream_reader.set_transport(transport) 231 self._over_ssl = transport.get_extra_info('sslcontext') is not None 232 if self._client_connected_cb is not None: 233 self._stream_writer = StreamWriter(transport, self, 234 self._stream_reader, 235 self._loop) 236 res = self._client_connected_cb(self._stream_reader, 237 self._stream_writer) 238 if coroutines.iscoroutine(res): 239 self._loop.create_task(res) 240 241 def connection_lost(self, exc): 242 if self._stream_reader is not None: 243 if exc is None: 244 self._stream_reader.feed_eof() 245 else: 246 self._stream_reader.set_exception(exc) 247 if not self._closed.done(): 248 if exc is None: 249 self._closed.set_result(None) 250 else: 251 self._closed.set_exception(exc) 252 super().connection_lost(exc) 253 self._stream_reader = None 254 self._stream_writer = None 255 256 def data_received(self, data): 257 self._stream_reader.feed_data(data) 258 259 def eof_received(self): 260 self._stream_reader.feed_eof() 261 if self._over_ssl: 262 # Prevent a warning in SSLProtocol.eof_received: 263 # "returning true from eof_received() 264 # has no effect when using ssl" 265 return False 266 return True 267 268 def __del__(self): 269 # Prevent reports about unhandled exceptions. 270 # Better than self._closed._log_traceback = False hack 271 closed = self._closed 272 if closed.done() and not closed.cancelled(): 273 closed.exception() 274 275 276class StreamWriter: 277 """Wraps a Transport. 278 279 This exposes write(), writelines(), [can_]write_eof(), 280 get_extra_info() and close(). It adds drain() which returns an 281 optional Future on which you can wait for flow control. It also 282 adds a transport property which references the Transport 283 directly. 284 """ 285 286 def __init__(self, transport, protocol, reader, loop): 287 self._transport = transport 288 self._protocol = protocol 289 # drain() expects that the reader has an exception() method 290 assert reader is None or isinstance(reader, StreamReader) 291 self._reader = reader 292 self._loop = loop 293 294 def __repr__(self): 295 info = [self.__class__.__name__, f'transport={self._transport!r}'] 296 if self._reader is not None: 297 info.append(f'reader={self._reader!r}') 298 return '<{}>'.format(' '.join(info)) 299 300 @property 301 def transport(self): 302 return self._transport 303 304 def write(self, data): 305 self._transport.write(data) 306 307 def writelines(self, data): 308 self._transport.writelines(data) 309 310 def write_eof(self): 311 return self._transport.write_eof() 312 313 def can_write_eof(self): 314 return self._transport.can_write_eof() 315 316 def close(self): 317 return self._transport.close() 318 319 def is_closing(self): 320 return self._transport.is_closing() 321 322 async def wait_closed(self): 323 await self._protocol._closed 324 325 def get_extra_info(self, name, default=None): 326 return self._transport.get_extra_info(name, default) 327 328 async def drain(self): 329 """Flush the write buffer. 330 331 The intended use is to write 332 333 w.write(data) 334 await w.drain() 335 """ 336 if self._reader is not None: 337 exc = self._reader.exception() 338 if exc is not None: 339 raise exc 340 if self._transport.is_closing(): 341 # Yield to the event loop so connection_lost() may be 342 # called. Without this, _drain_helper() would return 343 # immediately, and code that calls 344 # write(...); await drain() 345 # in a loop would never call connection_lost(), so it 346 # would not see an error when the socket is closed. 347 await sleep(0, loop=self._loop) 348 await self._protocol._drain_helper() 349 350 351class StreamReader: 352 353 def __init__(self, limit=_DEFAULT_LIMIT, loop=None): 354 # The line length limit is a security feature; 355 # it also doubles as half the buffer limit. 356 357 if limit <= 0: 358 raise ValueError('Limit cannot be <= 0') 359 360 self._limit = limit 361 if loop is None: 362 self._loop = events.get_event_loop() 363 else: 364 self._loop = loop 365 self._buffer = bytearray() 366 self._eof = False # Whether we're done. 367 self._waiter = None # A future used by _wait_for_data() 368 self._exception = None 369 self._transport = None 370 self._paused = False 371 372 def __repr__(self): 373 info = ['StreamReader'] 374 if self._buffer: 375 info.append(f'{len(self._buffer)} bytes') 376 if self._eof: 377 info.append('eof') 378 if self._limit != _DEFAULT_LIMIT: 379 info.append(f'limit={self._limit}') 380 if self._waiter: 381 info.append(f'waiter={self._waiter!r}') 382 if self._exception: 383 info.append(f'exception={self._exception!r}') 384 if self._transport: 385 info.append(f'transport={self._transport!r}') 386 if self._paused: 387 info.append('paused') 388 return '<{}>'.format(' '.join(info)) 389 390 def exception(self): 391 return self._exception 392 393 def set_exception(self, exc): 394 self._exception = exc 395 396 waiter = self._waiter 397 if waiter is not None: 398 self._waiter = None 399 if not waiter.cancelled(): 400 waiter.set_exception(exc) 401 402 def _wakeup_waiter(self): 403 """Wakeup read*() functions waiting for data or EOF.""" 404 waiter = self._waiter 405 if waiter is not None: 406 self._waiter = None 407 if not waiter.cancelled(): 408 waiter.set_result(None) 409 410 def set_transport(self, transport): 411 assert self._transport is None, 'Transport already set' 412 self._transport = transport 413 414 def _maybe_resume_transport(self): 415 if self._paused and len(self._buffer) <= self._limit: 416 self._paused = False 417 self._transport.resume_reading() 418 419 def feed_eof(self): 420 self._eof = True 421 self._wakeup_waiter() 422 423 def at_eof(self): 424 """Return True if the buffer is empty and 'feed_eof' was called.""" 425 return self._eof and not self._buffer 426 427 def feed_data(self, data): 428 assert not self._eof, 'feed_data after feed_eof' 429 430 if not data: 431 return 432 433 self._buffer.extend(data) 434 self._wakeup_waiter() 435 436 if (self._transport is not None and 437 not self._paused and 438 len(self._buffer) > 2 * self._limit): 439 try: 440 self._transport.pause_reading() 441 except NotImplementedError: 442 # The transport can't be paused. 443 # We'll just have to buffer all data. 444 # Forget the transport so we don't keep trying. 445 self._transport = None 446 else: 447 self._paused = True 448 449 async def _wait_for_data(self, func_name): 450 """Wait until feed_data() or feed_eof() is called. 451 452 If stream was paused, automatically resume it. 453 """ 454 # StreamReader uses a future to link the protocol feed_data() method 455 # to a read coroutine. Running two read coroutines at the same time 456 # would have an unexpected behaviour. It would not possible to know 457 # which coroutine would get the next data. 458 if self._waiter is not None: 459 raise RuntimeError( 460 f'{func_name}() called while another coroutine is ' 461 f'already waiting for incoming data') 462 463 assert not self._eof, '_wait_for_data after EOF' 464 465 # Waiting for data while paused will make deadlock, so prevent it. 466 # This is essential for readexactly(n) for case when n > self._limit. 467 if self._paused: 468 self._paused = False 469 self._transport.resume_reading() 470 471 self._waiter = self._loop.create_future() 472 try: 473 await self._waiter 474 finally: 475 self._waiter = None 476 477 async def readline(self): 478 """Read chunk of data from the stream until newline (b'\n') is found. 479 480 On success, return chunk that ends with newline. If only partial 481 line can be read due to EOF, return incomplete line without 482 terminating newline. When EOF was reached while no bytes read, empty 483 bytes object is returned. 484 485 If limit is reached, ValueError will be raised. In that case, if 486 newline was found, complete line including newline will be removed 487 from internal buffer. Else, internal buffer will be cleared. Limit is 488 compared against part of the line without newline. 489 490 If stream was paused, this function will automatically resume it if 491 needed. 492 """ 493 sep = b'\n' 494 seplen = len(sep) 495 try: 496 line = await self.readuntil(sep) 497 except IncompleteReadError as e: 498 return e.partial 499 except LimitOverrunError as e: 500 if self._buffer.startswith(sep, e.consumed): 501 del self._buffer[:e.consumed + seplen] 502 else: 503 self._buffer.clear() 504 self._maybe_resume_transport() 505 raise ValueError(e.args[0]) 506 return line 507 508 async def readuntil(self, separator=b'\n'): 509 """Read data from the stream until ``separator`` is found. 510 511 On success, the data and separator will be removed from the 512 internal buffer (consumed). Returned data will include the 513 separator at the end. 514 515 Configured stream limit is used to check result. Limit sets the 516 maximal length of data that can be returned, not counting the 517 separator. 518 519 If an EOF occurs and the complete separator is still not found, 520 an IncompleteReadError exception will be raised, and the internal 521 buffer will be reset. The IncompleteReadError.partial attribute 522 may contain the separator partially. 523 524 If the data cannot be read because of over limit, a 525 LimitOverrunError exception will be raised, and the data 526 will be left in the internal buffer, so it can be read again. 527 """ 528 seplen = len(separator) 529 if seplen == 0: 530 raise ValueError('Separator should be at least one-byte string') 531 532 if self._exception is not None: 533 raise self._exception 534 535 # Consume whole buffer except last bytes, which length is 536 # one less than seplen. Let's check corner cases with 537 # separator='SEPARATOR': 538 # * we have received almost complete separator (without last 539 # byte). i.e buffer='some textSEPARATO'. In this case we 540 # can safely consume len(separator) - 1 bytes. 541 # * last byte of buffer is first byte of separator, i.e. 542 # buffer='abcdefghijklmnopqrS'. We may safely consume 543 # everything except that last byte, but this require to 544 # analyze bytes of buffer that match partial separator. 545 # This is slow and/or require FSM. For this case our 546 # implementation is not optimal, since require rescanning 547 # of data that is known to not belong to separator. In 548 # real world, separator will not be so long to notice 549 # performance problems. Even when reading MIME-encoded 550 # messages :) 551 552 # `offset` is the number of bytes from the beginning of the buffer 553 # where there is no occurrence of `separator`. 554 offset = 0 555 556 # Loop until we find `separator` in the buffer, exceed the buffer size, 557 # or an EOF has happened. 558 while True: 559 buflen = len(self._buffer) 560 561 # Check if we now have enough data in the buffer for `separator` to 562 # fit. 563 if buflen - offset >= seplen: 564 isep = self._buffer.find(separator, offset) 565 566 if isep != -1: 567 # `separator` is in the buffer. `isep` will be used later 568 # to retrieve the data. 569 break 570 571 # see upper comment for explanation. 572 offset = buflen + 1 - seplen 573 if offset > self._limit: 574 raise LimitOverrunError( 575 'Separator is not found, and chunk exceed the limit', 576 offset) 577 578 # Complete message (with full separator) may be present in buffer 579 # even when EOF flag is set. This may happen when the last chunk 580 # adds data which makes separator be found. That's why we check for 581 # EOF *ater* inspecting the buffer. 582 if self._eof: 583 chunk = bytes(self._buffer) 584 self._buffer.clear() 585 raise IncompleteReadError(chunk, None) 586 587 # _wait_for_data() will resume reading if stream was paused. 588 await self._wait_for_data('readuntil') 589 590 if isep > self._limit: 591 raise LimitOverrunError( 592 'Separator is found, but chunk is longer than limit', isep) 593 594 chunk = self._buffer[:isep + seplen] 595 del self._buffer[:isep + seplen] 596 self._maybe_resume_transport() 597 return bytes(chunk) 598 599 async def read(self, n=-1): 600 """Read up to `n` bytes from the stream. 601 602 If n is not provided, or set to -1, read until EOF and return all read 603 bytes. If the EOF was received and the internal buffer is empty, return 604 an empty bytes object. 605 606 If n is zero, return empty bytes object immediately. 607 608 If n is positive, this function try to read `n` bytes, and may return 609 less or equal bytes than requested, but at least one byte. If EOF was 610 received before any byte is read, this function returns empty byte 611 object. 612 613 Returned value is not limited with limit, configured at stream 614 creation. 615 616 If stream was paused, this function will automatically resume it if 617 needed. 618 """ 619 620 if self._exception is not None: 621 raise self._exception 622 623 if n == 0: 624 return b'' 625 626 if n < 0: 627 # This used to just loop creating a new waiter hoping to 628 # collect everything in self._buffer, but that would 629 # deadlock if the subprocess sends more than self.limit 630 # bytes. So just call self.read(self._limit) until EOF. 631 blocks = [] 632 while True: 633 block = await self.read(self._limit) 634 if not block: 635 break 636 blocks.append(block) 637 return b''.join(blocks) 638 639 if not self._buffer and not self._eof: 640 await self._wait_for_data('read') 641 642 # This will work right even if buffer is less than n bytes 643 data = bytes(self._buffer[:n]) 644 del self._buffer[:n] 645 646 self._maybe_resume_transport() 647 return data 648 649 async def readexactly(self, n): 650 """Read exactly `n` bytes. 651 652 Raise an IncompleteReadError if EOF is reached before `n` bytes can be 653 read. The IncompleteReadError.partial attribute of the exception will 654 contain the partial read bytes. 655 656 if n is zero, return empty bytes object. 657 658 Returned value is not limited with limit, configured at stream 659 creation. 660 661 If stream was paused, this function will automatically resume it if 662 needed. 663 """ 664 if n < 0: 665 raise ValueError('readexactly size can not be less than zero') 666 667 if self._exception is not None: 668 raise self._exception 669 670 if n == 0: 671 return b'' 672 673 while len(self._buffer) < n: 674 if self._eof: 675 incomplete = bytes(self._buffer) 676 self._buffer.clear() 677 raise IncompleteReadError(incomplete, n) 678 679 await self._wait_for_data('readexactly') 680 681 if len(self._buffer) == n: 682 data = bytes(self._buffer) 683 self._buffer.clear() 684 else: 685 data = bytes(self._buffer[:n]) 686 del self._buffer[:n] 687 self._maybe_resume_transport() 688 return data 689 690 def __aiter__(self): 691 return self 692 693 async def __anext__(self): 694 val = await self.readline() 695 if val == b'': 696 raise StopAsyncIteration 697 return val 698