• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1__all__ = (
2    'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
3    'open_connection', 'start_server')
4
5import socket
6import sys
7import warnings
8import weakref
9
10if hasattr(socket, 'AF_UNIX'):
11    __all__ += ('open_unix_connection', 'start_unix_server')
12
13from . import coroutines
14from . import events
15from . import exceptions
16from . import format_helpers
17from . import protocols
18from .log import logger
19from .tasks import sleep
20
21
22_DEFAULT_LIMIT = 2 ** 16  # 64 KiB
23
24
25async def open_connection(host=None, port=None, *,
26                          limit=_DEFAULT_LIMIT, **kwds):
27    """A wrapper for create_connection() returning a (reader, writer) pair.
28
29    The reader returned is a StreamReader instance; the writer is a
30    StreamWriter instance.
31
32    The arguments are all the usual arguments to create_connection()
33    except protocol_factory; most common are positional host and port,
34    with various optional keyword arguments following.
35
36    Additional optional keyword arguments are loop (to set the event loop
37    instance to use) and limit (to set the buffer limit passed to the
38    StreamReader).
39
40    (If you want to customize the StreamReader and/or
41    StreamReaderProtocol classes, just copy the code -- there's
42    really nothing special here except some convenience.)
43    """
44    loop = events.get_running_loop()
45    reader = StreamReader(limit=limit, loop=loop)
46    protocol = StreamReaderProtocol(reader, loop=loop)
47    transport, _ = await loop.create_connection(
48        lambda: protocol, host, port, **kwds)
49    writer = StreamWriter(transport, protocol, reader, loop)
50    return reader, writer
51
52
53async def start_server(client_connected_cb, host=None, port=None, *,
54                       limit=_DEFAULT_LIMIT, **kwds):
55    """Start a socket server, call back for each client connected.
56
57    The first parameter, `client_connected_cb`, takes two parameters:
58    client_reader, client_writer.  client_reader is a StreamReader
59    object, while client_writer is a StreamWriter object.  This
60    parameter can either be a plain callback function or a coroutine;
61    if it is a coroutine, it will be automatically converted into a
62    Task.
63
64    The rest of the arguments are all the usual arguments to
65    loop.create_server() except protocol_factory; most common are
66    positional host and port, with various optional keyword arguments
67    following.  The return value is the same as loop.create_server().
68
69    Additional optional keyword arguments are loop (to set the event loop
70    instance to use) and limit (to set the buffer limit passed to the
71    StreamReader).
72
73    The return value is the same as loop.create_server(), i.e. a
74    Server object which can be used to stop the service.
75    """
76    loop = events.get_running_loop()
77
78    def factory():
79        reader = StreamReader(limit=limit, loop=loop)
80        protocol = StreamReaderProtocol(reader, client_connected_cb,
81                                        loop=loop)
82        return protocol
83
84    return await loop.create_server(factory, host, port, **kwds)
85
86
87if hasattr(socket, 'AF_UNIX'):
88    # UNIX Domain Sockets are supported on this platform
89
90    async def open_unix_connection(path=None, *,
91                                   limit=_DEFAULT_LIMIT, **kwds):
92        """Similar to `open_connection` but works with UNIX Domain Sockets."""
93        loop = events.get_running_loop()
94
95        reader = StreamReader(limit=limit, loop=loop)
96        protocol = StreamReaderProtocol(reader, loop=loop)
97        transport, _ = await loop.create_unix_connection(
98            lambda: protocol, path, **kwds)
99        writer = StreamWriter(transport, protocol, reader, loop)
100        return reader, writer
101
102    async def start_unix_server(client_connected_cb, path=None, *,
103                                limit=_DEFAULT_LIMIT, **kwds):
104        """Similar to `start_server` but works with UNIX Domain Sockets."""
105        loop = events.get_running_loop()
106
107        def factory():
108            reader = StreamReader(limit=limit, loop=loop)
109            protocol = StreamReaderProtocol(reader, client_connected_cb,
110                                            loop=loop)
111            return protocol
112
113        return await loop.create_unix_server(factory, path, **kwds)
114
115
116class FlowControlMixin(protocols.Protocol):
117    """Reusable flow control logic for StreamWriter.drain().
118
119    This implements the protocol methods pause_writing(),
120    resume_writing() and connection_lost().  If the subclass overrides
121    these it must call the super methods.
122
123    StreamWriter.drain() must wait for _drain_helper() coroutine.
124    """
125
126    def __init__(self, loop=None):
127        if loop is None:
128            self._loop = events._get_event_loop(stacklevel=4)
129        else:
130            self._loop = loop
131        self._paused = False
132        self._drain_waiter = None
133        self._connection_lost = False
134
135    def pause_writing(self):
136        assert not self._paused
137        self._paused = True
138        if self._loop.get_debug():
139            logger.debug("%r pauses writing", self)
140
141    def resume_writing(self):
142        assert self._paused
143        self._paused = False
144        if self._loop.get_debug():
145            logger.debug("%r resumes writing", self)
146
147        waiter = self._drain_waiter
148        if waiter is not None:
149            self._drain_waiter = None
150            if not waiter.done():
151                waiter.set_result(None)
152
153    def connection_lost(self, exc):
154        self._connection_lost = True
155        # Wake up the writer if currently paused.
156        if not self._paused:
157            return
158        waiter = self._drain_waiter
159        if waiter is None:
160            return
161        self._drain_waiter = None
162        if waiter.done():
163            return
164        if exc is None:
165            waiter.set_result(None)
166        else:
167            waiter.set_exception(exc)
168
169    async def _drain_helper(self):
170        if self._connection_lost:
171            raise ConnectionResetError('Connection lost')
172        if not self._paused:
173            return
174        waiter = self._drain_waiter
175        assert waiter is None or waiter.cancelled()
176        waiter = self._loop.create_future()
177        self._drain_waiter = waiter
178        await waiter
179
180    def _get_close_waiter(self, stream):
181        raise NotImplementedError
182
183
184class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
185    """Helper class to adapt between Protocol and StreamReader.
186
187    (This is a helper class instead of making StreamReader itself a
188    Protocol subclass, because the StreamReader has other potential
189    uses, and to prevent the user of the StreamReader to accidentally
190    call inappropriate methods of the protocol.)
191    """
192
193    _source_traceback = None
194
195    def __init__(self, stream_reader, client_connected_cb=None, loop=None):
196        super().__init__(loop=loop)
197        if stream_reader is not None:
198            self._stream_reader_wr = weakref.ref(stream_reader)
199            self._source_traceback = stream_reader._source_traceback
200        else:
201            self._stream_reader_wr = None
202        if client_connected_cb is not None:
203            # This is a stream created by the `create_server()` function.
204            # Keep a strong reference to the reader until a connection
205            # is established.
206            self._strong_reader = stream_reader
207        self._reject_connection = False
208        self._stream_writer = None
209        self._transport = None
210        self._client_connected_cb = client_connected_cb
211        self._over_ssl = False
212        self._closed = self._loop.create_future()
213
214    @property
215    def _stream_reader(self):
216        if self._stream_reader_wr is None:
217            return None
218        return self._stream_reader_wr()
219
220    def connection_made(self, transport):
221        if self._reject_connection:
222            context = {
223                'message': ('An open stream was garbage collected prior to '
224                            'establishing network connection; '
225                            'call "stream.close()" explicitly.')
226            }
227            if self._source_traceback:
228                context['source_traceback'] = self._source_traceback
229            self._loop.call_exception_handler(context)
230            transport.abort()
231            return
232        self._transport = transport
233        reader = self._stream_reader
234        if reader is not None:
235            reader.set_transport(transport)
236        self._over_ssl = transport.get_extra_info('sslcontext') is not None
237        if self._client_connected_cb is not None:
238            self._stream_writer = StreamWriter(transport, self,
239                                               reader,
240                                               self._loop)
241            res = self._client_connected_cb(reader,
242                                            self._stream_writer)
243            if coroutines.iscoroutine(res):
244                self._loop.create_task(res)
245            self._strong_reader = None
246
247    def connection_lost(self, exc):
248        reader = self._stream_reader
249        if reader is not None:
250            if exc is None:
251                reader.feed_eof()
252            else:
253                reader.set_exception(exc)
254        if not self._closed.done():
255            if exc is None:
256                self._closed.set_result(None)
257            else:
258                self._closed.set_exception(exc)
259        super().connection_lost(exc)
260        self._stream_reader_wr = None
261        self._stream_writer = None
262        self._transport = None
263
264    def data_received(self, data):
265        reader = self._stream_reader
266        if reader is not None:
267            reader.feed_data(data)
268
269    def eof_received(self):
270        reader = self._stream_reader
271        if reader is not None:
272            reader.feed_eof()
273        if self._over_ssl:
274            # Prevent a warning in SSLProtocol.eof_received:
275            # "returning true from eof_received()
276            # has no effect when using ssl"
277            return False
278        return True
279
280    def _get_close_waiter(self, stream):
281        return self._closed
282
283    def __del__(self):
284        # Prevent reports about unhandled exceptions.
285        # Better than self._closed._log_traceback = False hack
286        try:
287            closed = self._closed
288        except AttributeError:
289            pass  # failed constructor
290        else:
291            if closed.done() and not closed.cancelled():
292                closed.exception()
293
294
295class StreamWriter:
296    """Wraps a Transport.
297
298    This exposes write(), writelines(), [can_]write_eof(),
299    get_extra_info() and close().  It adds drain() which returns an
300    optional Future on which you can wait for flow control.  It also
301    adds a transport property which references the Transport
302    directly.
303    """
304
305    def __init__(self, transport, protocol, reader, loop):
306        self._transport = transport
307        self._protocol = protocol
308        # drain() expects that the reader has an exception() method
309        assert reader is None or isinstance(reader, StreamReader)
310        self._reader = reader
311        self._loop = loop
312        self._complete_fut = self._loop.create_future()
313        self._complete_fut.set_result(None)
314
315    def __repr__(self):
316        info = [self.__class__.__name__, f'transport={self._transport!r}']
317        if self._reader is not None:
318            info.append(f'reader={self._reader!r}')
319        return '<{}>'.format(' '.join(info))
320
321    @property
322    def transport(self):
323        return self._transport
324
325    def write(self, data):
326        self._transport.write(data)
327
328    def writelines(self, data):
329        self._transport.writelines(data)
330
331    def write_eof(self):
332        return self._transport.write_eof()
333
334    def can_write_eof(self):
335        return self._transport.can_write_eof()
336
337    def close(self):
338        return self._transport.close()
339
340    def is_closing(self):
341        return self._transport.is_closing()
342
343    async def wait_closed(self):
344        await self._protocol._get_close_waiter(self)
345
346    def get_extra_info(self, name, default=None):
347        return self._transport.get_extra_info(name, default)
348
349    async def drain(self):
350        """Flush the write buffer.
351
352        The intended use is to write
353
354          w.write(data)
355          await w.drain()
356        """
357        if self._reader is not None:
358            exc = self._reader.exception()
359            if exc is not None:
360                raise exc
361        if self._transport.is_closing():
362            # Wait for protocol.connection_lost() call
363            # Raise connection closing error if any,
364            # ConnectionResetError otherwise
365            # Yield to the event loop so connection_lost() may be
366            # called.  Without this, _drain_helper() would return
367            # immediately, and code that calls
368            #     write(...); await drain()
369            # in a loop would never call connection_lost(), so it
370            # would not see an error when the socket is closed.
371            await sleep(0)
372        await self._protocol._drain_helper()
373
374
375class StreamReader:
376
377    _source_traceback = None
378
379    def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
380        # The line length limit is  a security feature;
381        # it also doubles as half the buffer limit.
382
383        if limit <= 0:
384            raise ValueError('Limit cannot be <= 0')
385
386        self._limit = limit
387        if loop is None:
388            self._loop = events._get_event_loop()
389        else:
390            self._loop = loop
391        self._buffer = bytearray()
392        self._eof = False    # Whether we're done.
393        self._waiter = None  # A future used by _wait_for_data()
394        self._exception = None
395        self._transport = None
396        self._paused = False
397        if self._loop.get_debug():
398            self._source_traceback = format_helpers.extract_stack(
399                sys._getframe(1))
400
401    def __repr__(self):
402        info = ['StreamReader']
403        if self._buffer:
404            info.append(f'{len(self._buffer)} bytes')
405        if self._eof:
406            info.append('eof')
407        if self._limit != _DEFAULT_LIMIT:
408            info.append(f'limit={self._limit}')
409        if self._waiter:
410            info.append(f'waiter={self._waiter!r}')
411        if self._exception:
412            info.append(f'exception={self._exception!r}')
413        if self._transport:
414            info.append(f'transport={self._transport!r}')
415        if self._paused:
416            info.append('paused')
417        return '<{}>'.format(' '.join(info))
418
419    def exception(self):
420        return self._exception
421
422    def set_exception(self, exc):
423        self._exception = exc
424
425        waiter = self._waiter
426        if waiter is not None:
427            self._waiter = None
428            if not waiter.cancelled():
429                waiter.set_exception(exc)
430
431    def _wakeup_waiter(self):
432        """Wakeup read*() functions waiting for data or EOF."""
433        waiter = self._waiter
434        if waiter is not None:
435            self._waiter = None
436            if not waiter.cancelled():
437                waiter.set_result(None)
438
439    def set_transport(self, transport):
440        assert self._transport is None, 'Transport already set'
441        self._transport = transport
442
443    def _maybe_resume_transport(self):
444        if self._paused and len(self._buffer) <= self._limit:
445            self._paused = False
446            self._transport.resume_reading()
447
448    def feed_eof(self):
449        self._eof = True
450        self._wakeup_waiter()
451
452    def at_eof(self):
453        """Return True if the buffer is empty and 'feed_eof' was called."""
454        return self._eof and not self._buffer
455
456    def feed_data(self, data):
457        assert not self._eof, 'feed_data after feed_eof'
458
459        if not data:
460            return
461
462        self._buffer.extend(data)
463        self._wakeup_waiter()
464
465        if (self._transport is not None and
466                not self._paused and
467                len(self._buffer) > 2 * self._limit):
468            try:
469                self._transport.pause_reading()
470            except NotImplementedError:
471                # The transport can't be paused.
472                # We'll just have to buffer all data.
473                # Forget the transport so we don't keep trying.
474                self._transport = None
475            else:
476                self._paused = True
477
478    async def _wait_for_data(self, func_name):
479        """Wait until feed_data() or feed_eof() is called.
480
481        If stream was paused, automatically resume it.
482        """
483        # StreamReader uses a future to link the protocol feed_data() method
484        # to a read coroutine. Running two read coroutines at the same time
485        # would have an unexpected behaviour. It would not possible to know
486        # which coroutine would get the next data.
487        if self._waiter is not None:
488            raise RuntimeError(
489                f'{func_name}() called while another coroutine is '
490                f'already waiting for incoming data')
491
492        assert not self._eof, '_wait_for_data after EOF'
493
494        # Waiting for data while paused will make deadlock, so prevent it.
495        # This is essential for readexactly(n) for case when n > self._limit.
496        if self._paused:
497            self._paused = False
498            self._transport.resume_reading()
499
500        self._waiter = self._loop.create_future()
501        try:
502            await self._waiter
503        finally:
504            self._waiter = None
505
506    async def readline(self):
507        """Read chunk of data from the stream until newline (b'\n') is found.
508
509        On success, return chunk that ends with newline. If only partial
510        line can be read due to EOF, return incomplete line without
511        terminating newline. When EOF was reached while no bytes read, empty
512        bytes object is returned.
513
514        If limit is reached, ValueError will be raised. In that case, if
515        newline was found, complete line including newline will be removed
516        from internal buffer. Else, internal buffer will be cleared. Limit is
517        compared against part of the line without newline.
518
519        If stream was paused, this function will automatically resume it if
520        needed.
521        """
522        sep = b'\n'
523        seplen = len(sep)
524        try:
525            line = await self.readuntil(sep)
526        except exceptions.IncompleteReadError as e:
527            return e.partial
528        except exceptions.LimitOverrunError as e:
529            if self._buffer.startswith(sep, e.consumed):
530                del self._buffer[:e.consumed + seplen]
531            else:
532                self._buffer.clear()
533            self._maybe_resume_transport()
534            raise ValueError(e.args[0])
535        return line
536
537    async def readuntil(self, separator=b'\n'):
538        """Read data from the stream until ``separator`` is found.
539
540        On success, the data and separator will be removed from the
541        internal buffer (consumed). Returned data will include the
542        separator at the end.
543
544        Configured stream limit is used to check result. Limit sets the
545        maximal length of data that can be returned, not counting the
546        separator.
547
548        If an EOF occurs and the complete separator is still not found,
549        an IncompleteReadError exception will be raised, and the internal
550        buffer will be reset.  The IncompleteReadError.partial attribute
551        may contain the separator partially.
552
553        If the data cannot be read because of over limit, a
554        LimitOverrunError exception  will be raised, and the data
555        will be left in the internal buffer, so it can be read again.
556        """
557        seplen = len(separator)
558        if seplen == 0:
559            raise ValueError('Separator should be at least one-byte string')
560
561        if self._exception is not None:
562            raise self._exception
563
564        # Consume whole buffer except last bytes, which length is
565        # one less than seplen. Let's check corner cases with
566        # separator='SEPARATOR':
567        # * we have received almost complete separator (without last
568        #   byte). i.e buffer='some textSEPARATO'. In this case we
569        #   can safely consume len(separator) - 1 bytes.
570        # * last byte of buffer is first byte of separator, i.e.
571        #   buffer='abcdefghijklmnopqrS'. We may safely consume
572        #   everything except that last byte, but this require to
573        #   analyze bytes of buffer that match partial separator.
574        #   This is slow and/or require FSM. For this case our
575        #   implementation is not optimal, since require rescanning
576        #   of data that is known to not belong to separator. In
577        #   real world, separator will not be so long to notice
578        #   performance problems. Even when reading MIME-encoded
579        #   messages :)
580
581        # `offset` is the number of bytes from the beginning of the buffer
582        # where there is no occurrence of `separator`.
583        offset = 0
584
585        # Loop until we find `separator` in the buffer, exceed the buffer size,
586        # or an EOF has happened.
587        while True:
588            buflen = len(self._buffer)
589
590            # Check if we now have enough data in the buffer for `separator` to
591            # fit.
592            if buflen - offset >= seplen:
593                isep = self._buffer.find(separator, offset)
594
595                if isep != -1:
596                    # `separator` is in the buffer. `isep` will be used later
597                    # to retrieve the data.
598                    break
599
600                # see upper comment for explanation.
601                offset = buflen + 1 - seplen
602                if offset > self._limit:
603                    raise exceptions.LimitOverrunError(
604                        'Separator is not found, and chunk exceed the limit',
605                        offset)
606
607            # Complete message (with full separator) may be present in buffer
608            # even when EOF flag is set. This may happen when the last chunk
609            # adds data which makes separator be found. That's why we check for
610            # EOF *ater* inspecting the buffer.
611            if self._eof:
612                chunk = bytes(self._buffer)
613                self._buffer.clear()
614                raise exceptions.IncompleteReadError(chunk, None)
615
616            # _wait_for_data() will resume reading if stream was paused.
617            await self._wait_for_data('readuntil')
618
619        if isep > self._limit:
620            raise exceptions.LimitOverrunError(
621                'Separator is found, but chunk is longer than limit', isep)
622
623        chunk = self._buffer[:isep + seplen]
624        del self._buffer[:isep + seplen]
625        self._maybe_resume_transport()
626        return bytes(chunk)
627
628    async def read(self, n=-1):
629        """Read up to `n` bytes from the stream.
630
631        If n is not provided, or set to -1, read until EOF and return all read
632        bytes. If the EOF was received and the internal buffer is empty, return
633        an empty bytes object.
634
635        If n is zero, return empty bytes object immediately.
636
637        If n is positive, this function try to read `n` bytes, and may return
638        less or equal bytes than requested, but at least one byte. If EOF was
639        received before any byte is read, this function returns empty byte
640        object.
641
642        Returned value is not limited with limit, configured at stream
643        creation.
644
645        If stream was paused, this function will automatically resume it if
646        needed.
647        """
648
649        if self._exception is not None:
650            raise self._exception
651
652        if n == 0:
653            return b''
654
655        if n < 0:
656            # This used to just loop creating a new waiter hoping to
657            # collect everything in self._buffer, but that would
658            # deadlock if the subprocess sends more than self.limit
659            # bytes.  So just call self.read(self._limit) until EOF.
660            blocks = []
661            while True:
662                block = await self.read(self._limit)
663                if not block:
664                    break
665                blocks.append(block)
666            return b''.join(blocks)
667
668        if not self._buffer and not self._eof:
669            await self._wait_for_data('read')
670
671        # This will work right even if buffer is less than n bytes
672        data = bytes(self._buffer[:n])
673        del self._buffer[:n]
674
675        self._maybe_resume_transport()
676        return data
677
678    async def readexactly(self, n):
679        """Read exactly `n` bytes.
680
681        Raise an IncompleteReadError if EOF is reached before `n` bytes can be
682        read. The IncompleteReadError.partial attribute of the exception will
683        contain the partial read bytes.
684
685        if n is zero, return empty bytes object.
686
687        Returned value is not limited with limit, configured at stream
688        creation.
689
690        If stream was paused, this function will automatically resume it if
691        needed.
692        """
693        if n < 0:
694            raise ValueError('readexactly size can not be less than zero')
695
696        if self._exception is not None:
697            raise self._exception
698
699        if n == 0:
700            return b''
701
702        while len(self._buffer) < n:
703            if self._eof:
704                incomplete = bytes(self._buffer)
705                self._buffer.clear()
706                raise exceptions.IncompleteReadError(incomplete, n)
707
708            await self._wait_for_data('readexactly')
709
710        if len(self._buffer) == n:
711            data = bytes(self._buffer)
712            self._buffer.clear()
713        else:
714            data = bytes(self._buffer[:n])
715            del self._buffer[:n]
716        self._maybe_resume_transport()
717        return data
718
719    def __aiter__(self):
720        return self
721
722    async def __anext__(self):
723        val = await self.readline()
724        if val == b'':
725            raise StopAsyncIteration
726        return val
727