• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Test the support for SSL and sockets
2
3import sys
4import unittest
5import unittest.mock
6from test import support
7from test.support import socket_helper
8import socket
9import select
10import time
11import datetime
12import gc
13import os
14import errno
15import pprint
16import urllib.request
17import threading
18import traceback
19import asyncore
20import weakref
21import platform
22import sysconfig
23import functools
24try:
25    import ctypes
26except ImportError:
27    ctypes = None
28
29ssl = support.import_module("ssl")
30
31from ssl import TLSVersion, _TLSContentType, _TLSMessageType
32
33Py_DEBUG = hasattr(sys, 'gettotalrefcount')
34Py_DEBUG_WIN32 = Py_DEBUG and sys.platform == 'win32'
35
36PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
37HOST = socket_helper.HOST
38IS_LIBRESSL = ssl.OPENSSL_VERSION.startswith('LibreSSL')
39IS_OPENSSL_1_1_0 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0)
40IS_OPENSSL_1_1_1 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 1)
41PY_SSL_DEFAULT_CIPHERS = sysconfig.get_config_var('PY_SSL_DEFAULT_CIPHERS')
42
43PROTOCOL_TO_TLS_VERSION = {}
44for proto, ver in (
45    ("PROTOCOL_SSLv23", "SSLv3"),
46    ("PROTOCOL_TLSv1", "TLSv1"),
47    ("PROTOCOL_TLSv1_1", "TLSv1_1"),
48):
49    try:
50        proto = getattr(ssl, proto)
51        ver = getattr(ssl.TLSVersion, ver)
52    except AttributeError:
53        continue
54    PROTOCOL_TO_TLS_VERSION[proto] = ver
55
56def data_file(*name):
57    return os.path.join(os.path.dirname(__file__), *name)
58
59# The custom key and certificate files used in test_ssl are generated
60# using Lib/test/make_ssl_certs.py.
61# Other certificates are simply fetched from the Internet servers they
62# are meant to authenticate.
63
64CERTFILE = data_file("keycert.pem")
65BYTES_CERTFILE = os.fsencode(CERTFILE)
66ONLYCERT = data_file("ssl_cert.pem")
67ONLYKEY = data_file("ssl_key.pem")
68BYTES_ONLYCERT = os.fsencode(ONLYCERT)
69BYTES_ONLYKEY = os.fsencode(ONLYKEY)
70CERTFILE_PROTECTED = data_file("keycert.passwd.pem")
71ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem")
72KEY_PASSWORD = "somepass"
73CAPATH = data_file("capath")
74BYTES_CAPATH = os.fsencode(CAPATH)
75CAFILE_NEURONIO = data_file("capath", "4e1295a3.0")
76CAFILE_CACERT = data_file("capath", "5ed36f99.0")
77
78CERTFILE_INFO = {
79    'issuer': ((('countryName', 'XY'),),
80               (('localityName', 'Castle Anthrax'),),
81               (('organizationName', 'Python Software Foundation'),),
82               (('commonName', 'localhost'),)),
83    'notAfter': 'Aug 26 14:23:15 2028 GMT',
84    'notBefore': 'Aug 29 14:23:15 2018 GMT',
85    'serialNumber': '98A7CF88C74A32ED',
86    'subject': ((('countryName', 'XY'),),
87             (('localityName', 'Castle Anthrax'),),
88             (('organizationName', 'Python Software Foundation'),),
89             (('commonName', 'localhost'),)),
90    'subjectAltName': (('DNS', 'localhost'),),
91    'version': 3
92}
93
94# empty CRL
95CRLFILE = data_file("revocation.crl")
96
97# Two keys and certs signed by the same CA (for SNI tests)
98SIGNED_CERTFILE = data_file("keycert3.pem")
99SIGNED_CERTFILE_HOSTNAME = 'localhost'
100
101SIGNED_CERTFILE_INFO = {
102    'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
103    'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
104    'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
105    'issuer': ((('countryName', 'XY'),),
106            (('organizationName', 'Python Software Foundation CA'),),
107            (('commonName', 'our-ca-server'),)),
108    'notAfter': 'Jul  7 14:23:16 2028 GMT',
109    'notBefore': 'Aug 29 14:23:16 2018 GMT',
110    'serialNumber': 'CB2D80995A69525C',
111    'subject': ((('countryName', 'XY'),),
112             (('localityName', 'Castle Anthrax'),),
113             (('organizationName', 'Python Software Foundation'),),
114             (('commonName', 'localhost'),)),
115    'subjectAltName': (('DNS', 'localhost'),),
116    'version': 3
117}
118
119SIGNED_CERTFILE2 = data_file("keycert4.pem")
120SIGNED_CERTFILE2_HOSTNAME = 'fakehostname'
121SIGNED_CERTFILE_ECC = data_file("keycertecc.pem")
122SIGNED_CERTFILE_ECC_HOSTNAME = 'localhost-ecc'
123
124# Same certificate as pycacert.pem, but without extra text in file
125SIGNING_CA = data_file("capath", "ceff1710.0")
126# cert with all kinds of subject alt names
127ALLSANFILE = data_file("allsans.pem")
128IDNSANSFILE = data_file("idnsans.pem")
129
130REMOTE_HOST = "self-signed.pythontest.net"
131
132EMPTYCERT = data_file("nullcert.pem")
133BADCERT = data_file("badcert.pem")
134NONEXISTINGCERT = data_file("XXXnonexisting.pem")
135BADKEY = data_file("badkey.pem")
136NOKIACERT = data_file("nokia.pem")
137NULLBYTECERT = data_file("nullbytecert.pem")
138TALOS_INVALID_CRLDP = data_file("talos-2019-0758.pem")
139
140DHFILE = data_file("ffdh3072.pem")
141BYTES_DHFILE = os.fsencode(DHFILE)
142
143# Not defined in all versions of OpenSSL
144OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)
145OP_SINGLE_DH_USE = getattr(ssl, "OP_SINGLE_DH_USE", 0)
146OP_SINGLE_ECDH_USE = getattr(ssl, "OP_SINGLE_ECDH_USE", 0)
147OP_CIPHER_SERVER_PREFERENCE = getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0)
148OP_ENABLE_MIDDLEBOX_COMPAT = getattr(ssl, "OP_ENABLE_MIDDLEBOX_COMPAT", 0)
149
150
151def has_tls_protocol(protocol):
152    """Check if a TLS protocol is available and enabled
153
154    :param protocol: enum ssl._SSLMethod member or name
155    :return: bool
156    """
157    if isinstance(protocol, str):
158        assert protocol.startswith('PROTOCOL_')
159        protocol = getattr(ssl, protocol, None)
160        if protocol is None:
161            return False
162    if protocol in {
163        ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_SERVER,
164        ssl.PROTOCOL_TLS_CLIENT
165    }:
166        # auto-negotiate protocols are always available
167        return True
168    name = protocol.name
169    return has_tls_version(name[len('PROTOCOL_'):])
170
171
172@functools.lru_cache
173def has_tls_version(version):
174    """Check if a TLS/SSL version is enabled
175
176    :param version: TLS version name or ssl.TLSVersion member
177    :return: bool
178    """
179    if version == "SSLv2":
180        # never supported and not even in TLSVersion enum
181        return False
182
183    if isinstance(version, str):
184        version = ssl.TLSVersion.__members__[version]
185
186    # check compile time flags like ssl.HAS_TLSv1_2
187    if not getattr(ssl, f'HAS_{version.name}'):
188        return False
189
190    # check runtime and dynamic crypto policy settings. A TLS version may
191    # be compiled in but disabled by a policy or config option.
192    ctx = ssl.SSLContext()
193    if (
194            hasattr(ctx, 'minimum_version') and
195            ctx.minimum_version != ssl.TLSVersion.MINIMUM_SUPPORTED and
196            version < ctx.minimum_version
197    ):
198        return False
199    if (
200        hasattr(ctx, 'maximum_version') and
201        ctx.maximum_version != ssl.TLSVersion.MAXIMUM_SUPPORTED and
202        version > ctx.maximum_version
203    ):
204        return False
205
206    return True
207
208
209def requires_tls_version(version):
210    """Decorator to skip tests when a required TLS version is not available
211
212    :param version: TLS version name or ssl.TLSVersion member
213    :return:
214    """
215    def decorator(func):
216        @functools.wraps(func)
217        def wrapper(*args, **kw):
218            if not has_tls_version(version):
219                raise unittest.SkipTest(f"{version} is not available.")
220            else:
221                return func(*args, **kw)
222        return wrapper
223    return decorator
224
225
226requires_minimum_version = unittest.skipUnless(
227    hasattr(ssl.SSLContext, 'minimum_version'),
228    "required OpenSSL >= 1.1.0g"
229)
230
231
232def handle_error(prefix):
233    exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
234    if support.verbose:
235        sys.stdout.write(prefix + exc_format)
236
237def can_clear_options():
238    # 0.9.8m or higher
239    return ssl._OPENSSL_API_VERSION >= (0, 9, 8, 13, 15)
240
241def no_sslv2_implies_sslv3_hello():
242    # 0.9.7h or higher
243    return ssl.OPENSSL_VERSION_INFO >= (0, 9, 7, 8, 15)
244
245def have_verify_flags():
246    # 0.9.8 or higher
247    return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 0, 15)
248
249def _have_secp_curves():
250    if not ssl.HAS_ECDH:
251        return False
252    ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
253    try:
254        ctx.set_ecdh_curve("secp384r1")
255    except ValueError:
256        return False
257    else:
258        return True
259
260
261HAVE_SECP_CURVES = _have_secp_curves()
262
263
264def utc_offset(): #NOTE: ignore issues like #1647654
265    # local time = utc time + utc offset
266    if time.daylight and time.localtime().tm_isdst > 0:
267        return -time.altzone  # seconds
268    return -time.timezone
269
270def asn1time(cert_time):
271    # Some versions of OpenSSL ignore seconds, see #18207
272    # 0.9.8.i
273    if ssl._OPENSSL_API_VERSION == (0, 9, 8, 9, 15):
274        fmt = "%b %d %H:%M:%S %Y GMT"
275        dt = datetime.datetime.strptime(cert_time, fmt)
276        dt = dt.replace(second=0)
277        cert_time = dt.strftime(fmt)
278        # %d adds leading zero but ASN1_TIME_print() uses leading space
279        if cert_time[4] == "0":
280            cert_time = cert_time[:4] + " " + cert_time[5:]
281
282    return cert_time
283
284needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test")
285
286
287def test_wrap_socket(sock, ssl_version=ssl.PROTOCOL_TLS, *,
288                     cert_reqs=ssl.CERT_NONE, ca_certs=None,
289                     ciphers=None, certfile=None, keyfile=None,
290                     **kwargs):
291    context = ssl.SSLContext(ssl_version)
292    if cert_reqs is not None:
293        if cert_reqs == ssl.CERT_NONE:
294            context.check_hostname = False
295        context.verify_mode = cert_reqs
296    if ca_certs is not None:
297        context.load_verify_locations(ca_certs)
298    if certfile is not None or keyfile is not None:
299        context.load_cert_chain(certfile, keyfile)
300    if ciphers is not None:
301        context.set_ciphers(ciphers)
302    return context.wrap_socket(sock, **kwargs)
303
304
305def testing_context(server_cert=SIGNED_CERTFILE):
306    """Create context
307
308    client_context, server_context, hostname = testing_context()
309    """
310    if server_cert == SIGNED_CERTFILE:
311        hostname = SIGNED_CERTFILE_HOSTNAME
312    elif server_cert == SIGNED_CERTFILE2:
313        hostname = SIGNED_CERTFILE2_HOSTNAME
314    else:
315        raise ValueError(server_cert)
316
317    client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
318    client_context.load_verify_locations(SIGNING_CA)
319
320    server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
321    server_context.load_cert_chain(server_cert)
322    server_context.load_verify_locations(SIGNING_CA)
323
324    return client_context, server_context, hostname
325
326
327class BasicSocketTests(unittest.TestCase):
328
329    def test_constants(self):
330        ssl.CERT_NONE
331        ssl.CERT_OPTIONAL
332        ssl.CERT_REQUIRED
333        ssl.OP_CIPHER_SERVER_PREFERENCE
334        ssl.OP_SINGLE_DH_USE
335        if ssl.HAS_ECDH:
336            ssl.OP_SINGLE_ECDH_USE
337        if ssl.OPENSSL_VERSION_INFO >= (1, 0):
338            ssl.OP_NO_COMPRESSION
339        self.assertIn(ssl.HAS_SNI, {True, False})
340        self.assertIn(ssl.HAS_ECDH, {True, False})
341        ssl.OP_NO_SSLv2
342        ssl.OP_NO_SSLv3
343        ssl.OP_NO_TLSv1
344        ssl.OP_NO_TLSv1_3
345        if ssl.OPENSSL_VERSION_INFO >= (1, 0, 1):
346            ssl.OP_NO_TLSv1_1
347            ssl.OP_NO_TLSv1_2
348        self.assertEqual(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv23)
349
350    def test_private_init(self):
351        with self.assertRaisesRegex(TypeError, "public constructor"):
352            with socket.socket() as s:
353                ssl.SSLSocket(s)
354
355    def test_str_for_enums(self):
356        # Make sure that the PROTOCOL_* constants have enum-like string
357        # reprs.
358        proto = ssl.PROTOCOL_TLS
359        self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_TLS')
360        ctx = ssl.SSLContext(proto)
361        self.assertIs(ctx.protocol, proto)
362
363    def test_random(self):
364        v = ssl.RAND_status()
365        if support.verbose:
366            sys.stdout.write("\n RAND_status is %d (%s)\n"
367                             % (v, (v and "sufficient randomness") or
368                                "insufficient randomness"))
369
370        data, is_cryptographic = ssl.RAND_pseudo_bytes(16)
371        self.assertEqual(len(data), 16)
372        self.assertEqual(is_cryptographic, v == 1)
373        if v:
374            data = ssl.RAND_bytes(16)
375            self.assertEqual(len(data), 16)
376        else:
377            self.assertRaises(ssl.SSLError, ssl.RAND_bytes, 16)
378
379        # negative num is invalid
380        self.assertRaises(ValueError, ssl.RAND_bytes, -5)
381        self.assertRaises(ValueError, ssl.RAND_pseudo_bytes, -5)
382
383        if hasattr(ssl, 'RAND_egd'):
384            self.assertRaises(TypeError, ssl.RAND_egd, 1)
385            self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1)
386        ssl.RAND_add("this is a random string", 75.0)
387        ssl.RAND_add(b"this is a random bytes object", 75.0)
388        ssl.RAND_add(bytearray(b"this is a random bytearray object"), 75.0)
389
390    @unittest.skipUnless(os.name == 'posix', 'requires posix')
391    def test_random_fork(self):
392        status = ssl.RAND_status()
393        if not status:
394            self.fail("OpenSSL's PRNG has insufficient randomness")
395
396        rfd, wfd = os.pipe()
397        pid = os.fork()
398        if pid == 0:
399            try:
400                os.close(rfd)
401                child_random = ssl.RAND_pseudo_bytes(16)[0]
402                self.assertEqual(len(child_random), 16)
403                os.write(wfd, child_random)
404                os.close(wfd)
405            except BaseException:
406                os._exit(1)
407            else:
408                os._exit(0)
409        else:
410            os.close(wfd)
411            self.addCleanup(os.close, rfd)
412            support.wait_process(pid, exitcode=0)
413
414            child_random = os.read(rfd, 16)
415            self.assertEqual(len(child_random), 16)
416            parent_random = ssl.RAND_pseudo_bytes(16)[0]
417            self.assertEqual(len(parent_random), 16)
418
419            self.assertNotEqual(child_random, parent_random)
420
421    maxDiff = None
422
423    def test_parse_cert(self):
424        # note that this uses an 'unofficial' function in _ssl.c,
425        # provided solely for this test, to exercise the certificate
426        # parsing code
427        self.assertEqual(
428            ssl._ssl._test_decode_cert(CERTFILE),
429            CERTFILE_INFO
430        )
431        self.assertEqual(
432            ssl._ssl._test_decode_cert(SIGNED_CERTFILE),
433            SIGNED_CERTFILE_INFO
434        )
435
436        # Issue #13034: the subjectAltName in some certificates
437        # (notably projects.developer.nokia.com:443) wasn't parsed
438        p = ssl._ssl._test_decode_cert(NOKIACERT)
439        if support.verbose:
440            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
441        self.assertEqual(p['subjectAltName'],
442                         (('DNS', 'projects.developer.nokia.com'),
443                          ('DNS', 'projects.forum.nokia.com'))
444                        )
445        # extra OCSP and AIA fields
446        self.assertEqual(p['OCSP'], ('http://ocsp.verisign.com',))
447        self.assertEqual(p['caIssuers'],
448                         ('http://SVRIntl-G3-aia.verisign.com/SVRIntlG3.cer',))
449        self.assertEqual(p['crlDistributionPoints'],
450                         ('http://SVRIntl-G3-crl.verisign.com/SVRIntlG3.crl',))
451
452    def test_parse_cert_CVE_2019_5010(self):
453        p = ssl._ssl._test_decode_cert(TALOS_INVALID_CRLDP)
454        if support.verbose:
455            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
456        self.assertEqual(
457            p,
458            {
459                'issuer': (
460                    (('countryName', 'UK'),), (('commonName', 'cody-ca'),)),
461                'notAfter': 'Jun 14 18:00:58 2028 GMT',
462                'notBefore': 'Jun 18 18:00:58 2018 GMT',
463                'serialNumber': '02',
464                'subject': ((('countryName', 'UK'),),
465                            (('commonName',
466                              'codenomicon-vm-2.test.lal.cisco.com'),)),
467                'subjectAltName': (
468                    ('DNS', 'codenomicon-vm-2.test.lal.cisco.com'),),
469                'version': 3
470            }
471        )
472
473    def test_parse_cert_CVE_2013_4238(self):
474        p = ssl._ssl._test_decode_cert(NULLBYTECERT)
475        if support.verbose:
476            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
477        subject = ((('countryName', 'US'),),
478                   (('stateOrProvinceName', 'Oregon'),),
479                   (('localityName', 'Beaverton'),),
480                   (('organizationName', 'Python Software Foundation'),),
481                   (('organizationalUnitName', 'Python Core Development'),),
482                   (('commonName', 'null.python.org\x00example.org'),),
483                   (('emailAddress', 'python-dev@python.org'),))
484        self.assertEqual(p['subject'], subject)
485        self.assertEqual(p['issuer'], subject)
486        if ssl._OPENSSL_API_VERSION >= (0, 9, 8):
487            san = (('DNS', 'altnull.python.org\x00example.com'),
488                   ('email', 'null@python.org\x00user@example.org'),
489                   ('URI', 'http://null.python.org\x00http://example.org'),
490                   ('IP Address', '192.0.2.1'),
491                   ('IP Address', '2001:DB8:0:0:0:0:0:1'))
492        else:
493            # OpenSSL 0.9.7 doesn't support IPv6 addresses in subjectAltName
494            san = (('DNS', 'altnull.python.org\x00example.com'),
495                   ('email', 'null@python.org\x00user@example.org'),
496                   ('URI', 'http://null.python.org\x00http://example.org'),
497                   ('IP Address', '192.0.2.1'),
498                   ('IP Address', '<invalid>'))
499
500        self.assertEqual(p['subjectAltName'], san)
501
502    def test_parse_all_sans(self):
503        p = ssl._ssl._test_decode_cert(ALLSANFILE)
504        self.assertEqual(p['subjectAltName'],
505            (
506                ('DNS', 'allsans'),
507                ('othername', '<unsupported>'),
508                ('othername', '<unsupported>'),
509                ('email', 'user@example.org'),
510                ('DNS', 'www.example.org'),
511                ('DirName',
512                    ((('countryName', 'XY'),),
513                    (('localityName', 'Castle Anthrax'),),
514                    (('organizationName', 'Python Software Foundation'),),
515                    (('commonName', 'dirname example'),))),
516                ('URI', 'https://www.python.org/'),
517                ('IP Address', '127.0.0.1'),
518                ('IP Address', '0:0:0:0:0:0:0:1'),
519                ('Registered ID', '1.2.3.4.5')
520            )
521        )
522
523    def test_DER_to_PEM(self):
524        with open(CAFILE_CACERT, 'r') as f:
525            pem = f.read()
526        d1 = ssl.PEM_cert_to_DER_cert(pem)
527        p2 = ssl.DER_cert_to_PEM_cert(d1)
528        d2 = ssl.PEM_cert_to_DER_cert(p2)
529        self.assertEqual(d1, d2)
530        if not p2.startswith(ssl.PEM_HEADER + '\n'):
531            self.fail("DER-to-PEM didn't include correct header:\n%r\n" % p2)
532        if not p2.endswith('\n' + ssl.PEM_FOOTER + '\n'):
533            self.fail("DER-to-PEM didn't include correct footer:\n%r\n" % p2)
534
535    def test_openssl_version(self):
536        n = ssl.OPENSSL_VERSION_NUMBER
537        t = ssl.OPENSSL_VERSION_INFO
538        s = ssl.OPENSSL_VERSION
539        self.assertIsInstance(n, int)
540        self.assertIsInstance(t, tuple)
541        self.assertIsInstance(s, str)
542        # Some sanity checks follow
543        # >= 0.9
544        self.assertGreaterEqual(n, 0x900000)
545        # < 4.0
546        self.assertLess(n, 0x40000000)
547        major, minor, fix, patch, status = t
548        self.assertGreaterEqual(major, 1)
549        self.assertLess(major, 4)
550        self.assertGreaterEqual(minor, 0)
551        self.assertLess(minor, 256)
552        self.assertGreaterEqual(fix, 0)
553        self.assertLess(fix, 256)
554        self.assertGreaterEqual(patch, 0)
555        self.assertLessEqual(patch, 63)
556        self.assertGreaterEqual(status, 0)
557        self.assertLessEqual(status, 15)
558        # Version string as returned by {Open,Libre}SSL, the format might change
559        if IS_LIBRESSL:
560            self.assertTrue(s.startswith("LibreSSL {:d}".format(major)),
561                            (s, t, hex(n)))
562        else:
563            self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
564                            (s, t, hex(n)))
565
566    @support.cpython_only
567    def test_refcycle(self):
568        # Issue #7943: an SSL object doesn't create reference cycles with
569        # itself.
570        s = socket.socket(socket.AF_INET)
571        ss = test_wrap_socket(s)
572        wr = weakref.ref(ss)
573        with support.check_warnings(("", ResourceWarning)):
574            del ss
575        self.assertEqual(wr(), None)
576
577    def test_wrapped_unconnected(self):
578        # Methods on an unconnected SSLSocket propagate the original
579        # OSError raise by the underlying socket object.
580        s = socket.socket(socket.AF_INET)
581        with test_wrap_socket(s) as ss:
582            self.assertRaises(OSError, ss.recv, 1)
583            self.assertRaises(OSError, ss.recv_into, bytearray(b'x'))
584            self.assertRaises(OSError, ss.recvfrom, 1)
585            self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1)
586            self.assertRaises(OSError, ss.send, b'x')
587            self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0))
588            self.assertRaises(NotImplementedError, ss.dup)
589            self.assertRaises(NotImplementedError, ss.sendmsg,
590                              [b'x'], (), 0, ('0.0.0.0', 0))
591            self.assertRaises(NotImplementedError, ss.recvmsg, 100)
592            self.assertRaises(NotImplementedError, ss.recvmsg_into,
593                              [bytearray(100)])
594
595    def test_timeout(self):
596        # Issue #8524: when creating an SSL socket, the timeout of the
597        # original socket should be retained.
598        for timeout in (None, 0.0, 5.0):
599            s = socket.socket(socket.AF_INET)
600            s.settimeout(timeout)
601            with test_wrap_socket(s) as ss:
602                self.assertEqual(timeout, ss.gettimeout())
603
604    def test_errors_sslwrap(self):
605        sock = socket.socket()
606        self.assertRaisesRegex(ValueError,
607                        "certfile must be specified",
608                        ssl.wrap_socket, sock, keyfile=CERTFILE)
609        self.assertRaisesRegex(ValueError,
610                        "certfile must be specified for server-side operations",
611                        ssl.wrap_socket, sock, server_side=True)
612        self.assertRaisesRegex(ValueError,
613                        "certfile must be specified for server-side operations",
614                         ssl.wrap_socket, sock, server_side=True, certfile="")
615        with ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) as s:
616            self.assertRaisesRegex(ValueError, "can't connect in server-side mode",
617                                     s.connect, (HOST, 8080))
618        with self.assertRaises(OSError) as cm:
619            with socket.socket() as sock:
620                ssl.wrap_socket(sock, certfile=NONEXISTINGCERT)
621        self.assertEqual(cm.exception.errno, errno.ENOENT)
622        with self.assertRaises(OSError) as cm:
623            with socket.socket() as sock:
624                ssl.wrap_socket(sock,
625                    certfile=CERTFILE, keyfile=NONEXISTINGCERT)
626        self.assertEqual(cm.exception.errno, errno.ENOENT)
627        with self.assertRaises(OSError) as cm:
628            with socket.socket() as sock:
629                ssl.wrap_socket(sock,
630                    certfile=NONEXISTINGCERT, keyfile=NONEXISTINGCERT)
631        self.assertEqual(cm.exception.errno, errno.ENOENT)
632
633    def bad_cert_test(self, certfile):
634        """Check that trying to use the given client certificate fails"""
635        certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
636                                   certfile)
637        sock = socket.socket()
638        self.addCleanup(sock.close)
639        with self.assertRaises(ssl.SSLError):
640            test_wrap_socket(sock,
641                             certfile=certfile)
642
643    def test_empty_cert(self):
644        """Wrapping with an empty cert file"""
645        self.bad_cert_test("nullcert.pem")
646
647    def test_malformed_cert(self):
648        """Wrapping with a badly formatted certificate (syntax error)"""
649        self.bad_cert_test("badcert.pem")
650
651    def test_malformed_key(self):
652        """Wrapping with a badly formatted key (syntax error)"""
653        self.bad_cert_test("badkey.pem")
654
655    def test_match_hostname(self):
656        def ok(cert, hostname):
657            ssl.match_hostname(cert, hostname)
658        def fail(cert, hostname):
659            self.assertRaises(ssl.CertificateError,
660                              ssl.match_hostname, cert, hostname)
661
662        # -- Hostname matching --
663
664        cert = {'subject': ((('commonName', 'example.com'),),)}
665        ok(cert, 'example.com')
666        ok(cert, 'ExAmple.cOm')
667        fail(cert, 'www.example.com')
668        fail(cert, '.example.com')
669        fail(cert, 'example.org')
670        fail(cert, 'exampleXcom')
671
672        cert = {'subject': ((('commonName', '*.a.com'),),)}
673        ok(cert, 'foo.a.com')
674        fail(cert, 'bar.foo.a.com')
675        fail(cert, 'a.com')
676        fail(cert, 'Xa.com')
677        fail(cert, '.a.com')
678
679        # only match wildcards when they are the only thing
680        # in left-most segment
681        cert = {'subject': ((('commonName', 'f*.com'),),)}
682        fail(cert, 'foo.com')
683        fail(cert, 'f.com')
684        fail(cert, 'bar.com')
685        fail(cert, 'foo.a.com')
686        fail(cert, 'bar.foo.com')
687
688        # NULL bytes are bad, CVE-2013-4073
689        cert = {'subject': ((('commonName',
690                              'null.python.org\x00example.org'),),)}
691        ok(cert, 'null.python.org\x00example.org') # or raise an error?
692        fail(cert, 'example.org')
693        fail(cert, 'null.python.org')
694
695        # error cases with wildcards
696        cert = {'subject': ((('commonName', '*.*.a.com'),),)}
697        fail(cert, 'bar.foo.a.com')
698        fail(cert, 'a.com')
699        fail(cert, 'Xa.com')
700        fail(cert, '.a.com')
701
702        cert = {'subject': ((('commonName', 'a.*.com'),),)}
703        fail(cert, 'a.foo.com')
704        fail(cert, 'a..com')
705        fail(cert, 'a.com')
706
707        # wildcard doesn't match IDNA prefix 'xn--'
708        idna = 'püthon.python.org'.encode("idna").decode("ascii")
709        cert = {'subject': ((('commonName', idna),),)}
710        ok(cert, idna)
711        cert = {'subject': ((('commonName', 'x*.python.org'),),)}
712        fail(cert, idna)
713        cert = {'subject': ((('commonName', 'xn--p*.python.org'),),)}
714        fail(cert, idna)
715
716        # wildcard in first fragment and  IDNA A-labels in sequent fragments
717        # are supported.
718        idna = 'www*.pythön.org'.encode("idna").decode("ascii")
719        cert = {'subject': ((('commonName', idna),),)}
720        fail(cert, 'www.pythön.org'.encode("idna").decode("ascii"))
721        fail(cert, 'www1.pythön.org'.encode("idna").decode("ascii"))
722        fail(cert, 'ftp.pythön.org'.encode("idna").decode("ascii"))
723        fail(cert, 'pythön.org'.encode("idna").decode("ascii"))
724
725        # Slightly fake real-world example
726        cert = {'notAfter': 'Jun 26 21:41:46 2011 GMT',
727                'subject': ((('commonName', 'linuxfrz.org'),),),
728                'subjectAltName': (('DNS', 'linuxfr.org'),
729                                   ('DNS', 'linuxfr.com'),
730                                   ('othername', '<unsupported>'))}
731        ok(cert, 'linuxfr.org')
732        ok(cert, 'linuxfr.com')
733        # Not a "DNS" entry
734        fail(cert, '<unsupported>')
735        # When there is a subjectAltName, commonName isn't used
736        fail(cert, 'linuxfrz.org')
737
738        # A pristine real-world example
739        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
740                'subject': ((('countryName', 'US'),),
741                            (('stateOrProvinceName', 'California'),),
742                            (('localityName', 'Mountain View'),),
743                            (('organizationName', 'Google Inc'),),
744                            (('commonName', 'mail.google.com'),))}
745        ok(cert, 'mail.google.com')
746        fail(cert, 'gmail.com')
747        # Only commonName is considered
748        fail(cert, 'California')
749
750        # -- IPv4 matching --
751        cert = {'subject': ((('commonName', 'example.com'),),),
752                'subjectAltName': (('DNS', 'example.com'),
753                                   ('IP Address', '10.11.12.13'),
754                                   ('IP Address', '14.15.16.17'),
755                                   ('IP Address', '127.0.0.1'))}
756        ok(cert, '10.11.12.13')
757        ok(cert, '14.15.16.17')
758        # socket.inet_ntoa(socket.inet_aton('127.1')) == '127.0.0.1'
759        fail(cert, '127.1')
760        fail(cert, '14.15.16.17 ')
761        fail(cert, '14.15.16.17 extra data')
762        fail(cert, '14.15.16.18')
763        fail(cert, 'example.net')
764
765        # -- IPv6 matching --
766        if socket_helper.IPV6_ENABLED:
767            cert = {'subject': ((('commonName', 'example.com'),),),
768                    'subjectAltName': (
769                        ('DNS', 'example.com'),
770                        ('IP Address', '2001:0:0:0:0:0:0:CAFE\n'),
771                        ('IP Address', '2003:0:0:0:0:0:0:BABA\n'))}
772            ok(cert, '2001::cafe')
773            ok(cert, '2003::baba')
774            fail(cert, '2003::baba ')
775            fail(cert, '2003::baba extra data')
776            fail(cert, '2003::bebe')
777            fail(cert, 'example.net')
778
779        # -- Miscellaneous --
780
781        # Neither commonName nor subjectAltName
782        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
783                'subject': ((('countryName', 'US'),),
784                            (('stateOrProvinceName', 'California'),),
785                            (('localityName', 'Mountain View'),),
786                            (('organizationName', 'Google Inc'),))}
787        fail(cert, 'mail.google.com')
788
789        # No DNS entry in subjectAltName but a commonName
790        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
791                'subject': ((('countryName', 'US'),),
792                            (('stateOrProvinceName', 'California'),),
793                            (('localityName', 'Mountain View'),),
794                            (('commonName', 'mail.google.com'),)),
795                'subjectAltName': (('othername', 'blabla'), )}
796        ok(cert, 'mail.google.com')
797
798        # No DNS entry subjectAltName and no commonName
799        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
800                'subject': ((('countryName', 'US'),),
801                            (('stateOrProvinceName', 'California'),),
802                            (('localityName', 'Mountain View'),),
803                            (('organizationName', 'Google Inc'),)),
804                'subjectAltName': (('othername', 'blabla'),)}
805        fail(cert, 'google.com')
806
807        # Empty cert / no cert
808        self.assertRaises(ValueError, ssl.match_hostname, None, 'example.com')
809        self.assertRaises(ValueError, ssl.match_hostname, {}, 'example.com')
810
811        # Issue #17980: avoid denials of service by refusing more than one
812        # wildcard per fragment.
813        cert = {'subject': ((('commonName', 'a*b.example.com'),),)}
814        with self.assertRaisesRegex(
815                ssl.CertificateError,
816                "partial wildcards in leftmost label are not supported"):
817            ssl.match_hostname(cert, 'axxb.example.com')
818
819        cert = {'subject': ((('commonName', 'www.*.example.com'),),)}
820        with self.assertRaisesRegex(
821                ssl.CertificateError,
822                "wildcard can only be present in the leftmost label"):
823            ssl.match_hostname(cert, 'www.sub.example.com')
824
825        cert = {'subject': ((('commonName', 'a*b*.example.com'),),)}
826        with self.assertRaisesRegex(
827                ssl.CertificateError,
828                "too many wildcards"):
829            ssl.match_hostname(cert, 'axxbxxc.example.com')
830
831        cert = {'subject': ((('commonName', '*'),),)}
832        with self.assertRaisesRegex(
833                ssl.CertificateError,
834                "sole wildcard without additional labels are not support"):
835            ssl.match_hostname(cert, 'host')
836
837        cert = {'subject': ((('commonName', '*.com'),),)}
838        with self.assertRaisesRegex(
839                ssl.CertificateError,
840                r"hostname 'com' doesn't match '\*.com'"):
841            ssl.match_hostname(cert, 'com')
842
843        # extra checks for _inet_paton()
844        for invalid in ['1', '', '1.2.3', '256.0.0.1', '127.0.0.1/24']:
845            with self.assertRaises(ValueError):
846                ssl._inet_paton(invalid)
847        for ipaddr in ['127.0.0.1', '192.168.0.1']:
848            self.assertTrue(ssl._inet_paton(ipaddr))
849        if socket_helper.IPV6_ENABLED:
850            for ipaddr in ['::1', '2001:db8:85a3::8a2e:370:7334']:
851                self.assertTrue(ssl._inet_paton(ipaddr))
852
853    def test_server_side(self):
854        # server_hostname doesn't work for server sockets
855        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
856        with socket.socket() as sock:
857            self.assertRaises(ValueError, ctx.wrap_socket, sock, True,
858                              server_hostname="some.hostname")
859
860    def test_unknown_channel_binding(self):
861        # should raise ValueError for unknown type
862        s = socket.create_server(('127.0.0.1', 0))
863        c = socket.socket(socket.AF_INET)
864        c.connect(s.getsockname())
865        with test_wrap_socket(c, do_handshake_on_connect=False) as ss:
866            with self.assertRaises(ValueError):
867                ss.get_channel_binding("unknown-type")
868        s.close()
869
870    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
871                         "'tls-unique' channel binding not available")
872    def test_tls_unique_channel_binding(self):
873        # unconnected should return None for known type
874        s = socket.socket(socket.AF_INET)
875        with test_wrap_socket(s) as ss:
876            self.assertIsNone(ss.get_channel_binding("tls-unique"))
877        # the same for server-side
878        s = socket.socket(socket.AF_INET)
879        with test_wrap_socket(s, server_side=True, certfile=CERTFILE) as ss:
880            self.assertIsNone(ss.get_channel_binding("tls-unique"))
881
882    def test_dealloc_warn(self):
883        ss = test_wrap_socket(socket.socket(socket.AF_INET))
884        r = repr(ss)
885        with self.assertWarns(ResourceWarning) as cm:
886            ss = None
887            support.gc_collect()
888        self.assertIn(r, str(cm.warning.args[0]))
889
890    def test_get_default_verify_paths(self):
891        paths = ssl.get_default_verify_paths()
892        self.assertEqual(len(paths), 6)
893        self.assertIsInstance(paths, ssl.DefaultVerifyPaths)
894
895        with support.EnvironmentVarGuard() as env:
896            env["SSL_CERT_DIR"] = CAPATH
897            env["SSL_CERT_FILE"] = CERTFILE
898            paths = ssl.get_default_verify_paths()
899            self.assertEqual(paths.cafile, CERTFILE)
900            self.assertEqual(paths.capath, CAPATH)
901
902    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
903    def test_enum_certificates(self):
904        self.assertTrue(ssl.enum_certificates("CA"))
905        self.assertTrue(ssl.enum_certificates("ROOT"))
906
907        self.assertRaises(TypeError, ssl.enum_certificates)
908        self.assertRaises(WindowsError, ssl.enum_certificates, "")
909
910        trust_oids = set()
911        for storename in ("CA", "ROOT"):
912            store = ssl.enum_certificates(storename)
913            self.assertIsInstance(store, list)
914            for element in store:
915                self.assertIsInstance(element, tuple)
916                self.assertEqual(len(element), 3)
917                cert, enc, trust = element
918                self.assertIsInstance(cert, bytes)
919                self.assertIn(enc, {"x509_asn", "pkcs_7_asn"})
920                self.assertIsInstance(trust, (frozenset, set, bool))
921                if isinstance(trust, (frozenset, set)):
922                    trust_oids.update(trust)
923
924        serverAuth = "1.3.6.1.5.5.7.3.1"
925        self.assertIn(serverAuth, trust_oids)
926
927    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
928    def test_enum_crls(self):
929        self.assertTrue(ssl.enum_crls("CA"))
930        self.assertRaises(TypeError, ssl.enum_crls)
931        self.assertRaises(WindowsError, ssl.enum_crls, "")
932
933        crls = ssl.enum_crls("CA")
934        self.assertIsInstance(crls, list)
935        for element in crls:
936            self.assertIsInstance(element, tuple)
937            self.assertEqual(len(element), 2)
938            self.assertIsInstance(element[0], bytes)
939            self.assertIn(element[1], {"x509_asn", "pkcs_7_asn"})
940
941
942    def test_asn1object(self):
943        expected = (129, 'serverAuth', 'TLS Web Server Authentication',
944                    '1.3.6.1.5.5.7.3.1')
945
946        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
947        self.assertEqual(val, expected)
948        self.assertEqual(val.nid, 129)
949        self.assertEqual(val.shortname, 'serverAuth')
950        self.assertEqual(val.longname, 'TLS Web Server Authentication')
951        self.assertEqual(val.oid, '1.3.6.1.5.5.7.3.1')
952        self.assertIsInstance(val, ssl._ASN1Object)
953        self.assertRaises(ValueError, ssl._ASN1Object, 'serverAuth')
954
955        val = ssl._ASN1Object.fromnid(129)
956        self.assertEqual(val, expected)
957        self.assertIsInstance(val, ssl._ASN1Object)
958        self.assertRaises(ValueError, ssl._ASN1Object.fromnid, -1)
959        with self.assertRaisesRegex(ValueError, "unknown NID 100000"):
960            ssl._ASN1Object.fromnid(100000)
961        for i in range(1000):
962            try:
963                obj = ssl._ASN1Object.fromnid(i)
964            except ValueError:
965                pass
966            else:
967                self.assertIsInstance(obj.nid, int)
968                self.assertIsInstance(obj.shortname, str)
969                self.assertIsInstance(obj.longname, str)
970                self.assertIsInstance(obj.oid, (str, type(None)))
971
972        val = ssl._ASN1Object.fromname('TLS Web Server Authentication')
973        self.assertEqual(val, expected)
974        self.assertIsInstance(val, ssl._ASN1Object)
975        self.assertEqual(ssl._ASN1Object.fromname('serverAuth'), expected)
976        self.assertEqual(ssl._ASN1Object.fromname('1.3.6.1.5.5.7.3.1'),
977                         expected)
978        with self.assertRaisesRegex(ValueError, "unknown object 'serverauth'"):
979            ssl._ASN1Object.fromname('serverauth')
980
981    def test_purpose_enum(self):
982        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
983        self.assertIsInstance(ssl.Purpose.SERVER_AUTH, ssl._ASN1Object)
984        self.assertEqual(ssl.Purpose.SERVER_AUTH, val)
985        self.assertEqual(ssl.Purpose.SERVER_AUTH.nid, 129)
986        self.assertEqual(ssl.Purpose.SERVER_AUTH.shortname, 'serverAuth')
987        self.assertEqual(ssl.Purpose.SERVER_AUTH.oid,
988                              '1.3.6.1.5.5.7.3.1')
989
990        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.2')
991        self.assertIsInstance(ssl.Purpose.CLIENT_AUTH, ssl._ASN1Object)
992        self.assertEqual(ssl.Purpose.CLIENT_AUTH, val)
993        self.assertEqual(ssl.Purpose.CLIENT_AUTH.nid, 130)
994        self.assertEqual(ssl.Purpose.CLIENT_AUTH.shortname, 'clientAuth')
995        self.assertEqual(ssl.Purpose.CLIENT_AUTH.oid,
996                              '1.3.6.1.5.5.7.3.2')
997
998    def test_unsupported_dtls(self):
999        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1000        self.addCleanup(s.close)
1001        with self.assertRaises(NotImplementedError) as cx:
1002            test_wrap_socket(s, cert_reqs=ssl.CERT_NONE)
1003        self.assertEqual(str(cx.exception), "only stream sockets are supported")
1004        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1005        with self.assertRaises(NotImplementedError) as cx:
1006            ctx.wrap_socket(s)
1007        self.assertEqual(str(cx.exception), "only stream sockets are supported")
1008
1009    def cert_time_ok(self, timestring, timestamp):
1010        self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp)
1011
1012    def cert_time_fail(self, timestring):
1013        with self.assertRaises(ValueError):
1014            ssl.cert_time_to_seconds(timestring)
1015
1016    @unittest.skipUnless(utc_offset(),
1017                         'local time needs to be different from UTC')
1018    def test_cert_time_to_seconds_timezone(self):
1019        # Issue #19940: ssl.cert_time_to_seconds() returns wrong
1020        #               results if local timezone is not UTC
1021        self.cert_time_ok("May  9 00:00:00 2007 GMT", 1178668800.0)
1022        self.cert_time_ok("Jan  5 09:34:43 2018 GMT", 1515144883.0)
1023
1024    def test_cert_time_to_seconds(self):
1025        timestring = "Jan  5 09:34:43 2018 GMT"
1026        ts = 1515144883.0
1027        self.cert_time_ok(timestring, ts)
1028        # accept keyword parameter, assert its name
1029        self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts)
1030        # accept both %e and %d (space or zero generated by strftime)
1031        self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts)
1032        # case-insensitive
1033        self.cert_time_ok("JaN  5 09:34:43 2018 GmT", ts)
1034        self.cert_time_fail("Jan  5 09:34 2018 GMT")     # no seconds
1035        self.cert_time_fail("Jan  5 09:34:43 2018")      # no GMT
1036        self.cert_time_fail("Jan  5 09:34:43 2018 UTC")  # not GMT timezone
1037        self.cert_time_fail("Jan 35 09:34:43 2018 GMT")  # invalid day
1038        self.cert_time_fail("Jon  5 09:34:43 2018 GMT")  # invalid month
1039        self.cert_time_fail("Jan  5 24:00:00 2018 GMT")  # invalid hour
1040        self.cert_time_fail("Jan  5 09:60:43 2018 GMT")  # invalid minute
1041
1042        newyear_ts = 1230768000.0
1043        # leap seconds
1044        self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts)
1045        # same timestamp
1046        self.cert_time_ok("Jan  1 00:00:00 2009 GMT", newyear_ts)
1047
1048        self.cert_time_ok("Jan  5 09:34:59 2018 GMT", 1515144899)
1049        #  allow 60th second (even if it is not a leap second)
1050        self.cert_time_ok("Jan  5 09:34:60 2018 GMT", 1515144900)
1051        #  allow 2nd leap second for compatibility with time.strptime()
1052        self.cert_time_ok("Jan  5 09:34:61 2018 GMT", 1515144901)
1053        self.cert_time_fail("Jan  5 09:34:62 2018 GMT")  # invalid seconds
1054
1055        # no special treatment for the special value:
1056        #   99991231235959Z (rfc 5280)
1057        self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0)
1058
1059    @support.run_with_locale('LC_ALL', '')
1060    def test_cert_time_to_seconds_locale(self):
1061        # `cert_time_to_seconds()` should be locale independent
1062
1063        def local_february_name():
1064            return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0))
1065
1066        if local_february_name().lower() == 'feb':
1067            self.skipTest("locale-specific month name needs to be "
1068                          "different from C locale")
1069
1070        # locale-independent
1071        self.cert_time_ok("Feb  9 00:00:00 2007 GMT", 1170979200.0)
1072        self.cert_time_fail(local_february_name() + "  9 00:00:00 2007 GMT")
1073
1074    def test_connect_ex_error(self):
1075        server = socket.socket(socket.AF_INET)
1076        self.addCleanup(server.close)
1077        port = socket_helper.bind_port(server)  # Reserve port but don't listen
1078        s = test_wrap_socket(socket.socket(socket.AF_INET),
1079                            cert_reqs=ssl.CERT_REQUIRED)
1080        self.addCleanup(s.close)
1081        rc = s.connect_ex((HOST, port))
1082        # Issue #19919: Windows machines or VMs hosted on Windows
1083        # machines sometimes return EWOULDBLOCK.
1084        errors = (
1085            errno.ECONNREFUSED, errno.EHOSTUNREACH, errno.ETIMEDOUT,
1086            errno.EWOULDBLOCK,
1087        )
1088        self.assertIn(rc, errors)
1089
1090
1091class ContextTests(unittest.TestCase):
1092
1093    def test_constructor(self):
1094        for protocol in PROTOCOLS:
1095            ssl.SSLContext(protocol)
1096        ctx = ssl.SSLContext()
1097        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1098        self.assertRaises(ValueError, ssl.SSLContext, -1)
1099        self.assertRaises(ValueError, ssl.SSLContext, 42)
1100
1101    def test_protocol(self):
1102        for proto in PROTOCOLS:
1103            ctx = ssl.SSLContext(proto)
1104            self.assertEqual(ctx.protocol, proto)
1105
1106    def test_ciphers(self):
1107        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1108        ctx.set_ciphers("ALL")
1109        ctx.set_ciphers("DEFAULT")
1110        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
1111            ctx.set_ciphers("^$:,;?*'dorothyx")
1112
1113    @unittest.skipUnless(PY_SSL_DEFAULT_CIPHERS == 1,
1114                         "Test applies only to Python default ciphers")
1115    def test_python_ciphers(self):
1116        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1117        ciphers = ctx.get_ciphers()
1118        for suite in ciphers:
1119            name = suite['name']
1120            self.assertNotIn("PSK", name)
1121            self.assertNotIn("SRP", name)
1122            self.assertNotIn("MD5", name)
1123            self.assertNotIn("RC4", name)
1124            self.assertNotIn("3DES", name)
1125
1126    @unittest.skipIf(ssl.OPENSSL_VERSION_INFO < (1, 0, 2, 0, 0), 'OpenSSL too old')
1127    def test_get_ciphers(self):
1128        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1129        ctx.set_ciphers('AESGCM')
1130        names = set(d['name'] for d in ctx.get_ciphers())
1131        self.assertIn('AES256-GCM-SHA384', names)
1132        self.assertIn('AES128-GCM-SHA256', names)
1133
1134    def test_options(self):
1135        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1136        # OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value
1137        default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
1138        # SSLContext also enables these by default
1139        default |= (OP_NO_COMPRESSION | OP_CIPHER_SERVER_PREFERENCE |
1140                    OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE |
1141                    OP_ENABLE_MIDDLEBOX_COMPAT)
1142        self.assertEqual(default, ctx.options)
1143        ctx.options |= ssl.OP_NO_TLSv1
1144        self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options)
1145        if can_clear_options():
1146            ctx.options = (ctx.options & ~ssl.OP_NO_TLSv1)
1147            self.assertEqual(default, ctx.options)
1148            ctx.options = 0
1149            # Ubuntu has OP_NO_SSLv3 forced on by default
1150            self.assertEqual(0, ctx.options & ~ssl.OP_NO_SSLv3)
1151        else:
1152            with self.assertRaises(ValueError):
1153                ctx.options = 0
1154
1155    def test_verify_mode_protocol(self):
1156        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1157        # Default value
1158        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1159        ctx.verify_mode = ssl.CERT_OPTIONAL
1160        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1161        ctx.verify_mode = ssl.CERT_REQUIRED
1162        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1163        ctx.verify_mode = ssl.CERT_NONE
1164        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1165        with self.assertRaises(TypeError):
1166            ctx.verify_mode = None
1167        with self.assertRaises(ValueError):
1168            ctx.verify_mode = 42
1169
1170        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1171        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1172        self.assertFalse(ctx.check_hostname)
1173
1174        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1175        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1176        self.assertTrue(ctx.check_hostname)
1177
1178    def test_hostname_checks_common_name(self):
1179        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1180        self.assertTrue(ctx.hostname_checks_common_name)
1181        if ssl.HAS_NEVER_CHECK_COMMON_NAME:
1182            ctx.hostname_checks_common_name = True
1183            self.assertTrue(ctx.hostname_checks_common_name)
1184            ctx.hostname_checks_common_name = False
1185            self.assertFalse(ctx.hostname_checks_common_name)
1186            ctx.hostname_checks_common_name = True
1187            self.assertTrue(ctx.hostname_checks_common_name)
1188        else:
1189            with self.assertRaises(AttributeError):
1190                ctx.hostname_checks_common_name = True
1191
1192    @requires_minimum_version
1193    @unittest.skipIf(IS_LIBRESSL, "see bpo-34001")
1194    def test_min_max_version(self):
1195        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1196        # OpenSSL default is MINIMUM_SUPPORTED, however some vendors like
1197        # Fedora override the setting to TLS 1.0.
1198        minimum_range = {
1199            # stock OpenSSL
1200            ssl.TLSVersion.MINIMUM_SUPPORTED,
1201            # Fedora 29 uses TLS 1.0 by default
1202            ssl.TLSVersion.TLSv1,
1203            # RHEL 8 uses TLS 1.2 by default
1204            ssl.TLSVersion.TLSv1_2
1205        }
1206        maximum_range = {
1207            # stock OpenSSL
1208            ssl.TLSVersion.MAXIMUM_SUPPORTED,
1209            # Fedora 32 uses TLS 1.3 by default
1210            ssl.TLSVersion.TLSv1_3
1211        }
1212
1213        self.assertIn(
1214            ctx.minimum_version, minimum_range
1215        )
1216        self.assertIn(
1217            ctx.maximum_version, maximum_range
1218        )
1219
1220        ctx.minimum_version = ssl.TLSVersion.TLSv1_1
1221        ctx.maximum_version = ssl.TLSVersion.TLSv1_2
1222        self.assertEqual(
1223            ctx.minimum_version, ssl.TLSVersion.TLSv1_1
1224        )
1225        self.assertEqual(
1226            ctx.maximum_version, ssl.TLSVersion.TLSv1_2
1227        )
1228
1229        ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1230        ctx.maximum_version = ssl.TLSVersion.TLSv1
1231        self.assertEqual(
1232            ctx.minimum_version, ssl.TLSVersion.MINIMUM_SUPPORTED
1233        )
1234        self.assertEqual(
1235            ctx.maximum_version, ssl.TLSVersion.TLSv1
1236        )
1237
1238        ctx.maximum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
1239        self.assertEqual(
1240            ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1241        )
1242
1243        ctx.maximum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1244        self.assertIn(
1245            ctx.maximum_version,
1246            {ssl.TLSVersion.TLSv1, ssl.TLSVersion.SSLv3}
1247        )
1248
1249        ctx.minimum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
1250        self.assertIn(
1251            ctx.minimum_version,
1252            {ssl.TLSVersion.TLSv1_2, ssl.TLSVersion.TLSv1_3}
1253        )
1254
1255        with self.assertRaises(ValueError):
1256            ctx.minimum_version = 42
1257
1258        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_1)
1259
1260        self.assertIn(
1261            ctx.minimum_version, minimum_range
1262        )
1263        self.assertEqual(
1264            ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1265        )
1266        with self.assertRaises(ValueError):
1267            ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1268        with self.assertRaises(ValueError):
1269            ctx.maximum_version = ssl.TLSVersion.TLSv1
1270
1271
1272    @unittest.skipUnless(have_verify_flags(),
1273                         "verify_flags need OpenSSL > 0.9.8")
1274    def test_verify_flags(self):
1275        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1276        # default value
1277        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
1278        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT | tf)
1279        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF
1280        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_LEAF)
1281        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_CHAIN
1282        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_CHAIN)
1283        ctx.verify_flags = ssl.VERIFY_DEFAULT
1284        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT)
1285        # supports any value
1286        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT
1287        self.assertEqual(ctx.verify_flags,
1288                         ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT)
1289        with self.assertRaises(TypeError):
1290            ctx.verify_flags = None
1291
1292    def test_load_cert_chain(self):
1293        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1294        # Combined key and cert in a single file
1295        ctx.load_cert_chain(CERTFILE, keyfile=None)
1296        ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE)
1297        self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE)
1298        with self.assertRaises(OSError) as cm:
1299            ctx.load_cert_chain(NONEXISTINGCERT)
1300        self.assertEqual(cm.exception.errno, errno.ENOENT)
1301        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1302            ctx.load_cert_chain(BADCERT)
1303        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1304            ctx.load_cert_chain(EMPTYCERT)
1305        # Separate key and cert
1306        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1307        ctx.load_cert_chain(ONLYCERT, ONLYKEY)
1308        ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY)
1309        ctx.load_cert_chain(certfile=BYTES_ONLYCERT, keyfile=BYTES_ONLYKEY)
1310        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1311            ctx.load_cert_chain(ONLYCERT)
1312        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1313            ctx.load_cert_chain(ONLYKEY)
1314        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1315            ctx.load_cert_chain(certfile=ONLYKEY, keyfile=ONLYCERT)
1316        # Mismatching key and cert
1317        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1318        with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"):
1319            ctx.load_cert_chain(CAFILE_CACERT, ONLYKEY)
1320        # Password protected key and cert
1321        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD)
1322        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD.encode())
1323        ctx.load_cert_chain(CERTFILE_PROTECTED,
1324                            password=bytearray(KEY_PASSWORD.encode()))
1325        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD)
1326        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD.encode())
1327        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED,
1328                            bytearray(KEY_PASSWORD.encode()))
1329        with self.assertRaisesRegex(TypeError, "should be a string"):
1330            ctx.load_cert_chain(CERTFILE_PROTECTED, password=True)
1331        with self.assertRaises(ssl.SSLError):
1332            ctx.load_cert_chain(CERTFILE_PROTECTED, password="badpass")
1333        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1334            # openssl has a fixed limit on the password buffer.
1335            # PEM_BUFSIZE is generally set to 1kb.
1336            # Return a string larger than this.
1337            ctx.load_cert_chain(CERTFILE_PROTECTED, password=b'a' * 102400)
1338        # Password callback
1339        def getpass_unicode():
1340            return KEY_PASSWORD
1341        def getpass_bytes():
1342            return KEY_PASSWORD.encode()
1343        def getpass_bytearray():
1344            return bytearray(KEY_PASSWORD.encode())
1345        def getpass_badpass():
1346            return "badpass"
1347        def getpass_huge():
1348            return b'a' * (1024 * 1024)
1349        def getpass_bad_type():
1350            return 9
1351        def getpass_exception():
1352            raise Exception('getpass error')
1353        class GetPassCallable:
1354            def __call__(self):
1355                return KEY_PASSWORD
1356            def getpass(self):
1357                return KEY_PASSWORD
1358        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_unicode)
1359        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytes)
1360        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytearray)
1361        ctx.load_cert_chain(CERTFILE_PROTECTED, password=GetPassCallable())
1362        ctx.load_cert_chain(CERTFILE_PROTECTED,
1363                            password=GetPassCallable().getpass)
1364        with self.assertRaises(ssl.SSLError):
1365            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_badpass)
1366        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1367            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_huge)
1368        with self.assertRaisesRegex(TypeError, "must return a string"):
1369            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bad_type)
1370        with self.assertRaisesRegex(Exception, "getpass error"):
1371            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_exception)
1372        # Make sure the password function isn't called if it isn't needed
1373        ctx.load_cert_chain(CERTFILE, password=getpass_exception)
1374
1375    def test_load_verify_locations(self):
1376        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1377        ctx.load_verify_locations(CERTFILE)
1378        ctx.load_verify_locations(cafile=CERTFILE, capath=None)
1379        ctx.load_verify_locations(BYTES_CERTFILE)
1380        ctx.load_verify_locations(cafile=BYTES_CERTFILE, capath=None)
1381        self.assertRaises(TypeError, ctx.load_verify_locations)
1382        self.assertRaises(TypeError, ctx.load_verify_locations, None, None, None)
1383        with self.assertRaises(OSError) as cm:
1384            ctx.load_verify_locations(NONEXISTINGCERT)
1385        self.assertEqual(cm.exception.errno, errno.ENOENT)
1386        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1387            ctx.load_verify_locations(BADCERT)
1388        ctx.load_verify_locations(CERTFILE, CAPATH)
1389        ctx.load_verify_locations(CERTFILE, capath=BYTES_CAPATH)
1390
1391        # Issue #10989: crash if the second argument type is invalid
1392        self.assertRaises(TypeError, ctx.load_verify_locations, None, True)
1393
1394    def test_load_verify_cadata(self):
1395        # test cadata
1396        with open(CAFILE_CACERT) as f:
1397            cacert_pem = f.read()
1398        cacert_der = ssl.PEM_cert_to_DER_cert(cacert_pem)
1399        with open(CAFILE_NEURONIO) as f:
1400            neuronio_pem = f.read()
1401        neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem)
1402
1403        # test PEM
1404        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1405        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0)
1406        ctx.load_verify_locations(cadata=cacert_pem)
1407        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1)
1408        ctx.load_verify_locations(cadata=neuronio_pem)
1409        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1410        # cert already in hash table
1411        ctx.load_verify_locations(cadata=neuronio_pem)
1412        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1413
1414        # combined
1415        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1416        combined = "\n".join((cacert_pem, neuronio_pem))
1417        ctx.load_verify_locations(cadata=combined)
1418        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1419
1420        # with junk around the certs
1421        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1422        combined = ["head", cacert_pem, "other", neuronio_pem, "again",
1423                    neuronio_pem, "tail"]
1424        ctx.load_verify_locations(cadata="\n".join(combined))
1425        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1426
1427        # test DER
1428        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1429        ctx.load_verify_locations(cadata=cacert_der)
1430        ctx.load_verify_locations(cadata=neuronio_der)
1431        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1432        # cert already in hash table
1433        ctx.load_verify_locations(cadata=cacert_der)
1434        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1435
1436        # combined
1437        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1438        combined = b"".join((cacert_der, neuronio_der))
1439        ctx.load_verify_locations(cadata=combined)
1440        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1441
1442        # error cases
1443        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1444        self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object)
1445
1446        with self.assertRaisesRegex(ssl.SSLError, "no start line"):
1447            ctx.load_verify_locations(cadata="broken")
1448        with self.assertRaisesRegex(ssl.SSLError, "not enough data"):
1449            ctx.load_verify_locations(cadata=b"broken")
1450
1451
1452    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
1453    def test_load_dh_params(self):
1454        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1455        ctx.load_dh_params(DHFILE)
1456        if os.name != 'nt':
1457            ctx.load_dh_params(BYTES_DHFILE)
1458        self.assertRaises(TypeError, ctx.load_dh_params)
1459        self.assertRaises(TypeError, ctx.load_dh_params, None)
1460        with self.assertRaises(FileNotFoundError) as cm:
1461            ctx.load_dh_params(NONEXISTINGCERT)
1462        self.assertEqual(cm.exception.errno, errno.ENOENT)
1463        with self.assertRaises(ssl.SSLError) as cm:
1464            ctx.load_dh_params(CERTFILE)
1465
1466    def test_session_stats(self):
1467        for proto in PROTOCOLS:
1468            ctx = ssl.SSLContext(proto)
1469            self.assertEqual(ctx.session_stats(), {
1470                'number': 0,
1471                'connect': 0,
1472                'connect_good': 0,
1473                'connect_renegotiate': 0,
1474                'accept': 0,
1475                'accept_good': 0,
1476                'accept_renegotiate': 0,
1477                'hits': 0,
1478                'misses': 0,
1479                'timeouts': 0,
1480                'cache_full': 0,
1481            })
1482
1483    def test_set_default_verify_paths(self):
1484        # There's not much we can do to test that it acts as expected,
1485        # so just check it doesn't crash or raise an exception.
1486        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1487        ctx.set_default_verify_paths()
1488
1489    @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build")
1490    def test_set_ecdh_curve(self):
1491        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1492        ctx.set_ecdh_curve("prime256v1")
1493        ctx.set_ecdh_curve(b"prime256v1")
1494        self.assertRaises(TypeError, ctx.set_ecdh_curve)
1495        self.assertRaises(TypeError, ctx.set_ecdh_curve, None)
1496        self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo")
1497        self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo")
1498
1499    @needs_sni
1500    def test_sni_callback(self):
1501        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1502
1503        # set_servername_callback expects a callable, or None
1504        self.assertRaises(TypeError, ctx.set_servername_callback)
1505        self.assertRaises(TypeError, ctx.set_servername_callback, 4)
1506        self.assertRaises(TypeError, ctx.set_servername_callback, "")
1507        self.assertRaises(TypeError, ctx.set_servername_callback, ctx)
1508
1509        def dummycallback(sock, servername, ctx):
1510            pass
1511        ctx.set_servername_callback(None)
1512        ctx.set_servername_callback(dummycallback)
1513
1514    @needs_sni
1515    def test_sni_callback_refcycle(self):
1516        # Reference cycles through the servername callback are detected
1517        # and cleared.
1518        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1519        def dummycallback(sock, servername, ctx, cycle=ctx):
1520            pass
1521        ctx.set_servername_callback(dummycallback)
1522        wr = weakref.ref(ctx)
1523        del ctx, dummycallback
1524        gc.collect()
1525        self.assertIs(wr(), None)
1526
1527    def test_cert_store_stats(self):
1528        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1529        self.assertEqual(ctx.cert_store_stats(),
1530            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1531        ctx.load_cert_chain(CERTFILE)
1532        self.assertEqual(ctx.cert_store_stats(),
1533            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1534        ctx.load_verify_locations(CERTFILE)
1535        self.assertEqual(ctx.cert_store_stats(),
1536            {'x509_ca': 0, 'crl': 0, 'x509': 1})
1537        ctx.load_verify_locations(CAFILE_CACERT)
1538        self.assertEqual(ctx.cert_store_stats(),
1539            {'x509_ca': 1, 'crl': 0, 'x509': 2})
1540
1541    def test_get_ca_certs(self):
1542        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1543        self.assertEqual(ctx.get_ca_certs(), [])
1544        # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE
1545        ctx.load_verify_locations(CERTFILE)
1546        self.assertEqual(ctx.get_ca_certs(), [])
1547        # but CAFILE_CACERT is a CA cert
1548        ctx.load_verify_locations(CAFILE_CACERT)
1549        self.assertEqual(ctx.get_ca_certs(),
1550            [{'issuer': ((('organizationName', 'Root CA'),),
1551                         (('organizationalUnitName', 'http://www.cacert.org'),),
1552                         (('commonName', 'CA Cert Signing Authority'),),
1553                         (('emailAddress', 'support@cacert.org'),)),
1554              'notAfter': asn1time('Mar 29 12:29:49 2033 GMT'),
1555              'notBefore': asn1time('Mar 30 12:29:49 2003 GMT'),
1556              'serialNumber': '00',
1557              'crlDistributionPoints': ('https://www.cacert.org/revoke.crl',),
1558              'subject': ((('organizationName', 'Root CA'),),
1559                          (('organizationalUnitName', 'http://www.cacert.org'),),
1560                          (('commonName', 'CA Cert Signing Authority'),),
1561                          (('emailAddress', 'support@cacert.org'),)),
1562              'version': 3}])
1563
1564        with open(CAFILE_CACERT) as f:
1565            pem = f.read()
1566        der = ssl.PEM_cert_to_DER_cert(pem)
1567        self.assertEqual(ctx.get_ca_certs(True), [der])
1568
1569    def test_load_default_certs(self):
1570        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1571        ctx.load_default_certs()
1572
1573        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1574        ctx.load_default_certs(ssl.Purpose.SERVER_AUTH)
1575        ctx.load_default_certs()
1576
1577        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1578        ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH)
1579
1580        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1581        self.assertRaises(TypeError, ctx.load_default_certs, None)
1582        self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH')
1583
1584    @unittest.skipIf(sys.platform == "win32", "not-Windows specific")
1585    @unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars")
1586    def test_load_default_certs_env(self):
1587        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1588        with support.EnvironmentVarGuard() as env:
1589            env["SSL_CERT_DIR"] = CAPATH
1590            env["SSL_CERT_FILE"] = CERTFILE
1591            ctx.load_default_certs()
1592            self.assertEqual(ctx.cert_store_stats(), {"crl": 0, "x509": 1, "x509_ca": 0})
1593
1594    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
1595    @unittest.skipIf(hasattr(sys, "gettotalrefcount"), "Debug build does not share environment between CRTs")
1596    def test_load_default_certs_env_windows(self):
1597        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1598        ctx.load_default_certs()
1599        stats = ctx.cert_store_stats()
1600
1601        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1602        with support.EnvironmentVarGuard() as env:
1603            env["SSL_CERT_DIR"] = CAPATH
1604            env["SSL_CERT_FILE"] = CERTFILE
1605            ctx.load_default_certs()
1606            stats["x509"] += 1
1607            self.assertEqual(ctx.cert_store_stats(), stats)
1608
1609    def _assert_context_options(self, ctx):
1610        self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2)
1611        if OP_NO_COMPRESSION != 0:
1612            self.assertEqual(ctx.options & OP_NO_COMPRESSION,
1613                             OP_NO_COMPRESSION)
1614        if OP_SINGLE_DH_USE != 0:
1615            self.assertEqual(ctx.options & OP_SINGLE_DH_USE,
1616                             OP_SINGLE_DH_USE)
1617        if OP_SINGLE_ECDH_USE != 0:
1618            self.assertEqual(ctx.options & OP_SINGLE_ECDH_USE,
1619                             OP_SINGLE_ECDH_USE)
1620        if OP_CIPHER_SERVER_PREFERENCE != 0:
1621            self.assertEqual(ctx.options & OP_CIPHER_SERVER_PREFERENCE,
1622                             OP_CIPHER_SERVER_PREFERENCE)
1623
1624    def test_create_default_context(self):
1625        ctx = ssl.create_default_context()
1626
1627        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1628        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1629        self.assertTrue(ctx.check_hostname)
1630        self._assert_context_options(ctx)
1631
1632        with open(SIGNING_CA) as f:
1633            cadata = f.read()
1634        ctx = ssl.create_default_context(cafile=SIGNING_CA, capath=CAPATH,
1635                                         cadata=cadata)
1636        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1637        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1638        self._assert_context_options(ctx)
1639
1640        ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
1641        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1642        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1643        self._assert_context_options(ctx)
1644
1645    def test__create_stdlib_context(self):
1646        ctx = ssl._create_stdlib_context()
1647        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1648        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1649        self.assertFalse(ctx.check_hostname)
1650        self._assert_context_options(ctx)
1651
1652        ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1)
1653        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1654        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1655        self._assert_context_options(ctx)
1656
1657        ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1,
1658                                         cert_reqs=ssl.CERT_REQUIRED,
1659                                         check_hostname=True)
1660        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1661        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1662        self.assertTrue(ctx.check_hostname)
1663        self._assert_context_options(ctx)
1664
1665        ctx = ssl._create_stdlib_context(purpose=ssl.Purpose.CLIENT_AUTH)
1666        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1667        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1668        self._assert_context_options(ctx)
1669
1670    def test_check_hostname(self):
1671        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1672        self.assertFalse(ctx.check_hostname)
1673        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1674
1675        # Auto set CERT_REQUIRED
1676        ctx.check_hostname = True
1677        self.assertTrue(ctx.check_hostname)
1678        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1679        ctx.check_hostname = False
1680        ctx.verify_mode = ssl.CERT_REQUIRED
1681        self.assertFalse(ctx.check_hostname)
1682        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1683
1684        # Changing verify_mode does not affect check_hostname
1685        ctx.check_hostname = False
1686        ctx.verify_mode = ssl.CERT_NONE
1687        ctx.check_hostname = False
1688        self.assertFalse(ctx.check_hostname)
1689        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1690        # Auto set
1691        ctx.check_hostname = True
1692        self.assertTrue(ctx.check_hostname)
1693        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1694
1695        ctx.check_hostname = False
1696        ctx.verify_mode = ssl.CERT_OPTIONAL
1697        ctx.check_hostname = False
1698        self.assertFalse(ctx.check_hostname)
1699        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1700        # keep CERT_OPTIONAL
1701        ctx.check_hostname = True
1702        self.assertTrue(ctx.check_hostname)
1703        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1704
1705        # Cannot set CERT_NONE with check_hostname enabled
1706        with self.assertRaises(ValueError):
1707            ctx.verify_mode = ssl.CERT_NONE
1708        ctx.check_hostname = False
1709        self.assertFalse(ctx.check_hostname)
1710        ctx.verify_mode = ssl.CERT_NONE
1711        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1712
1713    def test_context_client_server(self):
1714        # PROTOCOL_TLS_CLIENT has sane defaults
1715        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1716        self.assertTrue(ctx.check_hostname)
1717        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1718
1719        # PROTOCOL_TLS_SERVER has different but also sane defaults
1720        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1721        self.assertFalse(ctx.check_hostname)
1722        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1723
1724    def test_context_custom_class(self):
1725        class MySSLSocket(ssl.SSLSocket):
1726            pass
1727
1728        class MySSLObject(ssl.SSLObject):
1729            pass
1730
1731        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1732        ctx.sslsocket_class = MySSLSocket
1733        ctx.sslobject_class = MySSLObject
1734
1735        with ctx.wrap_socket(socket.socket(), server_side=True) as sock:
1736            self.assertIsInstance(sock, MySSLSocket)
1737        obj = ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO())
1738        self.assertIsInstance(obj, MySSLObject)
1739
1740    @unittest.skipUnless(IS_OPENSSL_1_1_1, "Test requires OpenSSL 1.1.1")
1741    def test_num_tickest(self):
1742        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1743        self.assertEqual(ctx.num_tickets, 2)
1744        ctx.num_tickets = 1
1745        self.assertEqual(ctx.num_tickets, 1)
1746        ctx.num_tickets = 0
1747        self.assertEqual(ctx.num_tickets, 0)
1748        with self.assertRaises(ValueError):
1749            ctx.num_tickets = -1
1750        with self.assertRaises(TypeError):
1751            ctx.num_tickets = None
1752
1753        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1754        self.assertEqual(ctx.num_tickets, 2)
1755        with self.assertRaises(ValueError):
1756            ctx.num_tickets = 1
1757
1758
1759class SSLErrorTests(unittest.TestCase):
1760
1761    def test_str(self):
1762        # The str() of a SSLError doesn't include the errno
1763        e = ssl.SSLError(1, "foo")
1764        self.assertEqual(str(e), "foo")
1765        self.assertEqual(e.errno, 1)
1766        # Same for a subclass
1767        e = ssl.SSLZeroReturnError(1, "foo")
1768        self.assertEqual(str(e), "foo")
1769        self.assertEqual(e.errno, 1)
1770
1771    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
1772    def test_lib_reason(self):
1773        # Test the library and reason attributes
1774        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1775        with self.assertRaises(ssl.SSLError) as cm:
1776            ctx.load_dh_params(CERTFILE)
1777        self.assertEqual(cm.exception.library, 'PEM')
1778        self.assertEqual(cm.exception.reason, 'NO_START_LINE')
1779        s = str(cm.exception)
1780        self.assertTrue(s.startswith("[PEM: NO_START_LINE] no start line"), s)
1781
1782    def test_subclass(self):
1783        # Check that the appropriate SSLError subclass is raised
1784        # (this only tests one of them)
1785        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1786        ctx.check_hostname = False
1787        ctx.verify_mode = ssl.CERT_NONE
1788        with socket.create_server(("127.0.0.1", 0)) as s:
1789            c = socket.create_connection(s.getsockname())
1790            c.setblocking(False)
1791            with ctx.wrap_socket(c, False, do_handshake_on_connect=False) as c:
1792                with self.assertRaises(ssl.SSLWantReadError) as cm:
1793                    c.do_handshake()
1794                s = str(cm.exception)
1795                self.assertTrue(s.startswith("The operation did not complete (read)"), s)
1796                # For compatibility
1797                self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
1798
1799
1800    def test_bad_server_hostname(self):
1801        ctx = ssl.create_default_context()
1802        with self.assertRaises(ValueError):
1803            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1804                         server_hostname="")
1805        with self.assertRaises(ValueError):
1806            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1807                         server_hostname=".example.org")
1808        with self.assertRaises(TypeError):
1809            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1810                         server_hostname="example.org\x00evil.com")
1811
1812
1813class MemoryBIOTests(unittest.TestCase):
1814
1815    def test_read_write(self):
1816        bio = ssl.MemoryBIO()
1817        bio.write(b'foo')
1818        self.assertEqual(bio.read(), b'foo')
1819        self.assertEqual(bio.read(), b'')
1820        bio.write(b'foo')
1821        bio.write(b'bar')
1822        self.assertEqual(bio.read(), b'foobar')
1823        self.assertEqual(bio.read(), b'')
1824        bio.write(b'baz')
1825        self.assertEqual(bio.read(2), b'ba')
1826        self.assertEqual(bio.read(1), b'z')
1827        self.assertEqual(bio.read(1), b'')
1828
1829    def test_eof(self):
1830        bio = ssl.MemoryBIO()
1831        self.assertFalse(bio.eof)
1832        self.assertEqual(bio.read(), b'')
1833        self.assertFalse(bio.eof)
1834        bio.write(b'foo')
1835        self.assertFalse(bio.eof)
1836        bio.write_eof()
1837        self.assertFalse(bio.eof)
1838        self.assertEqual(bio.read(2), b'fo')
1839        self.assertFalse(bio.eof)
1840        self.assertEqual(bio.read(1), b'o')
1841        self.assertTrue(bio.eof)
1842        self.assertEqual(bio.read(), b'')
1843        self.assertTrue(bio.eof)
1844
1845    def test_pending(self):
1846        bio = ssl.MemoryBIO()
1847        self.assertEqual(bio.pending, 0)
1848        bio.write(b'foo')
1849        self.assertEqual(bio.pending, 3)
1850        for i in range(3):
1851            bio.read(1)
1852            self.assertEqual(bio.pending, 3-i-1)
1853        for i in range(3):
1854            bio.write(b'x')
1855            self.assertEqual(bio.pending, i+1)
1856        bio.read()
1857        self.assertEqual(bio.pending, 0)
1858
1859    def test_buffer_types(self):
1860        bio = ssl.MemoryBIO()
1861        bio.write(b'foo')
1862        self.assertEqual(bio.read(), b'foo')
1863        bio.write(bytearray(b'bar'))
1864        self.assertEqual(bio.read(), b'bar')
1865        bio.write(memoryview(b'baz'))
1866        self.assertEqual(bio.read(), b'baz')
1867
1868    def test_error_types(self):
1869        bio = ssl.MemoryBIO()
1870        self.assertRaises(TypeError, bio.write, 'foo')
1871        self.assertRaises(TypeError, bio.write, None)
1872        self.assertRaises(TypeError, bio.write, True)
1873        self.assertRaises(TypeError, bio.write, 1)
1874
1875
1876class SSLObjectTests(unittest.TestCase):
1877    def test_private_init(self):
1878        bio = ssl.MemoryBIO()
1879        with self.assertRaisesRegex(TypeError, "public constructor"):
1880            ssl.SSLObject(bio, bio)
1881
1882    def test_unwrap(self):
1883        client_ctx, server_ctx, hostname = testing_context()
1884        c_in = ssl.MemoryBIO()
1885        c_out = ssl.MemoryBIO()
1886        s_in = ssl.MemoryBIO()
1887        s_out = ssl.MemoryBIO()
1888        client = client_ctx.wrap_bio(c_in, c_out, server_hostname=hostname)
1889        server = server_ctx.wrap_bio(s_in, s_out, server_side=True)
1890
1891        # Loop on the handshake for a bit to get it settled
1892        for _ in range(5):
1893            try:
1894                client.do_handshake()
1895            except ssl.SSLWantReadError:
1896                pass
1897            if c_out.pending:
1898                s_in.write(c_out.read())
1899            try:
1900                server.do_handshake()
1901            except ssl.SSLWantReadError:
1902                pass
1903            if s_out.pending:
1904                c_in.write(s_out.read())
1905        # Now the handshakes should be complete (don't raise WantReadError)
1906        client.do_handshake()
1907        server.do_handshake()
1908
1909        # Now if we unwrap one side unilaterally, it should send close-notify
1910        # and raise WantReadError:
1911        with self.assertRaises(ssl.SSLWantReadError):
1912            client.unwrap()
1913
1914        # But server.unwrap() does not raise, because it reads the client's
1915        # close-notify:
1916        s_in.write(c_out.read())
1917        server.unwrap()
1918
1919        # And now that the client gets the server's close-notify, it doesn't
1920        # raise either.
1921        c_in.write(s_out.read())
1922        client.unwrap()
1923
1924class SimpleBackgroundTests(unittest.TestCase):
1925    """Tests that connect to a simple server running in the background"""
1926
1927    def setUp(self):
1928        server = ThreadedEchoServer(SIGNED_CERTFILE)
1929        self.server_addr = (HOST, server.port)
1930        server.__enter__()
1931        self.addCleanup(server.__exit__, None, None, None)
1932
1933    def test_connect(self):
1934        with test_wrap_socket(socket.socket(socket.AF_INET),
1935                            cert_reqs=ssl.CERT_NONE) as s:
1936            s.connect(self.server_addr)
1937            self.assertEqual({}, s.getpeercert())
1938            self.assertFalse(s.server_side)
1939
1940        # this should succeed because we specify the root cert
1941        with test_wrap_socket(socket.socket(socket.AF_INET),
1942                            cert_reqs=ssl.CERT_REQUIRED,
1943                            ca_certs=SIGNING_CA) as s:
1944            s.connect(self.server_addr)
1945            self.assertTrue(s.getpeercert())
1946            self.assertFalse(s.server_side)
1947
1948    def test_connect_fail(self):
1949        # This should fail because we have no verification certs. Connection
1950        # failure crashes ThreadedEchoServer, so run this in an independent
1951        # test method.
1952        s = test_wrap_socket(socket.socket(socket.AF_INET),
1953                            cert_reqs=ssl.CERT_REQUIRED)
1954        self.addCleanup(s.close)
1955        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
1956                               s.connect, self.server_addr)
1957
1958    def test_connect_ex(self):
1959        # Issue #11326: check connect_ex() implementation
1960        s = test_wrap_socket(socket.socket(socket.AF_INET),
1961                            cert_reqs=ssl.CERT_REQUIRED,
1962                            ca_certs=SIGNING_CA)
1963        self.addCleanup(s.close)
1964        self.assertEqual(0, s.connect_ex(self.server_addr))
1965        self.assertTrue(s.getpeercert())
1966
1967    def test_non_blocking_connect_ex(self):
1968        # Issue #11326: non-blocking connect_ex() should allow handshake
1969        # to proceed after the socket gets ready.
1970        s = test_wrap_socket(socket.socket(socket.AF_INET),
1971                            cert_reqs=ssl.CERT_REQUIRED,
1972                            ca_certs=SIGNING_CA,
1973                            do_handshake_on_connect=False)
1974        self.addCleanup(s.close)
1975        s.setblocking(False)
1976        rc = s.connect_ex(self.server_addr)
1977        # EWOULDBLOCK under Windows, EINPROGRESS elsewhere
1978        self.assertIn(rc, (0, errno.EINPROGRESS, errno.EWOULDBLOCK))
1979        # Wait for connect to finish
1980        select.select([], [s], [], 5.0)
1981        # Non-blocking handshake
1982        while True:
1983            try:
1984                s.do_handshake()
1985                break
1986            except ssl.SSLWantReadError:
1987                select.select([s], [], [], 5.0)
1988            except ssl.SSLWantWriteError:
1989                select.select([], [s], [], 5.0)
1990        # SSL established
1991        self.assertTrue(s.getpeercert())
1992
1993    def test_connect_with_context(self):
1994        # Same as test_connect, but with a separately created context
1995        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1996        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1997            s.connect(self.server_addr)
1998            self.assertEqual({}, s.getpeercert())
1999        # Same with a server hostname
2000        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2001                            server_hostname="dummy") as s:
2002            s.connect(self.server_addr)
2003        ctx.verify_mode = ssl.CERT_REQUIRED
2004        # This should succeed because we specify the root cert
2005        ctx.load_verify_locations(SIGNING_CA)
2006        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2007            s.connect(self.server_addr)
2008            cert = s.getpeercert()
2009            self.assertTrue(cert)
2010
2011    def test_connect_with_context_fail(self):
2012        # This should fail because we have no verification certs. Connection
2013        # failure crashes ThreadedEchoServer, so run this in an independent
2014        # test method.
2015        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2016        ctx.verify_mode = ssl.CERT_REQUIRED
2017        s = ctx.wrap_socket(socket.socket(socket.AF_INET))
2018        self.addCleanup(s.close)
2019        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
2020                                s.connect, self.server_addr)
2021
2022    def test_connect_capath(self):
2023        # Verify server certificates using the `capath` argument
2024        # NOTE: the subject hashing algorithm has been changed between
2025        # OpenSSL 0.9.8n and 1.0.0, as a result the capath directory must
2026        # contain both versions of each certificate (same content, different
2027        # filename) for this test to be portable across OpenSSL releases.
2028        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2029        ctx.verify_mode = ssl.CERT_REQUIRED
2030        ctx.load_verify_locations(capath=CAPATH)
2031        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2032            s.connect(self.server_addr)
2033            cert = s.getpeercert()
2034            self.assertTrue(cert)
2035
2036        # Same with a bytes `capath` argument
2037        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2038        ctx.verify_mode = ssl.CERT_REQUIRED
2039        ctx.load_verify_locations(capath=BYTES_CAPATH)
2040        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2041            s.connect(self.server_addr)
2042            cert = s.getpeercert()
2043            self.assertTrue(cert)
2044
2045    def test_connect_cadata(self):
2046        with open(SIGNING_CA) as f:
2047            pem = f.read()
2048        der = ssl.PEM_cert_to_DER_cert(pem)
2049        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2050        ctx.verify_mode = ssl.CERT_REQUIRED
2051        ctx.load_verify_locations(cadata=pem)
2052        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2053            s.connect(self.server_addr)
2054            cert = s.getpeercert()
2055            self.assertTrue(cert)
2056
2057        # same with DER
2058        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2059        ctx.verify_mode = ssl.CERT_REQUIRED
2060        ctx.load_verify_locations(cadata=der)
2061        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2062            s.connect(self.server_addr)
2063            cert = s.getpeercert()
2064            self.assertTrue(cert)
2065
2066    @unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows")
2067    def test_makefile_close(self):
2068        # Issue #5238: creating a file-like object with makefile() shouldn't
2069        # delay closing the underlying "real socket" (here tested with its
2070        # file descriptor, hence skipping the test under Windows).
2071        ss = test_wrap_socket(socket.socket(socket.AF_INET))
2072        ss.connect(self.server_addr)
2073        fd = ss.fileno()
2074        f = ss.makefile()
2075        f.close()
2076        # The fd is still open
2077        os.read(fd, 0)
2078        # Closing the SSL socket should close the fd too
2079        ss.close()
2080        gc.collect()
2081        with self.assertRaises(OSError) as e:
2082            os.read(fd, 0)
2083        self.assertEqual(e.exception.errno, errno.EBADF)
2084
2085    def test_non_blocking_handshake(self):
2086        s = socket.socket(socket.AF_INET)
2087        s.connect(self.server_addr)
2088        s.setblocking(False)
2089        s = test_wrap_socket(s,
2090                            cert_reqs=ssl.CERT_NONE,
2091                            do_handshake_on_connect=False)
2092        self.addCleanup(s.close)
2093        count = 0
2094        while True:
2095            try:
2096                count += 1
2097                s.do_handshake()
2098                break
2099            except ssl.SSLWantReadError:
2100                select.select([s], [], [])
2101            except ssl.SSLWantWriteError:
2102                select.select([], [s], [])
2103        if support.verbose:
2104            sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
2105
2106    def test_get_server_certificate(self):
2107        _test_get_server_certificate(self, *self.server_addr, cert=SIGNING_CA)
2108
2109    def test_get_server_certificate_fail(self):
2110        # Connection failure crashes ThreadedEchoServer, so run this in an
2111        # independent test method
2112        _test_get_server_certificate_fail(self, *self.server_addr)
2113
2114    def test_ciphers(self):
2115        with test_wrap_socket(socket.socket(socket.AF_INET),
2116                             cert_reqs=ssl.CERT_NONE, ciphers="ALL") as s:
2117            s.connect(self.server_addr)
2118        with test_wrap_socket(socket.socket(socket.AF_INET),
2119                             cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") as s:
2120            s.connect(self.server_addr)
2121        # Error checking can happen at instantiation or when connecting
2122        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
2123            with socket.socket(socket.AF_INET) as sock:
2124                s = test_wrap_socket(sock,
2125                                    cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx")
2126                s.connect(self.server_addr)
2127
2128    def test_get_ca_certs_capath(self):
2129        # capath certs are loaded on request
2130        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2131        ctx.load_verify_locations(capath=CAPATH)
2132        self.assertEqual(ctx.get_ca_certs(), [])
2133        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2134                             server_hostname='localhost') as s:
2135            s.connect(self.server_addr)
2136            cert = s.getpeercert()
2137            self.assertTrue(cert)
2138        self.assertEqual(len(ctx.get_ca_certs()), 1)
2139
2140    @needs_sni
2141    def test_context_setget(self):
2142        # Check that the context of a connected socket can be replaced.
2143        ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2144        ctx1.load_verify_locations(capath=CAPATH)
2145        ctx2 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2146        ctx2.load_verify_locations(capath=CAPATH)
2147        s = socket.socket(socket.AF_INET)
2148        with ctx1.wrap_socket(s, server_hostname='localhost') as ss:
2149            ss.connect(self.server_addr)
2150            self.assertIs(ss.context, ctx1)
2151            self.assertIs(ss._sslobj.context, ctx1)
2152            ss.context = ctx2
2153            self.assertIs(ss.context, ctx2)
2154            self.assertIs(ss._sslobj.context, ctx2)
2155
2156    def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs):
2157        # A simple IO loop. Call func(*args) depending on the error we get
2158        # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs.
2159        timeout = kwargs.get('timeout', support.SHORT_TIMEOUT)
2160        deadline = time.monotonic() + timeout
2161        count = 0
2162        while True:
2163            if time.monotonic() > deadline:
2164                self.fail("timeout")
2165            errno = None
2166            count += 1
2167            try:
2168                ret = func(*args)
2169            except ssl.SSLError as e:
2170                if e.errno not in (ssl.SSL_ERROR_WANT_READ,
2171                                   ssl.SSL_ERROR_WANT_WRITE):
2172                    raise
2173                errno = e.errno
2174            # Get any data from the outgoing BIO irrespective of any error, and
2175            # send it to the socket.
2176            buf = outgoing.read()
2177            sock.sendall(buf)
2178            # If there's no error, we're done. For WANT_READ, we need to get
2179            # data from the socket and put it in the incoming BIO.
2180            if errno is None:
2181                break
2182            elif errno == ssl.SSL_ERROR_WANT_READ:
2183                buf = sock.recv(32768)
2184                if buf:
2185                    incoming.write(buf)
2186                else:
2187                    incoming.write_eof()
2188        if support.verbose:
2189            sys.stdout.write("Needed %d calls to complete %s().\n"
2190                             % (count, func.__name__))
2191        return ret
2192
2193    def test_bio_handshake(self):
2194        sock = socket.socket(socket.AF_INET)
2195        self.addCleanup(sock.close)
2196        sock.connect(self.server_addr)
2197        incoming = ssl.MemoryBIO()
2198        outgoing = ssl.MemoryBIO()
2199        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2200        self.assertTrue(ctx.check_hostname)
2201        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
2202        ctx.load_verify_locations(SIGNING_CA)
2203        sslobj = ctx.wrap_bio(incoming, outgoing, False,
2204                              SIGNED_CERTFILE_HOSTNAME)
2205        self.assertIs(sslobj._sslobj.owner, sslobj)
2206        self.assertIsNone(sslobj.cipher())
2207        self.assertIsNone(sslobj.version())
2208        self.assertIsNotNone(sslobj.shared_ciphers())
2209        self.assertRaises(ValueError, sslobj.getpeercert)
2210        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
2211            self.assertIsNone(sslobj.get_channel_binding('tls-unique'))
2212        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2213        self.assertTrue(sslobj.cipher())
2214        self.assertIsNotNone(sslobj.shared_ciphers())
2215        self.assertIsNotNone(sslobj.version())
2216        self.assertTrue(sslobj.getpeercert())
2217        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
2218            self.assertTrue(sslobj.get_channel_binding('tls-unique'))
2219        try:
2220            self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2221        except ssl.SSLSyscallError:
2222            # If the server shuts down the TCP connection without sending a
2223            # secure shutdown message, this is reported as SSL_ERROR_SYSCALL
2224            pass
2225        self.assertRaises(ssl.SSLError, sslobj.write, b'foo')
2226
2227    def test_bio_read_write_data(self):
2228        sock = socket.socket(socket.AF_INET)
2229        self.addCleanup(sock.close)
2230        sock.connect(self.server_addr)
2231        incoming = ssl.MemoryBIO()
2232        outgoing = ssl.MemoryBIO()
2233        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2234        ctx.verify_mode = ssl.CERT_NONE
2235        sslobj = ctx.wrap_bio(incoming, outgoing, False)
2236        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2237        req = b'FOO\n'
2238        self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
2239        buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
2240        self.assertEqual(buf, b'foo\n')
2241        self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2242
2243
2244class NetworkedTests(unittest.TestCase):
2245
2246    def test_timeout_connect_ex(self):
2247        # Issue #12065: on a timeout, connect_ex() should return the original
2248        # errno (mimicking the behaviour of non-SSL sockets).
2249        with socket_helper.transient_internet(REMOTE_HOST):
2250            s = test_wrap_socket(socket.socket(socket.AF_INET),
2251                                cert_reqs=ssl.CERT_REQUIRED,
2252                                do_handshake_on_connect=False)
2253            self.addCleanup(s.close)
2254            s.settimeout(0.0000001)
2255            rc = s.connect_ex((REMOTE_HOST, 443))
2256            if rc == 0:
2257                self.skipTest("REMOTE_HOST responded too quickly")
2258            self.assertIn(rc, (errno.EAGAIN, errno.EWOULDBLOCK))
2259
2260    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'Needs IPv6')
2261    def test_get_server_certificate_ipv6(self):
2262        with socket_helper.transient_internet('ipv6.google.com'):
2263            _test_get_server_certificate(self, 'ipv6.google.com', 443)
2264            _test_get_server_certificate_fail(self, 'ipv6.google.com', 443)
2265
2266
2267def _test_get_server_certificate(test, host, port, cert=None):
2268    pem = ssl.get_server_certificate((host, port))
2269    if not pem:
2270        test.fail("No server certificate on %s:%s!" % (host, port))
2271
2272    pem = ssl.get_server_certificate((host, port), ca_certs=cert)
2273    if not pem:
2274        test.fail("No server certificate on %s:%s!" % (host, port))
2275    if support.verbose:
2276        sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem))
2277
2278def _test_get_server_certificate_fail(test, host, port):
2279    try:
2280        pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE)
2281    except ssl.SSLError as x:
2282        #should fail
2283        if support.verbose:
2284            sys.stdout.write("%s\n" % x)
2285    else:
2286        test.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
2287
2288
2289from test.ssl_servers import make_https_server
2290
2291class ThreadedEchoServer(threading.Thread):
2292
2293    class ConnectionHandler(threading.Thread):
2294
2295        """A mildly complicated class, because we want it to work both
2296        with and without the SSL wrapper around the socket connection, so
2297        that we can test the STARTTLS functionality."""
2298
2299        def __init__(self, server, connsock, addr):
2300            self.server = server
2301            self.running = False
2302            self.sock = connsock
2303            self.addr = addr
2304            self.sock.setblocking(True)
2305            self.sslconn = None
2306            threading.Thread.__init__(self)
2307            self.daemon = True
2308
2309        def wrap_conn(self):
2310            try:
2311                self.sslconn = self.server.context.wrap_socket(
2312                    self.sock, server_side=True)
2313                self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
2314                self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
2315            except (ConnectionResetError, BrokenPipeError, ConnectionAbortedError) as e:
2316                # We treat ConnectionResetError as though it were an
2317                # SSLError - OpenSSL on Ubuntu abruptly closes the
2318                # connection when asked to use an unsupported protocol.
2319                #
2320                # BrokenPipeError is raised in TLS 1.3 mode, when OpenSSL
2321                # tries to send session tickets after handshake.
2322                # https://github.com/openssl/openssl/issues/6342
2323                #
2324                # ConnectionAbortedError is raised in TLS 1.3 mode, when OpenSSL
2325                # tries to send session tickets after handshake when using WinSock.
2326                self.server.conn_errors.append(str(e))
2327                if self.server.chatty:
2328                    handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
2329                self.running = False
2330                self.close()
2331                return False
2332            except (ssl.SSLError, OSError) as e:
2333                # OSError may occur with wrong protocols, e.g. both
2334                # sides use PROTOCOL_TLS_SERVER.
2335                #
2336                # XXX Various errors can have happened here, for example
2337                # a mismatching protocol version, an invalid certificate,
2338                # or a low-level bug. This should be made more discriminating.
2339                #
2340                # bpo-31323: Store the exception as string to prevent
2341                # a reference leak: server -> conn_errors -> exception
2342                # -> traceback -> self (ConnectionHandler) -> server
2343                self.server.conn_errors.append(str(e))
2344                if self.server.chatty:
2345                    handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
2346                self.running = False
2347                self.server.stop()
2348                self.close()
2349                return False
2350            else:
2351                self.server.shared_ciphers.append(self.sslconn.shared_ciphers())
2352                if self.server.context.verify_mode == ssl.CERT_REQUIRED:
2353                    cert = self.sslconn.getpeercert()
2354                    if support.verbose and self.server.chatty:
2355                        sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
2356                    cert_binary = self.sslconn.getpeercert(True)
2357                    if support.verbose and self.server.chatty:
2358                        sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
2359                cipher = self.sslconn.cipher()
2360                if support.verbose and self.server.chatty:
2361                    sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
2362                    sys.stdout.write(" server: selected protocol is now "
2363                            + str(self.sslconn.selected_npn_protocol()) + "\n")
2364                return True
2365
2366        def read(self):
2367            if self.sslconn:
2368                return self.sslconn.read()
2369            else:
2370                return self.sock.recv(1024)
2371
2372        def write(self, bytes):
2373            if self.sslconn:
2374                return self.sslconn.write(bytes)
2375            else:
2376                return self.sock.send(bytes)
2377
2378        def close(self):
2379            if self.sslconn:
2380                self.sslconn.close()
2381            else:
2382                self.sock.close()
2383
2384        def run(self):
2385            self.running = True
2386            if not self.server.starttls_server:
2387                if not self.wrap_conn():
2388                    return
2389            while self.running:
2390                try:
2391                    msg = self.read()
2392                    stripped = msg.strip()
2393                    if not stripped:
2394                        # eof, so quit this handler
2395                        self.running = False
2396                        try:
2397                            self.sock = self.sslconn.unwrap()
2398                        except OSError:
2399                            # Many tests shut the TCP connection down
2400                            # without an SSL shutdown. This causes
2401                            # unwrap() to raise OSError with errno=0!
2402                            pass
2403                        else:
2404                            self.sslconn = None
2405                        self.close()
2406                    elif stripped == b'over':
2407                        if support.verbose and self.server.connectionchatty:
2408                            sys.stdout.write(" server: client closed connection\n")
2409                        self.close()
2410                        return
2411                    elif (self.server.starttls_server and
2412                          stripped == b'STARTTLS'):
2413                        if support.verbose and self.server.connectionchatty:
2414                            sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
2415                        self.write(b"OK\n")
2416                        if not self.wrap_conn():
2417                            return
2418                    elif (self.server.starttls_server and self.sslconn
2419                          and stripped == b'ENDTLS'):
2420                        if support.verbose and self.server.connectionchatty:
2421                            sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
2422                        self.write(b"OK\n")
2423                        self.sock = self.sslconn.unwrap()
2424                        self.sslconn = None
2425                        if support.verbose and self.server.connectionchatty:
2426                            sys.stdout.write(" server: connection is now unencrypted...\n")
2427                    elif stripped == b'CB tls-unique':
2428                        if support.verbose and self.server.connectionchatty:
2429                            sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
2430                        data = self.sslconn.get_channel_binding("tls-unique")
2431                        self.write(repr(data).encode("us-ascii") + b"\n")
2432                    elif stripped == b'PHA':
2433                        if support.verbose and self.server.connectionchatty:
2434                            sys.stdout.write(" server: initiating post handshake auth\n")
2435                        try:
2436                            self.sslconn.verify_client_post_handshake()
2437                        except ssl.SSLError as e:
2438                            self.write(repr(e).encode("us-ascii") + b"\n")
2439                        else:
2440                            self.write(b"OK\n")
2441                    elif stripped == b'HASCERT':
2442                        if self.sslconn.getpeercert() is not None:
2443                            self.write(b'TRUE\n')
2444                        else:
2445                            self.write(b'FALSE\n')
2446                    elif stripped == b'GETCERT':
2447                        cert = self.sslconn.getpeercert()
2448                        self.write(repr(cert).encode("us-ascii") + b"\n")
2449                    else:
2450                        if (support.verbose and
2451                            self.server.connectionchatty):
2452                            ctype = (self.sslconn and "encrypted") or "unencrypted"
2453                            sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
2454                                             % (msg, ctype, msg.lower(), ctype))
2455                        self.write(msg.lower())
2456                except (ConnectionResetError, ConnectionAbortedError):
2457                    # XXX: OpenSSL 1.1.1 sometimes raises ConnectionResetError
2458                    # when connection is not shut down gracefully.
2459                    if self.server.chatty and support.verbose:
2460                        sys.stdout.write(
2461                            " Connection reset by peer: {}\n".format(
2462                                self.addr)
2463                        )
2464                    self.close()
2465                    self.running = False
2466                except ssl.SSLError as err:
2467                    # On Windows sometimes test_pha_required_nocert receives the
2468                    # PEER_DID_NOT_RETURN_A_CERTIFICATE exception
2469                    # before the 'tlsv13 alert certificate required' exception.
2470                    # If the server is stopped when PEER_DID_NOT_RETURN_A_CERTIFICATE
2471                    # is received test_pha_required_nocert fails with ConnectionResetError
2472                    # because the underlying socket is closed
2473                    if 'PEER_DID_NOT_RETURN_A_CERTIFICATE' == err.reason:
2474                        if self.server.chatty and support.verbose:
2475                            sys.stdout.write(err.args[1])
2476                        # test_pha_required_nocert is expecting this exception
2477                        raise ssl.SSLError('tlsv13 alert certificate required')
2478                except OSError:
2479                    if self.server.chatty:
2480                        handle_error("Test server failure:\n")
2481                    self.close()
2482                    self.running = False
2483
2484                    # normally, we'd just stop here, but for the test
2485                    # harness, we want to stop the server
2486                    self.server.stop()
2487
2488    def __init__(self, certificate=None, ssl_version=None,
2489                 certreqs=None, cacerts=None,
2490                 chatty=True, connectionchatty=False, starttls_server=False,
2491                 npn_protocols=None, alpn_protocols=None,
2492                 ciphers=None, context=None):
2493        if context:
2494            self.context = context
2495        else:
2496            self.context = ssl.SSLContext(ssl_version
2497                                          if ssl_version is not None
2498                                          else ssl.PROTOCOL_TLS_SERVER)
2499            self.context.verify_mode = (certreqs if certreqs is not None
2500                                        else ssl.CERT_NONE)
2501            if cacerts:
2502                self.context.load_verify_locations(cacerts)
2503            if certificate:
2504                self.context.load_cert_chain(certificate)
2505            if npn_protocols:
2506                self.context.set_npn_protocols(npn_protocols)
2507            if alpn_protocols:
2508                self.context.set_alpn_protocols(alpn_protocols)
2509            if ciphers:
2510                self.context.set_ciphers(ciphers)
2511        self.chatty = chatty
2512        self.connectionchatty = connectionchatty
2513        self.starttls_server = starttls_server
2514        self.sock = socket.socket()
2515        self.port = socket_helper.bind_port(self.sock)
2516        self.flag = None
2517        self.active = False
2518        self.selected_npn_protocols = []
2519        self.selected_alpn_protocols = []
2520        self.shared_ciphers = []
2521        self.conn_errors = []
2522        threading.Thread.__init__(self)
2523        self.daemon = True
2524
2525    def __enter__(self):
2526        self.start(threading.Event())
2527        self.flag.wait()
2528        return self
2529
2530    def __exit__(self, *args):
2531        self.stop()
2532        self.join()
2533
2534    def start(self, flag=None):
2535        self.flag = flag
2536        threading.Thread.start(self)
2537
2538    def run(self):
2539        self.sock.settimeout(0.05)
2540        self.sock.listen()
2541        self.active = True
2542        if self.flag:
2543            # signal an event
2544            self.flag.set()
2545        while self.active:
2546            try:
2547                newconn, connaddr = self.sock.accept()
2548                if support.verbose and self.chatty:
2549                    sys.stdout.write(' server:  new connection from '
2550                                     + repr(connaddr) + '\n')
2551                handler = self.ConnectionHandler(self, newconn, connaddr)
2552                handler.start()
2553                handler.join()
2554            except socket.timeout:
2555                pass
2556            except KeyboardInterrupt:
2557                self.stop()
2558            except BaseException as e:
2559                if support.verbose and self.chatty:
2560                    sys.stdout.write(
2561                        ' connection handling failed: ' + repr(e) + '\n')
2562
2563        self.sock.close()
2564
2565    def stop(self):
2566        self.active = False
2567
2568class AsyncoreEchoServer(threading.Thread):
2569
2570    # this one's based on asyncore.dispatcher
2571
2572    class EchoServer (asyncore.dispatcher):
2573
2574        class ConnectionHandler(asyncore.dispatcher_with_send):
2575
2576            def __init__(self, conn, certfile):
2577                self.socket = test_wrap_socket(conn, server_side=True,
2578                                              certfile=certfile,
2579                                              do_handshake_on_connect=False)
2580                asyncore.dispatcher_with_send.__init__(self, self.socket)
2581                self._ssl_accepting = True
2582                self._do_ssl_handshake()
2583
2584            def readable(self):
2585                if isinstance(self.socket, ssl.SSLSocket):
2586                    while self.socket.pending() > 0:
2587                        self.handle_read_event()
2588                return True
2589
2590            def _do_ssl_handshake(self):
2591                try:
2592                    self.socket.do_handshake()
2593                except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
2594                    return
2595                except ssl.SSLEOFError:
2596                    return self.handle_close()
2597                except ssl.SSLError:
2598                    raise
2599                except OSError as err:
2600                    if err.args[0] == errno.ECONNABORTED:
2601                        return self.handle_close()
2602                else:
2603                    self._ssl_accepting = False
2604
2605            def handle_read(self):
2606                if self._ssl_accepting:
2607                    self._do_ssl_handshake()
2608                else:
2609                    data = self.recv(1024)
2610                    if support.verbose:
2611                        sys.stdout.write(" server:  read %s from client\n" % repr(data))
2612                    if not data:
2613                        self.close()
2614                    else:
2615                        self.send(data.lower())
2616
2617            def handle_close(self):
2618                self.close()
2619                if support.verbose:
2620                    sys.stdout.write(" server:  closed connection %s\n" % self.socket)
2621
2622            def handle_error(self):
2623                raise
2624
2625        def __init__(self, certfile):
2626            self.certfile = certfile
2627            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
2628            self.port = socket_helper.bind_port(sock, '')
2629            asyncore.dispatcher.__init__(self, sock)
2630            self.listen(5)
2631
2632        def handle_accepted(self, sock_obj, addr):
2633            if support.verbose:
2634                sys.stdout.write(" server:  new connection from %s:%s\n" %addr)
2635            self.ConnectionHandler(sock_obj, self.certfile)
2636
2637        def handle_error(self):
2638            raise
2639
2640    def __init__(self, certfile):
2641        self.flag = None
2642        self.active = False
2643        self.server = self.EchoServer(certfile)
2644        self.port = self.server.port
2645        threading.Thread.__init__(self)
2646        self.daemon = True
2647
2648    def __str__(self):
2649        return "<%s %s>" % (self.__class__.__name__, self.server)
2650
2651    def __enter__(self):
2652        self.start(threading.Event())
2653        self.flag.wait()
2654        return self
2655
2656    def __exit__(self, *args):
2657        if support.verbose:
2658            sys.stdout.write(" cleanup: stopping server.\n")
2659        self.stop()
2660        if support.verbose:
2661            sys.stdout.write(" cleanup: joining server thread.\n")
2662        self.join()
2663        if support.verbose:
2664            sys.stdout.write(" cleanup: successfully joined.\n")
2665        # make sure that ConnectionHandler is removed from socket_map
2666        asyncore.close_all(ignore_all=True)
2667
2668    def start (self, flag=None):
2669        self.flag = flag
2670        threading.Thread.start(self)
2671
2672    def run(self):
2673        self.active = True
2674        if self.flag:
2675            self.flag.set()
2676        while self.active:
2677            try:
2678                asyncore.loop(1)
2679            except:
2680                pass
2681
2682    def stop(self):
2683        self.active = False
2684        self.server.close()
2685
2686def server_params_test(client_context, server_context, indata=b"FOO\n",
2687                       chatty=True, connectionchatty=False, sni_name=None,
2688                       session=None):
2689    """
2690    Launch a server, connect a client to it and try various reads
2691    and writes.
2692    """
2693    stats = {}
2694    server = ThreadedEchoServer(context=server_context,
2695                                chatty=chatty,
2696                                connectionchatty=False)
2697    with server:
2698        with client_context.wrap_socket(socket.socket(),
2699                server_hostname=sni_name, session=session) as s:
2700            s.connect((HOST, server.port))
2701            for arg in [indata, bytearray(indata), memoryview(indata)]:
2702                if connectionchatty:
2703                    if support.verbose:
2704                        sys.stdout.write(
2705                            " client:  sending %r...\n" % indata)
2706                s.write(arg)
2707                outdata = s.read()
2708                if connectionchatty:
2709                    if support.verbose:
2710                        sys.stdout.write(" client:  read %r\n" % outdata)
2711                if outdata != indata.lower():
2712                    raise AssertionError(
2713                        "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
2714                        % (outdata[:20], len(outdata),
2715                           indata[:20].lower(), len(indata)))
2716            s.write(b"over\n")
2717            if connectionchatty:
2718                if support.verbose:
2719                    sys.stdout.write(" client:  closing connection.\n")
2720            stats.update({
2721                'compression': s.compression(),
2722                'cipher': s.cipher(),
2723                'peercert': s.getpeercert(),
2724                'client_alpn_protocol': s.selected_alpn_protocol(),
2725                'client_npn_protocol': s.selected_npn_protocol(),
2726                'version': s.version(),
2727                'session_reused': s.session_reused,
2728                'session': s.session,
2729            })
2730            s.close()
2731        stats['server_alpn_protocols'] = server.selected_alpn_protocols
2732        stats['server_npn_protocols'] = server.selected_npn_protocols
2733        stats['server_shared_ciphers'] = server.shared_ciphers
2734    return stats
2735
2736def try_protocol_combo(server_protocol, client_protocol, expect_success,
2737                       certsreqs=None, server_options=0, client_options=0):
2738    """
2739    Try to SSL-connect using *client_protocol* to *server_protocol*.
2740    If *expect_success* is true, assert that the connection succeeds,
2741    if it's false, assert that the connection fails.
2742    Also, if *expect_success* is a string, assert that it is the protocol
2743    version actually used by the connection.
2744    """
2745    if certsreqs is None:
2746        certsreqs = ssl.CERT_NONE
2747    certtype = {
2748        ssl.CERT_NONE: "CERT_NONE",
2749        ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
2750        ssl.CERT_REQUIRED: "CERT_REQUIRED",
2751    }[certsreqs]
2752    if support.verbose:
2753        formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
2754        sys.stdout.write(formatstr %
2755                         (ssl.get_protocol_name(client_protocol),
2756                          ssl.get_protocol_name(server_protocol),
2757                          certtype))
2758    client_context = ssl.SSLContext(client_protocol)
2759    client_context.options |= client_options
2760    server_context = ssl.SSLContext(server_protocol)
2761    server_context.options |= server_options
2762
2763    min_version = PROTOCOL_TO_TLS_VERSION.get(client_protocol, None)
2764    if (min_version is not None
2765    # SSLContext.minimum_version is only available on recent OpenSSL
2766    # (setter added in OpenSSL 1.1.0, getter added in OpenSSL 1.1.1)
2767    and hasattr(server_context, 'minimum_version')
2768    and server_protocol == ssl.PROTOCOL_TLS
2769    and server_context.minimum_version > min_version):
2770        # If OpenSSL configuration is strict and requires more recent TLS
2771        # version, we have to change the minimum to test old TLS versions.
2772        server_context.minimum_version = min_version
2773
2774    # NOTE: we must enable "ALL" ciphers on the client, otherwise an
2775    # SSLv23 client will send an SSLv3 hello (rather than SSLv2)
2776    # starting from OpenSSL 1.0.0 (see issue #8322).
2777    if client_context.protocol == ssl.PROTOCOL_TLS:
2778        client_context.set_ciphers("ALL")
2779
2780    for ctx in (client_context, server_context):
2781        ctx.verify_mode = certsreqs
2782        ctx.load_cert_chain(SIGNED_CERTFILE)
2783        ctx.load_verify_locations(SIGNING_CA)
2784    try:
2785        stats = server_params_test(client_context, server_context,
2786                                   chatty=False, connectionchatty=False)
2787    # Protocol mismatch can result in either an SSLError, or a
2788    # "Connection reset by peer" error.
2789    except ssl.SSLError:
2790        if expect_success:
2791            raise
2792    except OSError as e:
2793        if expect_success or e.errno != errno.ECONNRESET:
2794            raise
2795    else:
2796        if not expect_success:
2797            raise AssertionError(
2798                "Client protocol %s succeeded with server protocol %s!"
2799                % (ssl.get_protocol_name(client_protocol),
2800                   ssl.get_protocol_name(server_protocol)))
2801        elif (expect_success is not True
2802              and expect_success != stats['version']):
2803            raise AssertionError("version mismatch: expected %r, got %r"
2804                                 % (expect_success, stats['version']))
2805
2806
2807class ThreadedTests(unittest.TestCase):
2808
2809    def test_echo(self):
2810        """Basic test of an SSL client connecting to a server"""
2811        if support.verbose:
2812            sys.stdout.write("\n")
2813        for protocol in PROTOCOLS:
2814            if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}:
2815                continue
2816            if not has_tls_protocol(protocol):
2817                continue
2818            with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]):
2819                context = ssl.SSLContext(protocol)
2820                context.load_cert_chain(CERTFILE)
2821                server_params_test(context, context,
2822                                   chatty=True, connectionchatty=True)
2823
2824        client_context, server_context, hostname = testing_context()
2825
2826        with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER):
2827            server_params_test(client_context=client_context,
2828                               server_context=server_context,
2829                               chatty=True, connectionchatty=True,
2830                               sni_name=hostname)
2831
2832        client_context.check_hostname = False
2833        with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT):
2834            with self.assertRaises(ssl.SSLError) as e:
2835                server_params_test(client_context=server_context,
2836                                   server_context=client_context,
2837                                   chatty=True, connectionchatty=True,
2838                                   sni_name=hostname)
2839            self.assertIn('called a function you should not call',
2840                          str(e.exception))
2841
2842        with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER):
2843            with self.assertRaises(ssl.SSLError) as e:
2844                server_params_test(client_context=server_context,
2845                                   server_context=server_context,
2846                                   chatty=True, connectionchatty=True)
2847            self.assertIn('called a function you should not call',
2848                          str(e.exception))
2849
2850        with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT):
2851            with self.assertRaises(ssl.SSLError) as e:
2852                server_params_test(client_context=server_context,
2853                                   server_context=client_context,
2854                                   chatty=True, connectionchatty=True)
2855            self.assertIn('called a function you should not call',
2856                          str(e.exception))
2857
2858    def test_getpeercert(self):
2859        if support.verbose:
2860            sys.stdout.write("\n")
2861
2862        client_context, server_context, hostname = testing_context()
2863        server = ThreadedEchoServer(context=server_context, chatty=False)
2864        with server:
2865            with client_context.wrap_socket(socket.socket(),
2866                                            do_handshake_on_connect=False,
2867                                            server_hostname=hostname) as s:
2868                s.connect((HOST, server.port))
2869                # getpeercert() raise ValueError while the handshake isn't
2870                # done.
2871                with self.assertRaises(ValueError):
2872                    s.getpeercert()
2873                s.do_handshake()
2874                cert = s.getpeercert()
2875                self.assertTrue(cert, "Can't get peer certificate.")
2876                cipher = s.cipher()
2877                if support.verbose:
2878                    sys.stdout.write(pprint.pformat(cert) + '\n')
2879                    sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
2880                if 'subject' not in cert:
2881                    self.fail("No subject field in certificate: %s." %
2882                              pprint.pformat(cert))
2883                if ((('organizationName', 'Python Software Foundation'),)
2884                    not in cert['subject']):
2885                    self.fail(
2886                        "Missing or invalid 'organizationName' field in certificate subject; "
2887                        "should be 'Python Software Foundation'.")
2888                self.assertIn('notBefore', cert)
2889                self.assertIn('notAfter', cert)
2890                before = ssl.cert_time_to_seconds(cert['notBefore'])
2891                after = ssl.cert_time_to_seconds(cert['notAfter'])
2892                self.assertLess(before, after)
2893
2894    @unittest.skipUnless(have_verify_flags(),
2895                        "verify_flags need OpenSSL > 0.9.8")
2896    def test_crl_check(self):
2897        if support.verbose:
2898            sys.stdout.write("\n")
2899
2900        client_context, server_context, hostname = testing_context()
2901
2902        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
2903        self.assertEqual(client_context.verify_flags, ssl.VERIFY_DEFAULT | tf)
2904
2905        # VERIFY_DEFAULT should pass
2906        server = ThreadedEchoServer(context=server_context, chatty=True)
2907        with server:
2908            with client_context.wrap_socket(socket.socket(),
2909                                            server_hostname=hostname) as s:
2910                s.connect((HOST, server.port))
2911                cert = s.getpeercert()
2912                self.assertTrue(cert, "Can't get peer certificate.")
2913
2914        # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails
2915        client_context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
2916
2917        server = ThreadedEchoServer(context=server_context, chatty=True)
2918        with server:
2919            with client_context.wrap_socket(socket.socket(),
2920                                            server_hostname=hostname) as s:
2921                with self.assertRaisesRegex(ssl.SSLError,
2922                                            "certificate verify failed"):
2923                    s.connect((HOST, server.port))
2924
2925        # now load a CRL file. The CRL file is signed by the CA.
2926        client_context.load_verify_locations(CRLFILE)
2927
2928        server = ThreadedEchoServer(context=server_context, chatty=True)
2929        with server:
2930            with client_context.wrap_socket(socket.socket(),
2931                                            server_hostname=hostname) as s:
2932                s.connect((HOST, server.port))
2933                cert = s.getpeercert()
2934                self.assertTrue(cert, "Can't get peer certificate.")
2935
2936    def test_check_hostname(self):
2937        if support.verbose:
2938            sys.stdout.write("\n")
2939
2940        client_context, server_context, hostname = testing_context()
2941
2942        # correct hostname should verify
2943        server = ThreadedEchoServer(context=server_context, chatty=True)
2944        with server:
2945            with client_context.wrap_socket(socket.socket(),
2946                                            server_hostname=hostname) as s:
2947                s.connect((HOST, server.port))
2948                cert = s.getpeercert()
2949                self.assertTrue(cert, "Can't get peer certificate.")
2950
2951        # incorrect hostname should raise an exception
2952        server = ThreadedEchoServer(context=server_context, chatty=True)
2953        with server:
2954            with client_context.wrap_socket(socket.socket(),
2955                                            server_hostname="invalid") as s:
2956                with self.assertRaisesRegex(
2957                        ssl.CertificateError,
2958                        "Hostname mismatch, certificate is not valid for 'invalid'."):
2959                    s.connect((HOST, server.port))
2960
2961        # missing server_hostname arg should cause an exception, too
2962        server = ThreadedEchoServer(context=server_context, chatty=True)
2963        with server:
2964            with socket.socket() as s:
2965                with self.assertRaisesRegex(ValueError,
2966                                            "check_hostname requires server_hostname"):
2967                    client_context.wrap_socket(s)
2968
2969    def test_ecc_cert(self):
2970        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2971        client_context.load_verify_locations(SIGNING_CA)
2972        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
2973        hostname = SIGNED_CERTFILE_ECC_HOSTNAME
2974
2975        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
2976        # load ECC cert
2977        server_context.load_cert_chain(SIGNED_CERTFILE_ECC)
2978
2979        # correct hostname should verify
2980        server = ThreadedEchoServer(context=server_context, chatty=True)
2981        with server:
2982            with client_context.wrap_socket(socket.socket(),
2983                                            server_hostname=hostname) as s:
2984                s.connect((HOST, server.port))
2985                cert = s.getpeercert()
2986                self.assertTrue(cert, "Can't get peer certificate.")
2987                cipher = s.cipher()[0].split('-')
2988                self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
2989
2990    def test_dual_rsa_ecc(self):
2991        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2992        client_context.load_verify_locations(SIGNING_CA)
2993        # TODO: fix TLSv1.3 once SSLContext can restrict signature
2994        #       algorithms.
2995        client_context.options |= ssl.OP_NO_TLSv1_3
2996        # only ECDSA certs
2997        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
2998        hostname = SIGNED_CERTFILE_ECC_HOSTNAME
2999
3000        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3001        # load ECC and RSA key/cert pairs
3002        server_context.load_cert_chain(SIGNED_CERTFILE_ECC)
3003        server_context.load_cert_chain(SIGNED_CERTFILE)
3004
3005        # correct hostname should verify
3006        server = ThreadedEchoServer(context=server_context, chatty=True)
3007        with server:
3008            with client_context.wrap_socket(socket.socket(),
3009                                            server_hostname=hostname) as s:
3010                s.connect((HOST, server.port))
3011                cert = s.getpeercert()
3012                self.assertTrue(cert, "Can't get peer certificate.")
3013                cipher = s.cipher()[0].split('-')
3014                self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
3015
3016    def test_check_hostname_idn(self):
3017        if support.verbose:
3018            sys.stdout.write("\n")
3019
3020        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3021        server_context.load_cert_chain(IDNSANSFILE)
3022
3023        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3024        context.verify_mode = ssl.CERT_REQUIRED
3025        context.check_hostname = True
3026        context.load_verify_locations(SIGNING_CA)
3027
3028        # correct hostname should verify, when specified in several
3029        # different ways
3030        idn_hostnames = [
3031            ('könig.idn.pythontest.net',
3032             'xn--knig-5qa.idn.pythontest.net'),
3033            ('xn--knig-5qa.idn.pythontest.net',
3034             'xn--knig-5qa.idn.pythontest.net'),
3035            (b'xn--knig-5qa.idn.pythontest.net',
3036             'xn--knig-5qa.idn.pythontest.net'),
3037
3038            ('königsgäßchen.idna2003.pythontest.net',
3039             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3040            ('xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
3041             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3042            (b'xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
3043             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3044
3045            # ('königsgäßchen.idna2008.pythontest.net',
3046            #  'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3047            ('xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
3048             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3049            (b'xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
3050             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3051
3052        ]
3053        for server_hostname, expected_hostname in idn_hostnames:
3054            server = ThreadedEchoServer(context=server_context, chatty=True)
3055            with server:
3056                with context.wrap_socket(socket.socket(),
3057                                         server_hostname=server_hostname) as s:
3058                    self.assertEqual(s.server_hostname, expected_hostname)
3059                    s.connect((HOST, server.port))
3060                    cert = s.getpeercert()
3061                    self.assertEqual(s.server_hostname, expected_hostname)
3062                    self.assertTrue(cert, "Can't get peer certificate.")
3063
3064        # incorrect hostname should raise an exception
3065        server = ThreadedEchoServer(context=server_context, chatty=True)
3066        with server:
3067            with context.wrap_socket(socket.socket(),
3068                                     server_hostname="python.example.org") as s:
3069                with self.assertRaises(ssl.CertificateError):
3070                    s.connect((HOST, server.port))
3071
3072    def test_wrong_cert_tls12(self):
3073        """Connecting when the server rejects the client's certificate
3074
3075        Launch a server with CERT_REQUIRED, and check that trying to
3076        connect to it with a wrong client certificate fails.
3077        """
3078        client_context, server_context, hostname = testing_context()
3079        # load client cert that is not signed by trusted CA
3080        client_context.load_cert_chain(CERTFILE)
3081        # require TLS client authentication
3082        server_context.verify_mode = ssl.CERT_REQUIRED
3083        # TLS 1.3 has different handshake
3084        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3085
3086        server = ThreadedEchoServer(
3087            context=server_context, chatty=True, connectionchatty=True,
3088        )
3089
3090        with server, \
3091                client_context.wrap_socket(socket.socket(),
3092                                           server_hostname=hostname) as s:
3093            try:
3094                # Expect either an SSL error about the server rejecting
3095                # the connection, or a low-level connection reset (which
3096                # sometimes happens on Windows)
3097                s.connect((HOST, server.port))
3098            except ssl.SSLError as e:
3099                if support.verbose:
3100                    sys.stdout.write("\nSSLError is %r\n" % e)
3101            except OSError as e:
3102                if e.errno != errno.ECONNRESET:
3103                    raise
3104                if support.verbose:
3105                    sys.stdout.write("\nsocket.error is %r\n" % e)
3106            else:
3107                self.fail("Use of invalid cert should have failed!")
3108
3109    @requires_tls_version('TLSv1_3')
3110    def test_wrong_cert_tls13(self):
3111        client_context, server_context, hostname = testing_context()
3112        # load client cert that is not signed by trusted CA
3113        client_context.load_cert_chain(CERTFILE)
3114        server_context.verify_mode = ssl.CERT_REQUIRED
3115        server_context.minimum_version = ssl.TLSVersion.TLSv1_3
3116        client_context.minimum_version = ssl.TLSVersion.TLSv1_3
3117
3118        server = ThreadedEchoServer(
3119            context=server_context, chatty=True, connectionchatty=True,
3120        )
3121        with server, \
3122             client_context.wrap_socket(socket.socket(),
3123                                        server_hostname=hostname) as s:
3124            # TLS 1.3 perform client cert exchange after handshake
3125            s.connect((HOST, server.port))
3126            try:
3127                s.write(b'data')
3128                s.read(4)
3129            except ssl.SSLError as e:
3130                if support.verbose:
3131                    sys.stdout.write("\nSSLError is %r\n" % e)
3132            except OSError as e:
3133                if e.errno != errno.ECONNRESET:
3134                    raise
3135                if support.verbose:
3136                    sys.stdout.write("\nsocket.error is %r\n" % e)
3137            else:
3138                self.fail("Use of invalid cert should have failed!")
3139
3140    def test_rude_shutdown(self):
3141        """A brutal shutdown of an SSL server should raise an OSError
3142        in the client when attempting handshake.
3143        """
3144        listener_ready = threading.Event()
3145        listener_gone = threading.Event()
3146
3147        s = socket.socket()
3148        port = socket_helper.bind_port(s, HOST)
3149
3150        # `listener` runs in a thread.  It sits in an accept() until
3151        # the main thread connects.  Then it rudely closes the socket,
3152        # and sets Event `listener_gone` to let the main thread know
3153        # the socket is gone.
3154        def listener():
3155            s.listen()
3156            listener_ready.set()
3157            newsock, addr = s.accept()
3158            newsock.close()
3159            s.close()
3160            listener_gone.set()
3161
3162        def connector():
3163            listener_ready.wait()
3164            with socket.socket() as c:
3165                c.connect((HOST, port))
3166                listener_gone.wait()
3167                try:
3168                    ssl_sock = test_wrap_socket(c)
3169                except OSError:
3170                    pass
3171                else:
3172                    self.fail('connecting to closed SSL socket should have failed')
3173
3174        t = threading.Thread(target=listener)
3175        t.start()
3176        try:
3177            connector()
3178        finally:
3179            t.join()
3180
3181    def test_ssl_cert_verify_error(self):
3182        if support.verbose:
3183            sys.stdout.write("\n")
3184
3185        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3186        server_context.load_cert_chain(SIGNED_CERTFILE)
3187
3188        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3189
3190        server = ThreadedEchoServer(context=server_context, chatty=True)
3191        with server:
3192            with context.wrap_socket(socket.socket(),
3193                                     server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
3194                try:
3195                    s.connect((HOST, server.port))
3196                except ssl.SSLError as e:
3197                    msg = 'unable to get local issuer certificate'
3198                    self.assertIsInstance(e, ssl.SSLCertVerificationError)
3199                    self.assertEqual(e.verify_code, 20)
3200                    self.assertEqual(e.verify_message, msg)
3201                    self.assertIn(msg, repr(e))
3202                    self.assertIn('certificate verify failed', repr(e))
3203
3204    @requires_tls_version('SSLv2')
3205    def test_protocol_sslv2(self):
3206        """Connecting to an SSLv2 server with various client options"""
3207        if support.verbose:
3208            sys.stdout.write("\n")
3209        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
3210        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
3211        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
3212        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False)
3213        if has_tls_version('SSLv3'):
3214            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
3215        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
3216        # SSLv23 client with specific SSL options
3217        if no_sslv2_implies_sslv3_hello():
3218            # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
3219            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3220                               client_options=ssl.OP_NO_SSLv2)
3221        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3222                           client_options=ssl.OP_NO_SSLv3)
3223        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3224                           client_options=ssl.OP_NO_TLSv1)
3225
3226    def test_PROTOCOL_TLS(self):
3227        """Connecting to an SSLv23 server with various client options"""
3228        if support.verbose:
3229            sys.stdout.write("\n")
3230        if has_tls_version('SSLv2'):
3231            try:
3232                try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv2, True)
3233            except OSError as x:
3234                # this fails on some older versions of OpenSSL (0.9.7l, for instance)
3235                if support.verbose:
3236                    sys.stdout.write(
3237                        " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
3238                        % str(x))
3239        if has_tls_version('SSLv3'):
3240            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False)
3241        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True)
3242        if has_tls_version('TLSv1'):
3243            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1')
3244
3245        if has_tls_version('SSLv3'):
3246            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
3247        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_OPTIONAL)
3248        if has_tls_version('TLSv1'):
3249            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
3250
3251        if has_tls_version('SSLv3'):
3252            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
3253        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_REQUIRED)
3254        if has_tls_version('TLSv1'):
3255            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
3256
3257        # Server with specific SSL options
3258        if has_tls_version('SSLv3'):
3259            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False,
3260                           server_options=ssl.OP_NO_SSLv3)
3261        # Will choose TLSv1
3262        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True,
3263                           server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
3264        if has_tls_version('TLSv1'):
3265            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, False,
3266                               server_options=ssl.OP_NO_TLSv1)
3267
3268    @requires_tls_version('SSLv3')
3269    def test_protocol_sslv3(self):
3270        """Connecting to an SSLv3 server with various client options"""
3271        if support.verbose:
3272            sys.stdout.write("\n")
3273        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
3274        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
3275        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
3276        if has_tls_version('SSLv2'):
3277            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
3278        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS, False,
3279                           client_options=ssl.OP_NO_SSLv3)
3280        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
3281        if no_sslv2_implies_sslv3_hello():
3282            # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
3283            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS,
3284                               False, client_options=ssl.OP_NO_SSLv2)
3285
3286    @requires_tls_version('TLSv1')
3287    def test_protocol_tlsv1(self):
3288        """Connecting to a TLSv1 server with various client options"""
3289        if support.verbose:
3290            sys.stdout.write("\n")
3291        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
3292        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
3293        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
3294        if has_tls_version('SSLv2'):
3295            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
3296        if has_tls_version('SSLv3'):
3297            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
3298        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLS, False,
3299                           client_options=ssl.OP_NO_TLSv1)
3300
3301    @requires_tls_version('TLSv1_1')
3302    def test_protocol_tlsv1_1(self):
3303        """Connecting to a TLSv1.1 server with various client options.
3304           Testing against older TLS versions."""
3305        if support.verbose:
3306            sys.stdout.write("\n")
3307        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
3308        if has_tls_version('SSLv2'):
3309            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
3310        if has_tls_version('SSLv3'):
3311            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
3312        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLS, False,
3313                           client_options=ssl.OP_NO_TLSv1_1)
3314
3315        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
3316        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
3317        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
3318
3319    @requires_tls_version('TLSv1_2')
3320    def test_protocol_tlsv1_2(self):
3321        """Connecting to a TLSv1.2 server with various client options.
3322           Testing against older TLS versions."""
3323        if support.verbose:
3324            sys.stdout.write("\n")
3325        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
3326                           server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
3327                           client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
3328        if has_tls_version('SSLv2'):
3329            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
3330        if has_tls_version('SSLv3'):
3331            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
3332        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLS, False,
3333                           client_options=ssl.OP_NO_TLSv1_2)
3334
3335        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
3336        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
3337        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
3338        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
3339        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
3340
3341    def test_starttls(self):
3342        """Switching from clear text to encrypted and back again."""
3343        msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
3344
3345        server = ThreadedEchoServer(CERTFILE,
3346                                    starttls_server=True,
3347                                    chatty=True,
3348                                    connectionchatty=True)
3349        wrapped = False
3350        with server:
3351            s = socket.socket()
3352            s.setblocking(True)
3353            s.connect((HOST, server.port))
3354            if support.verbose:
3355                sys.stdout.write("\n")
3356            for indata in msgs:
3357                if support.verbose:
3358                    sys.stdout.write(
3359                        " client:  sending %r...\n" % indata)
3360                if wrapped:
3361                    conn.write(indata)
3362                    outdata = conn.read()
3363                else:
3364                    s.send(indata)
3365                    outdata = s.recv(1024)
3366                msg = outdata.strip().lower()
3367                if indata == b"STARTTLS" and msg.startswith(b"ok"):
3368                    # STARTTLS ok, switch to secure mode
3369                    if support.verbose:
3370                        sys.stdout.write(
3371                            " client:  read %r from server, starting TLS...\n"
3372                            % msg)
3373                    conn = test_wrap_socket(s)
3374                    wrapped = True
3375                elif indata == b"ENDTLS" and msg.startswith(b"ok"):
3376                    # ENDTLS ok, switch back to clear text
3377                    if support.verbose:
3378                        sys.stdout.write(
3379                            " client:  read %r from server, ending TLS...\n"
3380                            % msg)
3381                    s = conn.unwrap()
3382                    wrapped = False
3383                else:
3384                    if support.verbose:
3385                        sys.stdout.write(
3386                            " client:  read %r from server\n" % msg)
3387            if support.verbose:
3388                sys.stdout.write(" client:  closing connection.\n")
3389            if wrapped:
3390                conn.write(b"over\n")
3391            else:
3392                s.send(b"over\n")
3393            if wrapped:
3394                conn.close()
3395            else:
3396                s.close()
3397
3398    def test_socketserver(self):
3399        """Using socketserver to create and manage SSL connections."""
3400        server = make_https_server(self, certfile=SIGNED_CERTFILE)
3401        # try to connect
3402        if support.verbose:
3403            sys.stdout.write('\n')
3404        with open(CERTFILE, 'rb') as f:
3405            d1 = f.read()
3406        d2 = ''
3407        # now fetch the same data from the HTTPS server
3408        url = 'https://localhost:%d/%s' % (
3409            server.port, os.path.split(CERTFILE)[1])
3410        context = ssl.create_default_context(cafile=SIGNING_CA)
3411        f = urllib.request.urlopen(url, context=context)
3412        try:
3413            dlen = f.info().get("content-length")
3414            if dlen and (int(dlen) > 0):
3415                d2 = f.read(int(dlen))
3416                if support.verbose:
3417                    sys.stdout.write(
3418                        " client: read %d bytes from remote server '%s'\n"
3419                        % (len(d2), server))
3420        finally:
3421            f.close()
3422        self.assertEqual(d1, d2)
3423
3424    def test_asyncore_server(self):
3425        """Check the example asyncore integration."""
3426        if support.verbose:
3427            sys.stdout.write("\n")
3428
3429        indata = b"FOO\n"
3430        server = AsyncoreEchoServer(CERTFILE)
3431        with server:
3432            s = test_wrap_socket(socket.socket())
3433            s.connect(('127.0.0.1', server.port))
3434            if support.verbose:
3435                sys.stdout.write(
3436                    " client:  sending %r...\n" % indata)
3437            s.write(indata)
3438            outdata = s.read()
3439            if support.verbose:
3440                sys.stdout.write(" client:  read %r\n" % outdata)
3441            if outdata != indata.lower():
3442                self.fail(
3443                    "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
3444                    % (outdata[:20], len(outdata),
3445                       indata[:20].lower(), len(indata)))
3446            s.write(b"over\n")
3447            if support.verbose:
3448                sys.stdout.write(" client:  closing connection.\n")
3449            s.close()
3450            if support.verbose:
3451                sys.stdout.write(" client:  connection closed.\n")
3452
3453    def test_recv_send(self):
3454        """Test recv(), send() and friends."""
3455        if support.verbose:
3456            sys.stdout.write("\n")
3457
3458        server = ThreadedEchoServer(CERTFILE,
3459                                    certreqs=ssl.CERT_NONE,
3460                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
3461                                    cacerts=CERTFILE,
3462                                    chatty=True,
3463                                    connectionchatty=False)
3464        with server:
3465            s = test_wrap_socket(socket.socket(),
3466                                server_side=False,
3467                                certfile=CERTFILE,
3468                                ca_certs=CERTFILE,
3469                                cert_reqs=ssl.CERT_NONE,
3470                                ssl_version=ssl.PROTOCOL_TLS_CLIENT)
3471            s.connect((HOST, server.port))
3472            # helper methods for standardising recv* method signatures
3473            def _recv_into():
3474                b = bytearray(b"\0"*100)
3475                count = s.recv_into(b)
3476                return b[:count]
3477
3478            def _recvfrom_into():
3479                b = bytearray(b"\0"*100)
3480                count, addr = s.recvfrom_into(b)
3481                return b[:count]
3482
3483            # (name, method, expect success?, *args, return value func)
3484            send_methods = [
3485                ('send', s.send, True, [], len),
3486                ('sendto', s.sendto, False, ["some.address"], len),
3487                ('sendall', s.sendall, True, [], lambda x: None),
3488            ]
3489            # (name, method, whether to expect success, *args)
3490            recv_methods = [
3491                ('recv', s.recv, True, []),
3492                ('recvfrom', s.recvfrom, False, ["some.address"]),
3493                ('recv_into', _recv_into, True, []),
3494                ('recvfrom_into', _recvfrom_into, False, []),
3495            ]
3496            data_prefix = "PREFIX_"
3497
3498            for (meth_name, send_meth, expect_success, args,
3499                    ret_val_meth) in send_methods:
3500                indata = (data_prefix + meth_name).encode('ascii')
3501                try:
3502                    ret = send_meth(indata, *args)
3503                    msg = "sending with {}".format(meth_name)
3504                    self.assertEqual(ret, ret_val_meth(indata), msg=msg)
3505                    outdata = s.read()
3506                    if outdata != indata.lower():
3507                        self.fail(
3508                            "While sending with <<{name:s}>> bad data "
3509                            "<<{outdata:r}>> ({nout:d}) received; "
3510                            "expected <<{indata:r}>> ({nin:d})\n".format(
3511                                name=meth_name, outdata=outdata[:20],
3512                                nout=len(outdata),
3513                                indata=indata[:20], nin=len(indata)
3514                            )
3515                        )
3516                except ValueError as e:
3517                    if expect_success:
3518                        self.fail(
3519                            "Failed to send with method <<{name:s}>>; "
3520                            "expected to succeed.\n".format(name=meth_name)
3521                        )
3522                    if not str(e).startswith(meth_name):
3523                        self.fail(
3524                            "Method <<{name:s}>> failed with unexpected "
3525                            "exception message: {exp:s}\n".format(
3526                                name=meth_name, exp=e
3527                            )
3528                        )
3529
3530            for meth_name, recv_meth, expect_success, args in recv_methods:
3531                indata = (data_prefix + meth_name).encode('ascii')
3532                try:
3533                    s.send(indata)
3534                    outdata = recv_meth(*args)
3535                    if outdata != indata.lower():
3536                        self.fail(
3537                            "While receiving with <<{name:s}>> bad data "
3538                            "<<{outdata:r}>> ({nout:d}) received; "
3539                            "expected <<{indata:r}>> ({nin:d})\n".format(
3540                                name=meth_name, outdata=outdata[:20],
3541                                nout=len(outdata),
3542                                indata=indata[:20], nin=len(indata)
3543                            )
3544                        )
3545                except ValueError as e:
3546                    if expect_success:
3547                        self.fail(
3548                            "Failed to receive with method <<{name:s}>>; "
3549                            "expected to succeed.\n".format(name=meth_name)
3550                        )
3551                    if not str(e).startswith(meth_name):
3552                        self.fail(
3553                            "Method <<{name:s}>> failed with unexpected "
3554                            "exception message: {exp:s}\n".format(
3555                                name=meth_name, exp=e
3556                            )
3557                        )
3558                    # consume data
3559                    s.read()
3560
3561            # read(-1, buffer) is supported, even though read(-1) is not
3562            data = b"data"
3563            s.send(data)
3564            buffer = bytearray(len(data))
3565            self.assertEqual(s.read(-1, buffer), len(data))
3566            self.assertEqual(buffer, data)
3567
3568            # sendall accepts bytes-like objects
3569            if ctypes is not None:
3570                ubyte = ctypes.c_ubyte * len(data)
3571                byteslike = ubyte.from_buffer_copy(data)
3572                s.sendall(byteslike)
3573                self.assertEqual(s.read(), data)
3574
3575            # Make sure sendmsg et al are disallowed to avoid
3576            # inadvertent disclosure of data and/or corruption
3577            # of the encrypted data stream
3578            self.assertRaises(NotImplementedError, s.dup)
3579            self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
3580            self.assertRaises(NotImplementedError, s.recvmsg, 100)
3581            self.assertRaises(NotImplementedError,
3582                              s.recvmsg_into, [bytearray(100)])
3583            s.write(b"over\n")
3584
3585            self.assertRaises(ValueError, s.recv, -1)
3586            self.assertRaises(ValueError, s.read, -1)
3587
3588            s.close()
3589
3590    def test_recv_zero(self):
3591        server = ThreadedEchoServer(CERTFILE)
3592        server.__enter__()
3593        self.addCleanup(server.__exit__, None, None)
3594        s = socket.create_connection((HOST, server.port))
3595        self.addCleanup(s.close)
3596        s = test_wrap_socket(s, suppress_ragged_eofs=False)
3597        self.addCleanup(s.close)
3598
3599        # recv/read(0) should return no data
3600        s.send(b"data")
3601        self.assertEqual(s.recv(0), b"")
3602        self.assertEqual(s.read(0), b"")
3603        self.assertEqual(s.read(), b"data")
3604
3605        # Should not block if the other end sends no data
3606        s.setblocking(False)
3607        self.assertEqual(s.recv(0), b"")
3608        self.assertEqual(s.recv_into(bytearray()), 0)
3609
3610    def test_nonblocking_send(self):
3611        server = ThreadedEchoServer(CERTFILE,
3612                                    certreqs=ssl.CERT_NONE,
3613                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
3614                                    cacerts=CERTFILE,
3615                                    chatty=True,
3616                                    connectionchatty=False)
3617        with server:
3618            s = test_wrap_socket(socket.socket(),
3619                                server_side=False,
3620                                certfile=CERTFILE,
3621                                ca_certs=CERTFILE,
3622                                cert_reqs=ssl.CERT_NONE,
3623                                ssl_version=ssl.PROTOCOL_TLS_CLIENT)
3624            s.connect((HOST, server.port))
3625            s.setblocking(False)
3626
3627            # If we keep sending data, at some point the buffers
3628            # will be full and the call will block
3629            buf = bytearray(8192)
3630            def fill_buffer():
3631                while True:
3632                    s.send(buf)
3633            self.assertRaises((ssl.SSLWantWriteError,
3634                               ssl.SSLWantReadError), fill_buffer)
3635
3636            # Now read all the output and discard it
3637            s.setblocking(True)
3638            s.close()
3639
3640    def test_handshake_timeout(self):
3641        # Issue #5103: SSL handshake must respect the socket timeout
3642        server = socket.socket(socket.AF_INET)
3643        host = "127.0.0.1"
3644        port = socket_helper.bind_port(server)
3645        started = threading.Event()
3646        finish = False
3647
3648        def serve():
3649            server.listen()
3650            started.set()
3651            conns = []
3652            while not finish:
3653                r, w, e = select.select([server], [], [], 0.1)
3654                if server in r:
3655                    # Let the socket hang around rather than having
3656                    # it closed by garbage collection.
3657                    conns.append(server.accept()[0])
3658            for sock in conns:
3659                sock.close()
3660
3661        t = threading.Thread(target=serve)
3662        t.start()
3663        started.wait()
3664
3665        try:
3666            try:
3667                c = socket.socket(socket.AF_INET)
3668                c.settimeout(0.2)
3669                c.connect((host, port))
3670                # Will attempt handshake and time out
3671                self.assertRaisesRegex(socket.timeout, "timed out",
3672                                       test_wrap_socket, c)
3673            finally:
3674                c.close()
3675            try:
3676                c = socket.socket(socket.AF_INET)
3677                c = test_wrap_socket(c)
3678                c.settimeout(0.2)
3679                # Will attempt handshake and time out
3680                self.assertRaisesRegex(socket.timeout, "timed out",
3681                                       c.connect, (host, port))
3682            finally:
3683                c.close()
3684        finally:
3685            finish = True
3686            t.join()
3687            server.close()
3688
3689    def test_server_accept(self):
3690        # Issue #16357: accept() on a SSLSocket created through
3691        # SSLContext.wrap_socket().
3692        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3693        context.verify_mode = ssl.CERT_REQUIRED
3694        context.load_verify_locations(SIGNING_CA)
3695        context.load_cert_chain(SIGNED_CERTFILE)
3696        server = socket.socket(socket.AF_INET)
3697        host = "127.0.0.1"
3698        port = socket_helper.bind_port(server)
3699        server = context.wrap_socket(server, server_side=True)
3700        self.assertTrue(server.server_side)
3701
3702        evt = threading.Event()
3703        remote = None
3704        peer = None
3705        def serve():
3706            nonlocal remote, peer
3707            server.listen()
3708            # Block on the accept and wait on the connection to close.
3709            evt.set()
3710            remote, peer = server.accept()
3711            remote.send(remote.recv(4))
3712
3713        t = threading.Thread(target=serve)
3714        t.start()
3715        # Client wait until server setup and perform a connect.
3716        evt.wait()
3717        client = context.wrap_socket(socket.socket())
3718        client.connect((host, port))
3719        client.send(b'data')
3720        client.recv()
3721        client_addr = client.getsockname()
3722        client.close()
3723        t.join()
3724        remote.close()
3725        server.close()
3726        # Sanity checks.
3727        self.assertIsInstance(remote, ssl.SSLSocket)
3728        self.assertEqual(peer, client_addr)
3729
3730    def test_getpeercert_enotconn(self):
3731        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3732        with context.wrap_socket(socket.socket()) as sock:
3733            with self.assertRaises(OSError) as cm:
3734                sock.getpeercert()
3735            self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3736
3737    def test_do_handshake_enotconn(self):
3738        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3739        with context.wrap_socket(socket.socket()) as sock:
3740            with self.assertRaises(OSError) as cm:
3741                sock.do_handshake()
3742            self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3743
3744    def test_no_shared_ciphers(self):
3745        client_context, server_context, hostname = testing_context()
3746        # OpenSSL enables all TLS 1.3 ciphers, enforce TLS 1.2 for test
3747        client_context.options |= ssl.OP_NO_TLSv1_3
3748        # Force different suites on client and server
3749        client_context.set_ciphers("AES128")
3750        server_context.set_ciphers("AES256")
3751        with ThreadedEchoServer(context=server_context) as server:
3752            with client_context.wrap_socket(socket.socket(),
3753                                            server_hostname=hostname) as s:
3754                with self.assertRaises(OSError):
3755                    s.connect((HOST, server.port))
3756        self.assertIn("no shared cipher", server.conn_errors[0])
3757
3758    def test_version_basic(self):
3759        """
3760        Basic tests for SSLSocket.version().
3761        More tests are done in the test_protocol_*() methods.
3762        """
3763        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3764        context.check_hostname = False
3765        context.verify_mode = ssl.CERT_NONE
3766        with ThreadedEchoServer(CERTFILE,
3767                                ssl_version=ssl.PROTOCOL_TLS_SERVER,
3768                                chatty=False) as server:
3769            with context.wrap_socket(socket.socket()) as s:
3770                self.assertIs(s.version(), None)
3771                self.assertIs(s._sslobj, None)
3772                s.connect((HOST, server.port))
3773                if IS_OPENSSL_1_1_1 and has_tls_version('TLSv1_3'):
3774                    self.assertEqual(s.version(), 'TLSv1.3')
3775                elif ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
3776                    self.assertEqual(s.version(), 'TLSv1.2')
3777                else:  # 0.9.8 to 1.0.1
3778                    self.assertIn(s.version(), ('TLSv1', 'TLSv1.2'))
3779            self.assertIs(s._sslobj, None)
3780            self.assertIs(s.version(), None)
3781
3782    @requires_tls_version('TLSv1_3')
3783    def test_tls1_3(self):
3784        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3785        context.load_cert_chain(CERTFILE)
3786        context.options |= (
3787            ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2
3788        )
3789        with ThreadedEchoServer(context=context) as server:
3790            with context.wrap_socket(socket.socket()) as s:
3791                s.connect((HOST, server.port))
3792                self.assertIn(s.cipher()[0], {
3793                    'TLS_AES_256_GCM_SHA384',
3794                    'TLS_CHACHA20_POLY1305_SHA256',
3795                    'TLS_AES_128_GCM_SHA256',
3796                })
3797                self.assertEqual(s.version(), 'TLSv1.3')
3798
3799    @requires_minimum_version
3800    @requires_tls_version('TLSv1_2')
3801    def test_min_max_version_tlsv1_2(self):
3802        client_context, server_context, hostname = testing_context()
3803        # client TLSv1.0 to 1.2
3804        client_context.minimum_version = ssl.TLSVersion.TLSv1
3805        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3806        # server only TLSv1.2
3807        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
3808        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
3809
3810        with ThreadedEchoServer(context=server_context) as server:
3811            with client_context.wrap_socket(socket.socket(),
3812                                            server_hostname=hostname) as s:
3813                s.connect((HOST, server.port))
3814                self.assertEqual(s.version(), 'TLSv1.2')
3815
3816    @requires_minimum_version
3817    @requires_tls_version('TLSv1_1')
3818    def test_min_max_version_tlsv1_1(self):
3819        client_context, server_context, hostname = testing_context()
3820        # client 1.0 to 1.2, server 1.0 to 1.1
3821        client_context.minimum_version = ssl.TLSVersion.TLSv1
3822        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3823        server_context.minimum_version = ssl.TLSVersion.TLSv1
3824        server_context.maximum_version = ssl.TLSVersion.TLSv1_1
3825
3826        with ThreadedEchoServer(context=server_context) as server:
3827            with client_context.wrap_socket(socket.socket(),
3828                                            server_hostname=hostname) as s:
3829                s.connect((HOST, server.port))
3830                self.assertEqual(s.version(), 'TLSv1.1')
3831
3832    @requires_minimum_version
3833    @requires_tls_version('TLSv1_2')
3834    @requires_tls_version('TLSv1')
3835    def test_min_max_version_mismatch(self):
3836        client_context, server_context, hostname = testing_context()
3837        # client 1.0, server 1.2 (mismatch)
3838        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
3839        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
3840        client_context.maximum_version = ssl.TLSVersion.TLSv1
3841        client_context.minimum_version = ssl.TLSVersion.TLSv1
3842        with ThreadedEchoServer(context=server_context) as server:
3843            with client_context.wrap_socket(socket.socket(),
3844                                            server_hostname=hostname) as s:
3845                with self.assertRaises(ssl.SSLError) as e:
3846                    s.connect((HOST, server.port))
3847                self.assertIn("alert", str(e.exception))
3848
3849    @requires_minimum_version
3850    @requires_tls_version('SSLv3')
3851    def test_min_max_version_sslv3(self):
3852        client_context, server_context, hostname = testing_context()
3853        server_context.minimum_version = ssl.TLSVersion.SSLv3
3854        client_context.minimum_version = ssl.TLSVersion.SSLv3
3855        client_context.maximum_version = ssl.TLSVersion.SSLv3
3856        with ThreadedEchoServer(context=server_context) as server:
3857            with client_context.wrap_socket(socket.socket(),
3858                                            server_hostname=hostname) as s:
3859                s.connect((HOST, server.port))
3860                self.assertEqual(s.version(), 'SSLv3')
3861
3862    @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
3863    def test_default_ecdh_curve(self):
3864        # Issue #21015: elliptic curve-based Diffie Hellman key exchange
3865        # should be enabled by default on SSL contexts.
3866        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3867        context.load_cert_chain(CERTFILE)
3868        # TLSv1.3 defaults to PFS key agreement and no longer has KEA in
3869        # cipher name.
3870        context.options |= ssl.OP_NO_TLSv1_3
3871        # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled
3872        # explicitly using the 'ECCdraft' cipher alias.  Otherwise,
3873        # our default cipher list should prefer ECDH-based ciphers
3874        # automatically.
3875        if ssl.OPENSSL_VERSION_INFO < (1, 0, 0):
3876            context.set_ciphers("ECCdraft:ECDH")
3877        with ThreadedEchoServer(context=context) as server:
3878            with context.wrap_socket(socket.socket()) as s:
3879                s.connect((HOST, server.port))
3880                self.assertIn("ECDH", s.cipher()[0])
3881
3882    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
3883                         "'tls-unique' channel binding not available")
3884    def test_tls_unique_channel_binding(self):
3885        """Test tls-unique channel binding."""
3886        if support.verbose:
3887            sys.stdout.write("\n")
3888
3889        client_context, server_context, hostname = testing_context()
3890
3891        server = ThreadedEchoServer(context=server_context,
3892                                    chatty=True,
3893                                    connectionchatty=False)
3894
3895        with server:
3896            with client_context.wrap_socket(
3897                    socket.socket(),
3898                    server_hostname=hostname) as s:
3899                s.connect((HOST, server.port))
3900                # get the data
3901                cb_data = s.get_channel_binding("tls-unique")
3902                if support.verbose:
3903                    sys.stdout.write(
3904                        " got channel binding data: {0!r}\n".format(cb_data))
3905
3906                # check if it is sane
3907                self.assertIsNotNone(cb_data)
3908                if s.version() == 'TLSv1.3':
3909                    self.assertEqual(len(cb_data), 48)
3910                else:
3911                    self.assertEqual(len(cb_data), 12)  # True for TLSv1
3912
3913                # and compare with the peers version
3914                s.write(b"CB tls-unique\n")
3915                peer_data_repr = s.read().strip()
3916                self.assertEqual(peer_data_repr,
3917                                 repr(cb_data).encode("us-ascii"))
3918
3919            # now, again
3920            with client_context.wrap_socket(
3921                    socket.socket(),
3922                    server_hostname=hostname) as s:
3923                s.connect((HOST, server.port))
3924                new_cb_data = s.get_channel_binding("tls-unique")
3925                if support.verbose:
3926                    sys.stdout.write(
3927                        "got another channel binding data: {0!r}\n".format(
3928                            new_cb_data)
3929                    )
3930                # is it really unique
3931                self.assertNotEqual(cb_data, new_cb_data)
3932                self.assertIsNotNone(cb_data)
3933                if s.version() == 'TLSv1.3':
3934                    self.assertEqual(len(cb_data), 48)
3935                else:
3936                    self.assertEqual(len(cb_data), 12)  # True for TLSv1
3937                s.write(b"CB tls-unique\n")
3938                peer_data_repr = s.read().strip()
3939                self.assertEqual(peer_data_repr,
3940                                 repr(new_cb_data).encode("us-ascii"))
3941
3942    def test_compression(self):
3943        client_context, server_context, hostname = testing_context()
3944        stats = server_params_test(client_context, server_context,
3945                                   chatty=True, connectionchatty=True,
3946                                   sni_name=hostname)
3947        if support.verbose:
3948            sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
3949        self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
3950
3951    @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
3952                         "ssl.OP_NO_COMPRESSION needed for this test")
3953    def test_compression_disabled(self):
3954        client_context, server_context, hostname = testing_context()
3955        client_context.options |= ssl.OP_NO_COMPRESSION
3956        server_context.options |= ssl.OP_NO_COMPRESSION
3957        stats = server_params_test(client_context, server_context,
3958                                   chatty=True, connectionchatty=True,
3959                                   sni_name=hostname)
3960        self.assertIs(stats['compression'], None)
3961
3962    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
3963    def test_dh_params(self):
3964        # Check we can get a connection with ephemeral Diffie-Hellman
3965        client_context, server_context, hostname = testing_context()
3966        # test scenario needs TLS <= 1.2
3967        client_context.options |= ssl.OP_NO_TLSv1_3
3968        server_context.load_dh_params(DHFILE)
3969        server_context.set_ciphers("kEDH")
3970        server_context.options |= ssl.OP_NO_TLSv1_3
3971        stats = server_params_test(client_context, server_context,
3972                                   chatty=True, connectionchatty=True,
3973                                   sni_name=hostname)
3974        cipher = stats["cipher"][0]
3975        parts = cipher.split("-")
3976        if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
3977            self.fail("Non-DH cipher: " + cipher[0])
3978
3979    @unittest.skipUnless(HAVE_SECP_CURVES, "needs secp384r1 curve support")
3980    @unittest.skipIf(IS_OPENSSL_1_1_1, "TODO: Test doesn't work on 1.1.1")
3981    def test_ecdh_curve(self):
3982        # server secp384r1, client auto
3983        client_context, server_context, hostname = testing_context()
3984
3985        server_context.set_ecdh_curve("secp384r1")
3986        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
3987        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
3988        stats = server_params_test(client_context, server_context,
3989                                   chatty=True, connectionchatty=True,
3990                                   sni_name=hostname)
3991
3992        # server auto, client secp384r1
3993        client_context, server_context, hostname = testing_context()
3994        client_context.set_ecdh_curve("secp384r1")
3995        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
3996        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
3997        stats = server_params_test(client_context, server_context,
3998                                   chatty=True, connectionchatty=True,
3999                                   sni_name=hostname)
4000
4001        # server / client curve mismatch
4002        client_context, server_context, hostname = testing_context()
4003        client_context.set_ecdh_curve("prime256v1")
4004        server_context.set_ecdh_curve("secp384r1")
4005        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
4006        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
4007        try:
4008            stats = server_params_test(client_context, server_context,
4009                                       chatty=True, connectionchatty=True,
4010                                       sni_name=hostname)
4011        except ssl.SSLError:
4012            pass
4013        else:
4014            # OpenSSL 1.0.2 does not fail although it should.
4015            if IS_OPENSSL_1_1_0:
4016                self.fail("mismatch curve did not fail")
4017
4018    def test_selected_alpn_protocol(self):
4019        # selected_alpn_protocol() is None unless ALPN is used.
4020        client_context, server_context, hostname = testing_context()
4021        stats = server_params_test(client_context, server_context,
4022                                   chatty=True, connectionchatty=True,
4023                                   sni_name=hostname)
4024        self.assertIs(stats['client_alpn_protocol'], None)
4025
4026    @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
4027    def test_selected_alpn_protocol_if_server_uses_alpn(self):
4028        # selected_alpn_protocol() is None unless ALPN is used by the client.
4029        client_context, server_context, hostname = testing_context()
4030        server_context.set_alpn_protocols(['foo', 'bar'])
4031        stats = server_params_test(client_context, server_context,
4032                                   chatty=True, connectionchatty=True,
4033                                   sni_name=hostname)
4034        self.assertIs(stats['client_alpn_protocol'], None)
4035
4036    @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
4037    def test_alpn_protocols(self):
4038        server_protocols = ['foo', 'bar', 'milkshake']
4039        protocol_tests = [
4040            (['foo', 'bar'], 'foo'),
4041            (['bar', 'foo'], 'foo'),
4042            (['milkshake'], 'milkshake'),
4043            (['http/3.0', 'http/4.0'], None)
4044        ]
4045        for client_protocols, expected in protocol_tests:
4046            client_context, server_context, hostname = testing_context()
4047            server_context.set_alpn_protocols(server_protocols)
4048            client_context.set_alpn_protocols(client_protocols)
4049
4050            try:
4051                stats = server_params_test(client_context,
4052                                           server_context,
4053                                           chatty=True,
4054                                           connectionchatty=True,
4055                                           sni_name=hostname)
4056            except ssl.SSLError as e:
4057                stats = e
4058
4059            if (expected is None and IS_OPENSSL_1_1_0
4060                    and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)):
4061                # OpenSSL 1.1.0 to 1.1.0e raises handshake error
4062                self.assertIsInstance(stats, ssl.SSLError)
4063            else:
4064                msg = "failed trying %s (s) and %s (c).\n" \
4065                    "was expecting %s, but got %%s from the %%s" \
4066                        % (str(server_protocols), str(client_protocols),
4067                            str(expected))
4068                client_result = stats['client_alpn_protocol']
4069                self.assertEqual(client_result, expected,
4070                                 msg % (client_result, "client"))
4071                server_result = stats['server_alpn_protocols'][-1] \
4072                    if len(stats['server_alpn_protocols']) else 'nothing'
4073                self.assertEqual(server_result, expected,
4074                                 msg % (server_result, "server"))
4075
4076    def test_selected_npn_protocol(self):
4077        # selected_npn_protocol() is None unless NPN is used
4078        client_context, server_context, hostname = testing_context()
4079        stats = server_params_test(client_context, server_context,
4080                                   chatty=True, connectionchatty=True,
4081                                   sni_name=hostname)
4082        self.assertIs(stats['client_npn_protocol'], None)
4083
4084    @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
4085    def test_npn_protocols(self):
4086        server_protocols = ['http/1.1', 'spdy/2']
4087        protocol_tests = [
4088            (['http/1.1', 'spdy/2'], 'http/1.1'),
4089            (['spdy/2', 'http/1.1'], 'http/1.1'),
4090            (['spdy/2', 'test'], 'spdy/2'),
4091            (['abc', 'def'], 'abc')
4092        ]
4093        for client_protocols, expected in protocol_tests:
4094            client_context, server_context, hostname = testing_context()
4095            server_context.set_npn_protocols(server_protocols)
4096            client_context.set_npn_protocols(client_protocols)
4097            stats = server_params_test(client_context, server_context,
4098                                       chatty=True, connectionchatty=True,
4099                                       sni_name=hostname)
4100            msg = "failed trying %s (s) and %s (c).\n" \
4101                  "was expecting %s, but got %%s from the %%s" \
4102                      % (str(server_protocols), str(client_protocols),
4103                         str(expected))
4104            client_result = stats['client_npn_protocol']
4105            self.assertEqual(client_result, expected, msg % (client_result, "client"))
4106            server_result = stats['server_npn_protocols'][-1] \
4107                if len(stats['server_npn_protocols']) else 'nothing'
4108            self.assertEqual(server_result, expected, msg % (server_result, "server"))
4109
4110    def sni_contexts(self):
4111        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4112        server_context.load_cert_chain(SIGNED_CERTFILE)
4113        other_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4114        other_context.load_cert_chain(SIGNED_CERTFILE2)
4115        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4116        client_context.load_verify_locations(SIGNING_CA)
4117        return server_context, other_context, client_context
4118
4119    def check_common_name(self, stats, name):
4120        cert = stats['peercert']
4121        self.assertIn((('commonName', name),), cert['subject'])
4122
4123    @needs_sni
4124    def test_sni_callback(self):
4125        calls = []
4126        server_context, other_context, client_context = self.sni_contexts()
4127
4128        client_context.check_hostname = False
4129
4130        def servername_cb(ssl_sock, server_name, initial_context):
4131            calls.append((server_name, initial_context))
4132            if server_name is not None:
4133                ssl_sock.context = other_context
4134        server_context.set_servername_callback(servername_cb)
4135
4136        stats = server_params_test(client_context, server_context,
4137                                   chatty=True,
4138                                   sni_name='supermessage')
4139        # The hostname was fetched properly, and the certificate was
4140        # changed for the connection.
4141        self.assertEqual(calls, [("supermessage", server_context)])
4142        # CERTFILE4 was selected
4143        self.check_common_name(stats, 'fakehostname')
4144
4145        calls = []
4146        # The callback is called with server_name=None
4147        stats = server_params_test(client_context, server_context,
4148                                   chatty=True,
4149                                   sni_name=None)
4150        self.assertEqual(calls, [(None, server_context)])
4151        self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
4152
4153        # Check disabling the callback
4154        calls = []
4155        server_context.set_servername_callback(None)
4156
4157        stats = server_params_test(client_context, server_context,
4158                                   chatty=True,
4159                                   sni_name='notfunny')
4160        # Certificate didn't change
4161        self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
4162        self.assertEqual(calls, [])
4163
4164    @needs_sni
4165    def test_sni_callback_alert(self):
4166        # Returning a TLS alert is reflected to the connecting client
4167        server_context, other_context, client_context = self.sni_contexts()
4168
4169        def cb_returning_alert(ssl_sock, server_name, initial_context):
4170            return ssl.ALERT_DESCRIPTION_ACCESS_DENIED
4171        server_context.set_servername_callback(cb_returning_alert)
4172        with self.assertRaises(ssl.SSLError) as cm:
4173            stats = server_params_test(client_context, server_context,
4174                                       chatty=False,
4175                                       sni_name='supermessage')
4176        self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
4177
4178    @needs_sni
4179    def test_sni_callback_raising(self):
4180        # Raising fails the connection with a TLS handshake failure alert.
4181        server_context, other_context, client_context = self.sni_contexts()
4182
4183        def cb_raising(ssl_sock, server_name, initial_context):
4184            1/0
4185        server_context.set_servername_callback(cb_raising)
4186
4187        with support.catch_unraisable_exception() as catch:
4188            with self.assertRaises(ssl.SSLError) as cm:
4189                stats = server_params_test(client_context, server_context,
4190                                           chatty=False,
4191                                           sni_name='supermessage')
4192
4193            self.assertEqual(cm.exception.reason,
4194                             'SSLV3_ALERT_HANDSHAKE_FAILURE')
4195            self.assertEqual(catch.unraisable.exc_type, ZeroDivisionError)
4196
4197    @needs_sni
4198    def test_sni_callback_wrong_return_type(self):
4199        # Returning the wrong return type terminates the TLS connection
4200        # with an internal error alert.
4201        server_context, other_context, client_context = self.sni_contexts()
4202
4203        def cb_wrong_return_type(ssl_sock, server_name, initial_context):
4204            return "foo"
4205        server_context.set_servername_callback(cb_wrong_return_type)
4206
4207        with support.catch_unraisable_exception() as catch:
4208            with self.assertRaises(ssl.SSLError) as cm:
4209                stats = server_params_test(client_context, server_context,
4210                                           chatty=False,
4211                                           sni_name='supermessage')
4212
4213
4214            self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
4215            self.assertEqual(catch.unraisable.exc_type, TypeError)
4216
4217    def test_shared_ciphers(self):
4218        client_context, server_context, hostname = testing_context()
4219        client_context.set_ciphers("AES128:AES256")
4220        server_context.set_ciphers("AES256")
4221        expected_algs = [
4222            "AES256", "AES-256",
4223            # TLS 1.3 ciphers are always enabled
4224            "TLS_CHACHA20", "TLS_AES",
4225        ]
4226
4227        stats = server_params_test(client_context, server_context,
4228                                   sni_name=hostname)
4229        ciphers = stats['server_shared_ciphers'][0]
4230        self.assertGreater(len(ciphers), 0)
4231        for name, tls_version, bits in ciphers:
4232            if not any(alg in name for alg in expected_algs):
4233                self.fail(name)
4234
4235    def test_read_write_after_close_raises_valuerror(self):
4236        client_context, server_context, hostname = testing_context()
4237        server = ThreadedEchoServer(context=server_context, chatty=False)
4238
4239        with server:
4240            s = client_context.wrap_socket(socket.socket(),
4241                                           server_hostname=hostname)
4242            s.connect((HOST, server.port))
4243            s.close()
4244
4245            self.assertRaises(ValueError, s.read, 1024)
4246            self.assertRaises(ValueError, s.write, b'hello')
4247
4248    def test_sendfile(self):
4249        TEST_DATA = b"x" * 512
4250        with open(support.TESTFN, 'wb') as f:
4251            f.write(TEST_DATA)
4252        self.addCleanup(support.unlink, support.TESTFN)
4253        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
4254        context.verify_mode = ssl.CERT_REQUIRED
4255        context.load_verify_locations(SIGNING_CA)
4256        context.load_cert_chain(SIGNED_CERTFILE)
4257        server = ThreadedEchoServer(context=context, chatty=False)
4258        with server:
4259            with context.wrap_socket(socket.socket()) as s:
4260                s.connect((HOST, server.port))
4261                with open(support.TESTFN, 'rb') as file:
4262                    s.sendfile(file)
4263                    self.assertEqual(s.recv(1024), TEST_DATA)
4264
4265    def test_session(self):
4266        client_context, server_context, hostname = testing_context()
4267        # TODO: sessions aren't compatible with TLSv1.3 yet
4268        client_context.options |= ssl.OP_NO_TLSv1_3
4269
4270        # first connection without session
4271        stats = server_params_test(client_context, server_context,
4272                                   sni_name=hostname)
4273        session = stats['session']
4274        self.assertTrue(session.id)
4275        self.assertGreater(session.time, 0)
4276        self.assertGreater(session.timeout, 0)
4277        self.assertTrue(session.has_ticket)
4278        if ssl.OPENSSL_VERSION_INFO > (1, 0, 1):
4279            self.assertGreater(session.ticket_lifetime_hint, 0)
4280        self.assertFalse(stats['session_reused'])
4281        sess_stat = server_context.session_stats()
4282        self.assertEqual(sess_stat['accept'], 1)
4283        self.assertEqual(sess_stat['hits'], 0)
4284
4285        # reuse session
4286        stats = server_params_test(client_context, server_context,
4287                                   session=session, sni_name=hostname)
4288        sess_stat = server_context.session_stats()
4289        self.assertEqual(sess_stat['accept'], 2)
4290        self.assertEqual(sess_stat['hits'], 1)
4291        self.assertTrue(stats['session_reused'])
4292        session2 = stats['session']
4293        self.assertEqual(session2.id, session.id)
4294        self.assertEqual(session2, session)
4295        self.assertIsNot(session2, session)
4296        self.assertGreaterEqual(session2.time, session.time)
4297        self.assertGreaterEqual(session2.timeout, session.timeout)
4298
4299        # another one without session
4300        stats = server_params_test(client_context, server_context,
4301                                   sni_name=hostname)
4302        self.assertFalse(stats['session_reused'])
4303        session3 = stats['session']
4304        self.assertNotEqual(session3.id, session.id)
4305        self.assertNotEqual(session3, session)
4306        sess_stat = server_context.session_stats()
4307        self.assertEqual(sess_stat['accept'], 3)
4308        self.assertEqual(sess_stat['hits'], 1)
4309
4310        # reuse session again
4311        stats = server_params_test(client_context, server_context,
4312                                   session=session, sni_name=hostname)
4313        self.assertTrue(stats['session_reused'])
4314        session4 = stats['session']
4315        self.assertEqual(session4.id, session.id)
4316        self.assertEqual(session4, session)
4317        self.assertGreaterEqual(session4.time, session.time)
4318        self.assertGreaterEqual(session4.timeout, session.timeout)
4319        sess_stat = server_context.session_stats()
4320        self.assertEqual(sess_stat['accept'], 4)
4321        self.assertEqual(sess_stat['hits'], 2)
4322
4323    def test_session_handling(self):
4324        client_context, server_context, hostname = testing_context()
4325        client_context2, _, _ = testing_context()
4326
4327        # TODO: session reuse does not work with TLSv1.3
4328        client_context.options |= ssl.OP_NO_TLSv1_3
4329        client_context2.options |= ssl.OP_NO_TLSv1_3
4330
4331        server = ThreadedEchoServer(context=server_context, chatty=False)
4332        with server:
4333            with client_context.wrap_socket(socket.socket(),
4334                                            server_hostname=hostname) as s:
4335                # session is None before handshake
4336                self.assertEqual(s.session, None)
4337                self.assertEqual(s.session_reused, None)
4338                s.connect((HOST, server.port))
4339                session = s.session
4340                self.assertTrue(session)
4341                with self.assertRaises(TypeError) as e:
4342                    s.session = object
4343                self.assertEqual(str(e.exception), 'Value is not a SSLSession.')
4344
4345            with client_context.wrap_socket(socket.socket(),
4346                                            server_hostname=hostname) as s:
4347                s.connect((HOST, server.port))
4348                # cannot set session after handshake
4349                with self.assertRaises(ValueError) as e:
4350                    s.session = session
4351                self.assertEqual(str(e.exception),
4352                                 'Cannot set session after handshake.')
4353
4354            with client_context.wrap_socket(socket.socket(),
4355                                            server_hostname=hostname) as s:
4356                # can set session before handshake and before the
4357                # connection was established
4358                s.session = session
4359                s.connect((HOST, server.port))
4360                self.assertEqual(s.session.id, session.id)
4361                self.assertEqual(s.session, session)
4362                self.assertEqual(s.session_reused, True)
4363
4364            with client_context2.wrap_socket(socket.socket(),
4365                                             server_hostname=hostname) as s:
4366                # cannot re-use session with a different SSLContext
4367                with self.assertRaises(ValueError) as e:
4368                    s.session = session
4369                    s.connect((HOST, server.port))
4370                self.assertEqual(str(e.exception),
4371                                 'Session refers to a different SSLContext.')
4372
4373
4374@unittest.skipUnless(has_tls_version('TLSv1_3'), "Test needs TLS 1.3")
4375class TestPostHandshakeAuth(unittest.TestCase):
4376    def test_pha_setter(self):
4377        protocols = [
4378            ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_SERVER, ssl.PROTOCOL_TLS_CLIENT
4379        ]
4380        for protocol in protocols:
4381            ctx = ssl.SSLContext(protocol)
4382            self.assertEqual(ctx.post_handshake_auth, False)
4383
4384            ctx.post_handshake_auth = True
4385            self.assertEqual(ctx.post_handshake_auth, True)
4386
4387            ctx.verify_mode = ssl.CERT_REQUIRED
4388            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
4389            self.assertEqual(ctx.post_handshake_auth, True)
4390
4391            ctx.post_handshake_auth = False
4392            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
4393            self.assertEqual(ctx.post_handshake_auth, False)
4394
4395            ctx.verify_mode = ssl.CERT_OPTIONAL
4396            ctx.post_handshake_auth = True
4397            self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
4398            self.assertEqual(ctx.post_handshake_auth, True)
4399
4400    def test_pha_required(self):
4401        client_context, server_context, hostname = testing_context()
4402        server_context.post_handshake_auth = True
4403        server_context.verify_mode = ssl.CERT_REQUIRED
4404        client_context.post_handshake_auth = True
4405        client_context.load_cert_chain(SIGNED_CERTFILE)
4406
4407        server = ThreadedEchoServer(context=server_context, chatty=False)
4408        with server:
4409            with client_context.wrap_socket(socket.socket(),
4410                                            server_hostname=hostname) as s:
4411                s.connect((HOST, server.port))
4412                s.write(b'HASCERT')
4413                self.assertEqual(s.recv(1024), b'FALSE\n')
4414                s.write(b'PHA')
4415                self.assertEqual(s.recv(1024), b'OK\n')
4416                s.write(b'HASCERT')
4417                self.assertEqual(s.recv(1024), b'TRUE\n')
4418                # PHA method just returns true when cert is already available
4419                s.write(b'PHA')
4420                self.assertEqual(s.recv(1024), b'OK\n')
4421                s.write(b'GETCERT')
4422                cert_text = s.recv(4096).decode('us-ascii')
4423                self.assertIn('Python Software Foundation CA', cert_text)
4424
4425    def test_pha_required_nocert(self):
4426        client_context, server_context, hostname = testing_context()
4427        server_context.post_handshake_auth = True
4428        server_context.verify_mode = ssl.CERT_REQUIRED
4429        client_context.post_handshake_auth = True
4430
4431        # Ignore expected SSLError in ConnectionHandler of ThreadedEchoServer
4432        # (it is only raised sometimes on Windows)
4433        with support.catch_threading_exception() as cm:
4434            server = ThreadedEchoServer(context=server_context, chatty=False)
4435            with server:
4436                with client_context.wrap_socket(socket.socket(),
4437                                                server_hostname=hostname) as s:
4438                    s.connect((HOST, server.port))
4439                    s.write(b'PHA')
4440                    # receive CertificateRequest
4441                    self.assertEqual(s.recv(1024), b'OK\n')
4442                    # send empty Certificate + Finish
4443                    s.write(b'HASCERT')
4444                    # receive alert
4445                    with self.assertRaisesRegex(
4446                            ssl.SSLError,
4447                            'tlsv13 alert certificate required'):
4448                        s.recv(1024)
4449
4450    def test_pha_optional(self):
4451        if support.verbose:
4452            sys.stdout.write("\n")
4453
4454        client_context, server_context, hostname = testing_context()
4455        server_context.post_handshake_auth = True
4456        server_context.verify_mode = ssl.CERT_REQUIRED
4457        client_context.post_handshake_auth = True
4458        client_context.load_cert_chain(SIGNED_CERTFILE)
4459
4460        # check CERT_OPTIONAL
4461        server_context.verify_mode = ssl.CERT_OPTIONAL
4462        server = ThreadedEchoServer(context=server_context, chatty=False)
4463        with server:
4464            with client_context.wrap_socket(socket.socket(),
4465                                            server_hostname=hostname) as s:
4466                s.connect((HOST, server.port))
4467                s.write(b'HASCERT')
4468                self.assertEqual(s.recv(1024), b'FALSE\n')
4469                s.write(b'PHA')
4470                self.assertEqual(s.recv(1024), b'OK\n')
4471                s.write(b'HASCERT')
4472                self.assertEqual(s.recv(1024), b'TRUE\n')
4473
4474    def test_pha_optional_nocert(self):
4475        if support.verbose:
4476            sys.stdout.write("\n")
4477
4478        client_context, server_context, hostname = testing_context()
4479        server_context.post_handshake_auth = True
4480        server_context.verify_mode = ssl.CERT_OPTIONAL
4481        client_context.post_handshake_auth = True
4482
4483        server = ThreadedEchoServer(context=server_context, chatty=False)
4484        with server:
4485            with client_context.wrap_socket(socket.socket(),
4486                                            server_hostname=hostname) as s:
4487                s.connect((HOST, server.port))
4488                s.write(b'HASCERT')
4489                self.assertEqual(s.recv(1024), b'FALSE\n')
4490                s.write(b'PHA')
4491                self.assertEqual(s.recv(1024), b'OK\n')
4492                # optional doesn't fail when client does not have a cert
4493                s.write(b'HASCERT')
4494                self.assertEqual(s.recv(1024), b'FALSE\n')
4495
4496    def test_pha_no_pha_client(self):
4497        client_context, server_context, hostname = testing_context()
4498        server_context.post_handshake_auth = True
4499        server_context.verify_mode = ssl.CERT_REQUIRED
4500        client_context.load_cert_chain(SIGNED_CERTFILE)
4501
4502        server = ThreadedEchoServer(context=server_context, chatty=False)
4503        with server:
4504            with client_context.wrap_socket(socket.socket(),
4505                                            server_hostname=hostname) as s:
4506                s.connect((HOST, server.port))
4507                with self.assertRaisesRegex(ssl.SSLError, 'not server'):
4508                    s.verify_client_post_handshake()
4509                s.write(b'PHA')
4510                self.assertIn(b'extension not received', s.recv(1024))
4511
4512    def test_pha_no_pha_server(self):
4513        # server doesn't have PHA enabled, cert is requested in handshake
4514        client_context, server_context, hostname = testing_context()
4515        server_context.verify_mode = ssl.CERT_REQUIRED
4516        client_context.post_handshake_auth = True
4517        client_context.load_cert_chain(SIGNED_CERTFILE)
4518
4519        server = ThreadedEchoServer(context=server_context, chatty=False)
4520        with server:
4521            with client_context.wrap_socket(socket.socket(),
4522                                            server_hostname=hostname) as s:
4523                s.connect((HOST, server.port))
4524                s.write(b'HASCERT')
4525                self.assertEqual(s.recv(1024), b'TRUE\n')
4526                # PHA doesn't fail if there is already a cert
4527                s.write(b'PHA')
4528                self.assertEqual(s.recv(1024), b'OK\n')
4529                s.write(b'HASCERT')
4530                self.assertEqual(s.recv(1024), b'TRUE\n')
4531
4532    def test_pha_not_tls13(self):
4533        # TLS 1.2
4534        client_context, server_context, hostname = testing_context()
4535        server_context.verify_mode = ssl.CERT_REQUIRED
4536        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4537        client_context.post_handshake_auth = True
4538        client_context.load_cert_chain(SIGNED_CERTFILE)
4539
4540        server = ThreadedEchoServer(context=server_context, chatty=False)
4541        with server:
4542            with client_context.wrap_socket(socket.socket(),
4543                                            server_hostname=hostname) as s:
4544                s.connect((HOST, server.port))
4545                # PHA fails for TLS != 1.3
4546                s.write(b'PHA')
4547                self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024))
4548
4549    def test_bpo37428_pha_cert_none(self):
4550        # verify that post_handshake_auth does not implicitly enable cert
4551        # validation.
4552        hostname = SIGNED_CERTFILE_HOSTNAME
4553        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4554        client_context.post_handshake_auth = True
4555        client_context.load_cert_chain(SIGNED_CERTFILE)
4556        # no cert validation and CA on client side
4557        client_context.check_hostname = False
4558        client_context.verify_mode = ssl.CERT_NONE
4559
4560        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4561        server_context.load_cert_chain(SIGNED_CERTFILE)
4562        server_context.load_verify_locations(SIGNING_CA)
4563        server_context.post_handshake_auth = True
4564        server_context.verify_mode = ssl.CERT_REQUIRED
4565
4566        server = ThreadedEchoServer(context=server_context, chatty=False)
4567        with server:
4568            with client_context.wrap_socket(socket.socket(),
4569                                            server_hostname=hostname) as s:
4570                s.connect((HOST, server.port))
4571                s.write(b'HASCERT')
4572                self.assertEqual(s.recv(1024), b'FALSE\n')
4573                s.write(b'PHA')
4574                self.assertEqual(s.recv(1024), b'OK\n')
4575                s.write(b'HASCERT')
4576                self.assertEqual(s.recv(1024), b'TRUE\n')
4577                # server cert has not been validated
4578                self.assertEqual(s.getpeercert(), {})
4579
4580
4581HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
4582requires_keylog = unittest.skipUnless(
4583    HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')
4584
4585class TestSSLDebug(unittest.TestCase):
4586
4587    def keylog_lines(self, fname=support.TESTFN):
4588        with open(fname) as f:
4589            return len(list(f))
4590
4591    @requires_keylog
4592    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4593    def test_keylog_defaults(self):
4594        self.addCleanup(support.unlink, support.TESTFN)
4595        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4596        self.assertEqual(ctx.keylog_filename, None)
4597
4598        self.assertFalse(os.path.isfile(support.TESTFN))
4599        ctx.keylog_filename = support.TESTFN
4600        self.assertEqual(ctx.keylog_filename, support.TESTFN)
4601        self.assertTrue(os.path.isfile(support.TESTFN))
4602        self.assertEqual(self.keylog_lines(), 1)
4603
4604        ctx.keylog_filename = None
4605        self.assertEqual(ctx.keylog_filename, None)
4606
4607        with self.assertRaises((IsADirectoryError, PermissionError)):
4608            # Windows raises PermissionError
4609            ctx.keylog_filename = os.path.dirname(
4610                os.path.abspath(support.TESTFN))
4611
4612        with self.assertRaises(TypeError):
4613            ctx.keylog_filename = 1
4614
4615    @requires_keylog
4616    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4617    def test_keylog_filename(self):
4618        self.addCleanup(support.unlink, support.TESTFN)
4619        client_context, server_context, hostname = testing_context()
4620
4621        client_context.keylog_filename = support.TESTFN
4622        server = ThreadedEchoServer(context=server_context, chatty=False)
4623        with server:
4624            with client_context.wrap_socket(socket.socket(),
4625                                            server_hostname=hostname) as s:
4626                s.connect((HOST, server.port))
4627        # header, 5 lines for TLS 1.3
4628        self.assertEqual(self.keylog_lines(), 6)
4629
4630        client_context.keylog_filename = None
4631        server_context.keylog_filename = support.TESTFN
4632        server = ThreadedEchoServer(context=server_context, chatty=False)
4633        with server:
4634            with client_context.wrap_socket(socket.socket(),
4635                                            server_hostname=hostname) as s:
4636                s.connect((HOST, server.port))
4637        self.assertGreaterEqual(self.keylog_lines(), 11)
4638
4639        client_context.keylog_filename = support.TESTFN
4640        server_context.keylog_filename = support.TESTFN
4641        server = ThreadedEchoServer(context=server_context, chatty=False)
4642        with server:
4643            with client_context.wrap_socket(socket.socket(),
4644                                            server_hostname=hostname) as s:
4645                s.connect((HOST, server.port))
4646        self.assertGreaterEqual(self.keylog_lines(), 21)
4647
4648        client_context.keylog_filename = None
4649        server_context.keylog_filename = None
4650
4651    @requires_keylog
4652    @unittest.skipIf(sys.flags.ignore_environment,
4653                     "test is not compatible with ignore_environment")
4654    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4655    def test_keylog_env(self):
4656        self.addCleanup(support.unlink, support.TESTFN)
4657        with unittest.mock.patch.dict(os.environ):
4658            os.environ['SSLKEYLOGFILE'] = support.TESTFN
4659            self.assertEqual(os.environ['SSLKEYLOGFILE'], support.TESTFN)
4660
4661            ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4662            self.assertEqual(ctx.keylog_filename, None)
4663
4664            ctx = ssl.create_default_context()
4665            self.assertEqual(ctx.keylog_filename, support.TESTFN)
4666
4667            ctx = ssl._create_stdlib_context()
4668            self.assertEqual(ctx.keylog_filename, support.TESTFN)
4669
4670    def test_msg_callback(self):
4671        client_context, server_context, hostname = testing_context()
4672
4673        def msg_cb(conn, direction, version, content_type, msg_type, data):
4674            pass
4675
4676        self.assertIs(client_context._msg_callback, None)
4677        client_context._msg_callback = msg_cb
4678        self.assertIs(client_context._msg_callback, msg_cb)
4679        with self.assertRaises(TypeError):
4680            client_context._msg_callback = object()
4681
4682    def test_msg_callback_tls12(self):
4683        client_context, server_context, hostname = testing_context()
4684        client_context.options |= ssl.OP_NO_TLSv1_3
4685
4686        msg = []
4687
4688        def msg_cb(conn, direction, version, content_type, msg_type, data):
4689            self.assertIsInstance(conn, ssl.SSLSocket)
4690            self.assertIsInstance(data, bytes)
4691            self.assertIn(direction, {'read', 'write'})
4692            msg.append((direction, version, content_type, msg_type))
4693
4694        client_context._msg_callback = msg_cb
4695
4696        server = ThreadedEchoServer(context=server_context, chatty=False)
4697        with server:
4698            with client_context.wrap_socket(socket.socket(),
4699                                            server_hostname=hostname) as s:
4700                s.connect((HOST, server.port))
4701
4702        self.assertIn(
4703            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
4704             _TLSMessageType.SERVER_KEY_EXCHANGE),
4705            msg
4706        )
4707        self.assertIn(
4708            ("write", TLSVersion.TLSv1_2, _TLSContentType.CHANGE_CIPHER_SPEC,
4709             _TLSMessageType.CHANGE_CIPHER_SPEC),
4710            msg
4711        )
4712
4713
4714def test_main(verbose=False):
4715    if support.verbose:
4716        plats = {
4717            'Mac': platform.mac_ver,
4718            'Windows': platform.win32_ver,
4719        }
4720        for name, func in plats.items():
4721            plat = func()
4722            if plat and plat[0]:
4723                plat = '%s %r' % (name, plat)
4724                break
4725        else:
4726            plat = repr(platform.platform())
4727        print("test_ssl: testing with %r %r" %
4728            (ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO))
4729        print("          under %s" % plat)
4730        print("          HAS_SNI = %r" % ssl.HAS_SNI)
4731        print("          OP_ALL = 0x%8x" % ssl.OP_ALL)
4732        try:
4733            print("          OP_NO_TLSv1_1 = 0x%8x" % ssl.OP_NO_TLSv1_1)
4734        except AttributeError:
4735            pass
4736
4737    for filename in [
4738        CERTFILE, BYTES_CERTFILE,
4739        ONLYCERT, ONLYKEY, BYTES_ONLYCERT, BYTES_ONLYKEY,
4740        SIGNED_CERTFILE, SIGNED_CERTFILE2, SIGNING_CA,
4741        BADCERT, BADKEY, EMPTYCERT]:
4742        if not os.path.exists(filename):
4743            raise support.TestFailed("Can't read certificate file %r" % filename)
4744
4745    tests = [
4746        ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests,
4747        SSLObjectTests, SimpleBackgroundTests, ThreadedTests,
4748        TestPostHandshakeAuth, TestSSLDebug
4749    ]
4750
4751    if support.is_resource_enabled('network'):
4752        tests.append(NetworkedTests)
4753
4754    thread_info = support.threading_setup()
4755    try:
4756        support.run_unittest(*tests)
4757    finally:
4758        support.threading_cleanup(*thread_info)
4759
4760if __name__ == "__main__":
4761    test_main()
4762