• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1__all__ = (
2    'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
3    'open_connection', 'start_server')
4
5import socket
6import sys
7import warnings
8import weakref
9
10if hasattr(socket, 'AF_UNIX'):
11    __all__ += ('open_unix_connection', 'start_unix_server')
12
13from . import coroutines
14from . import events
15from . import exceptions
16from . import format_helpers
17from . import protocols
18from .log import logger
19from .tasks import sleep
20
21
22_DEFAULT_LIMIT = 2 ** 16  # 64 KiB
23
24
25async def open_connection(host=None, port=None, *,
26                          loop=None, limit=_DEFAULT_LIMIT, **kwds):
27    """A wrapper for create_connection() returning a (reader, writer) pair.
28
29    The reader returned is a StreamReader instance; the writer is a
30    StreamWriter instance.
31
32    The arguments are all the usual arguments to create_connection()
33    except protocol_factory; most common are positional host and port,
34    with various optional keyword arguments following.
35
36    Additional optional keyword arguments are loop (to set the event loop
37    instance to use) and limit (to set the buffer limit passed to the
38    StreamReader).
39
40    (If you want to customize the StreamReader and/or
41    StreamReaderProtocol classes, just copy the code -- there's
42    really nothing special here except some convenience.)
43    """
44    if loop is None:
45        loop = events.get_event_loop()
46    else:
47        warnings.warn("The loop argument is deprecated since Python 3.8, "
48                      "and scheduled for removal in Python 3.10.",
49                      DeprecationWarning, stacklevel=2)
50    reader = StreamReader(limit=limit, loop=loop)
51    protocol = StreamReaderProtocol(reader, loop=loop)
52    transport, _ = await loop.create_connection(
53        lambda: protocol, host, port, **kwds)
54    writer = StreamWriter(transport, protocol, reader, loop)
55    return reader, writer
56
57
58async def start_server(client_connected_cb, host=None, port=None, *,
59                       loop=None, limit=_DEFAULT_LIMIT, **kwds):
60    """Start a socket server, call back for each client connected.
61
62    The first parameter, `client_connected_cb`, takes two parameters:
63    client_reader, client_writer.  client_reader is a StreamReader
64    object, while client_writer is a StreamWriter object.  This
65    parameter can either be a plain callback function or a coroutine;
66    if it is a coroutine, it will be automatically converted into a
67    Task.
68
69    The rest of the arguments are all the usual arguments to
70    loop.create_server() except protocol_factory; most common are
71    positional host and port, with various optional keyword arguments
72    following.  The return value is the same as loop.create_server().
73
74    Additional optional keyword arguments are loop (to set the event loop
75    instance to use) and limit (to set the buffer limit passed to the
76    StreamReader).
77
78    The return value is the same as loop.create_server(), i.e. a
79    Server object which can be used to stop the service.
80    """
81    if loop is None:
82        loop = events.get_event_loop()
83    else:
84        warnings.warn("The loop argument is deprecated since Python 3.8, "
85                      "and scheduled for removal in Python 3.10.",
86                      DeprecationWarning, stacklevel=2)
87
88    def factory():
89        reader = StreamReader(limit=limit, loop=loop)
90        protocol = StreamReaderProtocol(reader, client_connected_cb,
91                                        loop=loop)
92        return protocol
93
94    return await loop.create_server(factory, host, port, **kwds)
95
96
97if hasattr(socket, 'AF_UNIX'):
98    # UNIX Domain Sockets are supported on this platform
99
100    async def open_unix_connection(path=None, *,
101                                   loop=None, limit=_DEFAULT_LIMIT, **kwds):
102        """Similar to `open_connection` but works with UNIX Domain Sockets."""
103        if loop is None:
104            loop = events.get_event_loop()
105        else:
106            warnings.warn("The loop argument is deprecated since Python 3.8, "
107                          "and scheduled for removal in Python 3.10.",
108                          DeprecationWarning, stacklevel=2)
109        reader = StreamReader(limit=limit, loop=loop)
110        protocol = StreamReaderProtocol(reader, loop=loop)
111        transport, _ = await loop.create_unix_connection(
112            lambda: protocol, path, **kwds)
113        writer = StreamWriter(transport, protocol, reader, loop)
114        return reader, writer
115
116    async def start_unix_server(client_connected_cb, path=None, *,
117                                loop=None, limit=_DEFAULT_LIMIT, **kwds):
118        """Similar to `start_server` but works with UNIX Domain Sockets."""
119        if loop is None:
120            loop = events.get_event_loop()
121        else:
122            warnings.warn("The loop argument is deprecated since Python 3.8, "
123                          "and scheduled for removal in Python 3.10.",
124                          DeprecationWarning, stacklevel=2)
125
126        def factory():
127            reader = StreamReader(limit=limit, loop=loop)
128            protocol = StreamReaderProtocol(reader, client_connected_cb,
129                                            loop=loop)
130            return protocol
131
132        return await loop.create_unix_server(factory, path, **kwds)
133
134
135class FlowControlMixin(protocols.Protocol):
136    """Reusable flow control logic for StreamWriter.drain().
137
138    This implements the protocol methods pause_writing(),
139    resume_writing() and connection_lost().  If the subclass overrides
140    these it must call the super methods.
141
142    StreamWriter.drain() must wait for _drain_helper() coroutine.
143    """
144
145    def __init__(self, loop=None):
146        if loop is None:
147            self._loop = events.get_event_loop()
148        else:
149            self._loop = loop
150        self._paused = False
151        self._drain_waiter = None
152        self._connection_lost = False
153
154    def pause_writing(self):
155        assert not self._paused
156        self._paused = True
157        if self._loop.get_debug():
158            logger.debug("%r pauses writing", self)
159
160    def resume_writing(self):
161        assert self._paused
162        self._paused = False
163        if self._loop.get_debug():
164            logger.debug("%r resumes writing", self)
165
166        waiter = self._drain_waiter
167        if waiter is not None:
168            self._drain_waiter = None
169            if not waiter.done():
170                waiter.set_result(None)
171
172    def connection_lost(self, exc):
173        self._connection_lost = True
174        # Wake up the writer if currently paused.
175        if not self._paused:
176            return
177        waiter = self._drain_waiter
178        if waiter is None:
179            return
180        self._drain_waiter = None
181        if waiter.done():
182            return
183        if exc is None:
184            waiter.set_result(None)
185        else:
186            waiter.set_exception(exc)
187
188    async def _drain_helper(self):
189        if self._connection_lost:
190            raise ConnectionResetError('Connection lost')
191        if not self._paused:
192            return
193        waiter = self._drain_waiter
194        assert waiter is None or waiter.cancelled()
195        waiter = self._loop.create_future()
196        self._drain_waiter = waiter
197        await waiter
198
199    def _get_close_waiter(self, stream):
200        raise NotImplementedError
201
202
203class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
204    """Helper class to adapt between Protocol and StreamReader.
205
206    (This is a helper class instead of making StreamReader itself a
207    Protocol subclass, because the StreamReader has other potential
208    uses, and to prevent the user of the StreamReader to accidentally
209    call inappropriate methods of the protocol.)
210    """
211
212    _source_traceback = None
213
214    def __init__(self, stream_reader, client_connected_cb=None, loop=None):
215        super().__init__(loop=loop)
216        if stream_reader is not None:
217            self._stream_reader_wr = weakref.ref(stream_reader)
218            self._source_traceback = stream_reader._source_traceback
219        else:
220            self._stream_reader_wr = None
221        if client_connected_cb is not None:
222            # This is a stream created by the `create_server()` function.
223            # Keep a strong reference to the reader until a connection
224            # is established.
225            self._strong_reader = stream_reader
226        self._reject_connection = False
227        self._stream_writer = None
228        self._transport = None
229        self._client_connected_cb = client_connected_cb
230        self._over_ssl = False
231        self._closed = self._loop.create_future()
232
233    @property
234    def _stream_reader(self):
235        if self._stream_reader_wr is None:
236            return None
237        return self._stream_reader_wr()
238
239    def connection_made(self, transport):
240        if self._reject_connection:
241            context = {
242                'message': ('An open stream was garbage collected prior to '
243                            'establishing network connection; '
244                            'call "stream.close()" explicitly.')
245            }
246            if self._source_traceback:
247                context['source_traceback'] = self._source_traceback
248            self._loop.call_exception_handler(context)
249            transport.abort()
250            return
251        self._transport = transport
252        reader = self._stream_reader
253        if reader is not None:
254            reader.set_transport(transport)
255        self._over_ssl = transport.get_extra_info('sslcontext') is not None
256        if self._client_connected_cb is not None:
257            self._stream_writer = StreamWriter(transport, self,
258                                               reader,
259                                               self._loop)
260            res = self._client_connected_cb(reader,
261                                            self._stream_writer)
262            if coroutines.iscoroutine(res):
263                self._loop.create_task(res)
264            self._strong_reader = None
265
266    def connection_lost(self, exc):
267        reader = self._stream_reader
268        if reader is not None:
269            if exc is None:
270                reader.feed_eof()
271            else:
272                reader.set_exception(exc)
273        if not self._closed.done():
274            if exc is None:
275                self._closed.set_result(None)
276            else:
277                self._closed.set_exception(exc)
278        super().connection_lost(exc)
279        self._stream_reader_wr = None
280        self._stream_writer = None
281        self._transport = None
282
283    def data_received(self, data):
284        reader = self._stream_reader
285        if reader is not None:
286            reader.feed_data(data)
287
288    def eof_received(self):
289        reader = self._stream_reader
290        if reader is not None:
291            reader.feed_eof()
292        if self._over_ssl:
293            # Prevent a warning in SSLProtocol.eof_received:
294            # "returning true from eof_received()
295            # has no effect when using ssl"
296            return False
297        return True
298
299    def _get_close_waiter(self, stream):
300        return self._closed
301
302    def __del__(self):
303        # Prevent reports about unhandled exceptions.
304        # Better than self._closed._log_traceback = False hack
305        closed = self._closed
306        if closed.done() and not closed.cancelled():
307            closed.exception()
308
309
310class StreamWriter:
311    """Wraps a Transport.
312
313    This exposes write(), writelines(), [can_]write_eof(),
314    get_extra_info() and close().  It adds drain() which returns an
315    optional Future on which you can wait for flow control.  It also
316    adds a transport property which references the Transport
317    directly.
318    """
319
320    def __init__(self, transport, protocol, reader, loop):
321        self._transport = transport
322        self._protocol = protocol
323        # drain() expects that the reader has an exception() method
324        assert reader is None or isinstance(reader, StreamReader)
325        self._reader = reader
326        self._loop = loop
327        self._complete_fut = self._loop.create_future()
328        self._complete_fut.set_result(None)
329
330    def __repr__(self):
331        info = [self.__class__.__name__, f'transport={self._transport!r}']
332        if self._reader is not None:
333            info.append(f'reader={self._reader!r}')
334        return '<{}>'.format(' '.join(info))
335
336    @property
337    def transport(self):
338        return self._transport
339
340    def write(self, data):
341        self._transport.write(data)
342
343    def writelines(self, data):
344        self._transport.writelines(data)
345
346    def write_eof(self):
347        return self._transport.write_eof()
348
349    def can_write_eof(self):
350        return self._transport.can_write_eof()
351
352    def close(self):
353        return self._transport.close()
354
355    def is_closing(self):
356        return self._transport.is_closing()
357
358    async def wait_closed(self):
359        await self._protocol._get_close_waiter(self)
360
361    def get_extra_info(self, name, default=None):
362        return self._transport.get_extra_info(name, default)
363
364    async def drain(self):
365        """Flush the write buffer.
366
367        The intended use is to write
368
369          w.write(data)
370          await w.drain()
371        """
372        if self._reader is not None:
373            exc = self._reader.exception()
374            if exc is not None:
375                raise exc
376        if self._transport.is_closing():
377            # Wait for protocol.connection_lost() call
378            # Raise connection closing error if any,
379            # ConnectionResetError otherwise
380            # Yield to the event loop so connection_lost() may be
381            # called.  Without this, _drain_helper() would return
382            # immediately, and code that calls
383            #     write(...); await drain()
384            # in a loop would never call connection_lost(), so it
385            # would not see an error when the socket is closed.
386            await sleep(0)
387        await self._protocol._drain_helper()
388
389
390class StreamReader:
391
392    _source_traceback = None
393
394    def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
395        # The line length limit is  a security feature;
396        # it also doubles as half the buffer limit.
397
398        if limit <= 0:
399            raise ValueError('Limit cannot be <= 0')
400
401        self._limit = limit
402        if loop is None:
403            self._loop = events.get_event_loop()
404        else:
405            self._loop = loop
406        self._buffer = bytearray()
407        self._eof = False    # Whether we're done.
408        self._waiter = None  # A future used by _wait_for_data()
409        self._exception = None
410        self._transport = None
411        self._paused = False
412        if self._loop.get_debug():
413            self._source_traceback = format_helpers.extract_stack(
414                sys._getframe(1))
415
416    def __repr__(self):
417        info = ['StreamReader']
418        if self._buffer:
419            info.append(f'{len(self._buffer)} bytes')
420        if self._eof:
421            info.append('eof')
422        if self._limit != _DEFAULT_LIMIT:
423            info.append(f'limit={self._limit}')
424        if self._waiter:
425            info.append(f'waiter={self._waiter!r}')
426        if self._exception:
427            info.append(f'exception={self._exception!r}')
428        if self._transport:
429            info.append(f'transport={self._transport!r}')
430        if self._paused:
431            info.append('paused')
432        return '<{}>'.format(' '.join(info))
433
434    def exception(self):
435        return self._exception
436
437    def set_exception(self, exc):
438        self._exception = exc
439
440        waiter = self._waiter
441        if waiter is not None:
442            self._waiter = None
443            if not waiter.cancelled():
444                waiter.set_exception(exc)
445
446    def _wakeup_waiter(self):
447        """Wakeup read*() functions waiting for data or EOF."""
448        waiter = self._waiter
449        if waiter is not None:
450            self._waiter = None
451            if not waiter.cancelled():
452                waiter.set_result(None)
453
454    def set_transport(self, transport):
455        assert self._transport is None, 'Transport already set'
456        self._transport = transport
457
458    def _maybe_resume_transport(self):
459        if self._paused and len(self._buffer) <= self._limit:
460            self._paused = False
461            self._transport.resume_reading()
462
463    def feed_eof(self):
464        self._eof = True
465        self._wakeup_waiter()
466
467    def at_eof(self):
468        """Return True if the buffer is empty and 'feed_eof' was called."""
469        return self._eof and not self._buffer
470
471    def feed_data(self, data):
472        assert not self._eof, 'feed_data after feed_eof'
473
474        if not data:
475            return
476
477        self._buffer.extend(data)
478        self._wakeup_waiter()
479
480        if (self._transport is not None and
481                not self._paused and
482                len(self._buffer) > 2 * self._limit):
483            try:
484                self._transport.pause_reading()
485            except NotImplementedError:
486                # The transport can't be paused.
487                # We'll just have to buffer all data.
488                # Forget the transport so we don't keep trying.
489                self._transport = None
490            else:
491                self._paused = True
492
493    async def _wait_for_data(self, func_name):
494        """Wait until feed_data() or feed_eof() is called.
495
496        If stream was paused, automatically resume it.
497        """
498        # StreamReader uses a future to link the protocol feed_data() method
499        # to a read coroutine. Running two read coroutines at the same time
500        # would have an unexpected behaviour. It would not possible to know
501        # which coroutine would get the next data.
502        if self._waiter is not None:
503            raise RuntimeError(
504                f'{func_name}() called while another coroutine is '
505                f'already waiting for incoming data')
506
507        assert not self._eof, '_wait_for_data after EOF'
508
509        # Waiting for data while paused will make deadlock, so prevent it.
510        # This is essential for readexactly(n) for case when n > self._limit.
511        if self._paused:
512            self._paused = False
513            self._transport.resume_reading()
514
515        self._waiter = self._loop.create_future()
516        try:
517            await self._waiter
518        finally:
519            self._waiter = None
520
521    async def readline(self):
522        """Read chunk of data from the stream until newline (b'\n') is found.
523
524        On success, return chunk that ends with newline. If only partial
525        line can be read due to EOF, return incomplete line without
526        terminating newline. When EOF was reached while no bytes read, empty
527        bytes object is returned.
528
529        If limit is reached, ValueError will be raised. In that case, if
530        newline was found, complete line including newline will be removed
531        from internal buffer. Else, internal buffer will be cleared. Limit is
532        compared against part of the line without newline.
533
534        If stream was paused, this function will automatically resume it if
535        needed.
536        """
537        sep = b'\n'
538        seplen = len(sep)
539        try:
540            line = await self.readuntil(sep)
541        except exceptions.IncompleteReadError as e:
542            return e.partial
543        except exceptions.LimitOverrunError as e:
544            if self._buffer.startswith(sep, e.consumed):
545                del self._buffer[:e.consumed + seplen]
546            else:
547                self._buffer.clear()
548            self._maybe_resume_transport()
549            raise ValueError(e.args[0])
550        return line
551
552    async def readuntil(self, separator=b'\n'):
553        """Read data from the stream until ``separator`` is found.
554
555        On success, the data and separator will be removed from the
556        internal buffer (consumed). Returned data will include the
557        separator at the end.
558
559        Configured stream limit is used to check result. Limit sets the
560        maximal length of data that can be returned, not counting the
561        separator.
562
563        If an EOF occurs and the complete separator is still not found,
564        an IncompleteReadError exception will be raised, and the internal
565        buffer will be reset.  The IncompleteReadError.partial attribute
566        may contain the separator partially.
567
568        If the data cannot be read because of over limit, a
569        LimitOverrunError exception  will be raised, and the data
570        will be left in the internal buffer, so it can be read again.
571        """
572        seplen = len(separator)
573        if seplen == 0:
574            raise ValueError('Separator should be at least one-byte string')
575
576        if self._exception is not None:
577            raise self._exception
578
579        # Consume whole buffer except last bytes, which length is
580        # one less than seplen. Let's check corner cases with
581        # separator='SEPARATOR':
582        # * we have received almost complete separator (without last
583        #   byte). i.e buffer='some textSEPARATO'. In this case we
584        #   can safely consume len(separator) - 1 bytes.
585        # * last byte of buffer is first byte of separator, i.e.
586        #   buffer='abcdefghijklmnopqrS'. We may safely consume
587        #   everything except that last byte, but this require to
588        #   analyze bytes of buffer that match partial separator.
589        #   This is slow and/or require FSM. For this case our
590        #   implementation is not optimal, since require rescanning
591        #   of data that is known to not belong to separator. In
592        #   real world, separator will not be so long to notice
593        #   performance problems. Even when reading MIME-encoded
594        #   messages :)
595
596        # `offset` is the number of bytes from the beginning of the buffer
597        # where there is no occurrence of `separator`.
598        offset = 0
599
600        # Loop until we find `separator` in the buffer, exceed the buffer size,
601        # or an EOF has happened.
602        while True:
603            buflen = len(self._buffer)
604
605            # Check if we now have enough data in the buffer for `separator` to
606            # fit.
607            if buflen - offset >= seplen:
608                isep = self._buffer.find(separator, offset)
609
610                if isep != -1:
611                    # `separator` is in the buffer. `isep` will be used later
612                    # to retrieve the data.
613                    break
614
615                # see upper comment for explanation.
616                offset = buflen + 1 - seplen
617                if offset > self._limit:
618                    raise exceptions.LimitOverrunError(
619                        'Separator is not found, and chunk exceed the limit',
620                        offset)
621
622            # Complete message (with full separator) may be present in buffer
623            # even when EOF flag is set. This may happen when the last chunk
624            # adds data which makes separator be found. That's why we check for
625            # EOF *ater* inspecting the buffer.
626            if self._eof:
627                chunk = bytes(self._buffer)
628                self._buffer.clear()
629                raise exceptions.IncompleteReadError(chunk, None)
630
631            # _wait_for_data() will resume reading if stream was paused.
632            await self._wait_for_data('readuntil')
633
634        if isep > self._limit:
635            raise exceptions.LimitOverrunError(
636                'Separator is found, but chunk is longer than limit', isep)
637
638        chunk = self._buffer[:isep + seplen]
639        del self._buffer[:isep + seplen]
640        self._maybe_resume_transport()
641        return bytes(chunk)
642
643    async def read(self, n=-1):
644        """Read up to `n` bytes from the stream.
645
646        If n is not provided, or set to -1, read until EOF and return all read
647        bytes. If the EOF was received and the internal buffer is empty, return
648        an empty bytes object.
649
650        If n is zero, return empty bytes object immediately.
651
652        If n is positive, this function try to read `n` bytes, and may return
653        less or equal bytes than requested, but at least one byte. If EOF was
654        received before any byte is read, this function returns empty byte
655        object.
656
657        Returned value is not limited with limit, configured at stream
658        creation.
659
660        If stream was paused, this function will automatically resume it if
661        needed.
662        """
663
664        if self._exception is not None:
665            raise self._exception
666
667        if n == 0:
668            return b''
669
670        if n < 0:
671            # This used to just loop creating a new waiter hoping to
672            # collect everything in self._buffer, but that would
673            # deadlock if the subprocess sends more than self.limit
674            # bytes.  So just call self.read(self._limit) until EOF.
675            blocks = []
676            while True:
677                block = await self.read(self._limit)
678                if not block:
679                    break
680                blocks.append(block)
681            return b''.join(blocks)
682
683        if not self._buffer and not self._eof:
684            await self._wait_for_data('read')
685
686        # This will work right even if buffer is less than n bytes
687        data = bytes(self._buffer[:n])
688        del self._buffer[:n]
689
690        self._maybe_resume_transport()
691        return data
692
693    async def readexactly(self, n):
694        """Read exactly `n` bytes.
695
696        Raise an IncompleteReadError if EOF is reached before `n` bytes can be
697        read. The IncompleteReadError.partial attribute of the exception will
698        contain the partial read bytes.
699
700        if n is zero, return empty bytes object.
701
702        Returned value is not limited with limit, configured at stream
703        creation.
704
705        If stream was paused, this function will automatically resume it if
706        needed.
707        """
708        if n < 0:
709            raise ValueError('readexactly size can not be less than zero')
710
711        if self._exception is not None:
712            raise self._exception
713
714        if n == 0:
715            return b''
716
717        while len(self._buffer) < n:
718            if self._eof:
719                incomplete = bytes(self._buffer)
720                self._buffer.clear()
721                raise exceptions.IncompleteReadError(incomplete, n)
722
723            await self._wait_for_data('readexactly')
724
725        if len(self._buffer) == n:
726            data = bytes(self._buffer)
727            self._buffer.clear()
728        else:
729            data = bytes(self._buffer[:n])
730            del self._buffer[:n]
731        self._maybe_resume_transport()
732        return data
733
734    def __aiter__(self):
735        return self
736
737    async def __anext__(self):
738        val = await self.readline()
739        if val == b'':
740            raise StopAsyncIteration
741        return val
742