• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import collections
2import warnings
3try:
4    import ssl
5except ImportError:  # pragma: no cover
6    ssl = None
7
8from . import constants
9from . import protocols
10from . import transports
11from .log import logger
12
13
14def _create_transport_context(server_side, server_hostname):
15    if server_side:
16        raise ValueError('Server side SSL needs a valid SSLContext')
17
18    # Client side may pass ssl=True to use a default
19    # context; in that case the sslcontext passed is None.
20    # The default is secure for client connections.
21    # Python 3.4+: use up-to-date strong settings.
22    sslcontext = ssl.create_default_context()
23    if not server_hostname:
24        sslcontext.check_hostname = False
25    return sslcontext
26
27
28# States of an _SSLPipe.
29_UNWRAPPED = "UNWRAPPED"
30_DO_HANDSHAKE = "DO_HANDSHAKE"
31_WRAPPED = "WRAPPED"
32_SHUTDOWN = "SHUTDOWN"
33
34
35class _SSLPipe(object):
36    """An SSL "Pipe".
37
38    An SSL pipe allows you to communicate with an SSL/TLS protocol instance
39    through memory buffers. It can be used to implement a security layer for an
40    existing connection where you don't have access to the connection's file
41    descriptor, or for some reason you don't want to use it.
42
43    An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode,
44    data is passed through untransformed. In wrapped mode, application level
45    data is encrypted to SSL record level data and vice versa. The SSL record
46    level is the lowest level in the SSL protocol suite and is what travels
47    as-is over the wire.
48
49    An SslPipe initially is in "unwrapped" mode. To start SSL, call
50    do_handshake(). To shutdown SSL again, call unwrap().
51    """
52
53    max_size = 256 * 1024   # Buffer size passed to read()
54
55    def __init__(self, context, server_side, server_hostname=None):
56        """
57        The *context* argument specifies the ssl.SSLContext to use.
58
59        The *server_side* argument indicates whether this is a server side or
60        client side transport.
61
62        The optional *server_hostname* argument can be used to specify the
63        hostname you are connecting to. You may only specify this parameter if
64        the _ssl module supports Server Name Indication (SNI).
65        """
66        self._context = context
67        self._server_side = server_side
68        self._server_hostname = server_hostname
69        self._state = _UNWRAPPED
70        self._incoming = ssl.MemoryBIO()
71        self._outgoing = ssl.MemoryBIO()
72        self._sslobj = None
73        self._need_ssldata = False
74        self._handshake_cb = None
75        self._shutdown_cb = None
76
77    @property
78    def context(self):
79        """The SSL context passed to the constructor."""
80        return self._context
81
82    @property
83    def ssl_object(self):
84        """The internal ssl.SSLObject instance.
85
86        Return None if the pipe is not wrapped.
87        """
88        return self._sslobj
89
90    @property
91    def need_ssldata(self):
92        """Whether more record level data is needed to complete a handshake
93        that is currently in progress."""
94        return self._need_ssldata
95
96    @property
97    def wrapped(self):
98        """
99        Whether a security layer is currently in effect.
100
101        Return False during handshake.
102        """
103        return self._state == _WRAPPED
104
105    def do_handshake(self, callback=None):
106        """Start the SSL handshake.
107
108        Return a list of ssldata. A ssldata element is a list of buffers
109
110        The optional *callback* argument can be used to install a callback that
111        will be called when the handshake is complete. The callback will be
112        called with None if successful, else an exception instance.
113        """
114        if self._state != _UNWRAPPED:
115            raise RuntimeError('handshake in progress or completed')
116        self._sslobj = self._context.wrap_bio(
117            self._incoming, self._outgoing,
118            server_side=self._server_side,
119            server_hostname=self._server_hostname)
120        self._state = _DO_HANDSHAKE
121        self._handshake_cb = callback
122        ssldata, appdata = self.feed_ssldata(b'', only_handshake=True)
123        assert len(appdata) == 0
124        return ssldata
125
126    def shutdown(self, callback=None):
127        """Start the SSL shutdown sequence.
128
129        Return a list of ssldata. A ssldata element is a list of buffers
130
131        The optional *callback* argument can be used to install a callback that
132        will be called when the shutdown is complete. The callback will be
133        called without arguments.
134        """
135        if self._state == _UNWRAPPED:
136            raise RuntimeError('no security layer present')
137        if self._state == _SHUTDOWN:
138            raise RuntimeError('shutdown in progress')
139        assert self._state in (_WRAPPED, _DO_HANDSHAKE)
140        self._state = _SHUTDOWN
141        self._shutdown_cb = callback
142        ssldata, appdata = self.feed_ssldata(b'')
143        assert appdata == [] or appdata == [b'']
144        return ssldata
145
146    def feed_eof(self):
147        """Send a potentially "ragged" EOF.
148
149        This method will raise an SSL_ERROR_EOF exception if the EOF is
150        unexpected.
151        """
152        self._incoming.write_eof()
153        ssldata, appdata = self.feed_ssldata(b'')
154        assert appdata == [] or appdata == [b'']
155
156    def feed_ssldata(self, data, only_handshake=False):
157        """Feed SSL record level data into the pipe.
158
159        The data must be a bytes instance. It is OK to send an empty bytes
160        instance. This can be used to get ssldata for a handshake initiated by
161        this endpoint.
162
163        Return a (ssldata, appdata) tuple. The ssldata element is a list of
164        buffers containing SSL data that needs to be sent to the remote SSL.
165
166        The appdata element is a list of buffers containing plaintext data that
167        needs to be forwarded to the application. The appdata list may contain
168        an empty buffer indicating an SSL "close_notify" alert. This alert must
169        be acknowledged by calling shutdown().
170        """
171        if self._state == _UNWRAPPED:
172            # If unwrapped, pass plaintext data straight through.
173            if data:
174                appdata = [data]
175            else:
176                appdata = []
177            return ([], appdata)
178
179        self._need_ssldata = False
180        if data:
181            self._incoming.write(data)
182
183        ssldata = []
184        appdata = []
185        try:
186            if self._state == _DO_HANDSHAKE:
187                # Call do_handshake() until it doesn't raise anymore.
188                self._sslobj.do_handshake()
189                self._state = _WRAPPED
190                if self._handshake_cb:
191                    self._handshake_cb(None)
192                if only_handshake:
193                    return (ssldata, appdata)
194                # Handshake done: execute the wrapped block
195
196            if self._state == _WRAPPED:
197                # Main state: read data from SSL until close_notify
198                while True:
199                    chunk = self._sslobj.read(self.max_size)
200                    appdata.append(chunk)
201                    if not chunk:  # close_notify
202                        break
203
204            elif self._state == _SHUTDOWN:
205                # Call shutdown() until it doesn't raise anymore.
206                self._sslobj.unwrap()
207                self._sslobj = None
208                self._state = _UNWRAPPED
209                if self._shutdown_cb:
210                    self._shutdown_cb()
211
212            elif self._state == _UNWRAPPED:
213                # Drain possible plaintext data after close_notify.
214                appdata.append(self._incoming.read())
215        except (ssl.SSLError, ssl.CertificateError) as exc:
216            exc_errno = getattr(exc, 'errno', None)
217            if exc_errno not in (
218                    ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
219                    ssl.SSL_ERROR_SYSCALL):
220                if self._state == _DO_HANDSHAKE and self._handshake_cb:
221                    self._handshake_cb(exc)
222                raise
223            self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
224
225        # Check for record level data that needs to be sent back.
226        # Happens for the initial handshake and renegotiations.
227        if self._outgoing.pending:
228            ssldata.append(self._outgoing.read())
229        return (ssldata, appdata)
230
231    def feed_appdata(self, data, offset=0):
232        """Feed plaintext data into the pipe.
233
234        Return an (ssldata, offset) tuple. The ssldata element is a list of
235        buffers containing record level data that needs to be sent to the
236        remote SSL instance. The offset is the number of plaintext bytes that
237        were processed, which may be less than the length of data.
238
239        NOTE: In case of short writes, this call MUST be retried with the SAME
240        buffer passed into the *data* argument (i.e. the id() must be the
241        same). This is an OpenSSL requirement. A further particularity is that
242        a short write will always have offset == 0, because the _ssl module
243        does not enable partial writes. And even though the offset is zero,
244        there will still be encrypted data in ssldata.
245        """
246        assert 0 <= offset <= len(data)
247        if self._state == _UNWRAPPED:
248            # pass through data in unwrapped mode
249            if offset < len(data):
250                ssldata = [data[offset:]]
251            else:
252                ssldata = []
253            return (ssldata, len(data))
254
255        ssldata = []
256        view = memoryview(data)
257        while True:
258            self._need_ssldata = False
259            try:
260                if offset < len(view):
261                    offset += self._sslobj.write(view[offset:])
262            except ssl.SSLError as exc:
263                # It is not allowed to call write() after unwrap() until the
264                # close_notify is acknowledged. We return the condition to the
265                # caller as a short write.
266                exc_errno = getattr(exc, 'errno', None)
267                if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
268                    exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ
269                if exc_errno not in (ssl.SSL_ERROR_WANT_READ,
270                                     ssl.SSL_ERROR_WANT_WRITE,
271                                     ssl.SSL_ERROR_SYSCALL):
272                    raise
273                self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
274
275            # See if there's any record level data back for us.
276            if self._outgoing.pending:
277                ssldata.append(self._outgoing.read())
278            if offset == len(view) or self._need_ssldata:
279                break
280        return (ssldata, offset)
281
282
283class _SSLProtocolTransport(transports._FlowControlMixin,
284                            transports.Transport):
285
286    _sendfile_compatible = constants._SendfileMode.FALLBACK
287
288    def __init__(self, loop, ssl_protocol):
289        self._loop = loop
290        # SSLProtocol instance
291        self._ssl_protocol = ssl_protocol
292        self._closed = False
293
294    def get_extra_info(self, name, default=None):
295        """Get optional transport information."""
296        return self._ssl_protocol._get_extra_info(name, default)
297
298    def set_protocol(self, protocol):
299        self._ssl_protocol._set_app_protocol(protocol)
300
301    def get_protocol(self):
302        return self._ssl_protocol._app_protocol
303
304    def is_closing(self):
305        return self._closed
306
307    def close(self):
308        """Close the transport.
309
310        Buffered data will be flushed asynchronously.  No more data
311        will be received.  After all buffered data is flushed, the
312        protocol's connection_lost() method will (eventually) called
313        with None as its argument.
314        """
315        self._closed = True
316        self._ssl_protocol._start_shutdown()
317
318    def __del__(self, _warn=warnings.warn):
319        if not self._closed:
320            _warn(f"unclosed transport {self!r}", ResourceWarning, source=self)
321            self.close()
322
323    def is_reading(self):
324        tr = self._ssl_protocol._transport
325        if tr is None:
326            raise RuntimeError('SSL transport has not been initialized yet')
327        return tr.is_reading()
328
329    def pause_reading(self):
330        """Pause the receiving end.
331
332        No data will be passed to the protocol's data_received()
333        method until resume_reading() is called.
334        """
335        self._ssl_protocol._transport.pause_reading()
336
337    def resume_reading(self):
338        """Resume the receiving end.
339
340        Data received will once again be passed to the protocol's
341        data_received() method.
342        """
343        self._ssl_protocol._transport.resume_reading()
344
345    def set_write_buffer_limits(self, high=None, low=None):
346        """Set the high- and low-water limits for write flow control.
347
348        These two values control when to call the protocol's
349        pause_writing() and resume_writing() methods.  If specified,
350        the low-water limit must be less than or equal to the
351        high-water limit.  Neither value can be negative.
352
353        The defaults are implementation-specific.  If only the
354        high-water limit is given, the low-water limit defaults to an
355        implementation-specific value less than or equal to the
356        high-water limit.  Setting high to zero forces low to zero as
357        well, and causes pause_writing() to be called whenever the
358        buffer becomes non-empty.  Setting low to zero causes
359        resume_writing() to be called only once the buffer is empty.
360        Use of zero for either limit is generally sub-optimal as it
361        reduces opportunities for doing I/O and computation
362        concurrently.
363        """
364        self._ssl_protocol._transport.set_write_buffer_limits(high, low)
365
366    def get_write_buffer_size(self):
367        """Return the current size of the write buffer."""
368        return self._ssl_protocol._transport.get_write_buffer_size()
369
370    @property
371    def _protocol_paused(self):
372        # Required for sendfile fallback pause_writing/resume_writing logic
373        return self._ssl_protocol._transport._protocol_paused
374
375    def write(self, data):
376        """Write some data bytes to the transport.
377
378        This does not block; it buffers the data and arranges for it
379        to be sent out asynchronously.
380        """
381        if not isinstance(data, (bytes, bytearray, memoryview)):
382            raise TypeError(f"data: expecting a bytes-like instance, "
383                            f"got {type(data).__name__}")
384        if not data:
385            return
386        self._ssl_protocol._write_appdata(data)
387
388    def can_write_eof(self):
389        """Return True if this transport supports write_eof(), False if not."""
390        return False
391
392    def abort(self):
393        """Close the transport immediately.
394
395        Buffered data will be lost.  No more data will be received.
396        The protocol's connection_lost() method will (eventually) be
397        called with None as its argument.
398        """
399        self._ssl_protocol._abort()
400        self._closed = True
401
402
403class SSLProtocol(protocols.Protocol):
404    """SSL protocol.
405
406    Implementation of SSL on top of a socket using incoming and outgoing
407    buffers which are ssl.MemoryBIO objects.
408    """
409
410    def __init__(self, loop, app_protocol, sslcontext, waiter,
411                 server_side=False, server_hostname=None,
412                 call_connection_made=True,
413                 ssl_handshake_timeout=None):
414        if ssl is None:
415            raise RuntimeError('stdlib ssl module not available')
416
417        if ssl_handshake_timeout is None:
418            ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
419        elif ssl_handshake_timeout <= 0:
420            raise ValueError(
421                f"ssl_handshake_timeout should be a positive number, "
422                f"got {ssl_handshake_timeout}")
423
424        if not sslcontext:
425            sslcontext = _create_transport_context(
426                server_side, server_hostname)
427
428        self._server_side = server_side
429        if server_hostname and not server_side:
430            self._server_hostname = server_hostname
431        else:
432            self._server_hostname = None
433        self._sslcontext = sslcontext
434        # SSL-specific extra info. More info are set when the handshake
435        # completes.
436        self._extra = dict(sslcontext=sslcontext)
437
438        # App data write buffering
439        self._write_backlog = collections.deque()
440        self._write_buffer_size = 0
441
442        self._waiter = waiter
443        self._loop = loop
444        self._set_app_protocol(app_protocol)
445        self._app_transport = _SSLProtocolTransport(self._loop, self)
446        # _SSLPipe instance (None until the connection is made)
447        self._sslpipe = None
448        self._session_established = False
449        self._in_handshake = False
450        self._in_shutdown = False
451        # transport, ex: SelectorSocketTransport
452        self._transport = None
453        self._call_connection_made = call_connection_made
454        self._ssl_handshake_timeout = ssl_handshake_timeout
455
456    def _set_app_protocol(self, app_protocol):
457        self._app_protocol = app_protocol
458        self._app_protocol_is_buffer = \
459            isinstance(app_protocol, protocols.BufferedProtocol)
460
461    def _wakeup_waiter(self, exc=None):
462        if self._waiter is None:
463            return
464        if not self._waiter.cancelled():
465            if exc is not None:
466                self._waiter.set_exception(exc)
467            else:
468                self._waiter.set_result(None)
469        self._waiter = None
470
471    def connection_made(self, transport):
472        """Called when the low-level connection is made.
473
474        Start the SSL handshake.
475        """
476        self._transport = transport
477        self._sslpipe = _SSLPipe(self._sslcontext,
478                                 self._server_side,
479                                 self._server_hostname)
480        self._start_handshake()
481
482    def connection_lost(self, exc):
483        """Called when the low-level connection is lost or closed.
484
485        The argument is an exception object or None (the latter
486        meaning a regular EOF is received or the connection was
487        aborted or closed).
488        """
489        if self._session_established:
490            self._session_established = False
491            self._loop.call_soon(self._app_protocol.connection_lost, exc)
492        else:
493            # Most likely an exception occurred while in SSL handshake.
494            # Just mark the app transport as closed so that its __del__
495            # doesn't complain.
496            if self._app_transport is not None:
497                self._app_transport._closed = True
498        self._transport = None
499        self._app_transport = None
500        if getattr(self, '_handshake_timeout_handle', None):
501            self._handshake_timeout_handle.cancel()
502        self._wakeup_waiter(exc)
503        self._app_protocol = None
504        self._sslpipe = None
505
506    def pause_writing(self):
507        """Called when the low-level transport's buffer goes over
508        the high-water mark.
509        """
510        self._app_protocol.pause_writing()
511
512    def resume_writing(self):
513        """Called when the low-level transport's buffer drains below
514        the low-water mark.
515        """
516        self._app_protocol.resume_writing()
517
518    def data_received(self, data):
519        """Called when some SSL data is received.
520
521        The argument is a bytes object.
522        """
523        if self._sslpipe is None:
524            # transport closing, sslpipe is destroyed
525            return
526
527        try:
528            ssldata, appdata = self._sslpipe.feed_ssldata(data)
529        except (SystemExit, KeyboardInterrupt):
530            raise
531        except BaseException as e:
532            self._fatal_error(e, 'SSL error in data received')
533            return
534
535        for chunk in ssldata:
536            self._transport.write(chunk)
537
538        for chunk in appdata:
539            if chunk:
540                try:
541                    if self._app_protocol_is_buffer:
542                        protocols._feed_data_to_buffered_proto(
543                            self._app_protocol, chunk)
544                    else:
545                        self._app_protocol.data_received(chunk)
546                except (SystemExit, KeyboardInterrupt):
547                    raise
548                except BaseException as ex:
549                    self._fatal_error(
550                        ex, 'application protocol failed to receive SSL data')
551                    return
552            else:
553                self._start_shutdown()
554                break
555
556    def eof_received(self):
557        """Called when the other end of the low-level stream
558        is half-closed.
559
560        If this returns a false value (including None), the transport
561        will close itself.  If it returns a true value, closing the
562        transport is up to the protocol.
563        """
564        try:
565            if self._loop.get_debug():
566                logger.debug("%r received EOF", self)
567
568            self._wakeup_waiter(ConnectionResetError)
569
570            if not self._in_handshake:
571                keep_open = self._app_protocol.eof_received()
572                if keep_open:
573                    logger.warning('returning true from eof_received() '
574                                   'has no effect when using ssl')
575        finally:
576            self._transport.close()
577
578    def _get_extra_info(self, name, default=None):
579        if name in self._extra:
580            return self._extra[name]
581        elif self._transport is not None:
582            return self._transport.get_extra_info(name, default)
583        else:
584            return default
585
586    def _start_shutdown(self):
587        if self._in_shutdown:
588            return
589        if self._in_handshake:
590            self._abort()
591        else:
592            self._in_shutdown = True
593            self._write_appdata(b'')
594
595    def _write_appdata(self, data):
596        self._write_backlog.append((data, 0))
597        self._write_buffer_size += len(data)
598        self._process_write_backlog()
599
600    def _start_handshake(self):
601        if self._loop.get_debug():
602            logger.debug("%r starts SSL handshake", self)
603            self._handshake_start_time = self._loop.time()
604        else:
605            self._handshake_start_time = None
606        self._in_handshake = True
607        # (b'', 1) is a special value in _process_write_backlog() to do
608        # the SSL handshake
609        self._write_backlog.append((b'', 1))
610        self._handshake_timeout_handle = \
611            self._loop.call_later(self._ssl_handshake_timeout,
612                                  self._check_handshake_timeout)
613        self._process_write_backlog()
614
615    def _check_handshake_timeout(self):
616        if self._in_handshake is True:
617            msg = (
618                f"SSL handshake is taking longer than "
619                f"{self._ssl_handshake_timeout} seconds: "
620                f"aborting the connection"
621            )
622            self._fatal_error(ConnectionAbortedError(msg))
623
624    def _on_handshake_complete(self, handshake_exc):
625        self._in_handshake = False
626        self._handshake_timeout_handle.cancel()
627
628        sslobj = self._sslpipe.ssl_object
629        try:
630            if handshake_exc is not None:
631                raise handshake_exc
632
633            peercert = sslobj.getpeercert()
634        except (SystemExit, KeyboardInterrupt):
635            raise
636        except BaseException as exc:
637            if isinstance(exc, ssl.CertificateError):
638                msg = 'SSL handshake failed on verifying the certificate'
639            else:
640                msg = 'SSL handshake failed'
641            self._fatal_error(exc, msg)
642            return
643
644        if self._loop.get_debug():
645            dt = self._loop.time() - self._handshake_start_time
646            logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
647
648        # Add extra info that becomes available after handshake.
649        self._extra.update(peercert=peercert,
650                           cipher=sslobj.cipher(),
651                           compression=sslobj.compression(),
652                           ssl_object=sslobj,
653                           )
654        if self._call_connection_made:
655            self._app_protocol.connection_made(self._app_transport)
656        self._wakeup_waiter()
657        self._session_established = True
658        # In case transport.write() was already called. Don't call
659        # immediately _process_write_backlog(), but schedule it:
660        # _on_handshake_complete() can be called indirectly from
661        # _process_write_backlog(), and _process_write_backlog() is not
662        # reentrant.
663        self._loop.call_soon(self._process_write_backlog)
664
665    def _process_write_backlog(self):
666        # Try to make progress on the write backlog.
667        if self._transport is None or self._sslpipe is None:
668            return
669
670        try:
671            for i in range(len(self._write_backlog)):
672                data, offset = self._write_backlog[0]
673                if data:
674                    ssldata, offset = self._sslpipe.feed_appdata(data, offset)
675                elif offset:
676                    ssldata = self._sslpipe.do_handshake(
677                        self._on_handshake_complete)
678                    offset = 1
679                else:
680                    ssldata = self._sslpipe.shutdown(self._finalize)
681                    offset = 1
682
683                for chunk in ssldata:
684                    self._transport.write(chunk)
685
686                if offset < len(data):
687                    self._write_backlog[0] = (data, offset)
688                    # A short write means that a write is blocked on a read
689                    # We need to enable reading if it is paused!
690                    assert self._sslpipe.need_ssldata
691                    if self._transport._paused:
692                        self._transport.resume_reading()
693                    break
694
695                # An entire chunk from the backlog was processed. We can
696                # delete it and reduce the outstanding buffer size.
697                del self._write_backlog[0]
698                self._write_buffer_size -= len(data)
699        except (SystemExit, KeyboardInterrupt):
700            raise
701        except BaseException as exc:
702            if self._in_handshake:
703                # Exceptions will be re-raised in _on_handshake_complete.
704                self._on_handshake_complete(exc)
705            else:
706                self._fatal_error(exc, 'Fatal error on SSL transport')
707
708    def _fatal_error(self, exc, message='Fatal error on transport'):
709        if isinstance(exc, OSError):
710            if self._loop.get_debug():
711                logger.debug("%r: %s", self, message, exc_info=True)
712        else:
713            self._loop.call_exception_handler({
714                'message': message,
715                'exception': exc,
716                'transport': self._transport,
717                'protocol': self,
718            })
719        if self._transport:
720            self._transport._force_close(exc)
721
722    def _finalize(self):
723        self._sslpipe = None
724
725        if self._transport is not None:
726            self._transport.close()
727
728    def _abort(self):
729        try:
730            if self._transport is not None:
731                self._transport.abort()
732        finally:
733            self._finalize()
734