• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import collections
2import warnings
3try:
4    import ssl
5except ImportError:  # pragma: no cover
6    ssl = None
7
8from . import constants
9from . import protocols
10from . import transports
11from .log import logger
12
13
14def _create_transport_context(server_side, server_hostname):
15    if server_side:
16        raise ValueError('Server side SSL needs a valid SSLContext')
17
18    # Client side may pass ssl=True to use a default
19    # context; in that case the sslcontext passed is None.
20    # The default is secure for client connections.
21    # Python 3.4+: use up-to-date strong settings.
22    sslcontext = ssl.create_default_context()
23    if not server_hostname:
24        sslcontext.check_hostname = False
25    return sslcontext
26
27
28# States of an _SSLPipe.
29_UNWRAPPED = "UNWRAPPED"
30_DO_HANDSHAKE = "DO_HANDSHAKE"
31_WRAPPED = "WRAPPED"
32_SHUTDOWN = "SHUTDOWN"
33
34
35class _SSLPipe(object):
36    """An SSL "Pipe".
37
38    An SSL pipe allows you to communicate with an SSL/TLS protocol instance
39    through memory buffers. It can be used to implement a security layer for an
40    existing connection where you don't have access to the connection's file
41    descriptor, or for some reason you don't want to use it.
42
43    An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode,
44    data is passed through untransformed. In wrapped mode, application level
45    data is encrypted to SSL record level data and vice versa. The SSL record
46    level is the lowest level in the SSL protocol suite and is what travels
47    as-is over the wire.
48
49    An SslPipe initially is in "unwrapped" mode. To start SSL, call
50    do_handshake(). To shutdown SSL again, call unwrap().
51    """
52
53    max_size = 256 * 1024   # Buffer size passed to read()
54
55    def __init__(self, context, server_side, server_hostname=None):
56        """
57        The *context* argument specifies the ssl.SSLContext to use.
58
59        The *server_side* argument indicates whether this is a server side or
60        client side transport.
61
62        The optional *server_hostname* argument can be used to specify the
63        hostname you are connecting to. You may only specify this parameter if
64        the _ssl module supports Server Name Indication (SNI).
65        """
66        self._context = context
67        self._server_side = server_side
68        self._server_hostname = server_hostname
69        self._state = _UNWRAPPED
70        self._incoming = ssl.MemoryBIO()
71        self._outgoing = ssl.MemoryBIO()
72        self._sslobj = None
73        self._need_ssldata = False
74        self._handshake_cb = None
75        self._shutdown_cb = None
76
77    @property
78    def context(self):
79        """The SSL context passed to the constructor."""
80        return self._context
81
82    @property
83    def ssl_object(self):
84        """The internal ssl.SSLObject instance.
85
86        Return None if the pipe is not wrapped.
87        """
88        return self._sslobj
89
90    @property
91    def need_ssldata(self):
92        """Whether more record level data is needed to complete a handshake
93        that is currently in progress."""
94        return self._need_ssldata
95
96    @property
97    def wrapped(self):
98        """
99        Whether a security layer is currently in effect.
100
101        Return False during handshake.
102        """
103        return self._state == _WRAPPED
104
105    def do_handshake(self, callback=None):
106        """Start the SSL handshake.
107
108        Return a list of ssldata. A ssldata element is a list of buffers
109
110        The optional *callback* argument can be used to install a callback that
111        will be called when the handshake is complete. The callback will be
112        called with None if successful, else an exception instance.
113        """
114        if self._state != _UNWRAPPED:
115            raise RuntimeError('handshake in progress or completed')
116        self._sslobj = self._context.wrap_bio(
117            self._incoming, self._outgoing,
118            server_side=self._server_side,
119            server_hostname=self._server_hostname)
120        self._state = _DO_HANDSHAKE
121        self._handshake_cb = callback
122        ssldata, appdata = self.feed_ssldata(b'', only_handshake=True)
123        assert len(appdata) == 0
124        return ssldata
125
126    def shutdown(self, callback=None):
127        """Start the SSL shutdown sequence.
128
129        Return a list of ssldata. A ssldata element is a list of buffers
130
131        The optional *callback* argument can be used to install a callback that
132        will be called when the shutdown is complete. The callback will be
133        called without arguments.
134        """
135        if self._state == _UNWRAPPED:
136            raise RuntimeError('no security layer present')
137        if self._state == _SHUTDOWN:
138            raise RuntimeError('shutdown in progress')
139        assert self._state in (_WRAPPED, _DO_HANDSHAKE)
140        self._state = _SHUTDOWN
141        self._shutdown_cb = callback
142        ssldata, appdata = self.feed_ssldata(b'')
143        assert appdata == [] or appdata == [b'']
144        return ssldata
145
146    def feed_eof(self):
147        """Send a potentially "ragged" EOF.
148
149        This method will raise an SSL_ERROR_EOF exception if the EOF is
150        unexpected.
151        """
152        self._incoming.write_eof()
153        ssldata, appdata = self.feed_ssldata(b'')
154        assert appdata == [] or appdata == [b'']
155
156    def feed_ssldata(self, data, only_handshake=False):
157        """Feed SSL record level data into the pipe.
158
159        The data must be a bytes instance. It is OK to send an empty bytes
160        instance. This can be used to get ssldata for a handshake initiated by
161        this endpoint.
162
163        Return a (ssldata, appdata) tuple. The ssldata element is a list of
164        buffers containing SSL data that needs to be sent to the remote SSL.
165
166        The appdata element is a list of buffers containing plaintext data that
167        needs to be forwarded to the application. The appdata list may contain
168        an empty buffer indicating an SSL "close_notify" alert. This alert must
169        be acknowledged by calling shutdown().
170        """
171        if self._state == _UNWRAPPED:
172            # If unwrapped, pass plaintext data straight through.
173            if data:
174                appdata = [data]
175            else:
176                appdata = []
177            return ([], appdata)
178
179        self._need_ssldata = False
180        if data:
181            self._incoming.write(data)
182
183        ssldata = []
184        appdata = []
185        try:
186            if self._state == _DO_HANDSHAKE:
187                # Call do_handshake() until it doesn't raise anymore.
188                self._sslobj.do_handshake()
189                self._state = _WRAPPED
190                if self._handshake_cb:
191                    self._handshake_cb(None)
192                if only_handshake:
193                    return (ssldata, appdata)
194                # Handshake done: execute the wrapped block
195
196            if self._state == _WRAPPED:
197                # Main state: read data from SSL until close_notify
198                while True:
199                    chunk = self._sslobj.read(self.max_size)
200                    appdata.append(chunk)
201                    if not chunk:  # close_notify
202                        break
203
204            elif self._state == _SHUTDOWN:
205                # Call shutdown() until it doesn't raise anymore.
206                self._sslobj.unwrap()
207                self._sslobj = None
208                self._state = _UNWRAPPED
209                if self._shutdown_cb:
210                    self._shutdown_cb()
211
212            elif self._state == _UNWRAPPED:
213                # Drain possible plaintext data after close_notify.
214                appdata.append(self._incoming.read())
215        except (ssl.SSLError, ssl.CertificateError) as exc:
216            exc_errno = getattr(exc, 'errno', None)
217            if exc_errno not in (
218                    ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
219                    ssl.SSL_ERROR_SYSCALL):
220                if self._state == _DO_HANDSHAKE and self._handshake_cb:
221                    self._handshake_cb(exc)
222                raise
223            self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
224
225        # Check for record level data that needs to be sent back.
226        # Happens for the initial handshake and renegotiations.
227        if self._outgoing.pending:
228            ssldata.append(self._outgoing.read())
229        return (ssldata, appdata)
230
231    def feed_appdata(self, data, offset=0):
232        """Feed plaintext data into the pipe.
233
234        Return an (ssldata, offset) tuple. The ssldata element is a list of
235        buffers containing record level data that needs to be sent to the
236        remote SSL instance. The offset is the number of plaintext bytes that
237        were processed, which may be less than the length of data.
238
239        NOTE: In case of short writes, this call MUST be retried with the SAME
240        buffer passed into the *data* argument (i.e. the id() must be the
241        same). This is an OpenSSL requirement. A further particularity is that
242        a short write will always have offset == 0, because the _ssl module
243        does not enable partial writes. And even though the offset is zero,
244        there will still be encrypted data in ssldata.
245        """
246        assert 0 <= offset <= len(data)
247        if self._state == _UNWRAPPED:
248            # pass through data in unwrapped mode
249            if offset < len(data):
250                ssldata = [data[offset:]]
251            else:
252                ssldata = []
253            return (ssldata, len(data))
254
255        ssldata = []
256        view = memoryview(data)
257        while True:
258            self._need_ssldata = False
259            try:
260                if offset < len(view):
261                    offset += self._sslobj.write(view[offset:])
262            except ssl.SSLError as exc:
263                # It is not allowed to call write() after unwrap() until the
264                # close_notify is acknowledged. We return the condition to the
265                # caller as a short write.
266                exc_errno = getattr(exc, 'errno', None)
267                if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
268                    exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ
269                if exc_errno not in (ssl.SSL_ERROR_WANT_READ,
270                                     ssl.SSL_ERROR_WANT_WRITE,
271                                     ssl.SSL_ERROR_SYSCALL):
272                    raise
273                self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
274
275            # See if there's any record level data back for us.
276            if self._outgoing.pending:
277                ssldata.append(self._outgoing.read())
278            if offset == len(view) or self._need_ssldata:
279                break
280        return (ssldata, offset)
281
282
283class _SSLProtocolTransport(transports._FlowControlMixin,
284                            transports.Transport):
285
286    _sendfile_compatible = constants._SendfileMode.FALLBACK
287
288    def __init__(self, loop, ssl_protocol):
289        self._loop = loop
290        # SSLProtocol instance
291        self._ssl_protocol = ssl_protocol
292        self._closed = False
293
294    def get_extra_info(self, name, default=None):
295        """Get optional transport information."""
296        return self._ssl_protocol._get_extra_info(name, default)
297
298    def set_protocol(self, protocol):
299        self._ssl_protocol._set_app_protocol(protocol)
300
301    def get_protocol(self):
302        return self._ssl_protocol._app_protocol
303
304    def is_closing(self):
305        return self._closed
306
307    def close(self):
308        """Close the transport.
309
310        Buffered data will be flushed asynchronously.  No more data
311        will be received.  After all buffered data is flushed, the
312        protocol's connection_lost() method will (eventually) called
313        with None as its argument.
314        """
315        self._closed = True
316        self._ssl_protocol._start_shutdown()
317
318    def __del__(self, _warn=warnings.warn):
319        if not self._closed:
320            _warn(f"unclosed transport {self!r}", ResourceWarning, source=self)
321            self.close()
322
323    def is_reading(self):
324        tr = self._ssl_protocol._transport
325        if tr is None:
326            raise RuntimeError('SSL transport has not been initialized yet')
327        return tr.is_reading()
328
329    def pause_reading(self):
330        """Pause the receiving end.
331
332        No data will be passed to the protocol's data_received()
333        method until resume_reading() is called.
334        """
335        self._ssl_protocol._transport.pause_reading()
336
337    def resume_reading(self):
338        """Resume the receiving end.
339
340        Data received will once again be passed to the protocol's
341        data_received() method.
342        """
343        self._ssl_protocol._transport.resume_reading()
344
345    def set_write_buffer_limits(self, high=None, low=None):
346        """Set the high- and low-water limits for write flow control.
347
348        These two values control when to call the protocol's
349        pause_writing() and resume_writing() methods.  If specified,
350        the low-water limit must be less than or equal to the
351        high-water limit.  Neither value can be negative.
352
353        The defaults are implementation-specific.  If only the
354        high-water limit is given, the low-water limit defaults to an
355        implementation-specific value less than or equal to the
356        high-water limit.  Setting high to zero forces low to zero as
357        well, and causes pause_writing() to be called whenever the
358        buffer becomes non-empty.  Setting low to zero causes
359        resume_writing() to be called only once the buffer is empty.
360        Use of zero for either limit is generally sub-optimal as it
361        reduces opportunities for doing I/O and computation
362        concurrently.
363        """
364        self._ssl_protocol._transport.set_write_buffer_limits(high, low)
365
366    def get_write_buffer_size(self):
367        """Return the current size of the write buffer."""
368        return self._ssl_protocol._transport.get_write_buffer_size()
369
370    def get_write_buffer_limits(self):
371        """Get the high and low watermarks for write flow control.
372        Return a tuple (low, high) where low and high are
373        positive number of bytes."""
374        return self._ssl_protocol._transport.get_write_buffer_limits()
375
376    @property
377    def _protocol_paused(self):
378        # Required for sendfile fallback pause_writing/resume_writing logic
379        return self._ssl_protocol._transport._protocol_paused
380
381    def write(self, data):
382        """Write some data bytes to the transport.
383
384        This does not block; it buffers the data and arranges for it
385        to be sent out asynchronously.
386        """
387        if not isinstance(data, (bytes, bytearray, memoryview)):
388            raise TypeError(f"data: expecting a bytes-like instance, "
389                            f"got {type(data).__name__}")
390        if not data:
391            return
392        self._ssl_protocol._write_appdata(data)
393
394    def can_write_eof(self):
395        """Return True if this transport supports write_eof(), False if not."""
396        return False
397
398    def abort(self):
399        """Close the transport immediately.
400
401        Buffered data will be lost.  No more data will be received.
402        The protocol's connection_lost() method will (eventually) be
403        called with None as its argument.
404        """
405        self._ssl_protocol._abort()
406        self._closed = True
407
408
409class SSLProtocol(protocols.Protocol):
410    """SSL protocol.
411
412    Implementation of SSL on top of a socket using incoming and outgoing
413    buffers which are ssl.MemoryBIO objects.
414    """
415
416    def __init__(self, loop, app_protocol, sslcontext, waiter,
417                 server_side=False, server_hostname=None,
418                 call_connection_made=True,
419                 ssl_handshake_timeout=None):
420        if ssl is None:
421            raise RuntimeError('stdlib ssl module not available')
422
423        if ssl_handshake_timeout is None:
424            ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
425        elif ssl_handshake_timeout <= 0:
426            raise ValueError(
427                f"ssl_handshake_timeout should be a positive number, "
428                f"got {ssl_handshake_timeout}")
429
430        if not sslcontext:
431            sslcontext = _create_transport_context(
432                server_side, server_hostname)
433
434        self._server_side = server_side
435        if server_hostname and not server_side:
436            self._server_hostname = server_hostname
437        else:
438            self._server_hostname = None
439        self._sslcontext = sslcontext
440        # SSL-specific extra info. More info are set when the handshake
441        # completes.
442        self._extra = dict(sslcontext=sslcontext)
443
444        # App data write buffering
445        self._write_backlog = collections.deque()
446        self._write_buffer_size = 0
447
448        self._waiter = waiter
449        self._loop = loop
450        self._set_app_protocol(app_protocol)
451        self._app_transport = _SSLProtocolTransport(self._loop, self)
452        # _SSLPipe instance (None until the connection is made)
453        self._sslpipe = None
454        self._session_established = False
455        self._in_handshake = False
456        self._in_shutdown = False
457        # transport, ex: SelectorSocketTransport
458        self._transport = None
459        self._call_connection_made = call_connection_made
460        self._ssl_handshake_timeout = ssl_handshake_timeout
461
462    def _set_app_protocol(self, app_protocol):
463        self._app_protocol = app_protocol
464        self._app_protocol_is_buffer = \
465            isinstance(app_protocol, protocols.BufferedProtocol)
466
467    def _wakeup_waiter(self, exc=None):
468        if self._waiter is None:
469            return
470        if not self._waiter.cancelled():
471            if exc is not None:
472                self._waiter.set_exception(exc)
473            else:
474                self._waiter.set_result(None)
475        self._waiter = None
476
477    def connection_made(self, transport):
478        """Called when the low-level connection is made.
479
480        Start the SSL handshake.
481        """
482        self._transport = transport
483        self._sslpipe = _SSLPipe(self._sslcontext,
484                                 self._server_side,
485                                 self._server_hostname)
486        self._start_handshake()
487
488    def connection_lost(self, exc):
489        """Called when the low-level connection is lost or closed.
490
491        The argument is an exception object or None (the latter
492        meaning a regular EOF is received or the connection was
493        aborted or closed).
494        """
495        if self._session_established:
496            self._session_established = False
497            self._loop.call_soon(self._app_protocol.connection_lost, exc)
498        else:
499            # Most likely an exception occurred while in SSL handshake.
500            # Just mark the app transport as closed so that its __del__
501            # doesn't complain.
502            if self._app_transport is not None:
503                self._app_transport._closed = True
504        self._transport = None
505        self._app_transport = None
506        if getattr(self, '_handshake_timeout_handle', None):
507            self._handshake_timeout_handle.cancel()
508        self._wakeup_waiter(exc)
509        self._app_protocol = None
510        self._sslpipe = None
511
512    def pause_writing(self):
513        """Called when the low-level transport's buffer goes over
514        the high-water mark.
515        """
516        self._app_protocol.pause_writing()
517
518    def resume_writing(self):
519        """Called when the low-level transport's buffer drains below
520        the low-water mark.
521        """
522        self._app_protocol.resume_writing()
523
524    def data_received(self, data):
525        """Called when some SSL data is received.
526
527        The argument is a bytes object.
528        """
529        if self._sslpipe is None:
530            # transport closing, sslpipe is destroyed
531            return
532
533        try:
534            ssldata, appdata = self._sslpipe.feed_ssldata(data)
535        except (SystemExit, KeyboardInterrupt):
536            raise
537        except BaseException as e:
538            self._fatal_error(e, 'SSL error in data received')
539            return
540
541        for chunk in ssldata:
542            self._transport.write(chunk)
543
544        for chunk in appdata:
545            if chunk:
546                try:
547                    if self._app_protocol_is_buffer:
548                        protocols._feed_data_to_buffered_proto(
549                            self._app_protocol, chunk)
550                    else:
551                        self._app_protocol.data_received(chunk)
552                except (SystemExit, KeyboardInterrupt):
553                    raise
554                except BaseException as ex:
555                    self._fatal_error(
556                        ex, 'application protocol failed to receive SSL data')
557                    return
558            else:
559                self._start_shutdown()
560                break
561
562    def eof_received(self):
563        """Called when the other end of the low-level stream
564        is half-closed.
565
566        If this returns a false value (including None), the transport
567        will close itself.  If it returns a true value, closing the
568        transport is up to the protocol.
569        """
570        try:
571            if self._loop.get_debug():
572                logger.debug("%r received EOF", self)
573
574            self._wakeup_waiter(ConnectionResetError)
575
576            if not self._in_handshake:
577                keep_open = self._app_protocol.eof_received()
578                if keep_open:
579                    logger.warning('returning true from eof_received() '
580                                   'has no effect when using ssl')
581        finally:
582            self._transport.close()
583
584    def _get_extra_info(self, name, default=None):
585        if name in self._extra:
586            return self._extra[name]
587        elif self._transport is not None:
588            return self._transport.get_extra_info(name, default)
589        else:
590            return default
591
592    def _start_shutdown(self):
593        if self._in_shutdown:
594            return
595        if self._in_handshake:
596            self._abort()
597        else:
598            self._in_shutdown = True
599            self._write_appdata(b'')
600
601    def _write_appdata(self, data):
602        self._write_backlog.append((data, 0))
603        self._write_buffer_size += len(data)
604        self._process_write_backlog()
605
606    def _start_handshake(self):
607        if self._loop.get_debug():
608            logger.debug("%r starts SSL handshake", self)
609            self._handshake_start_time = self._loop.time()
610        else:
611            self._handshake_start_time = None
612        self._in_handshake = True
613        # (b'', 1) is a special value in _process_write_backlog() to do
614        # the SSL handshake
615        self._write_backlog.append((b'', 1))
616        self._handshake_timeout_handle = \
617            self._loop.call_later(self._ssl_handshake_timeout,
618                                  self._check_handshake_timeout)
619        self._process_write_backlog()
620
621    def _check_handshake_timeout(self):
622        if self._in_handshake is True:
623            msg = (
624                f"SSL handshake is taking longer than "
625                f"{self._ssl_handshake_timeout} seconds: "
626                f"aborting the connection"
627            )
628            self._fatal_error(ConnectionAbortedError(msg))
629
630    def _on_handshake_complete(self, handshake_exc):
631        self._in_handshake = False
632        self._handshake_timeout_handle.cancel()
633
634        sslobj = self._sslpipe.ssl_object
635        try:
636            if handshake_exc is not None:
637                raise handshake_exc
638
639            peercert = sslobj.getpeercert()
640        except (SystemExit, KeyboardInterrupt):
641            raise
642        except BaseException as exc:
643            if isinstance(exc, ssl.CertificateError):
644                msg = 'SSL handshake failed on verifying the certificate'
645            else:
646                msg = 'SSL handshake failed'
647            self._fatal_error(exc, msg)
648            return
649
650        if self._loop.get_debug():
651            dt = self._loop.time() - self._handshake_start_time
652            logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
653
654        # Add extra info that becomes available after handshake.
655        self._extra.update(peercert=peercert,
656                           cipher=sslobj.cipher(),
657                           compression=sslobj.compression(),
658                           ssl_object=sslobj,
659                           )
660        if self._call_connection_made:
661            self._app_protocol.connection_made(self._app_transport)
662        self._wakeup_waiter()
663        self._session_established = True
664        # In case transport.write() was already called. Don't call
665        # immediately _process_write_backlog(), but schedule it:
666        # _on_handshake_complete() can be called indirectly from
667        # _process_write_backlog(), and _process_write_backlog() is not
668        # reentrant.
669        self._loop.call_soon(self._process_write_backlog)
670
671    def _process_write_backlog(self):
672        # Try to make progress on the write backlog.
673        if self._transport is None or self._sslpipe is None:
674            return
675
676        try:
677            for i in range(len(self._write_backlog)):
678                data, offset = self._write_backlog[0]
679                if data:
680                    ssldata, offset = self._sslpipe.feed_appdata(data, offset)
681                elif offset:
682                    ssldata = self._sslpipe.do_handshake(
683                        self._on_handshake_complete)
684                    offset = 1
685                else:
686                    ssldata = self._sslpipe.shutdown(self._finalize)
687                    offset = 1
688
689                for chunk in ssldata:
690                    self._transport.write(chunk)
691
692                if offset < len(data):
693                    self._write_backlog[0] = (data, offset)
694                    # A short write means that a write is blocked on a read
695                    # We need to enable reading if it is paused!
696                    assert self._sslpipe.need_ssldata
697                    if self._transport._paused:
698                        self._transport.resume_reading()
699                    break
700
701                # An entire chunk from the backlog was processed. We can
702                # delete it and reduce the outstanding buffer size.
703                del self._write_backlog[0]
704                self._write_buffer_size -= len(data)
705        except (SystemExit, KeyboardInterrupt):
706            raise
707        except BaseException as exc:
708            if self._in_handshake:
709                # Exceptions will be re-raised in _on_handshake_complete.
710                self._on_handshake_complete(exc)
711            else:
712                self._fatal_error(exc, 'Fatal error on SSL transport')
713
714    def _fatal_error(self, exc, message='Fatal error on transport'):
715        if isinstance(exc, OSError):
716            if self._loop.get_debug():
717                logger.debug("%r: %s", self, message, exc_info=True)
718        else:
719            self._loop.call_exception_handler({
720                'message': message,
721                'exception': exc,
722                'transport': self._transport,
723                'protocol': self,
724            })
725        if self._transport:
726            self._transport._force_close(exc)
727
728    def _finalize(self):
729        self._sslpipe = None
730
731        if self._transport is not None:
732            self._transport.close()
733
734    def _abort(self):
735        try:
736            if self._transport is not None:
737                self._transport.abort()
738        finally:
739            self._finalize()
740