• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Wrapper module for _ssl, providing some additional facilities
2# implemented in Python.  Written by Bill Janssen.
3
4"""This module provides some more Pythonic support for SSL.
5
6Object types:
7
8  SSLSocket -- subtype of socket.socket which does SSL over the socket
9
10Exceptions:
11
12  SSLError -- exception raised for I/O errors
13
14Functions:
15
16  cert_time_to_seconds -- convert time string used for certificate
17                          notBefore and notAfter functions to integer
18                          seconds past the Epoch (the time values
19                          returned from time.time())
20
21  get_server_certificate (addr, ssl_version, ca_certs, timeout) -- Retrieve the
22                          certificate from the server at the specified
23                          address and return it as a PEM-encoded string
24
25
26Integer constants:
27
28SSL_ERROR_ZERO_RETURN
29SSL_ERROR_WANT_READ
30SSL_ERROR_WANT_WRITE
31SSL_ERROR_WANT_X509_LOOKUP
32SSL_ERROR_SYSCALL
33SSL_ERROR_SSL
34SSL_ERROR_WANT_CONNECT
35
36SSL_ERROR_EOF
37SSL_ERROR_INVALID_ERROR_CODE
38
39The following group define certificate requirements that one side is
40allowing/requiring from the other side:
41
42CERT_NONE - no certificates from the other side are required (or will
43            be looked at if provided)
44CERT_OPTIONAL - certificates are not required, but if provided will be
45                validated, and if validation fails, the connection will
46                also fail
47CERT_REQUIRED - certificates are required, and will be validated, and
48                if validation fails, the connection will also fail
49
50The following constants identify various SSL protocol variants:
51
52PROTOCOL_SSLv2
53PROTOCOL_SSLv3
54PROTOCOL_SSLv23
55PROTOCOL_TLS
56PROTOCOL_TLS_CLIENT
57PROTOCOL_TLS_SERVER
58PROTOCOL_TLSv1
59PROTOCOL_TLSv1_1
60PROTOCOL_TLSv1_2
61
62The following constants identify various SSL alert message descriptions as per
63http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6
64
65ALERT_DESCRIPTION_CLOSE_NOTIFY
66ALERT_DESCRIPTION_UNEXPECTED_MESSAGE
67ALERT_DESCRIPTION_BAD_RECORD_MAC
68ALERT_DESCRIPTION_RECORD_OVERFLOW
69ALERT_DESCRIPTION_DECOMPRESSION_FAILURE
70ALERT_DESCRIPTION_HANDSHAKE_FAILURE
71ALERT_DESCRIPTION_BAD_CERTIFICATE
72ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE
73ALERT_DESCRIPTION_CERTIFICATE_REVOKED
74ALERT_DESCRIPTION_CERTIFICATE_EXPIRED
75ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN
76ALERT_DESCRIPTION_ILLEGAL_PARAMETER
77ALERT_DESCRIPTION_UNKNOWN_CA
78ALERT_DESCRIPTION_ACCESS_DENIED
79ALERT_DESCRIPTION_DECODE_ERROR
80ALERT_DESCRIPTION_DECRYPT_ERROR
81ALERT_DESCRIPTION_PROTOCOL_VERSION
82ALERT_DESCRIPTION_INSUFFICIENT_SECURITY
83ALERT_DESCRIPTION_INTERNAL_ERROR
84ALERT_DESCRIPTION_USER_CANCELLED
85ALERT_DESCRIPTION_NO_RENEGOTIATION
86ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION
87ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE
88ALERT_DESCRIPTION_UNRECOGNIZED_NAME
89ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE
90ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
91ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY
92"""
93
94import sys
95import os
96from collections import namedtuple
97from enum import Enum as _Enum, IntEnum as _IntEnum, IntFlag as _IntFlag
98from enum import _simple_enum
99
100import _ssl             # if we can't import it, let the error propagate
101
102from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
103from _ssl import _SSLContext, MemoryBIO, SSLSession
104from _ssl import (
105    SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
106    SSLSyscallError, SSLEOFError, SSLCertVerificationError
107    )
108from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj
109from _ssl import RAND_status, RAND_add, RAND_bytes
110try:
111    from _ssl import RAND_egd
112except ImportError:
113    # RAND_egd is not supported on some platforms
114    pass
115
116
117from _ssl import (
118    HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_SSLv2, HAS_SSLv3, HAS_TLSv1,
119    HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3, HAS_PSK
120)
121from _ssl import _DEFAULT_CIPHERS, _OPENSSL_API_VERSION
122
123_IntEnum._convert_(
124    '_SSLMethod', __name__,
125    lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23',
126    source=_ssl)
127
128_IntFlag._convert_(
129    'Options', __name__,
130    lambda name: name.startswith('OP_'),
131    source=_ssl)
132
133_IntEnum._convert_(
134    'AlertDescription', __name__,
135    lambda name: name.startswith('ALERT_DESCRIPTION_'),
136    source=_ssl)
137
138_IntEnum._convert_(
139    'SSLErrorNumber', __name__,
140    lambda name: name.startswith('SSL_ERROR_'),
141    source=_ssl)
142
143_IntFlag._convert_(
144    'VerifyFlags', __name__,
145    lambda name: name.startswith('VERIFY_'),
146    source=_ssl)
147
148_IntEnum._convert_(
149    'VerifyMode', __name__,
150    lambda name: name.startswith('CERT_'),
151    source=_ssl)
152
153PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_TLS
154_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()}
155
156_SSLv2_IF_EXISTS = getattr(_SSLMethod, 'PROTOCOL_SSLv2', None)
157
158
159@_simple_enum(_IntEnum)
160class TLSVersion:
161    MINIMUM_SUPPORTED = _ssl.PROTO_MINIMUM_SUPPORTED
162    SSLv3 = _ssl.PROTO_SSLv3
163    TLSv1 = _ssl.PROTO_TLSv1
164    TLSv1_1 = _ssl.PROTO_TLSv1_1
165    TLSv1_2 = _ssl.PROTO_TLSv1_2
166    TLSv1_3 = _ssl.PROTO_TLSv1_3
167    MAXIMUM_SUPPORTED = _ssl.PROTO_MAXIMUM_SUPPORTED
168
169
170@_simple_enum(_IntEnum)
171class _TLSContentType:
172    """Content types (record layer)
173
174    See RFC 8446, section B.1
175    """
176    CHANGE_CIPHER_SPEC = 20
177    ALERT = 21
178    HANDSHAKE = 22
179    APPLICATION_DATA = 23
180    # pseudo content types
181    HEADER = 0x100
182    INNER_CONTENT_TYPE = 0x101
183
184
185@_simple_enum(_IntEnum)
186class _TLSAlertType:
187    """Alert types for TLSContentType.ALERT messages
188
189    See RFC 8466, section B.2
190    """
191    CLOSE_NOTIFY = 0
192    UNEXPECTED_MESSAGE = 10
193    BAD_RECORD_MAC = 20
194    DECRYPTION_FAILED = 21
195    RECORD_OVERFLOW = 22
196    DECOMPRESSION_FAILURE = 30
197    HANDSHAKE_FAILURE = 40
198    NO_CERTIFICATE = 41
199    BAD_CERTIFICATE = 42
200    UNSUPPORTED_CERTIFICATE = 43
201    CERTIFICATE_REVOKED = 44
202    CERTIFICATE_EXPIRED = 45
203    CERTIFICATE_UNKNOWN = 46
204    ILLEGAL_PARAMETER = 47
205    UNKNOWN_CA = 48
206    ACCESS_DENIED = 49
207    DECODE_ERROR = 50
208    DECRYPT_ERROR = 51
209    EXPORT_RESTRICTION = 60
210    PROTOCOL_VERSION = 70
211    INSUFFICIENT_SECURITY = 71
212    INTERNAL_ERROR = 80
213    INAPPROPRIATE_FALLBACK = 86
214    USER_CANCELED = 90
215    NO_RENEGOTIATION = 100
216    MISSING_EXTENSION = 109
217    UNSUPPORTED_EXTENSION = 110
218    CERTIFICATE_UNOBTAINABLE = 111
219    UNRECOGNIZED_NAME = 112
220    BAD_CERTIFICATE_STATUS_RESPONSE = 113
221    BAD_CERTIFICATE_HASH_VALUE = 114
222    UNKNOWN_PSK_IDENTITY = 115
223    CERTIFICATE_REQUIRED = 116
224    NO_APPLICATION_PROTOCOL = 120
225
226
227@_simple_enum(_IntEnum)
228class _TLSMessageType:
229    """Message types (handshake protocol)
230
231    See RFC 8446, section B.3
232    """
233    HELLO_REQUEST = 0
234    CLIENT_HELLO = 1
235    SERVER_HELLO = 2
236    HELLO_VERIFY_REQUEST = 3
237    NEWSESSION_TICKET = 4
238    END_OF_EARLY_DATA = 5
239    HELLO_RETRY_REQUEST = 6
240    ENCRYPTED_EXTENSIONS = 8
241    CERTIFICATE = 11
242    SERVER_KEY_EXCHANGE = 12
243    CERTIFICATE_REQUEST = 13
244    SERVER_DONE = 14
245    CERTIFICATE_VERIFY = 15
246    CLIENT_KEY_EXCHANGE = 16
247    FINISHED = 20
248    CERTIFICATE_URL = 21
249    CERTIFICATE_STATUS = 22
250    SUPPLEMENTAL_DATA = 23
251    KEY_UPDATE = 24
252    NEXT_PROTO = 67
253    MESSAGE_HASH = 254
254    CHANGE_CIPHER_SPEC = 0x0101
255
256
257if sys.platform == "win32":
258    from _ssl import enum_certificates, enum_crls
259
260from socket import socket, SOCK_STREAM, create_connection
261from socket import SOL_SOCKET, SO_TYPE, _GLOBAL_DEFAULT_TIMEOUT
262import socket as _socket
263import base64        # for DER-to-PEM translation
264import errno
265import warnings
266
267
268socket_error = OSError  # keep that public name in module namespace
269
270CHANNEL_BINDING_TYPES = ['tls-unique']
271
272HAS_NEVER_CHECK_COMMON_NAME = hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT')
273
274
275_RESTRICTED_SERVER_CIPHERS = _DEFAULT_CIPHERS
276
277CertificateError = SSLCertVerificationError
278
279
280def _dnsname_match(dn, hostname):
281    """Matching according to RFC 6125, section 6.4.3
282
283    - Hostnames are compared lower-case.
284    - For IDNA, both dn and hostname must be encoded as IDN A-label (ACE).
285    - Partial wildcards like 'www*.example.org', multiple wildcards, sole
286      wildcard or wildcards in labels other then the left-most label are not
287      supported and a CertificateError is raised.
288    - A wildcard must match at least one character.
289    """
290    if not dn:
291        return False
292
293    wildcards = dn.count('*')
294    # speed up common case w/o wildcards
295    if not wildcards:
296        return dn.lower() == hostname.lower()
297
298    if wildcards > 1:
299        raise CertificateError(
300            "too many wildcards in certificate DNS name: {!r}.".format(dn))
301
302    dn_leftmost, sep, dn_remainder = dn.partition('.')
303
304    if '*' in dn_remainder:
305        # Only match wildcard in leftmost segment.
306        raise CertificateError(
307            "wildcard can only be present in the leftmost label: "
308            "{!r}.".format(dn))
309
310    if not sep:
311        # no right side
312        raise CertificateError(
313            "sole wildcard without additional labels are not support: "
314            "{!r}.".format(dn))
315
316    if dn_leftmost != '*':
317        # no partial wildcard matching
318        raise CertificateError(
319            "partial wildcards in leftmost label are not supported: "
320            "{!r}.".format(dn))
321
322    hostname_leftmost, sep, hostname_remainder = hostname.partition('.')
323    if not hostname_leftmost or not sep:
324        # wildcard must match at least one char
325        return False
326    return dn_remainder.lower() == hostname_remainder.lower()
327
328
329def _inet_paton(ipname):
330    """Try to convert an IP address to packed binary form
331
332    Supports IPv4 addresses on all platforms and IPv6 on platforms with IPv6
333    support.
334    """
335    # inet_aton() also accepts strings like '1', '127.1', some also trailing
336    # data like '127.0.0.1 whatever'.
337    try:
338        addr = _socket.inet_aton(ipname)
339    except OSError:
340        # not an IPv4 address
341        pass
342    else:
343        if _socket.inet_ntoa(addr) == ipname:
344            # only accept injective ipnames
345            return addr
346        else:
347            # refuse for short IPv4 notation and additional trailing data
348            raise ValueError(
349                "{!r} is not a quad-dotted IPv4 address.".format(ipname)
350            )
351
352    try:
353        return _socket.inet_pton(_socket.AF_INET6, ipname)
354    except OSError:
355        raise ValueError("{!r} is neither an IPv4 nor an IP6 "
356                         "address.".format(ipname))
357    except AttributeError:
358        # AF_INET6 not available
359        pass
360
361    raise ValueError("{!r} is not an IPv4 address.".format(ipname))
362
363
364def _ipaddress_match(cert_ipaddress, host_ip):
365    """Exact matching of IP addresses.
366
367    RFC 6125 explicitly doesn't define an algorithm for this
368    (section 1.7.2 - "Out of Scope").
369    """
370    # OpenSSL may add a trailing newline to a subjectAltName's IP address,
371    # commonly with IPv6 addresses. Strip off trailing \n.
372    ip = _inet_paton(cert_ipaddress.rstrip())
373    return ip == host_ip
374
375
376DefaultVerifyPaths = namedtuple("DefaultVerifyPaths",
377    "cafile capath openssl_cafile_env openssl_cafile openssl_capath_env "
378    "openssl_capath")
379
380def get_default_verify_paths():
381    """Return paths to default cafile and capath.
382    """
383    parts = _ssl.get_default_verify_paths()
384
385    # environment vars shadow paths
386    cafile = os.environ.get(parts[0], parts[1])
387    capath = os.environ.get(parts[2], parts[3])
388
389    return DefaultVerifyPaths(cafile if os.path.isfile(cafile) else None,
390                              capath if os.path.isdir(capath) else None,
391                              *parts)
392
393
394class _ASN1Object(namedtuple("_ASN1Object", "nid shortname longname oid")):
395    """ASN.1 object identifier lookup
396    """
397    __slots__ = ()
398
399    def __new__(cls, oid):
400        return super().__new__(cls, *_txt2obj(oid, name=False))
401
402    @classmethod
403    def fromnid(cls, nid):
404        """Create _ASN1Object from OpenSSL numeric ID
405        """
406        return super().__new__(cls, *_nid2obj(nid))
407
408    @classmethod
409    def fromname(cls, name):
410        """Create _ASN1Object from short name, long name or OID
411        """
412        return super().__new__(cls, *_txt2obj(name, name=True))
413
414
415class Purpose(_ASN1Object, _Enum):
416    """SSLContext purpose flags with X509v3 Extended Key Usage objects
417    """
418    SERVER_AUTH = '1.3.6.1.5.5.7.3.1'
419    CLIENT_AUTH = '1.3.6.1.5.5.7.3.2'
420
421
422class SSLContext(_SSLContext):
423    """An SSLContext holds various SSL-related configuration options and
424    data, such as certificates and possibly a private key."""
425    _windows_cert_stores = ("CA", "ROOT")
426
427    sslsocket_class = None  # SSLSocket is assigned later.
428    sslobject_class = None  # SSLObject is assigned later.
429
430    def __new__(cls, protocol=None, *args, **kwargs):
431        if protocol is None:
432            warnings.warn(
433                "ssl.SSLContext() without protocol argument is deprecated.",
434                category=DeprecationWarning,
435                stacklevel=2
436            )
437            protocol = PROTOCOL_TLS
438        self = _SSLContext.__new__(cls, protocol)
439        return self
440
441    def _encode_hostname(self, hostname):
442        if hostname is None:
443            return None
444        elif isinstance(hostname, str):
445            return hostname.encode('idna').decode('ascii')
446        else:
447            return hostname.decode('ascii')
448
449    def wrap_socket(self, sock, server_side=False,
450                    do_handshake_on_connect=True,
451                    suppress_ragged_eofs=True,
452                    server_hostname=None, session=None):
453        # SSLSocket class handles server_hostname encoding before it calls
454        # ctx._wrap_socket()
455        return self.sslsocket_class._create(
456            sock=sock,
457            server_side=server_side,
458            do_handshake_on_connect=do_handshake_on_connect,
459            suppress_ragged_eofs=suppress_ragged_eofs,
460            server_hostname=server_hostname,
461            context=self,
462            session=session
463        )
464
465    def wrap_bio(self, incoming, outgoing, server_side=False,
466                 server_hostname=None, session=None):
467        # Need to encode server_hostname here because _wrap_bio() can only
468        # handle ASCII str.
469        return self.sslobject_class._create(
470            incoming, outgoing, server_side=server_side,
471            server_hostname=self._encode_hostname(server_hostname),
472            session=session, context=self,
473        )
474
475    def set_npn_protocols(self, npn_protocols):
476        warnings.warn(
477            "ssl NPN is deprecated, use ALPN instead",
478            DeprecationWarning,
479            stacklevel=2
480        )
481        protos = bytearray()
482        for protocol in npn_protocols:
483            b = bytes(protocol, 'ascii')
484            if len(b) == 0 or len(b) > 255:
485                raise SSLError('NPN protocols must be 1 to 255 in length')
486            protos.append(len(b))
487            protos.extend(b)
488
489        self._set_npn_protocols(protos)
490
491    def set_servername_callback(self, server_name_callback):
492        if server_name_callback is None:
493            self.sni_callback = None
494        else:
495            if not callable(server_name_callback):
496                raise TypeError("not a callable object")
497
498            def shim_cb(sslobj, servername, sslctx):
499                servername = self._encode_hostname(servername)
500                return server_name_callback(sslobj, servername, sslctx)
501
502            self.sni_callback = shim_cb
503
504    def set_alpn_protocols(self, alpn_protocols):
505        protos = bytearray()
506        for protocol in alpn_protocols:
507            b = bytes(protocol, 'ascii')
508            if len(b) == 0 or len(b) > 255:
509                raise SSLError('ALPN protocols must be 1 to 255 in length')
510            protos.append(len(b))
511            protos.extend(b)
512
513        self._set_alpn_protocols(protos)
514
515    def _load_windows_store_certs(self, storename, purpose):
516        try:
517            for cert, encoding, trust in enum_certificates(storename):
518                # CA certs are never PKCS#7 encoded
519                if encoding == "x509_asn":
520                    if trust is True or purpose.oid in trust:
521                        try:
522                            self.load_verify_locations(cadata=cert)
523                        except SSLError as exc:
524                            warnings.warn(f"Bad certificate in Windows certificate store: {exc!s}")
525        except PermissionError:
526            warnings.warn("unable to enumerate Windows certificate store")
527
528    def load_default_certs(self, purpose=Purpose.SERVER_AUTH):
529        if not isinstance(purpose, _ASN1Object):
530            raise TypeError(purpose)
531        if sys.platform == "win32":
532            for storename in self._windows_cert_stores:
533                self._load_windows_store_certs(storename, purpose)
534        self.set_default_verify_paths()
535
536    if hasattr(_SSLContext, 'minimum_version'):
537        @property
538        def minimum_version(self):
539            return TLSVersion(super().minimum_version)
540
541        @minimum_version.setter
542        def minimum_version(self, value):
543            if value == TLSVersion.SSLv3:
544                self.options &= ~Options.OP_NO_SSLv3
545            super(SSLContext, SSLContext).minimum_version.__set__(self, value)
546
547        @property
548        def maximum_version(self):
549            return TLSVersion(super().maximum_version)
550
551        @maximum_version.setter
552        def maximum_version(self, value):
553            super(SSLContext, SSLContext).maximum_version.__set__(self, value)
554
555    @property
556    def options(self):
557        return Options(super().options)
558
559    @options.setter
560    def options(self, value):
561        super(SSLContext, SSLContext).options.__set__(self, value)
562
563    if hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT'):
564        @property
565        def hostname_checks_common_name(self):
566            ncs = self._host_flags & _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
567            return ncs != _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
568
569        @hostname_checks_common_name.setter
570        def hostname_checks_common_name(self, value):
571            if value:
572                self._host_flags &= ~_ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
573            else:
574                self._host_flags |= _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
575    else:
576        @property
577        def hostname_checks_common_name(self):
578            return True
579
580    @property
581    def _msg_callback(self):
582        """TLS message callback
583
584        The message callback provides a debugging hook to analyze TLS
585        connections. The callback is called for any TLS protocol message
586        (header, handshake, alert, and more), but not for application data.
587        Due to technical  limitations, the callback can't be used to filter
588        traffic or to abort a connection. Any exception raised in the
589        callback is delayed until the handshake, read, or write operation
590        has been performed.
591
592        def msg_cb(conn, direction, version, content_type, msg_type, data):
593            pass
594
595        conn
596            :class:`SSLSocket` or :class:`SSLObject` instance
597        direction
598            ``read`` or ``write``
599        version
600            :class:`TLSVersion` enum member or int for unknown version. For a
601            frame header, it's the header version.
602        content_type
603            :class:`_TLSContentType` enum member or int for unsupported
604            content type.
605        msg_type
606            Either a :class:`_TLSContentType` enum number for a header
607            message, a :class:`_TLSAlertType` enum member for an alert
608            message, a :class:`_TLSMessageType` enum member for other
609            messages, or int for unsupported message types.
610        data
611            Raw, decrypted message content as bytes
612        """
613        inner = super()._msg_callback
614        if inner is not None:
615            return inner.user_function
616        else:
617            return None
618
619    @_msg_callback.setter
620    def _msg_callback(self, callback):
621        if callback is None:
622            super(SSLContext, SSLContext)._msg_callback.__set__(self, None)
623            return
624
625        if not hasattr(callback, '__call__'):
626            raise TypeError(f"{callback} is not callable.")
627
628        def inner(conn, direction, version, content_type, msg_type, data):
629            try:
630                version = TLSVersion(version)
631            except ValueError:
632                pass
633
634            try:
635                content_type = _TLSContentType(content_type)
636            except ValueError:
637                pass
638
639            if content_type == _TLSContentType.HEADER:
640                msg_enum = _TLSContentType
641            elif content_type == _TLSContentType.ALERT:
642                msg_enum = _TLSAlertType
643            else:
644                msg_enum = _TLSMessageType
645            try:
646                msg_type = msg_enum(msg_type)
647            except ValueError:
648                pass
649
650            return callback(conn, direction, version,
651                            content_type, msg_type, data)
652
653        inner.user_function = callback
654
655        super(SSLContext, SSLContext)._msg_callback.__set__(self, inner)
656
657    @property
658    def protocol(self):
659        return _SSLMethod(super().protocol)
660
661    @property
662    def verify_flags(self):
663        return VerifyFlags(super().verify_flags)
664
665    @verify_flags.setter
666    def verify_flags(self, value):
667        super(SSLContext, SSLContext).verify_flags.__set__(self, value)
668
669    @property
670    def verify_mode(self):
671        value = super().verify_mode
672        try:
673            return VerifyMode(value)
674        except ValueError:
675            return value
676
677    @verify_mode.setter
678    def verify_mode(self, value):
679        super(SSLContext, SSLContext).verify_mode.__set__(self, value)
680
681
682def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None,
683                           capath=None, cadata=None):
684    """Create a SSLContext object with default settings.
685
686    NOTE: The protocol and settings may change anytime without prior
687          deprecation. The values represent a fair balance between maximum
688          compatibility and security.
689    """
690    if not isinstance(purpose, _ASN1Object):
691        raise TypeError(purpose)
692
693    # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION,
694    # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE
695    # by default.
696    if purpose == Purpose.SERVER_AUTH:
697        # verify certs and host name in client mode
698        context = SSLContext(PROTOCOL_TLS_CLIENT)
699        context.verify_mode = CERT_REQUIRED
700        context.check_hostname = True
701    elif purpose == Purpose.CLIENT_AUTH:
702        context = SSLContext(PROTOCOL_TLS_SERVER)
703    else:
704        raise ValueError(purpose)
705
706    # `VERIFY_X509_PARTIAL_CHAIN` makes OpenSSL's chain building behave more
707    # like RFC 3280 and 5280, which specify that chain building stops with the
708    # first trust anchor, even if that anchor is not self-signed.
709    #
710    # `VERIFY_X509_STRICT` makes OpenSSL more conservative about the
711    # certificates it accepts, including "disabling workarounds for
712    # some broken certificates."
713    context.verify_flags |= (_ssl.VERIFY_X509_PARTIAL_CHAIN |
714                             _ssl.VERIFY_X509_STRICT)
715
716    if cafile or capath or cadata:
717        context.load_verify_locations(cafile, capath, cadata)
718    elif context.verify_mode != CERT_NONE:
719        # no explicit cafile, capath or cadata but the verify mode is
720        # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
721        # root CA certificates for the given purpose. This may fail silently.
722        context.load_default_certs(purpose)
723    # OpenSSL 1.1.1 keylog file
724    if hasattr(context, 'keylog_filename'):
725        keylogfile = os.environ.get('SSLKEYLOGFILE')
726        if keylogfile and not sys.flags.ignore_environment:
727            context.keylog_filename = keylogfile
728    return context
729
730def _create_unverified_context(protocol=None, *, cert_reqs=CERT_NONE,
731                           check_hostname=False, purpose=Purpose.SERVER_AUTH,
732                           certfile=None, keyfile=None,
733                           cafile=None, capath=None, cadata=None):
734    """Create a SSLContext object for Python stdlib modules
735
736    All Python stdlib modules shall use this function to create SSLContext
737    objects in order to keep common settings in one place. The configuration
738    is less restrict than create_default_context()'s to increase backward
739    compatibility.
740    """
741    if not isinstance(purpose, _ASN1Object):
742        raise TypeError(purpose)
743
744    # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION,
745    # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE
746    # by default.
747    if purpose == Purpose.SERVER_AUTH:
748        # verify certs and host name in client mode
749        if protocol is None:
750            protocol = PROTOCOL_TLS_CLIENT
751    elif purpose == Purpose.CLIENT_AUTH:
752        if protocol is None:
753            protocol = PROTOCOL_TLS_SERVER
754    else:
755        raise ValueError(purpose)
756
757    context = SSLContext(protocol)
758    context.check_hostname = check_hostname
759    if cert_reqs is not None:
760        context.verify_mode = cert_reqs
761    if check_hostname:
762        context.check_hostname = True
763
764    if keyfile and not certfile:
765        raise ValueError("certfile must be specified")
766    if certfile or keyfile:
767        context.load_cert_chain(certfile, keyfile)
768
769    # load CA root certs
770    if cafile or capath or cadata:
771        context.load_verify_locations(cafile, capath, cadata)
772    elif context.verify_mode != CERT_NONE:
773        # no explicit cafile, capath or cadata but the verify mode is
774        # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
775        # root CA certificates for the given purpose. This may fail silently.
776        context.load_default_certs(purpose)
777    # OpenSSL 1.1.1 keylog file
778    if hasattr(context, 'keylog_filename'):
779        keylogfile = os.environ.get('SSLKEYLOGFILE')
780        if keylogfile and not sys.flags.ignore_environment:
781            context.keylog_filename = keylogfile
782    return context
783
784# Used by http.client if no context is explicitly passed.
785_create_default_https_context = create_default_context
786
787
788# Backwards compatibility alias, even though it's not a public name.
789_create_stdlib_context = _create_unverified_context
790
791
792class SSLObject:
793    """This class implements an interface on top of a low-level SSL object as
794    implemented by OpenSSL. This object captures the state of an SSL connection
795    but does not provide any network IO itself. IO needs to be performed
796    through separate "BIO" objects which are OpenSSL's IO abstraction layer.
797
798    This class does not have a public constructor. Instances are returned by
799    ``SSLContext.wrap_bio``. This class is typically used by framework authors
800    that want to implement asynchronous IO for SSL through memory buffers.
801
802    When compared to ``SSLSocket``, this object lacks the following features:
803
804     * Any form of network IO, including methods such as ``recv`` and ``send``.
805     * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
806    """
807    def __init__(self, *args, **kwargs):
808        raise TypeError(
809            f"{self.__class__.__name__} does not have a public "
810            f"constructor. Instances are returned by SSLContext.wrap_bio()."
811        )
812
813    @classmethod
814    def _create(cls, incoming, outgoing, server_side=False,
815                 server_hostname=None, session=None, context=None):
816        self = cls.__new__(cls)
817        sslobj = context._wrap_bio(
818            incoming, outgoing, server_side=server_side,
819            server_hostname=server_hostname,
820            owner=self, session=session
821        )
822        self._sslobj = sslobj
823        return self
824
825    @property
826    def context(self):
827        """The SSLContext that is currently in use."""
828        return self._sslobj.context
829
830    @context.setter
831    def context(self, ctx):
832        self._sslobj.context = ctx
833
834    @property
835    def session(self):
836        """The SSLSession for client socket."""
837        return self._sslobj.session
838
839    @session.setter
840    def session(self, session):
841        self._sslobj.session = session
842
843    @property
844    def session_reused(self):
845        """Was the client session reused during handshake"""
846        return self._sslobj.session_reused
847
848    @property
849    def server_side(self):
850        """Whether this is a server-side socket."""
851        return self._sslobj.server_side
852
853    @property
854    def server_hostname(self):
855        """The currently set server hostname (for SNI), or ``None`` if no
856        server hostname is set."""
857        return self._sslobj.server_hostname
858
859    def read(self, len=1024, buffer=None):
860        """Read up to 'len' bytes from the SSL object and return them.
861
862        If 'buffer' is provided, read into this buffer and return the number of
863        bytes read.
864        """
865        if buffer is not None:
866            v = self._sslobj.read(len, buffer)
867        else:
868            v = self._sslobj.read(len)
869        return v
870
871    def write(self, data):
872        """Write 'data' to the SSL object and return the number of bytes
873        written.
874
875        The 'data' argument must support the buffer interface.
876        """
877        return self._sslobj.write(data)
878
879    def getpeercert(self, binary_form=False):
880        """Returns a formatted version of the data in the certificate provided
881        by the other end of the SSL channel.
882
883        Return None if no certificate was provided, {} if a certificate was
884        provided, but not validated.
885        """
886        return self._sslobj.getpeercert(binary_form)
887
888    def get_verified_chain(self):
889        """Returns verified certificate chain provided by the other
890        end of the SSL channel as a list of DER-encoded bytes.
891
892        If certificate verification was disabled method acts the same as
893        ``SSLSocket.get_unverified_chain``.
894        """
895        chain = self._sslobj.get_verified_chain()
896
897        if chain is None:
898            return []
899
900        return [cert.public_bytes(_ssl.ENCODING_DER) for cert in chain]
901
902    def get_unverified_chain(self):
903        """Returns raw certificate chain provided by the other
904        end of the SSL channel as a list of DER-encoded bytes.
905        """
906        chain = self._sslobj.get_unverified_chain()
907
908        if chain is None:
909            return []
910
911        return [cert.public_bytes(_ssl.ENCODING_DER) for cert in chain]
912
913    def selected_npn_protocol(self):
914        """Return the currently selected NPN protocol as a string, or ``None``
915        if a next protocol was not negotiated or if NPN is not supported by one
916        of the peers."""
917        warnings.warn(
918            "ssl NPN is deprecated, use ALPN instead",
919            DeprecationWarning,
920            stacklevel=2
921        )
922
923    def selected_alpn_protocol(self):
924        """Return the currently selected ALPN protocol as a string, or ``None``
925        if a next protocol was not negotiated or if ALPN is not supported by one
926        of the peers."""
927        return self._sslobj.selected_alpn_protocol()
928
929    def cipher(self):
930        """Return the currently selected cipher as a 3-tuple ``(name,
931        ssl_version, secret_bits)``."""
932        return self._sslobj.cipher()
933
934    def shared_ciphers(self):
935        """Return a list of ciphers shared by the client during the handshake or
936        None if this is not a valid server connection.
937        """
938        return self._sslobj.shared_ciphers()
939
940    def compression(self):
941        """Return the current compression algorithm in use, or ``None`` if
942        compression was not negotiated or not supported by one of the peers."""
943        return self._sslobj.compression()
944
945    def pending(self):
946        """Return the number of bytes that can be read immediately."""
947        return self._sslobj.pending()
948
949    def do_handshake(self):
950        """Start the SSL/TLS handshake."""
951        self._sslobj.do_handshake()
952
953    def unwrap(self):
954        """Start the SSL shutdown handshake."""
955        return self._sslobj.shutdown()
956
957    def get_channel_binding(self, cb_type="tls-unique"):
958        """Get channel binding data for current connection.  Raise ValueError
959        if the requested `cb_type` is not supported.  Return bytes of the data
960        or None if the data is not available (e.g. before the handshake)."""
961        return self._sslobj.get_channel_binding(cb_type)
962
963    def version(self):
964        """Return a string identifying the protocol version used by the
965        current SSL channel. """
966        return self._sslobj.version()
967
968    def verify_client_post_handshake(self):
969        return self._sslobj.verify_client_post_handshake()
970
971
972def _sslcopydoc(func):
973    """Copy docstring from SSLObject to SSLSocket"""
974    func.__doc__ = getattr(SSLObject, func.__name__).__doc__
975    return func
976
977
978class SSLSocket(socket):
979    """This class implements a subtype of socket.socket that wraps
980    the underlying OS socket in an SSL context when necessary, and
981    provides read and write methods over that channel. """
982
983    def __init__(self, *args, **kwargs):
984        raise TypeError(
985            f"{self.__class__.__name__} does not have a public "
986            f"constructor. Instances are returned by "
987            f"SSLContext.wrap_socket()."
988        )
989
990    @classmethod
991    def _create(cls, sock, server_side=False, do_handshake_on_connect=True,
992                suppress_ragged_eofs=True, server_hostname=None,
993                context=None, session=None):
994        if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
995            raise NotImplementedError("only stream sockets are supported")
996        if server_side:
997            if server_hostname:
998                raise ValueError("server_hostname can only be specified "
999                                 "in client mode")
1000            if session is not None:
1001                raise ValueError("session can only be specified in "
1002                                 "client mode")
1003        if context.check_hostname and not server_hostname:
1004            raise ValueError("check_hostname requires server_hostname")
1005
1006        sock_timeout = sock.gettimeout()
1007        kwargs = dict(
1008            family=sock.family, type=sock.type, proto=sock.proto,
1009            fileno=sock.fileno()
1010        )
1011        self = cls.__new__(cls, **kwargs)
1012        super(SSLSocket, self).__init__(**kwargs)
1013        sock.detach()
1014        # Now SSLSocket is responsible for closing the file descriptor.
1015        try:
1016            self._context = context
1017            self._session = session
1018            self._closed = False
1019            self._sslobj = None
1020            self.server_side = server_side
1021            self.server_hostname = context._encode_hostname(server_hostname)
1022            self.do_handshake_on_connect = do_handshake_on_connect
1023            self.suppress_ragged_eofs = suppress_ragged_eofs
1024
1025            # See if we are connected
1026            try:
1027                self.getpeername()
1028            except OSError as e:
1029                if e.errno != errno.ENOTCONN:
1030                    raise
1031                connected = False
1032                blocking = self.getblocking()
1033                self.setblocking(False)
1034                try:
1035                    # We are not connected so this is not supposed to block, but
1036                    # testing revealed otherwise on macOS and Windows so we do
1037                    # the non-blocking dance regardless. Our raise when any data
1038                    # is found means consuming the data is harmless.
1039                    notconn_pre_handshake_data = self.recv(1)
1040                except OSError as e:
1041                    # EINVAL occurs for recv(1) on non-connected on unix sockets.
1042                    if e.errno not in (errno.ENOTCONN, errno.EINVAL):
1043                        raise
1044                    notconn_pre_handshake_data = b''
1045                self.setblocking(blocking)
1046                if notconn_pre_handshake_data:
1047                    # This prevents pending data sent to the socket before it was
1048                    # closed from escaping to the caller who could otherwise
1049                    # presume it came through a successful TLS connection.
1050                    reason = "Closed before TLS handshake with data in recv buffer."
1051                    notconn_pre_handshake_data_error = SSLError(e.errno, reason)
1052                    # Add the SSLError attributes that _ssl.c always adds.
1053                    notconn_pre_handshake_data_error.reason = reason
1054                    notconn_pre_handshake_data_error.library = None
1055                    try:
1056                        raise notconn_pre_handshake_data_error
1057                    finally:
1058                        # Explicitly break the reference cycle.
1059                        notconn_pre_handshake_data_error = None
1060            else:
1061                connected = True
1062
1063            self.settimeout(sock_timeout)  # Must come after setblocking() calls.
1064            self._connected = connected
1065            if connected:
1066                # create the SSL object
1067                self._sslobj = self._context._wrap_socket(
1068                    self, server_side, self.server_hostname,
1069                    owner=self, session=self._session,
1070                )
1071                if do_handshake_on_connect:
1072                    timeout = self.gettimeout()
1073                    if timeout == 0.0:
1074                        # non-blocking
1075                        raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
1076                    self.do_handshake()
1077        except:
1078            try:
1079                self.close()
1080            except OSError:
1081                pass
1082            raise
1083        return self
1084
1085    @property
1086    @_sslcopydoc
1087    def context(self):
1088        return self._context
1089
1090    @context.setter
1091    def context(self, ctx):
1092        self._context = ctx
1093        self._sslobj.context = ctx
1094
1095    @property
1096    @_sslcopydoc
1097    def session(self):
1098        if self._sslobj is not None:
1099            return self._sslobj.session
1100
1101    @session.setter
1102    def session(self, session):
1103        self._session = session
1104        if self._sslobj is not None:
1105            self._sslobj.session = session
1106
1107    @property
1108    @_sslcopydoc
1109    def session_reused(self):
1110        if self._sslobj is not None:
1111            return self._sslobj.session_reused
1112
1113    def dup(self):
1114        raise NotImplementedError("Can't dup() %s instances" %
1115                                  self.__class__.__name__)
1116
1117    def _checkClosed(self, msg=None):
1118        # raise an exception here if you wish to check for spurious closes
1119        pass
1120
1121    def _check_connected(self):
1122        if not self._connected:
1123            # getpeername() will raise ENOTCONN if the socket is really
1124            # not connected; note that we can be connected even without
1125            # _connected being set, e.g. if connect() first returned
1126            # EAGAIN.
1127            self.getpeername()
1128
1129    def read(self, len=1024, buffer=None):
1130        """Read up to LEN bytes and return them.
1131        Return zero-length string on EOF."""
1132
1133        self._checkClosed()
1134        if self._sslobj is None:
1135            raise ValueError("Read on closed or unwrapped SSL socket.")
1136        try:
1137            if buffer is not None:
1138                return self._sslobj.read(len, buffer)
1139            else:
1140                return self._sslobj.read(len)
1141        except SSLError as x:
1142            if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
1143                if buffer is not None:
1144                    return 0
1145                else:
1146                    return b''
1147            else:
1148                raise
1149
1150    def write(self, data):
1151        """Write DATA to the underlying SSL channel.  Returns
1152        number of bytes of DATA actually transmitted."""
1153
1154        self._checkClosed()
1155        if self._sslobj is None:
1156            raise ValueError("Write on closed or unwrapped SSL socket.")
1157        return self._sslobj.write(data)
1158
1159    @_sslcopydoc
1160    def getpeercert(self, binary_form=False):
1161        self._checkClosed()
1162        self._check_connected()
1163        return self._sslobj.getpeercert(binary_form)
1164
1165    @_sslcopydoc
1166    def get_verified_chain(self):
1167        chain = self._sslobj.get_verified_chain()
1168
1169        if chain is None:
1170            return []
1171
1172        return [cert.public_bytes(_ssl.ENCODING_DER) for cert in chain]
1173
1174    @_sslcopydoc
1175    def get_unverified_chain(self):
1176        chain = self._sslobj.get_unverified_chain()
1177
1178        if chain is None:
1179            return []
1180
1181        return [cert.public_bytes(_ssl.ENCODING_DER) for cert in chain]
1182
1183    @_sslcopydoc
1184    def selected_npn_protocol(self):
1185        self._checkClosed()
1186        warnings.warn(
1187            "ssl NPN is deprecated, use ALPN instead",
1188            DeprecationWarning,
1189            stacklevel=2
1190        )
1191        return None
1192
1193    @_sslcopydoc
1194    def selected_alpn_protocol(self):
1195        self._checkClosed()
1196        if self._sslobj is None or not _ssl.HAS_ALPN:
1197            return None
1198        else:
1199            return self._sslobj.selected_alpn_protocol()
1200
1201    @_sslcopydoc
1202    def cipher(self):
1203        self._checkClosed()
1204        if self._sslobj is None:
1205            return None
1206        else:
1207            return self._sslobj.cipher()
1208
1209    @_sslcopydoc
1210    def shared_ciphers(self):
1211        self._checkClosed()
1212        if self._sslobj is None:
1213            return None
1214        else:
1215            return self._sslobj.shared_ciphers()
1216
1217    @_sslcopydoc
1218    def compression(self):
1219        self._checkClosed()
1220        if self._sslobj is None:
1221            return None
1222        else:
1223            return self._sslobj.compression()
1224
1225    def send(self, data, flags=0):
1226        self._checkClosed()
1227        if self._sslobj is not None:
1228            if flags != 0:
1229                raise ValueError(
1230                    "non-zero flags not allowed in calls to send() on %s" %
1231                    self.__class__)
1232            return self._sslobj.write(data)
1233        else:
1234            return super().send(data, flags)
1235
1236    def sendto(self, data, flags_or_addr, addr=None):
1237        self._checkClosed()
1238        if self._sslobj is not None:
1239            raise ValueError("sendto not allowed on instances of %s" %
1240                             self.__class__)
1241        elif addr is None:
1242            return super().sendto(data, flags_or_addr)
1243        else:
1244            return super().sendto(data, flags_or_addr, addr)
1245
1246    def sendmsg(self, *args, **kwargs):
1247        # Ensure programs don't send data unencrypted if they try to
1248        # use this method.
1249        raise NotImplementedError("sendmsg not allowed on instances of %s" %
1250                                  self.__class__)
1251
1252    def sendall(self, data, flags=0):
1253        self._checkClosed()
1254        if self._sslobj is not None:
1255            if flags != 0:
1256                raise ValueError(
1257                    "non-zero flags not allowed in calls to sendall() on %s" %
1258                    self.__class__)
1259            count = 0
1260            with memoryview(data) as view, view.cast("B") as byte_view:
1261                amount = len(byte_view)
1262                while count < amount:
1263                    v = self.send(byte_view[count:])
1264                    count += v
1265        else:
1266            return super().sendall(data, flags)
1267
1268    def sendfile(self, file, offset=0, count=None):
1269        """Send a file, possibly by using os.sendfile() if this is a
1270        clear-text socket.  Return the total number of bytes sent.
1271        """
1272        if self._sslobj is not None:
1273            return self._sendfile_use_send(file, offset, count)
1274        else:
1275            # os.sendfile() works with plain sockets only
1276            return super().sendfile(file, offset, count)
1277
1278    def recv(self, buflen=1024, flags=0):
1279        self._checkClosed()
1280        if self._sslobj is not None:
1281            if flags != 0:
1282                raise ValueError(
1283                    "non-zero flags not allowed in calls to recv() on %s" %
1284                    self.__class__)
1285            return self.read(buflen)
1286        else:
1287            return super().recv(buflen, flags)
1288
1289    def recv_into(self, buffer, nbytes=None, flags=0):
1290        self._checkClosed()
1291        if nbytes is None:
1292            if buffer is not None:
1293                with memoryview(buffer) as view:
1294                    nbytes = view.nbytes
1295                if not nbytes:
1296                    nbytes = 1024
1297            else:
1298                nbytes = 1024
1299        if self._sslobj is not None:
1300            if flags != 0:
1301                raise ValueError(
1302                  "non-zero flags not allowed in calls to recv_into() on %s" %
1303                  self.__class__)
1304            return self.read(nbytes, buffer)
1305        else:
1306            return super().recv_into(buffer, nbytes, flags)
1307
1308    def recvfrom(self, buflen=1024, flags=0):
1309        self._checkClosed()
1310        if self._sslobj is not None:
1311            raise ValueError("recvfrom not allowed on instances of %s" %
1312                             self.__class__)
1313        else:
1314            return super().recvfrom(buflen, flags)
1315
1316    def recvfrom_into(self, buffer, nbytes=None, flags=0):
1317        self._checkClosed()
1318        if self._sslobj is not None:
1319            raise ValueError("recvfrom_into not allowed on instances of %s" %
1320                             self.__class__)
1321        else:
1322            return super().recvfrom_into(buffer, nbytes, flags)
1323
1324    def recvmsg(self, *args, **kwargs):
1325        raise NotImplementedError("recvmsg not allowed on instances of %s" %
1326                                  self.__class__)
1327
1328    def recvmsg_into(self, *args, **kwargs):
1329        raise NotImplementedError("recvmsg_into not allowed on instances of "
1330                                  "%s" % self.__class__)
1331
1332    @_sslcopydoc
1333    def pending(self):
1334        self._checkClosed()
1335        if self._sslobj is not None:
1336            return self._sslobj.pending()
1337        else:
1338            return 0
1339
1340    def shutdown(self, how):
1341        self._checkClosed()
1342        self._sslobj = None
1343        super().shutdown(how)
1344
1345    @_sslcopydoc
1346    def unwrap(self):
1347        if self._sslobj:
1348            s = self._sslobj.shutdown()
1349            self._sslobj = None
1350            return s
1351        else:
1352            raise ValueError("No SSL wrapper around " + str(self))
1353
1354    @_sslcopydoc
1355    def verify_client_post_handshake(self):
1356        if self._sslobj:
1357            return self._sslobj.verify_client_post_handshake()
1358        else:
1359            raise ValueError("No SSL wrapper around " + str(self))
1360
1361    def _real_close(self):
1362        self._sslobj = None
1363        super()._real_close()
1364
1365    @_sslcopydoc
1366    def do_handshake(self, block=False):
1367        self._check_connected()
1368        timeout = self.gettimeout()
1369        try:
1370            if timeout == 0.0 and block:
1371                self.settimeout(None)
1372            self._sslobj.do_handshake()
1373        finally:
1374            self.settimeout(timeout)
1375
1376    def _real_connect(self, addr, connect_ex):
1377        if self.server_side:
1378            raise ValueError("can't connect in server-side mode")
1379        # Here we assume that the socket is client-side, and not
1380        # connected at the time of the call.  We connect it, then wrap it.
1381        if self._connected or self._sslobj is not None:
1382            raise ValueError("attempt to connect already-connected SSLSocket!")
1383        self._sslobj = self.context._wrap_socket(
1384            self, False, self.server_hostname,
1385            owner=self, session=self._session
1386        )
1387        try:
1388            if connect_ex:
1389                rc = super().connect_ex(addr)
1390            else:
1391                rc = None
1392                super().connect(addr)
1393            if not rc:
1394                self._connected = True
1395                if self.do_handshake_on_connect:
1396                    self.do_handshake()
1397            return rc
1398        except (OSError, ValueError):
1399            self._sslobj = None
1400            raise
1401
1402    def connect(self, addr):
1403        """Connects to remote ADDR, and then wraps the connection in
1404        an SSL channel."""
1405        self._real_connect(addr, False)
1406
1407    def connect_ex(self, addr):
1408        """Connects to remote ADDR, and then wraps the connection in
1409        an SSL channel."""
1410        return self._real_connect(addr, True)
1411
1412    def accept(self):
1413        """Accepts a new connection from a remote client, and returns
1414        a tuple containing that new connection wrapped with a server-side
1415        SSL channel, and the address of the remote client."""
1416
1417        newsock, addr = super().accept()
1418        newsock = self.context.wrap_socket(newsock,
1419                    do_handshake_on_connect=self.do_handshake_on_connect,
1420                    suppress_ragged_eofs=self.suppress_ragged_eofs,
1421                    server_side=True)
1422        return newsock, addr
1423
1424    @_sslcopydoc
1425    def get_channel_binding(self, cb_type="tls-unique"):
1426        if self._sslobj is not None:
1427            return self._sslobj.get_channel_binding(cb_type)
1428        else:
1429            if cb_type not in CHANNEL_BINDING_TYPES:
1430                raise ValueError(
1431                    "{0} channel binding type not implemented".format(cb_type)
1432                )
1433            return None
1434
1435    @_sslcopydoc
1436    def version(self):
1437        if self._sslobj is not None:
1438            return self._sslobj.version()
1439        else:
1440            return None
1441
1442
1443# Python does not support forward declaration of types.
1444SSLContext.sslsocket_class = SSLSocket
1445SSLContext.sslobject_class = SSLObject
1446
1447
1448# some utility functions
1449
1450def cert_time_to_seconds(cert_time):
1451    """Return the time in seconds since the Epoch, given the timestring
1452    representing the "notBefore" or "notAfter" date from a certificate
1453    in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale).
1454
1455    "notBefore" or "notAfter" dates must use UTC (RFC 5280).
1456
1457    Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec
1458    UTC should be specified as GMT (see ASN1_TIME_print())
1459    """
1460    from time import strptime
1461    from calendar import timegm
1462
1463    months = (
1464        "Jan","Feb","Mar","Apr","May","Jun",
1465        "Jul","Aug","Sep","Oct","Nov","Dec"
1466    )
1467    time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT
1468    try:
1469        month_number = months.index(cert_time[:3].title()) + 1
1470    except ValueError:
1471        raise ValueError('time data %r does not match '
1472                         'format "%%b%s"' % (cert_time, time_format))
1473    else:
1474        # found valid month
1475        tt = strptime(cert_time[3:], time_format)
1476        # return an integer, the previous mktime()-based implementation
1477        # returned a float (fractional seconds are always zero here).
1478        return timegm((tt[0], month_number) + tt[2:6])
1479
1480PEM_HEADER = "-----BEGIN CERTIFICATE-----"
1481PEM_FOOTER = "-----END CERTIFICATE-----"
1482
1483def DER_cert_to_PEM_cert(der_cert_bytes):
1484    """Takes a certificate in binary DER format and returns the
1485    PEM version of it as a string."""
1486
1487    f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict')
1488    ss = [PEM_HEADER]
1489    ss += [f[i:i+64] for i in range(0, len(f), 64)]
1490    ss.append(PEM_FOOTER + '\n')
1491    return '\n'.join(ss)
1492
1493def PEM_cert_to_DER_cert(pem_cert_string):
1494    """Takes a certificate in ASCII PEM format and returns the
1495    DER-encoded version of it as a byte sequence"""
1496
1497    if not pem_cert_string.startswith(PEM_HEADER):
1498        raise ValueError("Invalid PEM encoding; must start with %s"
1499                         % PEM_HEADER)
1500    if not pem_cert_string.strip().endswith(PEM_FOOTER):
1501        raise ValueError("Invalid PEM encoding; must end with %s"
1502                         % PEM_FOOTER)
1503    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
1504    return base64.decodebytes(d.encode('ASCII', 'strict'))
1505
1506def get_server_certificate(addr, ssl_version=PROTOCOL_TLS_CLIENT,
1507                           ca_certs=None, timeout=_GLOBAL_DEFAULT_TIMEOUT):
1508    """Retrieve the certificate from the server at the specified address,
1509    and return it as a PEM-encoded string.
1510    If 'ca_certs' is specified, validate the server cert against it.
1511    If 'ssl_version' is specified, use it in the connection attempt.
1512    If 'timeout' is specified, use it in the connection attempt.
1513    """
1514
1515    host, port = addr
1516    if ca_certs is not None:
1517        cert_reqs = CERT_REQUIRED
1518    else:
1519        cert_reqs = CERT_NONE
1520    context = _create_stdlib_context(ssl_version,
1521                                     cert_reqs=cert_reqs,
1522                                     cafile=ca_certs)
1523    with create_connection(addr, timeout=timeout) as sock:
1524        with context.wrap_socket(sock, server_hostname=host) as sslsock:
1525            dercert = sslsock.getpeercert(True)
1526    return DER_cert_to_PEM_cert(dercert)
1527
1528def get_protocol_name(protocol_code):
1529    return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')
1530