• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Wrapper module for _ssl, providing some additional facilities
2# implemented in Python.  Written by Bill Janssen.
3
4"""This module provides some more Pythonic support for SSL.
5
6Object types:
7
8  SSLSocket -- subtype of socket.socket which does SSL over the socket
9
10Exceptions:
11
12  SSLError -- exception raised for I/O errors
13
14Functions:
15
16  cert_time_to_seconds -- convert time string used for certificate
17                          notBefore and notAfter functions to integer
18                          seconds past the Epoch (the time values
19                          returned from time.time())
20
21  fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
22                          by the server running on HOST at port PORT.  No
23                          validation of the certificate is performed.
24
25Integer constants:
26
27SSL_ERROR_ZERO_RETURN
28SSL_ERROR_WANT_READ
29SSL_ERROR_WANT_WRITE
30SSL_ERROR_WANT_X509_LOOKUP
31SSL_ERROR_SYSCALL
32SSL_ERROR_SSL
33SSL_ERROR_WANT_CONNECT
34
35SSL_ERROR_EOF
36SSL_ERROR_INVALID_ERROR_CODE
37
38The following group define certificate requirements that one side is
39allowing/requiring from the other side:
40
41CERT_NONE - no certificates from the other side are required (or will
42            be looked at if provided)
43CERT_OPTIONAL - certificates are not required, but if provided will be
44                validated, and if validation fails, the connection will
45                also fail
46CERT_REQUIRED - certificates are required, and will be validated, and
47                if validation fails, the connection will also fail
48
49The following constants identify various SSL protocol variants:
50
51PROTOCOL_SSLv2
52PROTOCOL_SSLv3
53PROTOCOL_SSLv23
54PROTOCOL_TLS
55PROTOCOL_TLS_CLIENT
56PROTOCOL_TLS_SERVER
57PROTOCOL_TLSv1
58PROTOCOL_TLSv1_1
59PROTOCOL_TLSv1_2
60
61The following constants identify various SSL alert message descriptions as per
62http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6
63
64ALERT_DESCRIPTION_CLOSE_NOTIFY
65ALERT_DESCRIPTION_UNEXPECTED_MESSAGE
66ALERT_DESCRIPTION_BAD_RECORD_MAC
67ALERT_DESCRIPTION_RECORD_OVERFLOW
68ALERT_DESCRIPTION_DECOMPRESSION_FAILURE
69ALERT_DESCRIPTION_HANDSHAKE_FAILURE
70ALERT_DESCRIPTION_BAD_CERTIFICATE
71ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE
72ALERT_DESCRIPTION_CERTIFICATE_REVOKED
73ALERT_DESCRIPTION_CERTIFICATE_EXPIRED
74ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN
75ALERT_DESCRIPTION_ILLEGAL_PARAMETER
76ALERT_DESCRIPTION_UNKNOWN_CA
77ALERT_DESCRIPTION_ACCESS_DENIED
78ALERT_DESCRIPTION_DECODE_ERROR
79ALERT_DESCRIPTION_DECRYPT_ERROR
80ALERT_DESCRIPTION_PROTOCOL_VERSION
81ALERT_DESCRIPTION_INSUFFICIENT_SECURITY
82ALERT_DESCRIPTION_INTERNAL_ERROR
83ALERT_DESCRIPTION_USER_CANCELLED
84ALERT_DESCRIPTION_NO_RENEGOTIATION
85ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION
86ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE
87ALERT_DESCRIPTION_UNRECOGNIZED_NAME
88ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE
89ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
90ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY
91"""
92
93import sys
94import os
95from collections import namedtuple
96from enum import Enum as _Enum, IntEnum as _IntEnum, IntFlag as _IntFlag
97
98import _ssl             # if we can't import it, let the error propagate
99
100from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
101from _ssl import _SSLContext, MemoryBIO, SSLSession
102from _ssl import (
103    SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
104    SSLSyscallError, SSLEOFError, SSLCertVerificationError
105    )
106from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj
107from _ssl import RAND_status, RAND_add, RAND_bytes, RAND_pseudo_bytes
108try:
109    from _ssl import RAND_egd
110except ImportError:
111    # LibreSSL does not provide RAND_egd
112    pass
113
114
115from _ssl import (
116    HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_SSLv2, HAS_SSLv3, HAS_TLSv1,
117    HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3
118)
119from _ssl import _DEFAULT_CIPHERS, _OPENSSL_API_VERSION
120
121
122_IntEnum._convert_(
123    '_SSLMethod', __name__,
124    lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23',
125    source=_ssl)
126
127_IntFlag._convert_(
128    'Options', __name__,
129    lambda name: name.startswith('OP_'),
130    source=_ssl)
131
132_IntEnum._convert_(
133    'AlertDescription', __name__,
134    lambda name: name.startswith('ALERT_DESCRIPTION_'),
135    source=_ssl)
136
137_IntEnum._convert_(
138    'SSLErrorNumber', __name__,
139    lambda name: name.startswith('SSL_ERROR_'),
140    source=_ssl)
141
142_IntFlag._convert_(
143    'VerifyFlags', __name__,
144    lambda name: name.startswith('VERIFY_'),
145    source=_ssl)
146
147_IntEnum._convert_(
148    'VerifyMode', __name__,
149    lambda name: name.startswith('CERT_'),
150    source=_ssl)
151
152PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_TLS
153_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()}
154
155_SSLv2_IF_EXISTS = getattr(_SSLMethod, 'PROTOCOL_SSLv2', None)
156
157
158class TLSVersion(_IntEnum):
159    MINIMUM_SUPPORTED = _ssl.PROTO_MINIMUM_SUPPORTED
160    SSLv3 = _ssl.PROTO_SSLv3
161    TLSv1 = _ssl.PROTO_TLSv1
162    TLSv1_1 = _ssl.PROTO_TLSv1_1
163    TLSv1_2 = _ssl.PROTO_TLSv1_2
164    TLSv1_3 = _ssl.PROTO_TLSv1_3
165    MAXIMUM_SUPPORTED = _ssl.PROTO_MAXIMUM_SUPPORTED
166
167
168class _TLSContentType(_IntEnum):
169    """Content types (record layer)
170
171    See RFC 8446, section B.1
172    """
173    CHANGE_CIPHER_SPEC = 20
174    ALERT = 21
175    HANDSHAKE = 22
176    APPLICATION_DATA = 23
177    # pseudo content types
178    HEADER = 0x100
179    INNER_CONTENT_TYPE = 0x101
180
181
182class _TLSAlertType(_IntEnum):
183    """Alert types for TLSContentType.ALERT messages
184
185    See RFC 8466, section B.2
186    """
187    CLOSE_NOTIFY = 0
188    UNEXPECTED_MESSAGE = 10
189    BAD_RECORD_MAC = 20
190    DECRYPTION_FAILED = 21
191    RECORD_OVERFLOW = 22
192    DECOMPRESSION_FAILURE = 30
193    HANDSHAKE_FAILURE = 40
194    NO_CERTIFICATE = 41
195    BAD_CERTIFICATE = 42
196    UNSUPPORTED_CERTIFICATE = 43
197    CERTIFICATE_REVOKED = 44
198    CERTIFICATE_EXPIRED = 45
199    CERTIFICATE_UNKNOWN = 46
200    ILLEGAL_PARAMETER = 47
201    UNKNOWN_CA = 48
202    ACCESS_DENIED = 49
203    DECODE_ERROR = 50
204    DECRYPT_ERROR = 51
205    EXPORT_RESTRICTION = 60
206    PROTOCOL_VERSION = 70
207    INSUFFICIENT_SECURITY = 71
208    INTERNAL_ERROR = 80
209    INAPPROPRIATE_FALLBACK = 86
210    USER_CANCELED = 90
211    NO_RENEGOTIATION = 100
212    MISSING_EXTENSION = 109
213    UNSUPPORTED_EXTENSION = 110
214    CERTIFICATE_UNOBTAINABLE = 111
215    UNRECOGNIZED_NAME = 112
216    BAD_CERTIFICATE_STATUS_RESPONSE = 113
217    BAD_CERTIFICATE_HASH_VALUE = 114
218    UNKNOWN_PSK_IDENTITY = 115
219    CERTIFICATE_REQUIRED = 116
220    NO_APPLICATION_PROTOCOL = 120
221
222
223class _TLSMessageType(_IntEnum):
224    """Message types (handshake protocol)
225
226    See RFC 8446, section B.3
227    """
228    HELLO_REQUEST = 0
229    CLIENT_HELLO = 1
230    SERVER_HELLO = 2
231    HELLO_VERIFY_REQUEST = 3
232    NEWSESSION_TICKET = 4
233    END_OF_EARLY_DATA = 5
234    HELLO_RETRY_REQUEST = 6
235    ENCRYPTED_EXTENSIONS = 8
236    CERTIFICATE = 11
237    SERVER_KEY_EXCHANGE = 12
238    CERTIFICATE_REQUEST = 13
239    SERVER_DONE = 14
240    CERTIFICATE_VERIFY = 15
241    CLIENT_KEY_EXCHANGE = 16
242    FINISHED = 20
243    CERTIFICATE_URL = 21
244    CERTIFICATE_STATUS = 22
245    SUPPLEMENTAL_DATA = 23
246    KEY_UPDATE = 24
247    NEXT_PROTO = 67
248    MESSAGE_HASH = 254
249    CHANGE_CIPHER_SPEC = 0x0101
250
251
252if sys.platform == "win32":
253    from _ssl import enum_certificates, enum_crls
254
255from socket import socket, SOCK_STREAM, create_connection
256from socket import SOL_SOCKET, SO_TYPE, _GLOBAL_DEFAULT_TIMEOUT
257import socket as _socket
258import base64        # for DER-to-PEM translation
259import errno
260import warnings
261
262
263socket_error = OSError  # keep that public name in module namespace
264
265CHANNEL_BINDING_TYPES = ['tls-unique']
266
267HAS_NEVER_CHECK_COMMON_NAME = hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT')
268
269
270_RESTRICTED_SERVER_CIPHERS = _DEFAULT_CIPHERS
271
272CertificateError = SSLCertVerificationError
273
274
275def _dnsname_match(dn, hostname):
276    """Matching according to RFC 6125, section 6.4.3
277
278    - Hostnames are compared lower case.
279    - For IDNA, both dn and hostname must be encoded as IDN A-label (ACE).
280    - Partial wildcards like 'www*.example.org', multiple wildcards, sole
281      wildcard or wildcards in labels other then the left-most label are not
282      supported and a CertificateError is raised.
283    - A wildcard must match at least one character.
284    """
285    if not dn:
286        return False
287
288    wildcards = dn.count('*')
289    # speed up common case w/o wildcards
290    if not wildcards:
291        return dn.lower() == hostname.lower()
292
293    if wildcards > 1:
294        raise CertificateError(
295            "too many wildcards in certificate DNS name: {!r}.".format(dn))
296
297    dn_leftmost, sep, dn_remainder = dn.partition('.')
298
299    if '*' in dn_remainder:
300        # Only match wildcard in leftmost segment.
301        raise CertificateError(
302            "wildcard can only be present in the leftmost label: "
303            "{!r}.".format(dn))
304
305    if not sep:
306        # no right side
307        raise CertificateError(
308            "sole wildcard without additional labels are not support: "
309            "{!r}.".format(dn))
310
311    if dn_leftmost != '*':
312        # no partial wildcard matching
313        raise CertificateError(
314            "partial wildcards in leftmost label are not supported: "
315            "{!r}.".format(dn))
316
317    hostname_leftmost, sep, hostname_remainder = hostname.partition('.')
318    if not hostname_leftmost or not sep:
319        # wildcard must match at least one char
320        return False
321    return dn_remainder.lower() == hostname_remainder.lower()
322
323
324def _inet_paton(ipname):
325    """Try to convert an IP address to packed binary form
326
327    Supports IPv4 addresses on all platforms and IPv6 on platforms with IPv6
328    support.
329    """
330    # inet_aton() also accepts strings like '1', '127.1', some also trailing
331    # data like '127.0.0.1 whatever'.
332    try:
333        addr = _socket.inet_aton(ipname)
334    except OSError:
335        # not an IPv4 address
336        pass
337    else:
338        if _socket.inet_ntoa(addr) == ipname:
339            # only accept injective ipnames
340            return addr
341        else:
342            # refuse for short IPv4 notation and additional trailing data
343            raise ValueError(
344                "{!r} is not a quad-dotted IPv4 address.".format(ipname)
345            )
346
347    try:
348        return _socket.inet_pton(_socket.AF_INET6, ipname)
349    except OSError:
350        raise ValueError("{!r} is neither an IPv4 nor an IP6 "
351                         "address.".format(ipname))
352    except AttributeError:
353        # AF_INET6 not available
354        pass
355
356    raise ValueError("{!r} is not an IPv4 address.".format(ipname))
357
358
359def _ipaddress_match(cert_ipaddress, host_ip):
360    """Exact matching of IP addresses.
361
362    RFC 6125 explicitly doesn't define an algorithm for this
363    (section 1.7.2 - "Out of Scope").
364    """
365    # OpenSSL may add a trailing newline to a subjectAltName's IP address,
366    # commonly woth IPv6 addresses. Strip off trailing \n.
367    ip = _inet_paton(cert_ipaddress.rstrip())
368    return ip == host_ip
369
370
371def match_hostname(cert, hostname):
372    """Verify that *cert* (in decoded format as returned by
373    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 and RFC 6125
374    rules are followed.
375
376    The function matches IP addresses rather than dNSNames if hostname is a
377    valid ipaddress string. IPv4 addresses are supported on all platforms.
378    IPv6 addresses are supported on platforms with IPv6 support (AF_INET6
379    and inet_pton).
380
381    CertificateError is raised on failure. On success, the function
382    returns nothing.
383    """
384    warnings.warn(
385        "ssl.match_hostname() is deprecated",
386        category=DeprecationWarning,
387        stacklevel=2
388    )
389    if not cert:
390        raise ValueError("empty or no certificate, match_hostname needs a "
391                         "SSL socket or SSL context with either "
392                         "CERT_OPTIONAL or CERT_REQUIRED")
393    try:
394        host_ip = _inet_paton(hostname)
395    except ValueError:
396        # Not an IP address (common case)
397        host_ip = None
398    dnsnames = []
399    san = cert.get('subjectAltName', ())
400    for key, value in san:
401        if key == 'DNS':
402            if host_ip is None and _dnsname_match(value, hostname):
403                return
404            dnsnames.append(value)
405        elif key == 'IP Address':
406            if host_ip is not None and _ipaddress_match(value, host_ip):
407                return
408            dnsnames.append(value)
409    if not dnsnames:
410        # The subject is only checked when there is no dNSName entry
411        # in subjectAltName
412        for sub in cert.get('subject', ()):
413            for key, value in sub:
414                # XXX according to RFC 2818, the most specific Common Name
415                # must be used.
416                if key == 'commonName':
417                    if _dnsname_match(value, hostname):
418                        return
419                    dnsnames.append(value)
420    if len(dnsnames) > 1:
421        raise CertificateError("hostname %r "
422            "doesn't match either of %s"
423            % (hostname, ', '.join(map(repr, dnsnames))))
424    elif len(dnsnames) == 1:
425        raise CertificateError("hostname %r "
426            "doesn't match %r"
427            % (hostname, dnsnames[0]))
428    else:
429        raise CertificateError("no appropriate commonName or "
430            "subjectAltName fields were found")
431
432
433DefaultVerifyPaths = namedtuple("DefaultVerifyPaths",
434    "cafile capath openssl_cafile_env openssl_cafile openssl_capath_env "
435    "openssl_capath")
436
437def get_default_verify_paths():
438    """Return paths to default cafile and capath.
439    """
440    parts = _ssl.get_default_verify_paths()
441
442    # environment vars shadow paths
443    cafile = os.environ.get(parts[0], parts[1])
444    capath = os.environ.get(parts[2], parts[3])
445
446    return DefaultVerifyPaths(cafile if os.path.isfile(cafile) else None,
447                              capath if os.path.isdir(capath) else None,
448                              *parts)
449
450
451class _ASN1Object(namedtuple("_ASN1Object", "nid shortname longname oid")):
452    """ASN.1 object identifier lookup
453    """
454    __slots__ = ()
455
456    def __new__(cls, oid):
457        return super().__new__(cls, *_txt2obj(oid, name=False))
458
459    @classmethod
460    def fromnid(cls, nid):
461        """Create _ASN1Object from OpenSSL numeric ID
462        """
463        return super().__new__(cls, *_nid2obj(nid))
464
465    @classmethod
466    def fromname(cls, name):
467        """Create _ASN1Object from short name, long name or OID
468        """
469        return super().__new__(cls, *_txt2obj(name, name=True))
470
471
472class Purpose(_ASN1Object, _Enum):
473    """SSLContext purpose flags with X509v3 Extended Key Usage objects
474    """
475    SERVER_AUTH = '1.3.6.1.5.5.7.3.1'
476    CLIENT_AUTH = '1.3.6.1.5.5.7.3.2'
477
478
479class SSLContext(_SSLContext):
480    """An SSLContext holds various SSL-related configuration options and
481    data, such as certificates and possibly a private key."""
482    _windows_cert_stores = ("CA", "ROOT")
483
484    sslsocket_class = None  # SSLSocket is assigned later.
485    sslobject_class = None  # SSLObject is assigned later.
486
487    def __new__(cls, protocol=None, *args, **kwargs):
488        if protocol is None:
489            warnings.warn(
490                "ssl.SSLContext() without protocol argument is deprecated.",
491                category=DeprecationWarning,
492                stacklevel=2
493            )
494            protocol = PROTOCOL_TLS
495        self = _SSLContext.__new__(cls, protocol)
496        return self
497
498    def _encode_hostname(self, hostname):
499        if hostname is None:
500            return None
501        elif isinstance(hostname, str):
502            return hostname.encode('idna').decode('ascii')
503        else:
504            return hostname.decode('ascii')
505
506    def wrap_socket(self, sock, server_side=False,
507                    do_handshake_on_connect=True,
508                    suppress_ragged_eofs=True,
509                    server_hostname=None, session=None):
510        # SSLSocket class handles server_hostname encoding before it calls
511        # ctx._wrap_socket()
512        return self.sslsocket_class._create(
513            sock=sock,
514            server_side=server_side,
515            do_handshake_on_connect=do_handshake_on_connect,
516            suppress_ragged_eofs=suppress_ragged_eofs,
517            server_hostname=server_hostname,
518            context=self,
519            session=session
520        )
521
522    def wrap_bio(self, incoming, outgoing, server_side=False,
523                 server_hostname=None, session=None):
524        # Need to encode server_hostname here because _wrap_bio() can only
525        # handle ASCII str.
526        return self.sslobject_class._create(
527            incoming, outgoing, server_side=server_side,
528            server_hostname=self._encode_hostname(server_hostname),
529            session=session, context=self,
530        )
531
532    def set_npn_protocols(self, npn_protocols):
533        warnings.warn(
534            "ssl NPN is deprecated, use ALPN instead",
535            DeprecationWarning,
536            stacklevel=2
537        )
538        protos = bytearray()
539        for protocol in npn_protocols:
540            b = bytes(protocol, 'ascii')
541            if len(b) == 0 or len(b) > 255:
542                raise SSLError('NPN protocols must be 1 to 255 in length')
543            protos.append(len(b))
544            protos.extend(b)
545
546        self._set_npn_protocols(protos)
547
548    def set_servername_callback(self, server_name_callback):
549        if server_name_callback is None:
550            self.sni_callback = None
551        else:
552            if not callable(server_name_callback):
553                raise TypeError("not a callable object")
554
555            def shim_cb(sslobj, servername, sslctx):
556                servername = self._encode_hostname(servername)
557                return server_name_callback(sslobj, servername, sslctx)
558
559            self.sni_callback = shim_cb
560
561    def set_alpn_protocols(self, alpn_protocols):
562        protos = bytearray()
563        for protocol in alpn_protocols:
564            b = bytes(protocol, 'ascii')
565            if len(b) == 0 or len(b) > 255:
566                raise SSLError('ALPN protocols must be 1 to 255 in length')
567            protos.append(len(b))
568            protos.extend(b)
569
570        self._set_alpn_protocols(protos)
571
572    def _load_windows_store_certs(self, storename, purpose):
573        certs = bytearray()
574        try:
575            for cert, encoding, trust in enum_certificates(storename):
576                # CA certs are never PKCS#7 encoded
577                if encoding == "x509_asn":
578                    if trust is True or purpose.oid in trust:
579                        certs.extend(cert)
580        except PermissionError:
581            warnings.warn("unable to enumerate Windows certificate store")
582        if certs:
583            self.load_verify_locations(cadata=certs)
584        return certs
585
586    def load_default_certs(self, purpose=Purpose.SERVER_AUTH):
587        if not isinstance(purpose, _ASN1Object):
588            raise TypeError(purpose)
589        if sys.platform == "win32":
590            for storename in self._windows_cert_stores:
591                self._load_windows_store_certs(storename, purpose)
592        self.set_default_verify_paths()
593
594    if hasattr(_SSLContext, 'minimum_version'):
595        @property
596        def minimum_version(self):
597            return TLSVersion(super().minimum_version)
598
599        @minimum_version.setter
600        def minimum_version(self, value):
601            if value == TLSVersion.SSLv3:
602                self.options &= ~Options.OP_NO_SSLv3
603            super(SSLContext, SSLContext).minimum_version.__set__(self, value)
604
605        @property
606        def maximum_version(self):
607            return TLSVersion(super().maximum_version)
608
609        @maximum_version.setter
610        def maximum_version(self, value):
611            super(SSLContext, SSLContext).maximum_version.__set__(self, value)
612
613    @property
614    def options(self):
615        return Options(super().options)
616
617    @options.setter
618    def options(self, value):
619        super(SSLContext, SSLContext).options.__set__(self, value)
620
621    if hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT'):
622        @property
623        def hostname_checks_common_name(self):
624            ncs = self._host_flags & _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
625            return ncs != _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
626
627        @hostname_checks_common_name.setter
628        def hostname_checks_common_name(self, value):
629            if value:
630                self._host_flags &= ~_ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
631            else:
632                self._host_flags |= _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
633    else:
634        @property
635        def hostname_checks_common_name(self):
636            return True
637
638    @property
639    def _msg_callback(self):
640        """TLS message callback
641
642        The message callback provides a debugging hook to analyze TLS
643        connections. The callback is called for any TLS protocol message
644        (header, handshake, alert, and more), but not for application data.
645        Due to technical  limitations, the callback can't be used to filter
646        traffic or to abort a connection. Any exception raised in the
647        callback is delayed until the handshake, read, or write operation
648        has been performed.
649
650        def msg_cb(conn, direction, version, content_type, msg_type, data):
651            pass
652
653        conn
654            :class:`SSLSocket` or :class:`SSLObject` instance
655        direction
656            ``read`` or ``write``
657        version
658            :class:`TLSVersion` enum member or int for unknown version. For a
659            frame header, it's the header version.
660        content_type
661            :class:`_TLSContentType` enum member or int for unsupported
662            content type.
663        msg_type
664            Either a :class:`_TLSContentType` enum number for a header
665            message, a :class:`_TLSAlertType` enum member for an alert
666            message, a :class:`_TLSMessageType` enum member for other
667            messages, or int for unsupported message types.
668        data
669            Raw, decrypted message content as bytes
670        """
671        inner = super()._msg_callback
672        if inner is not None:
673            return inner.user_function
674        else:
675            return None
676
677    @_msg_callback.setter
678    def _msg_callback(self, callback):
679        if callback is None:
680            super(SSLContext, SSLContext)._msg_callback.__set__(self, None)
681            return
682
683        if not hasattr(callback, '__call__'):
684            raise TypeError(f"{callback} is not callable.")
685
686        def inner(conn, direction, version, content_type, msg_type, data):
687            try:
688                version = TLSVersion(version)
689            except ValueError:
690                pass
691
692            try:
693                content_type = _TLSContentType(content_type)
694            except ValueError:
695                pass
696
697            if content_type == _TLSContentType.HEADER:
698                msg_enum = _TLSContentType
699            elif content_type == _TLSContentType.ALERT:
700                msg_enum = _TLSAlertType
701            else:
702                msg_enum = _TLSMessageType
703            try:
704                msg_type = msg_enum(msg_type)
705            except ValueError:
706                pass
707
708            return callback(conn, direction, version,
709                            content_type, msg_type, data)
710
711        inner.user_function = callback
712
713        super(SSLContext, SSLContext)._msg_callback.__set__(self, inner)
714
715    @property
716    def protocol(self):
717        return _SSLMethod(super().protocol)
718
719    @property
720    def verify_flags(self):
721        return VerifyFlags(super().verify_flags)
722
723    @verify_flags.setter
724    def verify_flags(self, value):
725        super(SSLContext, SSLContext).verify_flags.__set__(self, value)
726
727    @property
728    def verify_mode(self):
729        value = super().verify_mode
730        try:
731            return VerifyMode(value)
732        except ValueError:
733            return value
734
735    @verify_mode.setter
736    def verify_mode(self, value):
737        super(SSLContext, SSLContext).verify_mode.__set__(self, value)
738
739
740def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None,
741                           capath=None, cadata=None):
742    """Create a SSLContext object with default settings.
743
744    NOTE: The protocol and settings may change anytime without prior
745          deprecation. The values represent a fair balance between maximum
746          compatibility and security.
747    """
748    if not isinstance(purpose, _ASN1Object):
749        raise TypeError(purpose)
750
751    # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION,
752    # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE
753    # by default.
754    if purpose == Purpose.SERVER_AUTH:
755        # verify certs and host name in client mode
756        context = SSLContext(PROTOCOL_TLS_CLIENT)
757        context.verify_mode = CERT_REQUIRED
758        context.check_hostname = True
759    elif purpose == Purpose.CLIENT_AUTH:
760        context = SSLContext(PROTOCOL_TLS_SERVER)
761    else:
762        raise ValueError(purpose)
763
764    if cafile or capath or cadata:
765        context.load_verify_locations(cafile, capath, cadata)
766    elif context.verify_mode != CERT_NONE:
767        # no explicit cafile, capath or cadata but the verify mode is
768        # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
769        # root CA certificates for the given purpose. This may fail silently.
770        context.load_default_certs(purpose)
771    # OpenSSL 1.1.1 keylog file
772    if hasattr(context, 'keylog_filename'):
773        keylogfile = os.environ.get('SSLKEYLOGFILE')
774        if keylogfile and not sys.flags.ignore_environment:
775            context.keylog_filename = keylogfile
776    return context
777
778def _create_unverified_context(protocol=None, *, cert_reqs=CERT_NONE,
779                           check_hostname=False, purpose=Purpose.SERVER_AUTH,
780                           certfile=None, keyfile=None,
781                           cafile=None, capath=None, cadata=None):
782    """Create a SSLContext object for Python stdlib modules
783
784    All Python stdlib modules shall use this function to create SSLContext
785    objects in order to keep common settings in one place. The configuration
786    is less restrict than create_default_context()'s to increase backward
787    compatibility.
788    """
789    if not isinstance(purpose, _ASN1Object):
790        raise TypeError(purpose)
791
792    # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION,
793    # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE
794    # by default.
795    if purpose == Purpose.SERVER_AUTH:
796        # verify certs and host name in client mode
797        if protocol is None:
798            protocol = PROTOCOL_TLS_CLIENT
799    elif purpose == Purpose.CLIENT_AUTH:
800        if protocol is None:
801            protocol = PROTOCOL_TLS_SERVER
802    else:
803        raise ValueError(purpose)
804
805    context = SSLContext(protocol)
806    context.check_hostname = check_hostname
807    if cert_reqs is not None:
808        context.verify_mode = cert_reqs
809    if check_hostname:
810        context.check_hostname = True
811
812    if keyfile and not certfile:
813        raise ValueError("certfile must be specified")
814    if certfile or keyfile:
815        context.load_cert_chain(certfile, keyfile)
816
817    # load CA root certs
818    if cafile or capath or cadata:
819        context.load_verify_locations(cafile, capath, cadata)
820    elif context.verify_mode != CERT_NONE:
821        # no explicit cafile, capath or cadata but the verify mode is
822        # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
823        # root CA certificates for the given purpose. This may fail silently.
824        context.load_default_certs(purpose)
825    # OpenSSL 1.1.1 keylog file
826    if hasattr(context, 'keylog_filename'):
827        keylogfile = os.environ.get('SSLKEYLOGFILE')
828        if keylogfile and not sys.flags.ignore_environment:
829            context.keylog_filename = keylogfile
830    return context
831
832# Used by http.client if no context is explicitly passed.
833_create_default_https_context = create_default_context
834
835
836# Backwards compatibility alias, even though it's not a public name.
837_create_stdlib_context = _create_unverified_context
838
839
840class SSLObject:
841    """This class implements an interface on top of a low-level SSL object as
842    implemented by OpenSSL. This object captures the state of an SSL connection
843    but does not provide any network IO itself. IO needs to be performed
844    through separate "BIO" objects which are OpenSSL's IO abstraction layer.
845
846    This class does not have a public constructor. Instances are returned by
847    ``SSLContext.wrap_bio``. This class is typically used by framework authors
848    that want to implement asynchronous IO for SSL through memory buffers.
849
850    When compared to ``SSLSocket``, this object lacks the following features:
851
852     * Any form of network IO, including methods such as ``recv`` and ``send``.
853     * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
854    """
855    def __init__(self, *args, **kwargs):
856        raise TypeError(
857            f"{self.__class__.__name__} does not have a public "
858            f"constructor. Instances are returned by SSLContext.wrap_bio()."
859        )
860
861    @classmethod
862    def _create(cls, incoming, outgoing, server_side=False,
863                 server_hostname=None, session=None, context=None):
864        self = cls.__new__(cls)
865        sslobj = context._wrap_bio(
866            incoming, outgoing, server_side=server_side,
867            server_hostname=server_hostname,
868            owner=self, session=session
869        )
870        self._sslobj = sslobj
871        return self
872
873    @property
874    def context(self):
875        """The SSLContext that is currently in use."""
876        return self._sslobj.context
877
878    @context.setter
879    def context(self, ctx):
880        self._sslobj.context = ctx
881
882    @property
883    def session(self):
884        """The SSLSession for client socket."""
885        return self._sslobj.session
886
887    @session.setter
888    def session(self, session):
889        self._sslobj.session = session
890
891    @property
892    def session_reused(self):
893        """Was the client session reused during handshake"""
894        return self._sslobj.session_reused
895
896    @property
897    def server_side(self):
898        """Whether this is a server-side socket."""
899        return self._sslobj.server_side
900
901    @property
902    def server_hostname(self):
903        """The currently set server hostname (for SNI), or ``None`` if no
904        server hostname is set."""
905        return self._sslobj.server_hostname
906
907    def read(self, len=1024, buffer=None):
908        """Read up to 'len' bytes from the SSL object and return them.
909
910        If 'buffer' is provided, read into this buffer and return the number of
911        bytes read.
912        """
913        if buffer is not None:
914            v = self._sslobj.read(len, buffer)
915        else:
916            v = self._sslobj.read(len)
917        return v
918
919    def write(self, data):
920        """Write 'data' to the SSL object and return the number of bytes
921        written.
922
923        The 'data' argument must support the buffer interface.
924        """
925        return self._sslobj.write(data)
926
927    def getpeercert(self, binary_form=False):
928        """Returns a formatted version of the data in the certificate provided
929        by the other end of the SSL channel.
930
931        Return None if no certificate was provided, {} if a certificate was
932        provided, but not validated.
933        """
934        return self._sslobj.getpeercert(binary_form)
935
936    def selected_npn_protocol(self):
937        """Return the currently selected NPN protocol as a string, or ``None``
938        if a next protocol was not negotiated or if NPN is not supported by one
939        of the peers."""
940        warnings.warn(
941            "ssl NPN is deprecated, use ALPN instead",
942            DeprecationWarning,
943            stacklevel=2
944        )
945
946    def selected_alpn_protocol(self):
947        """Return the currently selected ALPN protocol as a string, or ``None``
948        if a next protocol was not negotiated or if ALPN is not supported by one
949        of the peers."""
950        return self._sslobj.selected_alpn_protocol()
951
952    def cipher(self):
953        """Return the currently selected cipher as a 3-tuple ``(name,
954        ssl_version, secret_bits)``."""
955        return self._sslobj.cipher()
956
957    def shared_ciphers(self):
958        """Return a list of ciphers shared by the client during the handshake or
959        None if this is not a valid server connection.
960        """
961        return self._sslobj.shared_ciphers()
962
963    def compression(self):
964        """Return the current compression algorithm in use, or ``None`` if
965        compression was not negotiated or not supported by one of the peers."""
966        return self._sslobj.compression()
967
968    def pending(self):
969        """Return the number of bytes that can be read immediately."""
970        return self._sslobj.pending()
971
972    def do_handshake(self):
973        """Start the SSL/TLS handshake."""
974        self._sslobj.do_handshake()
975
976    def unwrap(self):
977        """Start the SSL shutdown handshake."""
978        return self._sslobj.shutdown()
979
980    def get_channel_binding(self, cb_type="tls-unique"):
981        """Get channel binding data for current connection.  Raise ValueError
982        if the requested `cb_type` is not supported.  Return bytes of the data
983        or None if the data is not available (e.g. before the handshake)."""
984        return self._sslobj.get_channel_binding(cb_type)
985
986    def version(self):
987        """Return a string identifying the protocol version used by the
988        current SSL channel. """
989        return self._sslobj.version()
990
991    def verify_client_post_handshake(self):
992        return self._sslobj.verify_client_post_handshake()
993
994
995def _sslcopydoc(func):
996    """Copy docstring from SSLObject to SSLSocket"""
997    func.__doc__ = getattr(SSLObject, func.__name__).__doc__
998    return func
999
1000
1001class SSLSocket(socket):
1002    """This class implements a subtype of socket.socket that wraps
1003    the underlying OS socket in an SSL context when necessary, and
1004    provides read and write methods over that channel. """
1005
1006    def __init__(self, *args, **kwargs):
1007        raise TypeError(
1008            f"{self.__class__.__name__} does not have a public "
1009            f"constructor. Instances are returned by "
1010            f"SSLContext.wrap_socket()."
1011        )
1012
1013    @classmethod
1014    def _create(cls, sock, server_side=False, do_handshake_on_connect=True,
1015                suppress_ragged_eofs=True, server_hostname=None,
1016                context=None, session=None):
1017        if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
1018            raise NotImplementedError("only stream sockets are supported")
1019        if server_side:
1020            if server_hostname:
1021                raise ValueError("server_hostname can only be specified "
1022                                 "in client mode")
1023            if session is not None:
1024                raise ValueError("session can only be specified in "
1025                                 "client mode")
1026        if context.check_hostname and not server_hostname:
1027            raise ValueError("check_hostname requires server_hostname")
1028
1029        kwargs = dict(
1030            family=sock.family, type=sock.type, proto=sock.proto,
1031            fileno=sock.fileno()
1032        )
1033        self = cls.__new__(cls, **kwargs)
1034        super(SSLSocket, self).__init__(**kwargs)
1035        self.settimeout(sock.gettimeout())
1036        sock.detach()
1037
1038        self._context = context
1039        self._session = session
1040        self._closed = False
1041        self._sslobj = None
1042        self.server_side = server_side
1043        self.server_hostname = context._encode_hostname(server_hostname)
1044        self.do_handshake_on_connect = do_handshake_on_connect
1045        self.suppress_ragged_eofs = suppress_ragged_eofs
1046
1047        # See if we are connected
1048        try:
1049            self.getpeername()
1050        except OSError as e:
1051            if e.errno != errno.ENOTCONN:
1052                raise
1053            connected = False
1054        else:
1055            connected = True
1056
1057        self._connected = connected
1058        if connected:
1059            # create the SSL object
1060            try:
1061                self._sslobj = self._context._wrap_socket(
1062                    self, server_side, self.server_hostname,
1063                    owner=self, session=self._session,
1064                )
1065                if do_handshake_on_connect:
1066                    timeout = self.gettimeout()
1067                    if timeout == 0.0:
1068                        # non-blocking
1069                        raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
1070                    self.do_handshake()
1071            except (OSError, ValueError):
1072                self.close()
1073                raise
1074        return self
1075
1076    @property
1077    @_sslcopydoc
1078    def context(self):
1079        return self._context
1080
1081    @context.setter
1082    def context(self, ctx):
1083        self._context = ctx
1084        self._sslobj.context = ctx
1085
1086    @property
1087    @_sslcopydoc
1088    def session(self):
1089        if self._sslobj is not None:
1090            return self._sslobj.session
1091
1092    @session.setter
1093    def session(self, session):
1094        self._session = session
1095        if self._sslobj is not None:
1096            self._sslobj.session = session
1097
1098    @property
1099    @_sslcopydoc
1100    def session_reused(self):
1101        if self._sslobj is not None:
1102            return self._sslobj.session_reused
1103
1104    def dup(self):
1105        raise NotImplementedError("Can't dup() %s instances" %
1106                                  self.__class__.__name__)
1107
1108    def _checkClosed(self, msg=None):
1109        # raise an exception here if you wish to check for spurious closes
1110        pass
1111
1112    def _check_connected(self):
1113        if not self._connected:
1114            # getpeername() will raise ENOTCONN if the socket is really
1115            # not connected; note that we can be connected even without
1116            # _connected being set, e.g. if connect() first returned
1117            # EAGAIN.
1118            self.getpeername()
1119
1120    def read(self, len=1024, buffer=None):
1121        """Read up to LEN bytes and return them.
1122        Return zero-length string on EOF."""
1123
1124        self._checkClosed()
1125        if self._sslobj is None:
1126            raise ValueError("Read on closed or unwrapped SSL socket.")
1127        try:
1128            if buffer is not None:
1129                return self._sslobj.read(len, buffer)
1130            else:
1131                return self._sslobj.read(len)
1132        except SSLError as x:
1133            if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
1134                if buffer is not None:
1135                    return 0
1136                else:
1137                    return b''
1138            else:
1139                raise
1140
1141    def write(self, data):
1142        """Write DATA to the underlying SSL channel.  Returns
1143        number of bytes of DATA actually transmitted."""
1144
1145        self._checkClosed()
1146        if self._sslobj is None:
1147            raise ValueError("Write on closed or unwrapped SSL socket.")
1148        return self._sslobj.write(data)
1149
1150    @_sslcopydoc
1151    def getpeercert(self, binary_form=False):
1152        self._checkClosed()
1153        self._check_connected()
1154        return self._sslobj.getpeercert(binary_form)
1155
1156    @_sslcopydoc
1157    def selected_npn_protocol(self):
1158        self._checkClosed()
1159        warnings.warn(
1160            "ssl NPN is deprecated, use ALPN instead",
1161            DeprecationWarning,
1162            stacklevel=2
1163        )
1164        return None
1165
1166    @_sslcopydoc
1167    def selected_alpn_protocol(self):
1168        self._checkClosed()
1169        if self._sslobj is None or not _ssl.HAS_ALPN:
1170            return None
1171        else:
1172            return self._sslobj.selected_alpn_protocol()
1173
1174    @_sslcopydoc
1175    def cipher(self):
1176        self._checkClosed()
1177        if self._sslobj is None:
1178            return None
1179        else:
1180            return self._sslobj.cipher()
1181
1182    @_sslcopydoc
1183    def shared_ciphers(self):
1184        self._checkClosed()
1185        if self._sslobj is None:
1186            return None
1187        else:
1188            return self._sslobj.shared_ciphers()
1189
1190    @_sslcopydoc
1191    def compression(self):
1192        self._checkClosed()
1193        if self._sslobj is None:
1194            return None
1195        else:
1196            return self._sslobj.compression()
1197
1198    def send(self, data, flags=0):
1199        self._checkClosed()
1200        if self._sslobj is not None:
1201            if flags != 0:
1202                raise ValueError(
1203                    "non-zero flags not allowed in calls to send() on %s" %
1204                    self.__class__)
1205            return self._sslobj.write(data)
1206        else:
1207            return super().send(data, flags)
1208
1209    def sendto(self, data, flags_or_addr, addr=None):
1210        self._checkClosed()
1211        if self._sslobj is not None:
1212            raise ValueError("sendto not allowed on instances of %s" %
1213                             self.__class__)
1214        elif addr is None:
1215            return super().sendto(data, flags_or_addr)
1216        else:
1217            return super().sendto(data, flags_or_addr, addr)
1218
1219    def sendmsg(self, *args, **kwargs):
1220        # Ensure programs don't send data unencrypted if they try to
1221        # use this method.
1222        raise NotImplementedError("sendmsg not allowed on instances of %s" %
1223                                  self.__class__)
1224
1225    def sendall(self, data, flags=0):
1226        self._checkClosed()
1227        if self._sslobj is not None:
1228            if flags != 0:
1229                raise ValueError(
1230                    "non-zero flags not allowed in calls to sendall() on %s" %
1231                    self.__class__)
1232            count = 0
1233            with memoryview(data) as view, view.cast("B") as byte_view:
1234                amount = len(byte_view)
1235                while count < amount:
1236                    v = self.send(byte_view[count:])
1237                    count += v
1238        else:
1239            return super().sendall(data, flags)
1240
1241    def sendfile(self, file, offset=0, count=None):
1242        """Send a file, possibly by using os.sendfile() if this is a
1243        clear-text socket.  Return the total number of bytes sent.
1244        """
1245        if self._sslobj is not None:
1246            return self._sendfile_use_send(file, offset, count)
1247        else:
1248            # os.sendfile() works with plain sockets only
1249            return super().sendfile(file, offset, count)
1250
1251    def recv(self, buflen=1024, flags=0):
1252        self._checkClosed()
1253        if self._sslobj is not None:
1254            if flags != 0:
1255                raise ValueError(
1256                    "non-zero flags not allowed in calls to recv() on %s" %
1257                    self.__class__)
1258            return self.read(buflen)
1259        else:
1260            return super().recv(buflen, flags)
1261
1262    def recv_into(self, buffer, nbytes=None, flags=0):
1263        self._checkClosed()
1264        if buffer and (nbytes is None):
1265            nbytes = len(buffer)
1266        elif nbytes is None:
1267            nbytes = 1024
1268        if self._sslobj is not None:
1269            if flags != 0:
1270                raise ValueError(
1271                  "non-zero flags not allowed in calls to recv_into() on %s" %
1272                  self.__class__)
1273            return self.read(nbytes, buffer)
1274        else:
1275            return super().recv_into(buffer, nbytes, flags)
1276
1277    def recvfrom(self, buflen=1024, flags=0):
1278        self._checkClosed()
1279        if self._sslobj is not None:
1280            raise ValueError("recvfrom not allowed on instances of %s" %
1281                             self.__class__)
1282        else:
1283            return super().recvfrom(buflen, flags)
1284
1285    def recvfrom_into(self, buffer, nbytes=None, flags=0):
1286        self._checkClosed()
1287        if self._sslobj is not None:
1288            raise ValueError("recvfrom_into not allowed on instances of %s" %
1289                             self.__class__)
1290        else:
1291            return super().recvfrom_into(buffer, nbytes, flags)
1292
1293    def recvmsg(self, *args, **kwargs):
1294        raise NotImplementedError("recvmsg not allowed on instances of %s" %
1295                                  self.__class__)
1296
1297    def recvmsg_into(self, *args, **kwargs):
1298        raise NotImplementedError("recvmsg_into not allowed on instances of "
1299                                  "%s" % self.__class__)
1300
1301    @_sslcopydoc
1302    def pending(self):
1303        self._checkClosed()
1304        if self._sslobj is not None:
1305            return self._sslobj.pending()
1306        else:
1307            return 0
1308
1309    def shutdown(self, how):
1310        self._checkClosed()
1311        self._sslobj = None
1312        super().shutdown(how)
1313
1314    @_sslcopydoc
1315    def unwrap(self):
1316        if self._sslobj:
1317            s = self._sslobj.shutdown()
1318            self._sslobj = None
1319            return s
1320        else:
1321            raise ValueError("No SSL wrapper around " + str(self))
1322
1323    @_sslcopydoc
1324    def verify_client_post_handshake(self):
1325        if self._sslobj:
1326            return self._sslobj.verify_client_post_handshake()
1327        else:
1328            raise ValueError("No SSL wrapper around " + str(self))
1329
1330    def _real_close(self):
1331        self._sslobj = None
1332        super()._real_close()
1333
1334    @_sslcopydoc
1335    def do_handshake(self, block=False):
1336        self._check_connected()
1337        timeout = self.gettimeout()
1338        try:
1339            if timeout == 0.0 and block:
1340                self.settimeout(None)
1341            self._sslobj.do_handshake()
1342        finally:
1343            self.settimeout(timeout)
1344
1345    def _real_connect(self, addr, connect_ex):
1346        if self.server_side:
1347            raise ValueError("can't connect in server-side mode")
1348        # Here we assume that the socket is client-side, and not
1349        # connected at the time of the call.  We connect it, then wrap it.
1350        if self._connected or self._sslobj is not None:
1351            raise ValueError("attempt to connect already-connected SSLSocket!")
1352        self._sslobj = self.context._wrap_socket(
1353            self, False, self.server_hostname,
1354            owner=self, session=self._session
1355        )
1356        try:
1357            if connect_ex:
1358                rc = super().connect_ex(addr)
1359            else:
1360                rc = None
1361                super().connect(addr)
1362            if not rc:
1363                self._connected = True
1364                if self.do_handshake_on_connect:
1365                    self.do_handshake()
1366            return rc
1367        except (OSError, ValueError):
1368            self._sslobj = None
1369            raise
1370
1371    def connect(self, addr):
1372        """Connects to remote ADDR, and then wraps the connection in
1373        an SSL channel."""
1374        self._real_connect(addr, False)
1375
1376    def connect_ex(self, addr):
1377        """Connects to remote ADDR, and then wraps the connection in
1378        an SSL channel."""
1379        return self._real_connect(addr, True)
1380
1381    def accept(self):
1382        """Accepts a new connection from a remote client, and returns
1383        a tuple containing that new connection wrapped with a server-side
1384        SSL channel, and the address of the remote client."""
1385
1386        newsock, addr = super().accept()
1387        newsock = self.context.wrap_socket(newsock,
1388                    do_handshake_on_connect=self.do_handshake_on_connect,
1389                    suppress_ragged_eofs=self.suppress_ragged_eofs,
1390                    server_side=True)
1391        return newsock, addr
1392
1393    @_sslcopydoc
1394    def get_channel_binding(self, cb_type="tls-unique"):
1395        if self._sslobj is not None:
1396            return self._sslobj.get_channel_binding(cb_type)
1397        else:
1398            if cb_type not in CHANNEL_BINDING_TYPES:
1399                raise ValueError(
1400                    "{0} channel binding type not implemented".format(cb_type)
1401                )
1402            return None
1403
1404    @_sslcopydoc
1405    def version(self):
1406        if self._sslobj is not None:
1407            return self._sslobj.version()
1408        else:
1409            return None
1410
1411
1412# Python does not support forward declaration of types.
1413SSLContext.sslsocket_class = SSLSocket
1414SSLContext.sslobject_class = SSLObject
1415
1416
1417def wrap_socket(sock, keyfile=None, certfile=None,
1418                server_side=False, cert_reqs=CERT_NONE,
1419                ssl_version=PROTOCOL_TLS, ca_certs=None,
1420                do_handshake_on_connect=True,
1421                suppress_ragged_eofs=True,
1422                ciphers=None):
1423    warnings.warn(
1424        "ssl.wrap_socket() is deprecated, use SSLContext.wrap_socket()",
1425        category=DeprecationWarning,
1426        stacklevel=2
1427    )
1428    if server_side and not certfile:
1429        raise ValueError("certfile must be specified for server-side "
1430                         "operations")
1431    if keyfile and not certfile:
1432        raise ValueError("certfile must be specified")
1433    context = SSLContext(ssl_version)
1434    context.verify_mode = cert_reqs
1435    if ca_certs:
1436        context.load_verify_locations(ca_certs)
1437    if certfile:
1438        context.load_cert_chain(certfile, keyfile)
1439    if ciphers:
1440        context.set_ciphers(ciphers)
1441    return context.wrap_socket(
1442        sock=sock, server_side=server_side,
1443        do_handshake_on_connect=do_handshake_on_connect,
1444        suppress_ragged_eofs=suppress_ragged_eofs
1445    )
1446
1447# some utility functions
1448
1449def cert_time_to_seconds(cert_time):
1450    """Return the time in seconds since the Epoch, given the timestring
1451    representing the "notBefore" or "notAfter" date from a certificate
1452    in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale).
1453
1454    "notBefore" or "notAfter" dates must use UTC (RFC 5280).
1455
1456    Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec
1457    UTC should be specified as GMT (see ASN1_TIME_print())
1458    """
1459    from time import strptime
1460    from calendar import timegm
1461
1462    months = (
1463        "Jan","Feb","Mar","Apr","May","Jun",
1464        "Jul","Aug","Sep","Oct","Nov","Dec"
1465    )
1466    time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT
1467    try:
1468        month_number = months.index(cert_time[:3].title()) + 1
1469    except ValueError:
1470        raise ValueError('time data %r does not match '
1471                         'format "%%b%s"' % (cert_time, time_format))
1472    else:
1473        # found valid month
1474        tt = strptime(cert_time[3:], time_format)
1475        # return an integer, the previous mktime()-based implementation
1476        # returned a float (fractional seconds are always zero here).
1477        return timegm((tt[0], month_number) + tt[2:6])
1478
1479PEM_HEADER = "-----BEGIN CERTIFICATE-----"
1480PEM_FOOTER = "-----END CERTIFICATE-----"
1481
1482def DER_cert_to_PEM_cert(der_cert_bytes):
1483    """Takes a certificate in binary DER format and returns the
1484    PEM version of it as a string."""
1485
1486    f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict')
1487    ss = [PEM_HEADER]
1488    ss += [f[i:i+64] for i in range(0, len(f), 64)]
1489    ss.append(PEM_FOOTER + '\n')
1490    return '\n'.join(ss)
1491
1492def PEM_cert_to_DER_cert(pem_cert_string):
1493    """Takes a certificate in ASCII PEM format and returns the
1494    DER-encoded version of it as a byte sequence"""
1495
1496    if not pem_cert_string.startswith(PEM_HEADER):
1497        raise ValueError("Invalid PEM encoding; must start with %s"
1498                         % PEM_HEADER)
1499    if not pem_cert_string.strip().endswith(PEM_FOOTER):
1500        raise ValueError("Invalid PEM encoding; must end with %s"
1501                         % PEM_FOOTER)
1502    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
1503    return base64.decodebytes(d.encode('ASCII', 'strict'))
1504
1505def get_server_certificate(addr, ssl_version=PROTOCOL_TLS_CLIENT,
1506                           ca_certs=None, timeout=_GLOBAL_DEFAULT_TIMEOUT):
1507    """Retrieve the certificate from the server at the specified address,
1508    and return it as a PEM-encoded string.
1509    If 'ca_certs' is specified, validate the server cert against it.
1510    If 'ssl_version' is specified, use it in the connection attempt.
1511    If 'timeout' is specified, use it in the connection attempt.
1512    """
1513
1514    host, port = addr
1515    if ca_certs is not None:
1516        cert_reqs = CERT_REQUIRED
1517    else:
1518        cert_reqs = CERT_NONE
1519    context = _create_stdlib_context(ssl_version,
1520                                     cert_reqs=cert_reqs,
1521                                     cafile=ca_certs)
1522    with create_connection(addr, timeout=timeout) as sock:
1523        with context.wrap_socket(sock, server_hostname=host) as sslsock:
1524            dercert = sslsock.getpeercert(True)
1525    return DER_cert_to_PEM_cert(dercert)
1526
1527def get_protocol_name(protocol_code):
1528    return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')
1529