1# Wrapper module for _ssl, providing some additional facilities 2# implemented in Python. Written by Bill Janssen. 3 4"""This module provides some more Pythonic support for SSL. 5 6Object types: 7 8 SSLSocket -- subtype of socket.socket which does SSL over the socket 9 10Exceptions: 11 12 SSLError -- exception raised for I/O errors 13 14Functions: 15 16 cert_time_to_seconds -- convert time string used for certificate 17 notBefore and notAfter functions to integer 18 seconds past the Epoch (the time values 19 returned from time.time()) 20 21 fetch_server_certificate (HOST, PORT) -- fetch the certificate provided 22 by the server running on HOST at port PORT. No 23 validation of the certificate is performed. 24 25Integer constants: 26 27SSL_ERROR_ZERO_RETURN 28SSL_ERROR_WANT_READ 29SSL_ERROR_WANT_WRITE 30SSL_ERROR_WANT_X509_LOOKUP 31SSL_ERROR_SYSCALL 32SSL_ERROR_SSL 33SSL_ERROR_WANT_CONNECT 34 35SSL_ERROR_EOF 36SSL_ERROR_INVALID_ERROR_CODE 37 38The following group define certificate requirements that one side is 39allowing/requiring from the other side: 40 41CERT_NONE - no certificates from the other side are required (or will 42 be looked at if provided) 43CERT_OPTIONAL - certificates are not required, but if provided will be 44 validated, and if validation fails, the connection will 45 also fail 46CERT_REQUIRED - certificates are required, and will be validated, and 47 if validation fails, the connection will also fail 48 49The following constants identify various SSL protocol variants: 50 51PROTOCOL_SSLv2 52PROTOCOL_SSLv3 53PROTOCOL_SSLv23 54PROTOCOL_TLS 55PROTOCOL_TLS_CLIENT 56PROTOCOL_TLS_SERVER 57PROTOCOL_TLSv1 58PROTOCOL_TLSv1_1 59PROTOCOL_TLSv1_2 60 61The following constants identify various SSL alert message descriptions as per 62http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6 63 64ALERT_DESCRIPTION_CLOSE_NOTIFY 65ALERT_DESCRIPTION_UNEXPECTED_MESSAGE 66ALERT_DESCRIPTION_BAD_RECORD_MAC 67ALERT_DESCRIPTION_RECORD_OVERFLOW 68ALERT_DESCRIPTION_DECOMPRESSION_FAILURE 69ALERT_DESCRIPTION_HANDSHAKE_FAILURE 70ALERT_DESCRIPTION_BAD_CERTIFICATE 71ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE 72ALERT_DESCRIPTION_CERTIFICATE_REVOKED 73ALERT_DESCRIPTION_CERTIFICATE_EXPIRED 74ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN 75ALERT_DESCRIPTION_ILLEGAL_PARAMETER 76ALERT_DESCRIPTION_UNKNOWN_CA 77ALERT_DESCRIPTION_ACCESS_DENIED 78ALERT_DESCRIPTION_DECODE_ERROR 79ALERT_DESCRIPTION_DECRYPT_ERROR 80ALERT_DESCRIPTION_PROTOCOL_VERSION 81ALERT_DESCRIPTION_INSUFFICIENT_SECURITY 82ALERT_DESCRIPTION_INTERNAL_ERROR 83ALERT_DESCRIPTION_USER_CANCELLED 84ALERT_DESCRIPTION_NO_RENEGOTIATION 85ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION 86ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE 87ALERT_DESCRIPTION_UNRECOGNIZED_NAME 88ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE 89ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE 90ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY 91""" 92 93import sys 94import os 95from collections import namedtuple 96from enum import Enum as _Enum, IntEnum as _IntEnum, IntFlag as _IntFlag 97 98import _ssl # if we can't import it, let the error propagate 99 100from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION 101from _ssl import _SSLContext, MemoryBIO, SSLSession 102from _ssl import ( 103 SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, 104 SSLSyscallError, SSLEOFError, SSLCertVerificationError 105 ) 106from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj 107from _ssl import RAND_status, RAND_add, RAND_bytes, RAND_pseudo_bytes 108try: 109 from _ssl import RAND_egd 110except ImportError: 111 # LibreSSL does not provide RAND_egd 112 pass 113 114 115from _ssl import ( 116 HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_SSLv2, HAS_SSLv3, HAS_TLSv1, 117 HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3 118) 119from _ssl import _DEFAULT_CIPHERS, _OPENSSL_API_VERSION 120 121 122_IntEnum._convert( 123 '_SSLMethod', __name__, 124 lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23', 125 source=_ssl) 126 127_IntFlag._convert( 128 'Options', __name__, 129 lambda name: name.startswith('OP_'), 130 source=_ssl) 131 132_IntEnum._convert( 133 'AlertDescription', __name__, 134 lambda name: name.startswith('ALERT_DESCRIPTION_'), 135 source=_ssl) 136 137_IntEnum._convert( 138 'SSLErrorNumber', __name__, 139 lambda name: name.startswith('SSL_ERROR_'), 140 source=_ssl) 141 142_IntFlag._convert( 143 'VerifyFlags', __name__, 144 lambda name: name.startswith('VERIFY_'), 145 source=_ssl) 146 147_IntEnum._convert( 148 'VerifyMode', __name__, 149 lambda name: name.startswith('CERT_'), 150 source=_ssl) 151 152PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_TLS 153_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()} 154 155_SSLv2_IF_EXISTS = getattr(_SSLMethod, 'PROTOCOL_SSLv2', None) 156 157 158class TLSVersion(_IntEnum): 159 MINIMUM_SUPPORTED = _ssl.PROTO_MINIMUM_SUPPORTED 160 SSLv3 = _ssl.PROTO_SSLv3 161 TLSv1 = _ssl.PROTO_TLSv1 162 TLSv1_1 = _ssl.PROTO_TLSv1_1 163 TLSv1_2 = _ssl.PROTO_TLSv1_2 164 TLSv1_3 = _ssl.PROTO_TLSv1_3 165 MAXIMUM_SUPPORTED = _ssl.PROTO_MAXIMUM_SUPPORTED 166 167 168if sys.platform == "win32": 169 from _ssl import enum_certificates, enum_crls 170 171from socket import socket, AF_INET, SOCK_STREAM, create_connection 172from socket import SOL_SOCKET, SO_TYPE 173import socket as _socket 174import base64 # for DER-to-PEM translation 175import errno 176import warnings 177 178 179socket_error = OSError # keep that public name in module namespace 180 181CHANNEL_BINDING_TYPES = ['tls-unique'] 182 183HAS_NEVER_CHECK_COMMON_NAME = hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT') 184 185 186_RESTRICTED_SERVER_CIPHERS = _DEFAULT_CIPHERS 187 188CertificateError = SSLCertVerificationError 189 190 191def _dnsname_match(dn, hostname): 192 """Matching according to RFC 6125, section 6.4.3 193 194 - Hostnames are compared lower case. 195 - For IDNA, both dn and hostname must be encoded as IDN A-label (ACE). 196 - Partial wildcards like 'www*.example.org', multiple wildcards, sole 197 wildcard or wildcards in labels other then the left-most label are not 198 supported and a CertificateError is raised. 199 - A wildcard must match at least one character. 200 """ 201 if not dn: 202 return False 203 204 wildcards = dn.count('*') 205 # speed up common case w/o wildcards 206 if not wildcards: 207 return dn.lower() == hostname.lower() 208 209 if wildcards > 1: 210 raise CertificateError( 211 "too many wildcards in certificate DNS name: {!r}.".format(dn)) 212 213 dn_leftmost, sep, dn_remainder = dn.partition('.') 214 215 if '*' in dn_remainder: 216 # Only match wildcard in leftmost segment. 217 raise CertificateError( 218 "wildcard can only be present in the leftmost label: " 219 "{!r}.".format(dn)) 220 221 if not sep: 222 # no right side 223 raise CertificateError( 224 "sole wildcard without additional labels are not support: " 225 "{!r}.".format(dn)) 226 227 if dn_leftmost != '*': 228 # no partial wildcard matching 229 raise CertificateError( 230 "partial wildcards in leftmost label are not supported: " 231 "{!r}.".format(dn)) 232 233 hostname_leftmost, sep, hostname_remainder = hostname.partition('.') 234 if not hostname_leftmost or not sep: 235 # wildcard must match at least one char 236 return False 237 return dn_remainder.lower() == hostname_remainder.lower() 238 239 240def _inet_paton(ipname): 241 """Try to convert an IP address to packed binary form 242 243 Supports IPv4 addresses on all platforms and IPv6 on platforms with IPv6 244 support. 245 """ 246 # inet_aton() also accepts strings like '1' 247 if ipname.count('.') == 3: 248 try: 249 return _socket.inet_aton(ipname) 250 except OSError: 251 pass 252 253 try: 254 return _socket.inet_pton(_socket.AF_INET6, ipname) 255 except OSError: 256 raise ValueError("{!r} is neither an IPv4 nor an IP6 " 257 "address.".format(ipname)) 258 except AttributeError: 259 # AF_INET6 not available 260 pass 261 262 raise ValueError("{!r} is not an IPv4 address.".format(ipname)) 263 264 265def _ipaddress_match(ipname, host_ip): 266 """Exact matching of IP addresses. 267 268 RFC 6125 explicitly doesn't define an algorithm for this 269 (section 1.7.2 - "Out of Scope"). 270 """ 271 # OpenSSL may add a trailing newline to a subjectAltName's IP address 272 ip = _inet_paton(ipname.rstrip()) 273 return ip == host_ip 274 275 276def match_hostname(cert, hostname): 277 """Verify that *cert* (in decoded format as returned by 278 SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125 279 rules are followed. 280 281 The function matches IP addresses rather than dNSNames if hostname is a 282 valid ipaddress string. IPv4 addresses are supported on all platforms. 283 IPv6 addresses are supported on platforms with IPv6 support (AF_INET6 284 and inet_pton). 285 286 CertificateError is raised on failure. On success, the function 287 returns nothing. 288 """ 289 if not cert: 290 raise ValueError("empty or no certificate, match_hostname needs a " 291 "SSL socket or SSL context with either " 292 "CERT_OPTIONAL or CERT_REQUIRED") 293 try: 294 host_ip = _inet_paton(hostname) 295 except ValueError: 296 # Not an IP address (common case) 297 host_ip = None 298 dnsnames = [] 299 san = cert.get('subjectAltName', ()) 300 for key, value in san: 301 if key == 'DNS': 302 if host_ip is None and _dnsname_match(value, hostname): 303 return 304 dnsnames.append(value) 305 elif key == 'IP Address': 306 if host_ip is not None and _ipaddress_match(value, host_ip): 307 return 308 dnsnames.append(value) 309 if not dnsnames: 310 # The subject is only checked when there is no dNSName entry 311 # in subjectAltName 312 for sub in cert.get('subject', ()): 313 for key, value in sub: 314 # XXX according to RFC 2818, the most specific Common Name 315 # must be used. 316 if key == 'commonName': 317 if _dnsname_match(value, hostname): 318 return 319 dnsnames.append(value) 320 if len(dnsnames) > 1: 321 raise CertificateError("hostname %r " 322 "doesn't match either of %s" 323 % (hostname, ', '.join(map(repr, dnsnames)))) 324 elif len(dnsnames) == 1: 325 raise CertificateError("hostname %r " 326 "doesn't match %r" 327 % (hostname, dnsnames[0])) 328 else: 329 raise CertificateError("no appropriate commonName or " 330 "subjectAltName fields were found") 331 332 333DefaultVerifyPaths = namedtuple("DefaultVerifyPaths", 334 "cafile capath openssl_cafile_env openssl_cafile openssl_capath_env " 335 "openssl_capath") 336 337def get_default_verify_paths(): 338 """Return paths to default cafile and capath. 339 """ 340 parts = _ssl.get_default_verify_paths() 341 342 # environment vars shadow paths 343 cafile = os.environ.get(parts[0], parts[1]) 344 capath = os.environ.get(parts[2], parts[3]) 345 346 return DefaultVerifyPaths(cafile if os.path.isfile(cafile) else None, 347 capath if os.path.isdir(capath) else None, 348 *parts) 349 350 351class _ASN1Object(namedtuple("_ASN1Object", "nid shortname longname oid")): 352 """ASN.1 object identifier lookup 353 """ 354 __slots__ = () 355 356 def __new__(cls, oid): 357 return super().__new__(cls, *_txt2obj(oid, name=False)) 358 359 @classmethod 360 def fromnid(cls, nid): 361 """Create _ASN1Object from OpenSSL numeric ID 362 """ 363 return super().__new__(cls, *_nid2obj(nid)) 364 365 @classmethod 366 def fromname(cls, name): 367 """Create _ASN1Object from short name, long name or OID 368 """ 369 return super().__new__(cls, *_txt2obj(name, name=True)) 370 371 372class Purpose(_ASN1Object, _Enum): 373 """SSLContext purpose flags with X509v3 Extended Key Usage objects 374 """ 375 SERVER_AUTH = '1.3.6.1.5.5.7.3.1' 376 CLIENT_AUTH = '1.3.6.1.5.5.7.3.2' 377 378 379class SSLContext(_SSLContext): 380 """An SSLContext holds various SSL-related configuration options and 381 data, such as certificates and possibly a private key.""" 382 _windows_cert_stores = ("CA", "ROOT") 383 384 sslsocket_class = None # SSLSocket is assigned later. 385 sslobject_class = None # SSLObject is assigned later. 386 387 def __new__(cls, protocol=PROTOCOL_TLS, *args, **kwargs): 388 self = _SSLContext.__new__(cls, protocol) 389 return self 390 391 def _encode_hostname(self, hostname): 392 if hostname is None: 393 return None 394 elif isinstance(hostname, str): 395 return hostname.encode('idna').decode('ascii') 396 else: 397 return hostname.decode('ascii') 398 399 def wrap_socket(self, sock, server_side=False, 400 do_handshake_on_connect=True, 401 suppress_ragged_eofs=True, 402 server_hostname=None, session=None): 403 # SSLSocket class handles server_hostname encoding before it calls 404 # ctx._wrap_socket() 405 return self.sslsocket_class._create( 406 sock=sock, 407 server_side=server_side, 408 do_handshake_on_connect=do_handshake_on_connect, 409 suppress_ragged_eofs=suppress_ragged_eofs, 410 server_hostname=server_hostname, 411 context=self, 412 session=session 413 ) 414 415 def wrap_bio(self, incoming, outgoing, server_side=False, 416 server_hostname=None, session=None): 417 # Need to encode server_hostname here because _wrap_bio() can only 418 # handle ASCII str. 419 return self.sslobject_class._create( 420 incoming, outgoing, server_side=server_side, 421 server_hostname=self._encode_hostname(server_hostname), 422 session=session, context=self, 423 ) 424 425 def set_npn_protocols(self, npn_protocols): 426 protos = bytearray() 427 for protocol in npn_protocols: 428 b = bytes(protocol, 'ascii') 429 if len(b) == 0 or len(b) > 255: 430 raise SSLError('NPN protocols must be 1 to 255 in length') 431 protos.append(len(b)) 432 protos.extend(b) 433 434 self._set_npn_protocols(protos) 435 436 def set_servername_callback(self, server_name_callback): 437 if server_name_callback is None: 438 self.sni_callback = None 439 else: 440 if not callable(server_name_callback): 441 raise TypeError("not a callable object") 442 443 def shim_cb(sslobj, servername, sslctx): 444 servername = self._encode_hostname(servername) 445 return server_name_callback(sslobj, servername, sslctx) 446 447 self.sni_callback = shim_cb 448 449 def set_alpn_protocols(self, alpn_protocols): 450 protos = bytearray() 451 for protocol in alpn_protocols: 452 b = bytes(protocol, 'ascii') 453 if len(b) == 0 or len(b) > 255: 454 raise SSLError('ALPN protocols must be 1 to 255 in length') 455 protos.append(len(b)) 456 protos.extend(b) 457 458 self._set_alpn_protocols(protos) 459 460 def _load_windows_store_certs(self, storename, purpose): 461 certs = bytearray() 462 try: 463 for cert, encoding, trust in enum_certificates(storename): 464 # CA certs are never PKCS#7 encoded 465 if encoding == "x509_asn": 466 if trust is True or purpose.oid in trust: 467 certs.extend(cert) 468 except PermissionError: 469 warnings.warn("unable to enumerate Windows certificate store") 470 if certs: 471 self.load_verify_locations(cadata=certs) 472 return certs 473 474 def load_default_certs(self, purpose=Purpose.SERVER_AUTH): 475 if not isinstance(purpose, _ASN1Object): 476 raise TypeError(purpose) 477 if sys.platform == "win32": 478 for storename in self._windows_cert_stores: 479 self._load_windows_store_certs(storename, purpose) 480 self.set_default_verify_paths() 481 482 if hasattr(_SSLContext, 'minimum_version'): 483 @property 484 def minimum_version(self): 485 return TLSVersion(super().minimum_version) 486 487 @minimum_version.setter 488 def minimum_version(self, value): 489 if value == TLSVersion.SSLv3: 490 self.options &= ~Options.OP_NO_SSLv3 491 super(SSLContext, SSLContext).minimum_version.__set__(self, value) 492 493 @property 494 def maximum_version(self): 495 return TLSVersion(super().maximum_version) 496 497 @maximum_version.setter 498 def maximum_version(self, value): 499 super(SSLContext, SSLContext).maximum_version.__set__(self, value) 500 501 @property 502 def options(self): 503 return Options(super().options) 504 505 @options.setter 506 def options(self, value): 507 super(SSLContext, SSLContext).options.__set__(self, value) 508 509 if hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT'): 510 @property 511 def hostname_checks_common_name(self): 512 ncs = self._host_flags & _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT 513 return ncs != _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT 514 515 @hostname_checks_common_name.setter 516 def hostname_checks_common_name(self, value): 517 if value: 518 self._host_flags &= ~_ssl.HOSTFLAG_NEVER_CHECK_SUBJECT 519 else: 520 self._host_flags |= _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT 521 else: 522 @property 523 def hostname_checks_common_name(self): 524 return True 525 526 @property 527 def protocol(self): 528 return _SSLMethod(super().protocol) 529 530 @property 531 def verify_flags(self): 532 return VerifyFlags(super().verify_flags) 533 534 @verify_flags.setter 535 def verify_flags(self, value): 536 super(SSLContext, SSLContext).verify_flags.__set__(self, value) 537 538 @property 539 def verify_mode(self): 540 value = super().verify_mode 541 try: 542 return VerifyMode(value) 543 except ValueError: 544 return value 545 546 @verify_mode.setter 547 def verify_mode(self, value): 548 super(SSLContext, SSLContext).verify_mode.__set__(self, value) 549 550 551def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None, 552 capath=None, cadata=None): 553 """Create a SSLContext object with default settings. 554 555 NOTE: The protocol and settings may change anytime without prior 556 deprecation. The values represent a fair balance between maximum 557 compatibility and security. 558 """ 559 if not isinstance(purpose, _ASN1Object): 560 raise TypeError(purpose) 561 562 # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION, 563 # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE 564 # by default. 565 context = SSLContext(PROTOCOL_TLS) 566 567 if purpose == Purpose.SERVER_AUTH: 568 # verify certs and host name in client mode 569 context.verify_mode = CERT_REQUIRED 570 context.check_hostname = True 571 572 if cafile or capath or cadata: 573 context.load_verify_locations(cafile, capath, cadata) 574 elif context.verify_mode != CERT_NONE: 575 # no explicit cafile, capath or cadata but the verify mode is 576 # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system 577 # root CA certificates for the given purpose. This may fail silently. 578 context.load_default_certs(purpose) 579 return context 580 581def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=CERT_NONE, 582 check_hostname=False, purpose=Purpose.SERVER_AUTH, 583 certfile=None, keyfile=None, 584 cafile=None, capath=None, cadata=None): 585 """Create a SSLContext object for Python stdlib modules 586 587 All Python stdlib modules shall use this function to create SSLContext 588 objects in order to keep common settings in one place. The configuration 589 is less restrict than create_default_context()'s to increase backward 590 compatibility. 591 """ 592 if not isinstance(purpose, _ASN1Object): 593 raise TypeError(purpose) 594 595 # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION, 596 # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE 597 # by default. 598 context = SSLContext(protocol) 599 600 if not check_hostname: 601 context.check_hostname = False 602 if cert_reqs is not None: 603 context.verify_mode = cert_reqs 604 if check_hostname: 605 context.check_hostname = True 606 607 if keyfile and not certfile: 608 raise ValueError("certfile must be specified") 609 if certfile or keyfile: 610 context.load_cert_chain(certfile, keyfile) 611 612 # load CA root certs 613 if cafile or capath or cadata: 614 context.load_verify_locations(cafile, capath, cadata) 615 elif context.verify_mode != CERT_NONE: 616 # no explicit cafile, capath or cadata but the verify mode is 617 # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system 618 # root CA certificates for the given purpose. This may fail silently. 619 context.load_default_certs(purpose) 620 621 return context 622 623# Used by http.client if no context is explicitly passed. 624_create_default_https_context = create_default_context 625 626 627# Backwards compatibility alias, even though it's not a public name. 628_create_stdlib_context = _create_unverified_context 629 630 631class SSLObject: 632 """This class implements an interface on top of a low-level SSL object as 633 implemented by OpenSSL. This object captures the state of an SSL connection 634 but does not provide any network IO itself. IO needs to be performed 635 through separate "BIO" objects which are OpenSSL's IO abstraction layer. 636 637 This class does not have a public constructor. Instances are returned by 638 ``SSLContext.wrap_bio``. This class is typically used by framework authors 639 that want to implement asynchronous IO for SSL through memory buffers. 640 641 When compared to ``SSLSocket``, this object lacks the following features: 642 643 * Any form of network IO, including methods such as ``recv`` and ``send``. 644 * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery. 645 """ 646 def __init__(self, *args, **kwargs): 647 raise TypeError( 648 f"{self.__class__.__name__} does not have a public " 649 f"constructor. Instances are returned by SSLContext.wrap_bio()." 650 ) 651 652 @classmethod 653 def _create(cls, incoming, outgoing, server_side=False, 654 server_hostname=None, session=None, context=None): 655 self = cls.__new__(cls) 656 sslobj = context._wrap_bio( 657 incoming, outgoing, server_side=server_side, 658 server_hostname=server_hostname, 659 owner=self, session=session 660 ) 661 self._sslobj = sslobj 662 return self 663 664 @property 665 def context(self): 666 """The SSLContext that is currently in use.""" 667 return self._sslobj.context 668 669 @context.setter 670 def context(self, ctx): 671 self._sslobj.context = ctx 672 673 @property 674 def session(self): 675 """The SSLSession for client socket.""" 676 return self._sslobj.session 677 678 @session.setter 679 def session(self, session): 680 self._sslobj.session = session 681 682 @property 683 def session_reused(self): 684 """Was the client session reused during handshake""" 685 return self._sslobj.session_reused 686 687 @property 688 def server_side(self): 689 """Whether this is a server-side socket.""" 690 return self._sslobj.server_side 691 692 @property 693 def server_hostname(self): 694 """The currently set server hostname (for SNI), or ``None`` if no 695 server hostame is set.""" 696 return self._sslobj.server_hostname 697 698 def read(self, len=1024, buffer=None): 699 """Read up to 'len' bytes from the SSL object and return them. 700 701 If 'buffer' is provided, read into this buffer and return the number of 702 bytes read. 703 """ 704 if buffer is not None: 705 v = self._sslobj.read(len, buffer) 706 else: 707 v = self._sslobj.read(len) 708 return v 709 710 def write(self, data): 711 """Write 'data' to the SSL object and return the number of bytes 712 written. 713 714 The 'data' argument must support the buffer interface. 715 """ 716 return self._sslobj.write(data) 717 718 def getpeercert(self, binary_form=False): 719 """Returns a formatted version of the data in the certificate provided 720 by the other end of the SSL channel. 721 722 Return None if no certificate was provided, {} if a certificate was 723 provided, but not validated. 724 """ 725 return self._sslobj.getpeercert(binary_form) 726 727 def selected_npn_protocol(self): 728 """Return the currently selected NPN protocol as a string, or ``None`` 729 if a next protocol was not negotiated or if NPN is not supported by one 730 of the peers.""" 731 if _ssl.HAS_NPN: 732 return self._sslobj.selected_npn_protocol() 733 734 def selected_alpn_protocol(self): 735 """Return the currently selected ALPN protocol as a string, or ``None`` 736 if a next protocol was not negotiated or if ALPN is not supported by one 737 of the peers.""" 738 if _ssl.HAS_ALPN: 739 return self._sslobj.selected_alpn_protocol() 740 741 def cipher(self): 742 """Return the currently selected cipher as a 3-tuple ``(name, 743 ssl_version, secret_bits)``.""" 744 return self._sslobj.cipher() 745 746 def shared_ciphers(self): 747 """Return a list of ciphers shared by the client during the handshake or 748 None if this is not a valid server connection. 749 """ 750 return self._sslobj.shared_ciphers() 751 752 def compression(self): 753 """Return the current compression algorithm in use, or ``None`` if 754 compression was not negotiated or not supported by one of the peers.""" 755 return self._sslobj.compression() 756 757 def pending(self): 758 """Return the number of bytes that can be read immediately.""" 759 return self._sslobj.pending() 760 761 def do_handshake(self): 762 """Start the SSL/TLS handshake.""" 763 self._sslobj.do_handshake() 764 765 def unwrap(self): 766 """Start the SSL shutdown handshake.""" 767 return self._sslobj.shutdown() 768 769 def get_channel_binding(self, cb_type="tls-unique"): 770 """Get channel binding data for current connection. Raise ValueError 771 if the requested `cb_type` is not supported. Return bytes of the data 772 or None if the data is not available (e.g. before the handshake).""" 773 return self._sslobj.get_channel_binding(cb_type) 774 775 def version(self): 776 """Return a string identifying the protocol version used by the 777 current SSL channel. """ 778 return self._sslobj.version() 779 780 def verify_client_post_handshake(self): 781 return self._sslobj.verify_client_post_handshake() 782 783 784class SSLSocket(socket): 785 """This class implements a subtype of socket.socket that wraps 786 the underlying OS socket in an SSL context when necessary, and 787 provides read and write methods over that channel. """ 788 789 def __init__(self, *args, **kwargs): 790 raise TypeError( 791 f"{self.__class__.__name__} does not have a public " 792 f"constructor. Instances are returned by " 793 f"SSLContext.wrap_socket()." 794 ) 795 796 @classmethod 797 def _create(cls, sock, server_side=False, do_handshake_on_connect=True, 798 suppress_ragged_eofs=True, server_hostname=None, 799 context=None, session=None): 800 if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM: 801 raise NotImplementedError("only stream sockets are supported") 802 if server_side: 803 if server_hostname: 804 raise ValueError("server_hostname can only be specified " 805 "in client mode") 806 if session is not None: 807 raise ValueError("session can only be specified in " 808 "client mode") 809 if context.check_hostname and not server_hostname: 810 raise ValueError("check_hostname requires server_hostname") 811 812 kwargs = dict( 813 family=sock.family, type=sock.type, proto=sock.proto, 814 fileno=sock.fileno() 815 ) 816 self = cls.__new__(cls, **kwargs) 817 super(SSLSocket, self).__init__(**kwargs) 818 self.settimeout(sock.gettimeout()) 819 sock.detach() 820 821 self._context = context 822 self._session = session 823 self._closed = False 824 self._sslobj = None 825 self.server_side = server_side 826 self.server_hostname = context._encode_hostname(server_hostname) 827 self.do_handshake_on_connect = do_handshake_on_connect 828 self.suppress_ragged_eofs = suppress_ragged_eofs 829 830 # See if we are connected 831 try: 832 self.getpeername() 833 except OSError as e: 834 if e.errno != errno.ENOTCONN: 835 raise 836 connected = False 837 else: 838 connected = True 839 840 self._connected = connected 841 if connected: 842 # create the SSL object 843 try: 844 self._sslobj = self._context._wrap_socket( 845 self, server_side, self.server_hostname, 846 owner=self, session=self._session, 847 ) 848 if do_handshake_on_connect: 849 timeout = self.gettimeout() 850 if timeout == 0.0: 851 # non-blocking 852 raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets") 853 self.do_handshake() 854 except (OSError, ValueError): 855 self.close() 856 raise 857 return self 858 859 @property 860 def context(self): 861 return self._context 862 863 @context.setter 864 def context(self, ctx): 865 self._context = ctx 866 self._sslobj.context = ctx 867 868 @property 869 def session(self): 870 """The SSLSession for client socket.""" 871 if self._sslobj is not None: 872 return self._sslobj.session 873 874 @session.setter 875 def session(self, session): 876 self._session = session 877 if self._sslobj is not None: 878 self._sslobj.session = session 879 880 @property 881 def session_reused(self): 882 """Was the client session reused during handshake""" 883 if self._sslobj is not None: 884 return self._sslobj.session_reused 885 886 def dup(self): 887 raise NotImplementedError("Can't dup() %s instances" % 888 self.__class__.__name__) 889 890 def _checkClosed(self, msg=None): 891 # raise an exception here if you wish to check for spurious closes 892 pass 893 894 def _check_connected(self): 895 if not self._connected: 896 # getpeername() will raise ENOTCONN if the socket is really 897 # not connected; note that we can be connected even without 898 # _connected being set, e.g. if connect() first returned 899 # EAGAIN. 900 self.getpeername() 901 902 def read(self, len=1024, buffer=None): 903 """Read up to LEN bytes and return them. 904 Return zero-length string on EOF.""" 905 906 self._checkClosed() 907 if self._sslobj is None: 908 raise ValueError("Read on closed or unwrapped SSL socket.") 909 try: 910 if buffer is not None: 911 return self._sslobj.read(len, buffer) 912 else: 913 return self._sslobj.read(len) 914 except SSLError as x: 915 if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: 916 if buffer is not None: 917 return 0 918 else: 919 return b'' 920 else: 921 raise 922 923 def write(self, data): 924 """Write DATA to the underlying SSL channel. Returns 925 number of bytes of DATA actually transmitted.""" 926 927 self._checkClosed() 928 if self._sslobj is None: 929 raise ValueError("Write on closed or unwrapped SSL socket.") 930 return self._sslobj.write(data) 931 932 def getpeercert(self, binary_form=False): 933 """Returns a formatted version of the data in the 934 certificate provided by the other end of the SSL channel. 935 Return None if no certificate was provided, {} if a 936 certificate was provided, but not validated.""" 937 938 self._checkClosed() 939 self._check_connected() 940 return self._sslobj.getpeercert(binary_form) 941 942 def selected_npn_protocol(self): 943 self._checkClosed() 944 if self._sslobj is None or not _ssl.HAS_NPN: 945 return None 946 else: 947 return self._sslobj.selected_npn_protocol() 948 949 def selected_alpn_protocol(self): 950 self._checkClosed() 951 if self._sslobj is None or not _ssl.HAS_ALPN: 952 return None 953 else: 954 return self._sslobj.selected_alpn_protocol() 955 956 def cipher(self): 957 self._checkClosed() 958 if self._sslobj is None: 959 return None 960 else: 961 return self._sslobj.cipher() 962 963 def shared_ciphers(self): 964 self._checkClosed() 965 if self._sslobj is None: 966 return None 967 else: 968 return self._sslobj.shared_ciphers() 969 970 def compression(self): 971 self._checkClosed() 972 if self._sslobj is None: 973 return None 974 else: 975 return self._sslobj.compression() 976 977 def send(self, data, flags=0): 978 self._checkClosed() 979 if self._sslobj is not None: 980 if flags != 0: 981 raise ValueError( 982 "non-zero flags not allowed in calls to send() on %s" % 983 self.__class__) 984 return self._sslobj.write(data) 985 else: 986 return super().send(data, flags) 987 988 def sendto(self, data, flags_or_addr, addr=None): 989 self._checkClosed() 990 if self._sslobj is not None: 991 raise ValueError("sendto not allowed on instances of %s" % 992 self.__class__) 993 elif addr is None: 994 return super().sendto(data, flags_or_addr) 995 else: 996 return super().sendto(data, flags_or_addr, addr) 997 998 def sendmsg(self, *args, **kwargs): 999 # Ensure programs don't send data unencrypted if they try to 1000 # use this method. 1001 raise NotImplementedError("sendmsg not allowed on instances of %s" % 1002 self.__class__) 1003 1004 def sendall(self, data, flags=0): 1005 self._checkClosed() 1006 if self._sslobj is not None: 1007 if flags != 0: 1008 raise ValueError( 1009 "non-zero flags not allowed in calls to sendall() on %s" % 1010 self.__class__) 1011 count = 0 1012 with memoryview(data) as view, view.cast("B") as byte_view: 1013 amount = len(byte_view) 1014 while count < amount: 1015 v = self.send(byte_view[count:]) 1016 count += v 1017 else: 1018 return super().sendall(data, flags) 1019 1020 def sendfile(self, file, offset=0, count=None): 1021 """Send a file, possibly by using os.sendfile() if this is a 1022 clear-text socket. Return the total number of bytes sent. 1023 """ 1024 if self._sslobj is not None: 1025 return self._sendfile_use_send(file, offset, count) 1026 else: 1027 # os.sendfile() works with plain sockets only 1028 return super().sendfile(file, offset, count) 1029 1030 def recv(self, buflen=1024, flags=0): 1031 self._checkClosed() 1032 if self._sslobj is not None: 1033 if flags != 0: 1034 raise ValueError( 1035 "non-zero flags not allowed in calls to recv() on %s" % 1036 self.__class__) 1037 return self.read(buflen) 1038 else: 1039 return super().recv(buflen, flags) 1040 1041 def recv_into(self, buffer, nbytes=None, flags=0): 1042 self._checkClosed() 1043 if buffer and (nbytes is None): 1044 nbytes = len(buffer) 1045 elif nbytes is None: 1046 nbytes = 1024 1047 if self._sslobj is not None: 1048 if flags != 0: 1049 raise ValueError( 1050 "non-zero flags not allowed in calls to recv_into() on %s" % 1051 self.__class__) 1052 return self.read(nbytes, buffer) 1053 else: 1054 return super().recv_into(buffer, nbytes, flags) 1055 1056 def recvfrom(self, buflen=1024, flags=0): 1057 self._checkClosed() 1058 if self._sslobj is not None: 1059 raise ValueError("recvfrom not allowed on instances of %s" % 1060 self.__class__) 1061 else: 1062 return super().recvfrom(buflen, flags) 1063 1064 def recvfrom_into(self, buffer, nbytes=None, flags=0): 1065 self._checkClosed() 1066 if self._sslobj is not None: 1067 raise ValueError("recvfrom_into not allowed on instances of %s" % 1068 self.__class__) 1069 else: 1070 return super().recvfrom_into(buffer, nbytes, flags) 1071 1072 def recvmsg(self, *args, **kwargs): 1073 raise NotImplementedError("recvmsg not allowed on instances of %s" % 1074 self.__class__) 1075 1076 def recvmsg_into(self, *args, **kwargs): 1077 raise NotImplementedError("recvmsg_into not allowed on instances of " 1078 "%s" % self.__class__) 1079 1080 def pending(self): 1081 self._checkClosed() 1082 if self._sslobj is not None: 1083 return self._sslobj.pending() 1084 else: 1085 return 0 1086 1087 def shutdown(self, how): 1088 self._checkClosed() 1089 self._sslobj = None 1090 super().shutdown(how) 1091 1092 def unwrap(self): 1093 if self._sslobj: 1094 s = self._sslobj.shutdown() 1095 self._sslobj = None 1096 return s 1097 else: 1098 raise ValueError("No SSL wrapper around " + str(self)) 1099 1100 def verify_client_post_handshake(self): 1101 if self._sslobj: 1102 return self._sslobj.verify_client_post_handshake() 1103 else: 1104 raise ValueError("No SSL wrapper around " + str(self)) 1105 1106 def _real_close(self): 1107 self._sslobj = None 1108 super()._real_close() 1109 1110 def do_handshake(self, block=False): 1111 """Perform a TLS/SSL handshake.""" 1112 self._check_connected() 1113 timeout = self.gettimeout() 1114 try: 1115 if timeout == 0.0 and block: 1116 self.settimeout(None) 1117 self._sslobj.do_handshake() 1118 finally: 1119 self.settimeout(timeout) 1120 1121 def _real_connect(self, addr, connect_ex): 1122 if self.server_side: 1123 raise ValueError("can't connect in server-side mode") 1124 # Here we assume that the socket is client-side, and not 1125 # connected at the time of the call. We connect it, then wrap it. 1126 if self._connected or self._sslobj is not None: 1127 raise ValueError("attempt to connect already-connected SSLSocket!") 1128 self._sslobj = self.context._wrap_socket( 1129 self, False, self.server_hostname, 1130 owner=self, session=self._session 1131 ) 1132 try: 1133 if connect_ex: 1134 rc = super().connect_ex(addr) 1135 else: 1136 rc = None 1137 super().connect(addr) 1138 if not rc: 1139 self._connected = True 1140 if self.do_handshake_on_connect: 1141 self.do_handshake() 1142 return rc 1143 except (OSError, ValueError): 1144 self._sslobj = None 1145 raise 1146 1147 def connect(self, addr): 1148 """Connects to remote ADDR, and then wraps the connection in 1149 an SSL channel.""" 1150 self._real_connect(addr, False) 1151 1152 def connect_ex(self, addr): 1153 """Connects to remote ADDR, and then wraps the connection in 1154 an SSL channel.""" 1155 return self._real_connect(addr, True) 1156 1157 def accept(self): 1158 """Accepts a new connection from a remote client, and returns 1159 a tuple containing that new connection wrapped with a server-side 1160 SSL channel, and the address of the remote client.""" 1161 1162 newsock, addr = super().accept() 1163 newsock = self.context.wrap_socket(newsock, 1164 do_handshake_on_connect=self.do_handshake_on_connect, 1165 suppress_ragged_eofs=self.suppress_ragged_eofs, 1166 server_side=True) 1167 return newsock, addr 1168 1169 def get_channel_binding(self, cb_type="tls-unique"): 1170 """Get channel binding data for current connection. Raise ValueError 1171 if the requested `cb_type` is not supported. Return bytes of the data 1172 or None if the data is not available (e.g. before the handshake). 1173 """ 1174 if self._sslobj is not None: 1175 return self._sslobj.get_channel_binding(cb_type) 1176 else: 1177 if cb_type not in CHANNEL_BINDING_TYPES: 1178 raise ValueError( 1179 "{0} channel binding type not implemented".format(cb_type) 1180 ) 1181 return None 1182 1183 def version(self): 1184 """ 1185 Return a string identifying the protocol version used by the 1186 current SSL channel, or None if there is no established channel. 1187 """ 1188 if self._sslobj is not None: 1189 return self._sslobj.version() 1190 else: 1191 return None 1192 1193 1194# Python does not support forward declaration of types. 1195SSLContext.sslsocket_class = SSLSocket 1196SSLContext.sslobject_class = SSLObject 1197 1198 1199def wrap_socket(sock, keyfile=None, certfile=None, 1200 server_side=False, cert_reqs=CERT_NONE, 1201 ssl_version=PROTOCOL_TLS, ca_certs=None, 1202 do_handshake_on_connect=True, 1203 suppress_ragged_eofs=True, 1204 ciphers=None): 1205 1206 if server_side and not certfile: 1207 raise ValueError("certfile must be specified for server-side " 1208 "operations") 1209 if keyfile and not certfile: 1210 raise ValueError("certfile must be specified") 1211 context = SSLContext(ssl_version) 1212 context.verify_mode = cert_reqs 1213 if ca_certs: 1214 context.load_verify_locations(ca_certs) 1215 if certfile: 1216 context.load_cert_chain(certfile, keyfile) 1217 if ciphers: 1218 context.set_ciphers(ciphers) 1219 return context.wrap_socket( 1220 sock=sock, server_side=server_side, 1221 do_handshake_on_connect=do_handshake_on_connect, 1222 suppress_ragged_eofs=suppress_ragged_eofs 1223 ) 1224 1225# some utility functions 1226 1227def cert_time_to_seconds(cert_time): 1228 """Return the time in seconds since the Epoch, given the timestring 1229 representing the "notBefore" or "notAfter" date from a certificate 1230 in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale). 1231 1232 "notBefore" or "notAfter" dates must use UTC (RFC 5280). 1233 1234 Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec 1235 UTC should be specified as GMT (see ASN1_TIME_print()) 1236 """ 1237 from time import strptime 1238 from calendar import timegm 1239 1240 months = ( 1241 "Jan","Feb","Mar","Apr","May","Jun", 1242 "Jul","Aug","Sep","Oct","Nov","Dec" 1243 ) 1244 time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT 1245 try: 1246 month_number = months.index(cert_time[:3].title()) + 1 1247 except ValueError: 1248 raise ValueError('time data %r does not match ' 1249 'format "%%b%s"' % (cert_time, time_format)) 1250 else: 1251 # found valid month 1252 tt = strptime(cert_time[3:], time_format) 1253 # return an integer, the previous mktime()-based implementation 1254 # returned a float (fractional seconds are always zero here). 1255 return timegm((tt[0], month_number) + tt[2:6]) 1256 1257PEM_HEADER = "-----BEGIN CERTIFICATE-----" 1258PEM_FOOTER = "-----END CERTIFICATE-----" 1259 1260def DER_cert_to_PEM_cert(der_cert_bytes): 1261 """Takes a certificate in binary DER format and returns the 1262 PEM version of it as a string.""" 1263 1264 f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict') 1265 ss = [PEM_HEADER] 1266 ss += [f[i:i+64] for i in range(0, len(f), 64)] 1267 ss.append(PEM_FOOTER + '\n') 1268 return '\n'.join(ss) 1269 1270def PEM_cert_to_DER_cert(pem_cert_string): 1271 """Takes a certificate in ASCII PEM format and returns the 1272 DER-encoded version of it as a byte sequence""" 1273 1274 if not pem_cert_string.startswith(PEM_HEADER): 1275 raise ValueError("Invalid PEM encoding; must start with %s" 1276 % PEM_HEADER) 1277 if not pem_cert_string.strip().endswith(PEM_FOOTER): 1278 raise ValueError("Invalid PEM encoding; must end with %s" 1279 % PEM_FOOTER) 1280 d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] 1281 return base64.decodebytes(d.encode('ASCII', 'strict')) 1282 1283def get_server_certificate(addr, ssl_version=PROTOCOL_TLS, ca_certs=None): 1284 """Retrieve the certificate from the server at the specified address, 1285 and return it as a PEM-encoded string. 1286 If 'ca_certs' is specified, validate the server cert against it. 1287 If 'ssl_version' is specified, use it in the connection attempt.""" 1288 1289 host, port = addr 1290 if ca_certs is not None: 1291 cert_reqs = CERT_REQUIRED 1292 else: 1293 cert_reqs = CERT_NONE 1294 context = _create_stdlib_context(ssl_version, 1295 cert_reqs=cert_reqs, 1296 cafile=ca_certs) 1297 with create_connection(addr) as sock: 1298 with context.wrap_socket(sock) as sslsock: 1299 dercert = sslsock.getpeercert(True) 1300 return DER_cert_to_PEM_cert(dercert) 1301 1302def get_protocol_name(protocol_code): 1303 return _PROTOCOL_NAMES.get(protocol_code, '<unknown>') 1304