• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import collections
2import warnings
3try:
4    import ssl
5except ImportError:  # pragma: no cover
6    ssl = None
7
8from . import base_events
9from . import compat
10from . import protocols
11from . import transports
12from .log import logger
13
14
15def _create_transport_context(server_side, server_hostname):
16    if server_side:
17        raise ValueError('Server side SSL needs a valid SSLContext')
18
19    # Client side may pass ssl=True to use a default
20    # context; in that case the sslcontext passed is None.
21    # The default is secure for client connections.
22    if hasattr(ssl, 'create_default_context'):
23        # Python 3.4+: use up-to-date strong settings.
24        sslcontext = ssl.create_default_context()
25        if not server_hostname:
26            sslcontext.check_hostname = False
27    else:
28        # Fallback for Python 3.3.
29        sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
30        sslcontext.options |= ssl.OP_NO_SSLv2
31        sslcontext.options |= ssl.OP_NO_SSLv3
32        sslcontext.set_default_verify_paths()
33        sslcontext.verify_mode = ssl.CERT_REQUIRED
34    return sslcontext
35
36
37def _is_sslproto_available():
38    return hasattr(ssl, "MemoryBIO")
39
40
41# States of an _SSLPipe.
42_UNWRAPPED = "UNWRAPPED"
43_DO_HANDSHAKE = "DO_HANDSHAKE"
44_WRAPPED = "WRAPPED"
45_SHUTDOWN = "SHUTDOWN"
46
47
48class _SSLPipe(object):
49    """An SSL "Pipe".
50
51    An SSL pipe allows you to communicate with an SSL/TLS protocol instance
52    through memory buffers. It can be used to implement a security layer for an
53    existing connection where you don't have access to the connection's file
54    descriptor, or for some reason you don't want to use it.
55
56    An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode,
57    data is passed through untransformed. In wrapped mode, application level
58    data is encrypted to SSL record level data and vice versa. The SSL record
59    level is the lowest level in the SSL protocol suite and is what travels
60    as-is over the wire.
61
62    An SslPipe initially is in "unwrapped" mode. To start SSL, call
63    do_handshake(). To shutdown SSL again, call unwrap().
64    """
65
66    max_size = 256 * 1024   # Buffer size passed to read()
67
68    def __init__(self, context, server_side, server_hostname=None):
69        """
70        The *context* argument specifies the ssl.SSLContext to use.
71
72        The *server_side* argument indicates whether this is a server side or
73        client side transport.
74
75        The optional *server_hostname* argument can be used to specify the
76        hostname you are connecting to. You may only specify this parameter if
77        the _ssl module supports Server Name Indication (SNI).
78        """
79        self._context = context
80        self._server_side = server_side
81        self._server_hostname = server_hostname
82        self._state = _UNWRAPPED
83        self._incoming = ssl.MemoryBIO()
84        self._outgoing = ssl.MemoryBIO()
85        self._sslobj = None
86        self._need_ssldata = False
87        self._handshake_cb = None
88        self._shutdown_cb = None
89
90    @property
91    def context(self):
92        """The SSL context passed to the constructor."""
93        return self._context
94
95    @property
96    def ssl_object(self):
97        """The internal ssl.SSLObject instance.
98
99        Return None if the pipe is not wrapped.
100        """
101        return self._sslobj
102
103    @property
104    def need_ssldata(self):
105        """Whether more record level data is needed to complete a handshake
106        that is currently in progress."""
107        return self._need_ssldata
108
109    @property
110    def wrapped(self):
111        """
112        Whether a security layer is currently in effect.
113
114        Return False during handshake.
115        """
116        return self._state == _WRAPPED
117
118    def do_handshake(self, callback=None):
119        """Start the SSL handshake.
120
121        Return a list of ssldata. A ssldata element is a list of buffers
122
123        The optional *callback* argument can be used to install a callback that
124        will be called when the handshake is complete. The callback will be
125        called with None if successful, else an exception instance.
126        """
127        if self._state != _UNWRAPPED:
128            raise RuntimeError('handshake in progress or completed')
129        self._sslobj = self._context.wrap_bio(
130            self._incoming, self._outgoing,
131            server_side=self._server_side,
132            server_hostname=self._server_hostname)
133        self._state = _DO_HANDSHAKE
134        self._handshake_cb = callback
135        ssldata, appdata = self.feed_ssldata(b'', only_handshake=True)
136        assert len(appdata) == 0
137        return ssldata
138
139    def shutdown(self, callback=None):
140        """Start the SSL shutdown sequence.
141
142        Return a list of ssldata. A ssldata element is a list of buffers
143
144        The optional *callback* argument can be used to install a callback that
145        will be called when the shutdown is complete. The callback will be
146        called without arguments.
147        """
148        if self._state == _UNWRAPPED:
149            raise RuntimeError('no security layer present')
150        if self._state == _SHUTDOWN:
151            raise RuntimeError('shutdown in progress')
152        assert self._state in (_WRAPPED, _DO_HANDSHAKE)
153        self._state = _SHUTDOWN
154        self._shutdown_cb = callback
155        ssldata, appdata = self.feed_ssldata(b'')
156        assert appdata == [] or appdata == [b'']
157        return ssldata
158
159    def feed_eof(self):
160        """Send a potentially "ragged" EOF.
161
162        This method will raise an SSL_ERROR_EOF exception if the EOF is
163        unexpected.
164        """
165        self._incoming.write_eof()
166        ssldata, appdata = self.feed_ssldata(b'')
167        assert appdata == [] or appdata == [b'']
168
169    def feed_ssldata(self, data, only_handshake=False):
170        """Feed SSL record level data into the pipe.
171
172        The data must be a bytes instance. It is OK to send an empty bytes
173        instance. This can be used to get ssldata for a handshake initiated by
174        this endpoint.
175
176        Return a (ssldata, appdata) tuple. The ssldata element is a list of
177        buffers containing SSL data that needs to be sent to the remote SSL.
178
179        The appdata element is a list of buffers containing plaintext data that
180        needs to be forwarded to the application. The appdata list may contain
181        an empty buffer indicating an SSL "close_notify" alert. This alert must
182        be acknowledged by calling shutdown().
183        """
184        if self._state == _UNWRAPPED:
185            # If unwrapped, pass plaintext data straight through.
186            if data:
187                appdata = [data]
188            else:
189                appdata = []
190            return ([], appdata)
191
192        self._need_ssldata = False
193        if data:
194            self._incoming.write(data)
195
196        ssldata = []
197        appdata = []
198        try:
199            if self._state == _DO_HANDSHAKE:
200                # Call do_handshake() until it doesn't raise anymore.
201                self._sslobj.do_handshake()
202                self._state = _WRAPPED
203                if self._handshake_cb:
204                    self._handshake_cb(None)
205                if only_handshake:
206                    return (ssldata, appdata)
207                # Handshake done: execute the wrapped block
208
209            if self._state == _WRAPPED:
210                # Main state: read data from SSL until close_notify
211                while True:
212                    chunk = self._sslobj.read(self.max_size)
213                    appdata.append(chunk)
214                    if not chunk:  # close_notify
215                        break
216
217            elif self._state == _SHUTDOWN:
218                # Call shutdown() until it doesn't raise anymore.
219                self._sslobj.unwrap()
220                self._sslobj = None
221                self._state = _UNWRAPPED
222                if self._shutdown_cb:
223                    self._shutdown_cb()
224
225            elif self._state == _UNWRAPPED:
226                # Drain possible plaintext data after close_notify.
227                appdata.append(self._incoming.read())
228        except (ssl.SSLError, ssl.CertificateError) as exc:
229            if getattr(exc, 'errno', None) not in (
230                    ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
231                    ssl.SSL_ERROR_SYSCALL):
232                if self._state == _DO_HANDSHAKE and self._handshake_cb:
233                    self._handshake_cb(exc)
234                raise
235            self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ)
236
237        # Check for record level data that needs to be sent back.
238        # Happens for the initial handshake and renegotiations.
239        if self._outgoing.pending:
240            ssldata.append(self._outgoing.read())
241        return (ssldata, appdata)
242
243    def feed_appdata(self, data, offset=0):
244        """Feed plaintext data into the pipe.
245
246        Return an (ssldata, offset) tuple. The ssldata element is a list of
247        buffers containing record level data that needs to be sent to the
248        remote SSL instance. The offset is the number of plaintext bytes that
249        were processed, which may be less than the length of data.
250
251        NOTE: In case of short writes, this call MUST be retried with the SAME
252        buffer passed into the *data* argument (i.e. the id() must be the
253        same). This is an OpenSSL requirement. A further particularity is that
254        a short write will always have offset == 0, because the _ssl module
255        does not enable partial writes. And even though the offset is zero,
256        there will still be encrypted data in ssldata.
257        """
258        assert 0 <= offset <= len(data)
259        if self._state == _UNWRAPPED:
260            # pass through data in unwrapped mode
261            if offset < len(data):
262                ssldata = [data[offset:]]
263            else:
264                ssldata = []
265            return (ssldata, len(data))
266
267        ssldata = []
268        view = memoryview(data)
269        while True:
270            self._need_ssldata = False
271            try:
272                if offset < len(view):
273                    offset += self._sslobj.write(view[offset:])
274            except ssl.SSLError as exc:
275                # It is not allowed to call write() after unwrap() until the
276                # close_notify is acknowledged. We return the condition to the
277                # caller as a short write.
278                if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
279                    exc.errno = ssl.SSL_ERROR_WANT_READ
280                if exc.errno not in (ssl.SSL_ERROR_WANT_READ,
281                                     ssl.SSL_ERROR_WANT_WRITE,
282                                     ssl.SSL_ERROR_SYSCALL):
283                    raise
284                self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ)
285
286            # See if there's any record level data back for us.
287            if self._outgoing.pending:
288                ssldata.append(self._outgoing.read())
289            if offset == len(view) or self._need_ssldata:
290                break
291        return (ssldata, offset)
292
293
294class _SSLProtocolTransport(transports._FlowControlMixin,
295                            transports.Transport):
296
297    def __init__(self, loop, ssl_protocol, app_protocol):
298        self._loop = loop
299        # SSLProtocol instance
300        self._ssl_protocol = ssl_protocol
301        self._app_protocol = app_protocol
302        self._closed = False
303
304    def get_extra_info(self, name, default=None):
305        """Get optional transport information."""
306        return self._ssl_protocol._get_extra_info(name, default)
307
308    def set_protocol(self, protocol):
309        self._app_protocol = protocol
310
311    def get_protocol(self):
312        return self._app_protocol
313
314    def is_closing(self):
315        return self._closed
316
317    def close(self):
318        """Close the transport.
319
320        Buffered data will be flushed asynchronously.  No more data
321        will be received.  After all buffered data is flushed, the
322        protocol's connection_lost() method will (eventually) called
323        with None as its argument.
324        """
325        self._closed = True
326        self._ssl_protocol._start_shutdown()
327
328    # On Python 3.3 and older, objects with a destructor part of a reference
329    # cycle are never destroyed. It's not more the case on Python 3.4 thanks
330    # to the PEP 442.
331    if compat.PY34:
332        def __del__(self):
333            if not self._closed:
334                warnings.warn("unclosed transport %r" % self, ResourceWarning,
335                              source=self)
336                self.close()
337
338    def pause_reading(self):
339        """Pause the receiving end.
340
341        No data will be passed to the protocol's data_received()
342        method until resume_reading() is called.
343        """
344        self._ssl_protocol._transport.pause_reading()
345
346    def resume_reading(self):
347        """Resume the receiving end.
348
349        Data received will once again be passed to the protocol's
350        data_received() method.
351        """
352        self._ssl_protocol._transport.resume_reading()
353
354    def set_write_buffer_limits(self, high=None, low=None):
355        """Set the high- and low-water limits for write flow control.
356
357        These two values control when to call the protocol's
358        pause_writing() and resume_writing() methods.  If specified,
359        the low-water limit must be less than or equal to the
360        high-water limit.  Neither value can be negative.
361
362        The defaults are implementation-specific.  If only the
363        high-water limit is given, the low-water limit defaults to an
364        implementation-specific value less than or equal to the
365        high-water limit.  Setting high to zero forces low to zero as
366        well, and causes pause_writing() to be called whenever the
367        buffer becomes non-empty.  Setting low to zero causes
368        resume_writing() to be called only once the buffer is empty.
369        Use of zero for either limit is generally sub-optimal as it
370        reduces opportunities for doing I/O and computation
371        concurrently.
372        """
373        self._ssl_protocol._transport.set_write_buffer_limits(high, low)
374
375    def get_write_buffer_size(self):
376        """Return the current size of the write buffer."""
377        return self._ssl_protocol._transport.get_write_buffer_size()
378
379    def write(self, data):
380        """Write some data bytes to the transport.
381
382        This does not block; it buffers the data and arranges for it
383        to be sent out asynchronously.
384        """
385        if not isinstance(data, (bytes, bytearray, memoryview)):
386            raise TypeError("data: expecting a bytes-like instance, got {!r}"
387                                .format(type(data).__name__))
388        if not data:
389            return
390        self._ssl_protocol._write_appdata(data)
391
392    def can_write_eof(self):
393        """Return True if this transport supports write_eof(), False if not."""
394        return False
395
396    def abort(self):
397        """Close the transport immediately.
398
399        Buffered data will be lost.  No more data will be received.
400        The protocol's connection_lost() method will (eventually) be
401        called with None as its argument.
402        """
403        self._ssl_protocol._abort()
404
405
406class SSLProtocol(protocols.Protocol):
407    """SSL protocol.
408
409    Implementation of SSL on top of a socket using incoming and outgoing
410    buffers which are ssl.MemoryBIO objects.
411    """
412
413    def __init__(self, loop, app_protocol, sslcontext, waiter,
414                 server_side=False, server_hostname=None,
415                 call_connection_made=True):
416        if ssl is None:
417            raise RuntimeError('stdlib ssl module not available')
418
419        if not sslcontext:
420            sslcontext = _create_transport_context(server_side, server_hostname)
421
422        self._server_side = server_side
423        if server_hostname and not server_side:
424            self._server_hostname = server_hostname
425        else:
426            self._server_hostname = None
427        self._sslcontext = sslcontext
428        # SSL-specific extra info. More info are set when the handshake
429        # completes.
430        self._extra = dict(sslcontext=sslcontext)
431
432        # App data write buffering
433        self._write_backlog = collections.deque()
434        self._write_buffer_size = 0
435
436        self._waiter = waiter
437        self._loop = loop
438        self._app_protocol = app_protocol
439        self._app_transport = _SSLProtocolTransport(self._loop,
440                                                    self, self._app_protocol)
441        # _SSLPipe instance (None until the connection is made)
442        self._sslpipe = None
443        self._session_established = False
444        self._in_handshake = False
445        self._in_shutdown = False
446        # transport, ex: SelectorSocketTransport
447        self._transport = None
448        self._call_connection_made = call_connection_made
449
450    def _wakeup_waiter(self, exc=None):
451        if self._waiter is None:
452            return
453        if not self._waiter.cancelled():
454            if exc is not None:
455                self._waiter.set_exception(exc)
456            else:
457                self._waiter.set_result(None)
458        self._waiter = None
459
460    def connection_made(self, transport):
461        """Called when the low-level connection is made.
462
463        Start the SSL handshake.
464        """
465        self._transport = transport
466        self._sslpipe = _SSLPipe(self._sslcontext,
467                                 self._server_side,
468                                 self._server_hostname)
469        self._start_handshake()
470
471    def connection_lost(self, exc):
472        """Called when the low-level connection is lost or closed.
473
474        The argument is an exception object or None (the latter
475        meaning a regular EOF is received or the connection was
476        aborted or closed).
477        """
478        if self._session_established:
479            self._session_established = False
480            self._loop.call_soon(self._app_protocol.connection_lost, exc)
481        self._transport = None
482        self._app_transport = None
483        self._wakeup_waiter(exc)
484
485    def pause_writing(self):
486        """Called when the low-level transport's buffer goes over
487        the high-water mark.
488        """
489        self._app_protocol.pause_writing()
490
491    def resume_writing(self):
492        """Called when the low-level transport's buffer drains below
493        the low-water mark.
494        """
495        self._app_protocol.resume_writing()
496
497    def data_received(self, data):
498        """Called when some SSL data is received.
499
500        The argument is a bytes object.
501        """
502        try:
503            ssldata, appdata = self._sslpipe.feed_ssldata(data)
504        except ssl.SSLError as e:
505            if self._loop.get_debug():
506                logger.warning('%r: SSL error %s (reason %s)',
507                               self, e.errno, e.reason)
508            self._abort()
509            return
510
511        for chunk in ssldata:
512            self._transport.write(chunk)
513
514        for chunk in appdata:
515            if chunk:
516                self._app_protocol.data_received(chunk)
517            else:
518                self._start_shutdown()
519                break
520
521    def eof_received(self):
522        """Called when the other end of the low-level stream
523        is half-closed.
524
525        If this returns a false value (including None), the transport
526        will close itself.  If it returns a true value, closing the
527        transport is up to the protocol.
528        """
529        try:
530            if self._loop.get_debug():
531                logger.debug("%r received EOF", self)
532
533            self._wakeup_waiter(ConnectionResetError)
534
535            if not self._in_handshake:
536                keep_open = self._app_protocol.eof_received()
537                if keep_open:
538                    logger.warning('returning true from eof_received() '
539                                   'has no effect when using ssl')
540        finally:
541            self._transport.close()
542
543    def _get_extra_info(self, name, default=None):
544        if name in self._extra:
545            return self._extra[name]
546        else:
547            return self._transport.get_extra_info(name, default)
548
549    def _start_shutdown(self):
550        if self._in_shutdown:
551            return
552        self._in_shutdown = True
553        self._write_appdata(b'')
554
555    def _write_appdata(self, data):
556        self._write_backlog.append((data, 0))
557        self._write_buffer_size += len(data)
558        self._process_write_backlog()
559
560    def _start_handshake(self):
561        if self._loop.get_debug():
562            logger.debug("%r starts SSL handshake", self)
563            self._handshake_start_time = self._loop.time()
564        else:
565            self._handshake_start_time = None
566        self._in_handshake = True
567        # (b'', 1) is a special value in _process_write_backlog() to do
568        # the SSL handshake
569        self._write_backlog.append((b'', 1))
570        self._loop.call_soon(self._process_write_backlog)
571
572    def _on_handshake_complete(self, handshake_exc):
573        self._in_handshake = False
574
575        sslobj = self._sslpipe.ssl_object
576        try:
577            if handshake_exc is not None:
578                raise handshake_exc
579
580            peercert = sslobj.getpeercert()
581            if not hasattr(self._sslcontext, 'check_hostname'):
582                # Verify hostname if requested, Python 3.4+ uses check_hostname
583                # and checks the hostname in do_handshake()
584                if (self._server_hostname
585                and self._sslcontext.verify_mode != ssl.CERT_NONE):
586                    ssl.match_hostname(peercert, self._server_hostname)
587        except BaseException as exc:
588            if self._loop.get_debug():
589                if isinstance(exc, ssl.CertificateError):
590                    logger.warning("%r: SSL handshake failed "
591                                   "on verifying the certificate",
592                                   self, exc_info=True)
593                else:
594                    logger.warning("%r: SSL handshake failed",
595                                   self, exc_info=True)
596            self._transport.close()
597            if isinstance(exc, Exception):
598                self._wakeup_waiter(exc)
599                return
600            else:
601                raise
602
603        if self._loop.get_debug():
604            dt = self._loop.time() - self._handshake_start_time
605            logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
606
607        # Add extra info that becomes available after handshake.
608        self._extra.update(peercert=peercert,
609                           cipher=sslobj.cipher(),
610                           compression=sslobj.compression(),
611                           ssl_object=sslobj,
612                           )
613        if self._call_connection_made:
614            self._app_protocol.connection_made(self._app_transport)
615        self._wakeup_waiter()
616        self._session_established = True
617        # In case transport.write() was already called. Don't call
618        # immediately _process_write_backlog(), but schedule it:
619        # _on_handshake_complete() can be called indirectly from
620        # _process_write_backlog(), and _process_write_backlog() is not
621        # reentrant.
622        self._loop.call_soon(self._process_write_backlog)
623
624    def _process_write_backlog(self):
625        # Try to make progress on the write backlog.
626        if self._transport is None:
627            return
628
629        try:
630            for i in range(len(self._write_backlog)):
631                data, offset = self._write_backlog[0]
632                if data:
633                    ssldata, offset = self._sslpipe.feed_appdata(data, offset)
634                elif offset:
635                    ssldata = self._sslpipe.do_handshake(
636                        self._on_handshake_complete)
637                    offset = 1
638                else:
639                    ssldata = self._sslpipe.shutdown(self._finalize)
640                    offset = 1
641
642                for chunk in ssldata:
643                    self._transport.write(chunk)
644
645                if offset < len(data):
646                    self._write_backlog[0] = (data, offset)
647                    # A short write means that a write is blocked on a read
648                    # We need to enable reading if it is paused!
649                    assert self._sslpipe.need_ssldata
650                    if self._transport._paused:
651                        self._transport.resume_reading()
652                    break
653
654                # An entire chunk from the backlog was processed. We can
655                # delete it and reduce the outstanding buffer size.
656                del self._write_backlog[0]
657                self._write_buffer_size -= len(data)
658        except BaseException as exc:
659            if self._in_handshake:
660                # BaseExceptions will be re-raised in _on_handshake_complete.
661                self._on_handshake_complete(exc)
662            else:
663                self._fatal_error(exc, 'Fatal error on SSL transport')
664            if not isinstance(exc, Exception):
665                # BaseException
666                raise
667
668    def _fatal_error(self, exc, message='Fatal error on transport'):
669        # Should be called from exception handler only.
670        if isinstance(exc, base_events._FATAL_ERROR_IGNORE):
671            if self._loop.get_debug():
672                logger.debug("%r: %s", self, message, exc_info=True)
673        else:
674            self._loop.call_exception_handler({
675                'message': message,
676                'exception': exc,
677                'transport': self._transport,
678                'protocol': self,
679            })
680        if self._transport:
681            self._transport._force_close(exc)
682
683    def _finalize(self):
684        if self._transport is not None:
685            self._transport.close()
686
687    def _abort(self):
688        if self._transport is not None:
689            try:
690                self._transport.abort()
691            finally:
692                self._finalize()
693