• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import collections
2import warnings
3try:
4    import ssl
5except ImportError:  # pragma: no cover
6    ssl = None
7
8from . import base_events
9from . import constants
10from . import protocols
11from . import transports
12from .log import logger
13
14
15def _create_transport_context(server_side, server_hostname):
16    if server_side:
17        raise ValueError('Server side SSL needs a valid SSLContext')
18
19    # Client side may pass ssl=True to use a default
20    # context; in that case the sslcontext passed is None.
21    # The default is secure for client connections.
22    # Python 3.4+: use up-to-date strong settings.
23    sslcontext = ssl.create_default_context()
24    if not server_hostname:
25        sslcontext.check_hostname = False
26    return sslcontext
27
28
29# States of an _SSLPipe.
30_UNWRAPPED = "UNWRAPPED"
31_DO_HANDSHAKE = "DO_HANDSHAKE"
32_WRAPPED = "WRAPPED"
33_SHUTDOWN = "SHUTDOWN"
34
35
36class _SSLPipe(object):
37    """An SSL "Pipe".
38
39    An SSL pipe allows you to communicate with an SSL/TLS protocol instance
40    through memory buffers. It can be used to implement a security layer for an
41    existing connection where you don't have access to the connection's file
42    descriptor, or for some reason you don't want to use it.
43
44    An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode,
45    data is passed through untransformed. In wrapped mode, application level
46    data is encrypted to SSL record level data and vice versa. The SSL record
47    level is the lowest level in the SSL protocol suite and is what travels
48    as-is over the wire.
49
50    An SslPipe initially is in "unwrapped" mode. To start SSL, call
51    do_handshake(). To shutdown SSL again, call unwrap().
52    """
53
54    max_size = 256 * 1024   # Buffer size passed to read()
55
56    def __init__(self, context, server_side, server_hostname=None):
57        """
58        The *context* argument specifies the ssl.SSLContext to use.
59
60        The *server_side* argument indicates whether this is a server side or
61        client side transport.
62
63        The optional *server_hostname* argument can be used to specify the
64        hostname you are connecting to. You may only specify this parameter if
65        the _ssl module supports Server Name Indication (SNI).
66        """
67        self._context = context
68        self._server_side = server_side
69        self._server_hostname = server_hostname
70        self._state = _UNWRAPPED
71        self._incoming = ssl.MemoryBIO()
72        self._outgoing = ssl.MemoryBIO()
73        self._sslobj = None
74        self._need_ssldata = False
75        self._handshake_cb = None
76        self._shutdown_cb = None
77
78    @property
79    def context(self):
80        """The SSL context passed to the constructor."""
81        return self._context
82
83    @property
84    def ssl_object(self):
85        """The internal ssl.SSLObject instance.
86
87        Return None if the pipe is not wrapped.
88        """
89        return self._sslobj
90
91    @property
92    def need_ssldata(self):
93        """Whether more record level data is needed to complete a handshake
94        that is currently in progress."""
95        return self._need_ssldata
96
97    @property
98    def wrapped(self):
99        """
100        Whether a security layer is currently in effect.
101
102        Return False during handshake.
103        """
104        return self._state == _WRAPPED
105
106    def do_handshake(self, callback=None):
107        """Start the SSL handshake.
108
109        Return a list of ssldata. A ssldata element is a list of buffers
110
111        The optional *callback* argument can be used to install a callback that
112        will be called when the handshake is complete. The callback will be
113        called with None if successful, else an exception instance.
114        """
115        if self._state != _UNWRAPPED:
116            raise RuntimeError('handshake in progress or completed')
117        self._sslobj = self._context.wrap_bio(
118            self._incoming, self._outgoing,
119            server_side=self._server_side,
120            server_hostname=self._server_hostname)
121        self._state = _DO_HANDSHAKE
122        self._handshake_cb = callback
123        ssldata, appdata = self.feed_ssldata(b'', only_handshake=True)
124        assert len(appdata) == 0
125        return ssldata
126
127    def shutdown(self, callback=None):
128        """Start the SSL shutdown sequence.
129
130        Return a list of ssldata. A ssldata element is a list of buffers
131
132        The optional *callback* argument can be used to install a callback that
133        will be called when the shutdown is complete. The callback will be
134        called without arguments.
135        """
136        if self._state == _UNWRAPPED:
137            raise RuntimeError('no security layer present')
138        if self._state == _SHUTDOWN:
139            raise RuntimeError('shutdown in progress')
140        assert self._state in (_WRAPPED, _DO_HANDSHAKE)
141        self._state = _SHUTDOWN
142        self._shutdown_cb = callback
143        ssldata, appdata = self.feed_ssldata(b'')
144        assert appdata == [] or appdata == [b'']
145        return ssldata
146
147    def feed_eof(self):
148        """Send a potentially "ragged" EOF.
149
150        This method will raise an SSL_ERROR_EOF exception if the EOF is
151        unexpected.
152        """
153        self._incoming.write_eof()
154        ssldata, appdata = self.feed_ssldata(b'')
155        assert appdata == [] or appdata == [b'']
156
157    def feed_ssldata(self, data, only_handshake=False):
158        """Feed SSL record level data into the pipe.
159
160        The data must be a bytes instance. It is OK to send an empty bytes
161        instance. This can be used to get ssldata for a handshake initiated by
162        this endpoint.
163
164        Return a (ssldata, appdata) tuple. The ssldata element is a list of
165        buffers containing SSL data that needs to be sent to the remote SSL.
166
167        The appdata element is a list of buffers containing plaintext data that
168        needs to be forwarded to the application. The appdata list may contain
169        an empty buffer indicating an SSL "close_notify" alert. This alert must
170        be acknowledged by calling shutdown().
171        """
172        if self._state == _UNWRAPPED:
173            # If unwrapped, pass plaintext data straight through.
174            if data:
175                appdata = [data]
176            else:
177                appdata = []
178            return ([], appdata)
179
180        self._need_ssldata = False
181        if data:
182            self._incoming.write(data)
183
184        ssldata = []
185        appdata = []
186        try:
187            if self._state == _DO_HANDSHAKE:
188                # Call do_handshake() until it doesn't raise anymore.
189                self._sslobj.do_handshake()
190                self._state = _WRAPPED
191                if self._handshake_cb:
192                    self._handshake_cb(None)
193                if only_handshake:
194                    return (ssldata, appdata)
195                # Handshake done: execute the wrapped block
196
197            if self._state == _WRAPPED:
198                # Main state: read data from SSL until close_notify
199                while True:
200                    chunk = self._sslobj.read(self.max_size)
201                    appdata.append(chunk)
202                    if not chunk:  # close_notify
203                        break
204
205            elif self._state == _SHUTDOWN:
206                # Call shutdown() until it doesn't raise anymore.
207                self._sslobj.unwrap()
208                self._sslobj = None
209                self._state = _UNWRAPPED
210                if self._shutdown_cb:
211                    self._shutdown_cb()
212
213            elif self._state == _UNWRAPPED:
214                # Drain possible plaintext data after close_notify.
215                appdata.append(self._incoming.read())
216        except (ssl.SSLError, ssl.CertificateError) as exc:
217            exc_errno = getattr(exc, 'errno', None)
218            if exc_errno not in (
219                    ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
220                    ssl.SSL_ERROR_SYSCALL):
221                if self._state == _DO_HANDSHAKE and self._handshake_cb:
222                    self._handshake_cb(exc)
223                raise
224            self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
225
226        # Check for record level data that needs to be sent back.
227        # Happens for the initial handshake and renegotiations.
228        if self._outgoing.pending:
229            ssldata.append(self._outgoing.read())
230        return (ssldata, appdata)
231
232    def feed_appdata(self, data, offset=0):
233        """Feed plaintext data into the pipe.
234
235        Return an (ssldata, offset) tuple. The ssldata element is a list of
236        buffers containing record level data that needs to be sent to the
237        remote SSL instance. The offset is the number of plaintext bytes that
238        were processed, which may be less than the length of data.
239
240        NOTE: In case of short writes, this call MUST be retried with the SAME
241        buffer passed into the *data* argument (i.e. the id() must be the
242        same). This is an OpenSSL requirement. A further particularity is that
243        a short write will always have offset == 0, because the _ssl module
244        does not enable partial writes. And even though the offset is zero,
245        there will still be encrypted data in ssldata.
246        """
247        assert 0 <= offset <= len(data)
248        if self._state == _UNWRAPPED:
249            # pass through data in unwrapped mode
250            if offset < len(data):
251                ssldata = [data[offset:]]
252            else:
253                ssldata = []
254            return (ssldata, len(data))
255
256        ssldata = []
257        view = memoryview(data)
258        while True:
259            self._need_ssldata = False
260            try:
261                if offset < len(view):
262                    offset += self._sslobj.write(view[offset:])
263            except ssl.SSLError as exc:
264                # It is not allowed to call write() after unwrap() until the
265                # close_notify is acknowledged. We return the condition to the
266                # caller as a short write.
267                exc_errno = getattr(exc, 'errno', None)
268                if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
269                    exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ
270                if exc_errno not in (ssl.SSL_ERROR_WANT_READ,
271                                     ssl.SSL_ERROR_WANT_WRITE,
272                                     ssl.SSL_ERROR_SYSCALL):
273                    raise
274                self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
275
276            # See if there's any record level data back for us.
277            if self._outgoing.pending:
278                ssldata.append(self._outgoing.read())
279            if offset == len(view) or self._need_ssldata:
280                break
281        return (ssldata, offset)
282
283
284class _SSLProtocolTransport(transports._FlowControlMixin,
285                            transports.Transport):
286
287    _sendfile_compatible = constants._SendfileMode.FALLBACK
288
289    def __init__(self, loop, ssl_protocol):
290        self._loop = loop
291        # SSLProtocol instance
292        self._ssl_protocol = ssl_protocol
293        self._closed = False
294
295    def get_extra_info(self, name, default=None):
296        """Get optional transport information."""
297        return self._ssl_protocol._get_extra_info(name, default)
298
299    def set_protocol(self, protocol):
300        self._ssl_protocol._set_app_protocol(protocol)
301
302    def get_protocol(self):
303        return self._ssl_protocol._app_protocol
304
305    def is_closing(self):
306        return self._closed
307
308    def close(self):
309        """Close the transport.
310
311        Buffered data will be flushed asynchronously.  No more data
312        will be received.  After all buffered data is flushed, the
313        protocol's connection_lost() method will (eventually) called
314        with None as its argument.
315        """
316        self._closed = True
317        self._ssl_protocol._start_shutdown()
318
319    def __del__(self, _warn=warnings.warn):
320        if not self._closed:
321            _warn(f"unclosed transport {self!r}", ResourceWarning, source=self)
322            self.close()
323
324    def is_reading(self):
325        tr = self._ssl_protocol._transport
326        if tr is None:
327            raise RuntimeError('SSL transport has not been initialized yet')
328        return tr.is_reading()
329
330    def pause_reading(self):
331        """Pause the receiving end.
332
333        No data will be passed to the protocol's data_received()
334        method until resume_reading() is called.
335        """
336        self._ssl_protocol._transport.pause_reading()
337
338    def resume_reading(self):
339        """Resume the receiving end.
340
341        Data received will once again be passed to the protocol's
342        data_received() method.
343        """
344        self._ssl_protocol._transport.resume_reading()
345
346    def set_write_buffer_limits(self, high=None, low=None):
347        """Set the high- and low-water limits for write flow control.
348
349        These two values control when to call the protocol's
350        pause_writing() and resume_writing() methods.  If specified,
351        the low-water limit must be less than or equal to the
352        high-water limit.  Neither value can be negative.
353
354        The defaults are implementation-specific.  If only the
355        high-water limit is given, the low-water limit defaults to an
356        implementation-specific value less than or equal to the
357        high-water limit.  Setting high to zero forces low to zero as
358        well, and causes pause_writing() to be called whenever the
359        buffer becomes non-empty.  Setting low to zero causes
360        resume_writing() to be called only once the buffer is empty.
361        Use of zero for either limit is generally sub-optimal as it
362        reduces opportunities for doing I/O and computation
363        concurrently.
364        """
365        self._ssl_protocol._transport.set_write_buffer_limits(high, low)
366
367    def get_write_buffer_size(self):
368        """Return the current size of the write buffer."""
369        return self._ssl_protocol._transport.get_write_buffer_size()
370
371    @property
372    def _protocol_paused(self):
373        # Required for sendfile fallback pause_writing/resume_writing logic
374        return self._ssl_protocol._transport._protocol_paused
375
376    def write(self, data):
377        """Write some data bytes to the transport.
378
379        This does not block; it buffers the data and arranges for it
380        to be sent out asynchronously.
381        """
382        if not isinstance(data, (bytes, bytearray, memoryview)):
383            raise TypeError(f"data: expecting a bytes-like instance, "
384                            f"got {type(data).__name__}")
385        if not data:
386            return
387        self._ssl_protocol._write_appdata(data)
388
389    def can_write_eof(self):
390        """Return True if this transport supports write_eof(), False if not."""
391        return False
392
393    def abort(self):
394        """Close the transport immediately.
395
396        Buffered data will be lost.  No more data will be received.
397        The protocol's connection_lost() method will (eventually) be
398        called with None as its argument.
399        """
400        self._ssl_protocol._abort()
401        self._closed = True
402
403
404class SSLProtocol(protocols.Protocol):
405    """SSL protocol.
406
407    Implementation of SSL on top of a socket using incoming and outgoing
408    buffers which are ssl.MemoryBIO objects.
409    """
410
411    def __init__(self, loop, app_protocol, sslcontext, waiter,
412                 server_side=False, server_hostname=None,
413                 call_connection_made=True,
414                 ssl_handshake_timeout=None):
415        if ssl is None:
416            raise RuntimeError('stdlib ssl module not available')
417
418        if ssl_handshake_timeout is None:
419            ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
420        elif ssl_handshake_timeout <= 0:
421            raise ValueError(
422                f"ssl_handshake_timeout should be a positive number, "
423                f"got {ssl_handshake_timeout}")
424
425        if not sslcontext:
426            sslcontext = _create_transport_context(
427                server_side, server_hostname)
428
429        self._server_side = server_side
430        if server_hostname and not server_side:
431            self._server_hostname = server_hostname
432        else:
433            self._server_hostname = None
434        self._sslcontext = sslcontext
435        # SSL-specific extra info. More info are set when the handshake
436        # completes.
437        self._extra = dict(sslcontext=sslcontext)
438
439        # App data write buffering
440        self._write_backlog = collections.deque()
441        self._write_buffer_size = 0
442
443        self._waiter = waiter
444        self._loop = loop
445        self._set_app_protocol(app_protocol)
446        self._app_transport = _SSLProtocolTransport(self._loop, self)
447        # _SSLPipe instance (None until the connection is made)
448        self._sslpipe = None
449        self._session_established = False
450        self._in_handshake = False
451        self._in_shutdown = False
452        # transport, ex: SelectorSocketTransport
453        self._transport = None
454        self._call_connection_made = call_connection_made
455        self._ssl_handshake_timeout = ssl_handshake_timeout
456
457    def _set_app_protocol(self, app_protocol):
458        self._app_protocol = app_protocol
459        self._app_protocol_is_buffer = \
460            isinstance(app_protocol, protocols.BufferedProtocol)
461
462    def _wakeup_waiter(self, exc=None):
463        if self._waiter is None:
464            return
465        if not self._waiter.cancelled():
466            if exc is not None:
467                self._waiter.set_exception(exc)
468            else:
469                self._waiter.set_result(None)
470        self._waiter = None
471
472    def connection_made(self, transport):
473        """Called when the low-level connection is made.
474
475        Start the SSL handshake.
476        """
477        self._transport = transport
478        self._sslpipe = _SSLPipe(self._sslcontext,
479                                 self._server_side,
480                                 self._server_hostname)
481        self._start_handshake()
482
483    def connection_lost(self, exc):
484        """Called when the low-level connection is lost or closed.
485
486        The argument is an exception object or None (the latter
487        meaning a regular EOF is received or the connection was
488        aborted or closed).
489        """
490        if self._session_established:
491            self._session_established = False
492            self._loop.call_soon(self._app_protocol.connection_lost, exc)
493        else:
494            # Most likely an exception occurred while in SSL handshake.
495            # Just mark the app transport as closed so that its __del__
496            # doesn't complain.
497            if self._app_transport is not None:
498                self._app_transport._closed = True
499        self._transport = None
500        self._app_transport = None
501        if getattr(self, '_handshake_timeout_handle', None):
502            self._handshake_timeout_handle.cancel()
503        self._wakeup_waiter(exc)
504        self._app_protocol = None
505        self._sslpipe = None
506
507    def pause_writing(self):
508        """Called when the low-level transport's buffer goes over
509        the high-water mark.
510        """
511        self._app_protocol.pause_writing()
512
513    def resume_writing(self):
514        """Called when the low-level transport's buffer drains below
515        the low-water mark.
516        """
517        self._app_protocol.resume_writing()
518
519    def data_received(self, data):
520        """Called when some SSL data is received.
521
522        The argument is a bytes object.
523        """
524        if self._sslpipe is None:
525            # transport closing, sslpipe is destroyed
526            return
527
528        try:
529            ssldata, appdata = self._sslpipe.feed_ssldata(data)
530        except (SystemExit, KeyboardInterrupt):
531            raise
532        except BaseException as e:
533            self._fatal_error(e, 'SSL error in data received')
534            return
535
536        for chunk in ssldata:
537            self._transport.write(chunk)
538
539        for chunk in appdata:
540            if chunk:
541                try:
542                    if self._app_protocol_is_buffer:
543                        protocols._feed_data_to_buffered_proto(
544                            self._app_protocol, chunk)
545                    else:
546                        self._app_protocol.data_received(chunk)
547                except (SystemExit, KeyboardInterrupt):
548                    raise
549                except BaseException as ex:
550                    self._fatal_error(
551                        ex, 'application protocol failed to receive SSL data')
552                    return
553            else:
554                self._start_shutdown()
555                break
556
557    def eof_received(self):
558        """Called when the other end of the low-level stream
559        is half-closed.
560
561        If this returns a false value (including None), the transport
562        will close itself.  If it returns a true value, closing the
563        transport is up to the protocol.
564        """
565        try:
566            if self._loop.get_debug():
567                logger.debug("%r received EOF", self)
568
569            self._wakeup_waiter(ConnectionResetError)
570
571            if not self._in_handshake:
572                keep_open = self._app_protocol.eof_received()
573                if keep_open:
574                    logger.warning('returning true from eof_received() '
575                                   'has no effect when using ssl')
576        finally:
577            self._transport.close()
578
579    def _get_extra_info(self, name, default=None):
580        if name in self._extra:
581            return self._extra[name]
582        elif self._transport is not None:
583            return self._transport.get_extra_info(name, default)
584        else:
585            return default
586
587    def _start_shutdown(self):
588        if self._in_shutdown:
589            return
590        if self._in_handshake:
591            self._abort()
592        else:
593            self._in_shutdown = True
594            self._write_appdata(b'')
595
596    def _write_appdata(self, data):
597        self._write_backlog.append((data, 0))
598        self._write_buffer_size += len(data)
599        self._process_write_backlog()
600
601    def _start_handshake(self):
602        if self._loop.get_debug():
603            logger.debug("%r starts SSL handshake", self)
604            self._handshake_start_time = self._loop.time()
605        else:
606            self._handshake_start_time = None
607        self._in_handshake = True
608        # (b'', 1) is a special value in _process_write_backlog() to do
609        # the SSL handshake
610        self._write_backlog.append((b'', 1))
611        self._handshake_timeout_handle = \
612            self._loop.call_later(self._ssl_handshake_timeout,
613                                  self._check_handshake_timeout)
614        self._process_write_backlog()
615
616    def _check_handshake_timeout(self):
617        if self._in_handshake is True:
618            msg = (
619                f"SSL handshake is taking longer than "
620                f"{self._ssl_handshake_timeout} seconds: "
621                f"aborting the connection"
622            )
623            self._fatal_error(ConnectionAbortedError(msg))
624
625    def _on_handshake_complete(self, handshake_exc):
626        self._in_handshake = False
627        self._handshake_timeout_handle.cancel()
628
629        sslobj = self._sslpipe.ssl_object
630        try:
631            if handshake_exc is not None:
632                raise handshake_exc
633
634            peercert = sslobj.getpeercert()
635        except (SystemExit, KeyboardInterrupt):
636            raise
637        except BaseException as exc:
638            if isinstance(exc, ssl.CertificateError):
639                msg = 'SSL handshake failed on verifying the certificate'
640            else:
641                msg = 'SSL handshake failed'
642            self._fatal_error(exc, msg)
643            return
644
645        if self._loop.get_debug():
646            dt = self._loop.time() - self._handshake_start_time
647            logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
648
649        # Add extra info that becomes available after handshake.
650        self._extra.update(peercert=peercert,
651                           cipher=sslobj.cipher(),
652                           compression=sslobj.compression(),
653                           ssl_object=sslobj,
654                           )
655        if self._call_connection_made:
656            self._app_protocol.connection_made(self._app_transport)
657        self._wakeup_waiter()
658        self._session_established = True
659        # In case transport.write() was already called. Don't call
660        # immediately _process_write_backlog(), but schedule it:
661        # _on_handshake_complete() can be called indirectly from
662        # _process_write_backlog(), and _process_write_backlog() is not
663        # reentrant.
664        self._loop.call_soon(self._process_write_backlog)
665
666    def _process_write_backlog(self):
667        # Try to make progress on the write backlog.
668        if self._transport is None or self._sslpipe is None:
669            return
670
671        try:
672            for i in range(len(self._write_backlog)):
673                data, offset = self._write_backlog[0]
674                if data:
675                    ssldata, offset = self._sslpipe.feed_appdata(data, offset)
676                elif offset:
677                    ssldata = self._sslpipe.do_handshake(
678                        self._on_handshake_complete)
679                    offset = 1
680                else:
681                    ssldata = self._sslpipe.shutdown(self._finalize)
682                    offset = 1
683
684                for chunk in ssldata:
685                    self._transport.write(chunk)
686
687                if offset < len(data):
688                    self._write_backlog[0] = (data, offset)
689                    # A short write means that a write is blocked on a read
690                    # We need to enable reading if it is paused!
691                    assert self._sslpipe.need_ssldata
692                    if self._transport._paused:
693                        self._transport.resume_reading()
694                    break
695
696                # An entire chunk from the backlog was processed. We can
697                # delete it and reduce the outstanding buffer size.
698                del self._write_backlog[0]
699                self._write_buffer_size -= len(data)
700        except (SystemExit, KeyboardInterrupt):
701            raise
702        except BaseException as exc:
703            if self._in_handshake:
704                # Exceptions will be re-raised in _on_handshake_complete.
705                self._on_handshake_complete(exc)
706            else:
707                self._fatal_error(exc, 'Fatal error on SSL transport')
708
709    def _fatal_error(self, exc, message='Fatal error on transport'):
710        if isinstance(exc, OSError):
711            if self._loop.get_debug():
712                logger.debug("%r: %s", self, message, exc_info=True)
713        else:
714            self._loop.call_exception_handler({
715                'message': message,
716                'exception': exc,
717                'transport': self._transport,
718                'protocol': self,
719            })
720        if self._transport:
721            self._transport._force_close(exc)
722
723    def _finalize(self):
724        self._sslpipe = None
725
726        if self._transport is not None:
727            self._transport.close()
728
729    def _abort(self):
730        try:
731            if self._transport is not None:
732                self._transport.abort()
733        finally:
734            self._finalize()
735