• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1__all__ = (
2    'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
3    'open_connection', 'start_server')
4
5import collections
6import socket
7import sys
8import warnings
9import weakref
10
11if hasattr(socket, 'AF_UNIX'):
12    __all__ += ('open_unix_connection', 'start_unix_server')
13
14from . import coroutines
15from . import events
16from . import exceptions
17from . import format_helpers
18from . import protocols
19from .log import logger
20from .tasks import sleep
21
22
23_DEFAULT_LIMIT = 2 ** 16  # 64 KiB
24
25
26async def open_connection(host=None, port=None, *,
27                          limit=_DEFAULT_LIMIT, **kwds):
28    """A wrapper for create_connection() returning a (reader, writer) pair.
29
30    The reader returned is a StreamReader instance; the writer is a
31    StreamWriter instance.
32
33    The arguments are all the usual arguments to create_connection()
34    except protocol_factory; most common are positional host and port,
35    with various optional keyword arguments following.
36
37    Additional optional keyword arguments are loop (to set the event loop
38    instance to use) and limit (to set the buffer limit passed to the
39    StreamReader).
40
41    (If you want to customize the StreamReader and/or
42    StreamReaderProtocol classes, just copy the code -- there's
43    really nothing special here except some convenience.)
44    """
45    loop = events.get_running_loop()
46    reader = StreamReader(limit=limit, loop=loop)
47    protocol = StreamReaderProtocol(reader, loop=loop)
48    transport, _ = await loop.create_connection(
49        lambda: protocol, host, port, **kwds)
50    writer = StreamWriter(transport, protocol, reader, loop)
51    return reader, writer
52
53
54async def start_server(client_connected_cb, host=None, port=None, *,
55                       limit=_DEFAULT_LIMIT, **kwds):
56    """Start a socket server, call back for each client connected.
57
58    The first parameter, `client_connected_cb`, takes two parameters:
59    client_reader, client_writer.  client_reader is a StreamReader
60    object, while client_writer is a StreamWriter object.  This
61    parameter can either be a plain callback function or a coroutine;
62    if it is a coroutine, it will be automatically converted into a
63    Task.
64
65    The rest of the arguments are all the usual arguments to
66    loop.create_server() except protocol_factory; most common are
67    positional host and port, with various optional keyword arguments
68    following.  The return value is the same as loop.create_server().
69
70    Additional optional keyword argument is limit (to set the buffer
71    limit passed to the StreamReader).
72
73    The return value is the same as loop.create_server(), i.e. a
74    Server object which can be used to stop the service.
75    """
76    loop = events.get_running_loop()
77
78    def factory():
79        reader = StreamReader(limit=limit, loop=loop)
80        protocol = StreamReaderProtocol(reader, client_connected_cb,
81                                        loop=loop)
82        return protocol
83
84    return await loop.create_server(factory, host, port, **kwds)
85
86
87if hasattr(socket, 'AF_UNIX'):
88    # UNIX Domain Sockets are supported on this platform
89
90    async def open_unix_connection(path=None, *,
91                                   limit=_DEFAULT_LIMIT, **kwds):
92        """Similar to `open_connection` but works with UNIX Domain Sockets."""
93        loop = events.get_running_loop()
94
95        reader = StreamReader(limit=limit, loop=loop)
96        protocol = StreamReaderProtocol(reader, loop=loop)
97        transport, _ = await loop.create_unix_connection(
98            lambda: protocol, path, **kwds)
99        writer = StreamWriter(transport, protocol, reader, loop)
100        return reader, writer
101
102    async def start_unix_server(client_connected_cb, path=None, *,
103                                limit=_DEFAULT_LIMIT, **kwds):
104        """Similar to `start_server` but works with UNIX Domain Sockets."""
105        loop = events.get_running_loop()
106
107        def factory():
108            reader = StreamReader(limit=limit, loop=loop)
109            protocol = StreamReaderProtocol(reader, client_connected_cb,
110                                            loop=loop)
111            return protocol
112
113        return await loop.create_unix_server(factory, path, **kwds)
114
115
116class FlowControlMixin(protocols.Protocol):
117    """Reusable flow control logic for StreamWriter.drain().
118
119    This implements the protocol methods pause_writing(),
120    resume_writing() and connection_lost().  If the subclass overrides
121    these it must call the super methods.
122
123    StreamWriter.drain() must wait for _drain_helper() coroutine.
124    """
125
126    def __init__(self, loop=None):
127        if loop is None:
128            self._loop = events.get_event_loop()
129        else:
130            self._loop = loop
131        self._paused = False
132        self._drain_waiters = collections.deque()
133        self._connection_lost = False
134
135    def pause_writing(self):
136        assert not self._paused
137        self._paused = True
138        if self._loop.get_debug():
139            logger.debug("%r pauses writing", self)
140
141    def resume_writing(self):
142        assert self._paused
143        self._paused = False
144        if self._loop.get_debug():
145            logger.debug("%r resumes writing", self)
146
147        for waiter in self._drain_waiters:
148            if not waiter.done():
149                waiter.set_result(None)
150
151    def connection_lost(self, exc):
152        self._connection_lost = True
153        # Wake up the writer(s) if currently paused.
154        if not self._paused:
155            return
156
157        for waiter in self._drain_waiters:
158            if not waiter.done():
159                if exc is None:
160                    waiter.set_result(None)
161                else:
162                    waiter.set_exception(exc)
163
164    async def _drain_helper(self):
165        if self._connection_lost:
166            raise ConnectionResetError('Connection lost')
167        if not self._paused:
168            return
169        waiter = self._loop.create_future()
170        self._drain_waiters.append(waiter)
171        try:
172            await waiter
173        finally:
174            self._drain_waiters.remove(waiter)
175
176    def _get_close_waiter(self, stream):
177        raise NotImplementedError
178
179
180class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
181    """Helper class to adapt between Protocol and StreamReader.
182
183    (This is a helper class instead of making StreamReader itself a
184    Protocol subclass, because the StreamReader has other potential
185    uses, and to prevent the user of the StreamReader to accidentally
186    call inappropriate methods of the protocol.)
187    """
188
189    _source_traceback = None
190
191    def __init__(self, stream_reader, client_connected_cb=None, loop=None):
192        super().__init__(loop=loop)
193        if stream_reader is not None:
194            self._stream_reader_wr = weakref.ref(stream_reader)
195            self._source_traceback = stream_reader._source_traceback
196        else:
197            self._stream_reader_wr = None
198        if client_connected_cb is not None:
199            # This is a stream created by the `create_server()` function.
200            # Keep a strong reference to the reader until a connection
201            # is established.
202            self._strong_reader = stream_reader
203        self._reject_connection = False
204        self._task = None
205        self._transport = None
206        self._client_connected_cb = client_connected_cb
207        self._over_ssl = False
208        self._closed = self._loop.create_future()
209
210    @property
211    def _stream_reader(self):
212        if self._stream_reader_wr is None:
213            return None
214        return self._stream_reader_wr()
215
216    def _replace_transport(self, transport):
217        loop = self._loop
218        self._transport = transport
219        self._over_ssl = transport.get_extra_info('sslcontext') is not None
220
221    def connection_made(self, transport):
222        if self._reject_connection:
223            context = {
224                'message': ('An open stream was garbage collected prior to '
225                            'establishing network connection; '
226                            'call "stream.close()" explicitly.')
227            }
228            if self._source_traceback:
229                context['source_traceback'] = self._source_traceback
230            self._loop.call_exception_handler(context)
231            transport.abort()
232            return
233        self._transport = transport
234        reader = self._stream_reader
235        if reader is not None:
236            reader.set_transport(transport)
237        self._over_ssl = transport.get_extra_info('sslcontext') is not None
238        if self._client_connected_cb is not None:
239            writer = StreamWriter(transport, self, reader, self._loop)
240            res = self._client_connected_cb(reader, writer)
241            if coroutines.iscoroutine(res):
242                def callback(task):
243                    if task.cancelled():
244                        transport.close()
245                        return
246                    exc = task.exception()
247                    if exc is not None:
248                        self._loop.call_exception_handler({
249                            'message': 'Unhandled exception in client_connected_cb',
250                            'exception': exc,
251                            'transport': transport,
252                        })
253                        transport.close()
254
255                self._task = self._loop.create_task(res)
256                self._task.add_done_callback(callback)
257
258            self._strong_reader = None
259
260    def connection_lost(self, exc):
261        reader = self._stream_reader
262        if reader is not None:
263            if exc is None:
264                reader.feed_eof()
265            else:
266                reader.set_exception(exc)
267        if not self._closed.done():
268            if exc is None:
269                self._closed.set_result(None)
270            else:
271                self._closed.set_exception(exc)
272        super().connection_lost(exc)
273        self._stream_reader_wr = None
274        self._stream_writer = None
275        self._task = None
276        self._transport = None
277
278    def data_received(self, data):
279        reader = self._stream_reader
280        if reader is not None:
281            reader.feed_data(data)
282
283    def eof_received(self):
284        reader = self._stream_reader
285        if reader is not None:
286            reader.feed_eof()
287        if self._over_ssl:
288            # Prevent a warning in SSLProtocol.eof_received:
289            # "returning true from eof_received()
290            # has no effect when using ssl"
291            return False
292        return True
293
294    def _get_close_waiter(self, stream):
295        return self._closed
296
297    def __del__(self):
298        # Prevent reports about unhandled exceptions.
299        # Better than self._closed._log_traceback = False hack
300        try:
301            closed = self._closed
302        except AttributeError:
303            pass  # failed constructor
304        else:
305            if closed.done() and not closed.cancelled():
306                closed.exception()
307
308
309class StreamWriter:
310    """Wraps a Transport.
311
312    This exposes write(), writelines(), [can_]write_eof(),
313    get_extra_info() and close().  It adds drain() which returns an
314    optional Future on which you can wait for flow control.  It also
315    adds a transport property which references the Transport
316    directly.
317    """
318
319    def __init__(self, transport, protocol, reader, loop):
320        self._transport = transport
321        self._protocol = protocol
322        # drain() expects that the reader has an exception() method
323        assert reader is None or isinstance(reader, StreamReader)
324        self._reader = reader
325        self._loop = loop
326        self._complete_fut = self._loop.create_future()
327        self._complete_fut.set_result(None)
328
329    def __repr__(self):
330        info = [self.__class__.__name__, f'transport={self._transport!r}']
331        if self._reader is not None:
332            info.append(f'reader={self._reader!r}')
333        return '<{}>'.format(' '.join(info))
334
335    @property
336    def transport(self):
337        return self._transport
338
339    def write(self, data):
340        self._transport.write(data)
341
342    def writelines(self, data):
343        self._transport.writelines(data)
344
345    def write_eof(self):
346        return self._transport.write_eof()
347
348    def can_write_eof(self):
349        return self._transport.can_write_eof()
350
351    def close(self):
352        return self._transport.close()
353
354    def is_closing(self):
355        return self._transport.is_closing()
356
357    async def wait_closed(self):
358        await self._protocol._get_close_waiter(self)
359
360    def get_extra_info(self, name, default=None):
361        return self._transport.get_extra_info(name, default)
362
363    async def drain(self):
364        """Flush the write buffer.
365
366        The intended use is to write
367
368          w.write(data)
369          await w.drain()
370        """
371        if self._reader is not None:
372            exc = self._reader.exception()
373            if exc is not None:
374                raise exc
375        if self._transport.is_closing():
376            # Wait for protocol.connection_lost() call
377            # Raise connection closing error if any,
378            # ConnectionResetError otherwise
379            # Yield to the event loop so connection_lost() may be
380            # called.  Without this, _drain_helper() would return
381            # immediately, and code that calls
382            #     write(...); await drain()
383            # in a loop would never call connection_lost(), so it
384            # would not see an error when the socket is closed.
385            await sleep(0)
386        await self._protocol._drain_helper()
387
388    async def start_tls(self, sslcontext, *,
389                        server_hostname=None,
390                        ssl_handshake_timeout=None,
391                        ssl_shutdown_timeout=None):
392        """Upgrade an existing stream-based connection to TLS."""
393        server_side = self._protocol._client_connected_cb is not None
394        protocol = self._protocol
395        await self.drain()
396        new_transport = await self._loop.start_tls(  # type: ignore
397            self._transport, protocol, sslcontext,
398            server_side=server_side, server_hostname=server_hostname,
399            ssl_handshake_timeout=ssl_handshake_timeout,
400            ssl_shutdown_timeout=ssl_shutdown_timeout)
401        self._transport = new_transport
402        protocol._replace_transport(new_transport)
403
404    def __del__(self, warnings=warnings):
405        if not self._transport.is_closing():
406            if self._loop.is_closed():
407                warnings.warn("loop is closed", ResourceWarning)
408            else:
409                self.close()
410                warnings.warn(f"unclosed {self!r}", ResourceWarning)
411
412class StreamReader:
413
414    _source_traceback = None
415
416    def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
417        # The line length limit is  a security feature;
418        # it also doubles as half the buffer limit.
419
420        if limit <= 0:
421            raise ValueError('Limit cannot be <= 0')
422
423        self._limit = limit
424        if loop is None:
425            self._loop = events.get_event_loop()
426        else:
427            self._loop = loop
428        self._buffer = bytearray()
429        self._eof = False    # Whether we're done.
430        self._waiter = None  # A future used by _wait_for_data()
431        self._exception = None
432        self._transport = None
433        self._paused = False
434        if self._loop.get_debug():
435            self._source_traceback = format_helpers.extract_stack(
436                sys._getframe(1))
437
438    def __repr__(self):
439        info = ['StreamReader']
440        if self._buffer:
441            info.append(f'{len(self._buffer)} bytes')
442        if self._eof:
443            info.append('eof')
444        if self._limit != _DEFAULT_LIMIT:
445            info.append(f'limit={self._limit}')
446        if self._waiter:
447            info.append(f'waiter={self._waiter!r}')
448        if self._exception:
449            info.append(f'exception={self._exception!r}')
450        if self._transport:
451            info.append(f'transport={self._transport!r}')
452        if self._paused:
453            info.append('paused')
454        return '<{}>'.format(' '.join(info))
455
456    def exception(self):
457        return self._exception
458
459    def set_exception(self, exc):
460        self._exception = exc
461
462        waiter = self._waiter
463        if waiter is not None:
464            self._waiter = None
465            if not waiter.cancelled():
466                waiter.set_exception(exc)
467
468    def _wakeup_waiter(self):
469        """Wakeup read*() functions waiting for data or EOF."""
470        waiter = self._waiter
471        if waiter is not None:
472            self._waiter = None
473            if not waiter.cancelled():
474                waiter.set_result(None)
475
476    def set_transport(self, transport):
477        assert self._transport is None, 'Transport already set'
478        self._transport = transport
479
480    def _maybe_resume_transport(self):
481        if self._paused and len(self._buffer) <= self._limit:
482            self._paused = False
483            self._transport.resume_reading()
484
485    def feed_eof(self):
486        self._eof = True
487        self._wakeup_waiter()
488
489    def at_eof(self):
490        """Return True if the buffer is empty and 'feed_eof' was called."""
491        return self._eof and not self._buffer
492
493    def feed_data(self, data):
494        assert not self._eof, 'feed_data after feed_eof'
495
496        if not data:
497            return
498
499        self._buffer.extend(data)
500        self._wakeup_waiter()
501
502        if (self._transport is not None and
503                not self._paused and
504                len(self._buffer) > 2 * self._limit):
505            try:
506                self._transport.pause_reading()
507            except NotImplementedError:
508                # The transport can't be paused.
509                # We'll just have to buffer all data.
510                # Forget the transport so we don't keep trying.
511                self._transport = None
512            else:
513                self._paused = True
514
515    async def _wait_for_data(self, func_name):
516        """Wait until feed_data() or feed_eof() is called.
517
518        If stream was paused, automatically resume it.
519        """
520        # StreamReader uses a future to link the protocol feed_data() method
521        # to a read coroutine. Running two read coroutines at the same time
522        # would have an unexpected behaviour. It would not possible to know
523        # which coroutine would get the next data.
524        if self._waiter is not None:
525            raise RuntimeError(
526                f'{func_name}() called while another coroutine is '
527                f'already waiting for incoming data')
528
529        assert not self._eof, '_wait_for_data after EOF'
530
531        # Waiting for data while paused will make deadlock, so prevent it.
532        # This is essential for readexactly(n) for case when n > self._limit.
533        if self._paused:
534            self._paused = False
535            self._transport.resume_reading()
536
537        self._waiter = self._loop.create_future()
538        try:
539            await self._waiter
540        finally:
541            self._waiter = None
542
543    async def readline(self):
544        """Read chunk of data from the stream until newline (b'\n') is found.
545
546        On success, return chunk that ends with newline. If only partial
547        line can be read due to EOF, return incomplete line without
548        terminating newline. When EOF was reached while no bytes read, empty
549        bytes object is returned.
550
551        If limit is reached, ValueError will be raised. In that case, if
552        newline was found, complete line including newline will be removed
553        from internal buffer. Else, internal buffer will be cleared. Limit is
554        compared against part of the line without newline.
555
556        If stream was paused, this function will automatically resume it if
557        needed.
558        """
559        sep = b'\n'
560        seplen = len(sep)
561        try:
562            line = await self.readuntil(sep)
563        except exceptions.IncompleteReadError as e:
564            return e.partial
565        except exceptions.LimitOverrunError as e:
566            if self._buffer.startswith(sep, e.consumed):
567                del self._buffer[:e.consumed + seplen]
568            else:
569                self._buffer.clear()
570            self._maybe_resume_transport()
571            raise ValueError(e.args[0])
572        return line
573
574    async def readuntil(self, separator=b'\n'):
575        """Read data from the stream until ``separator`` is found.
576
577        On success, the data and separator will be removed from the
578        internal buffer (consumed). Returned data will include the
579        separator at the end.
580
581        Configured stream limit is used to check result. Limit sets the
582        maximal length of data that can be returned, not counting the
583        separator.
584
585        If an EOF occurs and the complete separator is still not found,
586        an IncompleteReadError exception will be raised, and the internal
587        buffer will be reset.  The IncompleteReadError.partial attribute
588        may contain the separator partially.
589
590        If the data cannot be read because of over limit, a
591        LimitOverrunError exception  will be raised, and the data
592        will be left in the internal buffer, so it can be read again.
593
594        The ``separator`` may also be a tuple of separators. In this
595        case the return value will be the shortest possible that has any
596        separator as the suffix. For the purposes of LimitOverrunError,
597        the shortest possible separator is considered to be the one that
598        matched.
599        """
600        if isinstance(separator, tuple):
601            # Makes sure shortest matches wins
602            separator = sorted(separator, key=len)
603        else:
604            separator = [separator]
605        if not separator:
606            raise ValueError('Separator should contain at least one element')
607        min_seplen = len(separator[0])
608        max_seplen = len(separator[-1])
609        if min_seplen == 0:
610            raise ValueError('Separator should be at least one-byte string')
611
612        if self._exception is not None:
613            raise self._exception
614
615        # Consume whole buffer except last bytes, which length is
616        # one less than max_seplen. Let's check corner cases with
617        # separator[-1]='SEPARATOR':
618        # * we have received almost complete separator (without last
619        #   byte). i.e buffer='some textSEPARATO'. In this case we
620        #   can safely consume max_seplen - 1 bytes.
621        # * last byte of buffer is first byte of separator, i.e.
622        #   buffer='abcdefghijklmnopqrS'. We may safely consume
623        #   everything except that last byte, but this require to
624        #   analyze bytes of buffer that match partial separator.
625        #   This is slow and/or require FSM. For this case our
626        #   implementation is not optimal, since require rescanning
627        #   of data that is known to not belong to separator. In
628        #   real world, separator will not be so long to notice
629        #   performance problems. Even when reading MIME-encoded
630        #   messages :)
631
632        # `offset` is the number of bytes from the beginning of the buffer
633        # where there is no occurrence of any `separator`.
634        offset = 0
635
636        # Loop until we find a `separator` in the buffer, exceed the buffer size,
637        # or an EOF has happened.
638        while True:
639            buflen = len(self._buffer)
640
641            # Check if we now have enough data in the buffer for shortest
642            # separator to fit.
643            if buflen - offset >= min_seplen:
644                match_start = None
645                match_end = None
646                for sep in separator:
647                    isep = self._buffer.find(sep, offset)
648
649                    if isep != -1:
650                        # `separator` is in the buffer. `match_start` and
651                        # `match_end` will be used later to retrieve the
652                        # data.
653                        end = isep + len(sep)
654                        if match_end is None or end < match_end:
655                            match_end = end
656                            match_start = isep
657                if match_end is not None:
658                    break
659
660                # see upper comment for explanation.
661                offset = max(0, buflen + 1 - max_seplen)
662                if offset > self._limit:
663                    raise exceptions.LimitOverrunError(
664                        'Separator is not found, and chunk exceed the limit',
665                        offset)
666
667            # Complete message (with full separator) may be present in buffer
668            # even when EOF flag is set. This may happen when the last chunk
669            # adds data which makes separator be found. That's why we check for
670            # EOF *after* inspecting the buffer.
671            if self._eof:
672                chunk = bytes(self._buffer)
673                self._buffer.clear()
674                raise exceptions.IncompleteReadError(chunk, None)
675
676            # _wait_for_data() will resume reading if stream was paused.
677            await self._wait_for_data('readuntil')
678
679        if match_start > self._limit:
680            raise exceptions.LimitOverrunError(
681                'Separator is found, but chunk is longer than limit', match_start)
682
683        chunk = self._buffer[:match_end]
684        del self._buffer[:match_end]
685        self._maybe_resume_transport()
686        return bytes(chunk)
687
688    async def read(self, n=-1):
689        """Read up to `n` bytes from the stream.
690
691        If `n` is not provided or set to -1,
692        read until EOF, then return all read bytes.
693        If EOF was received and the internal buffer is empty,
694        return an empty bytes object.
695
696        If `n` is 0, return an empty bytes object immediately.
697
698        If `n` is positive, return at most `n` available bytes
699        as soon as at least 1 byte is available in the internal buffer.
700        If EOF is received before any byte is read, return an empty
701        bytes object.
702
703        Returned value is not limited with limit, configured at stream
704        creation.
705
706        If stream was paused, this function will automatically resume it if
707        needed.
708        """
709
710        if self._exception is not None:
711            raise self._exception
712
713        if n == 0:
714            return b''
715
716        if n < 0:
717            # This used to just loop creating a new waiter hoping to
718            # collect everything in self._buffer, but that would
719            # deadlock if the subprocess sends more than self.limit
720            # bytes.  So just call self.read(self._limit) until EOF.
721            blocks = []
722            while True:
723                block = await self.read(self._limit)
724                if not block:
725                    break
726                blocks.append(block)
727            return b''.join(blocks)
728
729        if not self._buffer and not self._eof:
730            await self._wait_for_data('read')
731
732        # This will work right even if buffer is less than n bytes
733        data = bytes(memoryview(self._buffer)[:n])
734        del self._buffer[:n]
735
736        self._maybe_resume_transport()
737        return data
738
739    async def readexactly(self, n):
740        """Read exactly `n` bytes.
741
742        Raise an IncompleteReadError if EOF is reached before `n` bytes can be
743        read. The IncompleteReadError.partial attribute of the exception will
744        contain the partial read bytes.
745
746        if n is zero, return empty bytes object.
747
748        Returned value is not limited with limit, configured at stream
749        creation.
750
751        If stream was paused, this function will automatically resume it if
752        needed.
753        """
754        if n < 0:
755            raise ValueError('readexactly size can not be less than zero')
756
757        if self._exception is not None:
758            raise self._exception
759
760        if n == 0:
761            return b''
762
763        while len(self._buffer) < n:
764            if self._eof:
765                incomplete = bytes(self._buffer)
766                self._buffer.clear()
767                raise exceptions.IncompleteReadError(incomplete, n)
768
769            await self._wait_for_data('readexactly')
770
771        if len(self._buffer) == n:
772            data = bytes(self._buffer)
773            self._buffer.clear()
774        else:
775            data = bytes(memoryview(self._buffer)[:n])
776            del self._buffer[:n]
777        self._maybe_resume_transport()
778        return data
779
780    def __aiter__(self):
781        return self
782
783    async def __anext__(self):
784        val = await self.readline()
785        if val == b'':
786            raise StopAsyncIteration
787        return val
788